Spaces:
Running
Running
More efficient NLL implementation
Browse files
app.py
CHANGED
|
@@ -17,7 +17,7 @@ def get_windows_batched(examples: BatchEncoding, window_len: int, stride: int =
|
|
| 17 |
return BatchEncoding({
|
| 18 |
k: [
|
| 19 |
t[i][j : j + window_len] + [
|
| 20 |
-
pad_id if k
|
| 21 |
] * (j + window_len - len(t[i]))
|
| 22 |
for i in range(len(examples["input_ids"]))
|
| 23 |
for j in range(0, len(examples["input_ids"][i]) - 1, stride)
|
|
@@ -43,7 +43,10 @@ def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False):
|
|
| 43 |
return result
|
| 44 |
|
| 45 |
def nll_score(logprobs, labels):
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def kl_div_score(logprobs):
|
| 49 |
log_p = logprobs[
|
|
@@ -75,8 +78,18 @@ if not compact_layout:
|
|
| 75 |
"""
|
| 76 |
)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
|
| 79 |
-
metric_name = st.
|
|
|
|
|
|
|
| 80 |
|
| 81 |
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
|
| 82 |
|
|
@@ -84,9 +97,10 @@ tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)
|
|
| 84 |
MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
|
| 85 |
# Select window lengths such that we are allowed to fill the whole window without running out of memory
|
| 86 |
# (otherwise the window length is irrelevant)
|
|
|
|
| 87 |
window_len_options = [
|
| 88 |
w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 89 |
-
if w == 8 or w * (2 * w) *
|
| 90 |
]
|
| 91 |
window_len = st.select_slider(
|
| 92 |
r"Window size ($c_\text{max}$)",
|
|
@@ -95,7 +109,8 @@ window_len = st.select_slider(
|
|
| 95 |
)
|
| 96 |
# Now figure out how many tokens we are allowed to use:
|
| 97 |
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
|
| 98 |
-
max_tokens = int(MAX_MEM / (
|
|
|
|
| 99 |
|
| 100 |
DEFAULT_TEXT = """
|
| 101 |
We present context length probing, a novel explanation technique for causal
|
|
@@ -117,6 +132,7 @@ if tokenizer.eos_token:
|
|
| 117 |
text += tokenizer.eos_token
|
| 118 |
inputs = tokenizer([text])
|
| 119 |
[input_ids] = inputs["input_ids"]
|
|
|
|
| 120 |
num_user_tokens = len(input_ids) - (1 if tokenizer.eos_token else 0)
|
| 121 |
|
| 122 |
if num_user_tokens < 1:
|
|
@@ -160,13 +176,17 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, metric,
|
|
| 160 |
for i in range(0, num_items, batch_size):
|
| 161 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
| 162 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
| 168 |
-
)
|
| 169 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
logprobs = torch.cat(logprobs, dim=0)
|
| 171 |
pbar.empty()
|
| 172 |
|
|
|
|
| 17 |
return BatchEncoding({
|
| 18 |
k: [
|
| 19 |
t[i][j : j + window_len] + [
|
| 20 |
+
pad_id if k in ["input_ids", "labels"] else 0
|
| 21 |
] * (j + window_len - len(t[i]))
|
| 22 |
for i in range(len(examples["input_ids"]))
|
| 23 |
for j in range(0, len(examples["input_ids"][i]) - 1, stride)
|
|
|
|
| 43 |
return result
|
| 44 |
|
| 45 |
def nll_score(logprobs, labels):
|
| 46 |
+
if logprobs.shape[-1] == 1:
|
| 47 |
+
return -logprobs.squeeze(-1)
|
| 48 |
+
else:
|
| 49 |
+
return -logprobs[:, torch.arange(len(labels)), labels]
|
| 50 |
|
| 51 |
def kl_div_score(logprobs):
|
| 52 |
log_p = logprobs[
|
|
|
|
| 78 |
"""
|
| 79 |
)
|
| 80 |
|
| 81 |
+
generation_mode = False
|
| 82 |
+
# st.radio("Mode", ["Standard", "Generation"], horizontal=True) == "Generation"
|
| 83 |
+
# st.caption(
|
| 84 |
+
# "In standard mode, we analyze the model's predictions on the input text. "
|
| 85 |
+
# "In generation mode, we generate a continuation of the input text "
|
| 86 |
+
# "and visualize the contributions of different contexts to each generated token."
|
| 87 |
+
# )
|
| 88 |
+
|
| 89 |
model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
|
| 90 |
+
metric_name = st.radio(
|
| 91 |
+
"Metric", (["KL divergence"] if not generation_mode else []) + ["NLL loss"], index=0, horizontal=True
|
| 92 |
+
)
|
| 93 |
|
| 94 |
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
|
| 95 |
|
|
|
|
| 97 |
MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
|
| 98 |
# Select window lengths such that we are allowed to fill the whole window without running out of memory
|
| 99 |
# (otherwise the window length is irrelevant)
|
| 100 |
+
logprobs_dim = tokenizer.vocab_size if metric_name == "KL divergence" else 1
|
| 101 |
window_len_options = [
|
| 102 |
w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 103 |
+
if w == 8 or w * (2 * w) * logprobs_dim <= MAX_MEM
|
| 104 |
]
|
| 105 |
window_len = st.select_slider(
|
| 106 |
r"Window size ($c_\text{max}$)",
|
|
|
|
| 109 |
)
|
| 110 |
# Now figure out how many tokens we are allowed to use:
|
| 111 |
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
|
| 112 |
+
max_tokens = int(MAX_MEM / (logprobs_dim * window_len) - window_len)
|
| 113 |
+
max_tokens = min(max_tokens, 2048)
|
| 114 |
|
| 115 |
DEFAULT_TEXT = """
|
| 116 |
We present context length probing, a novel explanation technique for causal
|
|
|
|
| 132 |
text += tokenizer.eos_token
|
| 133 |
inputs = tokenizer([text])
|
| 134 |
[input_ids] = inputs["input_ids"]
|
| 135 |
+
inputs["labels"] = [[*input_ids[1:], tokenizer.eos_token_id]]
|
| 136 |
num_user_tokens = len(input_ids) - (1 if tokenizer.eos_token else 0)
|
| 137 |
|
| 138 |
if num_user_tokens < 1:
|
|
|
|
| 176 |
for i in range(0, num_items, batch_size):
|
| 177 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
| 178 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
| 179 |
+
batch_logprobs = get_logprobs(
|
| 180 |
+
_model,
|
| 181 |
+
batch,
|
| 182 |
+
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
|
|
|
|
|
|
| 183 |
)
|
| 184 |
+
batch_labels = batch["labels"]
|
| 185 |
+
if metric != "KL divergence":
|
| 186 |
+
batch_logprobs = torch.gather(
|
| 187 |
+
batch_logprobs, dim=-1, index=batch_labels[..., None]
|
| 188 |
+
)
|
| 189 |
+
logprobs.append(batch_logprobs)
|
| 190 |
logprobs = torch.cat(logprobs, dim=0)
|
| 191 |
pbar.empty()
|
| 192 |
|