import torch

from import Data
from torch_geometric.datasets.motif_generator import CustomMotif

[docs]class CycleMotif(CustomMotif): r"""Generates the cycle motif from the `"GNNExplainer: Generating Explanations for Graph Neural Networks" <>`__ paper. Args: num_nodes (int): The number of nodes in the cycle. """ def __init__(self, num_nodes: int): self.num_nodes = num_nodes row = torch.arange(num_nodes).view(-1, 1).repeat(1, 2).view(-1) col1 = torch.arange(-1, num_nodes - 1) % num_nodes col2 = torch.arange(1, num_nodes + 1) % num_nodes col = torch.stack([col1, col2], dim=1).sort(dim=-1)[0].view(-1) structure = Data( num_nodes=num_nodes, edge_index=torch.stack([row, col], dim=0), ) super().__init__(structure) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.num_nodes})'