Spaces:
Runtime error
Runtime error
| import os | |
| os.environ.setdefault("HF_HOME", "/tmp/hf") | |
| os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/hub") | |
| os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers") | |
| os.environ.setdefault("NANOCHAT_BASE_DIR", "/tmp/nanochat") | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import gradio as gr | |
| import json | |
| import pickle | |
| from nanochat.gpt import GPT, GPTConfig | |
| # Hardcoded model selection for this Space | |
| MODEL_REPO = "loocorez/nanochat-mid-d20-step765" | |
| STEP = "000765" | |
| DEPTH = "20" | |
| ckpt_dir = f"/tmp/ckpt/d{DEPTH}" | |
| os.makedirs(ckpt_dir, exist_ok=True) | |
| tok_local = hf_hub_download(MODEL_REPO, "tokenizer/tokenizer.pkl", local_dir="/tmp", local_dir_use_symlinks=False) | |
| model_path = hf_hub_download(MODEL_REPO, f"mid_checkpoints/d{DEPTH}/model_{STEP}.pt", local_dir=ckpt_dir, local_dir_use_symlinks=False) | |
| meta_path = hf_hub_download(MODEL_REPO, f"mid_checkpoints/d{DEPTH}/meta_{STEP}.json", local_dir=ckpt_dir, local_dir_use_symlinks=False) | |
| class PklTokenizer: | |
| def __init__(self, pkl_path): | |
| with open(pkl_path, "rb") as f: | |
| self.enc = pickle.load(f) | |
| self._bos_id = self.encode_special("<|bos|>") | |
| def get_bos_token_id(self): | |
| return self._bos_id | |
| def encode_special(self, text): | |
| return self.enc.encode_single_token(text) | |
| def encode(self, text): | |
| return self.enc.encode_ordinary(text) | |
| def decode(self, ids): | |
| return self.enc.decode(ids) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| with open(meta_path, "r") as f: | |
| meta = json.load(f) | |
| cfg = GPTConfig(**meta["model_config"]) | |
| with torch.device("meta"): | |
| model = GPT(cfg) | |
| model.to_empty(device=device) | |
| model.init_weights() | |
| state = torch.load(model_path, map_location=device) | |
| state = {k.lstrip("_orig_mod."): v for k, v in state.items()} | |
| model.load_state_dict(state, strict=True, assign=True) | |
| model.eval() | |
| tokenizer = PklTokenizer(tok_local) | |
| def chat_fn(history, temperature=0.8, top_k=50, max_new_tokens=256): | |
| bos = tokenizer.get_bos_token_id() | |
| user_start = tokenizer.encode_special("<|user_start|>") | |
| user_end = tokenizer.encode_special("<|user_end|>") | |
| assistant_start = tokenizer.encode_special("<|assistant_start|>") | |
| assistant_end = tokenizer.encode_special("<|assistant_end|>") | |
| tokens = [bos] | |
| for role, content in history: | |
| if role == "user": | |
| tokens += [user_start] + tokenizer.encode(content) + [user_end] | |
| else: | |
| tokens += [assistant_start] + tokenizer.encode(content) + [assistant_end] | |
| tokens += [assistant_start] | |
| generated = [] | |
| use_cuda = device.type == "cuda" | |
| dtype = torch.bfloat16 if use_cuda else torch.float32 | |
| with torch.amp.autocast(device_type=("cuda" if use_cuda else "cpu"), dtype=dtype): | |
| for token in model.generate(tokens, max_tokens=max_new_tokens, temperature=temperature, top_k=top_k): | |
| if token == assistant_end or token == bos: | |
| break | |
| generated.append(token) | |
| return tokenizer.decode(generated) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# NanoChat MID") | |
| chat = gr.Chatbot(type="tuple") | |
| msg = gr.Textbox() | |
| temp = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature") | |
| topk = gr.Slider(1, 200, value=50, step=1, label="Top-k") | |
| max_toks = gr.Slider(16, 1024, value=256, step=16, label="Max new tokens") | |
| def respond(user_message, chat_history, temperature, top_k, max_new_tokens): | |
| chat_history = chat_history + [("user", user_message)] | |
| reply = chat_fn(chat_history, temperature, top_k, max_new_tokens) | |
| chat_history = chat_history + [("assistant", reply)] | |
| return "", chat_history | |
| msg.submit(respond, [msg, chat, temp, topk, max_toks], [msg, chat]) | |
| demo.launch() | |