I have a batched input to GATv2Conv with node matrix of shape [batch_sz , num_nodes , node_feature_dim] , but the GATv2Conv accepts input of dim 2 ,searching through the internet , I found some solution ... (not the one I want)
data_list = [Data(x= torch.squeeze(torch.index_select(x, dim= 0, index = torch.tensor([idx]))) ,
edge_index= self.edge_indices ,
edge_attr= torch.squeeze(torch.index_select(adj_mats[i], dim= 0, index = torch.tensor([idx]))))
for idx in range(self.batch_sz)]
batch = Batch.from_data_list(data_list)
But using above solution , the distinction between graphs got lost , becuase :
batch.x.shape gave [batch_sz * num_nodes , node_feature_dim]...
It simply put all nodes of all graphs in one single graphs.. Now there is shared calculations between different graphs , which is strictly undesirable... As when applying some graph pooling layer, I don't know which nodes belonged to which graph....
Pls suggest some fix for this issue ...
Thanks in advance