Spaces:
Running
Running
| from enum import Enum | |
| from pathlib import Path | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding | |
| root_dir = Path(__file__).resolve().parent | |
| highlighted_text_component = components.declare_component( | |
| "highlighted_text", path=root_dir / "highlighted_text" / "build" | |
| ) | |
| def get_windows_batched(examples: BatchEncoding, window_len: int, stride: int = 1, pad_id: int = 0) -> BatchEncoding: | |
| return BatchEncoding({ | |
| k: [ | |
| t[i][j : j + window_len] + [ | |
| pad_id if k == "input_ids" else 0 | |
| ] * (j + window_len - len(t[i])) | |
| for i in range(len(examples["input_ids"])) | |
| for j in range(0, len(examples["input_ids"][i]) - 1, stride) | |
| ] | |
| for k, t in examples.items() | |
| }) | |
| BAD_CHAR = chr(0xfffd) | |
| def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False): | |
| cur_ids = [] | |
| result = [] | |
| for idx in ids: | |
| cur_ids.append(idx) | |
| decoded = tokenizer.decode(cur_ids) | |
| if BAD_CHAR not in decoded: | |
| if strip_whitespace: | |
| decoded = decoded.strip() | |
| result.append(decoded) | |
| del cur_ids[:] | |
| else: | |
| result.append("") | |
| return result | |
| compact_layout = st.experimental_get_query_params().get("compact", ["false"]) == ["true"] | |
| if not compact_layout: | |
| st.title("Context length probing") | |
| st.markdown( | |
| """[📃 Paper](https://arxiv.org/abs/2212.14815) | | |
| [🌍 Website](https://cifkao.github.io/context-probing) | | |
| [🧑💻 Code](https://github.com/cifkao/context-probing) | |
| """ | |
| ) | |
| model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"]) | |
| metric_name = st.selectbox("Metric", ["KL divergence", "Cross entropy"], index=1) | |
| tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False) | |
| # Make sure the logprobs do not use up more than ~4 GB of memory | |
| MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8) | |
| # Select window lengths such that we are allowed to fill the whole window without running out of memory | |
| # (otherwise the window length is irrelevant) | |
| window_len_options = [ | |
| w for w in [8, 16, 32, 64, 128, 256, 512, 1024] | |
| if w == 8 or w * (2 * w) * tokenizer.vocab_size <= MAX_MEM | |
| ] | |
| window_len = st.select_slider( | |
| r"Window size ($c_\text{max}$)", | |
| options=window_len_options, | |
| value=min(128, window_len_options[-1]) | |
| ) | |
| # Now figure out how many tokens we are allowed to use: | |
| # window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM | |
| max_tokens = int(MAX_MEM / (tokenizer.vocab_size * window_len) - window_len) | |
| DEFAULT_TEXT = """ | |
| We present context length probing, a novel explanation technique for causal | |
| language models, based on tracking the predictions of a model as a function of the length of | |
| available context, and allowing to assign differential importance scores to different contexts. | |
| The technique is model-agnostic and does not rely on access to model internals beyond computing | |
| token-level probabilities. We apply context length probing to large pre-trained language models | |
| and offer some initial analyses and insights, including the potential for studying long-range | |
| dependencies. | |
| """.replace("\n", " ").strip() | |
| text = st.text_area( | |
| f"Input text (≤{max_tokens} tokens)", | |
| DEFAULT_TEXT, | |
| ) | |
| inputs = tokenizer([text]) | |
| [input_ids] = inputs["input_ids"] | |
| if len(input_ids) < 2: | |
| st.error("Please enter at least 2 tokens.", icon="🚨") | |
| st.stop() | |
| if len(input_ids) > max_tokens: | |
| st.error( | |
| f"Your input has {len(input_ids)} tokens. Please enter at most {max_tokens} tokens " | |
| f"or try reducing the window size.", | |
| icon="🚨" | |
| ) | |
| st.stop() | |
| if metric_name == "KL divergence": | |
| st.error("KL divergence is not supported yet. Stay tuned!", icon="😭") | |
| st.stop() | |
| with st.spinner("Loading model…"): | |
| model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name) | |
| window_len = min(window_len, len(input_ids)) | |
| def get_logprobs(_model, _inputs, cache_key): | |
| del cache_key | |
| return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16) | |
| def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_key): | |
| del cache_key | |
| inputs_sliding = get_windows_batched( | |
| _inputs, | |
| window_len=window_len, | |
| pad_id=_tokenizer.eos_token_id | |
| ).convert_to_tensors("pt") | |
| logprobs = [] | |
| with st.spinner("Running model…"): | |
| batch_size = 8 | |
| num_items = len(inputs_sliding["input_ids"]) | |
| pbar = st.progress(0) | |
| for i in range(0, num_items, batch_size): | |
| pbar.progress(i / num_items, f"{i}/{num_items}") | |
| batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()} | |
| logprobs.append( | |
| get_logprobs( | |
| _model, | |
| batch, | |
| cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes()) | |
| ) | |
| ) | |
| logprobs = torch.cat(logprobs, dim=0) | |
| pbar.empty() | |
| with st.spinner("Computing scores…"): | |
| logprobs = logprobs.permute(1, 0, 2) | |
| logprobs = F.pad(logprobs, (0, 0, 0, window_len, 0, 0), value=torch.nan) | |
| logprobs = logprobs.view(-1, logprobs.shape[-1])[:-window_len] | |
| logprobs = logprobs.view(window_len, len(input_ids) + window_len - 2, logprobs.shape[-1]) | |
| scores = logprobs[:, torch.arange(len(input_ids[1:])), input_ids[1:]] | |
| scores = scores.diff(dim=0).transpose(0, 1) | |
| scores = scores.nan_to_num() | |
| scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-6 | |
| scores = scores.to(torch.float16) | |
| return scores | |
| scores = run_context_length_probing( | |
| _model=model, | |
| _tokenizer=tokenizer, | |
| _inputs=inputs, | |
| window_len=window_len, | |
| cache_key=(model_name, text), | |
| ) | |
| tokens = ids_to_readable_tokens(tokenizer, input_ids) | |
| st.markdown('<label style="font-size: 14px;">Output</label>', unsafe_allow_html=True) | |
| highlighted_text_component(tokens=tokens, scores=scores.tolist()) | |