Source code for torch_geometric.transforms.add_metapaths
import warnings
from typing import List, Optional, Tuple, Union, cast
import torch
from torch import Tensor
from torch_geometric import EdgeIndex
from torch_geometric.data import HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.typing import EdgeType
from torch_geometric.utils import coalesce, degree
[docs]@functional_transform('add_metapaths')
class AddMetaPaths(BaseTransform):
r"""Adds additional edge types to a
:class:`~torch_geometric.data.HeteroData` object between the source node
type and the destination node type of a given :obj:`metapath`, as described
in the `"Heterogenous Graph Attention Networks"
<https://arxiv.org/abs/1903.07293>`_ paper
(functional name: :obj:`add_metapaths`).
Meta-path based neighbors can exploit different aspects of structure
information in heterogeneous graphs.
Formally, a metapath is a path of the form
.. math::
\mathcal{V}_1 \xrightarrow{R_1} \mathcal{V}_2 \xrightarrow{R_2} \ldots
\xrightarrow{R_{\ell-1}} \mathcal{V}_{\ell}
in which :math:`\mathcal{V}_i` represents node types, and :math:`R_j`
represents the edge type connecting two node types.
The added edge type is given by the sequential multiplication of
adjacency matrices along the metapath, and is added to the
:class:`~torch_geometric.data.HeteroData` object as edge type
:obj:`(src_node_type, "metapath_*", dst_node_type)`, where
:obj:`src_node_type` and :obj:`dst_node_type` denote :math:`\mathcal{V}_1`
and :math:`\mathcal{V}_{\ell}`, respectively.
In addition, a :obj:`metapath_dict` object is added to the
:class:`~torch_geometric.data.HeteroData` object which maps the
metapath-based edge type to its original metapath.
.. code-block:: python
from torch_geometric.datasets import DBLP
from torch_geometric.data import HeteroData
from torch_geometric.transforms import AddMetaPaths
data = DBLP(root)[0]
# 4 node types: "paper", "author", "conference", and "term"
# 6 edge types: ("paper","author"), ("author", "paper"),
# ("paper, "term"), ("paper", "conference"),
# ("term, "paper"), ("conference", "paper")
# Add two metapaths:
# 1. From "paper" to "paper" through "conference"
# 2. From "author" to "conference" through "paper"
metapaths = [[("paper", "conference"), ("conference", "paper")],
[("author", "paper"), ("paper", "conference")]]
data = AddMetaPaths(metapaths)(data)
print(data.edge_types)
>>> [("author", "to", "paper"), ("paper", "to", "author"),
("paper", "to", "term"), ("paper", "to", "conference"),
("term", "to", "paper"), ("conference", "to", "paper"),
("paper", "metapath_0", "paper"),
("author", "metapath_1", "conference")]
print(data.metapath_dict)
>>> {("paper", "metapath_0", "paper"): [("paper", "conference"),
("conference", "paper")],
("author", "metapath_1", "conference"): [("author", "paper"),
("paper", "conference")]}
Args:
metapaths (List[List[Tuple[str, str, str]]]): The metapaths described
by a list of lists of
:obj:`(src_node_type, rel_type, dst_node_type)` tuples.
drop_orig_edge_types (bool, optional): If set to :obj:`True`, existing
edge types will be dropped. (default: :obj:`False`)
keep_same_node_type (bool, optional): If set to :obj:`True`, existing
edge types between the same node type are not dropped even in case
:obj:`drop_orig_edge_types` is set to :obj:`True`.
(default: :obj:`False`)
drop_unconnected_node_types (bool, optional): If set to :obj:`True`,
will drop node types not connected by any edge type.
(default: :obj:`False`)
max_sample (int, optional): If set, will sample at maximum
:obj:`max_sample` neighbors within metapaths. Useful in order to
tackle very dense metapath edges. (default: :obj:`None`)
weighted (bool, optional): If set to :obj:`True`, computes weights for
each metapath edge and stores them in :obj:`edge_weight`. The
weight of each metapath edge is computed as the number of metapaths
from the start to the end of the metapath edge.
(default :obj:`False`)
"""
def __init__(
self,
metapaths: List[List[EdgeType]],
drop_orig_edge_types: bool = False,
keep_same_node_type: bool = False,
drop_unconnected_node_types: bool = False,
max_sample: Optional[int] = None,
weighted: bool = False,
**kwargs: bool,
) -> None:
if 'drop_orig_edges' in kwargs:
warnings.warn("'drop_orig_edges' is deprecated. Use "
"'drop_orig_edge_types' instead")
drop_orig_edge_types = kwargs['drop_orig_edges']
if 'drop_unconnected_nodes' in kwargs:
warnings.warn("'drop_unconnected_nodes' is deprecated. Use "
"'drop_unconnected_node_types' instead")
drop_unconnected_node_types = kwargs['drop_unconnected_nodes']
for path in metapaths:
assert len(path) >= 2, f"Invalid metapath '{path}'"
assert all([
j[-1] == path[i + 1][0] for i, j in enumerate(path[:-1])
]), f"Invalid sequence of node types in '{path}'"
self.metapaths = metapaths
self.drop_orig_edge_types = drop_orig_edge_types
self.keep_same_node_type = keep_same_node_type
self.drop_unconnected_node_types = drop_unconnected_node_types
self.max_sample = max_sample
self.weighted = weighted
def forward(self, data: HeteroData) -> HeteroData:
edge_types = data.edge_types # Save original edge types.
data.metapath_dict = {}
for j, metapath in enumerate(self.metapaths):
for edge_type in metapath:
assert data._to_canonical(edge_type) in edge_types
edge_type = metapath[0]
edge_index, edge_weight = self._edge_index(data, edge_type)
if self.max_sample is not None:
edge_index, edge_weight = self._sample(edge_index, edge_weight)
for i, edge_type in enumerate(metapath[1:]):
edge_index2, edge_weight2 = self._edge_index(data, edge_type)
edge_index, edge_weight = edge_index.matmul(
edge_index2, edge_weight, edge_weight2)
if not self.weighted:
edge_weight = None
if self.max_sample is not None:
edge_index, edge_weight = self._sample(
edge_index, edge_weight)
new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])
data[new_edge_type].edge_index = edge_index.as_tensor()
if self.weighted:
data[new_edge_type].edge_weight = edge_weight
data.metapath_dict[new_edge_type] = metapath
postprocess(data, edge_types, self.drop_orig_edge_types,
self.keep_same_node_type, self.drop_unconnected_node_types)
return data
def _edge_index(
self,
data: HeteroData,
edge_type: EdgeType,
) -> Tuple[EdgeIndex, Optional[Tensor]]:
edge_index = EdgeIndex(
data[edge_type].edge_index,
sparse_size=data[edge_type].size(),
)
edge_index, perm = edge_index.sort_by('row')
if not self.weighted:
return edge_index, None
edge_weight = data[edge_type].get('edge_weight')
if edge_weight is not None:
assert edge_weight.dim() == 1
edge_weight = edge_weight[perm]
return edge_index, edge_weight
def _sample(
self,
edge_index: EdgeIndex,
edge_weight: Optional[Tensor],
) -> Tuple[EdgeIndex, Optional[Tensor]]:
deg = degree(edge_index[0], num_nodes=edge_index.get_sparse_size(0))
prob = (self.max_sample * (1. / deg))[edge_index[0]]
mask = torch.rand_like(prob) < prob
edge_index = cast(EdgeIndex, edge_index[:, mask])
assert isinstance(edge_index, EdgeIndex)
if edge_weight is not None:
edge_weight = edge_weight[mask]
return edge_index, edge_weight
[docs]@functional_transform('add_random_metapaths')
class AddRandomMetaPaths(BaseTransform):
r"""Adds additional edge types similar to :class:`AddMetaPaths`.
The key difference is that the added edge type is given by
multiple random walks along the metapath.
One might want to increase the number of random walks
via :obj:`walks_per_node` to achieve competitive performance with
:class:`AddMetaPaths`.
Args:
metapaths (List[List[Tuple[str, str, str]]]): The metapaths described
by a list of lists of
:obj:`(src_node_type, rel_type, dst_node_type)` tuples.
drop_orig_edge_types (bool, optional): If set to :obj:`True`, existing
edge types will be dropped. (default: :obj:`False`)
keep_same_node_type (bool, optional): If set to :obj:`True`, existing
edge types between the same node type are not dropped even in case
:obj:`drop_orig_edge_types` is set to :obj:`True`.
(default: :obj:`False`)
drop_unconnected_node_types (bool, optional): If set to :obj:`True`,
will drop node types not connected by any edge type.
(default: :obj:`False`)
walks_per_node (int, List[int], optional): The number of random walks
for each starting node in a metapath. (default: :obj:`1`)
sample_ratio (float, optional): The ratio of source nodes to start
random walks from. (default: :obj:`1.0`)
"""
def __init__(
self,
metapaths: List[List[EdgeType]],
drop_orig_edge_types: bool = False,
keep_same_node_type: bool = False,
drop_unconnected_node_types: bool = False,
walks_per_node: Union[int, List[int]] = 1,
sample_ratio: float = 1.0,
):
for path in metapaths:
assert len(path) >= 2, f"Invalid metapath '{path}'"
assert all([
j[-1] == path[i + 1][0] for i, j in enumerate(path[:-1])
]), f"Invalid sequence of node types in '{path}'"
self.metapaths = metapaths
self.drop_orig_edge_types = drop_orig_edge_types
self.keep_same_node_type = keep_same_node_type
self.drop_unconnected_node_types = drop_unconnected_node_types
self.sample_ratio = sample_ratio
if isinstance(walks_per_node, int):
walks_per_node = [walks_per_node] * len(metapaths)
assert len(walks_per_node) == len(metapaths)
self.walks_per_node = walks_per_node
def forward(self, data: HeteroData) -> HeteroData:
edge_types = data.edge_types # save original edge types
data.metapath_dict = {}
for j, metapath in enumerate(self.metapaths):
for edge_type in metapath:
assert data._to_canonical(
edge_type) in edge_types, f"'{edge_type}' not present"
src_node = metapath[0][0]
num_nodes = data[src_node].num_nodes
num_starts = round(num_nodes * self.sample_ratio)
row = start = torch.randperm(num_nodes)[:num_starts].repeat(
self.walks_per_node[j])
for i, edge_type in enumerate(metapath):
edge_index = EdgeIndex(
data[edge_type].edge_index,
sparse_size=data[edge_type].size(),
)
col, mask = self.sample(edge_index, start)
row, col = row[mask], col[mask]
start = col
new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])
data[new_edge_type].edge_index = coalesce(torch.vstack([row, col]))
data.metapath_dict[new_edge_type] = metapath
postprocess(data, edge_types, self.drop_orig_edge_types,
self.keep_same_node_type, self.drop_unconnected_node_types)
return data
@staticmethod
def sample(edge_index: EdgeIndex, subset: Tensor) -> Tuple[Tensor, Tensor]:
"""Sample neighbors from :obj:`edge_index` for each node in
:obj:`subset`.
"""
edge_index, _ = edge_index.sort_by('row')
rowptr = edge_index.get_indptr()
rowcount = rowptr.diff()[subset]
mask = rowcount > 0
offset = torch.zeros_like(subset)
offset[mask] = rowptr[subset[mask]]
rand = torch.rand((rowcount.size(0), 1), device=subset.device)
rand.mul_(rowcount.to(rand.dtype).view(-1, 1))
rand = rand.to(torch.long)
rand.add_(offset.view(-1, 1))
col = edge_index[1][rand].squeeze()
return col, mask
def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'sample_ratio={self.sample_ratio}, '
f'walks_per_node={self.walks_per_node})')
def postprocess(
data: HeteroData,
edge_types: List[EdgeType],
drop_orig_edge_types: bool,
keep_same_node_type: bool,
drop_unconnected_node_types: bool,
) -> None:
if drop_orig_edge_types:
for i in edge_types:
if keep_same_node_type and i[0] == i[-1]:
continue
else:
del data[i]
# Remove nodes not connected by any edge type:
if drop_unconnected_node_types:
new_edge_types = data.edge_types
node_types = data.node_types
connected_nodes = set()
for i in new_edge_types:
connected_nodes.add(i[0])
connected_nodes.add(i[-1])
for node in node_types:
if node not in connected_nodes:
del data[node]