Spaces:
Runtime error
Runtime error
| import gc | |
| from threading import Thread | |
| import torch | |
| from transformers import TextIteratorStreamer | |
| def generate_stream_xft( | |
| model, | |
| tokenizer, | |
| params, | |
| device, | |
| context_len=8192, | |
| stream_interval=2, | |
| judge_sent_end=False, | |
| ): | |
| prompt = params["prompt"] | |
| repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
| # unused now, and placehold for future. | |
| # temperature = float(params.get("temperature", 1.0)) | |
| # top_p = float(params.get("top_p", 1.0)) | |
| max_new_tokens = int(params.get("max_new_tokens", 4096)) | |
| echo = params.get("echo", True) | |
| inputs = tokenizer( | |
| prompt, return_tensors="pt", padding=model.config.padding | |
| ).input_ids | |
| input_echo_len = len(inputs[0]) | |
| max_len = max_new_tokens + input_echo_len | |
| decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) | |
| generation_kwargs = { | |
| "input_ids": inputs, | |
| "streamer": streamer, | |
| "max_length": max_len, | |
| "num_beams": model.config.beam_width, | |
| "length_penalty": repetition_penalty, | |
| "num_return_sequences": model.config.num_return_sequences, | |
| "early_stopping": model.config.early_stopping, | |
| "eos_token_id": model.config.eos_token_id, | |
| "pad_token_id": model.config.pad_token_id, | |
| } | |
| thread = Thread(target=model.model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| if echo: | |
| # means keep the prompt | |
| output = prompt | |
| else: | |
| output = "" | |
| i = 0 | |
| for i, new_text in enumerate(streamer): | |
| output += new_text | |
| yield { | |
| "text": output, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": i, | |
| "total_tokens": input_echo_len + i, | |
| }, | |
| "finish_reason": None, | |
| } | |
| output = output.strip() | |
| if i == max_new_tokens - 1: | |
| finish_reason = "length" | |
| else: | |
| finish_reason = "stop" | |
| yield { | |
| "text": output, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": i, | |
| "total_tokens": input_echo_len + i, | |
| }, | |
| "finish_reason": finish_reason, | |
| } | |
| gc.collect() | |