Source code for xma.utils.tensor

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import torch

from ..math import get_powers_of_2


[docs] def get_num_elements_and_hidden_size(x: torch.Tensor) -> tuple[int]: hidden_size = x.size(-1) num_elements = x.numel() // hidden_size return num_elements, hidden_size
[docs] def empty_like_contiguous(x: torch.Tensor, dtype: torch.dtype | None = None) -> torch.Tensor: return torch.empty_like(x, dtype=dtype, memory_format=torch.contiguous_format)
[docs] def zeros_like_contiguous(x: torch.Tensor, dtype: torch.dtype | None = None) -> torch.Tensor: return torch.zeros_like(x, dtype=dtype, memory_format=torch.contiguous_format)
[docs] def get_alignment(x: torch.Tensor) -> int: x = x.data_ptr() alignment = 4 for i in get_powers_of_2(4, 16): if x % i != 0: break else: alignment = i return alignment