xma.functional.fused_linear_cross_entropy

fused_linear_cross_entropy(x: Tensor, weight: Tensor, labels: Tensor, reduction: str = 'mean', logits_multiplier: float | None = None, *, kernel_backend: KernelBackend | None = None) Tensor[source]

compute cross entropy loss without materializing the full output logits matrix

Parameters:
  • x (torch.Tensor) – logits

  • weight (torch.Tensor) – vocab weight

  • labels (torch.Tensor) – labels

  • reduction (str) – reduction should be either sum or mean. Defaults to “mean”.

  • logits_multiplier (float | None) – logits multiplier pre-multiplies logits, None implies 1. Defaults to None.

  • kernel_backend (KernelBackend | None) – KernelBackend

Returns:

loss

Return type:

Tensor