import torch
from torch import Tensor
[docs]def cumsum(x: Tensor, dim: int = 0) -> Tensor:
r"""Returns the cumulative sum of elements of :obj:`x`.
In contrast to :meth:`torch.cumsum`, prepends the output with zero.
Args:
x (torch.Tensor): The input tensor.
dim (int, optional): The dimension to do the operation over.
(default: :obj:`0`)
Example:
>>> x = torch.tensor([2, 4, 1])
>>> cumsum(x)
tensor([0, 2, 6, 7])
"""
size = x.size()[:dim] + (x.size(dim) + 1, ) + x.size()[dim + 1:]
out = x.new_empty(size)
out.narrow(dim, 0, 1).zero_()
torch.cumsum(x, dim=dim, out=out.narrow(dim, 1, x.size(dim)))
return out