Spaces:
Paused
Paused
| from config import ModelArgs | |
| from model import Llama | |
| import torch | |
| import torch.nn.functional as F | |
| from tokenizer import Tokenizer | |
| import argparse | |
| tokenizer = Tokenizer() | |
| tokenizer = tokenizer.ready_tokenizer() | |
| def remove_prefix(state_dict, prefix): | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith(prefix): | |
| new_key = key[len(prefix):] # Remove the prefix | |
| new_state_dict[new_key] = value | |
| else: | |
| new_state_dict[key] = value | |
| return new_state_dict | |
| def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0): | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| generated_tokens = [] | |
| ModelArgs.inference=True | |
| for _ in range(max_length): | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| logits = outputs[:, -1, :] | |
| probs = F.softmax(logits, dim=-1) | |
| # Top-k filtering | |
| top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1) | |
| # Apply temperature scaling | |
| # probs = probs / temperature | |
| # Sample from top-k | |
| next_token = torch.multinomial(top_k_probs, num_samples=1) | |
| # generated_tokens.append(next_token.item()) | |
| xcol = torch.gather(top_k_indices, -1, next_token) | |
| input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence | |
| return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
| def main(): | |
| torch.set_float32_matmul_precision('high') | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--prompt", type=str, default="Once upon a time") | |
| parser.add_argument("--max_length", type=int, default=128) | |
| parser.add_argument("--temperature", type=float, default=1.0) | |
| parser.add_argument("--top_k", type=int, default=50) | |
| # parser.add_argument("--repetition_penalty", type=float, default=1.2) | |
| args = parser.parse_args() | |
| model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, no_of_decoder_layers=ModelArgs.no_of_decoder_layers, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout) | |
| # model = torch.compile(model) | |
| model = model.to(ModelArgs.device) | |
| dict_model = torch.load('weights/pretrained/snapshot_4650.pt') | |
| dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.') | |
| model.load_state_dict(dict_model['MODEL_STATE']) | |
| model.eval() | |
| print("Model ready") | |
| # prompt = 'Its a secret' | |
| with torch.no_grad(): | |
| generated_text = topk_sampling(model, args.prompt, max_length=args.max_length, top_k=50, temperature=args.temperature, device=ModelArgs.device) | |
| print("Gnerated: ", generated_text) | |
| # generated_text = beam_search(model, tokenizer, args.prompt, beam_width=5, max_length=50, temperature=1.0) | |
| print(args.prompt + generated_text) | |
| if __name__ == '__main__': | |
| main() |