MVP / mvp /utils /general.py
yzhouchen001's picture
model code
d9df210
raw
history blame
843 Bytes
import torch
from torch import nn
import torch.nn.functional as F
def pad_graph_nodes(mol_enc, g_n_nodes):
"""
Args:
mol_enc: 2D tensor of shape (sum_nodes, D)
Node embeddings for each molecule.
g_n_nodes: list[int] Number of nodes per graph (len = B)
Returns:
padded: (B, max_nodes, D) tensor
mask: (B, max_nodes) bool tensor, True for valid nodes
"""
# Already concatenated: shape (sum_nodes, D)
B = len(g_n_nodes)
D = mol_enc.shape[1]
max_nodes = max(g_n_nodes)
padded = mol_enc.new_zeros((B, max_nodes, D))
mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)
idx = 0
for i, n in enumerate(g_n_nodes):
padded[i, :n] = mol_enc[idx:idx+n]
mask[i, :n] = True
idx += n
return padded, mask