diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index a491ffa763505..9398aeb2c214c 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -26,6 +26,7 @@ def test_beam_search_single_input( max_tokens: int, beam_width: int, ) -> None: + example_prompts = example_prompts[:1] hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, max_tokens)