burtenshaw HF Staff commited on
Commit
c183265
·
verified ·
1 Parent(s): 731e4f3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +27 -1
README.md CHANGED
@@ -12,7 +12,33 @@ It was trained with a depth of 20 on 2 billion tokens and corresponds to this [t
12
 
13
  ## Usage
14
 
15
- coming...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  ## Base model evaluation
 
12
 
13
  ## Usage
14
 
15
+ ```python
16
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
17
+ import torch
18
+
19
+ model_dir = "nanochat-students/base-d20"
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model = AutoModel.from_pretrained(model_dir, trust_remote_code=True)
22
+ model = model.to(device)
23
+ model.eval()
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
26
+
27
+ prompt = "The capital of Belgium is "
28
+ input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id())
29
+ ids = torch.tensor([input_ids], dtype=torch.long, device=device)
30
+
31
+ max_new_tokens = 50
32
+ with torch.inference_mode():
33
+ for _ in range(max_new_tokens):
34
+ outputs = model(input_ids=ids)
35
+ logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
36
+ next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
37
+ ids = torch.cat([ids, next_token], dim=1)
38
+
39
+ decoded = tokenizer.decode(ids[0].tolist())
40
+ print(decoded)
41
+ ```
42
 
43
 
44
  ## Base model evaluation