- class GLEM(lm_to_use: str = 'prajjwal1/bert-tiny', gnn_to_use: <module 'torch_geometric.nn.models.basic_gnn' from '/home/docs/checkouts/'> = <class 'torch_geometric.nn.models.basic_gnn.GraphSAGE'>, out_channels: int = 47, gnn_loss=CrossEntropyLoss(), lm_loss=CrossEntropyLoss(), alpha: float = 0.5, beta: float = 0.5, lm_dtype: ~torch.dtype = torch.bfloat16, lm_use_lora: bool = True, lora_target_modules: ~typing.Optional[~typing.Union[str, ~typing.List[str]]] = None, device: ~typing.Union[str, ~torch.device] = device(type='cpu'))[source]
This GNN+LM co-training model is based on GLEM from the “Learning on Large-scale Text-attributed Graphs via Variational Inference” paper.
- Parameters:
lm_to_use (str) – A TextEncoder from huggingface model repo with a classifier(default: TinyBERT)
gnn_to_use (torch_geometric.nn.models) – (default: GraphSAGE)
out_channels (int) – output channels for LM and GNN, should be same
Optional[int] (num_gnn_heads) – Number of heads for attention, if needed
num_gnn_layers (int) – number of gnn layers
gnn_loss – loss function for gnn, (default: CrossEntropyLoss)
lm_loss – loss function for Language Model, (default: CrossEntropyLoss)
alpha (float) – pseudo label weight of E-step, LM optimization, (default: 0.5)
beta (float) – pseudo label weight of M-step, GNN optimization, (default: 0.5)
lm_dtype (torch.dtype) – the data type once you load LM into memory, (default: torch.bfloat16)
lm_use_lora (bool) – choose if LM use Lora peft for fine tune, (default: True)
lora_target_modules (
], default:None
) – The names of the target modules to apply the lora adapter to, e.g. [‘q_proj’, ‘v_proj’] for LLM , (default: None)
See examples/llm_plus_gnn/ for example usage.
- forward(*input: Any) None
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within this function, one should call the
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type:
- train(em_phase: str, train_loader: Union[DataLoader, NeighborLoader], optimizer: Optimizer, pseudo_labels: Tensor, epoch: int, is_augmented: bool = False, verbose: bool = False)[source]
GLEM training step, EM steps.
- Parameters:
em_phase (str) – ‘gnn’ or ‘lm’ choose which phase you are training on
train_loader (Union[DataLoader, NeighborLoader]) – use DataLoader for lm training, include tokenized data, labels is_gold mask. use NeighborLoader for gnn training, include x, edge_index.
optimizer (torch.optim.Optimizer) – optimizer for training
pseudo_labels (torch.Tensor) – the predicted labels used as pseudo labels
epoch (int) – current epoch
is_augmented (bool) – will use pseudo_labels or not
verbose (bool) – print training progress bar or not
- Returns:
training accuracy loss (float): loss value
- Return type:
acc (float)
- train_lm(train_loader: DataLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]
Language model Training in every epoch.
- Parameters:
train_loader (loader.dataloader.DataLoader) – text token dataloader
optimizer (torch.optim.Optimizer) – model optimizer
epoch (int) – current train epoch
pseudo_labels (torch.Tensor) – 1-D tensor, predictions from gnn
is_augmented (bool) – train with pseudo labels or not
verbose (bool) – print training progress bar or not
- Returns:
training accuracy loss (torch.float): loss value
- Return type:
approx_acc (torch.tensor)
- train_gnn(train_loader: NeighborLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]
GNN training step in every epoch.
- Parameters:
train_loader (loader.NeighborLoader) – gnn Neighbor node loader
optimizer (torch.optim.Optimizer) – model optimizer
epoch (int) – current train epoch
pseudo_labels (torch.tensor) – 1-D tensor, predictions from lm
is_augmented (bool) – use pseudo labeled node or not
verbose (bool) – print training progress or not
- Returns:
training accuracy loss (torch.float): loss value
- Return type:
approx_acc (torch.tensor)
- inference(em_phase: str, data_loader: Union[NeighborLoader, DataLoader], verbose: bool = False)[source]
GLEM inference step.
- Parameters:
- Returns:
- n * m tensor, m is number of classes,
n is number of nodes
- Return type:
out (torch.Tensor)
- inference_lm(data_loader: DataLoader, verbose: bool = True)[source]
LM inference step.
- Parameters:
data_loader (Dataloader) – include token, labels, and gold mask
verbose (bool) – print progress bar or not
- Returns:
- prediction from GNN, convert to pseudo labels
by preds.argmax(dim=-1).unsqueeze(1)
- Return type:
preds (tensor)
- inference_gnn(data_loader: NeighborLoader, verbose: bool = True)[source]
GNN inference step.
- Parameters:
data_loader (NeighborLoader) – include x, edge_index,
verbose (bool) – print progress bar or not
- Returns:
- prediction from GNN,
convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1)
- Return type:
preds (tensor)
- loss(logits: ~torch.Tensor, labels: ~torch.Tensor, loss_func: <module 'torch.nn.functional' from '/home/docs/checkouts/'>, is_gold: ~torch.Tensor, pseudo_labels: ~typing.Optional[~torch.Tensor] = None, pl_weight: float = 0.5, is_augmented: bool = True)[source]
Core function of variational EM inference, this function is aming on combining loss value on gold(original train) and loss value on pseudo labels.
Reference: <> # noqa
- Parameters:
logits (torch.tensor) – predict results from LM or GNN
labels (torch.tensor) – combined node labels from ground truth and pseudo labels(if provided)
loss_func (torch.nn.modules.loss) – loss function for classification
is_gold (tensor) – a tensor with bool value that mask ground truth label and during training, thus ~is_gold mask pseudo labels
pseudo_labels (torch.tensor) – predictions from other model
pl_weight (
, default:0.5
) – the pseudo labels used in E-step and M-step optimization alpha in E-step, beta in M-step respectivelyis_augmented (
, default:True
) – use EM or just train GNN and LM with gold data