torch_geometric.explain.algorithm.GraphMaskExplainer
- class GraphMaskExplainer(num_layers: int, epochs: int = 100, lr: float = 0.01, penalty_scaling: int = 5, lambda_optimizer_lr: int = 0.01, init_lambda: int = 0.55, allowance: int = 0.03, allow_multiple_explanations: bool = False, log: bool = True, **kwargs)[source]
Bases:
ExplainerAlgorithm
The GraphMask-Explainer model from the “Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking” paper for identifying layer-wise compact subgraph structures and node features that play a crucial role in the predictions made by a GNN.
Note
For an example of using
GraphMaskExplainer
, see examples/explain/graphmask_explainer.py.A working real-time example of
GraphMaskExplainer
in the form of a deployed app can be accessed here.- Parameters:
num_layers (int) – The number of layers to use.
epochs (int, optional) – The number of epochs to train. (default:
100
)lr (float, optional) – The learning rate to apply. (default:
0.01
)penalty_scaling (int, optional) – Scaling value of penalty term. Value must lie between 0 and 10. (default:
5
)lambda_optimizer_lr (float, optional) – The learning rate to optimize the Lagrange multiplier. (default:
1e-2
)init_lambda (float, optional) – The Lagrange multiplier. Value must lie between
0
and 1. (default:0.55
)allowance (float, optional) – A float value between
0
and1
denotes tolerance level. (default:0.03
)log (bool, optional) – If set to
False
, will not log any learning progress. (default:True
)**kwargs (optional) – Additional hyper-parameters to override default settings in
coeffs
.
- forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs) Explanation [source]
Computes the explanation.
- Parameters:
model (torch.nn.Module) – The model to explain.
x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.
edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.
target (torch.Tensor) – The target of the model.
index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default:
None
)**kwargs (optional) – Additional keyword arguments passed to
model
.