torch_geometric.utils
Reduces all values from the |
|
Reduces all values in the first dimension of the |
|
Sorts the elements of the |
|
Computes the (unweighted) degree of a given one-dimensional index tensor. |
|
Computes a sparsely evaluated softmax. |
|
Randomly drops nodes from the adjacency matrix |
|
Randomly drops edges from the adjacency matrix |
|
Drops edges from the adjacency matrix |
|
Randomly drops edges from the adjacency matrix |
|
Randomly shuffle the feature matrix |
|
Randomly masks feature from the feature matrix |
|
Randomly adds edges to |
|
Row-wise sorts |
|
Row-wise sorts |
|
Returns |
|
Converts the graph given by |
|
Returns |
|
Removes every self-loop in the graph given by |
|
Segregates self-loops from the graph. |
|
Adds a self-loop \((i,i) \in \mathcal{E}\) to every node \(i \in \mathcal{V}\) in the graph given by |
|
Adds remaining self-loop \((i,i) \in \mathcal{E}\) to every node \(i \in \mathcal{V}\) in the graph given by |
|
Returns the edge features or weights of self-loops \((i, i)\) of every node \(i \in \mathcal{V}\) in the graph given by |
|
Returns |
|
Removes the isolated nodes from the graph given by |
|
Returns the number of hops the model is aggregating information from. |
|
Returns the induced subgraph of |
|
Returns the induced subgraph of the bipartite graph |
|
Computes the induced subgraph of |
|
The homophily of a graph characterizes how likely nodes with the same label are near each other in a graph. |
|
The degree assortativity coefficient from the "Mixing patterns in networks" paper. |
|
Computes the graph Laplacian of the graph given by |
|
Computes the mesh Laplacian of a mesh given by |
|
Returns a new tensor which masks the |
|
Converts indices to a mask representation. |
|
Converts a mask to an index representation. |
|
Selects the input tensor or input list according to a given index or mask vector. |
|
Narrows the input tensor or input list to the specified range. |
|
Given a sparse batch of node features \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\) (with \(N_i\) indicating the number of nodes in graph \(i\)), creates a dense node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}\) (with \(N_{\max} = \max_i^B N_i\)). |
|
Converts batched sparse adjacency matrices given by edge indices and edge attributes to a single dense batched adjacency matrix. |
|
Given a contiguous batch of tensors \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}\) (with \(N_i\) indicating the number of elements in example \(i\)), creates a nested PyTorch tensor. |
|
Given a nested PyTorch tensor, creates a contiguous batch of tensors \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}\), and optionally a batch vector which assigns each element to a specific example. |
|
Converts a dense adjacency matrix to a sparse adjacency matrix defined by edge indices and edge attributes. |
|
Returns |
|
Returns |
|
Converts a sparse adjacency matrix defined by edge indices and edge attributes to a |
|
Converts a sparse adjacency matrix defined by edge indices and edge attributes to a |
|
Converts a sparse adjacency matrix defined by edge indices and edge attributes to a |
|
Converts a sparse adjacency matrix defined by edge indices and edge attributes to a |
|
Converts a |
|
Matrix product of sparse matrix with dense matrix. |
|
Splits |
|
Splits the |
|
Taskes a one-dimensional |
|
Computes the normalized cut \(\mathbf{e}_{i,j} \cdot \left( \frac{1}{\deg(i)} + \frac{1}{\deg(j)} \right)\) of a weighted graph given by edge indices and edge attributes. |
|
Returns the edge indices of a two-dimensional grid graph with height |
|
Computes (normalized) geodesic distances of a mesh given by |
|
Converts a graph given by edge indices and edge attributes to a scipy sparse matrix. |
|
Converts a scipy sparse matrix to edge indices and edge attributes. |
|
Converts a |
|
Converts a |
|
Converts a |
|
Converts a |
|
Converts a |
|
Converts a |
|
Converts a graph given by |
|
Converts a |
|
Converts a |
|
Converts a |
|
Converts a SMILES string to a |
|
Converts a |
|
Returns the |
|
Returns the |
|
Returns the |
|
Samples random negative edges of a graph given by |
|
Samples random negative edges of multiple graphs given by |
|
Samples a negative edge |
|
The tree decomposition algorithm of molecules from the "Junction Tree Variational Autoencoder for Molecular Graph Generation" paper. |
|
Returns the output embeddings of all |
|
Trims the |
|
Splits the edges of a |
- scatter(src: Tensor, index: Tensor, dim: int = 0, dim_size: Optional[int] = None, reduce: str = 'sum') Tensor [source]
Reduces all values from the
src
tensor at the indices specified in theindex
tensor along a given dimensiondim
. See the documentation of thetorch_scatter
package for more information.- Parameters
src (torch.Tensor) – The source tensor.
index (torch.Tensor) – The index tensor.
dim (int, optional) – The dimension along which to index. (default:
0
)dim_size (int, optional) – The size of the output tensor at dimension
dim
. If set toNone
, will create a minimal-sized output tensor according toindex.max() + 1
. (default:None
)reduce (str, optional) – The reduce operation (
"sum"
,"mean"
,"mul"
,"min"
or"max"
,"any"
). (default:"sum"
)
- segment(src: Tensor, ptr: Tensor, reduce: str = 'sum') Tensor [source]
Reduces all values in the first dimension of the
src
tensor within the ranges specified in theptr
. See the documentation of thetorch_scatter
package for more information.- Parameters
src (torch.Tensor) – The source tensor.
ptr (torch.Tensor) – A monotonically increasing pointer tensor that refers to the boundaries of segments such that
ptr[0] = 0
andptr[-1] = src.size(0)
.reduce (str, optional) – The reduce operation (
"sum"
,"mean"
,"mul"
,"min"
or"max"
). (default:"sum"
)
- index_sort(inputs: Tensor, max_value: Optional[int] = None) Tuple[Tensor, Tensor] [source]
Sorts the elements of the
inputs
tensor in ascending order. It is expected thatinputs
is one-dimensional and that it only contains positive integer values. Ifmax_value
is given, it can be used by the underlying algorithm for better performance.- Parameters
inputs (torch.Tensor) – A vector with positive integer values.
max_value (int, optional) – The maximum value stored inside
inputs
. This value can be an estimation, but needs to be greater than or equal to the real maximum. (default:None
)
- degree(index: Tensor, num_nodes: Optional[int] = None, dtype: Optional[dtype] = None) Tensor [source]
Computes the (unweighted) degree of a given one-dimensional index tensor.
- Parameters
index (LongTensor) – Index tensor.
num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofindex
. (default:None
)dtype (
torch.dtype
, optional) – The desired data type of the returned tensor.
- Return type
Tensor
Example
>>> row = torch.tensor([0, 1, 0, 2, 0]) >>> degree(row, dtype=torch.long) tensor([3, 1, 1])
- softmax(src: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, dim: int = 0) Tensor [source]
Computes a sparsely evaluated softmax. Given a value tensor
src
, this function first groups the values along the first dimension based on the indices specified inindex
, and then proceeds to compute the softmax individually for each group.- Parameters
src (Tensor) – The source tensor.
index (LongTensor, optional) – The indices of elements for applying the softmax. (default:
None
)ptr (LongTensor, optional) – If given, computes the softmax based on sorted inputs in CSR representation. (default:
None
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofindex
. (default:None
)dim (int, optional) – The dimension in which to normalize. (default:
0
)
- Return type
Tensor
Examples
>>> src = torch.tensor([1., 1., 1., 1.]) >>> index = torch.tensor([0, 0, 1, 2]) >>> ptr = torch.tensor([0, 2, 3, 4]) >>> softmax(src, index) tensor([0.5000, 0.5000, 1.0000, 1.0000])
>>> softmax(src, None, ptr) tensor([0.5000, 0.5000, 1.0000, 1.0000])
>>> src = torch.randn(4, 4) >>> ptr = torch.tensor([0, 4]) >>> softmax(src, index, dim=-1) tensor([[0.7404, 0.2596, 1.0000, 1.0000], [0.1702, 0.8298, 1.0000, 1.0000], [0.7607, 0.2393, 1.0000, 1.0000], [0.8062, 0.1938, 1.0000, 1.0000]])
- dropout_node(edge_index: Tensor, p: float = 0.5, num_nodes: Optional[int] = None, training: bool = True) Tuple[Tensor, Tensor, Tensor] [source]
Randomly drops nodes from the adjacency matrix
edge_index
with probabilityp
using samples from a Bernoulli distribution.The method returns (1) the retained
edge_index
, (2) the edge mask indicating which edges were retained. (3) the node mask indicating which nodes were retained.- Parameters
- Return type
(
LongTensor
,BoolTensor
,BoolTensor
)
Examples
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> edge_index, edge_mask, node_mask = dropout_node(edge_index) >>> edge_index tensor([[0, 1], [1, 0]]) >>> edge_mask tensor([ True, True, False, False, False, False]) >>> node_mask tensor([ True, True, False, False])
- dropout_edge(edge_index: Tensor, p: float = 0.5, force_undirected: bool = False, training: bool = True) Tuple[Tensor, Tensor] [source]
Randomly drops edges from the adjacency matrix
edge_index
with probabilityp
using samples from a Bernoulli distribution.The method returns (1) the retained
edge_index
, (2) the edge mask or index indicating which edges were retained, depending on the argumentforce_undirected
.- Parameters
edge_index (LongTensor) – The edge indices.
p (float, optional) – Dropout probability. (default:
0.5
)force_undirected (bool, optional) – If set to
True
, will either drop or keep both edges of an undirected edge. (default:False
)training (bool, optional) – If set to
False
, this operation is a no-op. (default:True
)
- Return type
(
LongTensor
,BoolTensor
orLongTensor
)
Examples
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> edge_index, edge_mask = dropout_edge(edge_index) >>> edge_index tensor([[0, 1, 2, 2], [1, 2, 1, 3]]) >>> edge_mask # masks indicating which edges are retained tensor([ True, False, True, True, True, False])
>>> edge_index, edge_id = dropout_edge(edge_index, ... force_undirected=True) >>> edge_index tensor([[0, 1, 2, 1, 2, 3], [1, 2, 3, 0, 1, 2]]) >>> edge_id # indices indicating which edges are retained tensor([0, 2, 4, 0, 2, 4])
- dropout_path(edge_index: Tensor, p: float = 0.2, walks_per_node: int = 1, walk_length: int = 3, num_nodes: Optional[int] = None, is_sorted: bool = False, training: bool = True) Tuple[Tensor, Tensor] [source]
Drops edges from the adjacency matrix
edge_index
based on random walks. The source nodes to start random walks from are sampled fromedge_index
with probabilityp
, following a Bernoulli distribution.The method returns (1) the retained
edge_index
, (2) the edge mask indicating which edges were retained.- Parameters
edge_index (LongTensor) – The edge indices.
p (float, optional) – Sample probability. (default:
0.2
)walks_per_node (int, optional) – The number of walks per node, same as
Node2Vec
. (default:1
)walk_length (int, optional) – The walk length, same as
Node2Vec
. (default:3
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)is_sorted (bool, optional) – If set to
True
, will expectedge_index
to be already sorted row-wise. (default:False
)training (bool, optional) – If set to
False
, this operation is a no-op. (default:True
)
- Return type
(
LongTensor
,BoolTensor
)
Example
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> edge_index, edge_mask = dropout_path(edge_index) >>> edge_index tensor([[1, 2], [2, 3]]) >>> edge_mask # masks indicating which edges are retained tensor([False, False, True, False, True, False])
- dropout_adj(edge_index: Tensor, edge_attr: Optional[Tensor] = None, p: float = 0.5, force_undirected: bool = False, num_nodes: Optional[int] = None, training: bool = True) Tuple[Tensor, Optional[Tensor]] [source]
Randomly drops edges from the adjacency matrix
(edge_index, edge_attr)
with probabilityp
using samples from a Bernoulli distribution.Warning
dropout_adj
is deprecated and will be removed in a future release. Usetorch_geometric.utils.dropout_edge
instead.- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor, optional) – Edge weights or multi-dimensional edge features. (default:
None
)p (float, optional) – Dropout probability. (default:
0.5
)force_undirected (bool, optional) – If set to
True
, will either drop or keep both edges of an undirected edge. (default:False
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)training (bool, optional) – If set to
False
, this operation is a no-op. (default:True
)
Examples
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6]) >>> dropout_adj(edge_index, edge_attr) (tensor([[0, 1, 2, 3], [1, 2, 3, 2]]), tensor([1, 3, 5, 6]))
>>> # The returned graph is kept undirected >>> dropout_adj(edge_index, edge_attr, force_undirected=True) (tensor([[0, 1, 2, 1, 2, 3], [1, 2, 3, 0, 1, 2]]), tensor([1, 3, 5, 1, 3, 5]))
- shuffle_node(x: Tensor, batch: Optional[Tensor] = None, training: bool = True) Tuple[Tensor, Tensor] [source]
Randomly shuffle the feature matrix
x
along the first dimmension.The method returns (1) the shuffled
x
, (2) the permutation indicating the orders of original nodes after shuffling.- Parameters
- Return type
(
FloatTensor
,LongTensor
)
Example
>>> # Standard case >>> x = torch.tensor([[0, 1, 2], ... [3, 4, 5], ... [6, 7, 8], ... [9, 10, 11]], dtype=torch.float) >>> x, node_perm = shuffle_node(x) >>> x tensor([[ 3., 4., 5.], [ 9., 10., 11.], [ 0., 1., 2.], [ 6., 7., 8.]]) >>> node_perm tensor([1, 3, 0, 2])
>>> # For batched graphs as inputs >>> batch = torch.tensor([0, 0, 1, 1]) >>> x, node_perm = shuffle_node(x, batch) >>> x tensor([[ 3., 4., 5.], [ 0., 1., 2.], [ 9., 10., 11.], [ 6., 7., 8.]]) >>> node_perm tensor([1, 0, 3, 2])
- mask_feature(x: Tensor, p: float = 0.5, mode: str = 'col', fill_value: float = 0.0, training: bool = True) Tuple[Tensor, Tensor] [source]
Randomly masks feature from the feature matrix
x
with probabilityp
using samples from a Bernoulli distribution.The method returns (1) the retained
x
, (2) the feature mask broadcastable withx
(mode='row'
andmode='col'
) or with the same shape asx
(mode='all'
), indicating where features are retained.- Parameters
x (FloatTensor) – The feature matrix.
p (float, optional) – The masking ratio. (default:
0.5
)mode (str, optional) – The masked scheme to use for feature masking. (
"row"
,"col"
or"all"
). Ifmode='col'
, will mask entire features of all nodes from the feature matrix. Ifmode='row'
, will mask entire nodes from the feature matrix. Ifmode='all'
, will mask individual features across all nodes. (default:'col'
)fill_value (float, optional) – The value for masked features in the output tensor. (default:
0
)training (bool, optional) – If set to
False
, this operation is a no-op. (default:True
)
- Return type
(
FloatTensor
,BoolTensor
)
Examples
>>> # Masked features are column-wise sampled >>> x = torch.tensor([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]], dtype=torch.float) >>> x, feat_mask = mask_feature(x) >>> x tensor([[1., 0., 3.], [4., 0., 6.], [7., 0., 9.]]), >>> feat_mask tensor([[True, False, True]])
>>> # Masked features are row-wise sampled >>> x, feat_mask = mask_feature(x, mode='row') >>> x tensor([[1., 2., 3.], [0., 0., 0.], [7., 8., 9.]]), >>> feat_mask tensor([[True], [False], [True]])
>>> # Masked features are uniformly sampled >>> x, feat_mask = mask_feature(x, mode='all') >>> x tensor([[0., 0., 0.], [4., 0., 6.], [0., 0., 9.]]) >>> feat_mask tensor([[False, False, False], [True, False, True], [False, False, True]])
- add_random_edge(edge_index, p: float, force_undirected: bool = False, num_nodes: Optional[Union[Tuple[int], int]] = None, training: bool = True) Tuple[Tensor, Tensor] [source]
Randomly adds edges to
edge_index
.The method returns (1) the retained
edge_index
, (2) the added edge indices.- Parameters
edge_index (LongTensor) – The edge indices.
p (float) – Ratio of added edges to the existing edges.
force_undirected (bool, optional) – If set to
True
, added edges will be undirected. (default:False
)num_nodes (int, Tuple[int], optional) – The overall number of nodes, i.e.
max_val + 1
, or the number of source and destination nodes, i.e.(max_src_val + 1, max_dst_val + 1)
ofedge_index
. (default:None
)training (bool, optional) – If set to
False
, this operation is a no-op. (default:True
)
- Return type
(
LongTensor
,LongTensor
)
Examples
>>> # Standard case >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5) >>> edge_index tensor([[0, 1, 1, 2, 2, 3, 2, 1, 3], [1, 0, 2, 1, 3, 2, 0, 2, 1]]) >>> added_edges tensor([[2, 1, 3], [0, 2, 1]])
>>> # The returned graph is kept undirected >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5, ... force_undirected=True) >>> edge_index tensor([[0, 1, 1, 2, 2, 3, 2, 1, 3, 0, 2, 1], [1, 0, 2, 1, 3, 2, 0, 2, 1, 2, 1, 3]]) >>> added_edges tensor([[2, 1, 3, 0, 2, 1], [0, 2, 1, 2, 1, 3]])
>>> # For bipartite graphs >>> edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], ... [2, 3, 1, 4, 2, 1]]) >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5, ... num_nodes=(6, 5)) >>> edge_index tensor([[0, 1, 2, 3, 4, 5, 3, 4, 1], [2, 3, 1, 4, 2, 1, 1, 3, 2]]) >>> added_edges tensor([[3, 4, 1], [1, 3, 2]])
- sort_edge_index(edge_index: Tensor, edge_attr: Union[Tensor, None, List[Tensor], str] = '???', num_nodes: Optional[int] = None, sort_by_row: bool = True) Union[Tensor, Tuple[Tensor, Optional[Tensor]], Tuple[Tensor, List[Tensor]]] [source]
Row-wise sorts
edge_index
.- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor or List[Tensor], optional) – Edge weights or multi- dimensional edge features. If given as a list, will re-shuffle and remove duplicates for all its entries. (default:
None
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)sort_by_row (bool, optional) – If set to
False
, will sortedge_index
column-wise.
- Return type
LongTensor
ifedge_attr
is not passed, else (LongTensor
,Optional[Tensor]
orList[Tensor]]
)
Warning
From PyG >= 2.3.0 onwards, this function will always return a tuple whenever
edge_attr
is passed as an argument (even in case it is set toNone
).Examples
>>> edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]]) >>> edge_attr = torch.tensor([[1], [2], [3], [4]]) >>> sort_edge_index(edge_index) tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
>>> sort_edge_index(edge_index, edge_attr) (tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), tensor([[4], [3], [2], [1]]))
- coalesce(edge_index: Tensor, edge_attr: Union[Tensor, None, List[Tensor], str] = '???', num_nodes: Optional[int] = None, reduce: str = 'add', is_sorted: bool = False, sort_by_row: bool = True) Union[Tensor, Tuple[Tensor, Optional[Tensor]], Tuple[Tensor, List[Tensor]]] [source]
Row-wise sorts
edge_index
and removes its duplicated entries. Duplicate entries inedge_attr
are merged by scattering them together according to the givenreduce
option.- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor or List[Tensor], optional) – Edge weights or multi- dimensional edge features. If given as a list, will re-shuffle and remove duplicates for all its entries. (default:
None
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)reduce (str, optional) – The reduce operation to use for merging edge features (
"add"
,"mean"
,"min"
,"max"
,"mul"
,"any"
). (default:"add"
)is_sorted (bool, optional) – If set to
True
, will expectedge_index
to be already sorted row-wise.sort_by_row (bool, optional) – If set to
False
, will sortedge_index
column-wise.
- Return type
LongTensor
ifedge_attr
is not passed, else (LongTensor
,Optional[Tensor]
orList[Tensor]]
)
Warning
From PyG >= 2.3.0 onwards, this function will always return a tuple whenever
edge_attr
is passed as an argument (even in case it is set toNone
).Example
>>> edge_index = torch.tensor([[1, 1, 2, 3], ... [3, 3, 1, 2]]) >>> edge_attr = torch.tensor([1., 1., 1., 1.]) >>> coalesce(edge_index) tensor([[1, 2, 3], [3, 1, 2]])
>>> # Sort `edge_index` column-wise >>> coalesce(edge_index, sort_by_row=False) tensor([[2, 3, 1], [1, 2, 3]])
>>> coalesce(edge_index, edge_attr) (tensor([[1, 2, 3], [3, 1, 2]]), tensor([2., 1., 1.]))
>>> # Use 'mean' operation to merge edge features >>> coalesce(edge_index, edge_attr, reduce='mean') (tensor([[1, 2, 3], [3, 1, 2]]), tensor([1., 1., 1.]))
- is_undirected(edge_index: Tensor, edge_attr: Union[Tensor, None, List[Tensor]] = None, num_nodes: Optional[int] = None) bool [source]
Returns
True
if the graph given byedge_index
is undirected.- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor or List[Tensor], optional) – Edge weights or multi- dimensional edge features. If given as a list, will check for equivalence in all its entries. (default:
None
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)
- Return type
Examples
>>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> weight = torch.tensor([0, 0, 1]) >>> is_undirected(edge_index, weight) True
>>> weight = torch.tensor([0, 1, 1]) >>> is_undirected(edge_index, weight) False
- to_undirected(edge_index: Tensor, edge_attr: Union[Tensor, None, List[Tensor], str] = '???', num_nodes: Optional[int] = None, reduce: str = 'add') Union[Tensor, Tuple[Optional[Tensor], Tensor], Tuple[Tensor, List[Tensor]]] [source]
Converts the graph given by
edge_index
to an undirected graph such that \((j,i) \in \mathcal{E}\) for every edge \((i,j) \in \mathcal{E}\).- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor or List[Tensor], optional) – Edge weights or multi- dimensional edge features. If given as a list, will remove duplicates for all its entries. (default:
None
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)reduce (str, optional) – The reduce operation to use for merging edge features (
"add"
,"mean"
,"min"
,"max"
,"mul"
). (default:"add"
)
- Return type
LongTensor
ifedge_attr
is not passed, else (LongTensor
,Optional[Tensor]
orList[Tensor]]
)
Warning
From PyG >= 2.3.0 onwards, this function will always return a tuple whenever
edge_attr
is passed as an argument (even in case it is set toNone
).Examples
>>> edge_index = torch.tensor([[0, 1, 1], ... [1, 0, 2]]) >>> to_undirected(edge_index) tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
>>> edge_index = torch.tensor([[0, 1, 1], ... [1, 0, 2]]) >>> edge_weight = torch.tensor([1., 1., 1.]) >>> to_undirected(edge_index, edge_weight) (tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), tensor([2., 2., 1., 1.]))
>>> # Use 'mean' operation to merge edge features >>> to_undirected(edge_index, edge_weight, reduce='mean') (tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), tensor([1., 1., 1., 1.]))
- contains_self_loops(edge_index: Tensor) bool [source]
Returns
True
if the graph given byedge_index
contains self-loops.- Parameters
edge_index (LongTensor) – The edge indices.
- Return type
Examples
>>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> contains_self_loops(edge_index) True
>>> edge_index = torch.tensor([[0, 1, 1], ... [1, 0, 2]]) >>> contains_self_loops(edge_index) False
- remove_self_loops(edge_index: Tensor, edge_attr: Optional[Tensor] = None) Tuple[Tensor, Optional[Tensor]] [source]
Removes every self-loop in the graph given by
edge_index
, so that \((i,i) \not\in \mathcal{E}\) for every \(i \in \mathcal{V}\).- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor, optional) – Edge weights or multi-dimensional edge features. (default:
None
)
- Return type
(
LongTensor
,Tensor
)
Example
>>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> edge_attr = [[1, 2], [3, 4], [5, 6]] >>> edge_attr = torch.tensor(edge_attr) >>> remove_self_loops(edge_index, edge_attr) (tensor([[0, 1], [1, 0]]), tensor([[1, 2], [3, 4]]))
- segregate_self_loops(edge_index: Tensor, edge_attr: Optional[Tensor] = None) Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]] [source]
Segregates self-loops from the graph.
- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor, optional) – Edge weights or multi-dimensional edge features. (default:
None
)
- Return type
(
LongTensor
,Tensor
,LongTensor
,Tensor
)
Example
>>> edge_index = torch.tensor([[0, 0, 1], ... [0, 1, 0]]) >>> (edge_index, edge_attr, ... loop_edge_index, ... loop_edge_attr) = segregate_self_loops(edge_index) >>> loop_edge_index tensor([[0], [0]])
- add_self_loops(edge_index: Tensor, edge_attr: Optional[Tensor] = None, fill_value: Optional[Union[float, Tensor, str]] = None, num_nodes: Optional[Union[int, Tuple[int, int]]] = None) Tuple[Tensor, Optional[Tensor]] [source]
Adds a self-loop \((i,i) \in \mathcal{E}\) to every node \(i \in \mathcal{V}\) in the graph given by
edge_index
. In case the graph is weighted or has multi-dimensional edge features (edge_attr != None
), edge features of self-loops will be added according tofill_value
.- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor, optional) – Edge weights or multi-dimensional edge features. (default:
None
)fill_value (float or Tensor or str, optional) – The way to generate edge features of self-loops (in case
edge_attr != None
). If given asfloat
ortorch.Tensor
, edge features of self-loops will be directly given byfill_value
. If given asstr
, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. ("add"
,"mean"
,"min"
,"max"
,"mul"
). (default:1.
)num_nodes (int or Tuple[int, int], optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. If given as a tuple, thenedge_index
is interpreted as a bipartite graph with shape(num_src_nodes, num_dst_nodes)
. (default:None
)
- Return type
(
LongTensor
,Tensor
)
Examples
>>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> edge_weight = torch.tensor([0.5, 0.5, 0.5]) >>> add_self_loops(edge_index) (tensor([[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]]), None)
>>> add_self_loops(edge_index, edge_weight) (tensor([[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]]), tensor([0.5000, 0.5000, 0.5000, 1.0000, 1.0000]))
>>> # edge features of self-loops are filled by constant `2.0` >>> add_self_loops(edge_index, edge_weight, ... fill_value=2.) (tensor([[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]]), tensor([0.5000, 0.5000, 0.5000, 2.0000, 2.0000]))
>>> # Use 'add' operation to merge edge features for self-loops >>> add_self_loops(edge_index, edge_weight, ... fill_value='add') (tensor([[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]]), tensor([0.5000, 0.5000, 0.5000, 1.0000, 0.5000]))
- add_remaining_self_loops(edge_index: Tensor, edge_attr: Optional[Tensor] = None, fill_value: Optional[Union[float, Tensor, str]] = None, num_nodes: Optional[int] = None) Tuple[Tensor, Optional[Tensor]] [source]
Adds remaining self-loop \((i,i) \in \mathcal{E}\) to every node \(i \in \mathcal{V}\) in the graph given by
edge_index
. In case the graph is weighted or has multi-dimensional edge features (edge_attr != None
), edge features of non-existing self-loops will be added according tofill_value
.- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor, optional) – Edge weights or multi-dimensional edge features. (default:
None
)fill_value (float or Tensor or str, optional) – The way to generate edge features of self-loops (in case
edge_attr != None
). If given asfloat
ortorch.Tensor
, edge features of self-loops will be directly given byfill_value
. If given asstr
, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. ("add"
,"mean"
,"min"
,"max"
,"mul"
). (default:1.
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)
- Return type
(
LongTensor
,Tensor
)
Example
>>> edge_index = torch.tensor([[0, 1], ... [1, 0]]) >>> edge_weight = torch.tensor([0.5, 0.5]) >>> add_remaining_self_loops(edge_index, edge_weight) (tensor([[0, 1, 0, 1], [1, 0, 0, 1]]), tensor([0.5000, 0.5000, 1.0000, 1.0000]))
- get_self_loop_attr(edge_index: Tensor, edge_attr: Optional[Tensor] = None, num_nodes: Optional[int] = None) Tensor [source]
Returns the edge features or weights of self-loops \((i, i)\) of every node \(i \in \mathcal{V}\) in the graph given by
edge_index
. Edge features of missing self-loops not present inedge_index
will be filled with zeros. Ifedge_attr
is not given, it will be the vector of ones.Note
This operation is analogous to getting the diagonal elements of the dense adjacency matrix.
- Parameters
- Return type
Tensor
Examples
>>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> edge_weight = torch.tensor([0.2, 0.3, 0.5]) >>> get_self_loop_attr(edge_index, edge_weight) tensor([0.5000, 0.0000])
>>> get_self_loop_attr(edge_index, edge_weight, num_nodes=4) tensor([0.5000, 0.0000, 0.0000, 0.0000])
- contains_isolated_nodes(edge_index: Tensor, num_nodes: Optional[int] = None) bool [source]
Returns
True
if the graph given byedge_index
contains isolated nodes.- Parameters
- Return type
Examples
>>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> contains_isolated_nodes(edge_index) False
>>> contains_isolated_nodes(edge_index, num_nodes=3) True
- remove_isolated_nodes(edge_index: Tensor, edge_attr: Optional[Tensor] = None, num_nodes: Optional[int] = None) Tuple[Tensor, Optional[Tensor], Tensor] [source]
Removes the isolated nodes from the graph given by
edge_index
with optional edge attributesedge_attr
. In addition, returns a mask of shape[num_nodes]
to manually filter out isolated node features later on. Self-loops are preserved for non-isolated nodes.- Parameters
- Return type
(LongTensor, Tensor, BoolTensor)
Examples
>>> edge_index = torch.tensor([[0, 1, 0], ... [1, 0, 0]]) >>> edge_index, edge_attr, mask = remove_isolated_nodes(edge_index) >>> mask # node mask (2 nodes) tensor([True, True])
>>> edge_index, edge_attr, mask = remove_isolated_nodes(edge_index, ... num_nodes=3) >>> mask # node mask (3 nodes) tensor([True, True, False])
- get_num_hops(model: Module) int [source]
Returns the number of hops the model is aggregating information from.
Example
>>> class GNN(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.conv1 = GCNConv(3, 16) ... self.conv2 = GCNConv(16, 16) ... self.lin = Linear(16, 2) ... ... def forward(self, x, edge_index): ... x = torch.F.relu(self.conv1(x, edge_index)) ... x = self.conv2(x, edge_index) ... return self.lin(x) >>> get_num_hops(GNN()) 2
- subgraph(subset: Union[Tensor, List[int]], edge_index: Tensor, edge_attr: Optional[Tensor] = None, relabel_nodes: bool = False, num_nodes: Optional[int] = None, return_edge_mask: bool = False) Union[Tuple[Tensor, Optional[Tensor]], Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] [source]
Returns the induced subgraph of
(edge_index, edge_attr)
containing the nodes insubset
.- Parameters
subset (LongTensor, BoolTensor or [int]) – The nodes to keep.
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor, optional) – Edge weights or multi-dimensional edge features. (default:
None
)relabel_nodes (bool, optional) – If set to
True
, the resultingedge_index
will be relabeled to hold consecutive indices starting from zero. (default:False
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)return_edge_mask (bool, optional) – If set to
True
, will return the edge mask to filter out additional edge features. (default:False
)
- Return type
(
LongTensor
,Tensor
)
Examples
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6], ... [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5]]) >>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) >>> subset = torch.tensor([3, 4, 5]) >>> subgraph(subset, edge_index, edge_attr) (tensor([[3, 4, 4, 5], [4, 3, 5, 4]]), tensor([ 7., 8., 9., 10.]))
>>> subgraph(subset, edge_index, edge_attr, return_edge_mask=True) (tensor([[3, 4, 4, 5], [4, 3, 5, 4]]), tensor([ 7., 8., 9., 10.]), tensor([False, False, False, False, False, False, True, True, True, True, False, False]))
- bipartite_subgraph(subset: Union[Tuple[Tensor, Tensor], Tuple[List[int], List[int]]], edge_index: Tensor, edge_attr: Optional[Tensor] = None, relabel_nodes: bool = False, size: Optional[Tuple[int, int]] = None, return_edge_mask: bool = False) Union[Tuple[Tensor, Optional[Tensor]], Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] [source]
Returns the induced subgraph of the bipartite graph
(edge_index, edge_attr)
containing the nodes insubset
.- Parameters
subset (Tuple[Tensor, Tensor] or tuple([int],[int])) – The nodes to keep.
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor, optional) – Edge weights or multi-dimensional edge features. (default:
None
)relabel_nodes (bool, optional) – If set to
True
, the resultingedge_index
will be relabeled to hold consecutive indices starting from zero. (default:False
)size (tuple, optional) – The number of nodes. (default:
None
)return_edge_mask (bool, optional) – If set to
True
, will return the edge mask to filter out additional edge features. (default:False
)
- Return type
(
LongTensor
,Tensor
)
Examples
>>> edge_index = torch.tensor([[0, 5, 2, 3, 3, 4, 4, 3, 5, 5, 6], ... [0, 0, 3, 2, 0, 0, 2, 1, 2, 3, 1]]) >>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) >>> subset = (torch.tensor([2, 3, 5]), torch.tensor([2, 3])) >>> bipartite_subgraph(subset, edge_index, edge_attr) (tensor([[2, 3, 5, 5], [3, 2, 2, 3]]), tensor([ 3, 4, 9, 10]))
>>> bipartite_subgraph(subset, edge_index, edge_attr, ... return_edge_mask=True) (tensor([[2, 3, 5, 5], [3, 2, 2, 3]]), tensor([ 3, 4, 9, 10]), tensor([False, False, True, True, False, False, False, False, True, True, False]))
- k_hop_subgraph(node_idx: Union[int, List[int], Tensor], num_hops: int, edge_index: Tensor, relabel_nodes: bool = False, num_nodes: Optional[int] = None, flow: str = 'source_to_target', directed: bool = False) Tuple[Tensor, Tensor, Tensor, Tensor] [source]
Computes the induced subgraph of
edge_index
around all nodes innode_idx
reachable within \(k\) hops.The
flow
argument denotes the direction of edges for finding \(k\)-hop neighbors. If set to"source_to_target"
, then the method will find all neighbors that point to the initial set of seed nodes innode_idx.
This mimics the natural flow of message passing in Graph Neural Networks.The method returns (1) the nodes involved in the subgraph, (2) the filtered
edge_index
connectivity, (3) the mapping from node indices innode_idx
to their new location, and (4) the edge mask indicating which edges were preserved.- Parameters
node_idx (int, list, tuple or
torch.Tensor
) – The central seed node(s).num_hops (int) – The number of hops \(k\).
edge_index (LongTensor) – The edge indices.
relabel_nodes (bool, optional) – If set to
True
, the resultingedge_index
will be relabeled to hold consecutive indices starting from zero. (default:False
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)flow (str, optional) – The flow direction of \(k\)-hop aggregation (
"source_to_target"
or"target_to_source"
). (default:"source_to_target"
)directed (bool, optional) – If set to
False
, will include all edges between all sampled nodes. (default:True
)
- Return type
(
LongTensor
,LongTensor
,LongTensor
,BoolTensor
)
Examples
>>> edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], ... [2, 2, 4, 4, 6, 6]])
>>> # Center node 6, 2-hops >>> subset, edge_index, mapping, edge_mask = k_hop_subgraph( ... 6, 2, edge_index, relabel_nodes=True) >>> subset tensor([2, 3, 4, 5, 6]) >>> edge_index tensor([[0, 1, 2, 3], [2, 2, 4, 4]]) >>> mapping tensor([4]) >>> edge_mask tensor([False, False, True, True, True, True]) >>> subset[mapping] tensor([6])
>>> edge_index = torch.tensor([[1, 2, 4, 5], ... [0, 1, 5, 6]]) >>> (subset, edge_index, ... mapping, edge_mask) = k_hop_subgraph([0, 6], 2, ... edge_index, ... relabel_nodes=True) >>> subset tensor([0, 1, 2, 4, 5, 6]) >>> edge_index tensor([[1, 2, 3, 4], [0, 1, 4, 5]]) >>> mapping tensor([0, 5]) >>> edge_mask tensor([True, True, True, True]) >>> subset[mapping] tensor([0, 6])
- homophily(edge_index: Union[Tensor, SparseTensor], y: Tensor, batch: Optional[Tensor] = None, method: str = 'edge') Union[float, Tensor] [source]
The homophily of a graph characterizes how likely nodes with the same label are near each other in a graph. There are many measures of homophily that fits this definition. In particular:
In the “Beyond Homophily in Graph Neural Networks: Current Limitations and Effective Designs” paper, the homophily is the fraction of edges in a graph which connects nodes that have the same class label:
\[\frac{| \{ (v,w) : (v,w) \in \mathcal{E} \wedge y_v = y_w \} | } {|\mathcal{E}|}\]That measure is called the edge homophily ratio.
In the “Geom-GCN: Geometric Graph Convolutional Networks” paper, edge homophily is normalized across neighborhoods:
\[\frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \frac{ | \{ (w,v) : w \in \mathcal{N}(v) \wedge y_v = y_w \} | } { |\mathcal{N}(v)| }\]That measure is called the node homophily ratio.
In the “Large-Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods” paper, edge homophily is modified to be insensitive to the number of classes and size of each class:
\[\frac{1}{C-1} \sum_{k=1}^{C} \max \left(0, h_k - \frac{|\mathcal{C}_k|} {|\mathcal{V}|} \right),\]where \(C\) denotes the number of classes, \(|\mathcal{C}_k|\) denotes the number of nodes of class \(k\), and \(h_k\) denotes the edge homophily ratio of nodes of class \(k\).
Thus, that measure is called the class insensitive edge homophily ratio.
- Parameters
edge_index (Tensor or SparseTensor) – The graph connectivity.
y (Tensor) – The labels.
batch (LongTensor, optional) – Batch vector\(\mathbf{b} \in {\{ 0, \ldots,B-1\}}^N\), which assigns each node to a specific example. (default:
None
)method (str, optional) – The method used to calculate the homophily, either
"edge"
(first formula),"node"
(second formula) or"edge_insensitive"
(third formula). (default:"edge"
)
Examples
>>> edge_index = torch.tensor([[0, 1, 2, 3], ... [1, 2, 0, 4]]) >>> y = torch.tensor([0, 0, 0, 0, 1]) >>> # Edge homophily ratio >>> homophily(edge_index, y, method='edge') 0.75
>>> # Node homophily ratio >>> homophily(edge_index, y, method='node') 0.6000000238418579
>>> # Class insensitive edge homophily ratio >>> homophily(edge_index, y, method='edge_insensitive') 0.19999998807907104
- assortativity(edge_index: Union[Tensor, SparseTensor]) float [source]
The degree assortativity coefficient from the “Mixing patterns in networks” paper. Assortativity in a network refers to the tendency of nodes to connect with other similar nodes over dissimilar nodes. It is computed from Pearson correlation coefficient of the node degrees.
- Parameters
edge_index (Tensor or SparseTensor) – The graph connectivity.
- Returns
The value of the degree assortativity coefficient for the input graph \(\in [-1, 1]\)
Example
>>> edge_index = torch.tensor([[0, 1, 2, 3, 2], ... [1, 2, 0, 1, 3]]) >>> assortativity(edge_index) -0.666667640209198
- get_laplacian(edge_index: Tensor, edge_weight: Optional[Tensor] = None, normalization: Optional[str] = None, dtype: Optional[dtype] = None, num_nodes: Optional[int] = None) Tuple[Tensor, Optional[Tensor]] [source]
Computes the graph Laplacian of the graph given by
edge_index
and optionaledge_weight
.- Parameters
edge_index (LongTensor) – The edge indices.
edge_weight (Tensor, optional) – One-dimensional edge weights. (default:
None
)normalization (str, optional) –
The normalization scheme for the graph Laplacian (default:
None
):1.
None
: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)2.
"sym"
: Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)3.
"rw"
: Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)dtype (torch.dtype, optional) – The desired data type of returned tensor in case
edge_weight=None
. (default:None
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)
Examples
>>> edge_index = torch.tensor([[0, 1, 1, 2], ... [1, 0, 2, 1]]) >>> edge_weight = torch.tensor([1., 2., 2., 4.])
>>> # No normalization >>> lap = get_laplacian(edge_index, edge_weight)
>>> # Symmetric normalization >>> lap_sym = get_laplacian(edge_index, edge_weight, normalization='sym')
>>> # Random-walk normalization >>> lap_rw = get_laplacian(edge_index, edge_weight, normalization='rw')
- get_mesh_laplacian(pos: Tensor, face: Tensor, normalization: Optional[str] = None) Tuple[Tensor, Tensor] [source]
Computes the mesh Laplacian of a mesh given by
pos
andface
. Computation is based on the cotangent matrix defined as\[\begin{split}\mathbf{C}_{ij} = \begin{cases} \frac{\cot \angle_{ikj}~+\cot \angle_{ilj}}{2} & \text{if } i, j \text{ is an edge} \\ -\sum_{j \in N(i)}{C_{ij}} & \text{if } i \text{ is in the diagonal} \\ 0 & \text{otherwise} \end{cases}\end{split}\]Normalization depends on the mass matrix defined as
\[\begin{split}\mathbf{M}_{ij} = \begin{cases} a(i) & \text{if } i \text{ is in the diagonal} \\ 0 & \text{otherwise} \end{cases}\end{split}\]where \(a(i)\) is obtained by joining the barycenters of the triangles around vertex \(i\).
- Parameters
pos (Tensor) – The node positions.
face (LongTensor) – The face indices.
normalization (str, optional) –
The normalization scheme for the mesh Laplacian (default:
None
):1.
None
: No normalization \(\mathbf{L} = \mathbf{C}\)2.
"sym"
: Symmetric normalization \(\mathbf{L} = \mathbf{M}^{-1/2} \mathbf{C}\mathbf{M}^{-1/2}\)3.
"rw"
: Row-wise normalization \(\mathbf{L} = \mathbf{M}^{-1} \mathbf{C}\)
- mask_select(src: Tensor, dim: int, mask: Tensor) Tensor [source]
Returns a new tensor which masks the
src
tensor along the dimensiondim
according to the boolean maskmask
.- Parameters
src (torch.Tensor) – The input tensor.
dim (int) – The dimension in which to mask.
mask (torch.BoolTensor) – The 1-D tensor containing the binary mask to index with.
- index_to_mask(index: Tensor, size: Optional[int] = None) Tensor [source]
Converts indices to a mask representation.
- Parameters
idx (Tensor) – The indices.
size (int, optional) – minimal sized output mask is returned.
Example
>>> index = torch.tensor([1, 3, 5]) >>> index_to_mask(index) tensor([False, True, False, True, False, True])
>>> index_to_mask(index, size=7) tensor([False, True, False, True, False, True, False])
- mask_to_index(mask: Tensor) Tensor [source]
Converts a mask to an index representation.
- Parameters
mask (Tensor) – The mask.
Example
>>> mask = torch.tensor([False, True, False]) >>> mask_to_index(mask) tensor([1])
- select(src: Union[Tensor, List[Any]], index_or_mask: Tensor, dim: int) Union[Tensor, List[Any]] [source]
Selects the input tensor or input list according to a given index or mask vector.
- Parameters
src (torch.Tensor or list) – The input tensor or list.
index_or_mask (torch.Tensor) – The index or mask vector.
dim (int) – The dimension along which to select.
- narrow(src: Union[Tensor, List[Any]], dim: int, start: int, length: int) Union[Tensor, List[Any]] [source]
Narrows the input tensor or input list to the specified range.
- Parameters
src (torch.Tensor or list) – The input tensor or list.
dim (int) – The dimension along which to narrow.
start (int) – The starting dimension.
length (int) – The distance to the ending dimension.
- to_dense_batch(x: Tensor, batch: Optional[Tensor] = None, fill_value: float = 0.0, max_num_nodes: Optional[int] = None, batch_size: Optional[int] = None) Tuple[Tensor, Tensor] [source]
Given a sparse batch of node features \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\) (with \(N_i\) indicating the number of nodes in graph \(i\)), creates a dense node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}\) (with \(N_{\max} = \max_i^B N_i\)). In addition, a mask of shape \(\mathbf{M} \in \{ 0, 1 \}^{B \times N_{\max}}\) is returned, holding information about the existence of fake-nodes in the dense representation.
- Parameters
x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).
batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. Must be ordered. (default:
None
)fill_value (float, optional) – The value for invalid entries in the resulting dense output tensor. (default:
0
)max_num_nodes (int, optional) – The size of the output node dimension. (default:
None
)
- Return type
(
Tensor
,BoolTensor
)
Examples
>>> x = torch.arange(12).view(6, 2) >>> x tensor([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11]])
>>> out, mask = to_dense_batch(x) >>> mask tensor([[True, True, True, True, True, True]])
>>> batch = torch.tensor([0, 0, 1, 2, 2, 2]) >>> out, mask = to_dense_batch(x, batch) >>> out tensor([[[ 0, 1], [ 2, 3], [ 0, 0]], [[ 4, 5], [ 0, 0], [ 0, 0]], [[ 6, 7], [ 8, 9], [10, 11]]]) >>> mask tensor([[ True, True, False], [ True, False, False], [ True, True, True]])
>>> out, mask = to_dense_batch(x, batch, max_num_nodes=4) >>> out tensor([[[ 0, 1], [ 2, 3], [ 0, 0], [ 0, 0]], [[ 4, 5], [ 0, 0], [ 0, 0], [ 0, 0]], [[ 6, 7], [ 8, 9], [10, 11], [ 0, 0]]])
>>> mask tensor([[ True, True, False, False], [ True, False, False, False], [ True, True, True, False]])
- to_dense_adj(edge_index: Tensor, batch: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, max_num_nodes: Optional[int] = None, batch_size: Optional[int] = None) Tensor [source]
Converts batched sparse adjacency matrices given by edge indices and edge attributes to a single dense batched adjacency matrix.
- Parameters
edge_index (LongTensor) – The edge indices.
batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default:
None
)edge_attr (Tensor, optional) – Edge weights or multi-dimensional edge features. (default:
None
)max_num_nodes (int, optional) – The size of the output node dimension. (default:
None
)
- Return type
Tensor
Examples
>>> edge_index = torch.tensor([[0, 0, 1, 2, 3], ... [0, 1, 0, 3, 0]]) >>> batch = torch.tensor([0, 0, 1, 1]) >>> to_dense_adj(edge_index, batch) tensor([[[1., 1.], [1., 0.]], [[0., 1.], [1., 0.]]])
>>> to_dense_adj(edge_index, batch, max_num_nodes=4) tensor([[[1., 1., 0., 0.], [1., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 1., 0., 0.], [1., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]])
>>> edge_attr = torch.Tensor([1, 2, 3, 4, 5]) >>> to_dense_adj(edge_index, batch, edge_attr) tensor([[[1., 2.], [3., 0.]], [[0., 4.], [5., 0.]]])
- to_nested_tensor(x: Tensor, batch: Optional[Tensor] = None, ptr: Optional[Tensor] = None, batch_size: Optional[int] = None) Tensor [source]
Given a contiguous batch of tensors \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}\) (with \(N_i\) indicating the number of elements in example \(i\)), creates a nested PyTorch tensor. Reverse operation of
from_nested_tensor()
.- Parameters
x (torch.Tensor) – The input tensor \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}\).
batch (torch.Tensor, optional) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each element to a specific example. Must be ordered. (default:
None
)ptr (torch.Tensor, optional) – Alternative representation of
batch
in compressed format. (default:None
)batch_size (int, optional) – The batch size \(B\). (default:
None
)
- from_nested_tensor(x: Tensor, return_batch: bool = False) Union[Tensor, Tuple[Tensor, Tensor]] [source]
Given a nested PyTorch tensor, creates a contiguous batch of tensors \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}\), and optionally a batch vector which assigns each element to a specific example. Reverse operation of
to_nested_tensor()
.- Parameters
x (torch.Tensor) – The nested input tensor. The size of nested tensors need to match except for the first dimension.
return_batch (bool, optional) – If set to
True
, will also return the batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\). (default:False
)
- dense_to_sparse(adj: Tensor) Tuple[Tensor, Tensor] [source]
Converts a dense adjacency matrix to a sparse adjacency matrix defined by edge indices and edge attributes.
- Parameters
adj (Tensor) – The dense adjacency matrix of shape
[num_nodes, num_nodes]
or[batch_size, num_nodes, num_nodes]
.- Return type
(
LongTensor
,Tensor
)
Examples
>>> # Forr a single adjacency matrix >>> adj = torch.tensor([[3, 1], ... [2, 0]]) >>> dense_to_sparse(adj) (tensor([[0, 0, 1], [0, 1, 0]]), tensor([3, 1, 2]))
>>> # For two adjacency matrixes >>> adj = torch.tensor([[[3, 1], ... [2, 0]], ... [[0, 1], ... [0, 2]]]) >>> dense_to_sparse(adj) (tensor([[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]]), tensor([3, 1, 2, 1, 2]))
- is_torch_sparse_tensor(src: Any) bool [source]
Returns
True
if the inputsrc
is atorch.sparse.Tensor
(in any sparse layout).- Parameters
src (Any) – The input object to be checked.
- is_sparse(src: Any) bool [source]
Returns
True
if the inputsrc
is of typetorch.sparse.Tensor
(in any sparse layout) or of typetorch_sparse.SparseTensor
.- Parameters
src (Any) – The input object to be checked.
- to_torch_coo_tensor(edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[int, int]]] = None, is_coalesced: bool = False) Tensor [source]
Converts a sparse adjacency matrix defined by edge indices and edge attributes to a
torch.sparse.Tensor
with layout torch.sparse_coo. Seeto_edge_index()
for the reverse operation.- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor, optional) – The edge attributes. (default:
None
)size (int or (int, int), optional) – The size of the sparse matrix. If given as an integer, will create a quadratic sparse matrix. If set to
None
, will infer a quadratic sparse matrix based onedge_index.max() + 1
. (default:None
)is_coalesced (bool) – If set to
True
, will assume thatedge_index
is already coalesced and thus avoids expensive computation. (default:False
)
- Return type
torch.sparse.Tensor
Example
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> to_torch_coo_tensor(edge_index) tensor(indices=tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), values=tensor([1., 1., 1., 1., 1., 1.]), size=(4, 4), nnz=6, layout=torch.sparse_coo)
- to_torch_csr_tensor(edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[int, int]]] = None, is_coalesced: bool = False) Tensor [source]
Converts a sparse adjacency matrix defined by edge indices and edge attributes to a
torch.sparse.Tensor
with layout torch.sparse_csr. Seeto_edge_index()
for the reverse operation.- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor, optional) – The edge attributes. (default:
None
)size (int or (int, int), optional) – The size of the sparse matrix. If given as an integer, will create a quadratic sparse matrix. If set to
None
, will infer a quadratic sparse matrix based onedge_index.max() + 1
. (default:None
)is_coalesced (bool) – If set to
True
, will assume thatedge_index
is already coalesced and thus avoids expensive computation. (default:False
)
- Return type
torch.sparse.Tensor
Example
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> to_torch_csr_tensor(edge_index) tensor(crow_indices=tensor([0, 1, 3, 5, 6]), col_indices=tensor([1, 0, 2, 1, 3, 2]), values=tensor([1., 1., 1., 1., 1., 1.]), size=(4, 4), nnz=6, layout=torch.sparse_csr)
- to_torch_csc_tensor(edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[int, int]]] = None, is_coalesced: bool = False) Tensor [source]
Converts a sparse adjacency matrix defined by edge indices and edge attributes to a
torch.sparse.Tensor
with layout torch.sparse_csc. Seeto_edge_index()
for the reverse operation.- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor, optional) – The edge attributes. (default:
None
)size (int or (int, int), optional) – The size of the sparse matrix. If given as an integer, will create a quadratic sparse matrix. If set to
None
, will infer a quadratic sparse matrix based onedge_index.max() + 1
. (default:None
)is_coalesced (bool) – If set to
True
, will assume thatedge_index
is already coalesced and thus avoids expensive computation. (default:False
)
- Return type
torch.sparse.Tensor
Example
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> to_torch_csc_tensor(edge_index) tensor(ccol_indices=tensor([0, 1, 3, 5, 6]), row_indices=tensor([1, 0, 2, 1, 3, 2]), values=tensor([1., 1., 1., 1., 1., 1.]), size=(4, 4), nnz=6, layout=torch.sparse_csc)
- to_torch_sparse_tensor(edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[int, int]]] = None, is_coalesced: bool = False, layout: layout = torch.sparse_coo)[source]
Converts a sparse adjacency matrix defined by edge indices and edge attributes to a
torch.sparse.Tensor
with customlayout
. Seeto_edge_index()
for the reverse operation.- Parameters
edge_index (LongTensor) – The edge indices.
edge_attr (Tensor, optional) – The edge attributes. (default:
None
)size (int or (int, int), optional) – The size of the sparse matrix. If given as an integer, will create a quadratic sparse matrix. If set to
None
, will infer a quadratic sparse matrix based onedge_index.max() + 1
. (default:None
)is_coalesced (bool) – If set to
True
, will assume thatedge_index
is already coalesced and thus avoids expensive computation. (default:False
)layout (torch.layout, optional) – The layout of the output sparse tensor (
torch.sparse_coo
,torch.sparse_csr
,torch.sparse_csc
). (default:torch.sparse_coo
)
- Return type
torch.sparse.Tensor
- to_edge_index(adj: Union[Tensor, SparseTensor]) Tuple[Tensor, Tensor] [source]
Converts a
torch.sparse.Tensor
or atorch_sparse.SparseTensor
to edge indices and edge attributes.- Parameters
adj (torch.sparse.Tensor or SparseTensor) – The adjacency matrix.
- Return type
Example
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2]]) >>> adj = to_torch_coo_tensor(edge_index) >>> to_edge_index(adj) (tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), tensor([1., 1., 1., 1., 1., 1.]))
- spmm(src: Union[Tensor, SparseTensor], other: Tensor, reduce: str = 'sum') Tensor [source]
Matrix product of sparse matrix with dense matrix.
- Parameters
src (torch.Tensor or torch_sparse.SparseTensor) – The input sparse matrix, either a PyG
torch_sparse.SparseTensor
or a PyTorchtorch.sparse.Tensor
.other (torch.Tensor) – The input dense matrix.
reduce (str, optional) – The reduce operation to use (
"sum"
,"mean"
,"min"
,"max"
). (default:"sum"
)
- Return type
Tensor
- unbatch(src: Tensor, batch: Tensor, dim: int = 0) List[Tensor] [source]
Splits
src
according to abatch
vector along dimensiondim
.- Parameters
src (Tensor) – The source tensor.
batch (LongTensor) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each entry in
src
to a specific example. Must be ordered.dim (int, optional) – The dimension along which to split the
src
tensor. (default:0
)
- Return type
List[Tensor]
Example
>>> src = torch.arange(7) >>> batch = torch.tensor([0, 0, 0, 1, 1, 2, 2]) >>> unbatch(src, batch) (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
- unbatch_edge_index(edge_index: Tensor, batch: Tensor) List[Tensor] [source]
Splits the
edge_index
according to abatch
vector.- Parameters
edge_index (Tensor) – The edge_index tensor. Must be ordered.
batch (LongTensor) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. Must be ordered.
- Return type
List[Tensor]
Example
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6], ... [1, 0, 2, 1, 3, 2, 5, 4, 6, 5]]) >>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1]) >>> unbatch_edge_index(edge_index, batch) (tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), tensor([[0, 1, 1, 2], [1, 0, 2, 1]]))
- one_hot(index: Tensor, num_classes: Optional[int] = None, dtype: Optional[dtype] = None) Tensor [source]
Taskes a one-dimensional
index
tensor and returns a one-hot encoded representation of it with shape[*, num_classes]
that has zeros everywhere except where the index of last dimension matches the corresponding value of the input tensor, in which case it will be1
.Note
This is a more memory-efficient version of
torch.nn.functional.one_hot()
as you can customize the outputdtype
.- Parameters
index (torch.Tensor) – The one-dimensional input tensor.
num_classes (int, optional) – The total number of classes. If set to
None
, the number of classes will be inferred as one greater than the largest class value in the input tensor. (default:None
)dtype (torch.dtype, optional) – The
dtype
of the output tensor.
- normalized_cut(edge_index: Tensor, edge_attr: Tensor, num_nodes: Optional[int] = None) Tensor [source]
Computes the normalized cut \(\mathbf{e}_{i,j} \cdot \left( \frac{1}{\deg(i)} + \frac{1}{\deg(j)} \right)\) of a weighted graph given by edge indices and edge attributes.
- Parameters
- Return type
Tensor
Example
>>> edge_index = torch.tensor([[1, 1, 2, 3], ... [3, 3, 1, 2]]) >>> edge_attr = torch.tensor([1., 1., 1., 1.]) >>> normalized_cut(edge_index, edge_attr) tensor([1.5000, 1.5000, 2.0000, 1.5000])
- grid(height: int, width: int, dtype: Optional[dtype] = None, device: Optional[device] = None) Tuple[Tensor, Tensor] [source]
Returns the edge indices of a two-dimensional grid graph with height
height
and widthwidth
and its node positions.- Parameters
height (int) – The height of the grid.
width (int) – The width of the grid.
dtype (
torch.device
, optional) – The desired data type of the returned position tensor.dtype – The desired device of the returned tensors.
- Return type
(
LongTensor
,Tensor
)
Example
>>> (row, col), pos = grid(height=2, width=2) >>> row tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]) >>> col tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]) >>> pos tensor([[0., 1.], [1., 1.], [0., 0.], [1., 0.]])
- geodesic_distance(pos: Tensor, face: Tensor, src: Optional[Tensor] = None, dest: Optional[Tensor] = None, norm: bool = True, max_distance: Optional[float] = None, num_workers: int = 0) Tensor [source]
Computes (normalized) geodesic distances of a mesh given by
pos
andface
. Ifsrc
anddest
are given, this method only computes the geodesic distances for the respective source and target node-pairs.Note
This function requires the
gdist
package. To install, runpip install cython && pip install gdist
.- Parameters
pos (Tensor) – The node positions.
face (LongTensor) – The face indices.
src (LongTensor, optional) – If given, only compute geodesic distances for the specified source indices. (default:
None
)dest (LongTensor, optional) – If given, only compute geodesic distances for the specified target indices. (default:
None
)norm (bool, optional) – Normalizes geodesic distances by \(\sqrt{\textrm{area}(\mathcal{M})}\). (default:
True
)max_distance (float, optional) – If given, only yields results for geodesic distances less than
max_distance
. This will speed up runtime dramatically. (default:None
)num_workers (int, optional) – How many subprocesses to use for calculating geodesic distances.
0
means that computation takes place in the main process.-1
means that the available amount of CPU cores is used. (default:0
)
- Return type
Tensor
Example
>>> pos = torch.Tensor([[0, 0, 0], ... [2, 0, 0], ... [0, 2, 0], ... [2, 2, 0]]) >>> face = torch.tensor([[0, 0], ... [1, 2], ... [3, 3]]) >>> geodesic_distance(pos, face) [[0, 1, 1, 1.4142135623730951], [1, 0, 1.4142135623730951, 1], [1, 1.4142135623730951, 0, 1], [1.4142135623730951, 1, 1, 0]]
- to_scipy_sparse_matrix(edge_index: Tensor, edge_attr: Optional[Tensor] = None, num_nodes: Optional[int] = None) coo_matrix [source]
Converts a graph given by edge indices and edge attributes to a scipy sparse matrix.
- Parameters
Examples
>>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> to_scipy_sparse_matrix(edge_index) <4x4 sparse matrix of type '<class 'numpy.float32'>' with 6 stored elements in COOrdinate format>
- from_scipy_sparse_matrix(A: spmatrix) Tuple[Tensor, Tensor] [source]
Converts a scipy sparse matrix to edge indices and edge attributes.
- Parameters
A (scipy.sparse) – A sparse matrix.
Examples
>>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> adj = to_scipy_sparse_matrix(edge_index) >>> # `edge_index` and `edge_weight` are both returned >>> from_scipy_sparse_matrix(adj) (tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), tensor([1., 1., 1., 1., 1., 1.]))
- to_networkx(data: Data, node_attrs: Optional[Iterable[str]] = None, edge_attrs: Optional[Iterable[str]] = None, graph_attrs: Optional[Iterable[str]] = None, to_undirected: Optional[Union[bool, str]] = False, remove_self_loops: bool = False) Any [source]
Converts a
torch_geometric.data.Data
instance to anetworkx.Graph
ifto_undirected
is set toTrue
, or a directednetworkx.DiGraph
otherwise.- Parameters
data (torch_geometric.data.Data) – The data object.
node_attrs (iterable of str, optional) – The node attributes to be copied. (default:
None
)edge_attrs (iterable of str, optional) – The edge attributes to be copied. (default:
None
)graph_attrs (iterable of str, optional) – The graph attributes to be copied. (default:
None
)to_undirected (bool or str, optional) – If set to
True
or “upper”, will return anetworkx.Graph
instead of anetworkx.DiGraph
. The undirected graph will correspond to the upper triangle of the corresponding adjacency matrix. Similarly, if set to “lower”, the undirected graph will correspond to the lower triangle of the adjacency matrix. (default:False
)remove_self_loops (bool, optional) – If set to
True
, will not include self loops in the resulting graph. (default:False
)
Examples
>>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> data = Data(edge_index=edge_index, num_nodes=4) >>> to_networkx(data) <networkx.classes.digraph.DiGraph at 0x2713fdb40d0>
- from_networkx(G: Any, group_node_attrs: Optional[Union[List[str], all]] = None, group_edge_attrs: Optional[Union[List[str], all]] = None) Data [source]
Converts a
networkx.Graph
ornetworkx.DiGraph
to atorch_geometric.data.Data
instance.- Parameters
G (networkx.Graph or networkx.DiGraph) – A networkx graph.
group_node_attrs (List[str] or all, optional) – The node attributes to be concatenated and added to
data.x
. (default:None
)group_edge_attrs (List[str] or all, optional) – The edge attributes to be concatenated and added to
data.edge_attr
. (default:None
)
Note
All
group_node_attrs
andgroup_edge_attrs
values must be numeric.Examples
>>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> data = Data(edge_index=edge_index, num_nodes=4) >>> g = to_networkx(data) >>> # A `Data` object is returned >>> from_networkx(g) Data(edge_index=[2, 6], num_nodes=4)
- to_networkit(edge_index: Tensor, edge_weight: Optional[Tensor] = None, num_nodes: Optional[int] = None, directed: bool = True) Any [source]
Converts a
(edge_index, edge_weight)
tuple to anetworkit.Graph
.- Parameters
edge_index (torch.Tensor) – The edge indices of the graph.
edge_weight (torch.Tensor, optional) – The edge weights of the graph. (default:
None
)num_nodes (int, optional) – The number of nodes in the graph. (default:
None
)directed (bool, optional) – If set to
False
, the graph will be undirected. (default:True
)
- from_networkit(g: Any) Tuple[Tensor, Optional[Tensor]] [source]
Converts a
networkit.Graph
to a(edge_index, edge_weight)
tuple. If thenetworkit.Graph
is not weighted, the returnededge_weight
will beNone
.- Parameters
g (networkkit.graph.Graph) – A
networkit
graph object.
- to_trimesh(data)[source]
Converts a
torch_geometric.data.Data
instance to atrimesh.Trimesh
.- Parameters
data (torch_geometric.data.Data) – The data object.
Example
>>> pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], ... dtype=torch.float) >>> face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t()
>>> data = Data(pos=pos, face=face) >>> to_trimesh(data) <trimesh.Trimesh(vertices.shape=(4, 3), faces.shape=(2, 3))>
- from_trimesh(mesh)[source]
- Converts a
trimesh.Trimesh
to a torch_geometric.data.Data
instance.- Args:
mesh (trimesh.Trimesh): A
trimesh
mesh.
Example
Example:
>>> pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], ... dtype=torch.float) >>> face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t()
>>> data = Data(pos=pos, face=face) >>> mesh = to_trimesh(data) >>> from_trimesh(mesh) Data(pos=[4, 3], face=[3, 2])
- Converts a
- to_cugraph(edge_index: Tensor, edge_weight: Optional[Tensor] = None, relabel_nodes: bool = True, directed: bool = True)[source]
Converts a graph given by
edge_index
and optionaledge_weight
into acugraph
graph object.- Parameters
edge_index (torch.Tensor) – The edge indices of the graph.
edge_weight (torch.Tensor, optional) – The edge weights of the graph. (default:
None
)relabel_nodes (bool, optional) – If set to
True
,cugraph
will remove any isolated nodes, leading to a relabeling of nodes. (default:True
)directed (bool, optional) – If set to
False
, the graph will be undirected. (default:True
)
- from_cugraph(g: Any) Tuple[Tensor, Optional[Tensor]] [source]
Converts a
cugraph
graph object intoedge_index
and optionaledge_weight
tensors.- Parameters
g (cugraph.Graph) – A
cugraph
graph object.
- to_dgl(data: Union[Data, HeteroData]) Any [source]
Converts a
torch_geometric.data.Data
ortorch_geometric.data.HeteroData
instance to adgl
graph object.- Parameters
data (torch_geometric.data.Data or torch_geometric.data.HeteroData) – The data object.
Example
>>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]]) >>> x = torch.randn(5, 3) >>> edge_attr = torch.randn(6, 2) >>> data = Data(x=x, edge_index=edge_index, edge_attr=y) >>> g = to_dgl(data) >>> g Graph(num_nodes=5, num_edges=6, ndata_schemes={'x': Scheme(shape=(3,))} edata_schemes={'edge_attr': Scheme(shape=(2, ))})
>>> data = HeteroData() >>> data['paper'].x = torch.randn(5, 3) >>> data['author'].x = torch.ones(5, 3) >>> edge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) >>> data['author', 'cites', 'paper'].edge_index = edge_index >>> g = to_dgl(data) >>> g Graph(num_nodes={'author': 5, 'paper': 5}, num_edges={('author', 'cites', 'paper'): 5}, metagraph=[('author', 'paper', 'cites')])
- from_dgl(g: Any) Union[Data, HeteroData] [source]
Converts a
dgl
graph object to atorch_geometric.data.Data
ortorch_geometric.data.HeteroData
instance.- Parameters
g (dgl.DGLGraph) – The
dgl
graph object.
Example
>>> g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) >>> g.ndata['x'] = torch.randn(g.num_nodes(), 3) >>> g.edata['edge_attr'] = torch.randn(g.num_edges(), 2) >>> data = from_dgl(g) >>> data Data(x=[6, 3], edge_attr=[4, 2], edge_index=[2, 4])
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({ ... ('author', 'writes', 'paper'): ([0, 1, 1, 2, 3, 3, 4], ... [0, 0, 1, 1, 1, 2, 2])}) >>> g.nodes['author'].data['x'] = torch.randn(5, 3) >>> g.nodes['paper'].data['x'] = torch.randn(5, 3) >>> data = from_dgl(g) >>> data HeteroData( author={ x=[5, 3] }, paper={ x=[3, 3] }, (author, writes, paper)={ edge_index=[2, 7] } )
- from_smiles(smiles: str, with_hydrogen: bool = False, kekulize: bool = False) Data [source]
Converts a SMILES string to a
torch_geometric.data.Data
instance.
- to_smiles(data: Data, kekulize: bool = False) Any [source]
Converts a
torch_geometric.data.Data
instance to a SMILES string.- Parameters
data (torch_geometric.data.Data) – The molecular graph.
kekulize (bool, optional) – If set to
True
, converts aromatic bonds to single/double bonds. (default:False
)
- erdos_renyi_graph(num_nodes: int, edge_prob: float, directed: bool = False) Tensor [source]
Returns the
edge_index
of a random Erdos-Renyi graph.- Parameters
Examples
>>> erdos_renyi_graph(5, 0.2, directed=False) tensor([[0, 1, 1, 4], [1, 0, 4, 1]])
>>> erdos_renyi_graph(5, 0.2, directed=True) tensor([[0, 1, 3, 3, 4, 4], [4, 3, 1, 2, 1, 3]])
- stochastic_blockmodel_graph(block_sizes: Union[List[int], Tensor], edge_probs: Union[List[List[float]], Tensor], directed: bool = False) Tensor [source]
Returns the
edge_index
of a stochastic blockmodel graph.- Parameters
Examples
>>> block_sizes = [2, 2, 4] >>> edge_probs = [[0.25, 0.05, 0.02], ... [0.05, 0.35, 0.07], ... [0.02, 0.07, 0.40]] >>> stochastic_blockmodel_graph(block_sizes, edge_probs, ... directed=False) tensor([[2, 4, 4, 5, 5, 6, 7, 7], [5, 6, 7, 2, 7, 4, 4, 5]])
>>> stochastic_blockmodel_graph(block_sizes, edge_probs, ... directed=True) tensor([[0, 2, 3, 4, 4, 5, 5], [3, 4, 1, 5, 6, 6, 7]])
- barabasi_albert_graph(num_nodes: int, num_edges: int) Tensor [source]
Returns the
edge_index
of a Barabasi-Albert preferential attachment model, where a graph ofnum_nodes
nodes grows by attaching new nodes withnum_edges
edges that are preferentially attached to existing nodes with high degree.- Parameters
Example
>>> barabasi_albert_graph(num_nodes=4, num_edges=3) tensor([[0, 0, 0, 1, 1, 2, 2, 3], [1, 2, 3, 0, 2, 0, 1, 0]])
- negative_sampling(edge_index: Tensor, num_nodes: Optional[Union[int, Tuple[int, int]]] = None, num_neg_samples: Optional[int] = None, method: str = 'sparse', force_undirected: bool = False) Tensor [source]
Samples random negative edges of a graph given by
edge_index
.- Parameters
edge_index (LongTensor) – The edge indices.
num_nodes (int or Tuple[int, int], optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. If given as a tuple, thenedge_index
is interpreted as a bipartite graph with shape(num_src_nodes, num_dst_nodes)
. (default:None
)num_neg_samples (int, optional) – The (approximate) number of negative samples to return. If set to
None
, will try to return a negative edge for every positive edge. (default:None
)method (str, optional) – The method to use for negative sampling, i.e.
"sparse"
or"dense"
. This is a memory/runtime trade-off."sparse"
will work on any graph of any size, while"dense"
can perform faster true-negative checks. (default:"sparse"
)force_undirected (bool, optional) – If set to
True
, sampled negative edges will be undirected. (default:False
)
- Return type
LongTensor
Examples
>>> # Standard usage >>> edge_index = torch.as_tensor([[0, 0, 1, 2], ... [0, 1, 2, 3]]) >>> negative_sampling(edge_index) tensor([[3, 0, 0, 3], [2, 3, 2, 1]])
>>> # For bipartite graph >>> negative_sampling(edge_index, num_nodes=(3, 4)) tensor([[0, 2, 2, 1], [2, 2, 1, 3]])
- batched_negative_sampling(edge_index: Tensor, batch: Union[Tensor, Tuple[Tensor, Tensor]], num_neg_samples: Optional[int] = None, method: str = 'sparse', force_undirected: bool = False) Tensor [source]
Samples random negative edges of multiple graphs given by
edge_index
andbatch
.- Parameters
edge_index (LongTensor) – The edge indices.
batch (LongTensor or Tuple[LongTensor, LongTensor]) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. If given as a tuple, then
edge_index
is interpreted as a bipartite graph connecting two different node types.num_neg_samples (int, optional) – The number of negative samples to return. If set to
None
, will try to return a negative edge for every positive edge. (default:None
)method (str, optional) – The method to use for negative sampling, i.e.
"sparse"
or"dense"
. This is a memory/runtime trade-off."sparse"
will work on any graph of any size, while"dense"
can perform faster true-negative checks. (default:"sparse"
)force_undirected (bool, optional) – If set to
True
, sampled negative edges will be undirected. (default:False
)
- Return type
LongTensor
Examples
>>> # Standard usage >>> edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) >>> edge_index = torch.cat([edge_index, edge_index + 4], dim=1) >>> edge_index tensor([[0, 0, 1, 2, 4, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6, 7]]) >>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) >>> batched_negative_sampling(edge_index, batch) tensor([[3, 1, 3, 2, 7, 7, 6, 5], [2, 0, 1, 1, 5, 6, 4, 4]])
>>> # For bipartite graph >>> edge_index1 = torch.as_tensor([[0, 0, 1, 1], [0, 1, 2, 3]]) >>> edge_index2 = edge_index1 + torch.tensor([[2], [4]]) >>> edge_index3 = edge_index2 + torch.tensor([[2], [4]]) >>> edge_index = torch.cat([edge_index1, edge_index2, ... edge_index3], dim=1) >>> edge_index tensor([[ 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5], [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]) >>> src_batch = torch.tensor([0, 0, 1, 1, 2, 2]) >>> dst_batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]) >>> batched_negative_sampling(edge_index, ... (src_batch, dst_batch)) tensor([[ 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5], [ 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9]])
- structured_negative_sampling(edge_index, num_nodes: Optional[int] = None, contains_neg_self_loops: bool = True)[source]
Samples a negative edge
(i,k)
for every positive edge(i,j)
in the graph given byedge_index
, and returns it as a tuple of the form(i,j,k)
.- Parameters
- Return type
(LongTensor, LongTensor, LongTensor)
Example
>>> edge_index = torch.as_tensor([[0, 0, 1, 2], ... [0, 1, 2, 3]]) >>> structured_negative_sampling(edge_index) (tensor([0, 0, 1, 2]), tensor([0, 1, 2, 3]), tensor([2, 3, 0, 2]))
- structured_negative_sampling_feasible(edge_index: Tensor, num_nodes: Optional[int] = None, contains_neg_self_loops: bool = True) bool [source]
Returns
True
ifstructured_negative_sampling()
is feasible on the graph given byedge_index
.structured_negative_sampling()
is infeasible if atleast one node is connected to all other nodes.- Parameters
- Return type
Examples
>>> edge_index = torch.LongTensor([[0, 0, 1, 1, 2, 2, 2], ... [1, 2, 0, 2, 0, 1, 1]]) >>> structured_negative_sampling_feasible(edge_index, 3, False) False
>>> structured_negative_sampling_feasible(edge_index, 3, True) True
- tree_decomposition(mol: Any, return_vocab: bool = False) Union[Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, int, Tensor]] [source]
The tree decomposition algorithm of molecules from the “Junction Tree Variational Autoencoder for Molecular Graph Generation” paper. Returns the graph connectivity of the junction tree, the assignment mapping of each atom to the clique in the junction tree, and the number of cliques.
- Parameters
- Return type
(LongTensor, LongTensor, int)
ifreturn_vocab
isFalse
, else(LongTensor, LongTensor, int, LongTensor)
- get_embeddings(model: Module, *args, **kwargs) List[Tensor] [source]
Returns the output embeddings of all
MessagePassing
layers inmodel
.Internally, this method registers forward hooks on all
MessagePassing
layers of amodel
, and runs the forward pass of themodel
by callingmodel(*args, **kwargs)
.- Parameters
model (torch.nn.Module) – The message passing model.
*args – Arguments passed to the model.
**kwargs (optional) – Additional keyword arguments passed to the model.
- trim_to_layer(layer: int, num_sampled_nodes_per_hop: Union[List[int], Dict[str, List[int]]], num_sampled_edges_per_hop: Union[List[int], Dict[Tuple[str, str, str], List[int]]], x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], edge_attr: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None) Tuple[Union[Tensor, Dict[str, Tensor]], Union[Tensor, Dict[Tuple[str, str, str], Tensor]], Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]]] [source]
Trims the
edge_index
representation, node featuresx
and edge featuresedge_attr
to a minimal-sized representation for the current GNN layerlayer
in directedNeighborLoader
scenarios.This ensures that no computation is performed for nodes and edges that are not included in the current GNN layer, thus avoiding unnecessary computation within the GNN when performing neighborhood sampling.
- Parameters
layer (int) – The current GNN layer.
num_sampled_nodes_per_hop (List[int] or Dict[NodeType, List[int]]) – The number of sampled nodes per hop.
num_sampled_edges_per_hop (List[int] or Dict[EdgeType, List[int]]) – The number of sampled edges per hop.
x (torch.Tensor or Dict[NodeType, torch.Tensor]) – The homogeneous or heterogeneous (hidden) node features.
edge_index (torch.Tensor or Dict[EdgeType, torch.Tensor]) – The homogeneous or heterogeneous edge indices.
edge_attr (torch.Tensor or Dict[EdgeType, torch.Tensor], optional) – The homogeneous or heterogeneous (hidden) edge features.
- train_test_split_edges(data: Data, val_ratio: float = 0.05, test_ratio: float = 0.1) Data [source]
Splits the edges of a
torch_geometric.data.Data
object into positive and negative train/val/test edges. As such, it will replace theedge_index
attribute withtrain_pos_edge_index
,train_pos_neg_adj_mask
,val_pos_edge_index
,val_neg_edge_index
andtest_pos_edge_index
attributes. Ifdata
has edge features namededge_attr
, thentrain_pos_edge_attr
,val_pos_edge_attr
andtest_pos_edge_attr
will be added as well.Warning
train_test_split_edges()
is deprecated and will be removed in a future release. Usetorch_geometric.transforms.RandomLinkSplit
instead.- Parameters
- Return type