Source code for xma.functional.continuous_count
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************
import torch
from ...accelerator import Accelerator, KernelBackend
from .cuda_implementation import continuous_count_cuda
[docs]
@torch.no_grad()
def continuous_count(x: torch.Tensor, bins: int, *, kernel_backend: KernelBackend | None = None) -> torch.Tensor:
"""
counts the number of occurances of the values [0, 1, ..., `bins`) in the input tensor (`bins` is excluded).
NOTE: the user is responsible for ensuring that the values lie in the valid range, any values outside this
range are ignored and not counted.
:param x: input tensor
:type x: torch.Tensor
:param bins: values [0, 1, ..., `bins`) are counted (`bins` is excluded)
:type bins: int
:param kernel_backend: KernelBackend
:type kernel_backend: KernelBackend | None
:return: output tensor
:rtype: Tensor
"""
if bins == 1:
return torch.tensor([x.numel()], dtype=torch.uint32, device=x.device)
assert x.dim() == 1, "x should be 1-dimensional"
assert x.dtype in [torch.int32, torch.long]
if kernel_backend is None:
kernel_backend = Accelerator.get_kernel_backend()
else:
assert kernel_backend.verify_accelerator()
if kernel_backend == KernelBackend.torch:
output = x.bincount(minlength=bins).to(torch.uint32)
elif kernel_backend in [KernelBackend.cuda, KernelBackend.triton]:
output = torch.empty(bins, dtype=torch.uint32, device=x.device)
continuous_count_cuda(x=x, output=output, E=bins, THREAD_BLOCK_CLUSTER_SIZE=1, BLOCK_SIZE=1024)
else:
raise ValueError(f"unexpected kernel_backend ({kernel_backend})")
return output