File size: 1,541 Bytes
d9df210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import dgl
from dgllife.model import GCN, GAT

class MolEnc(nn.Module):

    def __init__(self,
                 args,
                 in_dim,):
        super().__init__()

        self.return_emb = False

        if args.model in ('crossAttenContrastive', 'filipContrastive'):
            self.return_emb = True

        dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
        batchnorm = [True for _ in range(len(args.gnn_channels))]
        gnn_map = {
            "gcn": GCN(in_dim, args.gnn_channels, batchnorm = batchnorm, dropout = dropout),
            "gat": GAT(in_dim, args.gnn_channels, args.attn_heads)
        }
        self.GNN = gnn_map[args.gnn_type]
        self.pool = dgl.nn.pytorch.glob.MaxPooling()
        
        if not self.return_emb:
            self.fc1_graph = nn.Linear(args.gnn_channels[len(args.gnn_channels) - 1], args.gnn_hidden_dim * 2)
            self.fc2_graph = nn.Linear(args.gnn_hidden_dim * 2, args.final_embedding_dim)

            self.dropout = nn.Dropout(args.fc_dropout)
            self.relu = nn.ReLU()

    def forward(self, g, fp=None) -> torch.Tensor:
        g1 = g
        f1 = g.ndata['h']

        f = self.GNN(g1, f1)
        if self.return_emb:
            return f
        h = self.pool(g1, f)
        if fp is not None:
            h = torch.concat((h, fp), dim=-1)
        h1 = self.relu(self.fc1_graph(h))
        h1 = self.dropout(h1)
        h1 = self.fc2_graph(h1)
        h1 = self.dropout(h1)

        return h1