Spaces:
Sleeping
Sleeping
File size: 1,541 Bytes
d9df210 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
import torch.nn as nn
import dgl
from dgllife.model import GCN, GAT
class MolEnc(nn.Module):
def __init__(self,
args,
in_dim,):
super().__init__()
self.return_emb = False
if args.model in ('crossAttenContrastive', 'filipContrastive'):
self.return_emb = True
dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
batchnorm = [True for _ in range(len(args.gnn_channels))]
gnn_map = {
"gcn": GCN(in_dim, args.gnn_channels, batchnorm = batchnorm, dropout = dropout),
"gat": GAT(in_dim, args.gnn_channels, args.attn_heads)
}
self.GNN = gnn_map[args.gnn_type]
self.pool = dgl.nn.pytorch.glob.MaxPooling()
if not self.return_emb:
self.fc1_graph = nn.Linear(args.gnn_channels[len(args.gnn_channels) - 1], args.gnn_hidden_dim * 2)
self.fc2_graph = nn.Linear(args.gnn_hidden_dim * 2, args.final_embedding_dim)
self.dropout = nn.Dropout(args.fc_dropout)
self.relu = nn.ReLU()
def forward(self, g, fp=None) -> torch.Tensor:
g1 = g
f1 = g.ndata['h']
f = self.GNN(g1, f1)
if self.return_emb:
return f
h = self.pool(g1, f)
if fp is not None:
h = torch.concat((h, fp), dim=-1)
h1 = self.relu(self.fc1_graph(h))
h1 = self.dropout(h1)
h1 = self.fc2_graph(h1)
h1 = self.dropout(h1)
return h1
|