Source code for torch_geometric.profile.benchmark

import time
from typing import Any, Callable, List, Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.utils import is_torch_sparse_tensor

[docs]def benchmark( funcs: List[Callable], args: Tuple[Any], num_steps: int, func_names: Optional[List[str]] = None, num_warmups: int = 10, backward: bool = False, ): r"""Benchmark a list of functions :obj:`funcs` that receive the same set of arguments :obj:`args`. Args: funcs ([Callable]): The list of functions to benchmark. args ((Any, )): The arguments to pass to the functions. num_steps (int): The number of steps to run the benchmark. func_names ([str], optional): The names of the functions. If not given, will try to infer the name from the function itself. (default: :obj:`None`) num_warmups (int, optional): The number of warmup steps. (default: :obj:`10`) backward (bool, optional): If set to :obj:`True`, will benchmark both forward and backward passes. (default: :obj:`False`) """ from tabulate import tabulate if num_steps <= 0: raise ValueError(f"'num_steps' must be a positive integer " f"(got {num_steps})") if num_warmups <= 0: raise ValueError(f"'num_warmups' must be a positive integer " f"(got {num_warmups})") if func_names is None: func_names = [get_func_name(func) for func in funcs] if len(funcs) != len(func_names): raise ValueError(f"Length of 'funcs' (got {len(funcs)}) and " f"'func_names' (got {len(func_names)}) must be equal") ts: List[List[str]] = [] for func, name in zip(funcs, func_names): t_forward = t_backward = 0 for i in range(num_warmups + num_steps): args = [ arg.detach().requires_grad_(backward) if isinstance(arg, Tensor) and arg.is_floating_point() and not is_torch_sparse_tensor(arg) else arg for arg in args ] if torch.cuda.is_available(): torch.cuda.synchronize() t_start = time.perf_counter() out = func(*args) if torch.cuda.is_available(): torch.cuda.synchronize() if i >= num_warmups: t_forward += time.perf_counter() - t_start if backward: out_grad = torch.randn_like(out) t_start = time.perf_counter() out.backward(out_grad) if torch.cuda.is_available(): torch.cuda.synchronize() if i >= num_warmups: t_backward += time.perf_counter() - t_start ts.append([name, f'{t_forward:.4f}s']) if backward: ts[-1].append(f'{t_backward:.4f}s') ts[-1].append(f'{t_forward + t_backward:.4f}s') header = ['Name', 'Forward'] if backward: header.extend(['Backward', 'Total']) print(tabulate(ts, headers=header, tablefmt='psql'))
def get_func_name(func: Callable) -> str: if hasattr(func, '__name__'): return func.__name__ elif hasattr(func, '__class__'): return func.__class__.__name__ raise ValueError("Could not infer name for function '{func}'")