import copy
import torch
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
[docs]class DirGNNConv(torch.nn.Module):
r"""A generic wrapper for computing graph convolution on directed
graphs as described in the `"Edge Directionality Improves Learning on
Heterophilic Graphs" <https://arxiv.org/abs/2305.10498>`_ paper.
:class:`DirGNNConv` will pass messages both from source nodes to target
nodes and from target nodes to source nodes.
Args:
conv (MessagePassing): The underlying
:class:`~torch_geometric.nn.conv.MessagePassing` layer to use.
alpha (float, optional): The alpha coefficient used to weight the
aggregations of in- and out-edges as part of a convex combination.
(default: :obj:`0.5`)
root_weight (bool, optional): If set to :obj:`True`, the layer will add
transformed root node features to the output.
(default: :obj:`True`)
"""
def __init__(
self,
conv: MessagePassing,
alpha: float = 0.5,
root_weight: bool = True,
):
super().__init__()
self.alpha = alpha
self.root_weight = root_weight
self.conv_in = copy.deepcopy(conv)
self.conv_out = copy.deepcopy(conv)
if hasattr(conv, 'add_self_loops'):
self.conv_in.add_self_loops = False
self.conv_out.add_self_loops = False
if hasattr(conv, 'root_weight'):
self.conv_in.root_weight = False
self.conv_out.root_weight = False
if root_weight:
self.lin = torch.nn.Linear(conv.in_channels, conv.out_channels)
else:
self.lin = None
self.reset_parameters()
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.conv_in.reset_parameters()
self.conv_out.reset_parameters()
if self.lin is not None:
self.lin.reset_parameters()
[docs] def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
"""""" # noqa: D419
x_in = self.conv_in(x, edge_index)
x_out = self.conv_out(x, edge_index.flip([0]))
out = self.alpha * x_out + (1 - self.alpha) * x_in
if self.root_weight:
out = out + self.lin(x)
return out
def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.conv_in}, alpha={self.alpha})'