Source code for torch_geometric.nn.models.node2vec

from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Embedding
from torch.utils.data import DataLoader

from torch_geometric.index import index2ptr
from torch_geometric.typing import WITH_PYG_LIB, WITH_TORCH_CLUSTER
from torch_geometric.utils import sort_edge_index
from torch_geometric.utils.num_nodes import maybe_num_nodes


[docs]class Node2Vec(torch.nn.Module): r"""The Node2Vec model from the `"node2vec: Scalable Feature Learning for Networks" <https://arxiv.org/abs/1607.00653>`_ paper where random walks of length :obj:`walk_length` are sampled in a given graph, and node embeddings are learned via negative sampling optimization. .. note:: For an example of using Node2Vec, see `examples/node2vec.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ node2vec.py>`_. Args: edge_index (torch.Tensor): The edge indices. embedding_dim (int): The size of each embedding vector. walk_length (int): The walk length. context_size (int): The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes. walks_per_node (int, optional): The number of walks to sample for each node. (default: :obj:`1`) p (float, optional): Likelihood of immediately revisiting a node in the walk. (default: :obj:`1`) q (float, optional): Control parameter to interpolate between breadth-first strategy and depth-first strategy (default: :obj:`1`) num_negative_samples (int, optional): The number of negative samples to use for each positive sample. (default: :obj:`1`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`) sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the weight matrix will be sparse. (default: :obj:`False`) """ def __init__( self, edge_index: Tensor, embedding_dim: int, walk_length: int, context_size: int, walks_per_node: int = 1, p: float = 1.0, q: float = 1.0, num_negative_samples: int = 1, num_nodes: Optional[int] = None, sparse: bool = False, ): super().__init__() if WITH_PYG_LIB and p == 1.0 and q == 1.0: self.random_walk_fn = torch.ops.pyg.random_walk elif WITH_TORCH_CLUSTER: self.random_walk_fn = torch.ops.torch_cluster.random_walk else: if p == 1.0 and q == 1.0: raise ImportError(f"'{self.__class__.__name__}' " f"requires either the 'pyg-lib' or " f"'torch-cluster' package") else: raise ImportError(f"'{self.__class__.__name__}' " f"requires the 'torch-cluster' package") self.num_nodes = maybe_num_nodes(edge_index, num_nodes) row, col = sort_edge_index(edge_index, num_nodes=self.num_nodes).cpu() self.rowptr, self.col = index2ptr(row, self.num_nodes), col self.EPS = 1e-15 assert walk_length >= context_size self.embedding_dim = embedding_dim self.walk_length = walk_length - 1 self.context_size = context_size self.walks_per_node = walks_per_node self.p = p self.q = q self.num_negative_samples = num_negative_samples self.embedding = Embedding(self.num_nodes, embedding_dim, sparse=sparse) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.embedding.reset_parameters()
[docs] def forward(self, batch: Optional[Tensor] = None) -> Tensor: """Returns the embeddings for the nodes in :obj:`batch`.""" emb = self.embedding.weight return emb if batch is None else emb[batch]
def loader(self, **kwargs) -> DataLoader: return DataLoader(range(self.num_nodes), collate_fn=self.sample, **kwargs) @torch.jit.export def pos_sample(self, batch: Tensor) -> Tensor: batch = batch.repeat(self.walks_per_node) rw = self.random_walk_fn(self.rowptr, self.col, batch, self.walk_length, self.p, self.q) if not isinstance(rw, Tensor): rw = rw[0] walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) @torch.jit.export def neg_sample(self, batch: Tensor) -> Tensor: batch = batch.repeat(self.walks_per_node * self.num_negative_samples) rw = torch.randint(self.num_nodes, (batch.size(0), self.walk_length), dtype=batch.dtype, device=batch.device) rw = torch.cat([batch.view(-1, 1), rw], dim=-1) walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) @torch.jit.export def sample(self, batch: Union[List[int], Tensor]) -> Tuple[Tensor, Tensor]: if not isinstance(batch, Tensor): batch = torch.tensor(batch) return self.pos_sample(batch), self.neg_sample(batch)
[docs] @torch.jit.export def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor: r"""Computes the loss given positive and negative random walks.""" # Positive loss. start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous() h_start = self.embedding(start).view(pos_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) pos_loss = -torch.log(torch.sigmoid(out) + self.EPS).mean() # Negative loss. start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous() h_start = self.embedding(start).view(neg_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1, self.embedding_dim) out = (h_start * h_rest).sum(dim=-1).view(-1) neg_loss = -torch.log(1 - torch.sigmoid(out) + self.EPS).mean() return pos_loss + neg_loss
[docs] def test( self, train_z: Tensor, train_y: Tensor, test_z: Tensor, test_y: Tensor, solver: str = 'lbfgs', *args, **kwargs, ) -> float: r"""Evaluates latent space quality via a logistic regression downstream task. """ from sklearn.linear_model import LogisticRegression clf = LogisticRegression(solver=solver, *args, **kwargs).fit(train_z.detach().cpu().numpy(), train_y.detach().cpu().numpy()) return clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy())
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.embedding.weight.size(0)}, ' f'{self.embedding.weight.size(1)})')