File size: 3,766 Bytes
98926d5
 
 
 
 
 
 
 
 
2e62b02
 
 
98926d5
19642cc
 
 
 
98926d5
 
 
 
2e62b02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98926d5
 
2e62b02
 
 
 
 
 
 
 
 
 
 
 
98926d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e62b02
 
 
 
 
 
 
 
 
98926d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
os.environ.setdefault("HF_HOME", "/tmp/hf")
os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/hub")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
os.environ.setdefault("NANOCHAT_BASE_DIR", "/tmp/nanochat")

from huggingface_hub import hf_hub_download
import torch
import gradio as gr
import json
import pickle
from nanochat.gpt import GPT, GPTConfig

# Hardcoded model selection for this Space
MODEL_REPO = "loocorez/nanochat-mid-d20-step765"
STEP = "000765"
DEPTH = "20"

ckpt_dir = f"/tmp/ckpt/d{DEPTH}"
os.makedirs(ckpt_dir, exist_ok=True)

tok_local = hf_hub_download(MODEL_REPO, "tokenizer/tokenizer.pkl", local_dir="/tmp", local_dir_use_symlinks=False)

model_path = hf_hub_download(MODEL_REPO, f"mid_checkpoints/d{DEPTH}/model_{STEP}.pt", local_dir=ckpt_dir, local_dir_use_symlinks=False)
meta_path = hf_hub_download(MODEL_REPO, f"mid_checkpoints/d{DEPTH}/meta_{STEP}.json", local_dir=ckpt_dir, local_dir_use_symlinks=False)

class PklTokenizer:
    def __init__(self, pkl_path):
        with open(pkl_path, "rb") as f:
            self.enc = pickle.load(f)
        self._bos_id = self.encode_special("<|bos|>")
    def get_bos_token_id(self):
        return self._bos_id
    def encode_special(self, text):
        return self.enc.encode_single_token(text)
    def encode(self, text):
        return self.enc.encode_ordinary(text)
    def decode(self, ids):
        return self.enc.decode(ids)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open(meta_path, "r") as f:
    meta = json.load(f)
cfg = GPTConfig(**meta["model_config"])
with torch.device("meta"):
    model = GPT(cfg)
model.to_empty(device=device)
model.init_weights()
state = torch.load(model_path, map_location=device)
state = {k.lstrip("_orig_mod."): v for k, v in state.items()}
model.load_state_dict(state, strict=True, assign=True)
model.eval()
tokenizer = PklTokenizer(tok_local)

def chat_fn(history, temperature=0.8, top_k=50, max_new_tokens=256):
    bos = tokenizer.get_bos_token_id()
    user_start = tokenizer.encode_special("<|user_start|>")
    user_end = tokenizer.encode_special("<|user_end|>")
    assistant_start = tokenizer.encode_special("<|assistant_start|>")
    assistant_end = tokenizer.encode_special("<|assistant_end|>")

    tokens = [bos]
    for role, content in history:
        if role == "user":
            tokens += [user_start] + tokenizer.encode(content) + [user_end]
        else:
            tokens += [assistant_start] + tokenizer.encode(content) + [assistant_end]
    tokens += [assistant_start]

    generated = []
    use_cuda = device.type == "cuda"
    dtype = torch.bfloat16 if use_cuda else torch.float32
    with torch.amp.autocast(device_type=("cuda" if use_cuda else "cpu"), dtype=dtype):
        for token in model.generate(tokens, max_tokens=max_new_tokens, temperature=temperature, top_k=top_k):
            if token == assistant_end or token == bos:
                break
            generated.append(token)
    return tokenizer.decode(generated)

with gr.Blocks() as demo:
    gr.Markdown("# NanoChat MID")
    chat = gr.Chatbot(type="tuple")
    msg = gr.Textbox()
    temp = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature")
    topk = gr.Slider(1, 200, value=50, step=1, label="Top-k")
    max_toks = gr.Slider(16, 1024, value=256, step=16, label="Max new tokens")

    def respond(user_message, chat_history, temperature, top_k, max_new_tokens):
        chat_history = chat_history + [("user", user_message)]
        reply = chat_fn(chat_history, temperature, top_k, max_new_tokens)
        chat_history = chat_history + [("assistant", reply)]
        return "", chat_history

    msg.submit(respond, [msg, chat, temp, topk, max_toks], [msg, chat])

demo.launch()