Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Model] logits_soft_cap for Gemma2 with flashinfer #6051

Merged
7 changes: 5 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,15 @@ steps:

- label: Kernels Test %N
#mirror_hardwares: [amd]
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
command:
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved Hide resolved
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4

- label: Models Test
#mirror_hardwares: [amd]
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- pytest -v -s models -m \"not vlm\"

- label: Vision Language Models Test
Expand Down Expand Up @@ -223,7 +226,7 @@ steps:
- pytest -v -s distributed/test_custom_all_reduce.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s -x lora/test_mixtral.py
248 changes: 248 additions & 0 deletions tests/kernels/test_flashinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from typing import List, Optional, Tuple

import flashinfer
import pytest
import torch

NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.


def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape

outputs: List[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q *= scale

num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]

k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]

if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)

outputs.append(out)
start_idx += query_len

return torch.cat(outputs, dim=0)


@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
num_heads: Tuple[int,
int], head_size: int,
dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5

query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_value_cache = torch.randn(NUM_BLOCKS,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)

kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
data_type=dtype)

output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)

ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int, dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5

query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_value_cache = torch.randn(NUM_BLOCKS,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)

# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache /= head_size**0.5
value_cache /= head_size**0.5

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

qo_indptr = [0]
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
qo_indptr.append(qo_indptr[-1] + query_lens[i])

qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
)

output = wrapper.forward(
query,
key_value_cache,
logits_soft_cap=soft_cap,
)

ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
41 changes: 41 additions & 0 deletions tests/models/test_gemma2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Compare the outputs of HF and vLLM for Gemma2 models using greedy sampling.

Run `pytest tests/models/test_gemma2.py`.
"""
import os

import pytest

from .utils import check_logprobs_close

MODELS = ["google/gemma-2-9b"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
os.environ['VLLM_ATTENTION_BACKEND'] = "FLASHINFER"
# TODO(sang): Sliding window should be tested separately.
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
12 changes: 8 additions & 4 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
# Only used by gemma2 model
logits_soft_cap: Optional[float] = None

def __post_init__(self):
# Refer to
Expand Down Expand Up @@ -269,15 +271,17 @@ def forward(
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(query,
kv_cache,
causal=True)
output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=attn_metadata.logits_soft_cap,
causal=True)
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
sm_scale=self.scale,
)
logits_soft_cap=attn_metadata.logits_soft_cap)
return output.view(num_tokens, hidden_size)
2 changes: 1 addition & 1 deletion vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_attn_backend(
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning(("Flashinfer will be stuck on llma-2-7b,"
" please avoid using Flashinfer as the"
" please avoid using Flashinfer as the "
"backend when running on llma-2-7b."))
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
Expand Down
14 changes: 12 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,16 @@ def _prepare_model_input_tensors(
dtype=torch.long,
device=self.device)

logits_soft_cap = getattr(self.model_config.hf_config,
'final_logit_softcapping', None)
if logits_soft_cap is not None and self.attn_backend.get_name(
) != "flashinfer":
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I check if logits_soft_cap is supposed to be the attn_logit_softcapping value instead? The two values are different in the Gemma2 config.

"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yongxb Nice catch! final_logit_softcapping is used to cap the final logits before sampling. @LiuXiaoxuanPKU Could you please fix this?

raise ValueError("Please use Flashinfer backend for models with"
"logits_soft_cap (i.e., Gemma-2)."
" Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")

if self.attn_backend.get_name() == "flashinfer":
if len(paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
Expand All @@ -676,7 +686,6 @@ def _prepare_model_input_tensors(

kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype)

attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
Expand All @@ -697,7 +706,8 @@ def _prepare_model_input_tensors(
query_start_loc=query_start_loc,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph)
use_cuda_graph=use_captured_graph,
logits_soft_cap=logits_soft_cap)

else:
attn_metadata = self.attn_backend.make_metadata(
Expand Down
Loading