xma.functional.fused_linear_cross_entropy¶
- fused_linear_cross_entropy(x: Tensor, weight: Tensor, labels: Tensor, reduction: str = 'mean', logits_multiplier: float | None = None, *, kernel_backend: KernelBackend | None = None) Tensor[source]¶
compute cross entropy loss without materializing the full output logits matrix
- Parameters:
x (torch.Tensor) – logits
weight (torch.Tensor) – vocab weight
labels (torch.Tensor) – labels
reduction (str) – reduction should be either sum or mean. Defaults to “mean”.
logits_multiplier (float | None) – logits multiplier pre-multiplies logits, None implies 1. Defaults to None.
kernel_backend (KernelBackend | None) – KernelBackend
- Returns:
loss
- Return type:
Tensor