from typing import Optional
from torch_geometric.typing import Adj
import torch
from torch import Tensor
from torch_sparse import SparseTensor
from torch_scatter import scatter_add
[docs]class WLConv(torch.nn.Module):
r"""The Weisfeiler Lehman operator from the `"A Reduction of a Graph to a
Canonical Form and an Algebra Arising During this Reduction"
<https://www.iti.zcu.cz/wl2018/pdf/wl_paper_translation.pdf>`_ paper, which
iteratively refines node colorings:
.. math::
\mathbf{x}^{\prime}_i = \textrm{hash} \left( \mathbf{x}_i, \{
\mathbf{x}_j \colon j \in \mathcal{N}(i) \} \right)
"""
def __init__(self):
super(WLConv, self).__init__()
self.hashmap = {}
[docs] def reset_parameters(self):
self.hashmap = {}
[docs] @torch.no_grad()
def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
""""""
if x.dim() > 1:
assert (x.sum(dim=-1) == 1).sum() == x.size(0)
x = x.argmax(dim=-1) # one-hot -> integer.
assert x.dtype == torch.long
adj_t = edge_index
if not isinstance(adj_t, SparseTensor):
adj_t = SparseTensor(row=edge_index[1], col=edge_index[0],
sparse_sizes=(x.size(0), x.size(0)))
out = []
_, col, _ = adj_t.coo()
deg = adj_t.storage.rowcount().tolist()
for node, neighbors in zip(x.tolist(), x[col].split(deg)):
idx = hash(tuple([node] + neighbors.sort()[0].tolist()))
if idx not in self.hashmap:
self.hashmap[idx] = len(self.hashmap)
out.append(self.hashmap[idx])
return torch.tensor(out, device=x.device)
[docs] def histogram(self, x: Tensor, batch: Optional[Tensor] = None,
norm: bool = False) -> Tensor:
r"""Given a node coloring :obj:`x`, computes the color histograms of
the respective graphs (separated by :obj:`batch`)."""
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
num_colors = len(self.hashmap)
batch_size = int(batch.max()) + 1
index = batch * num_colors + x
out = scatter_add(torch.ones_like(index), index, dim=0,
dim_size=num_colors * batch_size)
out = out.view(batch_size, num_colors)
if norm:
out = out.to(torch.float)
out /= out.norm(dim=-1, keepdim=True)
return out
def __repr__(self):
return '{}()'.format(self.__class__.__name__)