torch_geometric.transforms.RemoveTrainingClasses
- class RemoveTrainingClasses(classes: List[int])[source]
Bases:
BaseTransform
Removes classes from the node-level training set as given by
data.train_mask
, e.g., in order to get a zero-shot label scenario (functional name:remove_training_classes
).- Parameters:
classes (List[int]) – The classes to remove from the training set.