Source code for xma.utils.contiguous
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************
from typing import Any, Callable
import torch
from torch.utils._pytree import tree_map
def _make_contiguous(x: Any) -> Any:
return x.contiguous() if isinstance(x, torch.Tensor) else x
[docs]
def ensure_contiguous(func: Callable) -> Callable:
def inner(*args, **kwargs):
args = tree_map(_make_contiguous, args)
kwargs = tree_map(_make_contiguous, kwargs)
return func(*args, **kwargs)
return inner