Spaces:
Running
Running
Allow overwrite
Browse files
app.py
CHANGED
|
@@ -316,10 +316,10 @@ def run_context_length_probing(
|
|
| 316 |
logprobs = torch.cat([unigram_logprobs.unsqueeze(0), logprobs], dim=0)
|
| 317 |
|
| 318 |
if metric == "NLL loss":
|
| 319 |
-
scores = nll_score(logprobs=logprobs, labels=label_ids)
|
| 320 |
elif metric == "KL divergence":
|
| 321 |
-
scores = kl_div_score(logprobs, labels=label_ids)
|
| 322 |
-
del logprobs # possibly
|
| 323 |
|
| 324 |
scores = (-scores).diff(dim=0).transpose(0, 1)
|
| 325 |
scores = scores.nan_to_num()
|
|
|
|
| 316 |
logprobs = torch.cat([unigram_logprobs.unsqueeze(0), logprobs], dim=0)
|
| 317 |
|
| 318 |
if metric == "NLL loss":
|
| 319 |
+
scores = nll_score(logprobs=logprobs, labels=label_ids, allow_overwrite=True)
|
| 320 |
elif metric == "KL divergence":
|
| 321 |
+
scores = kl_div_score(logprobs, labels=label_ids, allow_overwrite=True)
|
| 322 |
+
del logprobs # possibly overwritten by the score computation to save memory
|
| 323 |
|
| 324 |
scores = (-scores).diff(dim=0).transpose(0, 1)
|
| 325 |
scores = scores.nan_to_num()
|