vllm.model_executor.layers.fla.ops.chunk_scaled_dot_kkt ¶
Functions:
-
chunk_scaled_dot_kkt_fwd–Compute beta * K * K^T.
chunk_scaled_dot_kkt_fwd(k, g=None, beta=None, cu_seqlens=None, chunk_indices=None, chunk_size=FLA_CHUNK_SIZE, output_dtype=torch.float32) ¶
Compute beta * K * K^T.
Parameters:
-
(k¶Tensor) –The key tensor of shape
[B, T, H, K]. -
(beta¶Tensor, default:None) –The beta tensor of shape
[B, T, H]. -
(g¶Tensor, default:None) –The cumulative sum of the gate tensor of shape
[B, T, H]. Default:None. -
(cu_seqlens¶Tensor, default:None) –The cumulative sequence lengths of the input tensor. Default: None
-
(chunk_indices¶Tensor, default:None) –Pre-computed chunk indices. If None and cu_seqlens is provided, computed internally. Default: None
-
(chunk_size¶int, default:FLA_CHUNK_SIZE) –The chunk size. Default: 64.
-
(output_dtype¶dtype, default:float32) –The dtype of the output tensor. Default:
torch.float32
Returns:
-
Tensor–beta * K * K^T of shape
[B, T, H, BT]whereBTis the chunk size.