Spaces:
Running
Running
Fix sneaky problem caused by caching
Browse files
app.py
CHANGED
|
@@ -233,6 +233,11 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
|
| 233 |
model_kwargs = dict(use_cache=True)
|
| 234 |
for i in range(max_steps):
|
| 235 |
pbar.progress(i / max_steps, f"{i}/{max_steps}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 237 |
model_outputs = model(**model_inputs)
|
| 238 |
model_kwargs = model._update_model_kwargs_for_generation(model_outputs, model_kwargs, is_encoder_decoder=False)
|
|
@@ -250,7 +255,6 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
|
| 250 |
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
|
| 251 |
if input_ids.shape[1] > window_len:
|
| 252 |
input_ids = input_ids[:, 1:]
|
| 253 |
-
model_kwargs.update(use_cache=False, past_key_values=None)
|
| 254 |
if logprobs_window.shape[0] == window_len:
|
| 255 |
logprobs.append(
|
| 256 |
logprobs_window[torch.arange(window_len), input_ids.squeeze(0)]
|
|
|
|
| 233 |
model_kwargs = dict(use_cache=True)
|
| 234 |
for i in range(max_steps):
|
| 235 |
pbar.progress(i / max_steps, f"{i}/{max_steps}")
|
| 236 |
+
|
| 237 |
+
if input_ids.shape[1] == window_len:
|
| 238 |
+
model_kwargs.update(use_cache=False)
|
| 239 |
+
if "past_key_values" in model_kwargs:
|
| 240 |
+
del model_kwargs["past_key_values"]
|
| 241 |
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 242 |
model_outputs = model(**model_inputs)
|
| 243 |
model_kwargs = model._update_model_kwargs_for_generation(model_outputs, model_kwargs, is_encoder_decoder=False)
|
|
|
|
| 255 |
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
|
| 256 |
if input_ids.shape[1] > window_len:
|
| 257 |
input_ids = input_ids[:, 1:]
|
|
|
|
| 258 |
if logprobs_window.shape[0] == window_len:
|
| 259 |
logprobs.append(
|
| 260 |
logprobs_window[torch.arange(window_len), input_ids.squeeze(0)]
|