Spaces:
Runtime error
Runtime error
| import argparse | |
| import time | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
| import torch | |
| from threading import Thread | |
| MODEL_PATH = 'THUDM/glm-4-9b-chat' | |
| def stress_test(token_len, n, num_gpu): | |
| device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_PATH, | |
| trust_remote_code=True, | |
| padding_side="left" | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 | |
| ).to(device).eval() | |
| # Use INT4 weight infer | |
| # model = AutoModelForCausalLM.from_pretrained( | |
| # MODEL_PATH, | |
| # trust_remote_code=True, | |
| # quantization_config=BitsAndBytesConfig(load_in_4bit=True), | |
| # low_cpu_mem_usage=True, | |
| # ).eval() | |
| times = [] | |
| decode_times = [] | |
| print("Warming up...") | |
| vocab_size = tokenizer.vocab_size | |
| warmup_token_len = 20 | |
| random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long) | |
| start_tokens = [151331, 151333, 151336, 198] | |
| end_tokens = [151337] | |
| input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to( | |
| device) | |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device) | |
| position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device) | |
| warmup_inputs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'position_ids': position_ids | |
| } | |
| with torch.no_grad(): | |
| _ = model.generate( | |
| input_ids=warmup_inputs['input_ids'], | |
| attention_mask=warmup_inputs['attention_mask'], | |
| max_new_tokens=2048, | |
| do_sample=False, | |
| repetition_penalty=1.0, | |
| eos_token_id=[151329, 151336, 151338] | |
| ) | |
| print("Warming up complete. Starting stress test...") | |
| for i in range(n): | |
| random_token_ids = torch.randint(3, vocab_size - 200, (token_len - 5,), dtype=torch.long) | |
| input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze( | |
| 0).to(device) | |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device) | |
| position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device) | |
| test_inputs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'position_ids': position_ids | |
| } | |
| streamer = TextIteratorStreamer( | |
| tokenizer=tokenizer, | |
| timeout=36000, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generate_kwargs = { | |
| "input_ids": test_inputs['input_ids'], | |
| "attention_mask": test_inputs['attention_mask'], | |
| "max_new_tokens": 512, | |
| "do_sample": False, | |
| "repetition_penalty": 1.0, | |
| "eos_token_id": [151329, 151336, 151338], | |
| "streamer": streamer | |
| } | |
| start_time = time.time() | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| first_token_time = None | |
| all_token_times = [] | |
| for token in streamer: | |
| current_time = time.time() | |
| if first_token_time is None: | |
| first_token_time = current_time | |
| times.append(first_token_time - start_time) | |
| all_token_times.append(current_time) | |
| t.join() | |
| end_time = time.time() | |
| avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0 | |
| decode_times.append(avg_decode_time_per_token) | |
| print( | |
| f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second") | |
| torch.cuda.empty_cache() | |
| avg_first_token_time = sum(times) / n | |
| avg_decode_time = sum(decode_times) / n | |
| print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds") | |
| print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second") | |
| return times, avg_first_token_time, decode_times, avg_decode_time | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Stress test for model inference") | |
| parser.add_argument('--token_len', type=int, default=1000, help='Number of tokens for each test') | |
| parser.add_argument('--n', type=int, default=3, help='Number of iterations for the stress test') | |
| parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use for inference') | |
| args = parser.parse_args() | |
| token_len = args.token_len | |
| n = args.n | |
| num_gpu = args.num_gpu | |
| stress_test(token_len, n, num_gpu) | |
| if __name__ == "__main__": | |
| main() | |