from typing import List, Optional, Union
import torch
from torch import Tensor
from torch_geometric.data import Data, Dataset, InMemoryDataset
[docs]class ImbalancedSampler(torch.utils.data.WeightedRandomSampler):
r"""A weighted random sampler that randomly samples elements according to
class distribution.
As such, it will either remove samples from the majority class
(under-sampling) or add more examples from the minority class
(over-sampling).
**Graph-level sampling:**
.. code-block:: python
from torch_geometric.loader import DataLoader, ImbalancedSampler
sampler = ImbalancedSampler(dataset)
loader = DataLoader(dataset, batch_size=64, sampler=sampler, ...)
**Node-level sampling:**
.. code-block:: python
from torch_geometric.loader import NeighborLoader, ImbalancedSampler
sampler = ImbalancedSampler(data, input_nodes=data.train_mask)
loader = NeighborLoader(data, input_nodes=data.train_mask,
batch_size=64, num_neighbors=[-1, -1],
sampler=sampler, ...)
You can also pass in the class labels directly as a :class:`torch.Tensor`:
.. code-block:: python
from torch_geometric.loader import NeighborLoader, ImbalancedSampler
sampler = ImbalancedSampler(data.y)
loader = NeighborLoader(data, input_nodes=data.train_mask,
batch_size=64, num_neighbors=[-1, -1],
sampler=sampler, ...)
Args:
dataset (Dataset or Data or Tensor): The dataset or class distribution
from which to sample the data, given either as a
:class:`~torch_geometric.data.Dataset`,
:class:`~torch_geometric.data.Data`, or :class:`torch.Tensor`
object.
input_nodes (Tensor, optional): The indices of nodes that are used by
the corresponding loader, *e.g.*, by
:class:`~torch_geometric.loader.NeighborLoader`.
If set to :obj:`None`, all nodes will be considered.
This argument should only be set for node-level loaders and does
not have any effect when operating on a set of graphs as given by
:class:`~torch_geometric.data.Dataset`. (default: :obj:`None`)
num_samples (int, optional): The number of samples to draw for a single
epoch. If set to :obj:`None`, will sample as much elements as there
exists in the underlying data. (default: :obj:`None`)
"""
def __init__(
self,
dataset: Union[Dataset, Data, List[Data], Tensor],
input_nodes: Optional[Tensor] = None,
num_samples: Optional[int] = None,
):
if isinstance(dataset, Data):
y = dataset.y.view(-1)
assert dataset.num_nodes == y.numel()
y = y[input_nodes] if input_nodes is not None else y
elif isinstance(dataset, Tensor):
y = dataset.view(-1)
y = y[input_nodes] if input_nodes is not None else y
elif isinstance(dataset, InMemoryDataset):
y = dataset.y.view(-1)
assert len(dataset) == y.numel()
else:
ys = [data.y for data in dataset]
if isinstance(ys[0], Tensor):
y = torch.cat(ys, dim=0).view(-1)
else:
y = torch.tensor(ys).view(-1)
assert len(dataset) == y.numel()
assert y.dtype == torch.long # Require classification.
num_samples = y.numel() if num_samples is None else num_samples
class_weight = 1. / y.bincount()
weight = class_weight[y]
return super().__init__(weight, num_samples, replacement=True)