torch_geometric.nn.models.MetaPath2Vec
- class MetaPath2Vec(edge_index_dict: Dict[Tuple[str, str, str], Tensor], embedding_dim: int, metapath: List[Tuple[str, str, str]], walk_length: int, context_size: int, walks_per_node: int = 1, num_negative_samples: int = 1, num_nodes_dict: Optional[Dict[str, int]] = None, sparse: bool = False)[source]
Bases:
Module
The MetaPath2Vec model from the “metapath2vec: Scalable Representation Learning for Heterogeneous Networks” paper where random walks based on a given
metapath
are sampled in a heterogeneous graph, and node embeddings are learned via negative sampling optimization.Note
For an example of using MetaPath2Vec, see examples/hetero/metapath2vec.py.
- Parameters:
edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – Dictionary holding edge indices for each
(src_node_type, rel_type, dst_node_type)
edge type present in the heterogeneous graph.embedding_dim (int) – The size of each embedding vector.
metapath (List[Tuple[str, str, str]]) – The metapath described as a list of
(src_node_type, rel_type, dst_node_type)
tuples.walk_length (int) – The walk length.
context_size (int) – The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes.
walks_per_node (int, optional) – The number of walks to sample for each node. (default:
1
)num_negative_samples (int, optional) – The number of negative samples to use for each positive sample. (default:
1
)num_nodes_dict (Dict[str, int], optional) – Dictionary holding the number of nodes for each node type. (default:
None
)sparse (bool, optional) – If set to
True
, gradients w.r.t. to the weight matrix will be sparse. (default:False
)
- forward(node_type: str, batch: Optional[Tensor] = None) Tensor [source]
Returns the embeddings for the nodes in
batch
of typenode_type
.
- loader(**kwargs)[source]
Returns the data loader that creates both positive and negative random walks on the heterogeneous graph.
- Parameters:
**kwargs (optional) – Arguments of
torch.utils.data.DataLoader
, such asbatch_size
,shuffle
,drop_last
ornum_workers
.