torch_geometric.transforms.RemoveTrainingClasses

class RemoveTrainingClasses(classes: List[int])[source]

Bases: BaseTransform

Removes classes from the node-level training set as given by data.train_mask, e.g., in order to get a zero-shot label scenario (functional name: remove_training_classes).

Parameters:

classes (List[int]) – The classes to remove from the training set.