Source code for torch_geometric.utils.get_laplacian

from typing import Optional

import torch
from torch_scatter import scatter_add

from torch_geometric.utils import add_self_loops, remove_self_loops

from .num_nodes import maybe_num_nodes


[docs]def get_laplacian(edge_index, edge_weight: Optional[torch.Tensor] = None, normalization: Optional[str] = None, dtype: Optional[int] = None, num_nodes: Optional[int] = None): r""" Computes the graph Laplacian of the graph given by :obj:`edge_index` and optional :obj:`edge_weight`. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor, optional): One-dimensional edge weights. (default: :obj:`None`) normalization (str, optional): The normalization scheme for the graph Laplacian (default: :obj:`None`): 1. :obj:`None`: No normalization :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` 2. :obj:`"sym"`: Symmetric normalization :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}` 3. :obj:`"rw"`: Random-walk normalization :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` dtype (torch.dtype, optional): The desired data type of returned tensor in case :obj:`edge_weight=None`. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) Examples: >>> edge_index = torch.tensor([[0, 1, 1, 2], ... [1, 0, 2, 1]]) >>> edge_weight = torch.tensor([1., 2., 2., 4.]) >>> # No normalization >>> lap = get_laplacian(edge_index, edge_weight) >>> # Symmetric normalization >>> lap_sym = get_laplacian(edge_index, edge_weight, normalization='sym') >>> # Random-walk normalization >>> lap_rw = get_laplacian(edge_index, edge_weight, normalization='rw') """ if normalization is not None: assert normalization in ['sym', 'rw'] # 'Invalid normalization' edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) if edge_weight is None: edge_weight = torch.ones(edge_index.size(1), dtype=dtype, device=edge_index.device) num_nodes = maybe_num_nodes(edge_index, num_nodes) row, col = edge_index[0], edge_index[1] deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) if normalization is None: # L = D - A. edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) edge_weight = torch.cat([-edge_weight, deg], dim=0) elif normalization == 'sym': # Compute A_norm = -D^{-1/2} A D^{-1/2}. deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] # L = I - A_norm. edge_index, tmp = add_self_loops(edge_index, -edge_weight, fill_value=1., num_nodes=num_nodes) assert tmp is not None edge_weight = tmp else: # Compute A_norm = -D^{-1} A. deg_inv = 1.0 / deg deg_inv.masked_fill_(deg_inv == float('inf'), 0) edge_weight = deg_inv[row] * edge_weight # L = I - A_norm. edge_index, tmp = add_self_loops(edge_index, -edge_weight, fill_value=1., num_nodes=num_nodes) assert tmp is not None edge_weight = tmp return edge_index, edge_weight