VelocityLM / src /inference /inference.py
dixisouls's picture
Initial Commit
27b9282
"""Text generation utilities for the trained model."""
import torch
import torch.nn.functional as F
from typing import List, Optional, Union
from transformers import AutoTokenizer
import logging
logger = logging.getLogger(__name__)
class TextGenerator:
"""Text generation with various decoding strategies."""
def __init__(self, model, tokenizer, device='cuda'):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model.to(device)
self.model.eval()
@torch.no_grad()
def generate(
self,
prompt: Union[str, List[str]],
max_length: int = 100,
temperature: float = 1.0,
top_k: Optional[int] = 50,
top_p: Optional[float] = 0.9,
num_return_sequences: int = 1,
do_sample: bool = True,
repetition_penalty: float = 1.0,
) -> List[str]:
"""Generate text from prompt(s)."""
# Handle single string input
if isinstance(prompt, str):
prompts = [prompt]
else:
prompts = prompt
# Tokenize prompts
inputs = self.tokenizer(
prompts,
return_tensors='pt',
padding=True,
truncation=True,
max_length=max_length,
).to(self.device)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
# Generate
batch_size = input_ids.shape[0]
generated_ids = input_ids.clone()
for _ in range(max_length - input_ids.shape[1]):
# Get model predictions
outputs = self.model(
input_ids=generated_ids,
attention_mask=attention_mask,
)
# Get logits for the last token
next_token_logits = outputs.logits[:, -1, :]
# Apply repetition penalty
if repetition_penalty != 1.0:
for i in range(batch_size):
for token_id in set(generated_ids[i].tolist()):
next_token_logits[i, token_id] /= repetition_penalty
# Apply temperature
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Apply top-k filtering
if top_k is not None:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Apply top-p (nucleus) filtering
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
next_token_logits[indices_to_remove] = float('-inf')
# Sample from the distribution
if do_sample:
probs = F.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_logits, dim=-1)
# Append to generated sequence
generated_ids = torch.cat([generated_ids, next_tokens.unsqueeze(1)], dim=1)
# Update attention mask
attention_mask = torch.cat([
attention_mask,
torch.ones((batch_size, 1), device=self.device)
], dim=1)
# Check for EOS token
if (next_tokens == self.tokenizer.eos_token_id).all():
break
# Decode generated sequences
generated_texts = []
for i in range(batch_size):
generated_text = self.tokenizer.decode(
generated_ids[i],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
generated_texts.append(generated_text)
return generated_texts
def beam_search(
self,
prompt: str,
max_length: int = 100,
num_beams: int = 4,
length_penalty: float = 1.0,
early_stopping: bool = True,
) -> str:
"""Generate text using beam search."""
# Implementation of beam search
# This is a simplified version - full implementation would be more complex
inputs = self.tokenizer(
prompt,
return_tensors='pt',
truncation=True,
max_length=max_length,
).to(self.device)
# For now, fallback to greedy decoding
return self.generate(
prompt,
max_length=max_length,
do_sample=False,
num_return_sequences=1
)[0]
def load_generator(checkpoint_path: str, device: str = 'cuda'):
"""Load model and create generator."""
import yaml
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent.parent))
from src.model.transformer import TransformerForCausalLM
# Load config
config_path = Path(checkpoint_path) / 'config.json'
with open(config_path, 'r') as f:
import json
config = json.load(f)
# Create model config
class ModelConfig:
def __init__(self, config_dict):
for key, value in config_dict.items():
setattr(self, key, value)
model_config = ModelConfig(config['model'])
# Load model
model = TransformerForCausalLM(model_config)
state_dict = torch.load(
Path(checkpoint_path) / 'pytorch_model.bin',
map_location=device
)
model.load_state_dict(state_dict)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['tokenizer_name'])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Create generator
generator = TextGenerator(model, tokenizer, device)
return generator
if __name__ == '__main__':
"""Example usage."""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
parser.add_argument('--prompt', type=str, required=True, help='Input prompt')
parser.add_argument('--max-length', type=int, default=100, help='Maximum generation length')
parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature')
parser.add_argument('--top-k', type=int, default=50, help='Top-k filtering')
parser.add_argument('--top-p', type=float, default=0.9, help='Top-p (nucleus) filtering')
parser.add_argument('--device', type=str, default='cuda', help='Device to use')
args = parser.parse_args()
# Load generator
print("Loading model...")
generator = load_generator(args.checkpoint, args.device)
# Generate text
print(f"Prompt: {args.prompt}")
print("Generating...")
generated = generator.generate(
args.prompt,
max_length=args.max_length,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
print(f"Generated: {generated[0]}")