Source code for torch_geometric.transforms.remove_isolated_nodes

import re

import torch

from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import remove_isolated_nodes

[docs]class RemoveIsolatedNodes(BaseTransform): r"""Removes isolated nodes from the graph.""" def __call__(self, data): num_nodes = data.num_nodes out = remove_isolated_nodes(data.edge_index, data.edge_attr, num_nodes) data.edge_index, data.edge_attr, mask = out if hasattr(data, '__num_nodes__'): data.num_nodes = int(mask.sum()) for key, item in data: if bool('edge', key)): continue if torch.is_tensor(item) and item.size(0) == num_nodes: data[key] = item[mask] return data