Source code for torch_geometric.profile.nvtx

from functools import wraps
from typing import Optional

import torch

CUDA_PROFILE_STARTED = False


def begin_cuda_profile():
    global CUDA_PROFILE_STARTED
    prev_state = CUDA_PROFILE_STARTED
    if prev_state is False:
        CUDA_PROFILE_STARTED = True
        torch.cuda.cudart().cudaProfilerStart()
    return prev_state


def end_cuda_profile(prev_state: bool):
    global CUDA_PROFILE_STARTED
    CUDA_PROFILE_STARTED = prev_state
    if prev_state is False:
        torch.cuda.cudart().cudaProfilerStop()


[docs]def nvtxit(name: Optional[str] = None, n_warmups: int = 0, n_iters: Optional[int] = None): """Enables NVTX profiling for a function. Args: name (Optional[str], optional): Name to give the reference frame for the function being wrapped. Defaults to the name of the function in code. n_warmups (int, optional): Number of iters to call that function before starting. Defaults to 0. n_iters (Optional[int], optional): Number of iters of that function to record. Defaults to all of them. """ def nvtx(func): nonlocal name iters_so_far = 0 if name is None: name = func.__name__ @wraps(func) def wrapper(*args, **kwargs): nonlocal iters_so_far if not torch.cuda.is_available(): return func(*args, **kwargs) elif iters_so_far < n_warmups: iters_so_far += 1 return func(*args, **kwargs) elif n_iters is None or iters_so_far < n_iters + n_warmups: prev_state = begin_cuda_profile() torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}") result = func(*args, **kwargs) torch.cuda.nvtx.range_pop() end_cuda_profile(prev_state) iters_so_far += 1 return result else: return func(*args, **kwargs) return wrapper return nvtx