modeling script
Browse files- .gitignore +15 -0
- modeling.py +66 -0
- ultra/__init__.py +0 -0
- ultra/base_nbfnet.py +336 -0
- ultra/datasets.py +1095 -0
- ultra/eval.py +153 -0
- ultra/layers.py +234 -0
- ultra/models.py +214 -0
- ultra/rspmm/rspmm.py +204 -0
- ultra/rspmm/source/operator.cuh +82 -0
- ultra/rspmm/source/rspmm.cpp +283 -0
- ultra/rspmm/source/rspmm.cu +386 -0
- ultra/rspmm/source/rspmm.h +108 -0
- ultra/rspmm/source/util.cuh +28 -0
- ultra/tasks.py +201 -0
- ultra/util.py +172 -0
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
output/
|
| 10 |
+
.vscode/
|
| 11 |
+
.DS_Store
|
| 12 |
+
datasets/
|
| 13 |
+
ckpts/
|
| 14 |
+
*.csv
|
| 15 |
+
*.txt
|
modeling.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
| 4 |
+
#sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
| 5 |
+
from ultra.models import Ultra
|
| 6 |
+
from ultra.datasets import WN18RR, CoDExSmall, FB15k237, FB15k237Inductive
|
| 7 |
+
from ultra.eval import test
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class UltraConfig(PretrainedConfig):
|
| 11 |
+
|
| 12 |
+
model_type = "ultra"
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
relation_model_layers: int = 6,
|
| 17 |
+
relation_model_dim: int = 64,
|
| 18 |
+
entity_model_layers: int = 6,
|
| 19 |
+
entity_model_dim: int = 64,
|
| 20 |
+
**kwargs):
|
| 21 |
+
|
| 22 |
+
self.relation_model_cfg = dict(
|
| 23 |
+
input_dim=relation_model_dim,
|
| 24 |
+
hidden_dims=[relation_model_dim]*relation_model_layers,
|
| 25 |
+
message_func="distmult",
|
| 26 |
+
aggregate_func="sum",
|
| 27 |
+
short_cut=True,
|
| 28 |
+
layer_norm=True
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.entity_model_cfg = dict(
|
| 32 |
+
input_dim=entity_model_dim,
|
| 33 |
+
hidden_dims=[entity_model_dim]*entity_model_layers,
|
| 34 |
+
message_func="distmult",
|
| 35 |
+
aggregate_func="sum",
|
| 36 |
+
short_cut=True,
|
| 37 |
+
layer_norm=True
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
super().__init__(**kwargs)
|
| 41 |
+
|
| 42 |
+
class UltraLinkPrediction(PreTrainedModel):
|
| 43 |
+
|
| 44 |
+
config_class = UltraConfig
|
| 45 |
+
|
| 46 |
+
def __init__(self, config):
|
| 47 |
+
super().__init__(config)
|
| 48 |
+
|
| 49 |
+
self.model = Ultra(
|
| 50 |
+
rel_model_cfg=config.relation_model_cfg,
|
| 51 |
+
entity_model_cfg=config.entity_model_cfg,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, data, batch):
|
| 55 |
+
# data: PyG data object
|
| 56 |
+
# batch shape: (bs, 1+num_negs, 3)
|
| 57 |
+
return self.model.forward(data, batch)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
|
| 62 |
+
model = UltraLinkPrediction.from_pretrained("mgalkin/ultra_4g")
|
| 63 |
+
dataset = CoDExSmall(root="./datasets/")
|
| 64 |
+
test(model, mode="test", dataset=dataset, gpus=None)
|
| 65 |
+
# mrr: 0.463971
|
| 66 |
+
# hits@10: 0.666028
|
ultra/__init__.py
ADDED
|
File without changes
|
ultra/base_nbfnet.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from collections.abc import Sequence
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, autograd
|
| 6 |
+
|
| 7 |
+
from torch_scatter import scatter_add
|
| 8 |
+
from . import tasks, layers
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseNBFNet(nn.Module):
|
| 12 |
+
|
| 13 |
+
def __init__(self, input_dim, hidden_dims, num_relation, message_func="distmult", aggregate_func="sum",
|
| 14 |
+
short_cut=False, layer_norm=False, activation="relu", concat_hidden=False, num_mlp_layer=2,
|
| 15 |
+
dependent=False, remove_one_hop=False, num_beam=10, path_topk=10, **kwargs):
|
| 16 |
+
super(BaseNBFNet, self).__init__()
|
| 17 |
+
|
| 18 |
+
if not isinstance(hidden_dims, Sequence):
|
| 19 |
+
hidden_dims = [hidden_dims]
|
| 20 |
+
|
| 21 |
+
self.dims = [input_dim] + list(hidden_dims)
|
| 22 |
+
self.num_relation = num_relation
|
| 23 |
+
self.short_cut = short_cut # whether to use residual connections between GNN layers
|
| 24 |
+
self.concat_hidden = concat_hidden # whether to compute final states as a function of all layer outputs or last
|
| 25 |
+
self.remove_one_hop = remove_one_hop # whether to dynamically remove one-hop edges from edge_index
|
| 26 |
+
self.num_beam = num_beam
|
| 27 |
+
self.path_topk = path_topk
|
| 28 |
+
|
| 29 |
+
self.message_func = message_func
|
| 30 |
+
self.aggregate_func = aggregate_func
|
| 31 |
+
self.layer_norm = layer_norm
|
| 32 |
+
self.activation = activation
|
| 33 |
+
self.num_mlp_layers = num_mlp_layer
|
| 34 |
+
|
| 35 |
+
# self.layers = nn.ModuleList()
|
| 36 |
+
# for i in range(len(self.dims) - 1):
|
| 37 |
+
# self.layers.append(layers.GeneralizedRelationalConv(self.dims[i], self.dims[i + 1], num_relation,
|
| 38 |
+
# self.dims[0], message_func, aggregate_func, layer_norm,
|
| 39 |
+
# activation, dependent))
|
| 40 |
+
|
| 41 |
+
# feature_dim = (sum(hidden_dims) if concat_hidden else hidden_dims[-1]) + input_dim
|
| 42 |
+
|
| 43 |
+
# # additional relation embedding which serves as an initial 'query' for the NBFNet forward pass
|
| 44 |
+
# # each layer has its own learnable relations matrix, so we send the total number of relations, too
|
| 45 |
+
# self.query = nn.Embedding(num_relation, input_dim)
|
| 46 |
+
# self.mlp = nn.Sequential()
|
| 47 |
+
# mlp = []
|
| 48 |
+
# for i in range(num_mlp_layer - 1):
|
| 49 |
+
# mlp.append(nn.Linear(feature_dim, feature_dim))
|
| 50 |
+
# mlp.append(nn.ReLU())
|
| 51 |
+
# mlp.append(nn.Linear(feature_dim, 1))
|
| 52 |
+
# self.mlp = nn.Sequential(*mlp)
|
| 53 |
+
|
| 54 |
+
def remove_easy_edges(self, data, h_index, t_index, r_index=None):
|
| 55 |
+
# we remove training edges (we need to predict them at training time) from the edge index
|
| 56 |
+
# think of it as a dynamic edge dropout
|
| 57 |
+
h_index_ext = torch.cat([h_index, t_index], dim=-1)
|
| 58 |
+
t_index_ext = torch.cat([t_index, h_index], dim=-1)
|
| 59 |
+
r_index_ext = torch.cat([r_index, r_index + data.num_relations // 2], dim=-1)
|
| 60 |
+
if self.remove_one_hop:
|
| 61 |
+
# we remove all existing immediate edges between heads and tails in the batch
|
| 62 |
+
edge_index = data.edge_index
|
| 63 |
+
easy_edge = torch.stack([h_index_ext, t_index_ext]).flatten(1)
|
| 64 |
+
index = tasks.edge_match(edge_index, easy_edge)[0]
|
| 65 |
+
mask = ~index_to_mask(index, data.num_edges)
|
| 66 |
+
else:
|
| 67 |
+
# we remove existing immediate edges between heads and tails in the batch with the given relation
|
| 68 |
+
edge_index = torch.cat([data.edge_index, data.edge_type.unsqueeze(0)])
|
| 69 |
+
# note that here we add relation types r_index_ext to the matching query
|
| 70 |
+
easy_edge = torch.stack([h_index_ext, t_index_ext, r_index_ext]).flatten(1)
|
| 71 |
+
index = tasks.edge_match(edge_index, easy_edge)[0]
|
| 72 |
+
mask = ~index_to_mask(index, data.num_edges)
|
| 73 |
+
|
| 74 |
+
data = copy.copy(data)
|
| 75 |
+
data.edge_index = data.edge_index[:, mask]
|
| 76 |
+
data.edge_type = data.edge_type[mask]
|
| 77 |
+
return data
|
| 78 |
+
|
| 79 |
+
def negative_sample_to_tail(self, h_index, t_index, r_index, num_direct_rel):
|
| 80 |
+
# convert p(h | t, r) to p(t' | h', r')
|
| 81 |
+
# h' = t, r' = r^{-1}, t' = h
|
| 82 |
+
is_t_neg = (h_index == h_index[:, [0]]).all(dim=-1, keepdim=True)
|
| 83 |
+
new_h_index = torch.where(is_t_neg, h_index, t_index)
|
| 84 |
+
new_t_index = torch.where(is_t_neg, t_index, h_index)
|
| 85 |
+
new_r_index = torch.where(is_t_neg, r_index, r_index + num_direct_rel)
|
| 86 |
+
return new_h_index, new_t_index, new_r_index
|
| 87 |
+
|
| 88 |
+
def bellmanford(self, data, h_index, r_index, separate_grad=False):
|
| 89 |
+
batch_size = len(r_index)
|
| 90 |
+
|
| 91 |
+
# initialize queries (relation types of the given triples)
|
| 92 |
+
query = self.query(r_index)
|
| 93 |
+
index = h_index.unsqueeze(-1).expand_as(query)
|
| 94 |
+
|
| 95 |
+
# initial (boundary) condition - initialize all node states as zeros
|
| 96 |
+
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
|
| 97 |
+
# by the scatter operation we put query (relation) embeddings as init features of source (index) nodes
|
| 98 |
+
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
|
| 99 |
+
size = (data.num_nodes, data.num_nodes)
|
| 100 |
+
edge_weight = torch.ones(data.num_edges, device=h_index.device)
|
| 101 |
+
|
| 102 |
+
hiddens = []
|
| 103 |
+
edge_weights = []
|
| 104 |
+
layer_input = boundary
|
| 105 |
+
|
| 106 |
+
for layer in self.layers:
|
| 107 |
+
if separate_grad:
|
| 108 |
+
edge_weight = edge_weight.clone().requires_grad_()
|
| 109 |
+
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
|
| 110 |
+
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
|
| 111 |
+
if self.short_cut and hidden.shape == layer_input.shape:
|
| 112 |
+
# residual connection here
|
| 113 |
+
hidden = hidden + layer_input
|
| 114 |
+
hiddens.append(hidden)
|
| 115 |
+
edge_weights.append(edge_weight)
|
| 116 |
+
layer_input = hidden
|
| 117 |
+
|
| 118 |
+
# original query (relation type) embeddings
|
| 119 |
+
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
|
| 120 |
+
if self.concat_hidden:
|
| 121 |
+
output = torch.cat(hiddens + [node_query], dim=-1)
|
| 122 |
+
else:
|
| 123 |
+
output = torch.cat([hiddens[-1], node_query], dim=-1)
|
| 124 |
+
|
| 125 |
+
return {
|
| 126 |
+
"node_feature": output,
|
| 127 |
+
"edge_weights": edge_weights,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
def forward(self, data, batch):
|
| 131 |
+
h_index, t_index, r_index = batch.unbind(-1)
|
| 132 |
+
if self.training:
|
| 133 |
+
# Edge dropout in the training mode
|
| 134 |
+
# here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
|
| 135 |
+
# to make NBFNet iteration learn non-trivial paths
|
| 136 |
+
data = self.remove_easy_edges(data, h_index, t_index, r_index, data.num_relations // 2)
|
| 137 |
+
|
| 138 |
+
shape = h_index.shape
|
| 139 |
+
# turn all triples in a batch into a tail prediction mode
|
| 140 |
+
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index, num_direct_rel=data.num_relations // 2)
|
| 141 |
+
assert (h_index[:, [0]] == h_index).all()
|
| 142 |
+
assert (r_index[:, [0]] == r_index).all()
|
| 143 |
+
|
| 144 |
+
# message passing and updated node representations
|
| 145 |
+
output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) # (num_nodes, batch_size, feature_dim)
|
| 146 |
+
feature = output["node_feature"]
|
| 147 |
+
index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1])
|
| 148 |
+
# extract representations of tail entities from the updated node states
|
| 149 |
+
feature = feature.gather(1, index) # (batch_size, num_negative + 1, feature_dim)
|
| 150 |
+
|
| 151 |
+
# probability logit for each tail node in the batch
|
| 152 |
+
# (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1)
|
| 153 |
+
score = self.mlp(feature).squeeze(-1)
|
| 154 |
+
return score.view(shape)
|
| 155 |
+
|
| 156 |
+
def visualize(self, data, batch):
|
| 157 |
+
assert batch.shape == (1, 3)
|
| 158 |
+
h_index, t_index, r_index = batch.unbind(-1)
|
| 159 |
+
|
| 160 |
+
output = self.bellmanford(data, h_index, r_index, separate_grad=True)
|
| 161 |
+
feature = output["node_feature"]
|
| 162 |
+
edge_weights = output["edge_weights"]
|
| 163 |
+
|
| 164 |
+
index = t_index.unsqueeze(0).unsqueeze(-1).expand(-1, -1, feature.shape[-1])
|
| 165 |
+
feature = feature.gather(1, index).squeeze(0)
|
| 166 |
+
score = self.mlp(feature).squeeze(-1)
|
| 167 |
+
|
| 168 |
+
edge_grads = autograd.grad(score, edge_weights)
|
| 169 |
+
distances, back_edges = self.beam_search_distance(data, edge_grads, h_index, t_index, self.num_beam)
|
| 170 |
+
paths, weights = self.topk_average_length(distances, back_edges, t_index, self.path_topk)
|
| 171 |
+
|
| 172 |
+
return paths, weights
|
| 173 |
+
|
| 174 |
+
@torch.no_grad()
|
| 175 |
+
def beam_search_distance(self, data, edge_grads, h_index, t_index, num_beam=10):
|
| 176 |
+
# beam search the top-k distance from h to t (and to every other node)
|
| 177 |
+
num_nodes = data.num_nodes
|
| 178 |
+
input = torch.full((num_nodes, num_beam), float("-inf"), device=h_index.device)
|
| 179 |
+
input[h_index, 0] = 0
|
| 180 |
+
edge_mask = data.edge_index[0, :] != t_index
|
| 181 |
+
|
| 182 |
+
distances = []
|
| 183 |
+
back_edges = []
|
| 184 |
+
for edge_grad in edge_grads:
|
| 185 |
+
# we don't allow any path goes out of t once it arrives at t
|
| 186 |
+
node_in, node_out = data.edge_index[:, edge_mask]
|
| 187 |
+
relation = data.edge_type[edge_mask]
|
| 188 |
+
edge_grad = edge_grad[edge_mask]
|
| 189 |
+
|
| 190 |
+
message = input[node_in] + edge_grad.unsqueeze(-1) # (num_edges, num_beam)
|
| 191 |
+
# (num_edges, num_beam, 3)
|
| 192 |
+
msg_source = torch.stack([node_in, node_out, relation], dim=-1).unsqueeze(1).expand(-1, num_beam, -1)
|
| 193 |
+
|
| 194 |
+
# (num_edges, num_beam)
|
| 195 |
+
is_duplicate = torch.isclose(message.unsqueeze(-1), message.unsqueeze(-2)) & \
|
| 196 |
+
(msg_source.unsqueeze(-2) == msg_source.unsqueeze(-3)).all(dim=-1)
|
| 197 |
+
# pick the first occurrence as the ranking in the previous node's beam
|
| 198 |
+
# this makes deduplication easier later
|
| 199 |
+
# and store it in msg_source
|
| 200 |
+
is_duplicate = is_duplicate.float() - \
|
| 201 |
+
torch.arange(num_beam, dtype=torch.float, device=message.device) / (num_beam + 1)
|
| 202 |
+
prev_rank = is_duplicate.argmax(dim=-1, keepdim=True)
|
| 203 |
+
msg_source = torch.cat([msg_source, prev_rank], dim=-1) # (num_edges, num_beam, 4)
|
| 204 |
+
|
| 205 |
+
node_out, order = node_out.sort()
|
| 206 |
+
node_out_set = torch.unique(node_out)
|
| 207 |
+
# sort messages w.r.t. node_out
|
| 208 |
+
message = message[order].flatten() # (num_edges * num_beam)
|
| 209 |
+
msg_source = msg_source[order].flatten(0, -2) # (num_edges * num_beam, 4)
|
| 210 |
+
size = node_out.bincount(minlength=num_nodes)
|
| 211 |
+
msg2out = size_to_index(size[node_out_set] * num_beam)
|
| 212 |
+
# deduplicate messages that are from the same source and the same beam
|
| 213 |
+
is_duplicate = (msg_source[1:] == msg_source[:-1]).all(dim=-1)
|
| 214 |
+
is_duplicate = torch.cat([torch.zeros(1, dtype=torch.bool, device=message.device), is_duplicate])
|
| 215 |
+
message = message[~is_duplicate]
|
| 216 |
+
msg_source = msg_source[~is_duplicate]
|
| 217 |
+
msg2out = msg2out[~is_duplicate]
|
| 218 |
+
size = msg2out.bincount(minlength=len(node_out_set))
|
| 219 |
+
|
| 220 |
+
if not torch.isinf(message).all():
|
| 221 |
+
# take the topk messages from the neighborhood
|
| 222 |
+
# distance: (len(node_out_set) * num_beam)
|
| 223 |
+
distance, rel_index = scatter_topk(message, size, k=num_beam)
|
| 224 |
+
abs_index = rel_index + (size.cumsum(0) - size).unsqueeze(-1)
|
| 225 |
+
# store msg_source for backtracking
|
| 226 |
+
back_edge = msg_source[abs_index] # (len(node_out_set) * num_beam, 4)
|
| 227 |
+
distance = distance.view(len(node_out_set), num_beam)
|
| 228 |
+
back_edge = back_edge.view(len(node_out_set), num_beam, 4)
|
| 229 |
+
# scatter distance / back_edge back to all nodes
|
| 230 |
+
distance = scatter_add(distance, node_out_set, dim=0, dim_size=num_nodes) # (num_nodes, num_beam)
|
| 231 |
+
back_edge = scatter_add(back_edge, node_out_set, dim=0, dim_size=num_nodes) # (num_nodes, num_beam, 4)
|
| 232 |
+
else:
|
| 233 |
+
distance = torch.full((num_nodes, num_beam), float("-inf"), device=message.device)
|
| 234 |
+
back_edge = torch.zeros(num_nodes, num_beam, 4, dtype=torch.long, device=message.device)
|
| 235 |
+
|
| 236 |
+
distances.append(distance)
|
| 237 |
+
back_edges.append(back_edge)
|
| 238 |
+
input = distance
|
| 239 |
+
|
| 240 |
+
return distances, back_edges
|
| 241 |
+
|
| 242 |
+
def topk_average_length(self, distances, back_edges, t_index, k=10):
|
| 243 |
+
# backtrack distances and back_edges to generate the paths
|
| 244 |
+
paths = []
|
| 245 |
+
average_lengths = []
|
| 246 |
+
|
| 247 |
+
for i in range(len(distances)):
|
| 248 |
+
distance, order = distances[i][t_index].flatten(0, -1).sort(descending=True)
|
| 249 |
+
back_edge = back_edges[i][t_index].flatten(0, -2)[order]
|
| 250 |
+
for d, (h, t, r, prev_rank) in zip(distance[:k].tolist(), back_edge[:k].tolist()):
|
| 251 |
+
if d == float("-inf"):
|
| 252 |
+
break
|
| 253 |
+
path = [(h, t, r)]
|
| 254 |
+
for j in range(i - 1, -1, -1):
|
| 255 |
+
h, t, r, prev_rank = back_edges[j][h, prev_rank].tolist()
|
| 256 |
+
path.append((h, t, r))
|
| 257 |
+
paths.append(path[::-1])
|
| 258 |
+
average_lengths.append(d / len(path))
|
| 259 |
+
|
| 260 |
+
if paths:
|
| 261 |
+
average_lengths, paths = zip(*sorted(zip(average_lengths, paths), reverse=True)[:k])
|
| 262 |
+
|
| 263 |
+
return paths, average_lengths
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def index_to_mask(index, size):
|
| 267 |
+
index = index.view(-1)
|
| 268 |
+
size = int(index.max()) + 1 if size is None else size
|
| 269 |
+
mask = index.new_zeros(size, dtype=torch.bool)
|
| 270 |
+
mask[index] = True
|
| 271 |
+
return mask
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def size_to_index(size):
|
| 275 |
+
range = torch.arange(len(size), device=size.device)
|
| 276 |
+
index2sample = range.repeat_interleave(size)
|
| 277 |
+
return index2sample
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def multi_slice_mask(starts, ends, length):
|
| 281 |
+
values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)])
|
| 282 |
+
slices = torch.cat([starts, ends])
|
| 283 |
+
mask = scatter_add(values, slices, dim=0, dim_size=length + 1)[:-1]
|
| 284 |
+
mask = mask.cumsum(0).bool()
|
| 285 |
+
return mask
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def scatter_extend(data, size, input, input_size):
|
| 289 |
+
new_size = size + input_size
|
| 290 |
+
new_cum_size = new_size.cumsum(0)
|
| 291 |
+
new_data = torch.zeros(new_cum_size[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
|
| 292 |
+
starts = new_cum_size - new_size
|
| 293 |
+
ends = starts + size
|
| 294 |
+
index = multi_slice_mask(starts, ends, new_cum_size[-1])
|
| 295 |
+
new_data[index] = data
|
| 296 |
+
new_data[~index] = input
|
| 297 |
+
return new_data, new_size
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def scatter_topk(input, size, k, largest=True):
|
| 301 |
+
index2graph = size_to_index(size)
|
| 302 |
+
index2graph = index2graph.view([-1] + [1] * (input.ndim - 1))
|
| 303 |
+
|
| 304 |
+
mask = ~torch.isinf(input)
|
| 305 |
+
max = input[mask].max().item()
|
| 306 |
+
min = input[mask].min().item()
|
| 307 |
+
safe_input = input.clamp(2 * min - max, 2 * max - min)
|
| 308 |
+
offset = (max - min) * 4
|
| 309 |
+
if largest:
|
| 310 |
+
offset = -offset
|
| 311 |
+
input_ext = safe_input + offset * index2graph
|
| 312 |
+
index_ext = input_ext.argsort(dim=0, descending=largest)
|
| 313 |
+
num_actual = size.clamp(max=k)
|
| 314 |
+
num_padding = k - num_actual
|
| 315 |
+
starts = size.cumsum(0) - size
|
| 316 |
+
ends = starts + num_actual
|
| 317 |
+
mask = multi_slice_mask(starts, ends, len(index_ext)).nonzero().flatten()
|
| 318 |
+
|
| 319 |
+
if (num_padding > 0).any():
|
| 320 |
+
# special case: size < k, pad with the last valid index
|
| 321 |
+
padding = ends - 1
|
| 322 |
+
padding2graph = size_to_index(num_padding)
|
| 323 |
+
mask = scatter_extend(mask, num_actual, padding[padding2graph], num_padding)[0]
|
| 324 |
+
|
| 325 |
+
index = index_ext[mask] # (N * k, ...)
|
| 326 |
+
value = input.gather(0, index)
|
| 327 |
+
if isinstance(k, torch.Tensor) and k.shape == size.shape:
|
| 328 |
+
value = value.view(-1, *input.shape[1:])
|
| 329 |
+
index = index.view(-1, *input.shape[1:])
|
| 330 |
+
index = index - (size.cumsum(0) - size).repeat_interleave(k).view([-1] + [1] * (index.ndim - 1))
|
| 331 |
+
else:
|
| 332 |
+
value = value.view(-1, k, *input.shape[1:])
|
| 333 |
+
index = index.view(-1, k, *input.shape[1:])
|
| 334 |
+
index = index - (size.cumsum(0) - size).view([-1] + [1] * (index.ndim - 1))
|
| 335 |
+
|
| 336 |
+
return value, index
|
ultra/datasets.py
ADDED
|
@@ -0,0 +1,1095 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import csv
|
| 3 |
+
import shutil
|
| 4 |
+
import torch
|
| 5 |
+
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
|
| 6 |
+
from torch_geometric.datasets import RelLinkPredDataset, WordNet18RR
|
| 7 |
+
|
| 8 |
+
from ultra.tasks import build_relation_graph
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GrailInductiveDataset(InMemoryDataset):
|
| 12 |
+
|
| 13 |
+
def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, merge_valid_test=True):
|
| 14 |
+
self.version = version
|
| 15 |
+
assert version in ["v1", "v2", "v3", "v4"]
|
| 16 |
+
|
| 17 |
+
# by default, most models on Grail datasets merge inductive valid and test splits as the final test split
|
| 18 |
+
# with this choice, the validation set is that of the transductive train (on the seen graph)
|
| 19 |
+
# by default it's turned on but you can experiment with turning this option off
|
| 20 |
+
# you'll need to delete the processed datasets then and re-run to cache a new dataset
|
| 21 |
+
self.merge_valid_test = merge_valid_test
|
| 22 |
+
super().__init__(root, transform, pre_transform)
|
| 23 |
+
self.data, self.slices = torch.load(self.processed_paths[0])
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def num_relations(self):
|
| 27 |
+
return int(self.data.edge_type.max()) + 1
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def raw_dir(self):
|
| 31 |
+
return os.path.join(self.root, "grail", self.name, self.version, "raw")
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def processed_dir(self):
|
| 35 |
+
return os.path.join(self.root, "grail", self.name, self.version, "processed")
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def processed_file_names(self):
|
| 39 |
+
return "data.pt"
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def raw_file_names(self):
|
| 43 |
+
return [
|
| 44 |
+
"train_ind.txt", "valid_ind.txt", "test_ind.txt", "train.txt", "valid.txt"
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
def download(self):
|
| 48 |
+
for url, path in zip(self.urls, self.raw_paths):
|
| 49 |
+
download_path = download_url(url % self.version, self.raw_dir)
|
| 50 |
+
os.rename(download_path, path)
|
| 51 |
+
|
| 52 |
+
def process(self):
|
| 53 |
+
test_files = self.raw_paths[:3]
|
| 54 |
+
train_files = self.raw_paths[3:]
|
| 55 |
+
|
| 56 |
+
inv_train_entity_vocab = {}
|
| 57 |
+
inv_test_entity_vocab = {}
|
| 58 |
+
inv_relation_vocab = {}
|
| 59 |
+
triplets = []
|
| 60 |
+
num_samples = []
|
| 61 |
+
|
| 62 |
+
for txt_file in train_files:
|
| 63 |
+
with open(txt_file, "r") as fin:
|
| 64 |
+
num_sample = 0
|
| 65 |
+
for line in fin:
|
| 66 |
+
h_token, r_token, t_token = line.strip().split("\t")
|
| 67 |
+
if h_token not in inv_train_entity_vocab:
|
| 68 |
+
inv_train_entity_vocab[h_token] = len(inv_train_entity_vocab)
|
| 69 |
+
h = inv_train_entity_vocab[h_token]
|
| 70 |
+
if r_token not in inv_relation_vocab:
|
| 71 |
+
inv_relation_vocab[r_token] = len(inv_relation_vocab)
|
| 72 |
+
r = inv_relation_vocab[r_token]
|
| 73 |
+
if t_token not in inv_train_entity_vocab:
|
| 74 |
+
inv_train_entity_vocab[t_token] = len(inv_train_entity_vocab)
|
| 75 |
+
t = inv_train_entity_vocab[t_token]
|
| 76 |
+
triplets.append((h, t, r))
|
| 77 |
+
num_sample += 1
|
| 78 |
+
num_samples.append(num_sample)
|
| 79 |
+
|
| 80 |
+
for txt_file in test_files:
|
| 81 |
+
with open(txt_file, "r") as fin:
|
| 82 |
+
num_sample = 0
|
| 83 |
+
for line in fin:
|
| 84 |
+
h_token, r_token, t_token = line.strip().split("\t")
|
| 85 |
+
if h_token not in inv_test_entity_vocab:
|
| 86 |
+
inv_test_entity_vocab[h_token] = len(inv_test_entity_vocab)
|
| 87 |
+
h = inv_test_entity_vocab[h_token]
|
| 88 |
+
assert r_token in inv_relation_vocab
|
| 89 |
+
r = inv_relation_vocab[r_token]
|
| 90 |
+
if t_token not in inv_test_entity_vocab:
|
| 91 |
+
inv_test_entity_vocab[t_token] = len(inv_test_entity_vocab)
|
| 92 |
+
t = inv_test_entity_vocab[t_token]
|
| 93 |
+
triplets.append((h, t, r))
|
| 94 |
+
num_sample += 1
|
| 95 |
+
num_samples.append(num_sample)
|
| 96 |
+
triplets = torch.tensor(triplets)
|
| 97 |
+
|
| 98 |
+
edge_index = triplets[:, :2].t()
|
| 99 |
+
edge_type = triplets[:, 2]
|
| 100 |
+
num_relations = int(edge_type.max()) + 1
|
| 101 |
+
|
| 102 |
+
# creating fact graphs - those are graphs sent to a model, based on which we'll predict missing facts
|
| 103 |
+
# also, those fact graphs will be used for filtered evaluation
|
| 104 |
+
train_fact_slice = slice(None, sum(num_samples[:1]))
|
| 105 |
+
test_fact_slice = slice(sum(num_samples[:2]), sum(num_samples[:3]))
|
| 106 |
+
train_fact_index = edge_index[:, train_fact_slice]
|
| 107 |
+
train_fact_type = edge_type[train_fact_slice]
|
| 108 |
+
test_fact_index = edge_index[:, test_fact_slice]
|
| 109 |
+
test_fact_type = edge_type[test_fact_slice]
|
| 110 |
+
|
| 111 |
+
# add flipped triplets for the fact graphs
|
| 112 |
+
train_fact_index = torch.cat([train_fact_index, train_fact_index.flip(0)], dim=-1)
|
| 113 |
+
train_fact_type = torch.cat([train_fact_type, train_fact_type + num_relations])
|
| 114 |
+
test_fact_index = torch.cat([test_fact_index, test_fact_index.flip(0)], dim=-1)
|
| 115 |
+
test_fact_type = torch.cat([test_fact_type, test_fact_type + num_relations])
|
| 116 |
+
|
| 117 |
+
train_slice = slice(None, sum(num_samples[:1]))
|
| 118 |
+
valid_slice = slice(sum(num_samples[:1]), sum(num_samples[:2]))
|
| 119 |
+
# by default, SOTA models on Grail datasets merge inductive valid and test splits as the final test split
|
| 120 |
+
# with this choice, the validation set is that of the transductive train (on the seen graph)
|
| 121 |
+
# by default it's turned on but you can experiment with turning this option off
|
| 122 |
+
test_slice = slice(sum(num_samples[:3]), sum(num_samples)) if self.merge_valid_test else slice(sum(num_samples[:4]), sum(num_samples))
|
| 123 |
+
|
| 124 |
+
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=len(inv_train_entity_vocab),
|
| 125 |
+
target_edge_index=edge_index[:, train_slice], target_edge_type=edge_type[train_slice], num_relations=num_relations*2)
|
| 126 |
+
valid_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=len(inv_train_entity_vocab),
|
| 127 |
+
target_edge_index=edge_index[:, valid_slice], target_edge_type=edge_type[valid_slice], num_relations=num_relations*2)
|
| 128 |
+
test_data = Data(edge_index=test_fact_index, edge_type=test_fact_type, num_nodes=len(inv_test_entity_vocab),
|
| 129 |
+
target_edge_index=edge_index[:, test_slice], target_edge_type=edge_type[test_slice], num_relations=num_relations*2)
|
| 130 |
+
|
| 131 |
+
if self.pre_transform is not None:
|
| 132 |
+
train_data = self.pre_transform(train_data)
|
| 133 |
+
valid_data = self.pre_transform(valid_data)
|
| 134 |
+
test_data = self.pre_transform(test_data)
|
| 135 |
+
|
| 136 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
| 137 |
+
|
| 138 |
+
def __repr__(self):
|
| 139 |
+
return "%s(%s)" % (self.name, self.version)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class FB15k237Inductive(GrailInductiveDataset):
|
| 143 |
+
|
| 144 |
+
urls = [
|
| 145 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/train.txt",
|
| 146 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/valid.txt",
|
| 147 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/test.txt",
|
| 148 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/train.txt",
|
| 149 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/valid.txt"
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
name = "IndFB15k237"
|
| 153 |
+
|
| 154 |
+
def __init__(self, root, version):
|
| 155 |
+
super().__init__(root, version)
|
| 156 |
+
|
| 157 |
+
class WN18RRInductive(GrailInductiveDataset):
|
| 158 |
+
|
| 159 |
+
urls = [
|
| 160 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/train.txt",
|
| 161 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/valid.txt",
|
| 162 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/test.txt",
|
| 163 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/train.txt",
|
| 164 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/valid.txt"
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
name = "IndWN18RR"
|
| 168 |
+
|
| 169 |
+
def __init__(self, root, version):
|
| 170 |
+
super().__init__(root, version)
|
| 171 |
+
|
| 172 |
+
class NELLInductive(GrailInductiveDataset):
|
| 173 |
+
urls = [
|
| 174 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/train.txt",
|
| 175 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/valid.txt",
|
| 176 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/test.txt",
|
| 177 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s/train.txt",
|
| 178 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s/valid.txt"
|
| 179 |
+
]
|
| 180 |
+
name = "IndNELL"
|
| 181 |
+
|
| 182 |
+
def __init__(self, root, version):
|
| 183 |
+
super().__init__(root, version)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def FB15k237(root):
|
| 187 |
+
dataset = RelLinkPredDataset(name="FB15k-237", root=root+"/fb15k237/")
|
| 188 |
+
data = dataset.data
|
| 189 |
+
train_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
|
| 190 |
+
target_edge_index=data.train_edge_index, target_edge_type=data.train_edge_type,
|
| 191 |
+
num_relations=dataset.num_relations)
|
| 192 |
+
valid_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
|
| 193 |
+
target_edge_index=data.valid_edge_index, target_edge_type=data.valid_edge_type,
|
| 194 |
+
num_relations=dataset.num_relations)
|
| 195 |
+
test_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
|
| 196 |
+
target_edge_index=data.test_edge_index, target_edge_type=data.test_edge_type,
|
| 197 |
+
num_relations=dataset.num_relations)
|
| 198 |
+
|
| 199 |
+
# build relation graphs
|
| 200 |
+
train_data = build_relation_graph(train_data)
|
| 201 |
+
valid_data = build_relation_graph(valid_data)
|
| 202 |
+
test_data = build_relation_graph(test_data)
|
| 203 |
+
|
| 204 |
+
dataset.data, dataset.slices = dataset.collate([train_data, valid_data, test_data])
|
| 205 |
+
return dataset
|
| 206 |
+
|
| 207 |
+
def WN18RR(root):
|
| 208 |
+
dataset = WordNet18RR(root=root+"/wn18rr/")
|
| 209 |
+
# convert wn18rr into the same format as fb15k-237
|
| 210 |
+
data = dataset.data
|
| 211 |
+
num_nodes = int(data.edge_index.max()) + 1
|
| 212 |
+
num_relations = int(data.edge_type.max()) + 1
|
| 213 |
+
edge_index = data.edge_index[:, data.train_mask]
|
| 214 |
+
edge_type = data.edge_type[data.train_mask]
|
| 215 |
+
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1)
|
| 216 |
+
edge_type = torch.cat([edge_type, edge_type + num_relations])
|
| 217 |
+
train_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
|
| 218 |
+
target_edge_index=data.edge_index[:, data.train_mask],
|
| 219 |
+
target_edge_type=data.edge_type[data.train_mask],
|
| 220 |
+
num_relations=num_relations*2)
|
| 221 |
+
valid_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
|
| 222 |
+
target_edge_index=data.edge_index[:, data.val_mask],
|
| 223 |
+
target_edge_type=data.edge_type[data.val_mask],
|
| 224 |
+
num_relations=num_relations*2)
|
| 225 |
+
test_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
|
| 226 |
+
target_edge_index=data.edge_index[:, data.test_mask],
|
| 227 |
+
target_edge_type=data.edge_type[data.test_mask],
|
| 228 |
+
num_relations=num_relations*2)
|
| 229 |
+
|
| 230 |
+
# build relation graphs
|
| 231 |
+
train_data = build_relation_graph(train_data)
|
| 232 |
+
valid_data = build_relation_graph(valid_data)
|
| 233 |
+
test_data = build_relation_graph(test_data)
|
| 234 |
+
|
| 235 |
+
dataset.data, dataset.slices = dataset.collate([train_data, valid_data, test_data])
|
| 236 |
+
dataset.num_relations = num_relations * 2
|
| 237 |
+
return dataset
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class TransductiveDataset(InMemoryDataset):
|
| 241 |
+
|
| 242 |
+
delimiter = None
|
| 243 |
+
|
| 244 |
+
def __init__(self, root, transform=None, pre_transform=build_relation_graph, **kwargs):
|
| 245 |
+
|
| 246 |
+
super().__init__(root, transform, pre_transform)
|
| 247 |
+
self.data, self.slices = torch.load(self.processed_paths[0])
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def raw_file_names(self):
|
| 251 |
+
return ["train.txt", "valid.txt", "test.txt"]
|
| 252 |
+
|
| 253 |
+
def download(self):
|
| 254 |
+
for url, path in zip(self.urls, self.raw_paths):
|
| 255 |
+
download_path = download_url(url, self.raw_dir)
|
| 256 |
+
os.rename(download_path, path)
|
| 257 |
+
|
| 258 |
+
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
|
| 259 |
+
|
| 260 |
+
triplets = []
|
| 261 |
+
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
|
| 262 |
+
|
| 263 |
+
with open(triplet_file, "r", encoding="utf-8") as fin:
|
| 264 |
+
for l in fin:
|
| 265 |
+
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
|
| 266 |
+
if u not in inv_entity_vocab:
|
| 267 |
+
inv_entity_vocab[u] = entity_cnt
|
| 268 |
+
entity_cnt += 1
|
| 269 |
+
if v not in inv_entity_vocab:
|
| 270 |
+
inv_entity_vocab[v] = entity_cnt
|
| 271 |
+
entity_cnt += 1
|
| 272 |
+
if r not in inv_rel_vocab:
|
| 273 |
+
inv_rel_vocab[r] = rel_cnt
|
| 274 |
+
rel_cnt += 1
|
| 275 |
+
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
|
| 276 |
+
|
| 277 |
+
triplets.append((u, v, r))
|
| 278 |
+
|
| 279 |
+
return {
|
| 280 |
+
"triplets": triplets,
|
| 281 |
+
"num_node": len(inv_entity_vocab), #entity_cnt,
|
| 282 |
+
"num_relation": rel_cnt,
|
| 283 |
+
"inv_entity_vocab": inv_entity_vocab,
|
| 284 |
+
"inv_rel_vocab": inv_rel_vocab
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
# default loading procedure: process train/valid/test files, create graphs from them
|
| 288 |
+
def process(self):
|
| 289 |
+
|
| 290 |
+
train_files = self.raw_paths[:3]
|
| 291 |
+
|
| 292 |
+
train_results = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
|
| 293 |
+
valid_results = self.load_file(train_files[1],
|
| 294 |
+
train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
|
| 295 |
+
test_results = self.load_file(train_files[2],
|
| 296 |
+
train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
|
| 297 |
+
|
| 298 |
+
# in some datasets, there are several new nodes in the test set, eg 123,143 YAGO train adn 123,182 in YAGO test
|
| 299 |
+
# for consistency with other experimental results, we'll include those in the full vocab and num nodes
|
| 300 |
+
num_node = test_results["num_node"]
|
| 301 |
+
# the same for rels: in most cases train == test for transductive
|
| 302 |
+
# for AristoV4 train rels 1593, test 1604
|
| 303 |
+
num_relations = test_results["num_relation"]
|
| 304 |
+
|
| 305 |
+
train_triplets = train_results["triplets"]
|
| 306 |
+
valid_triplets = valid_results["triplets"]
|
| 307 |
+
test_triplets = test_results["triplets"]
|
| 308 |
+
|
| 309 |
+
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_triplets], dtype=torch.long).t()
|
| 310 |
+
train_target_etypes = torch.tensor([t[2] for t in train_triplets])
|
| 311 |
+
|
| 312 |
+
valid_edges = torch.tensor([[t[0], t[1]] for t in valid_triplets], dtype=torch.long).t()
|
| 313 |
+
valid_etypes = torch.tensor([t[2] for t in valid_triplets])
|
| 314 |
+
|
| 315 |
+
test_edges = torch.tensor([[t[0], t[1]] for t in test_triplets], dtype=torch.long).t()
|
| 316 |
+
test_etypes = torch.tensor([t[2] for t in test_triplets])
|
| 317 |
+
|
| 318 |
+
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
|
| 319 |
+
train_etypes = torch.cat([train_target_etypes, train_target_etypes+num_relations])
|
| 320 |
+
|
| 321 |
+
train_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
| 322 |
+
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_relations*2)
|
| 323 |
+
valid_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
| 324 |
+
target_edge_index=valid_edges, target_edge_type=valid_etypes, num_relations=num_relations*2)
|
| 325 |
+
test_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
| 326 |
+
target_edge_index=test_edges, target_edge_type=test_etypes, num_relations=num_relations*2)
|
| 327 |
+
|
| 328 |
+
# build graphs of relations
|
| 329 |
+
if self.pre_transform is not None:
|
| 330 |
+
train_data = self.pre_transform(train_data)
|
| 331 |
+
valid_data = self.pre_transform(valid_data)
|
| 332 |
+
test_data = self.pre_transform(test_data)
|
| 333 |
+
|
| 334 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
| 335 |
+
|
| 336 |
+
def __repr__(self):
|
| 337 |
+
return "%s()" % (self.name)
|
| 338 |
+
|
| 339 |
+
@property
|
| 340 |
+
def num_relations(self):
|
| 341 |
+
return int(self.data.edge_type.max()) + 1
|
| 342 |
+
|
| 343 |
+
@property
|
| 344 |
+
def raw_dir(self):
|
| 345 |
+
return os.path.join(self.root, self.name, "raw")
|
| 346 |
+
|
| 347 |
+
@property
|
| 348 |
+
def processed_dir(self):
|
| 349 |
+
return os.path.join(self.root, self.name, "processed")
|
| 350 |
+
|
| 351 |
+
@property
|
| 352 |
+
def processed_file_names(self):
|
| 353 |
+
return "data.pt"
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class CoDEx(TransductiveDataset):
|
| 358 |
+
|
| 359 |
+
name = "codex"
|
| 360 |
+
urls = [
|
| 361 |
+
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/train.txt",
|
| 362 |
+
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/valid.txt",
|
| 363 |
+
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/test.txt",
|
| 364 |
+
]
|
| 365 |
+
|
| 366 |
+
def download(self):
|
| 367 |
+
for url, path in zip(self.urls, self.raw_paths):
|
| 368 |
+
download_path = download_url(url % self.name, self.raw_dir)
|
| 369 |
+
os.rename(download_path, path)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class CoDExSmall(CoDEx):
|
| 373 |
+
"""
|
| 374 |
+
#node: 2034
|
| 375 |
+
#edge: 36543
|
| 376 |
+
#relation: 42
|
| 377 |
+
"""
|
| 378 |
+
url = "https://zenodo.org/record/4281094/files/codex-s.tar.gz"
|
| 379 |
+
md5 = "63cd8186fc2aeddc154e20cf4a10087e"
|
| 380 |
+
name = "codex-s"
|
| 381 |
+
|
| 382 |
+
def __init__(self, root):
|
| 383 |
+
super(CoDExSmall, self).__init__(root=root, size='s')
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class CoDExMedium(CoDEx):
|
| 387 |
+
"""
|
| 388 |
+
#node: 17050
|
| 389 |
+
#edge: 206205
|
| 390 |
+
#relation: 51
|
| 391 |
+
"""
|
| 392 |
+
url = "https://zenodo.org/record/4281094/files/codex-m.tar.gz"
|
| 393 |
+
md5 = "43e561cfdca1c6ad9cc2f5b1ca4add76"
|
| 394 |
+
name = "codex-m"
|
| 395 |
+
def __init__(self, root):
|
| 396 |
+
super(CoDExMedium, self).__init__(root=root, size='m')
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class CoDExLarge(CoDEx):
|
| 400 |
+
"""
|
| 401 |
+
#node: 77951
|
| 402 |
+
#edge: 612437
|
| 403 |
+
#relation: 69
|
| 404 |
+
"""
|
| 405 |
+
url = "https://zenodo.org/record/4281094/files/codex-l.tar.gz"
|
| 406 |
+
md5 = "9a10f4458c4bd2b16ef9b92b677e0d71"
|
| 407 |
+
name = "codex-l"
|
| 408 |
+
def __init__(self, root):
|
| 409 |
+
super(CoDExLarge, self).__init__(root=root, size='l')
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class NELL995(TransductiveDataset):
|
| 413 |
+
|
| 414 |
+
# from the RED-GNN paper https://github.com/LARS-research/RED-GNN/tree/main/transductive/data/nell
|
| 415 |
+
# the OG dumps were found to have test set leakages
|
| 416 |
+
# training set is made out of facts+train files, so we sum up their samples to build one training graph
|
| 417 |
+
|
| 418 |
+
urls = [
|
| 419 |
+
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/facts.txt",
|
| 420 |
+
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/train.txt",
|
| 421 |
+
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/valid.txt",
|
| 422 |
+
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/test.txt",
|
| 423 |
+
]
|
| 424 |
+
name = "nell995"
|
| 425 |
+
|
| 426 |
+
@property
|
| 427 |
+
def raw_file_names(self):
|
| 428 |
+
return ["facts.txt", "train.txt", "valid.txt", "test.txt"]
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def process(self):
|
| 432 |
+
train_files = self.raw_paths[:4]
|
| 433 |
+
|
| 434 |
+
facts_results = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
|
| 435 |
+
train_results = self.load_file(train_files[1], facts_results["inv_entity_vocab"], facts_results["inv_rel_vocab"])
|
| 436 |
+
valid_results = self.load_file(train_files[2], train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
|
| 437 |
+
test_results = self.load_file(train_files[3], train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
|
| 438 |
+
|
| 439 |
+
num_node = valid_results["num_node"]
|
| 440 |
+
num_relations = train_results["num_relation"]
|
| 441 |
+
|
| 442 |
+
train_triplets = facts_results["triplets"] + train_results["triplets"]
|
| 443 |
+
valid_triplets = valid_results["triplets"]
|
| 444 |
+
test_triplets = test_results["triplets"]
|
| 445 |
+
|
| 446 |
+
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_triplets], dtype=torch.long).t()
|
| 447 |
+
train_target_etypes = torch.tensor([t[2] for t in train_triplets])
|
| 448 |
+
|
| 449 |
+
valid_edges = torch.tensor([[t[0], t[1]] for t in valid_triplets], dtype=torch.long).t()
|
| 450 |
+
valid_etypes = torch.tensor([t[2] for t in valid_triplets])
|
| 451 |
+
|
| 452 |
+
test_edges = torch.tensor([[t[0], t[1]] for t in test_triplets], dtype=torch.long).t()
|
| 453 |
+
test_etypes = torch.tensor([t[2] for t in test_triplets])
|
| 454 |
+
|
| 455 |
+
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
|
| 456 |
+
train_etypes = torch.cat([train_target_etypes, train_target_etypes+num_relations])
|
| 457 |
+
|
| 458 |
+
train_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
| 459 |
+
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_relations*2)
|
| 460 |
+
valid_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
| 461 |
+
target_edge_index=valid_edges, target_edge_type=valid_etypes, num_relations=num_relations*2)
|
| 462 |
+
test_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
| 463 |
+
target_edge_index=test_edges, target_edge_type=test_etypes, num_relations=num_relations*2)
|
| 464 |
+
|
| 465 |
+
# build graphs of relations
|
| 466 |
+
if self.pre_transform is not None:
|
| 467 |
+
train_data = self.pre_transform(train_data)
|
| 468 |
+
valid_data = self.pre_transform(valid_data)
|
| 469 |
+
test_data = self.pre_transform(test_data)
|
| 470 |
+
|
| 471 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
class ConceptNet100k(TransductiveDataset):
|
| 475 |
+
|
| 476 |
+
urls = [
|
| 477 |
+
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/train",
|
| 478 |
+
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/valid",
|
| 479 |
+
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/test",
|
| 480 |
+
]
|
| 481 |
+
name = "cnet100k"
|
| 482 |
+
delimiter = "\t"
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class DBpedia100k(TransductiveDataset):
|
| 486 |
+
urls = [
|
| 487 |
+
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_train.txt",
|
| 488 |
+
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_valid.txt",
|
| 489 |
+
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_test.txt",
|
| 490 |
+
]
|
| 491 |
+
name = "dbp100k"
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class YAGO310(TransductiveDataset):
|
| 495 |
+
|
| 496 |
+
urls = [
|
| 497 |
+
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/train.txt",
|
| 498 |
+
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/valid.txt",
|
| 499 |
+
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/test.txt",
|
| 500 |
+
]
|
| 501 |
+
name = "yago310"
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
class Hetionet(TransductiveDataset):
|
| 505 |
+
|
| 506 |
+
urls = [
|
| 507 |
+
"https://www.dropbox.com/s/y47bt9oq57h6l5k/train.txt?dl=1",
|
| 508 |
+
"https://www.dropbox.com/s/a0pbrx9tz3dgsff/valid.txt?dl=1",
|
| 509 |
+
"https://www.dropbox.com/s/4dhrvg3fyq5tnu4/test.txt?dl=1",
|
| 510 |
+
]
|
| 511 |
+
name = "hetionet"
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class AristoV4(TransductiveDataset):
|
| 515 |
+
|
| 516 |
+
url = "https://zenodo.org/record/5942560/files/aristo-v4.zip"
|
| 517 |
+
|
| 518 |
+
name = "aristov4"
|
| 519 |
+
delimiter = "\t"
|
| 520 |
+
|
| 521 |
+
def download(self):
|
| 522 |
+
download_path = download_url(self.url, self.raw_dir)
|
| 523 |
+
extract_zip(download_path, self.raw_dir)
|
| 524 |
+
os.unlink(download_path)
|
| 525 |
+
for oldname, newname in zip(['train', 'valid', 'test'], self.raw_paths):
|
| 526 |
+
os.rename(os.path.join(self.raw_dir, oldname), newname)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
class SparserKG(TransductiveDataset):
|
| 530 |
+
|
| 531 |
+
# 5 datasets based on FB/NELL/WD, introduced in https://github.com/THU-KEG/DacKGR
|
| 532 |
+
# re-writing the loading function because dumps are in the format (h, t, r) while the standard is (h, r, t)
|
| 533 |
+
|
| 534 |
+
url = "https://raw.githubusercontent.com/THU-KEG/DacKGR/master/data.zip"
|
| 535 |
+
delimiter = "\t"
|
| 536 |
+
base_name = "SparseKG"
|
| 537 |
+
|
| 538 |
+
@property
|
| 539 |
+
def raw_dir(self):
|
| 540 |
+
return os.path.join(self.root, self.base_name, self.name, "raw")
|
| 541 |
+
|
| 542 |
+
@property
|
| 543 |
+
def processed_dir(self):
|
| 544 |
+
return os.path.join(self.root, self.base_name, self.name, "processed")
|
| 545 |
+
|
| 546 |
+
def download(self):
|
| 547 |
+
base_path = os.path.join(self.root, self.base_name)
|
| 548 |
+
download_path = download_url(self.url, base_path)
|
| 549 |
+
extract_zip(download_path, base_path)
|
| 550 |
+
for dsname in ['NELL23K', 'WD-singer', 'FB15K-237-10', 'FB15K-237-20', 'FB15K-237-50']:
|
| 551 |
+
for oldname, newname in zip(['train.triples', 'dev.triples', 'test.triples'], self.raw_file_names):
|
| 552 |
+
os.renames(os.path.join(base_path, "data", dsname, oldname), os.path.join(base_path, dsname, "raw", newname))
|
| 553 |
+
shutil.rmtree(os.path.join(base_path, "data"))
|
| 554 |
+
|
| 555 |
+
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
|
| 556 |
+
|
| 557 |
+
triplets = []
|
| 558 |
+
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
|
| 559 |
+
|
| 560 |
+
with open(triplet_file, "r", encoding="utf-8") as fin:
|
| 561 |
+
for l in fin:
|
| 562 |
+
u, v, r = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
|
| 563 |
+
if u not in inv_entity_vocab:
|
| 564 |
+
inv_entity_vocab[u] = entity_cnt
|
| 565 |
+
entity_cnt += 1
|
| 566 |
+
if v not in inv_entity_vocab:
|
| 567 |
+
inv_entity_vocab[v] = entity_cnt
|
| 568 |
+
entity_cnt += 1
|
| 569 |
+
if r not in inv_rel_vocab:
|
| 570 |
+
inv_rel_vocab[r] = rel_cnt
|
| 571 |
+
rel_cnt += 1
|
| 572 |
+
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
|
| 573 |
+
|
| 574 |
+
triplets.append((u, v, r))
|
| 575 |
+
|
| 576 |
+
return {
|
| 577 |
+
"triplets": triplets,
|
| 578 |
+
"num_node": len(inv_entity_vocab), #entity_cnt,
|
| 579 |
+
"num_relation": rel_cnt,
|
| 580 |
+
"inv_entity_vocab": inv_entity_vocab,
|
| 581 |
+
"inv_rel_vocab": inv_rel_vocab
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
class WDsinger(SparserKG):
|
| 585 |
+
name = "WD-singer"
|
| 586 |
+
|
| 587 |
+
class NELL23k(SparserKG):
|
| 588 |
+
name = "NELL23K"
|
| 589 |
+
|
| 590 |
+
class FB15k237_10(SparserKG):
|
| 591 |
+
name = "FB15K-237-10"
|
| 592 |
+
|
| 593 |
+
class FB15k237_20(SparserKG):
|
| 594 |
+
name = "FB15K-237-20"
|
| 595 |
+
|
| 596 |
+
class FB15k237_50(SparserKG):
|
| 597 |
+
name = "FB15K-237-50"
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class InductiveDataset(InMemoryDataset):
|
| 601 |
+
|
| 602 |
+
delimiter = None
|
| 603 |
+
# some datasets (4 from Hamaguchi et al and Indigo) have validation set based off the train graph, not inference
|
| 604 |
+
valid_on_inf = True #
|
| 605 |
+
|
| 606 |
+
def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, **kwargs):
|
| 607 |
+
|
| 608 |
+
self.version = str(version)
|
| 609 |
+
super().__init__(root, transform, pre_transform)
|
| 610 |
+
self.data, self.slices = torch.load(self.processed_paths[0])
|
| 611 |
+
|
| 612 |
+
def download(self):
|
| 613 |
+
for url, path in zip(self.urls, self.raw_paths):
|
| 614 |
+
download_path = download_url(url % self.version, self.raw_dir)
|
| 615 |
+
os.rename(download_path, path)
|
| 616 |
+
|
| 617 |
+
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
|
| 618 |
+
|
| 619 |
+
triplets = []
|
| 620 |
+
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
|
| 621 |
+
|
| 622 |
+
with open(triplet_file, "r", encoding="utf-8") as fin:
|
| 623 |
+
for l in fin:
|
| 624 |
+
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
|
| 625 |
+
if u not in inv_entity_vocab:
|
| 626 |
+
inv_entity_vocab[u] = entity_cnt
|
| 627 |
+
entity_cnt += 1
|
| 628 |
+
if v not in inv_entity_vocab:
|
| 629 |
+
inv_entity_vocab[v] = entity_cnt
|
| 630 |
+
entity_cnt += 1
|
| 631 |
+
if r not in inv_rel_vocab:
|
| 632 |
+
inv_rel_vocab[r] = rel_cnt
|
| 633 |
+
rel_cnt += 1
|
| 634 |
+
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
|
| 635 |
+
|
| 636 |
+
triplets.append((u, v, r))
|
| 637 |
+
|
| 638 |
+
return {
|
| 639 |
+
"triplets": triplets,
|
| 640 |
+
"num_node": len(inv_entity_vocab), #entity_cnt,
|
| 641 |
+
"num_relation": rel_cnt,
|
| 642 |
+
"inv_entity_vocab": inv_entity_vocab,
|
| 643 |
+
"inv_rel_vocab": inv_rel_vocab
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
def process(self):
|
| 647 |
+
|
| 648 |
+
train_files = self.raw_paths[:4]
|
| 649 |
+
|
| 650 |
+
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
|
| 651 |
+
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
|
| 652 |
+
valid_res = self.load_file(
|
| 653 |
+
train_files[2],
|
| 654 |
+
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
|
| 655 |
+
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"]
|
| 656 |
+
)
|
| 657 |
+
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
|
| 658 |
+
|
| 659 |
+
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
|
| 660 |
+
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
|
| 661 |
+
|
| 662 |
+
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
|
| 663 |
+
|
| 664 |
+
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
|
| 665 |
+
train_target_etypes = torch.tensor([t[2] for t in train_edges])
|
| 666 |
+
|
| 667 |
+
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
|
| 668 |
+
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
|
| 669 |
+
|
| 670 |
+
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
|
| 671 |
+
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
|
| 672 |
+
inf_etypes = torch.tensor([t[2] for t in inf_graph])
|
| 673 |
+
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
|
| 674 |
+
|
| 675 |
+
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
|
| 676 |
+
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
|
| 677 |
+
|
| 678 |
+
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
|
| 679 |
+
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
|
| 680 |
+
valid_data = Data(edge_index=inf_edges if self.valid_on_inf else train_fact_index,
|
| 681 |
+
edge_type=inf_etypes if self.valid_on_inf else train_fact_type,
|
| 682 |
+
num_nodes=inference_num_nodes if self.valid_on_inf else num_train_nodes,
|
| 683 |
+
target_edge_index=inf_valid_edges[:, :2].T,
|
| 684 |
+
target_edge_type=inf_valid_edges[:, 2],
|
| 685 |
+
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
|
| 686 |
+
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
|
| 687 |
+
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
|
| 688 |
+
|
| 689 |
+
if self.pre_transform is not None:
|
| 690 |
+
train_data = self.pre_transform(train_data)
|
| 691 |
+
valid_data = self.pre_transform(valid_data)
|
| 692 |
+
test_data = self.pre_transform(test_data)
|
| 693 |
+
|
| 694 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
| 695 |
+
|
| 696 |
+
@property
|
| 697 |
+
def num_relations(self):
|
| 698 |
+
return int(self.data.edge_type.max()) + 1
|
| 699 |
+
|
| 700 |
+
@property
|
| 701 |
+
def raw_dir(self):
|
| 702 |
+
return os.path.join(self.root, self.name, self.version, "raw")
|
| 703 |
+
|
| 704 |
+
@property
|
| 705 |
+
def processed_dir(self):
|
| 706 |
+
return os.path.join(self.root, self.name, self.version, "processed")
|
| 707 |
+
|
| 708 |
+
@property
|
| 709 |
+
def raw_file_names(self):
|
| 710 |
+
return [
|
| 711 |
+
"transductive_train.txt", "inference_graph.txt", "inf_valid.txt", "inf_test.txt"
|
| 712 |
+
]
|
| 713 |
+
|
| 714 |
+
@property
|
| 715 |
+
def processed_file_names(self):
|
| 716 |
+
return "data.pt"
|
| 717 |
+
|
| 718 |
+
def __repr__(self):
|
| 719 |
+
return "%s(%s)" % (self.name, self.version)
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
class IngramInductive(InductiveDataset):
|
| 723 |
+
|
| 724 |
+
@property
|
| 725 |
+
def raw_dir(self):
|
| 726 |
+
return os.path.join(self.root, "ingram", self.name, self.version, "raw")
|
| 727 |
+
|
| 728 |
+
@property
|
| 729 |
+
def processed_dir(self):
|
| 730 |
+
return os.path.join(self.root, "ingram", self.name, self.version, "processed")
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
class FBIngram(IngramInductive):
|
| 734 |
+
|
| 735 |
+
urls = [
|
| 736 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/train.txt",
|
| 737 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/msg.txt",
|
| 738 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/valid.txt",
|
| 739 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/test.txt",
|
| 740 |
+
]
|
| 741 |
+
name = "fb"
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
class WKIngram(IngramInductive):
|
| 745 |
+
|
| 746 |
+
urls = [
|
| 747 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/train.txt",
|
| 748 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/msg.txt",
|
| 749 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/valid.txt",
|
| 750 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/test.txt",
|
| 751 |
+
]
|
| 752 |
+
name = "wk"
|
| 753 |
+
|
| 754 |
+
class NLIngram(IngramInductive):
|
| 755 |
+
|
| 756 |
+
urls = [
|
| 757 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/train.txt",
|
| 758 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/msg.txt",
|
| 759 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/valid.txt",
|
| 760 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/test.txt",
|
| 761 |
+
]
|
| 762 |
+
name = "nl"
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
class ILPC2022(InductiveDataset):
|
| 766 |
+
|
| 767 |
+
urls = [
|
| 768 |
+
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/train.txt",
|
| 769 |
+
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference.txt",
|
| 770 |
+
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference_validation.txt",
|
| 771 |
+
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference_test.txt",
|
| 772 |
+
]
|
| 773 |
+
|
| 774 |
+
name = "ilpc2022"
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
class HM(InductiveDataset):
|
| 778 |
+
# benchmarks from Hamaguchi et al and Indigo BM
|
| 779 |
+
|
| 780 |
+
urls = [
|
| 781 |
+
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/train/train.txt",
|
| 782 |
+
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/test/test-graph.txt",
|
| 783 |
+
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/train/valid.txt",
|
| 784 |
+
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/test/test-fact.txt",
|
| 785 |
+
]
|
| 786 |
+
|
| 787 |
+
name = "hm"
|
| 788 |
+
versions = {
|
| 789 |
+
'1k': "Hamaguchi-BM_both-1000",
|
| 790 |
+
'3k': "Hamaguchi-BM_both-3000",
|
| 791 |
+
'5k': "Hamaguchi-BM_both-5000",
|
| 792 |
+
'indigo': "INDIGO-BM"
|
| 793 |
+
}
|
| 794 |
+
# in 4 HM graphs, the validation set is based off the training graph, so we'll adjust the dataset creation accordingly
|
| 795 |
+
valid_on_inf = False
|
| 796 |
+
|
| 797 |
+
def __init__(self, root, version, **kwargs):
|
| 798 |
+
version = self.versions[version]
|
| 799 |
+
super().__init__(root, version, **kwargs)
|
| 800 |
+
|
| 801 |
+
# HM datasets are a bit weird: validation set (based off the train graph) has a few hundred new nodes, so we need a custom processing
|
| 802 |
+
def process(self):
|
| 803 |
+
|
| 804 |
+
train_files = self.raw_paths[:4]
|
| 805 |
+
|
| 806 |
+
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
|
| 807 |
+
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
|
| 808 |
+
valid_res = self.load_file(
|
| 809 |
+
train_files[2],
|
| 810 |
+
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
|
| 811 |
+
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"]
|
| 812 |
+
)
|
| 813 |
+
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
|
| 814 |
+
|
| 815 |
+
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
|
| 816 |
+
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
|
| 817 |
+
|
| 818 |
+
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
|
| 819 |
+
|
| 820 |
+
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
|
| 821 |
+
train_target_etypes = torch.tensor([t[2] for t in train_edges])
|
| 822 |
+
|
| 823 |
+
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
|
| 824 |
+
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
|
| 825 |
+
|
| 826 |
+
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
|
| 827 |
+
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
|
| 828 |
+
inf_etypes = torch.tensor([t[2] for t in inf_graph])
|
| 829 |
+
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
|
| 830 |
+
|
| 831 |
+
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
|
| 832 |
+
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
|
| 833 |
+
|
| 834 |
+
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
|
| 835 |
+
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
|
| 836 |
+
valid_data = Data(edge_index=train_fact_index,
|
| 837 |
+
edge_type=train_fact_type,
|
| 838 |
+
num_nodes=valid_res["num_node"], # the only fix in this function
|
| 839 |
+
target_edge_index=inf_valid_edges[:, :2].T,
|
| 840 |
+
target_edge_type=inf_valid_edges[:, 2],
|
| 841 |
+
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
|
| 842 |
+
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
|
| 843 |
+
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
|
| 844 |
+
|
| 845 |
+
if self.pre_transform is not None:
|
| 846 |
+
train_data = self.pre_transform(train_data)
|
| 847 |
+
valid_data = self.pre_transform(valid_data)
|
| 848 |
+
test_data = self.pre_transform(test_data)
|
| 849 |
+
|
| 850 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
class MTDEAInductive(InductiveDataset):
|
| 854 |
+
|
| 855 |
+
valid_on_inf = False
|
| 856 |
+
url = "https://reltrans.s3.us-east-2.amazonaws.com/MTDEA_data.zip"
|
| 857 |
+
base_name = "mtdea"
|
| 858 |
+
|
| 859 |
+
def __init__(self, root, version, **kwargs):
|
| 860 |
+
|
| 861 |
+
assert version in self.versions, f"unknown version {version} for {self.name}, available: {self.versions}"
|
| 862 |
+
super().__init__(root, version, **kwargs)
|
| 863 |
+
|
| 864 |
+
@property
|
| 865 |
+
def raw_dir(self):
|
| 866 |
+
return os.path.join(self.root, self.base_name, self.name, self.version, "raw")
|
| 867 |
+
|
| 868 |
+
@property
|
| 869 |
+
def processed_dir(self):
|
| 870 |
+
return os.path.join(self.root, self.base_name, self.name, self.version, "processed")
|
| 871 |
+
|
| 872 |
+
@property
|
| 873 |
+
def raw_file_names(self):
|
| 874 |
+
return [
|
| 875 |
+
"transductive_train.txt", "inference_graph.txt", "transductive_valid.txt", "inf_test.txt"
|
| 876 |
+
]
|
| 877 |
+
|
| 878 |
+
def download(self):
|
| 879 |
+
base_path = os.path.join(self.root, self.base_name)
|
| 880 |
+
download_path = download_url(self.url, base_path)
|
| 881 |
+
extract_zip(download_path, base_path)
|
| 882 |
+
# unzip all datasets at once
|
| 883 |
+
for dsname in ['FBNELL', 'Metafam', 'WikiTopics-MT1', 'WikiTopics-MT2', 'WikiTopics-MT3', 'WikiTopics-MT4']:
|
| 884 |
+
cl = globals()[dsname.replace("-","")]
|
| 885 |
+
versions = cl.versions
|
| 886 |
+
for version in versions:
|
| 887 |
+
for oldname, newname in zip(['train.txt', 'observe.txt', 'valid.txt', 'test.txt'], self.raw_file_names):
|
| 888 |
+
foldername = cl.prefix % version + "-trans" if "transductive" in newname else cl.prefix % version + "-ind"
|
| 889 |
+
os.renames(
|
| 890 |
+
os.path.join(base_path, "MTDEA_datasets", dsname, foldername, oldname),
|
| 891 |
+
os.path.join(base_path, dsname, version, "raw", newname)
|
| 892 |
+
)
|
| 893 |
+
shutil.rmtree(os.path.join(base_path, "MTDEA_datasets"))
|
| 894 |
+
|
| 895 |
+
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}, limit_vocab=False):
|
| 896 |
+
|
| 897 |
+
triplets = []
|
| 898 |
+
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
|
| 899 |
+
|
| 900 |
+
# limit_vocab is for dropping triples with unseen head/tail not seen in the main entity_vocab
|
| 901 |
+
# can be used for FBNELL and MT3:art, other datasets seem to be ok and share num_nodes/num_relations in the train/inference graph
|
| 902 |
+
with open(triplet_file, "r", encoding="utf-8") as fin:
|
| 903 |
+
for l in fin:
|
| 904 |
+
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
|
| 905 |
+
if u not in inv_entity_vocab:
|
| 906 |
+
if limit_vocab:
|
| 907 |
+
continue
|
| 908 |
+
inv_entity_vocab[u] = entity_cnt
|
| 909 |
+
entity_cnt += 1
|
| 910 |
+
if v not in inv_entity_vocab:
|
| 911 |
+
if limit_vocab:
|
| 912 |
+
continue
|
| 913 |
+
inv_entity_vocab[v] = entity_cnt
|
| 914 |
+
entity_cnt += 1
|
| 915 |
+
if r not in inv_rel_vocab:
|
| 916 |
+
if limit_vocab:
|
| 917 |
+
continue
|
| 918 |
+
inv_rel_vocab[r] = rel_cnt
|
| 919 |
+
rel_cnt += 1
|
| 920 |
+
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
|
| 921 |
+
|
| 922 |
+
triplets.append((u, v, r))
|
| 923 |
+
|
| 924 |
+
return {
|
| 925 |
+
"triplets": triplets,
|
| 926 |
+
"num_node": entity_cnt,
|
| 927 |
+
"num_relation": rel_cnt,
|
| 928 |
+
"inv_entity_vocab": inv_entity_vocab,
|
| 929 |
+
"inv_rel_vocab": inv_rel_vocab
|
| 930 |
+
}
|
| 931 |
+
|
| 932 |
+
# special processes for MTDEA datasets for one particular fix in the validation set loading
|
| 933 |
+
def process(self):
|
| 934 |
+
|
| 935 |
+
train_files = self.raw_paths[:4]
|
| 936 |
+
|
| 937 |
+
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
|
| 938 |
+
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
|
| 939 |
+
valid_res = self.load_file(
|
| 940 |
+
train_files[2],
|
| 941 |
+
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
|
| 942 |
+
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"],
|
| 943 |
+
limit_vocab=True, # the 1st fix in this function compared to the superclass processor
|
| 944 |
+
)
|
| 945 |
+
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
|
| 946 |
+
|
| 947 |
+
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
|
| 948 |
+
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
|
| 949 |
+
|
| 950 |
+
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
|
| 951 |
+
|
| 952 |
+
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
|
| 953 |
+
train_target_etypes = torch.tensor([t[2] for t in train_edges])
|
| 954 |
+
|
| 955 |
+
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
|
| 956 |
+
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
|
| 957 |
+
|
| 958 |
+
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
|
| 959 |
+
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
|
| 960 |
+
inf_etypes = torch.tensor([t[2] for t in inf_graph])
|
| 961 |
+
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
|
| 962 |
+
|
| 963 |
+
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
|
| 964 |
+
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
|
| 965 |
+
|
| 966 |
+
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
|
| 967 |
+
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
|
| 968 |
+
valid_data = Data(edge_index=train_fact_index,
|
| 969 |
+
edge_type=train_fact_type,
|
| 970 |
+
num_nodes=valid_res["num_node"], # the 2nd fix in this function
|
| 971 |
+
target_edge_index=inf_valid_edges[:, :2].T,
|
| 972 |
+
target_edge_type=inf_valid_edges[:, 2],
|
| 973 |
+
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
|
| 974 |
+
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
|
| 975 |
+
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
|
| 976 |
+
|
| 977 |
+
if self.pre_transform is not None:
|
| 978 |
+
train_data = self.pre_transform(train_data)
|
| 979 |
+
valid_data = self.pre_transform(valid_data)
|
| 980 |
+
test_data = self.pre_transform(test_data)
|
| 981 |
+
|
| 982 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
class FBNELL(MTDEAInductive):
|
| 986 |
+
|
| 987 |
+
name = "FBNELL"
|
| 988 |
+
prefix = "%s"
|
| 989 |
+
versions = ["FBNELL_v1"]
|
| 990 |
+
|
| 991 |
+
def __init__(self, **kwargs):
|
| 992 |
+
kwargs.pop("version")
|
| 993 |
+
kwargs['version'] = self.versions[0]
|
| 994 |
+
super(FBNELL, self).__init__(**kwargs)
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
class Metafam(MTDEAInductive):
|
| 998 |
+
|
| 999 |
+
name = "Metafam"
|
| 1000 |
+
prefix = "%s"
|
| 1001 |
+
versions = ["Metafam"]
|
| 1002 |
+
|
| 1003 |
+
def __init__(self, **kwargs):
|
| 1004 |
+
kwargs.pop("version")
|
| 1005 |
+
kwargs['version'] = self.versions[0]
|
| 1006 |
+
super(Metafam, self).__init__(**kwargs)
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
class WikiTopicsMT1(MTDEAInductive):
|
| 1010 |
+
|
| 1011 |
+
name = "WikiTopics-MT1"
|
| 1012 |
+
prefix = "wikidata_%sv1"
|
| 1013 |
+
versions = ['mt', 'health', 'tax']
|
| 1014 |
+
|
| 1015 |
+
def __init__(self, **kwargs):
|
| 1016 |
+
assert kwargs['version'] in self.versions, f"unknown version {kwargs['version']}, available: {self.versions}"
|
| 1017 |
+
super(WikiTopicsMT1, self).__init__(**kwargs)
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
class WikiTopicsMT2(MTDEAInductive):
|
| 1021 |
+
|
| 1022 |
+
name = "WikiTopics-MT2"
|
| 1023 |
+
prefix = "wikidata_%sv1"
|
| 1024 |
+
versions = ['mt2', 'org', 'sci']
|
| 1025 |
+
|
| 1026 |
+
def __init__(self, **kwargs):
|
| 1027 |
+
super(WikiTopicsMT2, self).__init__(**kwargs)
|
| 1028 |
+
|
| 1029 |
+
|
| 1030 |
+
class WikiTopicsMT3(MTDEAInductive):
|
| 1031 |
+
|
| 1032 |
+
name = "WikiTopics-MT3"
|
| 1033 |
+
prefix = "wikidata_%sv2"
|
| 1034 |
+
versions = ['mt3', 'art', 'infra']
|
| 1035 |
+
|
| 1036 |
+
def __init__(self, **kwargs):
|
| 1037 |
+
super(WikiTopicsMT3, self).__init__(**kwargs)
|
| 1038 |
+
|
| 1039 |
+
|
| 1040 |
+
class WikiTopicsMT4(MTDEAInductive):
|
| 1041 |
+
|
| 1042 |
+
name = "WikiTopics-MT4"
|
| 1043 |
+
prefix = "wikidata_%sv2"
|
| 1044 |
+
versions = ['mt4', 'sci', 'health']
|
| 1045 |
+
|
| 1046 |
+
def __init__(self, **kwargs):
|
| 1047 |
+
super(WikiTopicsMT4, self).__init__(**kwargs)
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
# a joint dataset for pre-training ULTRA on several graphs
|
| 1051 |
+
class JointDataset(InMemoryDataset):
|
| 1052 |
+
|
| 1053 |
+
datasets_map = {
|
| 1054 |
+
'FB15k237': FB15k237,
|
| 1055 |
+
'WN18RR': WN18RR,
|
| 1056 |
+
'CoDExSmall': CoDExSmall,
|
| 1057 |
+
'CoDExMedium': CoDExMedium,
|
| 1058 |
+
'CoDExLarge': CoDExLarge,
|
| 1059 |
+
'NELL995': NELL995,
|
| 1060 |
+
'ConceptNet100k': ConceptNet100k,
|
| 1061 |
+
'DBpedia100k': DBpedia100k,
|
| 1062 |
+
'YAGO310': YAGO310,
|
| 1063 |
+
'AristoV4': AristoV4,
|
| 1064 |
+
}
|
| 1065 |
+
|
| 1066 |
+
def __init__(self, root, graphs, transform=None, pre_transform=None):
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
self.graphs = [self.datasets_map[ds](root=root) for ds in graphs]
|
| 1070 |
+
self.num_graphs = len(graphs)
|
| 1071 |
+
super().__init__(root, transform, pre_transform)
|
| 1072 |
+
self.data = torch.load(self.processed_paths[0])
|
| 1073 |
+
|
| 1074 |
+
@property
|
| 1075 |
+
def raw_dir(self):
|
| 1076 |
+
return os.path.join(self.root, "joint", f'{self.num_graphs}g', "raw")
|
| 1077 |
+
|
| 1078 |
+
@property
|
| 1079 |
+
def processed_dir(self):
|
| 1080 |
+
return os.path.join(self.root, "joint", f'{self.num_graphs}g', "processed")
|
| 1081 |
+
|
| 1082 |
+
@property
|
| 1083 |
+
def processed_file_names(self):
|
| 1084 |
+
return "data.pt"
|
| 1085 |
+
|
| 1086 |
+
def process(self):
|
| 1087 |
+
|
| 1088 |
+
train_data = [g[0] for g in self.graphs]
|
| 1089 |
+
valid_data = [g[1] for g in self.graphs]
|
| 1090 |
+
test_data = [g[2] for g in self.graphs]
|
| 1091 |
+
# filter_data = [
|
| 1092 |
+
# Data(edge_index=g.data.target_edge_index, edge_type=g.data.target_edge_type, num_nodes=g[0].num_nodes) for g in self.graphs
|
| 1093 |
+
# ]
|
| 1094 |
+
|
| 1095 |
+
torch.save((train_data, valid_data, test_data), self.processed_paths[0])
|
ultra/eval.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import distributed as dist
|
| 5 |
+
from torch.utils import data as torch_data
|
| 6 |
+
from torch_geometric.data import Data
|
| 7 |
+
|
| 8 |
+
from ultra import tasks, util
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
TRANSDUCTIVE = ("WordNet18RR", "RelLinkPredDataset", "CoDExSmall", "CoDExMedium", "CoDExLarge",
|
| 12 |
+
"YAGO310", "NELL995", "ConceptNet100k", "DBpedia100k", "Hetionet", "AristoV4",
|
| 13 |
+
"WDsinger", "NELL23k", "FB15k237_10", "FB15k237_20", "FB15k237_50")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_filtered_data(dataset, mode):
|
| 17 |
+
train_data, valid_data, test_data = dataset[0], dataset[1], dataset[2]
|
| 18 |
+
ds_name = dataset.__class__.__name__
|
| 19 |
+
|
| 20 |
+
if ds_name in TRANSDUCTIVE:
|
| 21 |
+
filtered_data = Data(edge_index=dataset._data.target_edge_index, edge_type=dataset._data.target_edge_type, num_nodes=dataset[0].num_nodes)
|
| 22 |
+
else:
|
| 23 |
+
if "ILPC" in ds_name or "Ingram" in ds_name:
|
| 24 |
+
full_inference_edges = torch.cat([valid_data.edge_index, valid_data.target_edge_index, test_data.target_edge_index], dim=1)
|
| 25 |
+
full_inference_etypes = torch.cat([valid_data.edge_type, valid_data.target_edge_type, test_data.target_edge_type])
|
| 26 |
+
filtered_data = Data(edge_index=full_inference_edges, edge_type=full_inference_etypes, num_nodes=test_data.num_nodes)
|
| 27 |
+
else:
|
| 28 |
+
# test filtering graph: inference edges + test edges
|
| 29 |
+
full_inference_edges = torch.cat([test_data.edge_index, test_data.target_edge_index], dim=1)
|
| 30 |
+
full_inference_etypes = torch.cat([test_data.edge_type, test_data.target_edge_type])
|
| 31 |
+
if mode == "test":
|
| 32 |
+
filtered_data = Data(edge_index=full_inference_edges, edge_type=full_inference_etypes, num_nodes=test_data.num_nodes)
|
| 33 |
+
else:
|
| 34 |
+
# validation filtering graph: train edges + validation edges
|
| 35 |
+
filtered_data = Data(
|
| 36 |
+
edge_index=torch.cat([train_data.edge_index, valid_data.target_edge_index], dim=1),
|
| 37 |
+
edge_type=torch.cat([train_data.edge_type, valid_data.target_edge_type])
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
return filtered_data
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def test(model, mode, dataset, batch_size=32, eval_metrics=["mrr", "hits@10"], gpus=None, return_metrics=False):
|
| 45 |
+
logger = util.get_root_logger()
|
| 46 |
+
test_data = dataset[1] if mode == "valid" else dataset[2]
|
| 47 |
+
filtered_data = get_filtered_data(dataset, mode)
|
| 48 |
+
|
| 49 |
+
device = util.get_devices(gpus)
|
| 50 |
+
world_size = util.get_world_size()
|
| 51 |
+
rank = util.get_rank()
|
| 52 |
+
|
| 53 |
+
test_triplets = torch.cat([test_data.target_edge_index, test_data.target_edge_type.unsqueeze(0)]).t()
|
| 54 |
+
sampler = torch_data.DistributedSampler(test_triplets, world_size, rank)
|
| 55 |
+
test_loader = torch_data.DataLoader(test_triplets, batch_size, sampler=sampler)
|
| 56 |
+
|
| 57 |
+
model.eval()
|
| 58 |
+
rankings = []
|
| 59 |
+
num_negatives = []
|
| 60 |
+
tail_rankings, num_tail_negs = [], [] # for explicit tail-only evaluation needed for 5 datasets
|
| 61 |
+
for batch in test_loader:
|
| 62 |
+
t_batch, h_batch = tasks.all_negative(test_data, batch)
|
| 63 |
+
t_pred = model(test_data, t_batch)
|
| 64 |
+
h_pred = model(test_data, h_batch)
|
| 65 |
+
|
| 66 |
+
if filtered_data is None:
|
| 67 |
+
t_mask, h_mask = tasks.strict_negative_mask(test_data, batch)
|
| 68 |
+
else:
|
| 69 |
+
t_mask, h_mask = tasks.strict_negative_mask(filtered_data, batch)
|
| 70 |
+
pos_h_index, pos_t_index, pos_r_index = batch.t()
|
| 71 |
+
t_ranking = tasks.compute_ranking(t_pred, pos_t_index, t_mask)
|
| 72 |
+
h_ranking = tasks.compute_ranking(h_pred, pos_h_index, h_mask)
|
| 73 |
+
num_t_negative = t_mask.sum(dim=-1)
|
| 74 |
+
num_h_negative = h_mask.sum(dim=-1)
|
| 75 |
+
|
| 76 |
+
rankings += [t_ranking, h_ranking]
|
| 77 |
+
num_negatives += [num_t_negative, num_h_negative]
|
| 78 |
+
|
| 79 |
+
tail_rankings += [t_ranking]
|
| 80 |
+
num_tail_negs += [num_t_negative]
|
| 81 |
+
|
| 82 |
+
ranking = torch.cat(rankings)
|
| 83 |
+
num_negative = torch.cat(num_negatives)
|
| 84 |
+
all_size = torch.zeros(world_size, dtype=torch.long, device=device)
|
| 85 |
+
all_size[rank] = len(ranking)
|
| 86 |
+
|
| 87 |
+
# ugly repetitive code for tail-only ranks processing
|
| 88 |
+
tail_ranking = torch.cat(tail_rankings)
|
| 89 |
+
num_tail_neg = torch.cat(num_tail_negs)
|
| 90 |
+
all_size_t = torch.zeros(world_size, dtype=torch.long, device=device)
|
| 91 |
+
all_size_t[rank] = len(tail_ranking)
|
| 92 |
+
if world_size > 1:
|
| 93 |
+
dist.all_reduce(all_size, op=dist.ReduceOp.SUM)
|
| 94 |
+
dist.all_reduce(all_size_t, op=dist.ReduceOp.SUM)
|
| 95 |
+
|
| 96 |
+
# obtaining all ranks
|
| 97 |
+
cum_size = all_size.cumsum(0)
|
| 98 |
+
all_ranking = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
|
| 99 |
+
all_ranking[cum_size[rank] - all_size[rank]: cum_size[rank]] = ranking
|
| 100 |
+
all_num_negative = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
|
| 101 |
+
all_num_negative[cum_size[rank] - all_size[rank]: cum_size[rank]] = num_negative
|
| 102 |
+
|
| 103 |
+
# the same for tails-only ranks
|
| 104 |
+
cum_size_t = all_size_t.cumsum(0)
|
| 105 |
+
all_ranking_t = torch.zeros(all_size_t.sum(), dtype=torch.long, device=device)
|
| 106 |
+
all_ranking_t[cum_size_t[rank] - all_size_t[rank]: cum_size_t[rank]] = tail_ranking
|
| 107 |
+
all_num_negative_t = torch.zeros(all_size_t.sum(), dtype=torch.long, device=device)
|
| 108 |
+
all_num_negative_t[cum_size_t[rank] - all_size_t[rank]: cum_size_t[rank]] = num_tail_neg
|
| 109 |
+
if world_size > 1:
|
| 110 |
+
dist.all_reduce(all_ranking, op=dist.ReduceOp.SUM)
|
| 111 |
+
dist.all_reduce(all_num_negative, op=dist.ReduceOp.SUM)
|
| 112 |
+
dist.all_reduce(all_ranking_t, op=dist.ReduceOp.SUM)
|
| 113 |
+
dist.all_reduce(all_num_negative_t, op=dist.ReduceOp.SUM)
|
| 114 |
+
|
| 115 |
+
metrics = {}
|
| 116 |
+
if rank == 0:
|
| 117 |
+
for metric in eval_metrics:
|
| 118 |
+
if "-tail" in metric:
|
| 119 |
+
_metric_name, direction = metric.split("-")
|
| 120 |
+
if direction != "tail":
|
| 121 |
+
raise ValueError("Only tail metric is supported in this mode")
|
| 122 |
+
_ranking = all_ranking_t
|
| 123 |
+
_num_neg = all_num_negative_t
|
| 124 |
+
else:
|
| 125 |
+
_ranking = all_ranking
|
| 126 |
+
_num_neg = all_num_negative
|
| 127 |
+
_metric_name = metric
|
| 128 |
+
|
| 129 |
+
if _metric_name == "mr":
|
| 130 |
+
score = _ranking.float().mean()
|
| 131 |
+
elif _metric_name == "mrr":
|
| 132 |
+
score = (1 / _ranking.float()).mean()
|
| 133 |
+
elif _metric_name.startswith("hits@"):
|
| 134 |
+
values = _metric_name[5:].split("_")
|
| 135 |
+
threshold = int(values[0])
|
| 136 |
+
if len(values) > 1:
|
| 137 |
+
num_sample = int(values[1])
|
| 138 |
+
# unbiased estimation
|
| 139 |
+
fp_rate = (_ranking - 1).float() / _num_neg
|
| 140 |
+
score = 0
|
| 141 |
+
for i in range(threshold):
|
| 142 |
+
# choose i false positive from num_sample - 1 negatives
|
| 143 |
+
num_comb = math.factorial(num_sample - 1) / \
|
| 144 |
+
math.factorial(i) / math.factorial(num_sample - i - 1)
|
| 145 |
+
score += num_comb * (fp_rate ** i) * ((1 - fp_rate) ** (num_sample - i - 1))
|
| 146 |
+
score = score.mean()
|
| 147 |
+
else:
|
| 148 |
+
score = (_ranking <= threshold).float().mean()
|
| 149 |
+
logger.warning("%s: %g" % (metric, score))
|
| 150 |
+
metrics[metric] = score
|
| 151 |
+
mrr = (1 / all_ranking.float()).mean()
|
| 152 |
+
|
| 153 |
+
return mrr if not return_metrics else metrics
|
ultra/layers.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torch_scatter import scatter
|
| 5 |
+
|
| 6 |
+
from torch_geometric.nn.conv import MessagePassing
|
| 7 |
+
from torch_geometric.utils import degree
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GeneralizedRelationalConv(MessagePassing):
|
| 12 |
+
|
| 13 |
+
eps = 1e-6
|
| 14 |
+
|
| 15 |
+
message2mul = {
|
| 16 |
+
"transe": "add",
|
| 17 |
+
"distmult": "mul",
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
# TODO for compile() - doesn't work currently
|
| 21 |
+
# propagate_type = {"edge_index": torch.LongTensor, "size": Tuple[int, int]}
|
| 22 |
+
|
| 23 |
+
def __init__(self, input_dim, output_dim, num_relation, query_input_dim, message_func="distmult",
|
| 24 |
+
aggregate_func="pna", layer_norm=False, activation="relu", dependent=False, project_relations=False):
|
| 25 |
+
super(GeneralizedRelationalConv, self).__init__()
|
| 26 |
+
self.input_dim = input_dim
|
| 27 |
+
self.output_dim = output_dim
|
| 28 |
+
self.num_relation = num_relation
|
| 29 |
+
self.query_input_dim = query_input_dim
|
| 30 |
+
self.message_func = message_func
|
| 31 |
+
self.aggregate_func = aggregate_func
|
| 32 |
+
self.dependent = dependent
|
| 33 |
+
self.project_relations = project_relations
|
| 34 |
+
|
| 35 |
+
if layer_norm:
|
| 36 |
+
self.layer_norm = nn.LayerNorm(output_dim)
|
| 37 |
+
else:
|
| 38 |
+
self.layer_norm = None
|
| 39 |
+
if isinstance(activation, str):
|
| 40 |
+
self.activation = getattr(F, activation)
|
| 41 |
+
else:
|
| 42 |
+
self.activation = activation
|
| 43 |
+
|
| 44 |
+
if self.aggregate_func == "pna":
|
| 45 |
+
self.linear = nn.Linear(input_dim * 13, output_dim)
|
| 46 |
+
else:
|
| 47 |
+
self.linear = nn.Linear(input_dim * 2, output_dim)
|
| 48 |
+
|
| 49 |
+
if dependent:
|
| 50 |
+
# obtain relation embeddings as a projection of the query relation
|
| 51 |
+
self.relation_linear = nn.Linear(query_input_dim, num_relation * input_dim)
|
| 52 |
+
else:
|
| 53 |
+
if not self.project_relations:
|
| 54 |
+
# relation embeddings as an independent embedding matrix per each layer
|
| 55 |
+
self.relation = nn.Embedding(num_relation, input_dim)
|
| 56 |
+
else:
|
| 57 |
+
# will be initialized after the pass over relation graph
|
| 58 |
+
self.relation = None
|
| 59 |
+
self.relation_projection = nn.Sequential(
|
| 60 |
+
nn.Linear(input_dim, input_dim),
|
| 61 |
+
nn.ReLU(),
|
| 62 |
+
nn.Linear(input_dim, input_dim)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def forward(self, input, query, boundary, edge_index, edge_type, size, edge_weight=None):
|
| 67 |
+
batch_size = len(query)
|
| 68 |
+
|
| 69 |
+
if self.dependent:
|
| 70 |
+
# layer-specific relation features as a projection of input "query" (relation) embeddings
|
| 71 |
+
relation = self.relation_linear(query).view(batch_size, self.num_relation, self.input_dim)
|
| 72 |
+
else:
|
| 73 |
+
if not self.project_relations:
|
| 74 |
+
# layer-specific relation features as a special embedding matrix unique to each layer
|
| 75 |
+
relation = self.relation.weight.expand(batch_size, -1, -1)
|
| 76 |
+
else:
|
| 77 |
+
# NEW and only change:
|
| 78 |
+
# projecting relation features to unique features for this layer, then resizing for the current batch
|
| 79 |
+
relation = self.relation_projection(self.relation)
|
| 80 |
+
if edge_weight is None:
|
| 81 |
+
edge_weight = torch.ones(len(edge_type), device=input.device)
|
| 82 |
+
|
| 83 |
+
# note that we send the initial boundary condition (node states at layer0) to the message passing
|
| 84 |
+
# correspond to Eq.6 on p5 in https://arxiv.org/pdf/2106.06935.pdf
|
| 85 |
+
output = self.propagate(input=input, relation=relation, boundary=boundary, edge_index=edge_index,
|
| 86 |
+
edge_type=edge_type, size=size, edge_weight=edge_weight)
|
| 87 |
+
return output
|
| 88 |
+
|
| 89 |
+
def propagate(self, edge_index, size=None, **kwargs):
|
| 90 |
+
if kwargs["edge_weight"].requires_grad or self.message_func == "rotate":
|
| 91 |
+
# the rspmm cuda kernel only works for TransE and DistMult message functions
|
| 92 |
+
# otherwise we invoke separate message & aggregate functions
|
| 93 |
+
return super(GeneralizedRelationalConv, self).propagate(edge_index, size, **kwargs)
|
| 94 |
+
|
| 95 |
+
for hook in self._propagate_forward_pre_hooks.values():
|
| 96 |
+
res = hook(self, (edge_index, size, kwargs))
|
| 97 |
+
if res is not None:
|
| 98 |
+
edge_index, size, kwargs = res
|
| 99 |
+
|
| 100 |
+
# in newer PyG,
|
| 101 |
+
# __check_input__ -> _check_input()
|
| 102 |
+
# __collect__ -> _collect()
|
| 103 |
+
# __fused_user_args__ -> _fuser_user_args
|
| 104 |
+
size = self._check_input(edge_index, size)
|
| 105 |
+
coll_dict = self._collect(self._fused_user_args, edge_index, size, kwargs)
|
| 106 |
+
|
| 107 |
+
msg_aggr_kwargs = self.inspector.distribute("message_and_aggregate", coll_dict)
|
| 108 |
+
for hook in self._message_and_aggregate_forward_pre_hooks.values():
|
| 109 |
+
res = hook(self, (edge_index, msg_aggr_kwargs))
|
| 110 |
+
if res is not None:
|
| 111 |
+
edge_index, msg_aggr_kwargs = res
|
| 112 |
+
out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
|
| 113 |
+
for hook in self._message_and_aggregate_forward_hooks.values():
|
| 114 |
+
res = hook(self, (edge_index, msg_aggr_kwargs), out)
|
| 115 |
+
if res is not None:
|
| 116 |
+
out = res
|
| 117 |
+
|
| 118 |
+
update_kwargs = self.inspector.distribute("update", coll_dict)
|
| 119 |
+
out = self.update(out, **update_kwargs)
|
| 120 |
+
|
| 121 |
+
for hook in self._propagate_forward_hooks.values():
|
| 122 |
+
res = hook(self, (edge_index, size, kwargs), out)
|
| 123 |
+
if res is not None:
|
| 124 |
+
out = res
|
| 125 |
+
|
| 126 |
+
return out
|
| 127 |
+
|
| 128 |
+
def message(self, input_j, relation, boundary, edge_type):
|
| 129 |
+
relation_j = relation.index_select(self.node_dim, edge_type)
|
| 130 |
+
|
| 131 |
+
if self.message_func == "transe":
|
| 132 |
+
message = input_j + relation_j
|
| 133 |
+
elif self.message_func == "distmult":
|
| 134 |
+
message = input_j * relation_j
|
| 135 |
+
elif self.message_func == "rotate":
|
| 136 |
+
x_j_re, x_j_im = input_j.chunk(2, dim=-1)
|
| 137 |
+
r_j_re, r_j_im = relation_j.chunk(2, dim=-1)
|
| 138 |
+
message_re = x_j_re * r_j_re - x_j_im * r_j_im
|
| 139 |
+
message_im = x_j_re * r_j_im + x_j_im * r_j_re
|
| 140 |
+
message = torch.cat([message_re, message_im], dim=-1)
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError("Unknown message function `%s`" % self.message_func)
|
| 143 |
+
|
| 144 |
+
# augment messages with the boundary condition
|
| 145 |
+
message = torch.cat([message, boundary], dim=self.node_dim) # (num_edges + num_nodes, batch_size, input_dim)
|
| 146 |
+
|
| 147 |
+
return message
|
| 148 |
+
|
| 149 |
+
def aggregate(self, input, edge_weight, index, dim_size):
|
| 150 |
+
# augment aggregation index with self-loops for the boundary condition
|
| 151 |
+
index = torch.cat([index, torch.arange(dim_size, device=input.device)]) # (num_edges + num_nodes,)
|
| 152 |
+
edge_weight = torch.cat([edge_weight, torch.ones(dim_size, device=input.device)])
|
| 153 |
+
shape = [1] * input.ndim
|
| 154 |
+
shape[self.node_dim] = -1
|
| 155 |
+
edge_weight = edge_weight.view(shape)
|
| 156 |
+
|
| 157 |
+
if self.aggregate_func == "pna":
|
| 158 |
+
mean = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="mean")
|
| 159 |
+
sq_mean = scatter(input ** 2 * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="mean")
|
| 160 |
+
max = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="max")
|
| 161 |
+
min = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="min")
|
| 162 |
+
std = (sq_mean - mean ** 2).clamp(min=self.eps).sqrt()
|
| 163 |
+
features = torch.cat([mean.unsqueeze(-1), max.unsqueeze(-1), min.unsqueeze(-1), std.unsqueeze(-1)], dim=-1)
|
| 164 |
+
features = features.flatten(-2)
|
| 165 |
+
degree_out = degree(index, dim_size).unsqueeze(0).unsqueeze(-1)
|
| 166 |
+
scale = degree_out.log()
|
| 167 |
+
scale = scale / scale.mean()
|
| 168 |
+
scales = torch.cat([torch.ones_like(scale), scale, 1 / scale.clamp(min=1e-2)], dim=-1)
|
| 169 |
+
output = (features.unsqueeze(-1) * scales.unsqueeze(-2)).flatten(-2)
|
| 170 |
+
else:
|
| 171 |
+
output = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size,
|
| 172 |
+
reduce=self.aggregate_func)
|
| 173 |
+
|
| 174 |
+
return output
|
| 175 |
+
|
| 176 |
+
def message_and_aggregate(self, edge_index, input, relation, boundary, edge_type, edge_weight, index, dim_size):
|
| 177 |
+
# fused computation of message and aggregate steps with the custom rspmm cuda kernel
|
| 178 |
+
# speed up computation by several times
|
| 179 |
+
# reduce memory complexity from O(|E|d) to O(|V|d), so we can apply it to larger graphs
|
| 180 |
+
from ultra.rspmm.rspmm import generalized_rspmm
|
| 181 |
+
|
| 182 |
+
batch_size, num_node = input.shape[:2]
|
| 183 |
+
input = input.transpose(0, 1).flatten(1)
|
| 184 |
+
relation = relation.transpose(0, 1).flatten(1)
|
| 185 |
+
boundary = boundary.transpose(0, 1).flatten(1)
|
| 186 |
+
degree_out = degree(index, dim_size).unsqueeze(-1) + 1
|
| 187 |
+
|
| 188 |
+
if self.message_func in self.message2mul:
|
| 189 |
+
mul = self.message2mul[self.message_func]
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError("Unknown message function `%s`" % self.message_func)
|
| 192 |
+
if self.aggregate_func == "sum":
|
| 193 |
+
update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
|
| 194 |
+
update = update + boundary
|
| 195 |
+
elif self.aggregate_func == "mean":
|
| 196 |
+
update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
|
| 197 |
+
update = (update + boundary) / degree_out
|
| 198 |
+
elif self.aggregate_func == "max":
|
| 199 |
+
update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="max", mul=mul)
|
| 200 |
+
update = torch.max(update, boundary)
|
| 201 |
+
elif self.aggregate_func == "pna":
|
| 202 |
+
# we use PNA with 4 aggregators (mean / max / min / std)
|
| 203 |
+
# and 3 scalars (identity / log degree / reciprocal of log degree)
|
| 204 |
+
sum = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
|
| 205 |
+
sq_sum = generalized_rspmm(edge_index, edge_type, edge_weight, relation ** 2, input ** 2, sum="add",
|
| 206 |
+
mul=mul)
|
| 207 |
+
max = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="max", mul=mul)
|
| 208 |
+
min = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="min", mul=mul)
|
| 209 |
+
mean = (sum + boundary) / degree_out
|
| 210 |
+
sq_mean = (sq_sum + boundary ** 2) / degree_out
|
| 211 |
+
max = torch.max(max, boundary)
|
| 212 |
+
min = torch.min(min, boundary) # (node, batch_size * input_dim)
|
| 213 |
+
std = (sq_mean - mean ** 2).clamp(min=self.eps).sqrt()
|
| 214 |
+
features = torch.cat([mean.unsqueeze(-1), max.unsqueeze(-1), min.unsqueeze(-1), std.unsqueeze(-1)], dim=-1)
|
| 215 |
+
features = features.flatten(-2) # (node, batch_size * input_dim * 4)
|
| 216 |
+
scale = degree_out.log()
|
| 217 |
+
scale = scale / scale.mean()
|
| 218 |
+
scales = torch.cat([torch.ones_like(scale), scale, 1 / scale.clamp(min=1e-2)], dim=-1) # (node, 3)
|
| 219 |
+
update = (features.unsqueeze(-1) * scales.unsqueeze(-2)).flatten(-2) # (node, batch_size * input_dim * 4 * 3)
|
| 220 |
+
else:
|
| 221 |
+
raise ValueError("Unknown aggregation function `%s`" % self.aggregate_func)
|
| 222 |
+
|
| 223 |
+
update = update.view(num_node, batch_size, -1).transpose(0, 1)
|
| 224 |
+
return update
|
| 225 |
+
|
| 226 |
+
def update(self, update, input):
|
| 227 |
+
# node update as a function of old states (input) and this layer output (update)
|
| 228 |
+
output = self.linear(torch.cat([input, update], dim=-1))
|
| 229 |
+
if self.layer_norm:
|
| 230 |
+
output = self.layer_norm(output)
|
| 231 |
+
if self.activation:
|
| 232 |
+
output = self.activation(output)
|
| 233 |
+
return output
|
| 234 |
+
|
ultra/models.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from . import tasks, layers
|
| 5 |
+
from ultra.base_nbfnet import BaseNBFNet
|
| 6 |
+
|
| 7 |
+
class Ultra(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, rel_model_cfg, entity_model_cfg):
|
| 10 |
+
# kept that because super Ultra sounds cool
|
| 11 |
+
super(Ultra, self).__init__()
|
| 12 |
+
|
| 13 |
+
self.relation_model = RelNBFNet(**rel_model_cfg)
|
| 14 |
+
self.entity_model = EntityNBFNet(**entity_model_cfg)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def forward(self, data, batch):
|
| 18 |
+
|
| 19 |
+
# batch shape: (bs, 1+num_negs, 3)
|
| 20 |
+
# relations are the same all positive and negative triples, so we can extract only one from the first triple among 1+nug_negs
|
| 21 |
+
query_rels = batch[:, 0, 2]
|
| 22 |
+
relation_representations = self.relation_model(data.relation_graph, query=query_rels)
|
| 23 |
+
score = self.entity_model(data, relation_representations, batch)
|
| 24 |
+
|
| 25 |
+
return score
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# NBFNet to work on the graph of relations with 4 fundamental interactions
|
| 29 |
+
# Doesn't have the final projection MLP from hidden dim -> 1, returns all node representations
|
| 30 |
+
# of shape [bs, num_rel, hidden]
|
| 31 |
+
class RelNBFNet(BaseNBFNet):
|
| 32 |
+
|
| 33 |
+
def __init__(self, input_dim, hidden_dims, num_relation=4, **kwargs):
|
| 34 |
+
super().__init__(input_dim, hidden_dims, num_relation, **kwargs)
|
| 35 |
+
|
| 36 |
+
self.layers = nn.ModuleList()
|
| 37 |
+
for i in range(len(self.dims) - 1):
|
| 38 |
+
self.layers.append(
|
| 39 |
+
layers.GeneralizedRelationalConv(
|
| 40 |
+
self.dims[i], self.dims[i + 1], num_relation,
|
| 41 |
+
self.dims[0], self.message_func, self.aggregate_func, self.layer_norm,
|
| 42 |
+
self.activation, dependent=False)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if self.concat_hidden:
|
| 46 |
+
feature_dim = sum(hidden_dims) + input_dim
|
| 47 |
+
self.mlp = nn.Sequential(
|
| 48 |
+
nn.Linear(feature_dim, feature_dim),
|
| 49 |
+
nn.ReLU(),
|
| 50 |
+
nn.Linear(feature_dim, input_dim)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def bellmanford(self, data, h_index, separate_grad=False):
|
| 55 |
+
batch_size = len(h_index)
|
| 56 |
+
|
| 57 |
+
# initialize initial nodes (relations of interest in the batcj) with all ones
|
| 58 |
+
query = torch.ones(h_index.shape[0], self.dims[0], device=h_index.device, dtype=torch.float)
|
| 59 |
+
index = h_index.unsqueeze(-1).expand_as(query)
|
| 60 |
+
|
| 61 |
+
# initial (boundary) condition - initialize all node states as zeros
|
| 62 |
+
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
|
| 63 |
+
#boundary = torch.zeros(data.num_nodes, *query.shape, device=h_index.device)
|
| 64 |
+
# Indicator function: by the scatter operation we put ones as init features of source (index) nodes
|
| 65 |
+
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
|
| 66 |
+
size = (data.num_nodes, data.num_nodes)
|
| 67 |
+
edge_weight = torch.ones(data.num_edges, device=h_index.device)
|
| 68 |
+
|
| 69 |
+
hiddens = []
|
| 70 |
+
edge_weights = []
|
| 71 |
+
layer_input = boundary
|
| 72 |
+
|
| 73 |
+
for layer in self.layers:
|
| 74 |
+
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
|
| 75 |
+
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
|
| 76 |
+
if self.short_cut and hidden.shape == layer_input.shape:
|
| 77 |
+
# residual connection here
|
| 78 |
+
hidden = hidden + layer_input
|
| 79 |
+
hiddens.append(hidden)
|
| 80 |
+
edge_weights.append(edge_weight)
|
| 81 |
+
layer_input = hidden
|
| 82 |
+
|
| 83 |
+
# original query (relation type) embeddings
|
| 84 |
+
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
|
| 85 |
+
if self.concat_hidden:
|
| 86 |
+
output = torch.cat(hiddens + [node_query], dim=-1)
|
| 87 |
+
output = self.mlp(output)
|
| 88 |
+
else:
|
| 89 |
+
output = hiddens[-1]
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
"node_feature": output,
|
| 93 |
+
"edge_weights": edge_weights,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
def forward(self, rel_graph, query):
|
| 97 |
+
|
| 98 |
+
# message passing and updated node representations (that are in fact relations)
|
| 99 |
+
output = self.bellmanford(rel_graph, h_index=query)["node_feature"] # (batch_size, num_nodes, hidden_dim)
|
| 100 |
+
|
| 101 |
+
return output
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class EntityNBFNet(BaseNBFNet):
|
| 105 |
+
|
| 106 |
+
def __init__(self, input_dim, hidden_dims, num_relation=1, **kwargs):
|
| 107 |
+
|
| 108 |
+
# dummy num_relation = 1 as we won't use it in the NBFNet layer
|
| 109 |
+
super().__init__(input_dim, hidden_dims, num_relation, **kwargs)
|
| 110 |
+
|
| 111 |
+
self.layers = nn.ModuleList()
|
| 112 |
+
for i in range(len(self.dims) - 1):
|
| 113 |
+
self.layers.append(
|
| 114 |
+
layers.GeneralizedRelationalConv(
|
| 115 |
+
self.dims[i], self.dims[i + 1], num_relation,
|
| 116 |
+
self.dims[0], self.message_func, self.aggregate_func, self.layer_norm,
|
| 117 |
+
self.activation, dependent=False, project_relations=True)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
feature_dim = (sum(hidden_dims) if self.concat_hidden else hidden_dims[-1]) + input_dim
|
| 121 |
+
self.mlp = nn.Sequential()
|
| 122 |
+
mlp = []
|
| 123 |
+
for i in range(self.num_mlp_layers - 1):
|
| 124 |
+
mlp.append(nn.Linear(feature_dim, feature_dim))
|
| 125 |
+
mlp.append(nn.ReLU())
|
| 126 |
+
mlp.append(nn.Linear(feature_dim, 1))
|
| 127 |
+
self.mlp = nn.Sequential(*mlp)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def bellmanford(self, data, h_index, r_index, separate_grad=False):
|
| 131 |
+
batch_size = len(r_index)
|
| 132 |
+
|
| 133 |
+
# initialize queries (relation types of the given triples)
|
| 134 |
+
query = self.query[torch.arange(batch_size, device=r_index.device), r_index]
|
| 135 |
+
index = h_index.unsqueeze(-1).expand_as(query)
|
| 136 |
+
|
| 137 |
+
# initial (boundary) condition - initialize all node states as zeros
|
| 138 |
+
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
|
| 139 |
+
# by the scatter operation we put query (relation) embeddings as init features of source (index) nodes
|
| 140 |
+
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
|
| 141 |
+
|
| 142 |
+
size = (data.num_nodes, data.num_nodes)
|
| 143 |
+
edge_weight = torch.ones(data.num_edges, device=h_index.device)
|
| 144 |
+
|
| 145 |
+
hiddens = []
|
| 146 |
+
edge_weights = []
|
| 147 |
+
layer_input = boundary
|
| 148 |
+
|
| 149 |
+
for layer in self.layers:
|
| 150 |
+
|
| 151 |
+
# for visualization
|
| 152 |
+
if separate_grad:
|
| 153 |
+
edge_weight = edge_weight.clone().requires_grad_()
|
| 154 |
+
|
| 155 |
+
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
|
| 156 |
+
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
|
| 157 |
+
if self.short_cut and hidden.shape == layer_input.shape:
|
| 158 |
+
# residual connection here
|
| 159 |
+
hidden = hidden + layer_input
|
| 160 |
+
hiddens.append(hidden)
|
| 161 |
+
edge_weights.append(edge_weight)
|
| 162 |
+
layer_input = hidden
|
| 163 |
+
|
| 164 |
+
# original query (relation type) embeddings
|
| 165 |
+
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
|
| 166 |
+
if self.concat_hidden:
|
| 167 |
+
output = torch.cat(hiddens + [node_query], dim=-1)
|
| 168 |
+
else:
|
| 169 |
+
output = torch.cat([hiddens[-1], node_query], dim=-1)
|
| 170 |
+
|
| 171 |
+
return {
|
| 172 |
+
"node_feature": output,
|
| 173 |
+
"edge_weights": edge_weights,
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
def forward(self, data, relation_representations, batch):
|
| 177 |
+
h_index, t_index, r_index = batch.unbind(-1)
|
| 178 |
+
|
| 179 |
+
# initial query representations are those from the relation graph
|
| 180 |
+
self.query = relation_representations
|
| 181 |
+
|
| 182 |
+
# initialize relations in each NBFNet layer (with uinque projection internally)
|
| 183 |
+
for layer in self.layers:
|
| 184 |
+
layer.relation = relation_representations
|
| 185 |
+
|
| 186 |
+
if self.training:
|
| 187 |
+
# Edge dropout in the training mode
|
| 188 |
+
# here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
|
| 189 |
+
# to make NBFNet iteration learn non-trivial paths
|
| 190 |
+
data = self.remove_easy_edges(data, h_index, t_index, r_index)
|
| 191 |
+
|
| 192 |
+
shape = h_index.shape
|
| 193 |
+
# turn all triples in a batch into a tail prediction mode
|
| 194 |
+
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index, num_direct_rel=data.num_relations // 2)
|
| 195 |
+
assert (h_index[:, [0]] == h_index).all()
|
| 196 |
+
assert (r_index[:, [0]] == r_index).all()
|
| 197 |
+
|
| 198 |
+
# message passing and updated node representations
|
| 199 |
+
output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) # (num_nodes, batch_size, feature_dim)
|
| 200 |
+
feature = output["node_feature"]
|
| 201 |
+
index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1])
|
| 202 |
+
# extract representations of tail entities from the updated node states
|
| 203 |
+
feature = feature.gather(1, index) # (batch_size, num_negative + 1, feature_dim)
|
| 204 |
+
|
| 205 |
+
# probability logit for each tail node in the batch
|
| 206 |
+
# (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1)
|
| 207 |
+
score = self.mlp(feature).squeeze(-1)
|
| 208 |
+
return score.view(shape)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
ultra/rspmm/rspmm.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import torch.backends.openmp
|
| 5 |
+
from torch import autograd
|
| 6 |
+
from torch.utils import cpp_extension
|
| 7 |
+
|
| 8 |
+
module = sys.modules[__name__]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RSPMMAddMulFunction(autograd.Function):
|
| 12 |
+
|
| 13 |
+
@staticmethod
|
| 14 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
| 15 |
+
node_in, node_out = edge_index
|
| 16 |
+
key = node_in * (node_out.max() + 1) + node_out
|
| 17 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
| 18 |
+
|
| 19 |
+
if input.device.type == "cuda":
|
| 20 |
+
forward = rspmm.rspmm_add_mul_forward_cuda
|
| 21 |
+
else:
|
| 22 |
+
forward = rspmm.rspmm_add_mul_forward_cpu
|
| 23 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
| 24 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
| 25 |
+
return output
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def backward(ctx, output_grad):
|
| 29 |
+
if output_grad.device.type == "cuda":
|
| 30 |
+
backward = rspmm.rspmm_add_mul_backward_cuda
|
| 31 |
+
else:
|
| 32 |
+
backward = rspmm.rspmm_add_mul_backward_cpu
|
| 33 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
| 34 |
+
return None, None, weight_grad, relation_grad, input_grad
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class RSPMMMinMulFunction(autograd.Function):
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
| 41 |
+
node_in, node_out = edge_index
|
| 42 |
+
key = node_in * (node_out.max() + 1) + node_out
|
| 43 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
| 44 |
+
|
| 45 |
+
if input.device.type == "cuda":
|
| 46 |
+
forward = rspmm.rspmm_min_mul_forward_cuda
|
| 47 |
+
else:
|
| 48 |
+
forward = rspmm.rspmm_min_mul_forward_cpu
|
| 49 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
| 50 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
| 51 |
+
return output
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def backward(ctx, output_grad):
|
| 55 |
+
if output_grad.device.type == "cuda":
|
| 56 |
+
backward = rspmm.rspmm_min_mul_backward_cuda
|
| 57 |
+
else:
|
| 58 |
+
backward = rspmm.rspmm_min_mul_backward_cpu
|
| 59 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
| 60 |
+
return None, None, weight_grad, relation_grad, input_grad
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class RSPMMMaxMulFunction(autograd.Function):
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
| 67 |
+
node_in, node_out = edge_index
|
| 68 |
+
key = node_in * (node_out.max() + 1) + node_out
|
| 69 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
| 70 |
+
|
| 71 |
+
if input.device.type == "cuda":
|
| 72 |
+
forward = rspmm.rspmm_max_mul_forward_cuda
|
| 73 |
+
else:
|
| 74 |
+
forward = rspmm.rspmm_max_mul_forward_cpu
|
| 75 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
| 76 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
| 77 |
+
return output
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def backward(ctx, output_grad):
|
| 81 |
+
if output_grad.device.type == "cuda":
|
| 82 |
+
backward = rspmm.rspmm_max_mul_backward_cuda
|
| 83 |
+
else:
|
| 84 |
+
backward = rspmm.rspmm_max_mul_backward_cpu
|
| 85 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
| 86 |
+
return None, None, weight_grad, relation_grad, input_grad
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class RSPMMAddAddFunction(autograd.Function):
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
| 93 |
+
node_in, node_out = edge_index
|
| 94 |
+
key = node_in * (node_out.max() + 1) + node_out
|
| 95 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
| 96 |
+
|
| 97 |
+
if input.device.type == "cuda":
|
| 98 |
+
forward = rspmm.rspmm_add_add_forward_cuda
|
| 99 |
+
else:
|
| 100 |
+
forward = rspmm.rspmm_add_add_forward_cpu
|
| 101 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
| 102 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
| 103 |
+
return output
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def backward(ctx, output_grad):
|
| 107 |
+
if output_grad.device.type == "cuda":
|
| 108 |
+
backward = rspmm.rspmm_add_add_backward_cuda
|
| 109 |
+
else:
|
| 110 |
+
backward = rspmm.rspmm_add_add_backward_cpu
|
| 111 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
| 112 |
+
return None, None, weight_grad, relation_grad, input_grad
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class RSPMMMinAddFunction(autograd.Function):
|
| 116 |
+
|
| 117 |
+
@staticmethod
|
| 118 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
| 119 |
+
node_in, node_out = edge_index
|
| 120 |
+
key = node_in * (node_out.max() + 1) + node_out
|
| 121 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
| 122 |
+
|
| 123 |
+
if input.device.type == "cuda":
|
| 124 |
+
forward = rspmm.rspmm_min_add_forward_cuda
|
| 125 |
+
else:
|
| 126 |
+
forward = rspmm.rspmm_min_add_forward_cpu
|
| 127 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
| 128 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
| 129 |
+
return output
|
| 130 |
+
|
| 131 |
+
@staticmethod
|
| 132 |
+
def backward(ctx, output_grad):
|
| 133 |
+
if output_grad.device.type == "cuda":
|
| 134 |
+
backward = rspmm.rspmm_min_add_backward_cuda
|
| 135 |
+
else:
|
| 136 |
+
backward = rspmm.rspmm_min_add_backward_cpu
|
| 137 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
| 138 |
+
return None, None, weight_grad, relation_grad, input_grad
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class RSPMMMaxAddFunction(autograd.Function):
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
| 145 |
+
node_in, node_out = edge_index
|
| 146 |
+
key = node_in * (node_out.max() + 1) + node_out
|
| 147 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
| 148 |
+
|
| 149 |
+
if input.device.type == "cuda":
|
| 150 |
+
forward = rspmm.rspmm_max_add_forward_cuda
|
| 151 |
+
else:
|
| 152 |
+
forward = rspmm.rspmm_max_add_forward_cpu
|
| 153 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
| 154 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
| 155 |
+
return output
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def backward(ctx, output_grad):
|
| 159 |
+
if output_grad.device.type == "cuda":
|
| 160 |
+
backward = rspmm.rspmm_max_add_backward_cuda
|
| 161 |
+
else:
|
| 162 |
+
backward = rspmm.rspmm_max_add_backward_cpu
|
| 163 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
| 164 |
+
return None, None, weight_grad, relation_grad, input_grad
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul="mul"):
|
| 168 |
+
name = "RSPMM%s%sFunction" % (sum.capitalize(), mul.capitalize())
|
| 169 |
+
if not hasattr(module, name):
|
| 170 |
+
raise ValueError("No generalized rspmm implementation found for summation `%s` and multiplication `%s`"
|
| 171 |
+
% (sum, mul))
|
| 172 |
+
Function = getattr(module, name)
|
| 173 |
+
|
| 174 |
+
node_in, node_out = edge_index
|
| 175 |
+
key = node_in * (node_out.max() + 1) + node_out
|
| 176 |
+
order = key.argsort()
|
| 177 |
+
|
| 178 |
+
return Function.apply(edge_index[:, order], edge_type[order], edge_weight[order], relation, input)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs):
|
| 182 |
+
if extra_cflags is None:
|
| 183 |
+
extra_cflags = ["-Ofast"]
|
| 184 |
+
if torch.backends.openmp.is_available():
|
| 185 |
+
extra_cflags += ["-fopenmp", "-DAT_PARALLEL_OPENMP"]
|
| 186 |
+
else:
|
| 187 |
+
extra_cflags.append("-DAT_PARALLEL_NATIVE")
|
| 188 |
+
if extra_cuda_cflags is None:
|
| 189 |
+
if torch.cuda.is_available():
|
| 190 |
+
extra_cuda_cflags = ["-O3"]
|
| 191 |
+
extra_cflags.append("-DCUDA_OP")
|
| 192 |
+
else:
|
| 193 |
+
new_sources = []
|
| 194 |
+
for source in sources:
|
| 195 |
+
if not cpp_extension._is_cuda_file(source):
|
| 196 |
+
new_sources.append(source)
|
| 197 |
+
sources = new_sources
|
| 198 |
+
|
| 199 |
+
return cpp_extension.load(name, sources, extra_cflags, extra_cuda_cflags, **kwargs)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
print("Load rspmm extension. This may take a while...")
|
| 203 |
+
path = os.path.join(os.path.dirname(__file__), "source")
|
| 204 |
+
rspmm = load_extension("rspmm", [os.path.join(path, "rspmm.cpp"), os.path.join(path, "rspmm.cu")])
|
ultra/rspmm/source/operator.cuh
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <limits>
|
| 4 |
+
|
| 5 |
+
#ifdef __CUDA_ARCH__
|
| 6 |
+
#define HOST_DEVICE __host__ __device__
|
| 7 |
+
#else
|
| 8 |
+
#define HOST_DEVICE
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
|
| 13 |
+
template <class scalar_t>
|
| 14 |
+
struct BinaryAdd {
|
| 15 |
+
HOST_DEVICE static scalar_t forward(scalar_t x, scalar_t y) {
|
| 16 |
+
return x + y;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
HOST_DEVICE static scalar_t backward_lhs(scalar_t x, scalar_t y) {
|
| 20 |
+
return 1;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
HOST_DEVICE static scalar_t backward_rhs(scalar_t x, scalar_t y) {
|
| 24 |
+
return 1;
|
| 25 |
+
}
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
template <class scalar_t>
|
| 29 |
+
struct BinaryMul {
|
| 30 |
+
HOST_DEVICE static scalar_t forward(scalar_t x, scalar_t y) {
|
| 31 |
+
return x * y;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
HOST_DEVICE static scalar_t backward_lhs(scalar_t x, scalar_t y) {
|
| 35 |
+
return y;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
HOST_DEVICE static scalar_t backward_rhs(scalar_t x, scalar_t y) {
|
| 39 |
+
return x;
|
| 40 |
+
}
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
template <class scalar_t>
|
| 44 |
+
struct NaryAdd {
|
| 45 |
+
HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
|
| 46 |
+
return result + x;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
|
| 50 |
+
return 1;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
static constexpr scalar_t zero = 0;
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
template <class scalar_t>
|
| 57 |
+
struct NaryMin {
|
| 58 |
+
HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
|
| 59 |
+
return result < x ? result : x;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
|
| 63 |
+
return result == x ? 1 : 0;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
static constexpr scalar_t zero = std::numeric_limits<scalar_t>::max();
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
template <class scalar_t>
|
| 70 |
+
struct NaryMax {
|
| 71 |
+
HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
|
| 72 |
+
return result > x ? result : x;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
|
| 76 |
+
return result == x ? 1 : 0;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
static constexpr scalar_t zero = std::numeric_limits<scalar_t>::lowest();
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
} // namespace at
|
ultra/rspmm/source/rspmm.cpp
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <mutex>
|
| 2 |
+
|
| 3 |
+
#include <ATen/Parallel.h>
|
| 4 |
+
|
| 5 |
+
#include "operator.cuh"
|
| 6 |
+
#include "rspmm.h"
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
// In PyTorch 1.4.0, parallel_for depends on some functions from at::internal in ATen/Parallel.h
|
| 11 |
+
// which are not explicitly included
|
| 12 |
+
// This is fixed in some new PyTorch release
|
| 13 |
+
using namespace at::internal;
|
| 14 |
+
|
| 15 |
+
void rspmm_forward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
|
| 16 |
+
const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg) {
|
| 17 |
+
checkDim(c, edge_index_arg, 2);
|
| 18 |
+
checkDim(c, edge_type_arg, 1);
|
| 19 |
+
checkDim(c, edge_weight_arg, 1);
|
| 20 |
+
checkDim(c, relation_arg, 2);
|
| 21 |
+
checkDim(c, input_arg, 2);
|
| 22 |
+
checkSameType(c, edge_index_arg, edge_type_arg);
|
| 23 |
+
checkAllSameType(c, {edge_weight_arg, relation_arg, input_arg});
|
| 24 |
+
checkSize(c, edge_index_arg, 0, 2);
|
| 25 |
+
checkSize(c, edge_type_arg, {edge_index_arg->size(1)});
|
| 26 |
+
checkSize(c, edge_weight_arg, {edge_index_arg->size(1)});
|
| 27 |
+
checkSize(c, relation_arg, 1, input_arg->size(1));
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
void rspmm_backward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
|
| 31 |
+
const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg,
|
| 32 |
+
const TensorArg &output_arg, const TensorArg &output_grad_arg) {
|
| 33 |
+
rspmm_forward_check(c, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg);
|
| 34 |
+
checkDim(c, output_arg, 2);
|
| 35 |
+
checkSameSize(c, output_arg, output_grad_arg);
|
| 36 |
+
checkAllSameType(c, {input_arg, output_arg, output_grad_arg});
|
| 37 |
+
checkSize(c, output_arg, 1, input_arg->size(1));
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
Tensor ind2ptr(const Tensor &index, int size) {
|
| 41 |
+
// scatter_add is super slow for int64, due to non-hardware atomic operations
|
| 42 |
+
// use int32 instead
|
| 43 |
+
Tensor num_per_index = at::zeros({size}, index.options().dtype(at::ScalarType::Int));
|
| 44 |
+
num_per_index.scatter_add_(0, index, at::ones(index.sizes(), num_per_index.options()));
|
| 45 |
+
num_per_index = num_per_index.toType(at::ScalarType::Long);
|
| 46 |
+
Tensor pointer = num_per_index.cumsum(0) - num_per_index;
|
| 47 |
+
return pointer;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
template <class scalar_t, class NaryOp, class BinaryOp>
|
| 51 |
+
void rspmm_forward_out_cpu(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
|
| 52 |
+
const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
|
| 53 |
+
scalar_t *output,
|
| 54 |
+
int64_t num_row, int64_t nnz, int64_t dim) {
|
| 55 |
+
parallel_for(0, num_row, 0, [&](int64_t row_start, int64_t row_end) {
|
| 56 |
+
for (int64_t row = row_start; row < row_end; row++) {
|
| 57 |
+
for (int64_t d = 0; d < dim; d++)
|
| 58 |
+
output[row * dim + d] = NaryOp::zero;
|
| 59 |
+
|
| 60 |
+
int64_t ptr_start = row_ptr[row];
|
| 61 |
+
int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
|
| 62 |
+
for (int64_t ptr = ptr_start; ptr < ptr_end; ptr++) {
|
| 63 |
+
int64_t col = col_ind[ptr];
|
| 64 |
+
int64_t layer = layer_ind[ptr];
|
| 65 |
+
scalar_t w = weight[ptr];
|
| 66 |
+
for (int64_t d = 0; d < dim; d++) {
|
| 67 |
+
scalar_t x = BinaryOp::forward(relation[layer * dim + d], input[col * dim + d]);
|
| 68 |
+
scalar_t y = w * x;
|
| 69 |
+
scalar_t &out = output[row * dim + d];
|
| 70 |
+
out = NaryOp::forward(out, y);
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
});
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
template <class scalar_t, class NaryOp, class BinaryOp>
|
| 78 |
+
void rspmm_backward_out_cpu(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
|
| 79 |
+
const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
|
| 80 |
+
const scalar_t *output, const scalar_t *output_grad,
|
| 81 |
+
scalar_t *weight_grad, scalar_t *relation_grad, scalar_t *input_grad,
|
| 82 |
+
int64_t num_row, int64_t nnz, int64_t dim,
|
| 83 |
+
std::vector<std::mutex> &relation_mutex, std::vector<std::mutex> &input_mutex) {
|
| 84 |
+
parallel_for(0, num_row, 0, [&](int64_t row_start, int64_t row_end) {
|
| 85 |
+
for (int64_t row = row_start; row < row_end; row++) {
|
| 86 |
+
int64_t ptr_start = row_ptr[row];
|
| 87 |
+
int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
|
| 88 |
+
for (int64_t ptr = ptr_start; ptr < ptr_end; ptr++) {
|
| 89 |
+
int64_t col = col_ind[ptr];
|
| 90 |
+
int64_t layer = layer_ind[ptr];
|
| 91 |
+
scalar_t w = weight[ptr];
|
| 92 |
+
scalar_t w_grad = 0;
|
| 93 |
+
for (int64_t d = 0; d < dim; d++) {
|
| 94 |
+
scalar_t rel = relation[layer * dim + d];
|
| 95 |
+
scalar_t in = input[col * dim + d];
|
| 96 |
+
scalar_t out = output[row * dim + d];
|
| 97 |
+
scalar_t out_grad = output_grad[row * dim + d];
|
| 98 |
+
scalar_t x = BinaryOp::forward(rel, in);
|
| 99 |
+
scalar_t y = w * x;
|
| 100 |
+
scalar_t dx_drel = BinaryOp::backward_lhs(rel, in);
|
| 101 |
+
scalar_t dx_din = BinaryOp::backward_rhs(rel, in);
|
| 102 |
+
scalar_t dout_dy = NaryOp::backward(out, y);
|
| 103 |
+
scalar_t dy_dw = x;
|
| 104 |
+
scalar_t dy_dx = w;
|
| 105 |
+
w_grad += out_grad * dout_dy * dy_dw;
|
| 106 |
+
{
|
| 107 |
+
std::lock_guard<std::mutex> lock(relation_mutex[layer * dim + d]);
|
| 108 |
+
relation_grad[layer * dim + d] += out_grad * dout_dy * dy_dx * dx_drel;
|
| 109 |
+
}
|
| 110 |
+
{
|
| 111 |
+
std::lock_guard<std::mutex> lock(input_mutex[col * dim + d]);
|
| 112 |
+
input_grad[col * dim + d] += out_grad * dout_dy * dy_dx * dx_din;
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
weight_grad[ptr] = w_grad;
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
});
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
template <template<class> class NaryOp, template<class> class BinaryOp>
|
| 122 |
+
Tensor rspmm_forward_cpu(const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
|
| 123 |
+
const Tensor &relation_, const Tensor &input_) {
|
| 124 |
+
constexpr const char *fn_name = "rspmm_forward_cpu";
|
| 125 |
+
TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
|
| 126 |
+
edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
|
| 127 |
+
input_arg(input_, "input", 5);
|
| 128 |
+
|
| 129 |
+
rspmm_forward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg);
|
| 130 |
+
checkDeviceType(fn_name, {edge_index_, edge_type_, edge_weight_, relation_, input_}, kCPU);
|
| 131 |
+
|
| 132 |
+
const Tensor edge_index = edge_index_.contiguous();
|
| 133 |
+
const Tensor edge_type = edge_type_.contiguous();
|
| 134 |
+
const Tensor edge_weight = edge_weight_.contiguous();
|
| 135 |
+
const Tensor relation = relation_.contiguous();
|
| 136 |
+
const Tensor input = input_.contiguous();
|
| 137 |
+
|
| 138 |
+
int64_t nnz = edge_index.size(0);
|
| 139 |
+
int64_t num_row = input.size(0);
|
| 140 |
+
int64_t dim = input.size(1);
|
| 141 |
+
Tensor output = at::empty({num_row, dim}, input.options());
|
| 142 |
+
|
| 143 |
+
Tensor row_ind = edge_index.select(0, 0);
|
| 144 |
+
Tensor row_ptr = ind2ptr(row_ind, num_row);
|
| 145 |
+
Tensor col_ind = edge_index.select(0, 1);
|
| 146 |
+
Tensor layer_ind = edge_type;
|
| 147 |
+
|
| 148 |
+
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_forward_cpu", [&] {
|
| 149 |
+
rspmm_forward_out_cpu<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>(
|
| 150 |
+
row_ptr.data_ptr<int64_t>(),
|
| 151 |
+
col_ind.data_ptr<int64_t>(),
|
| 152 |
+
layer_ind.data_ptr<int64_t>(),
|
| 153 |
+
edge_weight.data_ptr<scalar_t>(),
|
| 154 |
+
relation.data_ptr<scalar_t>(),
|
| 155 |
+
input.data_ptr<scalar_t>(),
|
| 156 |
+
output.data_ptr<scalar_t>(),
|
| 157 |
+
num_row, nnz, dim
|
| 158 |
+
);
|
| 159 |
+
});
|
| 160 |
+
|
| 161 |
+
return output;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
template <template<class> class NaryOp, template<class> class BinaryOp>
|
| 165 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_backward_cpu(
|
| 166 |
+
const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
|
| 167 |
+
const Tensor &relation_, const Tensor &input_, const Tensor &output_, const Tensor &output_grad_) {
|
| 168 |
+
constexpr const char *fn_name = "rspmm_backward_cpu";
|
| 169 |
+
TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
|
| 170 |
+
edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
|
| 171 |
+
input_arg(input_, "input", 5), output_arg(output_, "output", 6),
|
| 172 |
+
output_grad_arg(output_grad_, "output_grad", 7);
|
| 173 |
+
|
| 174 |
+
rspmm_backward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg,
|
| 175 |
+
output_arg, output_grad_arg);
|
| 176 |
+
checkDeviceType(fn_name, {edge_index_, edge_type_, edge_weight_, relation_, input_, output_, output_grad_}, kCPU);
|
| 177 |
+
|
| 178 |
+
const Tensor edge_index = edge_index_.contiguous();
|
| 179 |
+
const Tensor edge_type = edge_type_.contiguous();
|
| 180 |
+
const Tensor edge_weight = edge_weight_.contiguous();
|
| 181 |
+
const Tensor relation = relation_.contiguous();
|
| 182 |
+
const Tensor input = input_.contiguous();
|
| 183 |
+
const Tensor output = output_.contiguous();
|
| 184 |
+
const Tensor output_grad = output_grad_.contiguous();
|
| 185 |
+
|
| 186 |
+
int64_t nnz = edge_index.size(0);
|
| 187 |
+
int64_t num_row = input.size(0);
|
| 188 |
+
int64_t dim = input.size(1);
|
| 189 |
+
Tensor weight_grad = at::zeros_like(edge_weight);
|
| 190 |
+
Tensor relation_grad = at::zeros_like(relation);
|
| 191 |
+
Tensor input_grad = at::zeros_like(input);
|
| 192 |
+
|
| 193 |
+
Tensor row_ind = edge_index.select(0, 0);
|
| 194 |
+
Tensor row_ptr = ind2ptr(row_ind, num_row);
|
| 195 |
+
Tensor col_ind = edge_index.select(0, 1);
|
| 196 |
+
Tensor layer_ind = edge_type;
|
| 197 |
+
std::vector<std::mutex> relation_mutex(relation.numel());
|
| 198 |
+
std::vector<std::mutex> input_mutex(input.numel());
|
| 199 |
+
|
| 200 |
+
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cpu", [&] {
|
| 201 |
+
rspmm_backward_out_cpu<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>(
|
| 202 |
+
row_ptr.data_ptr<int64_t>(),
|
| 203 |
+
col_ind.data_ptr<int64_t>(),
|
| 204 |
+
layer_ind.data_ptr<int64_t>(),
|
| 205 |
+
edge_weight.data_ptr<scalar_t>(),
|
| 206 |
+
relation.data_ptr<scalar_t>(),
|
| 207 |
+
input.data_ptr<scalar_t>(),
|
| 208 |
+
output.data_ptr<scalar_t>(),
|
| 209 |
+
output_grad.data_ptr<scalar_t>(),
|
| 210 |
+
weight_grad.data_ptr<scalar_t>(),
|
| 211 |
+
relation_grad.data_ptr<scalar_t>(),
|
| 212 |
+
input_grad.data_ptr<scalar_t>(),
|
| 213 |
+
num_row, nnz, dim,
|
| 214 |
+
relation_mutex, input_mutex
|
| 215 |
+
);
|
| 216 |
+
});
|
| 217 |
+
|
| 218 |
+
return std::make_tuple(weight_grad, relation_grad, input_grad);
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
#define DECLARE_FORWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
|
| 222 |
+
Tensor rspmm_##ADD##_##MUL##_forward_cpu( \
|
| 223 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
|
| 224 |
+
const Tensor &relation, const Tensor &input) { \
|
| 225 |
+
return rspmm_forward_cpu<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input); \
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
#define DECLARE_BACKWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
|
| 229 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_##ADD##_##MUL##_backward_cpu( \
|
| 230 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
|
| 231 |
+
const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad) { \
|
| 232 |
+
return rspmm_backward_cpu<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input, \
|
| 233 |
+
output, output_grad); \
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
DECLARE_FORWARD_IMPL(add, mul, NaryAdd, BinaryMul)
|
| 237 |
+
DECLARE_BACKWARD_IMPL(add, mul, NaryAdd, BinaryMul)
|
| 238 |
+
|
| 239 |
+
DECLARE_FORWARD_IMPL(min, mul, NaryMin, BinaryMul)
|
| 240 |
+
DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul)
|
| 241 |
+
|
| 242 |
+
DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul)
|
| 243 |
+
DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul)
|
| 244 |
+
|
| 245 |
+
DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd)
|
| 246 |
+
DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd)
|
| 247 |
+
|
| 248 |
+
DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd)
|
| 249 |
+
DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd)
|
| 250 |
+
|
| 251 |
+
DECLARE_FORWARD_IMPL(max, add, NaryMax, BinaryAdd)
|
| 252 |
+
DECLARE_BACKWARD_IMPL(max, add, NaryMax, BinaryAdd)
|
| 253 |
+
|
| 254 |
+
} // namespace at
|
| 255 |
+
|
| 256 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 257 |
+
m.def("rspmm_add_mul_forward_cpu", &at::rspmm_add_mul_forward_cpu);
|
| 258 |
+
m.def("rspmm_add_mul_backward_cpu", &at::rspmm_add_mul_backward_cpu);
|
| 259 |
+
m.def("rspmm_min_mul_forward_cpu", &at::rspmm_min_mul_forward_cpu);
|
| 260 |
+
m.def("rspmm_min_mul_backward_cpu", &at::rspmm_min_mul_backward_cpu);
|
| 261 |
+
m.def("rspmm_max_mul_forward_cpu", &at::rspmm_max_mul_forward_cpu);
|
| 262 |
+
m.def("rspmm_max_mul_backward_cpu", &at::rspmm_max_mul_backward_cpu);
|
| 263 |
+
m.def("rspmm_add_add_forward_cpu", &at::rspmm_add_add_forward_cpu);
|
| 264 |
+
m.def("rspmm_add_add_backward_cpu", &at::rspmm_add_add_backward_cpu);
|
| 265 |
+
m.def("rspmm_min_add_forward_cpu", &at::rspmm_min_add_forward_cpu);
|
| 266 |
+
m.def("rspmm_min_add_backward_cpu", &at::rspmm_min_add_backward_cpu);
|
| 267 |
+
m.def("rspmm_max_add_forward_cpu", &at::rspmm_max_add_forward_cpu);
|
| 268 |
+
m.def("rspmm_max_add_backward_cpu", &at::rspmm_max_add_backward_cpu);
|
| 269 |
+
#ifdef CUDA_OP
|
| 270 |
+
m.def("rspmm_add_mul_forward_cuda", &at::rspmm_add_mul_forward_cuda);
|
| 271 |
+
m.def("rspmm_add_mul_backward_cuda", &at::rspmm_add_mul_backward_cuda);
|
| 272 |
+
m.def("rspmm_min_mul_forward_cuda", &at::rspmm_min_mul_forward_cuda);
|
| 273 |
+
m.def("rspmm_min_mul_backward_cuda", &at::rspmm_min_mul_backward_cuda);
|
| 274 |
+
m.def("rspmm_max_mul_forward_cuda", &at::rspmm_max_mul_forward_cuda);
|
| 275 |
+
m.def("rspmm_max_mul_backward_cuda", &at::rspmm_max_mul_backward_cuda);
|
| 276 |
+
m.def("rspmm_add_add_forward_cuda", &at::rspmm_add_add_forward_cuda);
|
| 277 |
+
m.def("rspmm_add_add_backward_cuda", &at::rspmm_add_add_backward_cuda);
|
| 278 |
+
m.def("rspmm_min_add_forward_cuda", &at::rspmm_min_add_forward_cuda);
|
| 279 |
+
m.def("rspmm_min_add_backward_cuda", &at::rspmm_min_add_backward_cuda);
|
| 280 |
+
m.def("rspmm_max_add_forward_cuda", &at::rspmm_max_add_forward_cuda);
|
| 281 |
+
m.def("rspmm_max_add_backward_cuda", &at::rspmm_max_add_backward_cuda);
|
| 282 |
+
#endif
|
| 283 |
+
}
|
ultra/rspmm/source/rspmm.cu
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 2 |
+
#include <THC/THCAtomics.cuh>
|
| 3 |
+
|
| 4 |
+
#include "util.cuh"
|
| 5 |
+
#include "operator.cuh"
|
| 6 |
+
#include "rspmm.h"
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
// Memory & time efficient implementation of generalized spmm
|
| 11 |
+
// Much of the code is inspired by GE-SpMM
|
| 12 |
+
// https://github.com/hgyhungry/ge-spmm
|
| 13 |
+
|
| 14 |
+
namespace {
|
| 15 |
+
|
| 16 |
+
const int kCoarseningFactor = 2;
|
| 17 |
+
const int kThreadPerBlock = 256;
|
| 18 |
+
|
| 19 |
+
} // namespace anonymous
|
| 20 |
+
|
| 21 |
+
template <class scalar_t, class NaryOp, class BinaryOp>
|
| 22 |
+
__global__
|
| 23 |
+
void rspmm_forward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
|
| 24 |
+
const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
|
| 25 |
+
scalar_t *output,
|
| 26 |
+
int64_t num_row, int64_t nnz, int64_t dim) {
|
| 27 |
+
// for best optimization, the following code is compiled with constant warpSize
|
| 28 |
+
assert(blockDim.x == warpSize);
|
| 29 |
+
|
| 30 |
+
extern __shared__ int64_t buffer[];
|
| 31 |
+
int64_t *col_ind_buf = buffer;
|
| 32 |
+
int64_t *layer_ind_buf = buffer + blockDim.y * warpSize;
|
| 33 |
+
scalar_t *weight_buf = reinterpret_cast<scalar_t *>(layer_ind_buf + blockDim.y * warpSize);
|
| 34 |
+
col_ind_buf += threadIdx.y * warpSize;
|
| 35 |
+
layer_ind_buf += threadIdx.y * warpSize;
|
| 36 |
+
weight_buf += threadIdx.y * warpSize;
|
| 37 |
+
|
| 38 |
+
int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
|
| 39 |
+
if (row >= num_row)
|
| 40 |
+
return;
|
| 41 |
+
int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x;
|
| 42 |
+
int64_t ptr_start = row_ptr[row];
|
| 43 |
+
int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
|
| 44 |
+
scalar_t out[kCoarseningFactor];
|
| 45 |
+
#pragma unroll
|
| 46 |
+
for (int64_t i = 0; i < kCoarseningFactor; i++)
|
| 47 |
+
out[i] = NaryOp::zero;
|
| 48 |
+
|
| 49 |
+
for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) {
|
| 50 |
+
int64_t ptr = block_ptr + threadIdx.x;
|
| 51 |
+
if (ptr < ptr_end) {
|
| 52 |
+
col_ind_buf[threadIdx.x] = col_ind[ptr];
|
| 53 |
+
layer_ind_buf[threadIdx.x] = layer_ind[ptr];
|
| 54 |
+
weight_buf[threadIdx.x] = weight[ptr];
|
| 55 |
+
}
|
| 56 |
+
__syncwarp();
|
| 57 |
+
|
| 58 |
+
int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr;
|
| 59 |
+
for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) {
|
| 60 |
+
int64_t col = col_ind_buf[offset_ptr];
|
| 61 |
+
int64_t layer = layer_ind_buf[offset_ptr];
|
| 62 |
+
scalar_t w = weight_buf[offset_ptr];
|
| 63 |
+
#pragma unroll
|
| 64 |
+
for (int64_t i = 0; i < kCoarseningFactor; i++) {
|
| 65 |
+
int64_t d = d_start + i * warpSize;
|
| 66 |
+
if (d >= dim)
|
| 67 |
+
break;
|
| 68 |
+
scalar_t x = BinaryOp::forward(relation[layer * dim + d], input[col * dim + d]);
|
| 69 |
+
scalar_t y = w * x;
|
| 70 |
+
out[i] = NaryOp::forward(out[i], y);
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
__syncwarp();
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
#pragma unroll
|
| 77 |
+
for (int64_t i = 0; i < kCoarseningFactor; i++) {
|
| 78 |
+
int64_t d = d_start + i * warpSize;
|
| 79 |
+
if (d >= dim)
|
| 80 |
+
break;
|
| 81 |
+
output[row * dim + d] = out[i];
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
template <class scalar_t, class NaryOp, class BinaryOp>
|
| 86 |
+
__global__
|
| 87 |
+
void rspmm_backward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
|
| 88 |
+
const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
|
| 89 |
+
const scalar_t *output, const scalar_t *output_grad,
|
| 90 |
+
scalar_t *weight_grad, scalar_t *relation_grad, scalar_t *input_grad,
|
| 91 |
+
int64_t num_row, int64_t nnz, int64_t dim) {
|
| 92 |
+
// for best optimization, the following code is compiled with constant warpSize
|
| 93 |
+
assert(blockDim.x == warpSize);
|
| 94 |
+
|
| 95 |
+
extern __shared__ int64_t buffer[];
|
| 96 |
+
int64_t *col_ind_buf = buffer;
|
| 97 |
+
int64_t *layer_ind_buf = col_ind_buf + blockDim.y * warpSize;
|
| 98 |
+
scalar_t *weight_buf = reinterpret_cast<scalar_t *>(layer_ind_buf + blockDim.y * warpSize);
|
| 99 |
+
col_ind_buf += threadIdx.y * warpSize;
|
| 100 |
+
layer_ind_buf += threadIdx.y * warpSize;
|
| 101 |
+
weight_buf += threadIdx.y * warpSize;
|
| 102 |
+
|
| 103 |
+
int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
|
| 104 |
+
if (row >= num_row)
|
| 105 |
+
return;
|
| 106 |
+
int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x;
|
| 107 |
+
int64_t ptr_start = row_ptr[row];
|
| 108 |
+
int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
|
| 109 |
+
|
| 110 |
+
for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) {
|
| 111 |
+
int64_t ptr = block_ptr + threadIdx.x;
|
| 112 |
+
if (ptr < ptr_end) {
|
| 113 |
+
col_ind_buf[threadIdx.x] = col_ind[ptr];
|
| 114 |
+
layer_ind_buf[threadIdx.x] = layer_ind[ptr];
|
| 115 |
+
weight_buf[threadIdx.x] = weight[ptr];
|
| 116 |
+
}
|
| 117 |
+
__syncwarp();
|
| 118 |
+
|
| 119 |
+
int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr;
|
| 120 |
+
for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) {
|
| 121 |
+
int64_t col = col_ind_buf[offset_ptr];
|
| 122 |
+
int64_t layer = layer_ind_buf[offset_ptr];
|
| 123 |
+
scalar_t w = weight_buf[offset_ptr];
|
| 124 |
+
scalar_t w_grad = 0;
|
| 125 |
+
#pragma unroll
|
| 126 |
+
for (int64_t i = 0; i < kCoarseningFactor; i++) {
|
| 127 |
+
int64_t d = d_start + i * warpSize;
|
| 128 |
+
if (d >= dim)
|
| 129 |
+
break;
|
| 130 |
+
scalar_t rel = relation[layer * dim + d];
|
| 131 |
+
scalar_t in = input[col * dim + d];
|
| 132 |
+
scalar_t out = output[row * dim + d];
|
| 133 |
+
scalar_t out_grad = output_grad[row * dim + d];
|
| 134 |
+
scalar_t x = BinaryOp::forward(rel, in);
|
| 135 |
+
scalar_t y = w * x;
|
| 136 |
+
scalar_t dx_drel = BinaryOp::backward_lhs(rel, in);
|
| 137 |
+
scalar_t dx_din = BinaryOp::backward_rhs(rel, in);
|
| 138 |
+
scalar_t dout_dy = NaryOp::backward(out, y);
|
| 139 |
+
scalar_t dy_dw = x;
|
| 140 |
+
scalar_t dy_dx = w;
|
| 141 |
+
w_grad += out_grad * dout_dy * dy_dw;
|
| 142 |
+
atomicAdd(&relation_grad[layer * dim + d], out_grad * dout_dy * dy_dx * dx_drel);
|
| 143 |
+
atomicAdd(&input_grad[col * dim + d], out_grad * dout_dy * dy_dx * dx_din);
|
| 144 |
+
}
|
| 145 |
+
w_grad = warp_reduce(w_grad);
|
| 146 |
+
if (threadIdx.x == 0)
|
| 147 |
+
atomicAdd(&weight_grad[block_ptr + offset_ptr], w_grad);
|
| 148 |
+
}
|
| 149 |
+
__syncwarp();
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
// only relation & input require gradients
|
| 154 |
+
template <class scalar_t, class NaryOp, class BinaryOp>
|
| 155 |
+
__global__
|
| 156 |
+
void rspmm_backward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
|
| 157 |
+
const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
|
| 158 |
+
const scalar_t *output, const scalar_t *output_grad,
|
| 159 |
+
scalar_t *relation_grad, scalar_t *input_grad,
|
| 160 |
+
int64_t num_row, int64_t nnz, int64_t dim) {
|
| 161 |
+
// for best optimization, the following code is compiled with constant warpSize
|
| 162 |
+
assert(blockDim.x == warpSize);
|
| 163 |
+
|
| 164 |
+
extern __shared__ int64_t buffer[];
|
| 165 |
+
int64_t *col_ind_buf = buffer;
|
| 166 |
+
int64_t *layer_ind_buf = col_ind_buf + blockDim.y * warpSize;
|
| 167 |
+
scalar_t *weight_buf = reinterpret_cast<scalar_t *>(layer_ind_buf + blockDim.y * warpSize);
|
| 168 |
+
col_ind_buf += threadIdx.y * warpSize;
|
| 169 |
+
layer_ind_buf += threadIdx.y * warpSize;
|
| 170 |
+
weight_buf += threadIdx.y * warpSize;
|
| 171 |
+
|
| 172 |
+
int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
|
| 173 |
+
if (row >= num_row)
|
| 174 |
+
return;
|
| 175 |
+
int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x;
|
| 176 |
+
int64_t ptr_start = row_ptr[row];
|
| 177 |
+
int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
|
| 178 |
+
|
| 179 |
+
for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) {
|
| 180 |
+
int64_t ptr = block_ptr + threadIdx.x;
|
| 181 |
+
if (ptr < ptr_end) {
|
| 182 |
+
col_ind_buf[threadIdx.x] = col_ind[ptr];
|
| 183 |
+
layer_ind_buf[threadIdx.x] = layer_ind[ptr];
|
| 184 |
+
weight_buf[threadIdx.x] = weight[ptr];
|
| 185 |
+
}
|
| 186 |
+
__syncwarp();
|
| 187 |
+
|
| 188 |
+
int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr;
|
| 189 |
+
for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) {
|
| 190 |
+
int64_t col = col_ind_buf[offset_ptr];
|
| 191 |
+
int64_t layer = layer_ind_buf[offset_ptr];
|
| 192 |
+
scalar_t w = weight_buf[offset_ptr];
|
| 193 |
+
#pragma unroll
|
| 194 |
+
for (int64_t i = 0; i < kCoarseningFactor; i++) {
|
| 195 |
+
int64_t d = d_start + i * warpSize;
|
| 196 |
+
if (d >= dim)
|
| 197 |
+
break;
|
| 198 |
+
scalar_t rel = relation[layer * dim + d];
|
| 199 |
+
scalar_t in = input[col * dim + d];
|
| 200 |
+
scalar_t out = output[row * dim + d];
|
| 201 |
+
scalar_t out_grad = output_grad[row * dim + d];
|
| 202 |
+
scalar_t x = BinaryOp::forward(rel, in);
|
| 203 |
+
scalar_t y = w * x;
|
| 204 |
+
scalar_t dx_drel = BinaryOp::backward_lhs(rel, in);
|
| 205 |
+
scalar_t dx_din = BinaryOp::backward_rhs(rel, in);
|
| 206 |
+
scalar_t dout_dy = NaryOp::backward(out, y);
|
| 207 |
+
scalar_t dy_dx = w;
|
| 208 |
+
atomicAdd(&relation_grad[layer * dim + d], out_grad * dout_dy * dy_dx * dx_drel);
|
| 209 |
+
atomicAdd(&input_grad[col * dim + d], out_grad * dout_dy * dy_dx * dx_din);
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
__syncwarp();
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
template <template<class> class NaryOp, template<class> class BinaryOp>
|
| 217 |
+
Tensor rspmm_forward_cuda(const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
|
| 218 |
+
const Tensor &relation_, const Tensor &input_) {
|
| 219 |
+
constexpr const char *fn_name = "rspmm_forward_cuda";
|
| 220 |
+
TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
|
| 221 |
+
edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
|
| 222 |
+
input_arg(input_, "input", 5);
|
| 223 |
+
|
| 224 |
+
rspmm_forward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg);
|
| 225 |
+
checkAllSameGPU(fn_name, {edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg});
|
| 226 |
+
|
| 227 |
+
const Tensor edge_index = edge_index_.contiguous();
|
| 228 |
+
const Tensor edge_type = edge_type_.contiguous();
|
| 229 |
+
const Tensor edge_weight = edge_weight_.contiguous();
|
| 230 |
+
const Tensor relation = relation_.contiguous();
|
| 231 |
+
const Tensor input = input_.contiguous();
|
| 232 |
+
|
| 233 |
+
int64_t nnz = edge_index.size(0);
|
| 234 |
+
int64_t num_row = input.size(0);
|
| 235 |
+
int64_t dim = input.size(1);
|
| 236 |
+
Tensor output = at::empty({num_row, dim}, input.options());
|
| 237 |
+
|
| 238 |
+
Tensor row_ind = edge_index.select(0, 0);
|
| 239 |
+
Tensor row_ptr = ind2ptr(row_ind, num_row);
|
| 240 |
+
Tensor col_ind = edge_index.select(0, 1);
|
| 241 |
+
Tensor layer_ind = edge_type;
|
| 242 |
+
|
| 243 |
+
cudaSetDevice(input.get_device());
|
| 244 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 245 |
+
|
| 246 |
+
const int dim_per_block = 32; // warpSize
|
| 247 |
+
const int num_dim_block = (dim + dim_per_block * kCoarseningFactor - 1) / (dim_per_block * kCoarseningFactor);
|
| 248 |
+
const int row_per_block = kThreadPerBlock / dim_per_block;
|
| 249 |
+
const int num_row_block = (num_row + row_per_block - 1) / row_per_block;
|
| 250 |
+
|
| 251 |
+
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_forward_cuda", [&] {
|
| 252 |
+
const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t));
|
| 253 |
+
rspmm_forward_out_cuda<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>
|
| 254 |
+
<<<dim3(num_row_block, num_dim_block), dim3(dim_per_block, row_per_block), memory_size, stream>>>(
|
| 255 |
+
row_ptr.data_ptr<int64_t>(),
|
| 256 |
+
col_ind.data_ptr<int64_t>(),
|
| 257 |
+
layer_ind.data_ptr<int64_t>(),
|
| 258 |
+
edge_weight.data_ptr<scalar_t>(),
|
| 259 |
+
relation.data_ptr<scalar_t>(),
|
| 260 |
+
input.data_ptr<scalar_t>(),
|
| 261 |
+
output.data_ptr<scalar_t>(),
|
| 262 |
+
num_row, nnz, dim
|
| 263 |
+
);
|
| 264 |
+
});
|
| 265 |
+
|
| 266 |
+
return output;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
template <template<class> class NaryOp, template<class> class BinaryOp>
|
| 270 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_backward_cuda(
|
| 271 |
+
const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
|
| 272 |
+
const Tensor &relation_, const Tensor &input_, const Tensor &output_, const Tensor &output_grad_) {
|
| 273 |
+
constexpr const char *fn_name = "rspmm_backward_cuda";
|
| 274 |
+
TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
|
| 275 |
+
edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
|
| 276 |
+
input_arg(input_, "input", 5), output_arg(output_, "output", 6),
|
| 277 |
+
output_grad_arg(output_grad_, "output_grad", 7);
|
| 278 |
+
|
| 279 |
+
rspmm_backward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg,
|
| 280 |
+
output_arg, output_grad_arg);
|
| 281 |
+
checkAllSameGPU(fn_name, {edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg, output_arg,
|
| 282 |
+
output_grad_arg});
|
| 283 |
+
|
| 284 |
+
const Tensor edge_index = edge_index_.contiguous();
|
| 285 |
+
const Tensor edge_type = edge_type_.contiguous();
|
| 286 |
+
const Tensor edge_weight = edge_weight_.contiguous();
|
| 287 |
+
const Tensor relation = relation_.contiguous();
|
| 288 |
+
const Tensor input = input_.contiguous();
|
| 289 |
+
const Tensor output = output_.contiguous();
|
| 290 |
+
const Tensor output_grad = output_grad_.contiguous();
|
| 291 |
+
|
| 292 |
+
int64_t nnz = edge_index.size(0);
|
| 293 |
+
int64_t num_row = input.size(0);
|
| 294 |
+
int64_t dim = input.size(1);
|
| 295 |
+
Tensor weight_grad = at::zeros_like(edge_weight);
|
| 296 |
+
Tensor relation_grad = at::zeros_like(relation);
|
| 297 |
+
Tensor input_grad = at::zeros_like(input);
|
| 298 |
+
|
| 299 |
+
Tensor row_ind = edge_index.select(0, 0);
|
| 300 |
+
Tensor row_ptr = ind2ptr(row_ind, num_row);
|
| 301 |
+
Tensor col_ind = edge_index.select(0, 1);
|
| 302 |
+
Tensor layer_ind = edge_type;
|
| 303 |
+
|
| 304 |
+
cudaSetDevice(input.get_device());
|
| 305 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 306 |
+
|
| 307 |
+
const int dim_per_block = 32; // warpSize
|
| 308 |
+
const int num_dim_block = (dim + dim_per_block * kCoarseningFactor - 1) / (dim_per_block * kCoarseningFactor);
|
| 309 |
+
const int row_per_block = kThreadPerBlock / dim_per_block;
|
| 310 |
+
const int num_row_block = (num_row + row_per_block - 1) / row_per_block;
|
| 311 |
+
|
| 312 |
+
if (edge_weight.requires_grad())
|
| 313 |
+
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cuda", [&] {
|
| 314 |
+
const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t));
|
| 315 |
+
rspmm_backward_out_cuda<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>
|
| 316 |
+
<<<dim3(num_row_block, num_dim_block), dim3(dim_per_block, row_per_block), memory_size, stream>>>(
|
| 317 |
+
row_ptr.data_ptr<int64_t>(),
|
| 318 |
+
col_ind.data_ptr<int64_t>(),
|
| 319 |
+
layer_ind.data_ptr<int64_t>(),
|
| 320 |
+
edge_weight.data_ptr<scalar_t>(),
|
| 321 |
+
relation.data_ptr<scalar_t>(),
|
| 322 |
+
input.data_ptr<scalar_t>(),
|
| 323 |
+
output.data_ptr<scalar_t>(),
|
| 324 |
+
output_grad.data_ptr<scalar_t>(),
|
| 325 |
+
weight_grad.data_ptr<scalar_t>(),
|
| 326 |
+
relation_grad.data_ptr<scalar_t>(),
|
| 327 |
+
input_grad.data_ptr<scalar_t>(),
|
| 328 |
+
num_row, nnz, dim
|
| 329 |
+
);
|
| 330 |
+
});
|
| 331 |
+
else
|
| 332 |
+
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cuda", [&] {
|
| 333 |
+
const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t));
|
| 334 |
+
rspmm_backward_out_cuda<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>
|
| 335 |
+
<<<dim3(num_row_block, num_dim_block), dim3(dim_per_block, row_per_block), memory_size, stream>>>(
|
| 336 |
+
row_ptr.data_ptr<int64_t>(),
|
| 337 |
+
col_ind.data_ptr<int64_t>(),
|
| 338 |
+
layer_ind.data_ptr<int64_t>(),
|
| 339 |
+
edge_weight.data_ptr<scalar_t>(),
|
| 340 |
+
relation.data_ptr<scalar_t>(),
|
| 341 |
+
input.data_ptr<scalar_t>(),
|
| 342 |
+
output.data_ptr<scalar_t>(),
|
| 343 |
+
output_grad.data_ptr<scalar_t>(),
|
| 344 |
+
relation_grad.data_ptr<scalar_t>(),
|
| 345 |
+
input_grad.data_ptr<scalar_t>(),
|
| 346 |
+
num_row, nnz, dim
|
| 347 |
+
);
|
| 348 |
+
});
|
| 349 |
+
|
| 350 |
+
return std::make_tuple(weight_grad, relation_grad, input_grad);
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
#define DECLARE_FORWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
|
| 354 |
+
Tensor rspmm_##ADD##_##MUL##_forward_cuda( \
|
| 355 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
|
| 356 |
+
const Tensor &relation, const Tensor &input) { \
|
| 357 |
+
return rspmm_forward_cuda<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input); \
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
#define DECLARE_BACKWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
|
| 361 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_##ADD##_##MUL##_backward_cuda( \
|
| 362 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
|
| 363 |
+
const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad) { \
|
| 364 |
+
return rspmm_backward_cuda<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input, \
|
| 365 |
+
output, output_grad); \
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
DECLARE_FORWARD_IMPL(add, mul, NaryAdd, BinaryMul)
|
| 369 |
+
DECLARE_BACKWARD_IMPL(add, mul, NaryAdd, BinaryMul)
|
| 370 |
+
|
| 371 |
+
DECLARE_FORWARD_IMPL(min, mul, NaryMin, BinaryMul)
|
| 372 |
+
DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul)
|
| 373 |
+
|
| 374 |
+
DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul)
|
| 375 |
+
DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul)
|
| 376 |
+
|
| 377 |
+
DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd)
|
| 378 |
+
DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd)
|
| 379 |
+
|
| 380 |
+
DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd)
|
| 381 |
+
DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd)
|
| 382 |
+
|
| 383 |
+
DECLARE_FORWARD_IMPL(max, add, NaryMax, BinaryAdd)
|
| 384 |
+
DECLARE_BACKWARD_IMPL(max, add, NaryMax, BinaryAdd)
|
| 385 |
+
|
| 386 |
+
} // namespace at
|
ultra/rspmm/source/rspmm.h
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <tuple>
|
| 4 |
+
|
| 5 |
+
#include <torch/extension.h>
|
| 6 |
+
//#include <ATen/SparseTensorUtils.h>
|
| 7 |
+
#include <ATen/native/SparseTensorUtils.h>
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
|
| 11 |
+
using namespace at::sparse;
|
| 12 |
+
|
| 13 |
+
void rspmm_forward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
|
| 14 |
+
const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg);
|
| 15 |
+
|
| 16 |
+
void rspmm_backward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
|
| 17 |
+
const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg,
|
| 18 |
+
const TensorArg &output_arg, const TensorArg &output_grad_arg);
|
| 19 |
+
|
| 20 |
+
Tensor ind2ptr(const Tensor &index, int size);
|
| 21 |
+
|
| 22 |
+
Tensor rspmm_add_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 23 |
+
const Tensor &relation, const Tensor &input);
|
| 24 |
+
|
| 25 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_add_mul_backward_cpu(
|
| 26 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 27 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 28 |
+
|
| 29 |
+
Tensor rspmm_min_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 30 |
+
const Tensor &relation, const Tensor &input);
|
| 31 |
+
|
| 32 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_min_mul_backward_cpu(
|
| 33 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 34 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 35 |
+
|
| 36 |
+
Tensor rspmm_max_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 37 |
+
const Tensor &relation, const Tensor &input);
|
| 38 |
+
|
| 39 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_max_mul_backward_cpu(
|
| 40 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 41 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 42 |
+
|
| 43 |
+
Tensor rspmm_add_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 44 |
+
const Tensor &relation, const Tensor &input);
|
| 45 |
+
|
| 46 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_add_add_backward_cpu(
|
| 47 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 48 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 49 |
+
|
| 50 |
+
Tensor rspmm_min_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 51 |
+
const Tensor &relation, const Tensor &input);
|
| 52 |
+
|
| 53 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_min_add_backward_cpu(
|
| 54 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 55 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 56 |
+
|
| 57 |
+
Tensor rspmm_max_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 58 |
+
const Tensor &relation, const Tensor &input);
|
| 59 |
+
|
| 60 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_max_add_backward_cpu(
|
| 61 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 62 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 63 |
+
|
| 64 |
+
#ifdef CUDA_OP
|
| 65 |
+
Tensor rspmm_add_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 66 |
+
const Tensor &relation, const Tensor &input);
|
| 67 |
+
|
| 68 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_add_mul_backward_cuda(
|
| 69 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 70 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 71 |
+
|
| 72 |
+
Tensor rspmm_min_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 73 |
+
const Tensor &relation, const Tensor &input);
|
| 74 |
+
|
| 75 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_min_mul_backward_cuda(
|
| 76 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 77 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 78 |
+
|
| 79 |
+
Tensor rspmm_max_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 80 |
+
const Tensor &relation, const Tensor &input);
|
| 81 |
+
|
| 82 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_max_mul_backward_cuda(
|
| 83 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 84 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 85 |
+
|
| 86 |
+
Tensor rspmm_add_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 87 |
+
const Tensor &relation, const Tensor &input);
|
| 88 |
+
|
| 89 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_add_add_backward_cuda(
|
| 90 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 91 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 92 |
+
|
| 93 |
+
Tensor rspmm_min_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 94 |
+
const Tensor &relation, const Tensor &input);
|
| 95 |
+
|
| 96 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_min_add_backward_cuda(
|
| 97 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 98 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 99 |
+
|
| 100 |
+
Tensor rspmm_max_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
| 101 |
+
const Tensor &relation, const Tensor &input);
|
| 102 |
+
|
| 103 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_max_add_backward_cuda(
|
| 104 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
| 105 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
| 106 |
+
#endif
|
| 107 |
+
|
| 108 |
+
} // namespace at
|
ultra/rspmm/source/util.cuh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at {
|
| 4 |
+
|
| 5 |
+
const unsigned kFullMask = 0xFFFFFFFF;
|
| 6 |
+
|
| 7 |
+
template <class scalar_t>
|
| 8 |
+
__device__ scalar_t warp_reduce(scalar_t value) {
|
| 9 |
+
#pragma unroll
|
| 10 |
+
for (int delta = 1; delta < warpSize; delta *= 2)
|
| 11 |
+
#if __CUDACC_VER_MAJOR__ >= 9
|
| 12 |
+
value += __shfl_down_sync(kFullMask, value, delta);
|
| 13 |
+
#else
|
| 14 |
+
value += __shfl_down(value, delta);
|
| 15 |
+
#endif
|
| 16 |
+
return value;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
template<class scalar_t>
|
| 20 |
+
__device__ scalar_t warp_broadcast(scalar_t value, int lane_id) {
|
| 21 |
+
#if __CUDACC_VER_MAJOR__ >= 9
|
| 22 |
+
return __shfl_sync(kFullMask, value, lane_id);
|
| 23 |
+
#else
|
| 24 |
+
return __shfl(value, lane_id);
|
| 25 |
+
#endif
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
} // namespace at
|
ultra/tasks.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import reduce
|
| 2 |
+
from torch_scatter import scatter_add
|
| 3 |
+
from torch_geometric.data import Data
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def edge_match(edge_index, query_index):
|
| 8 |
+
# O((n + q)logn) time
|
| 9 |
+
# O(n) memory
|
| 10 |
+
# edge_index: big underlying graph
|
| 11 |
+
# query_index: edges to match
|
| 12 |
+
|
| 13 |
+
# preparing unique hashing of edges, base: (max_node, max_relation) + 1
|
| 14 |
+
base = edge_index.max(dim=1)[0] + 1
|
| 15 |
+
# we will map edges to long ints, so we need to make sure the maximum product is less than MAX_LONG_INT
|
| 16 |
+
# idea: max number of edges = num_nodes * num_relations
|
| 17 |
+
# e.g. for a graph of 10 nodes / 5 relations, edge IDs 0...9 mean all possible outgoing edge types from node 0
|
| 18 |
+
# given a tuple (h, r), we will search for all other existing edges starting from head h
|
| 19 |
+
assert reduce(int.__mul__, base.tolist()) < torch.iinfo(torch.long).max
|
| 20 |
+
scale = base.cumprod(0)
|
| 21 |
+
scale = scale[-1] // scale
|
| 22 |
+
|
| 23 |
+
# hash both the original edge index and the query index to unique integers
|
| 24 |
+
edge_hash = (edge_index * scale.unsqueeze(-1)).sum(dim=0)
|
| 25 |
+
edge_hash, order = edge_hash.sort()
|
| 26 |
+
query_hash = (query_index * scale.unsqueeze(-1)).sum(dim=0)
|
| 27 |
+
|
| 28 |
+
# matched ranges: [start[i], end[i])
|
| 29 |
+
start = torch.bucketize(query_hash, edge_hash)
|
| 30 |
+
end = torch.bucketize(query_hash, edge_hash, right=True)
|
| 31 |
+
# num_match shows how many edges satisfy the (h, r) pattern for each query in the batch
|
| 32 |
+
num_match = end - start
|
| 33 |
+
|
| 34 |
+
# generate the corresponding ranges
|
| 35 |
+
offset = num_match.cumsum(0) - num_match
|
| 36 |
+
range = torch.arange(num_match.sum(), device=edge_index.device)
|
| 37 |
+
range = range + (start - offset).repeat_interleave(num_match)
|
| 38 |
+
|
| 39 |
+
return order[range], num_match
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def negative_sampling(data, batch, num_negative, strict=True):
|
| 43 |
+
batch_size = len(batch)
|
| 44 |
+
pos_h_index, pos_t_index, pos_r_index = batch.t()
|
| 45 |
+
|
| 46 |
+
# strict negative sampling vs random negative sampling
|
| 47 |
+
if strict:
|
| 48 |
+
t_mask, h_mask = strict_negative_mask(data, batch)
|
| 49 |
+
t_mask = t_mask[:batch_size // 2]
|
| 50 |
+
neg_t_candidate = t_mask.nonzero()[:, 1]
|
| 51 |
+
num_t_candidate = t_mask.sum(dim=-1)
|
| 52 |
+
# draw samples for negative tails
|
| 53 |
+
rand = torch.rand(len(t_mask), num_negative, device=batch.device)
|
| 54 |
+
index = (rand * num_t_candidate.unsqueeze(-1)).long()
|
| 55 |
+
index = index + (num_t_candidate.cumsum(0) - num_t_candidate).unsqueeze(-1)
|
| 56 |
+
neg_t_index = neg_t_candidate[index]
|
| 57 |
+
|
| 58 |
+
h_mask = h_mask[batch_size // 2:]
|
| 59 |
+
neg_h_candidate = h_mask.nonzero()[:, 1]
|
| 60 |
+
num_h_candidate = h_mask.sum(dim=-1)
|
| 61 |
+
# draw samples for negative heads
|
| 62 |
+
rand = torch.rand(len(h_mask), num_negative, device=batch.device)
|
| 63 |
+
index = (rand * num_h_candidate.unsqueeze(-1)).long()
|
| 64 |
+
index = index + (num_h_candidate.cumsum(0) - num_h_candidate).unsqueeze(-1)
|
| 65 |
+
neg_h_index = neg_h_candidate[index]
|
| 66 |
+
else:
|
| 67 |
+
neg_index = torch.randint(data.num_nodes, (batch_size, num_negative), device=batch.device)
|
| 68 |
+
neg_t_index, neg_h_index = neg_index[:batch_size // 2], neg_index[batch_size // 2:]
|
| 69 |
+
|
| 70 |
+
h_index = pos_h_index.unsqueeze(-1).repeat(1, num_negative + 1)
|
| 71 |
+
t_index = pos_t_index.unsqueeze(-1).repeat(1, num_negative + 1)
|
| 72 |
+
r_index = pos_r_index.unsqueeze(-1).repeat(1, num_negative + 1)
|
| 73 |
+
t_index[:batch_size // 2, 1:] = neg_t_index
|
| 74 |
+
h_index[batch_size // 2:, 1:] = neg_h_index
|
| 75 |
+
|
| 76 |
+
return torch.stack([h_index, t_index, r_index], dim=-1)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def all_negative(data, batch):
|
| 80 |
+
pos_h_index, pos_t_index, pos_r_index = batch.t()
|
| 81 |
+
r_index = pos_r_index.unsqueeze(-1).expand(-1, data.num_nodes)
|
| 82 |
+
# generate all negative tails for this batch
|
| 83 |
+
all_index = torch.arange(data.num_nodes, device=batch.device)
|
| 84 |
+
h_index, t_index = torch.meshgrid(pos_h_index, all_index, indexing="ij") # indexing "xy" would return transposed
|
| 85 |
+
t_batch = torch.stack([h_index, t_index, r_index], dim=-1)
|
| 86 |
+
# generate all negative heads for this batch
|
| 87 |
+
all_index = torch.arange(data.num_nodes, device=batch.device)
|
| 88 |
+
t_index, h_index = torch.meshgrid(pos_t_index, all_index, indexing="ij")
|
| 89 |
+
h_batch = torch.stack([h_index, t_index, r_index], dim=-1)
|
| 90 |
+
|
| 91 |
+
return t_batch, h_batch
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def strict_negative_mask(data, batch):
|
| 95 |
+
# this function makes sure that for a given (h, r) batch we will NOT sample true tails as random negatives
|
| 96 |
+
# similarly, for a given (t, r) we will NOT sample existing true heads as random negatives
|
| 97 |
+
|
| 98 |
+
pos_h_index, pos_t_index, pos_r_index = batch.t()
|
| 99 |
+
|
| 100 |
+
# part I: sample hard negative tails
|
| 101 |
+
# edge index of all (head, relation) edges from the underlying graph
|
| 102 |
+
edge_index = torch.stack([data.edge_index[0], data.edge_type])
|
| 103 |
+
# edge index of current batch (head, relation) for which we will sample negatives
|
| 104 |
+
query_index = torch.stack([pos_h_index, pos_r_index])
|
| 105 |
+
# search for all true tails for the given (h, r) batch
|
| 106 |
+
edge_id, num_t_truth = edge_match(edge_index, query_index)
|
| 107 |
+
# build an index from the found edges
|
| 108 |
+
t_truth_index = data.edge_index[1, edge_id]
|
| 109 |
+
sample_id = torch.arange(len(num_t_truth), device=batch.device).repeat_interleave(num_t_truth)
|
| 110 |
+
t_mask = torch.ones(len(num_t_truth), data.num_nodes, dtype=torch.bool, device=batch.device)
|
| 111 |
+
# assign 0s to the mask with the found true tails
|
| 112 |
+
t_mask[sample_id, t_truth_index] = 0
|
| 113 |
+
t_mask.scatter_(1, pos_t_index.unsqueeze(-1), 0)
|
| 114 |
+
|
| 115 |
+
# part II: sample hard negative heads
|
| 116 |
+
# edge_index[1] denotes tails, so the edge index becomes (t, r)
|
| 117 |
+
edge_index = torch.stack([data.edge_index[1], data.edge_type])
|
| 118 |
+
# edge index of current batch (tail, relation) for which we will sample heads
|
| 119 |
+
query_index = torch.stack([pos_t_index, pos_r_index])
|
| 120 |
+
# search for all true heads for the given (t, r) batch
|
| 121 |
+
edge_id, num_h_truth = edge_match(edge_index, query_index)
|
| 122 |
+
# build an index from the found edges
|
| 123 |
+
h_truth_index = data.edge_index[0, edge_id]
|
| 124 |
+
sample_id = torch.arange(len(num_h_truth), device=batch.device).repeat_interleave(num_h_truth)
|
| 125 |
+
h_mask = torch.ones(len(num_h_truth), data.num_nodes, dtype=torch.bool, device=batch.device)
|
| 126 |
+
# assign 0s to the mask with the found true heads
|
| 127 |
+
h_mask[sample_id, h_truth_index] = 0
|
| 128 |
+
h_mask.scatter_(1, pos_h_index.unsqueeze(-1), 0)
|
| 129 |
+
|
| 130 |
+
return t_mask, h_mask
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def compute_ranking(pred, target, mask=None):
|
| 134 |
+
pos_pred = pred.gather(-1, target.unsqueeze(-1))
|
| 135 |
+
if mask is not None:
|
| 136 |
+
# filtered ranking
|
| 137 |
+
ranking = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1
|
| 138 |
+
else:
|
| 139 |
+
# unfiltered ranking
|
| 140 |
+
ranking = torch.sum(pos_pred <= pred, dim=-1) + 1
|
| 141 |
+
return ranking
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def build_relation_graph(graph):
|
| 145 |
+
|
| 146 |
+
# expect the graph is already with inverse edges
|
| 147 |
+
|
| 148 |
+
edge_index, edge_type = graph.edge_index, graph.edge_type
|
| 149 |
+
num_nodes, num_rels = graph.num_nodes, graph.num_relations
|
| 150 |
+
device = edge_index.device
|
| 151 |
+
|
| 152 |
+
Eh = torch.vstack([edge_index[0], edge_type]).T.unique(dim=0) # (num_edges, 2)
|
| 153 |
+
Dh = scatter_add(torch.ones_like(Eh[:, 1]), Eh[:, 0])
|
| 154 |
+
|
| 155 |
+
EhT = torch.sparse_coo_tensor(
|
| 156 |
+
torch.flip(Eh, dims=[1]).T,
|
| 157 |
+
torch.ones(Eh.shape[0], device=device) / Dh[Eh[:, 0]],
|
| 158 |
+
(num_rels, num_nodes)
|
| 159 |
+
)
|
| 160 |
+
Eh = torch.sparse_coo_tensor(
|
| 161 |
+
Eh.T,
|
| 162 |
+
torch.ones(Eh.shape[0], device=device),
|
| 163 |
+
(num_nodes, num_rels)
|
| 164 |
+
)
|
| 165 |
+
Et = torch.vstack([edge_index[1], edge_type]).T.unique(dim=0) # (num_edges, 2)
|
| 166 |
+
|
| 167 |
+
Dt = scatter_add(torch.ones_like(Et[:, 1]), Et[:, 0])
|
| 168 |
+
assert not (Dt[Et[:, 0]] == 0).any()
|
| 169 |
+
|
| 170 |
+
EtT = torch.sparse_coo_tensor(
|
| 171 |
+
torch.flip(Et, dims=[1]).T,
|
| 172 |
+
torch.ones(Et.shape[0], device=device) / Dt[Et[:, 0]],
|
| 173 |
+
(num_rels, num_nodes)
|
| 174 |
+
)
|
| 175 |
+
Et = torch.sparse_coo_tensor(
|
| 176 |
+
Et.T,
|
| 177 |
+
torch.ones(Et.shape[0], device=device),
|
| 178 |
+
(num_nodes, num_rels)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
Ahh = torch.sparse.mm(EhT, Eh).coalesce()
|
| 182 |
+
Att = torch.sparse.mm(EtT, Et).coalesce()
|
| 183 |
+
Aht = torch.sparse.mm(EhT, Et).coalesce()
|
| 184 |
+
Ath = torch.sparse.mm(EtT, Eh).coalesce()
|
| 185 |
+
|
| 186 |
+
hh_edges = torch.cat([Ahh.indices().T, torch.zeros(Ahh.indices().T.shape[0], 1, dtype=torch.long).fill_(0)], dim=1) # head to head
|
| 187 |
+
tt_edges = torch.cat([Att.indices().T, torch.zeros(Att.indices().T.shape[0], 1, dtype=torch.long).fill_(1)], dim=1) # tail to tail
|
| 188 |
+
ht_edges = torch.cat([Aht.indices().T, torch.zeros(Aht.indices().T.shape[0], 1, dtype=torch.long).fill_(2)], dim=1) # head to tail
|
| 189 |
+
th_edges = torch.cat([Ath.indices().T, torch.zeros(Ath.indices().T.shape[0], 1, dtype=torch.long).fill_(3)], dim=1) # tail to head
|
| 190 |
+
|
| 191 |
+
rel_graph = Data(
|
| 192 |
+
edge_index=torch.cat([hh_edges[:, [0, 1]].T, tt_edges[:, [0, 1]].T, ht_edges[:, [0, 1]].T, th_edges[:, [0, 1]].T], dim=1),
|
| 193 |
+
edge_type=torch.cat([hh_edges[:, 2], tt_edges[:, 2], ht_edges[:, 2], th_edges[:, 2]], dim=0),
|
| 194 |
+
num_nodes=num_rels,
|
| 195 |
+
num_relations=4
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
graph.relation_graph = rel_graph
|
| 199 |
+
return graph
|
| 200 |
+
|
| 201 |
+
|
ultra/util.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import ast
|
| 4 |
+
import copy
|
| 5 |
+
import time
|
| 6 |
+
import logging
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
import yaml
|
| 10 |
+
import jinja2
|
| 11 |
+
from jinja2 import meta
|
| 12 |
+
import easydict
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import distributed as dist
|
| 16 |
+
from torch_geometric.data import Data
|
| 17 |
+
from torch_geometric.datasets import RelLinkPredDataset, WordNet18RR
|
| 18 |
+
|
| 19 |
+
from ultra import models, datasets
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__file__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def detect_variables(cfg_file):
|
| 26 |
+
with open(cfg_file, "r") as fin:
|
| 27 |
+
raw = fin.read()
|
| 28 |
+
env = jinja2.Environment()
|
| 29 |
+
tree = env.parse(raw)
|
| 30 |
+
vars = meta.find_undeclared_variables(tree)
|
| 31 |
+
return vars
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_config(cfg_file, context=None):
|
| 35 |
+
with open(cfg_file, "r") as fin:
|
| 36 |
+
raw = fin.read()
|
| 37 |
+
template = jinja2.Template(raw)
|
| 38 |
+
instance = template.render(context)
|
| 39 |
+
cfg = yaml.safe_load(instance)
|
| 40 |
+
cfg = easydict.EasyDict(cfg)
|
| 41 |
+
return cfg
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def literal_eval(string):
|
| 45 |
+
try:
|
| 46 |
+
return ast.literal_eval(string)
|
| 47 |
+
except (ValueError, SyntaxError):
|
| 48 |
+
return string
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def parse_args():
|
| 52 |
+
parser = argparse.ArgumentParser()
|
| 53 |
+
parser.add_argument("-c", "--config", help="yaml configuration file", required=True)
|
| 54 |
+
parser.add_argument("-s", "--seed", help="random seed for PyTorch", type=int, default=1024)
|
| 55 |
+
|
| 56 |
+
args, unparsed = parser.parse_known_args()
|
| 57 |
+
# get dynamic arguments defined in the config file
|
| 58 |
+
vars = detect_variables(args.config)
|
| 59 |
+
parser = argparse.ArgumentParser()
|
| 60 |
+
for var in vars:
|
| 61 |
+
parser.add_argument("--%s" % var, required=True)
|
| 62 |
+
vars = parser.parse_known_args(unparsed)[0]
|
| 63 |
+
vars = {k: literal_eval(v) for k, v in vars._get_kwargs()}
|
| 64 |
+
|
| 65 |
+
return args, vars
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_root_logger(file=True):
|
| 69 |
+
format = "%(asctime)-10s %(message)s"
|
| 70 |
+
datefmt = "%H:%M:%S"
|
| 71 |
+
logging.basicConfig(format=format, datefmt=datefmt)
|
| 72 |
+
logger = logging.getLogger("")
|
| 73 |
+
logger.setLevel(logging.INFO)
|
| 74 |
+
|
| 75 |
+
if file:
|
| 76 |
+
handler = logging.FileHandler("log.txt")
|
| 77 |
+
format = logging.Formatter(format, datefmt)
|
| 78 |
+
handler.setFormatter(format)
|
| 79 |
+
logger.addHandler(handler)
|
| 80 |
+
|
| 81 |
+
return logger
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_rank():
|
| 85 |
+
if dist.is_initialized():
|
| 86 |
+
return dist.get_rank()
|
| 87 |
+
if "RANK" in os.environ:
|
| 88 |
+
return int(os.environ["RANK"])
|
| 89 |
+
return 0
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_world_size():
|
| 93 |
+
if dist.is_initialized():
|
| 94 |
+
return dist.get_world_size()
|
| 95 |
+
if "WORLD_SIZE" in os.environ:
|
| 96 |
+
return int(os.environ["WORLD_SIZE"])
|
| 97 |
+
return 1
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def synchronize():
|
| 101 |
+
if get_world_size() > 1:
|
| 102 |
+
dist.barrier()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_device(cfg):
|
| 106 |
+
if cfg.train.gpus:
|
| 107 |
+
device = torch.device(cfg.train.gpus[get_rank()])
|
| 108 |
+
else:
|
| 109 |
+
device = torch.device("cpu")
|
| 110 |
+
return device
|
| 111 |
+
|
| 112 |
+
def get_devices(gpus):
|
| 113 |
+
if gpus is not None:
|
| 114 |
+
device = torch.device(gpus[get_rank()])
|
| 115 |
+
else:
|
| 116 |
+
device = torch.device("cpu")
|
| 117 |
+
return device
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def create_working_directory(cfg):
|
| 121 |
+
file_name = "working_dir.tmp"
|
| 122 |
+
world_size = get_world_size()
|
| 123 |
+
if cfg.train.gpus is not None and len(cfg.train.gpus) != world_size:
|
| 124 |
+
error_msg = "World size is %d but found %d GPUs in the argument"
|
| 125 |
+
if world_size == 1:
|
| 126 |
+
error_msg += ". Did you launch with `python -m torch.distributed.launch`?"
|
| 127 |
+
raise ValueError(error_msg % (world_size, len(cfg.train.gpus)))
|
| 128 |
+
if world_size > 1 and not dist.is_initialized():
|
| 129 |
+
dist.init_process_group("nccl", init_method="env://")
|
| 130 |
+
|
| 131 |
+
working_dir = os.path.join(os.path.expanduser(cfg.output_dir),
|
| 132 |
+
cfg.model["class"], cfg.dataset["class"], time.strftime("%Y-%m-%d-%H-%M-%S"))
|
| 133 |
+
|
| 134 |
+
# synchronize working directory
|
| 135 |
+
if get_rank() == 0:
|
| 136 |
+
with open(file_name, "w") as fout:
|
| 137 |
+
fout.write(working_dir)
|
| 138 |
+
os.makedirs(working_dir)
|
| 139 |
+
synchronize()
|
| 140 |
+
if get_rank() != 0:
|
| 141 |
+
with open(file_name, "r") as fin:
|
| 142 |
+
working_dir = fin.read()
|
| 143 |
+
synchronize()
|
| 144 |
+
if get_rank() == 0:
|
| 145 |
+
os.remove(file_name)
|
| 146 |
+
|
| 147 |
+
os.chdir(working_dir)
|
| 148 |
+
return working_dir
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def build_dataset(cfg):
|
| 152 |
+
data_config = copy.deepcopy(cfg.dataset)
|
| 153 |
+
cls = data_config.pop("class")
|
| 154 |
+
|
| 155 |
+
ds_cls = getattr(datasets, cls)
|
| 156 |
+
dataset = ds_cls(**data_config)
|
| 157 |
+
|
| 158 |
+
if get_rank() == 0:
|
| 159 |
+
logger.warning("%s dataset" % (cls if "version" not in cfg.dataset else f'{cls}({cfg.dataset.version})'))
|
| 160 |
+
if cls != "JointDataset":
|
| 161 |
+
logger.warning("#train: %d, #valid: %d, #test: %d" %
|
| 162 |
+
(dataset[0].target_edge_index.shape[1], dataset[1].target_edge_index.shape[1],
|
| 163 |
+
dataset[2].target_edge_index.shape[1]))
|
| 164 |
+
else:
|
| 165 |
+
logger.warning("#train: %d, #valid: %d, #test: %d" %
|
| 166 |
+
(sum(d.target_edge_index.shape[1] for d in dataset._data[0]),
|
| 167 |
+
sum(d.target_edge_index.shape[1] for d in dataset._data[1]),
|
| 168 |
+
sum(d.target_edge_index.shape[1] for d in dataset._data[2]),
|
| 169 |
+
))
|
| 170 |
+
|
| 171 |
+
return dataset
|
| 172 |
+
|