Source code for torch_geometric.nn.conv.wl_conv

from typing import Optional

import torch
from torch import Tensor

from torch_geometric.typing import Adj
from torch_geometric.utils import (
    degree,
    is_sparse,
    scatter,
    sort_edge_index,
    to_edge_index,
)


[docs]class WLConv(torch.nn.Module): r"""The Weisfeiler Lehman (WL) operator from the `"A Reduction of a Graph to a Canonical Form and an Algebra Arising During this Reduction" <https://www.iti.zcu.cz/wl2018/pdf/wl_paper_translation.pdf>`_ paper. :class:`WLConv` iteratively refines node colorings according to: .. math:: \mathbf{x}^{\prime}_i = \textrm{hash} \left( \mathbf{x}_i, \{ \mathbf{x}_j \colon j \in \mathcal{N}(i) \} \right) Shapes: - **input:** node coloring :math:`(|\mathcal{V}|, F_{in})` *(one-hot encodings)* or :math:`(|\mathcal{V}|)` *(integer-based)*, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node coloring :math:`(|\mathcal{V}|)` *(integer-based)* """ def __init__(self): super().__init__() self.hashmap = {}
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.hashmap = {}
[docs] @torch.no_grad() def forward(self, x: Tensor, edge_index: Adj) -> Tensor: r"""Runs the forward pass of the module.""" if x.dim() > 1: assert (x.sum(dim=-1) == 1).sum() == x.size(0) x = x.argmax(dim=-1) # one-hot -> integer. assert x.dtype == torch.long if is_sparse(edge_index): col_and_row, _ = to_edge_index(edge_index) col = col_and_row[0] row = col_and_row[1] else: edge_index = sort_edge_index(edge_index, num_nodes=x.size(0), sort_by_row=False) row, col = edge_index[0], edge_index[1] # `col` is sorted, so we can use it to `split` neighbors to groups: deg = degree(col, x.size(0), dtype=torch.long).tolist() out = [] for node, neighbors in zip(x.tolist(), x[row].split(deg)): idx = hash(tuple([node] + neighbors.sort()[0].tolist())) if idx not in self.hashmap: self.hashmap[idx] = len(self.hashmap) out.append(self.hashmap[idx]) return torch.tensor(out, device=x.device)
[docs] def histogram(self, x: Tensor, batch: Optional[Tensor] = None, norm: bool = False) -> Tensor: r"""Given a node coloring :obj:`x`, computes the color histograms of the respective graphs (separated by :obj:`batch`). """ if batch is None: batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) num_colors = len(self.hashmap) batch_size = int(batch.max()) + 1 index = batch * num_colors + x out = scatter(torch.ones_like(index), index, dim=0, dim_size=num_colors * batch_size, reduce='sum') out = out.view(batch_size, num_colors) if norm: out = out.to(torch.float) out /= out.norm(dim=-1, keepdim=True) return out