# Source code for torch_geometric.utils._homophily

from typing import Union, overload

import torch
from torch import Tensor

from torch_geometric.typing import Adj, OptTensor, SparseTensor
from torch_geometric.utils import degree, scatter

def homophily(
y: Tensor,
batch: None = ...,
method: str = ...,
) -> float:
pass

def homophily(
y: Tensor,
batch: Tensor,
method: str = ...,
) -> Tensor:
pass

[docs]def homophily(
y: Tensor,
batch: OptTensor = None,
method: str = 'edge',
) -> Union[float, Tensor]:
r"""The homophily of a graph characterizes how likely nodes with the same
label are near each other in a graph.

There are many measures of homophily that fits this definition.
In particular:

- In the "Beyond Homophily in Graph Neural Networks: Current Limitations
and Effective Designs" <https://arxiv.org/abs/2006.11468>_ paper, the
homophily is the fraction of edges in a graph which connects nodes
that have the same class label:

.. math::
\frac{| \{ (v,w) : (v,w) \in \mathcal{E} \wedge y_v = y_w \} | }
{|\mathcal{E}|}

That measure is called the *edge homophily ratio*.

- In the "Geom-GCN: Geometric Graph Convolutional Networks"
<https://arxiv.org/abs/2002.05287>_ paper, edge homophily is normalized
across neighborhoods:

.. math::
\frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \frac{ | \{ (w,v) : w
\in \mathcal{N}(v) \wedge y_v = y_w \} |  } { |\mathcal{N}(v)| }

That measure is called the *node homophily ratio*.

- In the "Large-Scale Learning on Non-Homophilous Graphs: New Benchmarks
and Strong Simple Methods" <https://arxiv.org/abs/2110.14446>_ paper,
edge homophily is modified to be insensitive to the number of classes
and size of each class:

.. math::
\frac{1}{C-1} \sum_{k=1}^{C} \max \left(0, h_k - \frac{|\mathcal{C}_k|}
{|\mathcal{V}|} \right),

where :math:C denotes the number of classes, :math:|\mathcal{C}_k|
denotes the number of nodes of class :math:k, and :math:h_k denotes
the edge homophily ratio of nodes of class :math:k.

Thus, that measure is called the *class insensitive edge homophily
ratio*.

Args:
edge_index (Tensor or SparseTensor): The graph connectivity.
y (Tensor): The labels.
batch (LongTensor, optional): Batch vector\
:math:\mathbf{b} \in {\{ 0, \ldots,B-1\}}^N, which assigns
each node to a specific example. (default: :obj:None)
method (str, optional): The method used to calculate the homophily,
either :obj:"edge" (first formula), :obj:"node" (second
formula) or :obj:"edge_insensitive" (third formula).
(default: :obj:"edge")

Examples:
>>> edge_index = torch.tensor([[0, 1, 2, 3],
...                            [1, 2, 0, 4]])
>>> y = torch.tensor([0, 0, 0, 0, 1])
>>> # Edge homophily ratio
>>> homophily(edge_index, y, method='edge')
0.75

>>> # Node homophily ratio
>>> homophily(edge_index, y, method='node')
0.6000000238418579

>>> # Class insensitive edge homophily ratio
>>> homophily(edge_index, y, method='edge_insensitive')
0.19999998807907104
"""
assert method in {'edge', 'node', 'edge_insensitive'}
y = y.squeeze(-1) if y.dim() > 1 else y

if isinstance(edge_index, SparseTensor):
row, col, _ = edge_index.coo()
else:
row, col = edge_index

if method == 'edge':
out = torch.zeros(row.size(0), device=row.device)
out[y[row] == y[col]] = 1.
if batch is None:
return float(out.mean())
else:
dim_size = int(batch.max()) + 1
return scatter(out, batch[col], 0, dim_size, reduce='mean')

elif method == 'node':
out = torch.zeros(row.size(0), device=row.device)
out[y[row] == y[col]] = 1.
out = scatter(out, col, 0, dim_size=y.size(0), reduce='mean')
if batch is None:
return float(out.mean())
else:
return scatter(out, batch, dim=0, reduce='mean')

elif method == 'edge_insensitive':
assert y.dim() == 1
num_classes = int(y.max()) + 1
assert num_classes >= 2
batch = torch.zeros_like(y) if batch is None else batch
num_nodes = degree(batch, dtype=torch.int64)
num_graphs = num_nodes.numel()
batch = num_classes * batch + y

h = homophily(edge_index, y, batch, method='edge')
h = h.view(num_graphs, num_classes)

counts = batch.bincount(minlength=num_classes * num_graphs)
counts = counts.view(num_graphs, num_classes)
proportions = counts / num_nodes.view(-1, 1)

out = (h - proportions).clamp_(min=0).sum(dim=-1)
out /= num_classes - 1
return out if out.numel() > 1 else float(out)

else:
raise NotImplementedError