Spaces:
Runtime error
Runtime error
| import torch | |
| import gc | |
| import os | |
| import time | |
| import random | |
| from typing import Dict, Optional, Sequence, List, Tuple | |
| from transformers.cache_utils import Cache, DynamicCache | |
| from transformers import ( | |
| LlamaModel, | |
| LlamaForCausalLM, | |
| GenerationConfig, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| TextIteratorStreamer, | |
| ) | |
| from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask | |
| import torch.nn.functional as F | |
| def get_jacobian_trajectory( | |
| model, tokenizer, input_ids, attention_mask, max_new_tokens | |
| ): | |
| bsz = input_ids.shape[0] | |
| prompt_len = [torch.sum(t) for t in attention_mask] | |
| max_prompt_len = max(prompt_len) | |
| total_len = max_prompt_len + max_new_tokens | |
| # initialize the first point of jacobian trajectory | |
| tokens = torch.full( | |
| (bsz, total_len), tokenizer.pad_token_id, dtype=torch.long, device=model.device | |
| ) | |
| for i in range(bsz): | |
| tokens[i, :] = torch.tensor( | |
| random.choices(input_ids[i][attention_mask[i] == 1], k=total_len), | |
| dtype=torch.long, | |
| device=model.device, | |
| ) | |
| tokens[i, : prompt_len[i]] = input_ids[i][: prompt_len[i]].to( | |
| dtype=torch.long, device=model.device | |
| ) | |
| itr = 0 | |
| next_generation = tokens | |
| generate_attention_mask = torch.full_like(next_generation, 1).to(model.device) | |
| accurate_lengths = torch.tensor([prompt_len[i].item()] * bsz, device=model.device) | |
| prev_len = 0 | |
| while True: | |
| current_generation = next_generation | |
| with torch.no_grad(): | |
| logits = model(current_generation, generate_attention_mask).logits | |
| next_generation = torch.argmax( | |
| torch.nn.functional.softmax(logits, dim=-1) / 0.001, dim=-1 | |
| ) | |
| # hold prompt unchanged and update generated tokens | |
| for i in range(bsz): | |
| next_generation[i, :] = torch.cat( | |
| ( | |
| tokens[i, : prompt_len[i]], | |
| next_generation[i, prompt_len[i] - 1 : total_len - 1], | |
| ), | |
| dim=0, | |
| ) | |
| if ( | |
| torch.all(torch.eq(next_generation, current_generation)).item() | |
| and itr == max_new_tokens | |
| or len( | |
| torch.where( | |
| current_generation[0, : accurate_lengths[0]] | |
| == tokenizer.eos_token_id | |
| )[0] | |
| ) | |
| > 0 | |
| ): | |
| # forced exit due to max_new_tokens constraint or eos reached | |
| return next_generation, itr | |
| # skip the first itr, current_generation has not been updated yet | |
| if itr != 0: | |
| if torch.all(torch.eq(next_generation, current_generation)).item(): | |
| matched_position = total_len | |
| else: | |
| matched_position = ( | |
| torch.eq(current_generation, next_generation).squeeze(0) == False | |
| ).nonzero(as_tuple=True)[0][0] | |
| fast_forward_cnt = matched_position - accurate_lengths[0] | |
| for i in range(bsz): | |
| accurate_lengths[i] = matched_position.item() | |
| # flush and print the first sequence | |
| generated_str = tokenizer.decode( | |
| next_generation[0, prompt_len[0] : accurate_lengths[0]], | |
| skip_special_tokens=True, | |
| spaces_between_special_tokens=False, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| print(generated_str[prev_len:], flush=True, end="") | |
| prev_len = len(generated_str) | |
| if torch.all(torch.eq(next_generation, current_generation)).item(): | |
| # early termination: itr < max_new_tokens | |
| return next_generation, itr | |
| itr += 1 | |
| def generate_stream_cllm( | |
| model, | |
| tokenizer, | |
| params, | |
| device, | |
| context_len, | |
| stream_interval=2, | |
| judge_sent_end=False, | |
| ): | |
| # converge_step = [] | |
| prompt = params["prompt"] | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| max_new_tokens = int(params.get("n_token_seq_length", 32)) | |
| max_new_seq_len = int(params.get("max_new_tokens", 1024)) | |
| prompt_len = torch.sum(inputs["attention_mask"], dim=-1) | |
| generation = inputs["input_ids"] | |
| input_echo_len = len(generation) | |
| ### generation phase | |
| itr = 0 | |
| eos_reached = False | |
| while True: | |
| if itr == 0: | |
| input_ids = inputs["input_ids"] | |
| input_masks = inputs["attention_mask"] | |
| else: | |
| input_masks = torch.ones_like(input_ids).to(device) | |
| for j in range(bsz): | |
| input_masks[j][ | |
| torch.sum(inputs["attention_mask"], dim=-1)[j] | |
| + itr * max_new_tokens : | |
| ] = 0 | |
| bsz = input_ids.shape[0] | |
| eos_reached = torch.tensor([False] * bsz, device=device) | |
| generation, iter_steps = get_jacobian_trajectory( | |
| model=model, | |
| tokenizer=tokenizer, | |
| input_ids=input_ids, | |
| attention_mask=input_masks, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| ### inspect <eos> | |
| for j in range(bsz): | |
| prompt_len = torch.sum(input_masks, dim=-1) | |
| eos_positions = torch.where(generation[j] == tokenizer.eos_token_id)[0] | |
| if len(eos_positions) == 0: | |
| # no EOS, continue to the next item in the batch | |
| generation[j][prompt_len[j] + max_new_tokens :] = tokenizer.pad_token_id | |
| continue | |
| # otherwise, set tokens coming after EOS as pad | |
| else: | |
| if len(eos_positions) != 0: | |
| eos_reached[j] = True | |
| generation[j, int(eos_positions[0]) + 1 :] = tokenizer.pad_token_id | |
| itr += 1 | |
| if all(eos_reached) or itr * max_new_tokens >= max_new_seq_len: | |
| break | |
| input_ids = generation[ | |
| torch.where(eos_reached == False)[0].tolist(), ... | |
| ] # delete samples with <eos> generated | |
| if all(eos_reached): | |
| finish_reason = "eos" | |
| elif itr * max_new_tokens > max_new_seq_len: | |
| finish_reason = "length" | |
| else: | |
| finish_reason = "stop" | |
| output = tokenizer.decode(input_ids[0], skip_special_tokens=False) | |
| yield { | |
| "text": "", | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": itr * max_new_tokens, | |
| "total_tokens": input_echo_len + itr * max_new_tokens, | |
| }, | |
| "finish_reason": finish_reason, | |
| } | |
| # clean | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if device == "xpu": | |
| torch.xpu.empty_cache() | |
| if device == "npu": | |
| torch.npu.empty_cache() | |