Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import dgl | |
| from dgllife.model import GCN, GAT | |
| class MolEnc(nn.Module): | |
| def __init__(self, | |
| args, | |
| in_dim,): | |
| super().__init__() | |
| self.return_emb = False | |
| if args.model in ('crossAttenContrastive', 'filipContrastive'): | |
| self.return_emb = True | |
| dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))] | |
| batchnorm = [True for _ in range(len(args.gnn_channels))] | |
| gnn_map = { | |
| "gcn": GCN(in_dim, args.gnn_channels, batchnorm = batchnorm, dropout = dropout), | |
| "gat": GAT(in_dim, args.gnn_channels, args.attn_heads) | |
| } | |
| self.GNN = gnn_map[args.gnn_type] | |
| self.pool = dgl.nn.pytorch.glob.MaxPooling() | |
| if not self.return_emb: | |
| self.fc1_graph = nn.Linear(args.gnn_channels[len(args.gnn_channels) - 1], args.gnn_hidden_dim * 2) | |
| self.fc2_graph = nn.Linear(args.gnn_hidden_dim * 2, args.final_embedding_dim) | |
| self.dropout = nn.Dropout(args.fc_dropout) | |
| self.relu = nn.ReLU() | |
| def forward(self, g, fp=None) -> torch.Tensor: | |
| g1 = g | |
| f1 = g.ndata['h'] | |
| f = self.GNN(g1, f1) | |
| if self.return_emb: | |
| return f | |
| h = self.pool(g1, f) | |
| if fp is not None: | |
| h = torch.concat((h, fp), dim=-1) | |
| h1 = self.relu(self.fc1_graph(h)) | |
| h1 = self.dropout(h1) | |
| h1 = self.fc2_graph(h1) | |
| h1 = self.dropout(h1) | |
| return h1 | |