Source code for torch_geometric.datasets.motif_generator.custom

from typing import Any, Optional

from torch_geometric.data import Data
from torch_geometric.datasets.motif_generator import MotifGenerator
from torch_geometric.utils import from_networkx


[docs]class CustomMotif(MotifGenerator): r"""Generates a motif based on a custom structure coming from a :class:`torch_geometric.data.Data` or :class:`networkx.Graph` object. Args: structure (torch_geometric.data.Data or networkx.Graph): The structure to use as a motif. """ def __init__(self, structure: Any): super().__init__() self.structure: Optional[Data] = None if isinstance(structure, Data): self.structure = structure else: try: import networkx as nx if isinstance(structure, nx.Graph): self.structure = from_networkx(structure) except ImportError: pass if self.structure is None: raise ValueError(f"Expected a motif structure of type " f"'torch_geometric.data.Data' or 'networkx.Graph'" f"(got {type(structure)})") def __call__(self) -> Data: assert isinstance(self.structure, Data) return self.structure