torch_geometric.nn.models.GLEM

class GLEM(lm_to_use: str = 'prajjwal1/bert-tiny', gnn_to_use: <module 'torch_geometric.nn.models.basic_gnn' from '/home/docs/checkouts/readthedocs.org/user_builds/pytorch-geometric/envs/latest/lib/python3.9/site-packages/torch_geometric/nn/models/basic_gnn.py'> = <class 'torch_geometric.nn.models.basic_gnn.GraphSAGE'>, out_channels: int = 47, gnn_loss=CrossEntropyLoss(), lm_loss=CrossEntropyLoss(), alpha: float = 0.5, beta: float = 0.5, lm_dtype: ~torch.dtype = torch.bfloat16, lm_use_lora: bool = True, lora_target_modules: ~typing.Optional[~typing.Union[str, ~typing.List[str]]] = None, device: ~typing.Union[str, ~torch.device] = device(type='cpu'))[source]

Bases: Module

This GNN+LM co-training model is based on GLEM from the “Learning on Large-scale Text-attributed Graphs via Variational Inference” paper.

Parameters:
  • lm_to_use (str) – A TextEncoder from huggingface model repo with a classifier(default: TinyBERT)

  • gnn_to_use (torch_geometric.nn.models) – (default: GraphSAGE)

  • out_channels (int) – output channels for LM and GNN, should be same

  • Optional[int] (num_gnn_heads) – Number of heads for attention, if needed

  • num_gnn_layers (int) – number of gnn layers

  • gnn_loss – loss function for gnn, (default: CrossEntropyLoss)

  • lm_loss – loss function for Language Model, (default: CrossEntropyLoss)

  • alpha (float) – pseudo label weight of E-step, LM optimization, (default: 0.5)

  • beta (float) – pseudo label weight of M-step, GNN optimization, (default: 0.5)

  • lm_dtype (torch.dtype) – the data type once you load LM into memory, (default: torch.bfloat16)

  • lm_use_lora (bool) – choose if LM use Lora peft for fine tune, (default: True)

  • lora_target_modules (Union[str, List[str], None], default: None) – The names of the target modules to apply the lora adapter to, e.g. [‘q_proj’, ‘v_proj’] for LLM , (default: None)

Note

See examples/llm_plus_gnn/glem.py for example usage.

forward(*input: Any) None

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

None

train(em_phase: str, train_loader: Union[DataLoader, NeighborLoader], optimizer: Optimizer, pseudo_labels: Tensor, epoch: int, is_augmented: bool = False, verbose: bool = False)[source]

GLEM training step, EM steps.

Parameters:
  • em_phase (str) – ‘gnn’ or ‘lm’ choose which phase you are training on

  • train_loader (Union[DataLoader, NeighborLoader]) – use DataLoader for lm training, include tokenized data, labels is_gold mask. use NeighborLoader for gnn training, include x, edge_index.

  • optimizer (torch.optim.Optimizer) – optimizer for training

  • pseudo_labels (torch.Tensor) – the predicted labels used as pseudo labels

  • epoch (int) – current epoch

  • is_augmented (bool) – will use pseudo_labels or not

  • verbose (bool) – print training progress bar or not

Returns:

training accuracy loss (float): loss value

Return type:

acc (float)

train_lm(train_loader: DataLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]

Language model Training in every epoch.

Parameters:
Returns:

training accuracy loss (torch.float): loss value

Return type:

approx_acc (torch.tensor)

train_gnn(train_loader: NeighborLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]

GNN training step in every epoch.

Parameters:
  • train_loader (loader.NeighborLoader) – gnn Neighbor node loader

  • optimizer (torch.optim.Optimizer) – model optimizer

  • epoch (int) – current train epoch

  • pseudo_labels (torch.tensor) – 1-D tensor, predictions from lm

  • is_augmented (bool) – use pseudo labeled node or not

  • verbose (bool) – print training progress or not

Returns:

training accuracy loss (torch.float): loss value

Return type:

approx_acc (torch.tensor)

inference(em_phase: str, data_loader: Union[NeighborLoader, DataLoader], verbose: bool = False)[source]

GLEM inference step.

Parameters:
  • em_phase (str) – ‘gnn’ or ‘lm’

  • data_loader (dataloader or Neighborloader) – dataloader: for lm training, include tokenized data nodeloader: for gnn training, include x, edge_index

  • verbose (bool) – print inference progress or not

Returns:

n * m tensor, m is number of classes,

n is number of nodes

Return type:

out (torch.Tensor)

inference_lm(data_loader: DataLoader, verbose: bool = True)[source]

LM inference step.

Parameters:
  • data_loader (Dataloader) – include token, labels, and gold mask

  • verbose (bool) – print progress bar or not

Returns:

prediction from GNN, convert to pseudo labels

by preds.argmax(dim=-1).unsqueeze(1)

Return type:

preds (tensor)

inference_gnn(data_loader: NeighborLoader, verbose: bool = True)[source]

GNN inference step.

Parameters:
  • data_loader (NeighborLoader) – include x, edge_index,

  • verbose (bool) – print progress bar or not

Returns:

prediction from GNN,

convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1)

Return type:

preds (tensor)

loss(logits: ~torch.Tensor, labels: ~torch.Tensor, loss_func: <module 'torch.nn.functional' from '/home/docs/checkouts/readthedocs.org/user_builds/pytorch-geometric/envs/latest/lib/python3.9/site-packages/torch/nn/functional.py'>, is_gold: ~torch.Tensor, pseudo_labels: ~typing.Optional[~torch.Tensor] = None, pl_weight: float = 0.5, is_augmented: bool = True)[source]

Core function of variational EM inference, this function is aming on combining loss value on gold(original train) and loss value on pseudo labels.

Reference: <https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa

Parameters:
  • logits (torch.tensor) – predict results from LM or GNN

  • labels (torch.tensor) – combined node labels from ground truth and pseudo labels(if provided)

  • loss_func (torch.nn.modules.loss) – loss function for classification

  • is_gold (tensor) – a tensor with bool value that mask ground truth label and during training, thus ~is_gold mask pseudo labels

  • pseudo_labels (torch.tensor) – predictions from other model

  • pl_weight (float, default: 0.5) – the pseudo labels used in E-step and M-step optimization alpha in E-step, beta in M-step respectively

  • is_augmented (bool, default: True) – use EM or just train GNN and LM with gold data