import csv
import os
import os.path as osp
from collections.abc import Sequence
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from torch import Tensor
from tqdm import tqdm
from torch_geometric.data import InMemoryDataset, download_google_url
from torch_geometric.data.data import BaseData
from torch_geometric.io import fs
try:
from transformers import PreTrainedTokenizerBase
WITH_TRANSFORMERS = True
except ImportError: # pragma: no cover
PreTrainedTokenizerBase = None
WITH_TRANSFORMERS = False
try:
from pandas import DataFrame, read_csv
WITH_PANDAS = True
except ImportError:
WITH_PANDAS = False
IndexType = Union[slice, Tensor, np.ndarray, Sequence]
def _safe_auto_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
try:
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(model_name)
except ValueError as err:
if "bert" in model_name:
from transformers import BertTokenizer
return BertTokenizer.from_pretrained(model_name)
elif "gpt2" in model_name:
from transformers import GPT2Tokenizer
return GPT2Tokenizer.from_pretrained(model_name)
else:
raise RuntimeError(
f"Unsupported legacy model: {model_name}") from err
[docs]class TAGDataset(InMemoryDataset):
r"""The Text Attributed Graph datasets from the
`"Learning on Large-scale Text-attributed Graphs via Variational Inference"
<https://arxiv.org/abs/2210.14709>`_ paper and `"Harnessing Explanations:
LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation
Learning" <https://arxiv.org/abs/2305.19523>`_ paper.
This dataset is aiming on transform `ogbn products`, `ogbn arxiv`
into Text Attributed Graph that each node in graph is associate with a
raw text, LLM prediction and explanation, that dataset can be adapt to
DataLoader (for LM training) and NeighborLoader(for GNN training).
In addition, this class can be use as a wrapper class by convert a
InMemoryDataset with Tokenizer and text into Text Attributed Graph.
Args:
root (str): Root directory where the dataset should be saved.
dataset (InMemoryDataset): The name of the dataset
(:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`).
tokenizer_name (str): The tokenizer name for language model,
Be sure to use same tokenizer name as your `model id` of model repo
on huggingface.co.
text (List[str]): list of raw text associate with node, the order of
list should be align with node list
split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary,
for saving split index, it is required that if your dataset doesn't
have get_split_idx function
tokenize_batch_size (int): batch size of tokenizing text, the
tokenizing process will run on cpu, default: 256
token_on_disk (bool): save token as .pt file on disk or not,
default: False
text_on_disk (bool): save given text(list of str) as dataframe on disk
or not, default: False
force_reload (bool): default: False
.. note::
See `example/llm/glem.py` for example usage
"""
raw_text_id = {
'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',
'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'
}
llm_prediction_url = 'https://github.com/XiaoxinHe/TAPE/raw/main/gpt_preds'
llm_explanation_id = {
'ogbn-arxiv': '1o8n2xRen-N_elF9NQpIca0iCHJgEJbRQ',
}
def __init__(
self,
root: str,
dataset: InMemoryDataset,
tokenizer_name: str,
text: Optional[List[str]] = None,
split_idx: Optional[Dict[str, Tensor]] = None,
tokenize_batch_size: int = 256,
token_on_disk: bool = False,
text_on_disk: bool = False,
force_reload: bool = False,
) -> None:
# list the vars you want to pass in before run download & process
self.name = dataset.name
self.text = text
self.llm_prediction_topk = 5
self.tokenizer_name = tokenizer_name
self.tokenizer = _safe_auto_tokenizer(tokenizer_name)
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.dir_name = '_'.join(dataset.name.split('-'))
self.root = osp.join(root, self.dir_name)
missing_str_list = []
if not WITH_PANDAS:
missing_str_list.append('pandas')
if len(missing_str_list) > 0:
missing_str = ' '.join(missing_str_list)
error_out = f"`pip install {missing_str}` to use this dataset."
raise ImportError(error_out)
if hasattr(dataset, 'get_idx_split'):
self.split_idx = dataset.get_idx_split()
elif split_idx is not None:
self.split_idx = split_idx
else:
raise ValueError("TAGDataset need split idx for generating "
"is_gold mask, please pass splited index "
"in format of dictionaty with 'train', 'valid' "
"'test' index tensor to 'split_idx'")
if text_on_disk:
if text is not None:
self.save_node_text(text)
self.text_on_disk = text_on_disk
# init will call download and process
super().__init__(self.root, transform=None, pre_transform=None,
pre_filter=None, force_reload=force_reload)
# after processing and download
# Dataset has to have BaseData as _data
assert dataset._data is not None
self._data = dataset._data # reassign reference
assert self._data is not None
assert dataset._data.y is not None
assert isinstance(self._data, BaseData)
assert self._data.num_nodes is not None
assert isinstance(dataset._data.num_nodes, int)
assert isinstance(self._data.num_nodes, int)
self._n_id = torch.arange(self._data.num_nodes)
is_good_tensor = self.load_gold_mask()
self._is_gold = is_good_tensor.squeeze()
self._data['is_gold'] = is_good_tensor
if self.text is not None and len(self.text) != self._data.num_nodes:
raise ValueError("The number of text sequence in 'text' should be "
"equal to number of nodes!")
self.token_on_disk = token_on_disk
self.tokenize_batch_size = tokenize_batch_size
self._token = self.tokenize_graph(self.tokenize_batch_size)
self._llm_explanation_token: Dict[str, Tensor] = {}
self._all_token: Dict[str, Tensor] = {}
if self.name in self.llm_explanation_id:
self._llm_explanation_token = self.tokenize_graph(
self.tokenize_batch_size, text_type='llm_explanation')
self._all_token = self.tokenize_graph(self.tokenize_batch_size,
text_type='all')
self.__num_classes__ = dataset.num_classes
@property
def num_classes(self) -> int:
return self.__num_classes__
@property
def raw_file_names(self) -> List[str]:
file_names = []
for _, _, files in os.walk(osp.join(self.root, 'raw')):
for file in files:
file_names.append(file)
return file_names
@property
def processed_file_names(self) -> List[str]:
return [
'geometric_data_processed.pt', 'pre_filter.pt',
'pre_transformed.pt'
]
@property
def token(self) -> Dict[str, Tensor]:
if self._token is None: # lazy load
self._token = self.tokenize_graph()
return self._token
@property
def llm_explanation_token(self) -> Dict[str, Tensor]:
if self._llm_explanation_token is None and \
self.name in self.llm_explanation_id:
self._llm_explanation_token = self.tokenize_graph(
text_type='llm_explanation')
return self._llm_explanation_token
@property
def all_token(self) -> Dict[str, Tensor]:
if self._all_token is None and \
self.name in self.llm_explanation_id:
self._all_token = self.tokenize_graph(text_type='all')
return self._all_token
# load is_gold after init
@property
def is_gold(self) -> Tensor:
if self._is_gold is None:
print('lazy load is_gold!!')
self._is_gold = self.load_gold_mask()
return self._is_gold
def get_n_id(self, node_idx: IndexType) -> Tensor:
if self._n_id is None:
assert self._data is not None
assert self._data.num_nodes is not None
assert isinstance(self._data.num_nodes, int)
self._n_id = torch.arange(self._data.num_nodes)
return self._n_id[node_idx]
def load_gold_mask(self) -> Tensor:
r"""Use original train split as gold split, generating is_gold mask
for picking ground truth labels and pseudo labels.
"""
train_split_idx = self.get_idx_split()['train']
assert self._data is not None
assert self._data.num_nodes is not None
assert isinstance(self._data.num_nodes, int)
is_good_tensor = torch.zeros(self._data.num_nodes,
dtype=torch.bool).view(-1, 1)
is_good_tensor[train_split_idx] = True
return is_good_tensor
def get_gold(self, node_idx: IndexType) -> Tensor:
r"""Get gold mask for given node_idx.
Args:
node_idx (torch.tensor): a tensor contain node idx
"""
if self._is_gold is None:
self._is_gold = self.is_gold
return self._is_gold[node_idx]
def get_idx_split(self) -> Dict[str, Tensor]:
return self.split_idx
def download(self) -> None:
print('downloading raw text')
raw_text_path = download_google_url(id=self.raw_text_id[self.name],
folder=f'{self.root}/raw',
filename='node-text.csv.gz',
log=True)
self.text = list(read_csv(raw_text_path)['text'])
if self.name in self.llm_explanation_id:
print('downloading llm explanations')
llm_explanation_path = download_google_url(
id=self.llm_explanation_id[self.name],
folder=f'{self.root}/raw', filename='node-gpt-response.csv.gz',
log=True)
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
print('downloading llm predictions')
fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir)
def process(self) -> None:
# process Title and Abstraction
if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):
text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))
self.text = list(text_df['text'])
elif self.name in self.raw_text_id:
self.download()
else:
print('The dataset is not ogbn-products nor ogbn-arxiv,'
'please pass in your raw text string list to `text`')
if self.text is None:
raise ValueError("The TAGDataset only have ogbn-products and "
"ogbn-arxiv raw text in default "
"The raw text of each node is not specified"
"Please pass in 'text' when convert your dataset "
"to Text Attribute Graph Dataset")
# process LLM explanation and prediction
llm_explanation_path = f'{self.raw_dir}/node-gpt-response.csv.gz'
llm_prediction_path = f'{self.raw_dir}/{self.name}.csv'
if osp.exists(llm_explanation_path) and osp.exists(
llm_prediction_path):
# load LLM explanation
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
# load LLM prediction
preds = []
with open(llm_prediction_path) as file:
reader = csv.reader(file)
for row in reader:
inner_list = []
for value in row:
inner_list.append(int(value))
preds.append(inner_list)
pl = torch.zeros(len(preds), self.llm_prediction_topk,
dtype=torch.long)
for i, pred in enumerate(preds):
pl[i][:len(pred)] = torch.tensor(
pred[:self.llm_prediction_topk], dtype=torch.long) + 1
if self.llm_explanation is None or pl is None:
raise ValueError(
"The TAGDataset only have ogbn-arxiv LLM explanations"
"and predictions in default. The llm explanation and"
"prediction of each node is not specified.Please pass in"
"'llm_explanation' and 'llm_prediction' when"
"convert your dataset to Text Attribute Graph Dataset")
elif self.name in self.llm_explanation_id:
self.download()
else:
print(
'The dataset is not ogbn-arxiv,'
'please pass in your llm explanation list to `llm_explanation`'
'and llm prediction list to `llm_prediction`')
def save_node_text(self, text: List[str]) -> None:
node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')
if osp.exists(node_text_path):
print(f'The raw text is existed at {node_text_path}')
else:
print(f'Saving raw text file at {node_text_path}')
os.makedirs(f'{self.root}/raw', exist_ok=True)
text_df = DataFrame(text, columns=['text'])
text_df.to_csv(osp.join(node_text_path), compression='gzip',
index=False)
def tokenize_graph(self, batch_size: int = 256,
text_type: str = 'raw_text') -> Dict[str, Tensor]:
r"""Tokenizing the text associate with each node, running in cpu.
Args:
batch_size (Optional[int]): batch size of list of text for
generating emebdding
text_type (Optional[str]): type of text
Returns:
Dict[str, torch.Tensor]: tokenized graph
"""
assert text_type in ['raw_text', 'llm_explanation', 'all']
if text_type == 'raw_text':
_text = self.text
elif text_type == 'llm_explanation':
_text = self.llm_explanation
elif text_type == 'all':
if self.text is None or self.llm_explanation is None:
raise ValueError("The TAGDataset need text and llm explanation"
"for tokenizing all text")
_text = [
f'{raw_txt} Explanation: {exp_txt}'
for raw_txt, exp_txt in zip(self.text, self.llm_explanation)
]
data_len = 0
if _text is not None:
data_len = len(_text)
else:
raise ValueError("The TAGDataset need text for tokenization")
token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
path = os.path.join(self.processed_dir, 'token', text_type,
self.tokenizer_name)
# Check if the .pt files already exist
token_files_exist = any(
os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)
if token_files_exist and self.token_on_disk:
print('Found tokenized file, loading may take several minutes...')
all_encoded_token = {
k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True)
for k in token_keys
if os.path.exists(os.path.join(path, f'{k}.pt'))
}
return all_encoded_token
all_encoded_token = {k: [] for k in token_keys}
pbar = tqdm(total=data_len)
pbar.set_description(f'Tokenizing Text Attributed Graph {text_type}')
for i in range(0, data_len, batch_size):
end_index = min(data_len, i + batch_size)
token = self.tokenizer(_text[i:end_index], padding='max_length',
truncation=True, max_length=512,
return_tensors="pt")
for k in token.keys():
all_encoded_token[k].append(token[k])
pbar.update(end_index - i)
pbar.close()
all_encoded_token = {
k: torch.cat(v)
for k, v in all_encoded_token.items() if len(v) > 0
}
if self.token_on_disk:
os.makedirs(path, exist_ok=True)
print('Saving tokens on Disk')
for k, tensor in all_encoded_token.items():
torch.save(tensor, os.path.join(path, f'{k}.pt'))
print('Token saved:', os.path.join(path, f'{k}.pt'))
os.environ["TOKENIZERS_PARALLELISM"] = 'true' # suppressing warning
return all_encoded_token
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
class TextDataset(torch.utils.data.Dataset):
r"""This nested dataset provides textual data for each node in
the graph. Factory method to create TextDataset from TAGDataset.
Args:
tag_dataset (TAGDataset): the parent dataset
text_type (str): type of text
"""
def __init__(self, tag_dataset: 'TAGDataset',
text_type: str = 'raw_text') -> None:
assert text_type in ['raw_text', 'llm_explanation', 'all']
self.tag_dataset = tag_dataset
if text_type == 'raw_text':
self.token = tag_dataset.token
elif text_type == 'llm_explanation':
self.token = tag_dataset.llm_explanation_token
elif text_type == 'all':
self.token = tag_dataset.all_token
assert tag_dataset._data is not None
self._data = tag_dataset._data
assert tag_dataset._data.y is not None
self.labels = tag_dataset._data.y
def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]:
r"""This function will be called in __getitem__().
Args:
node_idx (IndexType): selected node idx in each batch
Returns:
items (Dict[str, Tensor]): input for LM
"""
items = {k: v[node_idx] for k, v in self.token.items()}
return items
# for LM training
def __getitem__(
self,
node_id: IndexType,
) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
r"""This function will override the function in
torch.utils.data.Dataset, and will be called when you
iterate batch in the dataloader, make sure all following
key value pairs are present in the return dict.
Args:
node_id (List[int]): list of node idx for selecting tokens,
labels etc. when iterating data loader for LM
Returns:
items (dict): input k,v pairs for Language model training and
inference
"""
item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {}
item['input'] = self.get_token(node_id)
item['labels'] = self.labels[node_id]
item['is_gold'] = self.tag_dataset.get_gold(node_id)
item['n_id'] = self.tag_dataset.get_n_id(node_id)
return item
def __len__(self) -> int:
assert self._data.num_nodes is not None
return self._data.num_nodes
def get(self, idx: int) -> BaseData:
return self._data
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
def to_text_dataset(self, text_type: str = 'raw_text') -> TextDataset:
r"""Factory Build text dataset from Text Attributed Graph Dataset
each data point is node's associated text token.
"""
return TAGDataset.TextDataset(self, text_type)