update model
Browse files- evaluation.ipynb +56 -12
- model.onnx +2 -2
- weights.pb +2 -2
evaluation.ipynb
CHANGED
|
@@ -86,14 +86,18 @@
|
|
| 86 |
" input_ids = pad(input_ids, (0, pad_len), value=1)\n",
|
| 87 |
" ort_inputs = {\n",
|
| 88 |
" 'input_ids': input_ids.detach().cpu().numpy(),\n",
|
| 89 |
-
" 'attention_mask': torch.ones(input_ids.shape).detach().cpu().numpy().astype('int64')\n",
|
| 90 |
" }\n",
|
|
|
|
|
|
|
|
|
|
| 91 |
" predictions = session.run(None, ort_inputs)\n",
|
| 92 |
" outputs = torch.from_numpy(predictions[0]) \n",
|
| 93 |
" last_token_logits = outputs[:, -2 - pad_len, :]\n",
|
| 94 |
" pred = last_token_logits.argmax(dim=-1)\n",
|
| 95 |
" total += label.size(0)\n",
|
| 96 |
" hit += (pred == label).sum().item()\n",
|
|
|
|
| 97 |
"acc = hit / total\n",
|
| 98 |
"print('acc: ', acc)"
|
| 99 |
]
|
|
@@ -132,19 +136,59 @@
|
|
| 132 |
"\n",
|
| 133 |
"print(\"prompt: \", prompt)\n",
|
| 134 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
"# start\n",
|
| 136 |
-
"
|
| 137 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
" inp = {'input_ids': input_ids.detach().cpu().numpy(),\n",
|
| 139 |
-
" 'attention_mask':
|
| 140 |
-
"
|
| 141 |
-
"
|
| 142 |
-
"
|
| 143 |
-
"
|
| 144 |
-
"
|
| 145 |
-
"
|
| 146 |
-
"
|
| 147 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
]
|
| 149 |
}
|
| 150 |
],
|
|
|
|
| 86 |
" input_ids = pad(input_ids, (0, pad_len), value=1)\n",
|
| 87 |
" ort_inputs = {\n",
|
| 88 |
" 'input_ids': input_ids.detach().cpu().numpy(),\n",
|
| 89 |
+
" 'attention_mask': torch.cat([torch.ones(input_ids.shape), torch.ones([1, 1])], dim=-1).detach().cpu().numpy().astype('int64')\n",
|
| 90 |
" }\n",
|
| 91 |
+
" for i in range(28):\n",
|
| 92 |
+
" ort_inputs[\"past_key_values.{}.key\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n",
|
| 93 |
+
" ort_inputs[\"past_key_values.{}.value\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n",
|
| 94 |
" predictions = session.run(None, ort_inputs)\n",
|
| 95 |
" outputs = torch.from_numpy(predictions[0]) \n",
|
| 96 |
" last_token_logits = outputs[:, -2 - pad_len, :]\n",
|
| 97 |
" pred = last_token_logits.argmax(dim=-1)\n",
|
| 98 |
" total += label.size(0)\n",
|
| 99 |
" hit += (pred == label).sum().item()\n",
|
| 100 |
+
"\n",
|
| 101 |
"acc = hit / total\n",
|
| 102 |
"print('acc: ', acc)"
|
| 103 |
]
|
|
|
|
| 136 |
"\n",
|
| 137 |
"print(\"prompt: \", prompt)\n",
|
| 138 |
"\n",
|
| 139 |
+
"total_time = 0.0\n",
|
| 140 |
+
"num_iter = 10\n",
|
| 141 |
+
"num_warmup = 3\n",
|
| 142 |
+
"\n",
|
| 143 |
"# start\n",
|
| 144 |
+
"for idx in range(num_iter):\n",
|
| 145 |
+
" text = []\n",
|
| 146 |
+
" tic = time.time()\n",
|
| 147 |
+
"\n",
|
| 148 |
+
" input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
|
| 149 |
+
"\n",
|
| 150 |
+
" attention_mask = torch.ones(input_ids.shape[1] +1)\n",
|
| 151 |
+
" attention_mask[0] = 0\n",
|
| 152 |
+
" attention_mask = attention_mask.unsqueeze(0)\n",
|
| 153 |
+
"\n",
|
| 154 |
" inp = {'input_ids': input_ids.detach().cpu().numpy(),\n",
|
| 155 |
+
" 'attention_mask': attention_mask.detach().cpu().numpy().astype('int64')}\n",
|
| 156 |
+
" for i in range(28):\n",
|
| 157 |
+
" inp[\"past_key_values.{}.key\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
|
| 158 |
+
" inp[\"past_key_values.{}.value\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
|
| 159 |
+
"\n",
|
| 160 |
+
" for i in range(32):\n",
|
| 161 |
+
"\n",
|
| 162 |
+
" output = session.run(None, inp)\n",
|
| 163 |
+
" logits = output[0]\n",
|
| 164 |
+
" logits = torch.from_numpy(logits)\n",
|
| 165 |
+
" next_token_logits = logits[:, -1, :]\n",
|
| 166 |
+
" probs = torch.nn.functional.softmax(next_token_logits, dim=-1)\n",
|
| 167 |
+
" next_tokens = torch.argmax(probs, dim=-1)\n",
|
| 168 |
+
" present_kv = output[1]\n",
|
| 169 |
+
" for i in range(28):\n",
|
| 170 |
+
"\n",
|
| 171 |
+
" if step == 0:\n",
|
| 172 |
+
" inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1][:, :, 1:, :]\n",
|
| 173 |
+
" inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2][:, :, 1:, :]\n",
|
| 174 |
+
" else:\n",
|
| 175 |
+
" inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1]\n",
|
| 176 |
+
" inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2]\n",
|
| 177 |
+
"\n",
|
| 178 |
+
" input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n",
|
| 179 |
+
" if step == 0:\n",
|
| 180 |
+
" attention_mask = torch.cat([attention_mask[:, 1:], torch.ones([1, 1])], dim=-1)\n",
|
| 181 |
+
" else:\n",
|
| 182 |
+
" attention_mask = torch.cat([attention_mask, torch.ones([1, 1])], dim=-1)\n",
|
| 183 |
+
"\n",
|
| 184 |
+
" inp['attention_mask'] = attention_mask.detach().cpu().numpy().astype('int64')\n",
|
| 185 |
+
" inp['input_ids'] = input_ids[:, -1:].detach().cpu().numpy()\n",
|
| 186 |
+
"\n",
|
| 187 |
+
" print(tokenizer.decode(input_ids[0]))\n",
|
| 188 |
+
" toc = time.time()\n",
|
| 189 |
+
" if idx >= num_warmup:\n",
|
| 190 |
+
" total_time += (toc - tic)\n",
|
| 191 |
+
"print(\"Inference latency: %.3f s.\" % (total_time / (num_iter - num_warmup)))"
|
| 192 |
]
|
| 193 |
}
|
| 194 |
],
|
model.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:99af1fc6a93e6b02902f3f4c3fe32bf3d7bb4441406bee3bf0cbceaa5b9f64e3
|
| 3 |
+
size 6332176
|
weights.pb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9641d64847996acc53c7093cf4ff9c02443b9c4fd61699cb9ac00b86861c528
|
| 3 |
+
size 6057661312
|