Spaces:
Running
Running
Adjust limit, turn off caching for logprobs
Browse files
app.py
CHANGED
|
@@ -96,11 +96,12 @@ tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)
|
|
| 96 |
# Make sure the logprobs do not use up more than ~4 GB of memory
|
| 97 |
MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
|
| 98 |
# Select window lengths such that we are allowed to fill the whole window without running out of memory
|
| 99 |
-
# (otherwise the window length is irrelevant)
|
| 100 |
-
|
|
|
|
| 101 |
window_len_options = [
|
| 102 |
w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 103 |
-
if w == 8 or w * (2 * w) *
|
| 104 |
]
|
| 105 |
window_len = st.select_slider(
|
| 106 |
r"Window size ($c_\text{max}$)",
|
|
@@ -109,8 +110,7 @@ window_len = st.select_slider(
|
|
| 109 |
)
|
| 110 |
# Now figure out how many tokens we are allowed to use:
|
| 111 |
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
|
| 112 |
-
max_tokens = int(MAX_MEM / (
|
| 113 |
-
max_tokens = min(max_tokens, 2048)
|
| 114 |
|
| 115 |
DEFAULT_TEXT = """
|
| 116 |
We present context length probing, a novel explanation technique for causal
|
|
@@ -151,10 +151,8 @@ with st.spinner("Loading model…"):
|
|
| 151 |
|
| 152 |
window_len = min(window_len, len(input_ids))
|
| 153 |
|
| 154 |
-
@st.cache_data(show_spinner=False)
|
| 155 |
@torch.inference_mode()
|
| 156 |
-
def get_logprobs(_model, _inputs
|
| 157 |
-
del cache_key
|
| 158 |
return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16)
|
| 159 |
|
| 160 |
@st.cache_data(show_spinner=False)
|
|
@@ -179,7 +177,7 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, metric,
|
|
| 179 |
batch_logprobs = get_logprobs(
|
| 180 |
_model,
|
| 181 |
batch,
|
| 182 |
-
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
| 183 |
)
|
| 184 |
batch_labels = batch["labels"]
|
| 185 |
if metric != "KL divergence":
|
|
|
|
| 96 |
# Make sure the logprobs do not use up more than ~4 GB of memory
|
| 97 |
MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
|
| 98 |
# Select window lengths such that we are allowed to fill the whole window without running out of memory
|
| 99 |
+
# (otherwise the window length is irrelevant); if using NLL, memory is not a consideration, but we want
|
| 100 |
+
# to limit runtime
|
| 101 |
+
multiplier = tokenizer.vocab_size if metric_name == "KL divergence" else 16384 # arbitrary number
|
| 102 |
window_len_options = [
|
| 103 |
w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 104 |
+
if w == 8 or w * (2 * w) * multiplier <= MAX_MEM
|
| 105 |
]
|
| 106 |
window_len = st.select_slider(
|
| 107 |
r"Window size ($c_\text{max}$)",
|
|
|
|
| 110 |
)
|
| 111 |
# Now figure out how many tokens we are allowed to use:
|
| 112 |
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
|
| 113 |
+
max_tokens = int(MAX_MEM / (multiplier * window_len) - window_len)
|
|
|
|
| 114 |
|
| 115 |
DEFAULT_TEXT = """
|
| 116 |
We present context length probing, a novel explanation technique for causal
|
|
|
|
| 151 |
|
| 152 |
window_len = min(window_len, len(input_ids))
|
| 153 |
|
|
|
|
| 154 |
@torch.inference_mode()
|
| 155 |
+
def get_logprobs(_model, _inputs):
|
|
|
|
| 156 |
return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16)
|
| 157 |
|
| 158 |
@st.cache_data(show_spinner=False)
|
|
|
|
| 177 |
batch_logprobs = get_logprobs(
|
| 178 |
_model,
|
| 179 |
batch,
|
| 180 |
+
#cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
| 181 |
)
|
| 182 |
batch_labels = batch["labels"]
|
| 183 |
if metric != "KL divergence":
|