import os os.environ.setdefault("HF_HOME", "/tmp/hf") os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/hub") os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers") os.environ.setdefault("NANOCHAT_BASE_DIR", "/tmp/nanochat") from huggingface_hub import hf_hub_download import torch import gradio as gr import json import pickle from nanochat.gpt import GPT, GPTConfig # Hardcoded model selection for this Space MODEL_REPO = "loocorez/nanochat-mid-d20-step765" STEP = "000765" DEPTH = "20" ckpt_dir = f"/tmp/ckpt/d{DEPTH}" os.makedirs(ckpt_dir, exist_ok=True) tok_local = hf_hub_download(MODEL_REPO, "tokenizer/tokenizer.pkl", local_dir="/tmp", local_dir_use_symlinks=False) model_path = hf_hub_download(MODEL_REPO, f"mid_checkpoints/d{DEPTH}/model_{STEP}.pt", local_dir=ckpt_dir, local_dir_use_symlinks=False) meta_path = hf_hub_download(MODEL_REPO, f"mid_checkpoints/d{DEPTH}/meta_{STEP}.json", local_dir=ckpt_dir, local_dir_use_symlinks=False) class PklTokenizer: def __init__(self, pkl_path): with open(pkl_path, "rb") as f: self.enc = pickle.load(f) self._bos_id = self.encode_special("<|bos|>") def get_bos_token_id(self): return self._bos_id def encode_special(self, text): return self.enc.encode_single_token(text) def encode(self, text): return self.enc.encode_ordinary(text) def decode(self, ids): return self.enc.decode(ids) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with open(meta_path, "r") as f: meta = json.load(f) cfg = GPTConfig(**meta["model_config"]) with torch.device("meta"): model = GPT(cfg) model.to_empty(device=device) model.init_weights() state = torch.load(model_path, map_location=device) state = {k.lstrip("_orig_mod."): v for k, v in state.items()} model.load_state_dict(state, strict=True, assign=True) model.eval() tokenizer = PklTokenizer(tok_local) def chat_fn(history, temperature=0.8, top_k=50, max_new_tokens=256): bos = tokenizer.get_bos_token_id() user_start = tokenizer.encode_special("<|user_start|>") user_end = tokenizer.encode_special("<|user_end|>") assistant_start = tokenizer.encode_special("<|assistant_start|>") assistant_end = tokenizer.encode_special("<|assistant_end|>") tokens = [bos] for role, content in history: if role == "user": tokens += [user_start] + tokenizer.encode(content) + [user_end] else: tokens += [assistant_start] + tokenizer.encode(content) + [assistant_end] tokens += [assistant_start] generated = [] use_cuda = device.type == "cuda" dtype = torch.bfloat16 if use_cuda else torch.float32 with torch.amp.autocast(device_type=("cuda" if use_cuda else "cpu"), dtype=dtype): for token in model.generate(tokens, max_tokens=max_new_tokens, temperature=temperature, top_k=top_k): if token == assistant_end or token == bos: break generated.append(token) return tokenizer.decode(generated) with gr.Blocks() as demo: gr.Markdown("# NanoChat MID") chat = gr.Chatbot(type="tuple") msg = gr.Textbox() temp = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature") topk = gr.Slider(1, 200, value=50, step=1, label="Top-k") max_toks = gr.Slider(16, 1024, value=256, step=16, label="Max new tokens") def respond(user_message, chat_history, temperature, top_k, max_new_tokens): chat_history = chat_history + [("user", user_message)] reply = chat_fn(chat_history, temperature, top_k, max_new_tokens) chat_history = chat_history + [("assistant", reply)] return "", chat_history msg.submit(respond, [msg, chat, temp, topk, max_toks], [msg, chat]) demo.launch()