Spaces:
Paused
Paused
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import signal | |
| import subprocess | |
| import unittest | |
| import psutil | |
| import pytest | |
| from transformers import AutoModelForCausalLM | |
| from transformers.testing_utils import require_torch_multi_accelerator, torch_device | |
| from trl.extras.vllm_client import VLLMClient | |
| from trl.scripts.vllm_serve import chunk_list | |
| from .testing_utils import require_3_accelerators | |
| class TestChunkList(unittest.TestCase): | |
| def test_even_split(self): | |
| self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 2), [[1, 2, 3], [4, 5, 6]]) | |
| def test_uneven_split(self): | |
| self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 4), [[1, 2], [3, 4], [5], [6]]) | |
| def test_more_chunks_than_elements(self): | |
| self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 8), [[1], [2], [3], [4], [5], [6], [], []]) | |
| def test_n_equals_len(self): | |
| self.assertEqual(chunk_list([1, 2, 3], 3), [[1], [2], [3]]) | |
| def test_n_is_1(self): | |
| self.assertEqual(chunk_list([1, 2, 3], 1), [[1, 2, 3]]) | |
| def test_single_element_list(self): | |
| self.assertEqual(chunk_list([42], 2), [[42], []]) | |
| def test_any_dtype(self): | |
| self.assertEqual( | |
| chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2), | |
| [[1, "two", 3.0], [{"four": 4}, ["f", "i", "v", "e"]]], | |
| ) | |
| class TestVLLMClientServer(unittest.TestCase): | |
| model_id = "Qwen/Qwen2.5-1.5B" | |
| def setUpClass(cls): | |
| # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" | |
| env = os.environ.copy() | |
| VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" | |
| env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1 | |
| # Start the server process | |
| cls.server_process = subprocess.Popen( | |
| ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env | |
| ) | |
| # Initialize the client | |
| cls.client = VLLMClient(connection_timeout=240) | |
| cls.client.init_communicator() | |
| def test_generate(self): | |
| prompts = ["Hello, AI!", "Tell me a joke"] | |
| outputs = self.client.generate(prompts) | |
| # Check that the output is a list | |
| self.assertIsInstance(outputs, list) | |
| # Check that the number of generated sequences is equal to the number of prompts | |
| self.assertEqual(len(outputs), len(prompts)) | |
| # Check that the generated sequences are lists of integers | |
| for seq in outputs: | |
| self.assertTrue(all(isinstance(tok, int) for tok in seq)) | |
| def test_generate_with_params(self): | |
| prompts = ["Hello, AI!", "Tell me a joke"] | |
| outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32) | |
| # Check that the output is a list | |
| self.assertIsInstance(outputs, list) | |
| # Check that the number of generated sequences is 2 times the number of prompts | |
| self.assertEqual(len(outputs), 2 * len(prompts)) | |
| # Check that the generated sequences are lists of integers | |
| for seq in outputs: | |
| self.assertTrue(all(isinstance(tok, int) for tok in seq)) | |
| # Check that the length of the generated sequences is less than or equal to 32 | |
| for seq in outputs: | |
| self.assertLessEqual(len(seq), 32) | |
| def test_update_model_params(self): | |
| model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) | |
| self.client.update_model_params(model) | |
| def test_reset_prefix_cache(self): | |
| # Test resetting the prefix cache | |
| self.client.reset_prefix_cache() | |
| def tearDownClass(cls): | |
| super().tearDownClass() | |
| # Close the client | |
| cls.client.close_communicator() | |
| # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to | |
| # kill the server process and its children explicitly. | |
| parent = psutil.Process(cls.server_process.pid) | |
| children = parent.children(recursive=True) | |
| for child in children: | |
| child.send_signal(signal.SIGTERM) | |
| cls.server_process.terminate() | |
| cls.server_process.wait() | |
| # Same as above but using base_url to instantiate the client. | |
| class TestVLLMClientServerBaseURL(unittest.TestCase): | |
| model_id = "Qwen/Qwen2.5-1.5B" | |
| def setUpClass(cls): | |
| # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" | |
| env = os.environ.copy() | |
| VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" | |
| env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1 | |
| # Start the server process | |
| cls.server_process = subprocess.Popen( | |
| ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env | |
| ) | |
| # Initialize the client | |
| cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=240) | |
| cls.client.init_communicator() | |
| def test_generate(self): | |
| prompts = ["Hello, AI!", "Tell me a joke"] | |
| outputs = self.client.generate(prompts) | |
| # Check that the output is a list | |
| self.assertIsInstance(outputs, list) | |
| # Check that the number of generated sequences is equal to the number of prompts | |
| self.assertEqual(len(outputs), len(prompts)) | |
| # Check that the generated sequences are lists of integers | |
| for seq in outputs: | |
| self.assertTrue(all(isinstance(tok, int) for tok in seq)) | |
| def test_generate_with_params(self): | |
| prompts = ["Hello, AI!", "Tell me a joke"] | |
| outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32) | |
| # Check that the output is a list | |
| self.assertIsInstance(outputs, list) | |
| # Check that the number of generated sequences is 2 times the number of prompts | |
| self.assertEqual(len(outputs), 2 * len(prompts)) | |
| # Check that the generated sequences are lists of integers | |
| for seq in outputs: | |
| self.assertTrue(all(isinstance(tok, int) for tok in seq)) | |
| # Check that the length of the generated sequences is less than or equal to 32 | |
| for seq in outputs: | |
| self.assertLessEqual(len(seq), 32) | |
| def test_update_model_params(self): | |
| model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) | |
| self.client.update_model_params(model) | |
| def test_reset_prefix_cache(self): | |
| # Test resetting the prefix cache | |
| self.client.reset_prefix_cache() | |
| def tearDownClass(cls): | |
| super().tearDownClass() | |
| # Close the client | |
| cls.client.close_communicator() | |
| # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to | |
| # kill the server process and its children explicitly. | |
| parent = psutil.Process(cls.server_process.pid) | |
| children = parent.children(recursive=True) | |
| for child in children: | |
| child.send_signal(signal.SIGTERM) | |
| cls.server_process.terminate() | |
| cls.server_process.wait() | |
| class TestVLLMClientServerTP(unittest.TestCase): | |
| model_id = "Qwen/Qwen2.5-1.5B" | |
| def setUpClass(cls): | |
| # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2" | |
| env = os.environ.copy() | |
| VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" | |
| env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2 | |
| # Start the server process | |
| cls.server_process = subprocess.Popen( | |
| ["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| env=env, | |
| ) | |
| # Initialize the client | |
| cls.client = VLLMClient(connection_timeout=240) | |
| cls.client.init_communicator() | |
| def test_generate(self): | |
| prompts = ["Hello, AI!", "Tell me a joke"] | |
| outputs = self.client.generate(prompts) | |
| # Check that the output is a list | |
| self.assertIsInstance(outputs, list) | |
| # Check that the number of generated sequences is equal to the number of prompts | |
| self.assertEqual(len(outputs), len(prompts)) | |
| # Check that the generated sequences are lists of integers | |
| for seq in outputs: | |
| self.assertTrue(all(isinstance(tok, int) for tok in seq)) | |
| def test_update_model_params(self): | |
| model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) | |
| self.client.update_model_params(model) | |
| def test_reset_prefix_cache(self): | |
| # Test resetting the prefix cache | |
| self.client.reset_prefix_cache() | |
| def tearDownClass(cls): | |
| super().tearDownClass() | |
| # Close the client | |
| cls.client.close_communicator() | |
| # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to | |
| # kill the server process and its children explicitly. | |
| parent = psutil.Process(cls.server_process.pid) | |
| children = parent.children(recursive=True) | |
| for child in children: | |
| child.send_signal(signal.SIGTERM) | |
| cls.server_process.terminate() | |
| cls.server_process.wait() | |
| class TestVLLMClientServerDP(unittest.TestCase): | |
| model_id = "Qwen/Qwen2.5-1.5B" | |
| def setUpClass(cls): | |
| # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2" | |
| env = os.environ.copy() | |
| VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" | |
| env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2 | |
| # Start the server process | |
| cls.server_process = subprocess.Popen( | |
| ["trl", "vllm-serve", "--model", cls.model_id, "--data_parallel_size", "2"], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| env=env, | |
| ) | |
| # Initialize the client | |
| cls.client = VLLMClient(connection_timeout=240) | |
| cls.client.init_communicator() | |
| def test_generate(self): | |
| prompts = ["Hello, AI!", "Tell me a joke"] | |
| outputs = self.client.generate(prompts) | |
| # Check that the output is a list | |
| self.assertIsInstance(outputs, list) | |
| # Check that the number of generated sequences is equal to the number of prompts | |
| self.assertEqual(len(outputs), len(prompts)) | |
| # Check that the generated sequences are lists of integers | |
| for seq in outputs: | |
| self.assertTrue(all(isinstance(tok, int) for tok in seq)) | |
| def test_update_model_params(self): | |
| model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) | |
| self.client.update_model_params(model) | |
| def test_reset_prefix_cache(self): | |
| # Test resetting the prefix cache | |
| self.client.reset_prefix_cache() | |
| def tearDownClass(cls): | |
| super().tearDownClass() | |
| # Close the client | |
| cls.client.close_communicator() | |
| # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to | |
| # kill the server process and its children explicitly. | |
| parent = psutil.Process(cls.server_process.pid) | |
| children = parent.children(recursive=True) | |
| for child in children: | |
| child.send_signal(signal.SIGTERM) | |
| cls.server_process.terminate() | |
| cls.server_process.wait() | |
| class TestVLLMClientServerDeviceParameter(unittest.TestCase): | |
| """Test the device parameter functionality in init_communicator.""" | |
| model_id = "Qwen/Qwen2.5-1.5B" | |
| def setUpClass(cls): | |
| # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" | |
| env = os.environ.copy() | |
| VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" | |
| env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1 | |
| # Start the server process | |
| cls.server_process = subprocess.Popen( | |
| ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env | |
| ) | |
| def test_init_communicator_with_device_int(self): | |
| """Test init_communicator with integer device parameter.""" | |
| client = VLLMClient(connection_timeout=240) | |
| client.init_communicator(device=0) # Explicitly specify device 0 | |
| # Test basic functionality | |
| prompts = ["Hello, AI!"] | |
| outputs = client.generate(prompts) | |
| self.assertIsInstance(outputs, list) | |
| self.assertEqual(len(outputs), len(prompts)) | |
| client.close_communicator() | |
| def test_init_communicator_with_device_string(self): | |
| """Test init_communicator with string device parameter.""" | |
| client = VLLMClient(connection_timeout=240) | |
| client.init_communicator(device="cuda:0") # Explicitly specify device as string | |
| # Test basic functionality | |
| prompts = ["Hello, AI!"] | |
| outputs = client.generate(prompts) | |
| self.assertIsInstance(outputs, list) | |
| self.assertEqual(len(outputs), len(prompts)) | |
| client.close_communicator() | |
| def test_init_communicator_with_torch_device(self): | |
| """Test init_communicator with torch.device object.""" | |
| import torch | |
| client = VLLMClient(connection_timeout=240) | |
| device = torch.device("cuda:0") | |
| client.init_communicator(device=device) # Explicitly specify torch.device object | |
| # Test basic functionality | |
| prompts = ["Hello, AI!"] | |
| outputs = client.generate(prompts) | |
| self.assertIsInstance(outputs, list) | |
| self.assertEqual(len(outputs), len(prompts)) | |
| client.close_communicator() | |
| def tearDownClass(cls): | |
| super().tearDownClass() | |
| # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to | |
| # kill the server process and its children explicitly. | |
| parent = psutil.Process(cls.server_process.pid) | |
| children = parent.children(recursive=True) | |
| for child in children: | |
| child.send_signal(signal.SIGTERM) | |
| cls.server_process.terminate() | |
| cls.server_process.wait() | |