class TritonAttentionImpl(AttentionImpl):
# Per-token-head quant: scale views carved from inline head padding.
_k_scale_cache: torch.Tensor | None = None
_v_scale_cache: torch.Tensor | None = None
def _ensure_scale_caches(self, kv_cache: torch.Tensor) -> None:
"""Extract per-head scale views from the padded head dimension.
The KV cache shape is ``(num_blocks, 2, block_size, nkv, hs+pad)``
where ``pad = sizeof(float32) / sizeof(cache_dtype)``. The last
``pad`` elements of each head hold one float32 scale. We create
strided float32 views over those bytes.
Scale shape: ``(num_blocks, block_size, num_kv_heads)``
"""
if self._k_scale_cache is not None:
return
from vllm.utils.torch_utils import get_dtype_size
num_blocks, _, block_size, nkv, padded_hs = kv_cache.shape
dtype_sz = kv_cache.element_size()
scale_pad = get_dtype_size(torch.float32) // dtype_sz # e.g. 4
hs = padded_hs - scale_pad
raw = kv_cache.untyped_storage()
base_f32 = torch.tensor([], dtype=torch.float32, device=kv_cache.device).set_(
raw
)
# In the raw bytes, each (block, kv_half, slot, head) occupies
# padded_hs * dtype_sz bytes. The scale float32 sits at byte
# offset hs * dtype_sz within that region.
kv_half_bytes = block_size * nkv * padded_hs * dtype_sz
full_block_f32 = 2 * kv_half_bytes // 4 # stride between blocks
slot_f32 = nkv * padded_hs * dtype_sz // 4 # stride between slots
head_f32 = padded_hs * dtype_sz // 4 # stride between heads
scale_off_f32 = hs * dtype_sz // 4 # offset to scale within head
# K scales: kv_half=0
self._k_scale_cache = torch.as_strided(
base_f32,
size=(num_blocks, block_size, nkv),
stride=(full_block_f32, slot_f32, head_f32),
storage_offset=scale_off_f32,
)
self._k_scale_cache.fill_(1.0)
# V scales: kv_half=1, offset by kv_half_bytes
v_base_f32 = kv_half_bytes // 4
self._v_scale_cache = torch.as_strided(
base_f32,
size=(num_blocks, block_size, nkv),
stride=(full_block_f32, slot_f32, head_f32),
storage_offset=v_base_f32 + scale_off_f32,
)
self._v_scale_cache.fill_(1.0)
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
use_alibi_sqrt: bool = False,
chunk_lookback: int = -1,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
elif attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY):
self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.attn_type = attn_type
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}."
)
self.use_alibi_sqrt = use_alibi_sqrt
self.chunk_lookback = chunk_lookback
self.supports_quant_query_input = current_platform.is_cuda()
self._kv_quant_mode = get_kv_quant_mode(kv_cache_dtype)
self._is_per_token_head_quant = self._kv_quant_mode.is_per_token_head
# Enable tensor descriptors for Q/K/V load/store on platforms that
# benefit from HW 2D block reads (Intel Xe2/Xe3). The dead branch
# is eliminated at Triton compile time, so other platforms see
# zero cost when TD is off.
#
# ``VLLM_TRITON_ATTN_USE_TD`` is tri-state:
# - unset (None): auto-select (TD on for XPU, off elsewhere),
# - ``1``: force TD on regardless of platform,
# - ``0``: force TD off regardless of platform (useful for A/B).
td_override = envs.VLLM_TRITON_ATTN_USE_TD
if td_override is None:
self.use_td = current_platform.is_xpu()
else:
self.use_td = td_override
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with Paged Attention impl. in Triton.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[num_blocks, 2, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for TritonAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
# Per-token-head quantized KV cache: use separate scale caches.
if self._is_per_token_head_quant:
self._ensure_scale_caches(kv_cache)
key_cache, value_cache = kv_cache.unbind(1)
if key_cache.dtype == torch.uint8:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
q_descale = None
k_descale = None
v_descale = None
k_scale_cache = self._k_scale_cache
v_scale_cache = self._v_scale_cache
# FP8 per-tensor / auto path (original flow).
else:
key_cache, value_cache = kv_cache.unbind(1)
if (
is_quantized_kv_cache(self.kv_cache_dtype)
and key_cache.dtype != self.fp8_dtype
):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
descale_shape = (
attn_metadata.query_start_loc.shape[0] - 1,
key_cache.shape[2],
)
q_descale = (
layer._q_scale
if (
self._kv_quant_mode == KVQuantMode.FP8_PER_TENSOR
and query.dtype == self.fp8_dtype
)
else None
)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
k_scale_cache = None
v_scale_cache = None
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
seq_threshold_3D = attn_metadata.seq_threshold_3D
num_par_softmax_segments = attn_metadata.num_par_softmax_segments
softmax_segm_output = attn_metadata.softmax_segm_output
softmax_segm_max = attn_metadata.softmax_segm_max
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
use_alibi_sqrt=self.use_alibi_sqrt,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
sinks=self.sinks,
output_scale=output_scale,
mm_prefix_range=mm_prefix_range_tensor,
kv_quant_mode=self._kv_quant_mode,
k_scale_cache=k_scale_cache,
v_scale_cache=v_scale_cache,
chunk_lookback=self.chunk_lookback,
use_td=self.use_td,
)
return output
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# Quantized KV cache is not supported for encoder attention.
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"quantized KV cache is not supported for encoder attention"
)
# Use encoder-specific metadata for sequence information
query_start_loc = attn_metadata.query_start_loc
seq_lens = attn_metadata.seq_lens
max_query_len = attn_metadata.max_query_len
# Call flash attention directly on Q, K, V tensors
context_attention_fwd(
q=query,
k=key,
v=value,
o=output,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_input_len=max_query_len,
is_causal=False, # Encoder attention is bidirectional
softmax_scale=self.scale,
sliding_window_q=self.sliding_window[0],
sliding_window_k=self.sliding_window[1],
)
return output
def do_kv_cache_update(
self,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
# Reshape the input keys and values and store them in the cache.
if self._is_per_token_head_quant:
self._ensure_scale_caches(kv_cache)
key_cache, value_cache = kv_cache.unbind(1)
if key_cache.dtype == torch.uint8:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
triton_reshape_and_cache_flash_per_token_head_quant(
key,
value,
key_cache,
value_cache,
self._k_scale_cache,
self._v_scale_cache,
slot_mapping,
)
return
# For decoder and cross-attention, use KV cache as before.
key_cache, value_cache = kv_cache.unbind(1)
if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def fused_rope_kvcache_supported(self):
if self._is_per_token_head_quant:
return False
return rocm_aiter_ops.is_enabled()
def do_rope_and_kv_cache_update(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(1)
flash_layout = True
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
rocm_aiter_ops.triton_rope_and_cache(
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
key_cache,
value_cache,
layer_slot_mapping,
layer._k_scale,
layer._v_scale,
flash_layout,
is_fp8_kv_cache,
)