Source code for torch_geometric.explain.algorithm.dummy_explainer

from typing import Optional

import torch
from torch import Tensor

from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm


[docs]class DummyExplainer(ExplainerAlgorithm): r"""A dummy explainer for testing purposes.""" def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None, **kwargs, ) -> Explanation: r"""Returns random explanations based on the shape of the inputs. Args: model (torch.nn.Module): The model to explain. x (torch.Tensor): The node features. edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor, optional): The edge attributes. (default: :obj:`None`) Returns: Explanation: A random explanation based on the shape of the inputs. """ num_nodes, num_edges = x.size(0), edge_index.size(1) mask_dict = {} mask_dict['node_feat_mask'] = torch.rand_like(x) mask_dict['node_mask'] = torch.rand(num_nodes, device=x.device) mask_dict['edge_mask'] = torch.rand(num_edges, device=x.device) if edge_attr is not None: mask_dict['edge_feat_mask'] = torch.rand_like(edge_attr) return Explanation( edge_index=edge_index, x=x, edge_attr=edge_attr, **mask_dict, ) def supports(self) -> bool: return True