Spaces:
Runtime error
Runtime error
| import gc | |
| from threading import Thread | |
| import torch | |
| from diffusers import DDIMScheduler | |
| import transformers | |
| from transformers import ( | |
| GenerationConfig, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| TextIteratorStreamer, | |
| ) | |
| from fastchat.utils import build_logger | |
| logger = build_logger("diffusion_infer", 'diffusion_infer.log') | |
| def generate_stream_sde( | |
| model, | |
| tokenizer, | |
| params, | |
| device, | |
| context_len=256, | |
| stream_interval=2, | |
| ): | |
| prompt = params["prompt"] | |
| # temperature = float(params.get("temperature", 1.0)) | |
| # repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
| # top_p = float(params.get("top_p", 1.0)) | |
| # top_k = int(params.get("top_k", 50)) # -1 means disable | |
| # max_new_tokens = int(params.get("max_new_tokens", 1024)) | |
| # stop_token_ids = params.get("stop_token_ids", None) or [] | |
| # stop_token_ids.append(tokenizer.eos_token_id) | |
| # | |
| # decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| # streamer = TextIteratorStreamer(tokenizer, **decode_config) | |
| encoding = tokenizer(prompt, return_tensors="pt").to(device) | |
| input_ids = encoding.input_ids | |
| # encoding["decoder_input_ids"] = encoding["input_ids"].clone() | |
| input_echo_len = len(input_ids) | |
| # | |
| # generation_config = GenerationConfig( | |
| # max_new_tokens=max_new_tokens, | |
| # do_sample=temperature >= 1e-5, | |
| # temperature=temperature, | |
| # repetition_penalty=repetition_penalty, | |
| # no_repeat_ngram_size=10, | |
| # top_p=top_p, | |
| # top_k=top_k, | |
| # eos_token_id=stop_token_ids, | |
| # ) | |
| # | |
| # class CodeBlockStopper(StoppingCriteria): | |
| # def __call__( | |
| # self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs | |
| # ) -> bool: | |
| # # Code-completion is open-end generation. | |
| # # We check \n\n to stop at end of a code block. | |
| # if list(input_ids[0][-2:]) == [628, 198]: | |
| # return True | |
| # return False | |
| # gen_kwargs = dict( | |
| # **encoding, | |
| # streamer=streamer, | |
| # generation_config=generation_config, | |
| # stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), | |
| # ) | |
| generation_kwargs = {"prompt": prompt} | |
| model.scheduler = DDIMScheduler.from_config(model.scheduler.config) | |
| logger.info(f"model.scheduler: {model.scheduler}") | |
| thread = Thread(target=model, kwargs=generation_kwargs) | |
| thread.start() | |
| # i = 0 | |
| # output = "" | |
| # for new_text in streamer: | |
| # i += 1 | |
| # output += new_text | |
| # if i % stream_interval == 0 or i == max_new_tokens - 1: | |
| # yield { | |
| # "text": output, | |
| # "usage": { | |
| # "prompt_tokens": input_echo_len, | |
| # "completion_tokens": i, | |
| # "total_tokens": input_echo_len + i, | |
| # }, | |
| # "finish_reason": None, | |
| # } | |
| # if i >= max_new_tokens: | |
| # break | |
| # | |
| # if i >= max_new_tokens: | |
| # finish_reason = "length" | |
| # else: | |
| # finish_reason = "stop" | |
| logger.info(f"prompt: {prompt}") | |
| output = model(prompt=prompt).images[0] | |
| yield { | |
| "text": output, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": 0, | |
| "total_tokens": input_echo_len, | |
| }, | |
| "finish_reason": "stop", | |
| } | |
| thread.join() | |
| # clean | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if device == "xpu": | |
| torch.xpu.empty_cache() | |
| if device == "npu": | |
| torch.npu.empty_cache() | |