Spaces:
Running
on
Zero
Running
on
Zero
| """Text generation utilities for the trained model.""" | |
| import torch | |
| import torch.nn.functional as F | |
| from typing import List, Optional, Union | |
| from transformers import AutoTokenizer | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class TextGenerator: | |
| """Text generation with various decoding strategies.""" | |
| def __init__(self, model, tokenizer, device='cuda'): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.device = device | |
| self.model.to(device) | |
| self.model.eval() | |
| def generate( | |
| self, | |
| prompt: Union[str, List[str]], | |
| max_length: int = 100, | |
| temperature: float = 1.0, | |
| top_k: Optional[int] = 50, | |
| top_p: Optional[float] = 0.9, | |
| num_return_sequences: int = 1, | |
| do_sample: bool = True, | |
| repetition_penalty: float = 1.0, | |
| ) -> List[str]: | |
| """Generate text from prompt(s).""" | |
| # Handle single string input | |
| if isinstance(prompt, str): | |
| prompts = [prompt] | |
| else: | |
| prompts = prompt | |
| # Tokenize prompts | |
| inputs = self.tokenizer( | |
| prompts, | |
| return_tensors='pt', | |
| padding=True, | |
| truncation=True, | |
| max_length=max_length, | |
| ).to(self.device) | |
| input_ids = inputs['input_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| # Generate | |
| batch_size = input_ids.shape[0] | |
| generated_ids = input_ids.clone() | |
| for _ in range(max_length - input_ids.shape[1]): | |
| # Get model predictions | |
| outputs = self.model( | |
| input_ids=generated_ids, | |
| attention_mask=attention_mask, | |
| ) | |
| # Get logits for the last token | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # Apply repetition penalty | |
| if repetition_penalty != 1.0: | |
| for i in range(batch_size): | |
| for token_id in set(generated_ids[i].tolist()): | |
| next_token_logits[i, token_id] /= repetition_penalty | |
| # Apply temperature | |
| if temperature != 1.0: | |
| next_token_logits = next_token_logits / temperature | |
| # Apply top-k filtering | |
| if top_k is not None: | |
| indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
| next_token_logits[indices_to_remove] = float('-inf') | |
| # Apply top-p (nucleus) filtering | |
| if top_p is not None: | |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| # Remove tokens with cumulative probability above the threshold | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter( | |
| 1, sorted_indices, sorted_indices_to_remove | |
| ) | |
| next_token_logits[indices_to_remove] = float('-inf') | |
| # Sample from the distribution | |
| if do_sample: | |
| probs = F.softmax(next_token_logits, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| else: | |
| next_tokens = torch.argmax(next_token_logits, dim=-1) | |
| # Append to generated sequence | |
| generated_ids = torch.cat([generated_ids, next_tokens.unsqueeze(1)], dim=1) | |
| # Update attention mask | |
| attention_mask = torch.cat([ | |
| attention_mask, | |
| torch.ones((batch_size, 1), device=self.device) | |
| ], dim=1) | |
| # Check for EOS token | |
| if (next_tokens == self.tokenizer.eos_token_id).all(): | |
| break | |
| # Decode generated sequences | |
| generated_texts = [] | |
| for i in range(batch_size): | |
| generated_text = self.tokenizer.decode( | |
| generated_ids[i], | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True | |
| ) | |
| generated_texts.append(generated_text) | |
| return generated_texts | |
| def beam_search( | |
| self, | |
| prompt: str, | |
| max_length: int = 100, | |
| num_beams: int = 4, | |
| length_penalty: float = 1.0, | |
| early_stopping: bool = True, | |
| ) -> str: | |
| """Generate text using beam search.""" | |
| # Implementation of beam search | |
| # This is a simplified version - full implementation would be more complex | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors='pt', | |
| truncation=True, | |
| max_length=max_length, | |
| ).to(self.device) | |
| # For now, fallback to greedy decoding | |
| return self.generate( | |
| prompt, | |
| max_length=max_length, | |
| do_sample=False, | |
| num_return_sequences=1 | |
| )[0] | |
| def load_generator(checkpoint_path: str, device: str = 'cuda'): | |
| """Load model and create generator.""" | |
| import yaml | |
| from pathlib import Path | |
| import sys | |
| sys.path.append(str(Path(__file__).parent.parent.parent)) | |
| from src.model.transformer import TransformerForCausalLM | |
| # Load config | |
| config_path = Path(checkpoint_path) / 'config.json' | |
| with open(config_path, 'r') as f: | |
| import json | |
| config = json.load(f) | |
| # Create model config | |
| class ModelConfig: | |
| def __init__(self, config_dict): | |
| for key, value in config_dict.items(): | |
| setattr(self, key, value) | |
| model_config = ModelConfig(config['model']) | |
| # Load model | |
| model = TransformerForCausalLM(model_config) | |
| state_dict = torch.load( | |
| Path(checkpoint_path) / 'pytorch_model.bin', | |
| map_location=device | |
| ) | |
| model.load_state_dict(state_dict) | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['tokenizer_name']) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Create generator | |
| generator = TextGenerator(model, tokenizer, device) | |
| return generator | |
| if __name__ == '__main__': | |
| """Example usage.""" | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint') | |
| parser.add_argument('--prompt', type=str, required=True, help='Input prompt') | |
| parser.add_argument('--max-length', type=int, default=100, help='Maximum generation length') | |
| parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature') | |
| parser.add_argument('--top-k', type=int, default=50, help='Top-k filtering') | |
| parser.add_argument('--top-p', type=float, default=0.9, help='Top-p (nucleus) filtering') | |
| parser.add_argument('--device', type=str, default='cuda', help='Device to use') | |
| args = parser.parse_args() | |
| # Load generator | |
| print("Loading model...") | |
| generator = load_generator(args.checkpoint, args.device) | |
| # Generate text | |
| print(f"Prompt: {args.prompt}") | |
| print("Generating...") | |
| generated = generator.generate( | |
| args.prompt, | |
| max_length=args.max_length, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| top_p=args.top_p, | |
| ) | |
| print(f"Generated: {generated[0]}") |