Source code for torch_geometric.transforms.radius_graph

import torch_geometric
from torch_geometric.transforms import BaseTransform

[docs]class RadiusGraph(BaseTransform): r"""Creates edges based on node positions :obj:`pos` to all points within a given distance. Args: r (float): The distance. loop (bool, optional): If :obj:`True`, the graph will contain self-loops. (default: :obj:`False`) max_num_neighbors (int, optional): The maximum number of neighbors to return for each element in :obj:`y`. This flag is only needed for CUDA tensors. (default: :obj:`32`) flow (string, optional): The flow direction when using in combination with message passing (:obj:`"source_to_target"` or :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) """ def __init__( self, r: float, loop: bool = False, max_num_neighbors: int = 32, flow: str = 'source_to_target', ): self.r = r self.loop = loop self.max_num_neighbors = max_num_neighbors self.flow = flow def __call__(self, data): data.edge_attr = None batch = data.batch if 'batch' in data else None data.edge_index = torch_geometric.nn.radius_graph( data.pos, self.r, batch, self.loop, self.max_num_neighbors, self.flow) return data def __repr__(self) -> str: return f'{self.__class__.__name__}(r={self.r})'