import logging import gc import json import time # Add time module from datetime import timedelta from typing import List, Optional, Tuple, Type, TypeVar, Union, Dict import torch import torch.nn.functional as F import torch.distributions as dists import transformers from transformers import AutoTokenizer from peft import LoraConfig, get_peft_model from accelerate import ( Accelerator, InitProcessGroupKwargs, ) from datasets import Dataset from packaging import version from tqdm import tqdm from peft import PeftConfig, PeftModel import numpy as np # Add numpy import import os import jinja2 # Import LLaDA model related modules from model_cache.llada.modeling_llada import LLaDAModelLM from model_cache.llada.configuration_llada import LLaDAConfig from lm_eval import utils from lm_eval.api.instance import Instance from lm_eval.api.model import TemplateLM from lm_eval.api.registry import register_model from lm_eval.models.utils import get_dtype from lm_eval.__main__ import cli_evaluate eval_logger = logging.getLogger(__name__) T = TypeVar("T", bound="TemplateLM") import random def set_seed(seed): torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None): """ Creates a complete attention mask for the entire sequence with block-based causal attention. Args: prompt_length: Length of the prompt (first irregular block) max_length: Maximum total sequence length block_size: Size of each regular block device: Device to create tensor on dtype: Data type for the attention mask Returns: attention_mask: Tensor of shape [1, 1, max_length, max_length] """ # Use the provided dtype or default to bfloat16 if dtype is None: dtype = torch.bfloat16 # Initialize mask with -inf (no attention) attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype) # Block 0: Prompt (can see itself) attention_mask[:, :, :prompt_length, :prompt_length] = 0 # Calculate the number of regular blocks after prompt remaining_length = max_length - prompt_length num_blocks = (remaining_length + block_size - 1) // block_size # Process each regular block for b in range(num_blocks): block_start = prompt_length + b * block_size block_end = min(prompt_length + (b + 1) * block_size, max_length) # Current block can see the prompt attention_mask[:, :, block_start:block_end, :prompt_length] = 0 # Current block can see all previous regular blocks for prev_b in range(b): prev_start = prompt_length + prev_b * block_size prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length) attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0 # Current block can see itself (full attention within block) attention_mask[:, :, block_start:block_end, block_start:block_end] = 0 return attention_mask def extract_attention_mask(full_mask, start_pos, input_length, cache_length): """ Extract the relevant portion of attention mask for current forward pass. Args: full_mask: Complete attention mask [1, 1, max_length, max_length] start_pos: Starting position in the full sequence input_length: Length of current input sequence cache_length: Length of cached sequence Returns: attention_mask: Extracted mask [1, 1, input_length, cache_length + input_length] """ end_pos = start_pos + input_length total_length = cache_length + input_length # Extract the relevant rows (current input positions) # and columns (cache + current input positions) extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf, device=full_mask.device, dtype=full_mask.dtype) # Copy cache columns (0 to cache_length in the extracted mask corresponds to 0 to cache_length in full mask) extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length] # Copy current input columns extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos] return extracted_mask def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None, dtype=None): """ Builds a custom float attention mask with block-based causal attention. Args: input_ids: Input token IDs. prompt_length: Length of the prompt for each sequence in the batch. block_size: Size of each regular block. device: Device to create tensor on. dtype: Data type for the attention mask. Returns: attn_mask: Tensor of shape [B, 1, seq_len, seq_len]. """ B, seq_len = input_ids.shape # Use the provided dtype or default to float32 if dtype is None: dtype = torch.float32 # Initialize to all -inf attn_mask = torch.full((B, 1, seq_len, seq_len), float('-inf'), dtype=dtype, device=device) # 1. Prompt section: each token can attend to the entire prompt for i in range(B): attn_mask[i, :, :, :prompt_length[i]] = 0.0 # Allow all tokens to see the prompt # 2. Block division: divide blocks starting from prompt_length num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size for b in range(num_blocks): block_start = prompt_length[i] + b * block_size block_end = min(block_start + block_size, seq_len) # Full attention within the block attn_mask[i, :, block_start:block_end, block_start:block_end] = 0.0 # Causal attention between blocks (can only see previous blocks) for prev_b in range(b): prev_start = prompt_length[i] + prev_b * block_size prev_end = min(prev_start + block_size, seq_len) # Current block can see previous blocks attn_mask[i, :, block_start:block_end, prev_start:prev_end] = 0.0 return attn_mask def top_p_logits(logits, top_p=None): sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) return logits def top_k_logits(logits, top_k=None): top_k = min(top_k, logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) return logits def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): if temperature > 0: logits = logits / temperature if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p) if top_k is not None: logits = top_k_logits(logits, top_k) probs = torch.softmax(logits, dim=-1) if temperature > 0: try: x0 = dists.Categorical(probs=probs).sample() initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) except: initial_confidence, x0 = probs.max(dim=-1) else: initial_confidence, x0 = probs.max(dim=-1) # Save initial confidence confidence = initial_confidence.clone() if margin_confidence: sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) # Extract top1 and top2 probabilities top1_probs = sorted_probs[:, 0] top2_probs = sorted_probs[:, 1] # Calculate confidence as top1 - top2 confidence = top1_probs - top2_probs if neg_entropy: epsilon = 1e-10 log_probs = torch.log(probs + epsilon) confidence = torch.sum(probs * log_probs, dim=-1) return confidence, x0, initial_confidence @register_model("dream_lora") class DreamLoRA(TemplateLM): def __init__( self, pretrained: Union[str, transformers.PreTrainedModel], lora_path: str, batch_size: Optional[Union[int, str]] = 1, device: Optional[str] = "cuda", dtype: Optional[Union[str, torch.dtype]] = "auto", max_new_tokens: Optional[int] = 128, max_length: Optional[int] = 4096, # Updated to match example code add_bos_token: Optional[bool] = False, nll_type: Optional[str] = "mc", log_type: Optional[str] = "ftb", mc_num: Optional[int] = 128, classifier_free_guidance: Optional[float] = 1.0, sampling_eps: Optional[float] = 1e-3, diffusion_steps: Optional[int] = 128, trust_remote_code: Optional[bool] = True, parallelize: Optional[bool] = False, autogptq: Optional[Union[bool, str]] = False, temperature: Optional[float] = 0.2, # Updated default value top_p: Optional[float] = None, # Updated default value top_k: Optional[float] = None, alg: Optional[str] = "entropy", alg_temp: Optional[float] = 0.0, escape_until: Optional[bool] = False, block_size: Optional[int] = 4, # Updated to match example code mask_token_id: Optional[int] = 126336, # Added mask_token_id parameter block_add_threshold: Optional[float] = 0.5, # Added block_add_threshold parameter decoded_token_threshold: Optional[float] = 0.9, # Added decoded token threshold parameter skip_threshold: Optional[float] = 1.0, # Added skip_threshold parameter sampling_strategy: Optional[str] = "default", # Added sampling strategy parameter save_dir: Optional[str] = None, # Added save directory parameter show_speed: Optional[bool] = True, # Added speed statistics parameter **kwargs, ) -> None: super().__init__() # prepare for parallelism assert isinstance(device, str) assert isinstance(pretrained, str) assert isinstance(batch_size, (int, str)) gpus = torch.cuda.device_count() accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) if accelerator.num_processes > 1: self.accelerator = accelerator if "npu" in accelerator.device.type: gpus = torch.npu.device_count() # using one process with no model parallelism if not (parallelize or accelerator.num_processes > 1): # use user-passed device device_list = set( ["cuda", "cpu"] + [f"cuda:{i}" for i in range(gpus)] + ["mps", "mps:0"] + [f"npu:{i}" for i in range(gpus)] ) if device and device in device_list: self._device = torch.device(device) eval_logger.info(f"Using device '{device}'") if device in ("mps", "mps:0") and version.parse( torch.__version__ ) < version.parse("2.1"): raise RuntimeError( f"mps requires torch >= 2.1. You have {torch.__version__}" ) else: eval_logger.info("Device not specified") eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") self._device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) else: # Parallelism managed by accelerate if device != "cuda": eval_logger.info( f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." ) # TODO: include in warning that `load_in_8bit` etc. affect this too self._device = ( self.accelerator.device if hasattr(self, "accelerator") else torch.device(device) ) self.batch_size_per_gpu = batch_size if isinstance(batch_size, str): self.batch_size_per_gpu = int(batch_size) # Save LoRA path and block_size self.lora_path = lora_path self.block_size = block_size self.block_add_threshold = block_add_threshold # Added block_add_threshold attribute self.skip_threshold = skip_threshold # Added skip_threshold attribute self.sampling_strategy = sampling_strategy # Save sampling strategy parameter self.decoded_token_threshold = decoded_token_threshold # Added decoded token threshold attribute # Save target_dtype for later use self.target_dtype = get_dtype(dtype) self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code) if isinstance(pretrained, str): if gpus >= 1 or str(self.device) == "mps": # TODO: can remove this whole snippet except in the mps case, perhaps? if not (parallelize or autogptq or hasattr(self, "accelerator")): # place model onto device requested manually, # if not using HF Accelerate or device_map # or any other option that preloads model onto device try: self.model.to(self.device) except ValueError: eval_logger.debug( "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." ) # multigpu data-parallel support when launched with accelerate if gpus > 1: if accelerator.num_processes > 1: if parallelize: eval_logger.warning( "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available." ) elif gpus > accelerator.num_processes: eval_logger.warning( "WARNING: The number of total system GPUs does not match the number of spawned processes. " "If you would like to use data parallelism, please launch the script " "with 'accelerate launch *script*'. " f"Current run will proceed with {accelerator.num_processes} devices." ) if self.accelerator.is_local_main_process: eval_logger.info( f"Using {gpus} devices with data parallelism" ) self._device = torch.device(f"{accelerator.device}") self.accelerator = accelerator self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes else: # if we aren't launching via accelerate, ditch self._rank = 0 self._world_size = 1 else: # if a PreTrainedModel was passed into HFLM, we forgo distributed setup. eval_logger.warning( "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration" ) self._rank = 0 self._world_size = 1 self.max_length = max_length self.add_bos_token = add_bos_token # generation params self.max_new_tokens = max_new_tokens self.diffusion_steps = diffusion_steps self.temperature = temperature self.top_p = top_p self.top_k = top_k self.alg = alg self.alg_temp = alg_temp self.escape_until = escape_until self.block_size = block_size self.mask_token_id = mask_token_id # loglikelihood params self.nll_type = nll_type self.log_type = log_type self.mc_num = mc_num self.classifier_free_guidance = classifier_free_guidance self.sampling_eps = sampling_eps # Add backend attribute, consistent with LLaDA.py self.backend = "causal" # Add truncation attribute, consistent with LLaDA.py self.truncation = False self.save_dir = save_dir self.show_speed = show_speed @property def batch_size(self): return self.batch_size_per_gpu @property def eot_token_id(self): # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* return self.tokenizer.eos_token_id @property def device(self): return self._device @property def rank(self): return self._rank @property def world_size(self): return self._world_size def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code): # Get correct data type target_dtype = get_dtype(dtype) # Load LLaDA model and configuration config = LLaDAConfig.from_pretrained(pretrained) self.model = LLaDAModelLM.from_pretrained( pretrained, config=config, torch_dtype=target_dtype, trust_remote_code=False, ).eval() # Load LoRA configuration and model peft_config = PeftConfig.from_pretrained(self.lora_path) self.model = PeftModel.from_pretrained(self.model, self.lora_path) # Convert data type only when target_dtype is not None and not "auto" if target_dtype is not None and target_dtype != "auto": self.model = self.model.to(target_dtype) # Move to specified device self.model = self.model.to(self.device) # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( pretrained, trust_remote_code=trust_remote_code ) def tok_encode( self, string: str, left_truncate_len=None, add_special_tokens=None ) -> List[int]: """ """ # default for None - empty dict, use predefined tokenizer param # used for all models except for CausalLM or predefined value special_tokens_kwargs = {} # by default for CausalLM - false or self.add_bos_token is set if add_special_tokens is None: if self.backend == "causal": special_tokens_kwargs = { "add_special_tokens": False or self.add_bos_token } # otherwise the method explicitly defines the value else: special_tokens_kwargs = {"add_special_tokens": add_special_tokens} encoding = self.tokenizer.encode(string, **special_tokens_kwargs) # left-truncate the encoded context to be at most `left_truncate_len` tokens long if left_truncate_len: encoding = encoding[-left_truncate_len:] return encoding def tok_batch_encode( self, strings: List[str], padding_side: str = "left", left_truncate_len: int = None, truncation: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. old_padding_side = self.tokenizer.padding_side self.tokenizer.padding_side = padding_side add_special_tokens = {} if self.backend == "causal": add_special_tokens = {"add_special_tokens": False or self.add_bos_token} encoding = self.tokenizer( strings, truncation=truncation, padding="longest", return_tensors="pt", **add_special_tokens, ) if left_truncate_len: original_lengths = encoding["input_ids"].size(1) if original_lengths > left_truncate_len: eval_logger.warn( f"Left truncation applied. Original sequence length was {original_lengths}, " f"truncating to last {left_truncate_len} tokens. Some content will be lost.", ) encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] encoding["attention_mask"] = encoding["attention_mask"][ :, -left_truncate_len: ] self.tokenizer.padding_side = old_padding_side return encoding["input_ids"].to(self.device), encoding["attention_mask"].to(self.device) def tok_decode(self, tokens, skip_special_tokens=True): return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) def _count_tokens_after_truncation(self, response_text: str, until_terms: List[str] = None) -> int: """ Unified token counting function: calculates the number of non-126081 tokens after truncating the response. """ # Apply truncation based on until parameters truncated_text = response_text if until_terms and not self.escape_until: for term in until_terms: if len(term) > 0: truncated_text = truncated_text.split(term)[0] # Re-tokenize processed answer and count non-126081 tokens generated_answer_ids = torch.tensor(self.tokenizer(truncated_text)["input_ids"]) return int((generated_answer_ids != 126081).sum()) @classmethod def create_from_arg_string( cls: Type[T], arg_string: str, additional_config: Optional[dict] = None ) -> T: """ Creates an instance of the LM class using the given argument string and additional config. Parameters: - arg_string: A string containing arguments in the format key1=value1,key2=value2. - additional_config: Optional dictionary containing additional configuration parameters. Returns: - Instance of the LM class. """ additional_config = {} if additional_config is None else additional_config args = utils.simple_parse_args_string(arg_string) args2 = {k: v for k, v in additional_config.items() if v is not None} return cls(**args, **args2) def apply_chat_template( self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True ) -> str: """ Method to apply a chat template to a list of chat history between user and model. """ try: chat_templated = self.tokenizer.apply_chat_template( chat_history, tokenize=False, add_generation_prompt=add_generation_prompt, continue_final_message=not add_generation_prompt, ) except jinja2.exceptions.TemplateError: eval_logger.warning( "Failed to apply chat template. removing the system role in chat history." ) chat_history = [msg for msg in chat_history if msg["role"] != "system"] chat_templated = self.tokenizer.apply_chat_template( chat_history, tokenize=False, add_generation_prompt=add_generation_prompt, continue_final_message=not add_generation_prompt, ) return chat_templated @property def tokenizer_name(self) -> str: return self.tokenizer.name_or_path.replace("/", "__") def _generate_block_single(self, prompt): """ Generates a response for a single prompt using parallel block generation, based on KV cache, and uses pre-generated attention masks. Returns: generated_sequence (List[int]) - List of generated token IDs """ self.model.eval() mask_id = self.mask_token_id block_size = self.block_size block_add_threshold = self.block_add_threshold skip_threshold = self.skip_threshold # Pre-generate the full attention mask, using the model's data type prompt_length = prompt.shape[1] full_attention_mask = create_full_block_attention_mask( prompt_length=prompt_length, max_length=self.max_length, block_size=block_size, device=self.device, dtype=self.target_dtype if self.target_dtype is not None and self.target_dtype != "auto" else torch.bfloat16 ) with torch.inference_mode(): # Initialization x_t = prompt.to(self.device) # Track block states - states can be: 'active', 'to_cache', 'in_cache' # Added 'is_complete' field to indicate whether it's a complete state (True) or incomplete state (False) block_states = { 0: { 'start_pos': 0, 'end_pos': prompt.shape[1], 'mask_count': 0, 'total_masks': prompt.shape[1], 'state': 'to_cache', # Prompt is immediately ready for caching 'is_complete': True, # Prompt is always in a complete state }, } # Initialize cache past_key_values = None current_blocks = 0 # Number of active blocks step = 0 eos_detected = False # EOS detection flag cache_length = 0 while current_blocks >= 0: step += 1 # Check if a new block needs to be added if len(block_states)-1 < (self.max_new_tokens // block_size) and not eos_detected: last_block_id = len(block_states) - 1 current_progress = (block_states[last_block_id]['total_masks'] - block_states[last_block_id]['mask_count']) / block_states[last_block_id]['total_masks'] if current_progress >= block_add_threshold: # Add new block new_block_id = len(block_states) new_start_pos = x_t.shape[1] x_t = torch.cat([x_t, torch.tensor([[mask_id] * block_size]).to(self.device)], dim=1) block_states[new_block_id] = { 'start_pos': new_start_pos, 'end_pos': new_start_pos + block_size, 'mask_count': block_size, 'total_masks': block_size, 'state': 'active', 'is_complete': False, # New block defaults to an incomplete state } current_blocks += 1 # At the beginning of each loop, update the block's complete/incomplete states self._update_block_completion_states(block_states, self.decoded_token_threshold) # Check if there are still mask tokens mask_index = (x_t == mask_id) if mask_index.sum() == 0 and current_blocks == 0: break # Determine which blocks need to be added to the cache blocks_to_cache = [bid for bid, state in block_states.items() if state['state'] == 'to_cache'] # Determine the part to be processed update_kvcache = 0 if blocks_to_cache: # Find the earliest block to be cached earliest_block_id = min(blocks_to_cache) earliest_pos = block_states[earliest_block_id]['start_pos'] # Find the latest block to be cached latest_block_id = max(blocks_to_cache) latest_pos = block_states[latest_block_id]['end_pos'] # Update the cache for all blocks within this range update_kvcache = latest_pos - earliest_pos # Create input sequence for forward pass process_start_pos = cache_length if update_kvcache > 0: # Need to update cache - use completed blocks earliest_block_to_cache = min(blocks_to_cache) input_seq = x_t[:, block_states[earliest_block_to_cache]['start_pos']:] process_start_pos = block_states[earliest_block_to_cache]['start_pos'] else: # Only process active blocks active_blocks = [bid for bid, state in block_states.items() if state['state'] == 'active'] if active_blocks: # Get all active blocks after caching earliest_active_after_cache = float('inf') for bid in active_blocks: if block_states[bid]['start_pos'] >= cache_length: earliest_active_after_cache = min(earliest_active_after_cache, block_states[bid]['start_pos']) if earliest_active_after_cache < float('inf'): input_seq = x_t[:, earliest_active_after_cache:] process_start_pos = earliest_active_after_cache else: # No active blocks after caching, this should not happen input_seq = x_t[:, cache_length:] # If cache length is already equal to or exceeds sequence length, exit if cache_length >= x_t.shape[1]: print(f"Cache length ({cache_length}) >= sequence length ({x_t.shape[1]}) at step {step}. Exiting generation loop.") raise Exception("Cache length >= sequence length") else: # No active blocks, but blocks might need to be cached in the next iteration break # Check if input_seq is empty if input_seq.shape[1] == 0: print(f"Warning: input_seq is empty at step {step}. Breaking generation loop.") raise Exception("input_seq is empty") # Extract the attention mask for the current input from the pre-generated full mask input_length = input_seq.shape[1] attention_mask = extract_attention_mask( full_mask=full_attention_mask, start_pos=process_start_pos, input_length=input_length, cache_length=cache_length ) outputs = self.model( input_seq, attention_bias=attention_mask, past_key_values=past_key_values, use_cache=True, update_kvcache=update_kvcache+cache_length, ) # Get current logits - LLaDA model directly uses logits, no shifting needed logits = outputs.logits # Update cache if needed if update_kvcache > 0: # Update cache past_key_values = outputs.past_key_values # Mark blocks as cached for block_id in blocks_to_cache: block_states[block_id]['state'] = 'in_cache' # Process mask tokens for each active block blocks_to_deactivate = [] for block_id in sorted(block_states.keys()): if block_states[block_id]['state'] != 'active': continue # Get mask positions for this block block_start = block_states[block_id]['start_pos'] block_end = block_states[block_id]['end_pos'] block_mask_index = mask_index.clone() block_mask_index[:, :block_start] = False block_mask_index[:, block_end:] = False # Skip if the current block has no masks if block_mask_index.sum() == 0: blocks_to_deactivate.append(block_id) continue # Calculate relative position of logits logit_offset = block_start - process_start_pos block_rel_positions = torch.where(block_mask_index[0, block_start:block_end])[0] if block_rel_positions.size(0) > 0: # Get logits for masked positions block_mask_logits = logits[:, logit_offset + block_rel_positions, :] # Sample tokens confidence, x0, initial_confidence = sample_tokens( block_mask_logits.squeeze(0), self.temperature, top_p=self.top_p, top_k=self.top_k, neg_entropy=(self.sampling_strategy == "neg_entropy"), margin_confidence=(self.sampling_strategy == "margin_confidence") ) # Use different sampling strategies based on the block's complete/incomplete state is_complete = block_states[block_id]['is_complete'] if is_complete: # Complete state: apply confidence threshold, if no high confidence, select the highest high_conf_indices = torch.where(initial_confidence > skip_threshold)[0] if len(high_conf_indices) == 0: number_transfer_tokens = 1 _, transfer_index = torch.topk(confidence, number_transfer_tokens) else: transfer_index = torch.tensor([], device=self.device, dtype=torch.long) # Merge indices all_indices = torch.unique(torch.cat([transfer_index, high_conf_indices])) else: # Incomplete state: only apply confidence threshold, if no tokens exceed the threshold, select none high_conf_indices = torch.where(initial_confidence > skip_threshold)[0] all_indices = high_conf_indices # Update tokens if len(all_indices) > 0: x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_id x0_[all_indices] = x0[all_indices].clone() # Map indices back to original positions for i, idx in enumerate(all_indices): abs_pos = block_start + block_rel_positions[idx] x_t[0, abs_pos] = x0_[idx] # Update block state block_states[block_id]['mask_count'] -= len(all_indices) # Check for EOS token eos_token_id = 126081 if eos_token_id is not None: for idx in all_indices: if x0[idx].item() == eos_token_id: eos_detected = True break # Deactivate this block if no masks remain mask_index = (x_t == mask_id) block_mask_index = mask_index.clone() block_mask_index[:, :block_start] = False block_mask_index[:, block_end:] = False if block_mask_index.sum() == 0: blocks_to_deactivate.append(block_id) continue # Deactivate completed blocks and mark them for caching in the next iteration for block_id in blocks_to_deactivate: if block_states[block_id]['state'] == 'active': # Check if all preceding blocks are already in a non-active state can_deactivate = True for prev_block_id in range(block_id): if prev_block_id in block_states and block_states[prev_block_id]['state'] == 'active': can_deactivate = False break # Only mark the current block as 'to_cache' if all preceding blocks are not active if can_deactivate: block_states[block_id]['state'] = 'to_cache' current_blocks -= 1 # If there are active preceding blocks, keep the current block in active state (do nothing) if update_kvcache > 0: cache_length += update_kvcache # Safety check if step > 10000: print(f"WARNING: Hit safety check at step {step}. Exiting generation loop.") break current_text = self.tokenizer.decode(x_t[0, prompt.shape[1]:].tolist(),skip_special_tokens=False) # Generate final answer generated_sequence = x_t[0, prompt.shape[1]:].tolist() return generated_sequence def generate_until(self, requests: List[Instance], disable_tqdm: bool = False): res = [] start_time = time.time() # Statistics variables num_tokens = 0 num_nfe = 0 bar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running generate_until requests") for i, req in enumerate(requests): question = req.args[0] # print("question:",question) # exit() gen_kwargs = req.args[1] # Process input in LLaDA.py style # print("Self.add_bos_token:", self.add_bos_token) contexts = [question] if self.add_bos_token: contexts = [self.tokenizer.bos_token + p for p in contexts] # Use the same tokenization method as LLaDA.py context_enc, attn_masks = self.tok_batch_encode( contexts, truncation=self.truncation, ) input_ids = context_enc[0].unsqueeze(0) # Take the first one and add batch dimension # Add length check if input_ids.shape[1] > self.max_length - self.max_new_tokens: eval_logger.warning(f"Prompt length {input_ids.shape[1]} is larger than {self.max_length-self.max_new_tokens}, cutoff on the left side") input_ids = input_ids[:, -(self.max_length-self.max_new_tokens):] # Generate token IDs generated_answer = self._generate_block_single(input_ids) # Use tokenizer.batch_decode for decoding, consistent with LLaDA.py cont_toks_list = self.tokenizer.batch_decode([generated_answer], skip_special_tokens=True) s = cont_toks_list[0] # Take the first (and only) result # Use unified token counting function if self.show_speed: num_tokens += self._count_tokens_after_truncation(s, gen_kwargs.get("until", [])) num_nfe += 1 # NFE uses simplified statistics (fixed to 1) # Handle until truncation in LLaDA.py style if not self.escape_until: for term in gen_kwargs.get("until", []): if len(term) > 0: s = s.split(term)[0] res.append(s) bar.update(1) bar.close() # Save statistics only at the end if self.save_dir is not None: os.makedirs(self.save_dir, exist_ok=True) final_time = time.time() total_time = final_time - start_time final_stats = { "processed_samples": len(res), "total_samples": len(requests), "total_tokens": int(num_tokens), "total_nfe": int(num_nfe), "total_time": total_time, "tokens_per_second": float(num_tokens) / total_time if total_time > 0 else 0.0, "nfe_per_token": float(num_nfe) / float(num_tokens) if num_tokens > 0 else 0.0, "timestamp": final_time } final_stats_path = os.path.join(self.save_dir, f'rank_{self.rank}_final_stats.json') with open(final_stats_path, 'w', encoding='utf-8') as f: json.dump(final_stats, f, ensure_ascii=False, indent=2) if self.show_speed: final_time = time.time() total_time = final_time - start_time print(f"\n=== Final Statistics ===") print(f"Processed samples: {len(res)}") print(f"Total tokens: {num_tokens}") print(f"Total time: {total_time:.2f} seconds") print(f"Throughput: {num_tokens / total_time:.2f} tokens/s") print(f"Total NFE: {num_nfe}") return res def _forward_process(self, batch): b, l = batch.shape # sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1 u0 = torch.rand(1, device=batch.device, dtype=torch.float32) indices = torch.arange(b, device=batch.device).float() t = (u0 + indices / b) % 1 p_mask = (1 - self.sampling_eps) * t + self.sampling_eps p_mask = p_mask[:, None].repeat(1, l) mask_indices = torch.rand((b, l), device=batch.device) < p_mask # always unmask bos and eos mask_indices[:, 0] = False mask_indices[:, -1] = False noisy_batch = torch.where(mask_indices, self.mask_token_id, batch) return noisy_batch, p_mask @torch.no_grad() def get_logits(self, batch, prompt_index): ''' prompt_index : 1D bool tensor, length=batch.shape[1] ''' if self.classifier_free_guidance > 1.: assert len(prompt_index) == batch.shape[1] prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) un_batch = batch.clone() un_batch[prompt_index] = self.mask_token_id batch = torch.cat([batch, un_batch]) input = batch with torch.amp.autocast('cuda', dtype=torch.bfloat16): logits = self.model(input).logits # since bos always unmask, the first logits will not be used logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) if self.classifier_free_guidance > 1.: logits, un_logits = torch.chunk(logits, 2, dim=0) logits = un_logits + self.cfg * (logits - un_logits) return logits[:, :batch.shape[1]] @torch.no_grad() def _eval_target_nll_mc(self, prefix, target): if prefix is None: seq = target[None, :] else: seq = torch.concatenate([prefix, target])[None, :] seq = seq.repeat((self.batch_size, 1)).to(self.device) if self.log_type == 'ftb': prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) else: prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix) loss_acc = [] for _ in range(max(self.mc_num // self.batch_size, 1)): perturbed_seq = seq.clone() # eval_logger.info("before noising") perturbed_seq_, p_mask = self._forward_process(seq) # eval_logger.info("end noising") if self.log_type == 'ftb': perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):] elif self.log_type == 'btf': perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)] elif self.log_type == 'union': perturbed_seq = perturbed_seq_ else: raise NotImplementedError(self.log_type) mask_indices = perturbed_seq == self.mask_token_id logits = self.get_logits(perturbed_seq, prompt_index) loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices] loss = loss.sum() / self.batch_size loss_acc.append(loss.item()) return sum(loss_acc) / len(loss_acc) @torch.no_grad() def _eval_target_nll_ar(self, prefix, target): prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2 assert self.log_type in ['ftb', 'btf'] assert self.nll_type in ['ar_ftb', 'ar_btf'] if self.log_type == 'ftb': prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1] else: prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1] if self.log_type == 'ftb': perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2 else: perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1 mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool) if self.nll_type == 'ar_ftb': mask_index = torch.triu(mask_index) else: mask_index = torch.tril(mask_index) perturbed_[mask_index] = self.mask_token_id if self.log_type == 'ftb': perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1) else: perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1) logits_ = [] num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1 for i in range(num): end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq) perturbed_seq_ = perturbed_seq[i * self.batch_size: end] perturbed_seq_ = perturbed_seq_.to(self.device) if len(perturbed_seq_.shape) == 1: perturbed_seq_ = perturbed_seq_.unsqueeze(0) logits = self.get_logits(perturbed_seq_, prompt_index) logits_.append(logits.cpu()) logits = torch.cat(logits_, dim=0) temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool) if self.nll_type == 'ar_ftb': temp_index = torch.triu(temp_index, diagonal=1) else: temp_index = torch.tril(temp_index, diagonal=-1) mask_index[temp_index] = False if self.log_type == 'ftb': logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1) else: logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1) if self.log_type == 'ftb': loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item() else: loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item() return loss def _encode_pair(self, context, continuation): if self.add_bos_token: context = self.tokenizer.bos_token + context n_spaces = len(context) - len(context.rstrip()) if n_spaces > 0: continuation = context[-n_spaces:] + continuation context = context[:-n_spaces] whole_enc = self.tokenizer.encode(context + continuation) + [self.tokenizer.eos_token_id] context_enc = self.tokenizer.encode(context) context_enc_len = len(context_enc) continuation_enc = whole_enc[context_enc_len:] # by default truncate on the left cutoff_length = max(len(whole_enc) - self.max_length, 0) if cutoff_length > 0: eval_logger.warning(f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side") context_remain = context_enc_len-cutoff_length if context_remain > 0: context_enc = context_enc[-context_remain:] else: eval_logger.warning(f"All context (prompt) is truncated.") context_enc = "" continuation_enc = whole_enc[-self.max_length:] return context_enc, continuation_enc def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: def _tokenize(e): prefix, target = self._encode_pair(e["prefix"], e["target"]) return { "prefix_text": e["prefix"], "target_text": e["target"], "prefix": prefix, "target": target, } ds = [] ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] ds = Dataset.from_list(ds) print(ds[0]) ds = ds.map(_tokenize) ds = ds.with_format("torch") out = [] with torch.no_grad(): for elem in tqdm(ds, desc="Computing likelihood..."): prefix = elem["prefix"] target = elem["target"] # likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py if self.nll_type == 'mc': ll = -self._eval_target_nll_mc(prefix, target) if self.log_type == 'union': ll = ll / (len(target) + len(prefix)) elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf': ll = -self._eval_target_nll_ar(prefix, target) else: raise NotImplementedError(self.nll_type) # TODO: greedy decoding is_target_greedy_dec = False out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) return out def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: raise NotImplementedError def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: raise NotImplementedError def _update_block_completion_states(self, block_states, decoded_token_threshold): """ Updates the complete/incomplete state of blocks. Iterates through blocks from front to back. If a block's decoded token count exceeds the threshold, the next block to its right (if it exists) is set to a complete state. """ for block_id in sorted(block_states.keys()): # if block_id == 0: # Skip prompt block # continue # Calculate decoded tokens for the current block decoded_tokens = block_states[block_id]['total_masks'] - block_states[block_id]['mask_count'] decode_ratio = decoded_tokens / block_states[block_id]['total_masks'] # If current block's decoded token count exceeds the threshold, the next block (if exists) is set to a complete state # print("decode_ratio",decode_ratio) # print("decoded_token_threshold",decoded_token_threshold) if decode_ratio >= decoded_token_threshold: next_block_id = block_id + 1 if next_block_id in block_states: block_states[next_block_id]['is_complete'] = True if __name__ == "__main__": set_seed(1234) cli_evaluate()