# Source code for torch_geometric.transforms.fixed_points

import re
import math

import torch
import numpy as np

[docs]class FixedPoints(object):
r"""Samples a fixed number of :obj:num points and features from a point
cloud.

Args:
num (int): The number of points to sample.
replace (bool, optional): If set to :obj:False, samples points
without replacement. (default: :obj:True)
allow_duplicates (bool, optional): In case :obj:replace is
:objFalse and :obj:num is greater than the number of points,
this option determines whether to add duplicated nodes to the
output points or not.
In case :obj:allow_duplicates is :obj:False, the number of
output points might be smaller than :obj:num.
In case :obj:allow_duplicates is :obj:True, the number of
duplicated points are kept to a minimum. (default: :obj:False)
"""
def __init__(self, num, replace=True, allow_duplicates=False):
self.num = num
self.replace = replace
self.allow_duplicates = allow_duplicates

def __call__(self, data):
num_nodes = data.num_nodes

if self.replace:
choice = np.random.choice(num_nodes, self.num, replace=True)
choice = torch.from_numpy(choice).to(torch.long)
elif not self.allow_duplicates:
choice = torch.randperm(num_nodes)[:self.num]
else:
choice = torch.cat([
torch.randperm(num_nodes)
for _ in range(math.ceil(self.num / num_nodes))
], dim=0)[:self.num]

for key, item in data:
if bool(re.search('edge', key)):
continue
if torch.is_tensor(item) and item.size(0) == num_nodes:
data[key] = item[choice]

return data

def __repr__(self):
return '{}({}, replace={})'.format(self.__class__.__name__, self.num,
self.replace)