# Source code for torch_geometric.nn.conv.wl_conv

from typing import Optional

import torch
from torch import Tensor

from torch_geometric.utils import (
degree,
is_sparse,
scatter,
sort_edge_index,
to_edge_index,
)

[docs]class WLConv(torch.nn.Module):
r"""The Weisfeiler Lehman (WL) operator from the "A Reduction of a Graph
to a Canonical Form and an Algebra Arising During this Reduction"
<https://www.iti.zcu.cz/wl2018/pdf/wl_paper_translation.pdf>_ paper.

:class:WLConv iteratively refines node colorings according to:

.. math::
\mathbf{x}^{\prime}_i = \textrm{hash} \left( \mathbf{x}_i, \{
\mathbf{x}_j \colon j \in \mathcal{N}(i) \} \right)

Shapes:
- **input:**
node coloring :math:(|\mathcal{V}|, F_{in}) *(one-hot encodings)*
or :math:(|\mathcal{V}|) *(integer-based)*,
edge indices :math:(2, |\mathcal{E}|)
- **output:** node coloring :math:(|\mathcal{V}|) *(integer-based)*
"""
def __init__(self):
super().__init__()
self.hashmap = {}

[docs]    def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.hashmap = {}

def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
r"""Runs the forward pass of the module."""
if x.dim() > 1:
assert (x.sum(dim=-1) == 1).sum() == x.size(0)
x = x.argmax(dim=-1)  # one-hot -> integer.
assert x.dtype == torch.long

if is_sparse(edge_index):
col_and_row, _ = to_edge_index(edge_index)
col = col_and_row[0]
row = col_and_row[1]
else:
edge_index = sort_edge_index(edge_index, num_nodes=x.size(0),
sort_by_row=False)
row, col = edge_index[0], edge_index[1]

# col is sorted, so we can use it to split neighbors to groups:
deg = degree(col, x.size(0), dtype=torch.long).tolist()

out = []
for node, neighbors in zip(x.tolist(), x[row].split(deg)):
idx = hash(tuple([node] + neighbors.sort()[0].tolist()))
if idx not in self.hashmap:
self.hashmap[idx] = len(self.hashmap)
out.append(self.hashmap[idx])

[docs]    def histogram(self, x: Tensor, batch: Optional[Tensor] = None,
norm: bool = False) -> Tensor:
r"""Given a node coloring :obj:x, computes the color histograms of
the respective graphs (separated by :obj:batch).
"""
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

num_colors = len(self.hashmap)
batch_size = int(batch.max()) + 1

index = batch * num_colors + x
out = scatter(torch.ones_like(index), index, dim=0,
dim_size=num_colors * batch_size, reduce='sum')
out = out.view(batch_size, num_colors)

if norm:
out = out.to(torch.float)
out /= out.norm(dim=-1, keepdim=True)

return out