Source code for torch_geometric.nn.models.metapath2vec
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Embedding
from torch.utils.data import DataLoader
from torch_geometric.index import index2ptr
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.utils import sort_edge_index
EPS = 1e-15
[docs]class MetaPath2Vec(torch.nn.Module):
r"""The MetaPath2Vec model from the `"metapath2vec: Scalable Representation
Learning for Heterogeneous Networks"
<https://ericdongyx.github.io/papers/
KDD17-dong-chawla-swami-metapath2vec.pdf>`_ paper where random walks based
on a given :obj:`metapath` are sampled in a heterogeneous graph, and node
embeddings are learned via negative sampling optimization.
.. note::
For an example of using MetaPath2Vec, see
`examples/hetero/metapath2vec.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
hetero/metapath2vec.py>`_.
Args:
edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): Dictionary
holding edge indices for each
:obj:`(src_node_type, rel_type, dst_node_type)` edge type present
in the heterogeneous graph.
embedding_dim (int): The size of each embedding vector.
metapath (List[Tuple[str, str, str]]): The metapath described as a list
of :obj:`(src_node_type, rel_type, dst_node_type)` tuples.
walk_length (int): The walk length.
context_size (int): The actual context size which is considered for
positive samples. This parameter increases the effective sampling
rate by reusing samples across different source nodes.
walks_per_node (int, optional): The number of walks to sample for each
node. (default: :obj:`1`)
num_negative_samples (int, optional): The number of negative samples to
use for each positive sample. (default: :obj:`1`)
num_nodes_dict (Dict[str, int], optional): Dictionary holding the
number of nodes for each node type. (default: :obj:`None`)
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
weight matrix will be sparse. (default: :obj:`False`)
"""
def __init__(
self,
edge_index_dict: Dict[EdgeType, Tensor],
embedding_dim: int,
metapath: List[EdgeType],
walk_length: int,
context_size: int,
walks_per_node: int = 1,
num_negative_samples: int = 1,
num_nodes_dict: Optional[Dict[NodeType, int]] = None,
sparse: bool = False,
):
super().__init__()
if num_nodes_dict is None:
num_nodes_dict = {}
for keys, edge_index in edge_index_dict.items():
key = keys[0]
N = int(edge_index[0].max() + 1)
num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))
key = keys[-1]
N = int(edge_index[1].max() + 1)
num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))
self.rowptr_dict, self.col_dict, self.rowcount_dict = {}, {}, {}
for keys, edge_index in edge_index_dict.items():
sizes = (num_nodes_dict[keys[0]], num_nodes_dict[keys[-1]])
row, col = sort_edge_index(edge_index, num_nodes=max(sizes)).cpu()
rowptr = index2ptr(row, size=sizes[0])
self.rowptr_dict[keys] = rowptr
self.col_dict[keys] = col
self.rowcount_dict[keys] = rowptr[1:] - rowptr[:-1]
for edge_type1, edge_type2 in zip(metapath[:-1], metapath[1:]):
if edge_type1[-1] != edge_type2[0]:
raise ValueError(
"Found invalid metapath. Ensure that the destination node "
"type matches with the source node type across all "
"consecutive edge types.")
assert walk_length + 1 >= context_size
if walk_length > len(metapath) and metapath[0][0] != metapath[-1][-1]:
raise AttributeError(
"The 'walk_length' is longer than the given 'metapath', but "
"the 'metapath' does not denote a cycle")
self.embedding_dim = embedding_dim
self.metapath = metapath
self.walk_length = walk_length
self.context_size = context_size
self.walks_per_node = walks_per_node
self.num_negative_samples = num_negative_samples
self.num_nodes_dict = num_nodes_dict
types = {x[0] for x in metapath} | {x[-1] for x in metapath}
types = sorted(list(types))
count = 0
self.start, self.end = {}, {}
for key in types:
self.start[key] = count
count += num_nodes_dict[key]
self.end[key] = count
offset = [self.start[metapath[0][0]]]
offset += [self.start[keys[-1]] for keys in metapath
] * int((walk_length / len(metapath)) + 1)
offset = offset[:walk_length + 1]
assert len(offset) == walk_length + 1
self.offset = torch.tensor(offset)
# + 1 denotes a dummy node used to link to for isolated nodes.
self.embedding = Embedding(count + 1, embedding_dim, sparse=sparse)
self.dummy_idx = count
self.reset_parameters()
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.embedding.reset_parameters()
[docs] def forward(self, node_type: str, batch: OptTensor = None) -> Tensor:
r"""Returns the embeddings for the nodes in :obj:`batch` of type
:obj:`node_type`.
"""
emb = self.embedding.weight[self.start[node_type]:self.end[node_type]]
return emb if batch is None else emb.index_select(0, batch)
[docs] def loader(self, **kwargs):
r"""Returns the data loader that creates both positive and negative
random walks on the heterogeneous graph.
Args:
**kwargs (optional): Arguments of
:class:`torch.utils.data.DataLoader`, such as
:obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or
:obj:`num_workers`.
"""
return DataLoader(range(self.num_nodes_dict[self.metapath[0][0]]),
collate_fn=self._sample, **kwargs)
def _pos_sample(self, batch: Tensor) -> Tensor:
batch = batch.repeat(self.walks_per_node)
rws = [batch]
for i in range(self.walk_length):
edge_type = self.metapath[i % len(self.metapath)]
batch = sample(
self.rowptr_dict[edge_type],
self.col_dict[edge_type],
self.rowcount_dict[edge_type],
batch,
num_neighbors=1,
dummy_idx=self.dummy_idx,
).view(-1)
rws.append(batch)
rw = torch.stack(rws, dim=-1)
rw.add_(self.offset.view(1, -1))
rw[rw > self.dummy_idx] = self.dummy_idx
walks = []
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
for j in range(num_walks_per_rw):
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)
def _neg_sample(self, batch: Tensor) -> Tensor:
batch = batch.repeat(self.walks_per_node * self.num_negative_samples)
rws = [batch]
for i in range(self.walk_length):
keys = self.metapath[i % len(self.metapath)]
batch = torch.randint(0, self.num_nodes_dict[keys[-1]],
(batch.size(0), ), dtype=torch.long)
rws.append(batch)
rw = torch.stack(rws, dim=-1)
rw.add_(self.offset.view(1, -1))
walks = []
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
for j in range(num_walks_per_rw):
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)
def _sample(self, batch: List[int]) -> Tuple[Tensor, Tensor]:
if not isinstance(batch, Tensor):
batch = torch.tensor(batch, dtype=torch.long)
return self._pos_sample(batch), self._neg_sample(batch)
[docs] def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor:
r"""Computes the loss given positive and negative random walks."""
# Positive loss.
start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous()
h_start = self.embedding(start).view(pos_rw.size(0), 1,
self.embedding_dim)
h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1,
self.embedding_dim)
out = (h_start * h_rest).sum(dim=-1).view(-1)
pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean()
# Negative loss.
start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous()
h_start = self.embedding(start).view(neg_rw.size(0), 1,
self.embedding_dim)
h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1,
self.embedding_dim)
out = (h_start * h_rest).sum(dim=-1).view(-1)
neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean()
return pos_loss + neg_loss
[docs] def test(self, train_z: Tensor, train_y: Tensor, test_z: Tensor,
test_y: Tensor, solver: str = "lbfgs", *args, **kwargs) -> float:
r"""Evaluates latent space quality via a logistic regression downstream
task.
"""
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(solver=solver, *args,
**kwargs).fit(train_z.detach().cpu().numpy(),
train_y.detach().cpu().numpy())
return clf.score(test_z.detach().cpu().numpy(),
test_y.detach().cpu().numpy())
def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'{self.embedding.weight.size(0) - 1}, '
f'{self.embedding.weight.size(1)})')
def sample(rowptr: Tensor, col: Tensor, rowcount: Tensor, subset: Tensor,
num_neighbors: int, dummy_idx: int) -> Tensor:
mask = subset >= dummy_idx
subset = subset.clamp(min=0, max=rowptr.numel() - 2)
count = rowcount[subset]
rand = torch.rand((subset.size(0), num_neighbors), device=subset.device)
rand *= count.to(rand.dtype).view(-1, 1)
rand = rand.to(torch.long) + rowptr[subset].view(-1, 1)
rand = rand.clamp(max=col.numel() - 1) # If last node is isolated.
col = col[rand] if col.numel() > 0 else rand
col[mask | (count == 0)] = dummy_idx
return col