import torch
from torch.nn import Embedding
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor
from sklearn.linear_model import LogisticRegression
from torch_geometric.utils.num_nodes import maybe_num_nodes
try:
import torch_cluster # noqa
random_walk = torch.ops.torch_cluster.random_walk
except ImportError:
random_walk = None
EPS = 1e-15
[docs]class Node2Vec(torch.nn.Module):
r"""The Node2Vec model from the
`"node2vec: Scalable Feature Learning for Networks"
<https://arxiv.org/abs/1607.00653>`_ paper where random walks of
length :obj:`walk_length` are sampled in a given graph, and node embeddings
are learned via negative sampling optimization.
.. note::
For an example of using Node2Vec, see `examples/node2vec.py
<https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
node2vec.py>`_.
Args:
edge_index (LongTensor): The edge indices.
embedding_dim (int): The size of each embedding vector.
walk_length (int): The walk length.
context_size (int): The actual context size which is considered for
positive samples. This parameter increases the effective sampling
rate by reusing samples across different source nodes.
walks_per_node (int, optional): The number of walks to sample for each
node. (default: :obj:`1`)
p (float, optional): Likelihood of immediately revisiting a node in the
walk. (default: :obj:`1`)
q (float, optional): Control parameter to interpolate between
breadth-first strategy and depth-first strategy (default: :obj:`1`)
num_negative_samples (int, optional): The number of negative samples to
use for each positive sample. (default: :obj:`1`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
weight matrix will be sparse. (default: :obj:`False`)
"""
def __init__(self, edge_index, embedding_dim, walk_length, context_size,
walks_per_node=1, p=1, q=1, num_negative_samples=1,
num_nodes=None, sparse=False):
super(Node2Vec, self).__init__()
if random_walk is None:
raise ImportError('`Node2Vec` requires `torch-cluster`.')
N = maybe_num_nodes(edge_index, num_nodes)
row, col = edge_index
self.adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
self.adj = self.adj.to('cpu')
assert walk_length >= context_size
self.embedding_dim = embedding_dim
self.walk_length = walk_length - 1
self.context_size = context_size
self.walks_per_node = walks_per_node
self.p = p
self.q = q
self.num_negative_samples = num_negative_samples
self.embedding = Embedding(N, embedding_dim, sparse=sparse)
self.reset_parameters()
[docs] def reset_parameters(self):
self.embedding.reset_parameters()
[docs] def forward(self, batch=None):
"""Returns the embeddings for the nodes in :obj:`batch`."""
emb = self.embedding.weight
return emb if batch is None else emb[batch]
[docs] def loader(self, **kwargs):
return DataLoader(range(self.adj.sparse_size(0)),
collate_fn=self.sample, **kwargs)
[docs] def pos_sample(self, batch):
batch = batch.repeat(self.walks_per_node)
rowptr, col, _ = self.adj.csr()
rw = random_walk(rowptr, col, batch, self.walk_length, self.p, self.q)
if not isinstance(rw, torch.Tensor):
rw = rw[0]
walks = []
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
for j in range(num_walks_per_rw):
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)
[docs] def neg_sample(self, batch):
batch = batch.repeat(self.walks_per_node * self.num_negative_samples)
rw = torch.randint(self.adj.sparse_size(0),
(batch.size(0), self.walk_length))
rw = torch.cat([batch.view(-1, 1), rw], dim=-1)
walks = []
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
for j in range(num_walks_per_rw):
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)
[docs] def sample(self, batch):
if not isinstance(batch, torch.Tensor):
batch = torch.tensor(batch)
return self.pos_sample(batch), self.neg_sample(batch)
[docs] def loss(self, pos_rw, neg_rw):
r"""Computes the loss given positive and negative random walks."""
# Positive loss.
start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous()
h_start = self.embedding(start).view(pos_rw.size(0), 1,
self.embedding_dim)
h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1,
self.embedding_dim)
out = (h_start * h_rest).sum(dim=-1).view(-1)
pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean()
# Negative loss.
start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous()
h_start = self.embedding(start).view(neg_rw.size(0), 1,
self.embedding_dim)
h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1,
self.embedding_dim)
out = (h_start * h_rest).sum(dim=-1).view(-1)
neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean()
return pos_loss + neg_loss
[docs] def test(self, train_z, train_y, test_z, test_y, solver='lbfgs',
multi_class='auto', *args, **kwargs):
r"""Evaluates latent space quality via a logistic regression downstream
task."""
clf = LogisticRegression(solver=solver, multi_class=multi_class, *args,
**kwargs).fit(train_z.detach().cpu().numpy(),
train_y.detach().cpu().numpy())
return clf.score(test_z.detach().cpu().numpy(),
test_y.detach().cpu().numpy())
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__,
self.embedding.weight.size(0),
self.embedding.weight.size(1))