Spaces:
Sleeping
Sleeping
Orion Weller
commited on
Commit
·
a09b56d
1
Parent(s):
56649db
saliency maps
Browse files- .gitignore +3 -1
- analysis.py +93 -1
- app.py +88 -11
- dataset_loading.py +11 -2
- requirements.txt +3 -1
.gitignore
CHANGED
|
@@ -1,3 +1,5 @@
|
|
| 1 |
datasets/
|
| 2 |
__pycache__/
|
| 3 |
-
env/
|
|
|
|
|
|
|
|
|
| 1 |
datasets/
|
| 2 |
__pycache__/
|
| 3 |
+
env/
|
| 4 |
+
.ipynb_checkpoints/
|
| 5 |
+
*.ipynb
|
analysis.py
CHANGED
|
@@ -1,8 +1,21 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import plotly.express as px
|
| 4 |
import plotly.figure_factory as ff
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def results_to_df(results: dict, metric_name: str):
|
| 8 |
metric_scores = []
|
|
@@ -38,4 +51,83 @@ def create_boxplot_diff(results1, results2, metric_name):
|
|
| 38 |
|
| 39 |
x_axis = f"Difference in {metric_name} from 1 to 2"
|
| 40 |
fig = px.histogram(pd.DataFrame({x_axis: diff}), x=x_axis, marginal="box")
|
| 41 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import pipeline
|
| 6 |
+
import streamlit as st
|
| 7 |
+
|
| 8 |
import plotly.express as px
|
| 9 |
import plotly.figure_factory as ff
|
| 10 |
|
| 11 |
+
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization
|
| 12 |
+
from captum.attr import visualization as viz
|
| 13 |
+
from captum import attr
|
| 14 |
+
from captum.attr._utils.visualization import format_word_importances, format_special_tokens, _get_color
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
| 18 |
+
|
| 19 |
|
| 20 |
def results_to_df(results: dict, metric_name: str):
|
| 21 |
metric_scores = []
|
|
|
|
| 51 |
|
| 52 |
x_axis = f"Difference in {metric_name} from 1 to 2"
|
| 53 |
fig = px.histogram(pd.DataFrame({x_axis: diff}), x=x_axis, marginal="box")
|
| 54 |
+
return fig
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def summarize_attributions(attributions):
|
| 58 |
+
attributions = attributions.sum(dim=-1).squeeze(0)
|
| 59 |
+
attributions = attributions / torch.norm(attributions)
|
| 60 |
+
return attributions
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_words(words, importances):
|
| 64 |
+
words_colored = []
|
| 65 |
+
for word, importance in zip(words, importances[: len(words)]):
|
| 66 |
+
word = format_special_tokens(word)
|
| 67 |
+
color = _get_color(importance)
|
| 68 |
+
unwrapped_tag = '<span style="background-color: {color}; opacity:1.0; line-height:1.75">{word}</span>'.format(
|
| 69 |
+
color=color, word=word
|
| 70 |
+
)
|
| 71 |
+
words_colored.append(unwrapped_tag)
|
| 72 |
+
return words_colored
|
| 73 |
+
|
| 74 |
+
@st.cache_resource
|
| 75 |
+
def get_model(model_name: str):
|
| 76 |
+
if model_name == "MonoT5":
|
| 77 |
+
pipe = pipeline('text2text-generation',
|
| 78 |
+
model='castorini/monot5-small-msmarco-10k',
|
| 79 |
+
tokenizer='castorini/monot5-small-msmarco-10k',
|
| 80 |
+
device='cpu')
|
| 81 |
+
def formatter(query, doc):
|
| 82 |
+
return f"Query: {query} Document: {doc} Relevant:"
|
| 83 |
+
|
| 84 |
+
return pipe, formatter
|
| 85 |
+
|
| 86 |
+
def prep_func(pipe, formatter):
|
| 87 |
+
# variables that only need to be run once
|
| 88 |
+
decoder_input_ids = pipe.tokenizer(["<pad>"], return_tensors="pt", add_special_tokens=False, truncation=True).input_ids.to('cpu')
|
| 89 |
+
decoder_embedding_layer = pipe.model.base_model.decoder.embed_tokens
|
| 90 |
+
decoder_inputs_emb = decoder_embedding_layer(decoder_input_ids)
|
| 91 |
+
|
| 92 |
+
token_false_id = pipe.tokenizer.get_vocab()['▁false']
|
| 93 |
+
token_true_id = pipe.tokenizer.get_vocab()["▁true"]
|
| 94 |
+
|
| 95 |
+
# this function needs to be run for each combination
|
| 96 |
+
@st.cache_data
|
| 97 |
+
def get_saliency(query, doc):
|
| 98 |
+
input_ids = pipe.tokenizer(
|
| 99 |
+
[formatter(query, doc)],
|
| 100 |
+
padding=False,
|
| 101 |
+
truncation=True,
|
| 102 |
+
return_tensors="pt",
|
| 103 |
+
max_length=pipe.tokenizer.model_max_length,
|
| 104 |
+
)["input_ids"].to('cpu')
|
| 105 |
+
|
| 106 |
+
embedding_layer = pipe.model.base_model.encoder.embed_tokens
|
| 107 |
+
inputs_emb = embedding_layer(input_ids)
|
| 108 |
+
|
| 109 |
+
def forward_from_embeddings(inputs_embeds, decoder_inputs_embeds):
|
| 110 |
+
logits = pipe.model.forward(inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds)['logits'][:, -1, :]
|
| 111 |
+
batch_scores = logits[:, [token_false_id, token_true_id]]
|
| 112 |
+
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
| 113 |
+
scores = batch_scores[:, 1].exp() # relevant token
|
| 114 |
+
return scores
|
| 115 |
+
|
| 116 |
+
lig = attr.Saliency(forward_from_embeddings)
|
| 117 |
+
attributions_ig, delta = lig.attribute(
|
| 118 |
+
inputs=(inputs_emb, decoder_inputs_emb)
|
| 119 |
+
)
|
| 120 |
+
attributions_normed = summarize_attributions(attributions_ig)
|
| 121 |
+
return "\n".join(get_words(pipe.tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).tolist()), attributions_normed))
|
| 122 |
+
|
| 123 |
+
return get_saliency
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
query = "how to add dll to visual studio?"
|
| 128 |
+
doc = "StackOverflow In the days of 16-bit Windows, a WPARAM was a 16-bit word, while LPARAM was a 32-bit long. These distinctions went away in Win32; they both became 32-bit values. ... WPARAM is defined as UINT_PTR , which in 64-bit Windows is an unsigned, 64-bit value."
|
| 129 |
+
model, formatter = get_model("MonoT5")
|
| 130 |
+
get_saliency = prep_func(model, formatter)
|
| 131 |
+
print(get_saliency(query, doc))
|
| 132 |
+
|
| 133 |
+
|
app.py
CHANGED
|
@@ -13,9 +13,10 @@ import plotly.express as px
|
|
| 13 |
|
| 14 |
from constants import ALL_DATASETS, ALL_METRICS
|
| 15 |
from dataset_loading import get_dataset, load_run, load_local_qrels, load_local_corpus, load_local_queries
|
| 16 |
-
from analysis import create_boxplot_1df, create_boxplot_2df, create_boxplot_diff
|
| 17 |
|
| 18 |
|
|
|
|
| 19 |
st.set_page_config(layout="wide")
|
| 20 |
|
| 21 |
|
|
@@ -41,6 +42,7 @@ def check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus)
|
|
| 41 |
return True
|
| 42 |
return False
|
| 43 |
|
|
|
|
| 44 |
def validate(config_option, file_loaded):
|
| 45 |
if config_option != "None" and file_loaded is None:
|
| 46 |
st.error("Please upload a file for " + config_option)
|
|
@@ -90,6 +92,14 @@ with st.sidebar:
|
|
| 90 |
incorrect_only = st.checkbox("Show only incorrect instances", value=False)
|
| 91 |
one_better_than_two = st.checkbox("Show only instances where run 1 is better than run 2", value=False)
|
| 92 |
two_better_than_one = st.checkbox("Show only instances where run 2 is better than run 1", value=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
advanced_options1 = st.checkbox("Show advanced options for Run 1", value=False)
|
| 94 |
doc_expansion1 = doc_expansion2 = None
|
| 95 |
query_expansion1 = query_expansion2 = None
|
|
@@ -307,9 +317,16 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
| 307 |
if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel:
|
| 308 |
alt_text = doc_expansion1[docid]["text"]
|
| 309 |
text = combine(text, alt_text, run1_uses_doc_expansion)
|
| 310 |
-
st.text_area(f"{docid}:", text)
|
| 311 |
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
pred_doc = run1_pandas[run1_pandas.doc_id.isin(relevant_docs)]
|
| 315 |
rank_pred = pred_doc[pred_doc.qid == str(inst_num)]["rank"].tolist()
|
|
@@ -320,6 +337,7 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
| 320 |
ranking_str = "--"
|
| 321 |
rank_col.metric(f"Rank of Relevant Doc(s)", ranking_str)
|
| 322 |
|
|
|
|
| 323 |
st.divider()
|
| 324 |
|
| 325 |
# top ranked
|
|
@@ -336,10 +354,22 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
| 336 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
| 337 |
alt_text = run1_top_n_docs_alt[d_idx]["text"]
|
| 338 |
doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
else:
|
| 341 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
st.divider()
|
| 344 |
|
| 345 |
# none checked
|
|
@@ -384,20 +414,28 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
| 384 |
combined_text2 = combine(query_text_og, alt_text2, run2_uses_query_expansion)
|
| 385 |
col_run1.markdown(combined_text1)
|
| 386 |
col_run2.markdown(combined_text2)
|
|
|
|
|
|
|
| 387 |
elif run1_uses_query_expansion != "None":
|
| 388 |
alt_text = query_expansion1[str(inst_num)]
|
| 389 |
combined_text1 = combine(query_text_og, alt_text, run1_uses_query_expansion)
|
| 390 |
col_run1.markdown(combined_text1)
|
| 391 |
col_run2.markdown(query_text_og)
|
|
|
|
|
|
|
| 392 |
elif run2_uses_query_expansion != "None":
|
| 393 |
alt_text = query_expansion2[str(inst_num)]
|
| 394 |
combined_text2 = combine(query_text_og, alt_text, run2_uses_query_expansion)
|
| 395 |
col_run1.markdown(query_text_og)
|
| 396 |
col_run2.markdown(combined_text2)
|
|
|
|
|
|
|
| 397 |
else:
|
| 398 |
query_text = query_text_og
|
| 399 |
col_run1.markdown(query_text)
|
| 400 |
col_run2.markdown(query_text)
|
|
|
|
|
|
|
| 401 |
|
| 402 |
st.divider()
|
| 403 |
|
|
@@ -420,13 +458,27 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
| 420 |
if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel1:
|
| 421 |
alt_text = doc_expansion1[docid]["text"]
|
| 422 |
text = combine(text, alt_text, run1_uses_doc_expansion)
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
|
| 425 |
for (docid, title, text) in doc_texts:
|
| 426 |
if doc_expansion2 is not None and run2_uses_doc_expansion != "None" and not show_orig_rel2:
|
| 427 |
alt_text = doc_expansion2[docid]["text"]
|
| 428 |
text = combine(text, alt_text, run2_uses_doc_expansion)
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
# top ranked
|
| 432 |
# NOTE: BEIR calls trec_eval which ranks by score, then doc_id for ties
|
|
@@ -474,10 +526,23 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
| 474 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
| 475 |
alt_text = run1_top_n_docs_alt[d_idx]["text"]
|
| 476 |
doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
|
| 477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
else:
|
| 479 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
| 480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
if col_run2.checkbox('Show top ranked documents for Run 2', key=f"{inst_index}top-2run"):
|
| 483 |
col_run2.subheader("Top N Ranked Documents")
|
|
@@ -492,10 +557,22 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
| 492 |
for d_idx, doc in enumerate(run2_top_n_docs):
|
| 493 |
alt_text = run2_top_n_docs_alt[d_idx]["text"]
|
| 494 |
doc_text = combine(doc["text"], alt_text, run2_uses_doc_expansion)
|
| 495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
else:
|
| 497 |
for d_idx, doc in enumerate(run2_top_n_docs):
|
| 498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
st.divider()
|
| 501 |
|
|
|
|
| 13 |
|
| 14 |
from constants import ALL_DATASETS, ALL_METRICS
|
| 15 |
from dataset_loading import get_dataset, load_run, load_local_qrels, load_local_corpus, load_local_queries
|
| 16 |
+
from analysis import create_boxplot_1df, create_boxplot_2df, create_boxplot_diff, get_model, prep_func
|
| 17 |
|
| 18 |
|
| 19 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
| 20 |
st.set_page_config(layout="wide")
|
| 21 |
|
| 22 |
|
|
|
|
| 42 |
return True
|
| 43 |
return False
|
| 44 |
|
| 45 |
+
|
| 46 |
def validate(config_option, file_loaded):
|
| 47 |
if config_option != "None" and file_loaded is None:
|
| 48 |
st.error("Please upload a file for " + config_option)
|
|
|
|
| 92 |
incorrect_only = st.checkbox("Show only incorrect instances", value=False)
|
| 93 |
one_better_than_two = st.checkbox("Show only instances where run 1 is better than run 2", value=False)
|
| 94 |
two_better_than_one = st.checkbox("Show only instances where run 2 is better than run 1", value=False)
|
| 95 |
+
use_model_saliency = st.checkbox("Use model saliency (slow!)", value=False)
|
| 96 |
+
if use_model_saliency:
|
| 97 |
+
# choose from a list of models
|
| 98 |
+
model_name = st.selectbox("Choose from a list of models", ["MonoT5"])
|
| 99 |
+
model, formatter = get_model("MonoT5")
|
| 100 |
+
get_saliency = prep_func(model, formatter)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
advanced_options1 = st.checkbox("Show advanced options for Run 1", value=False)
|
| 104 |
doc_expansion1 = doc_expansion2 = None
|
| 105 |
query_expansion1 = query_expansion2 = None
|
|
|
|
| 317 |
if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel:
|
| 318 |
alt_text = doc_expansion1[docid]["text"]
|
| 319 |
text = combine(text, alt_text, run1_uses_doc_expansion)
|
|
|
|
| 320 |
|
| 321 |
+
if use_model_saliency:
|
| 322 |
+
if st.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency", value=False):
|
| 323 |
+
st.markdown(get_saliency(query_text, doc_texts),unsafe_allow_html=True)
|
| 324 |
+
else:
|
| 325 |
+
st.text_area(f"{docid}:", text)
|
| 326 |
+
|
| 327 |
+
else:
|
| 328 |
+
st.text_area(f"{docid}:", text)
|
| 329 |
+
|
| 330 |
|
| 331 |
pred_doc = run1_pandas[run1_pandas.doc_id.isin(relevant_docs)]
|
| 332 |
rank_pred = pred_doc[pred_doc.qid == str(inst_num)]["rank"].tolist()
|
|
|
|
| 337 |
ranking_str = "--"
|
| 338 |
rank_col.metric(f"Rank of Relevant Doc(s)", ranking_str)
|
| 339 |
|
| 340 |
+
|
| 341 |
st.divider()
|
| 342 |
|
| 343 |
# top ranked
|
|
|
|
| 354 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
| 355 |
alt_text = run1_top_n_docs_alt[d_idx]["text"]
|
| 356 |
doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
|
| 357 |
+
if use_model_saliency:
|
| 358 |
+
if st.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency", value=False):
|
| 359 |
+
st.markdown(get_saliency(query_text, doc_text),unsafe_allow_html=True)
|
| 360 |
+
else:
|
| 361 |
+
st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}")
|
| 362 |
+
else:
|
| 363 |
+
st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}")
|
| 364 |
else:
|
| 365 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
| 366 |
+
if use_model_saliency:
|
| 367 |
+
if st.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked", value=False):
|
| 368 |
+
st.markdown(get_saliency(query_text, doc),unsafe_allow_html=True)
|
| 369 |
+
else:
|
| 370 |
+
st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}")
|
| 371 |
+
else:
|
| 372 |
+
st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}")
|
| 373 |
st.divider()
|
| 374 |
|
| 375 |
# none checked
|
|
|
|
| 414 |
combined_text2 = combine(query_text_og, alt_text2, run2_uses_query_expansion)
|
| 415 |
col_run1.markdown(combined_text1)
|
| 416 |
col_run2.markdown(combined_text2)
|
| 417 |
+
query_text1 = combined_text1
|
| 418 |
+
query_text2 = combined_text2
|
| 419 |
elif run1_uses_query_expansion != "None":
|
| 420 |
alt_text = query_expansion1[str(inst_num)]
|
| 421 |
combined_text1 = combine(query_text_og, alt_text, run1_uses_query_expansion)
|
| 422 |
col_run1.markdown(combined_text1)
|
| 423 |
col_run2.markdown(query_text_og)
|
| 424 |
+
query_text1 = combined_text1
|
| 425 |
+
query_text2 = query_text_og
|
| 426 |
elif run2_uses_query_expansion != "None":
|
| 427 |
alt_text = query_expansion2[str(inst_num)]
|
| 428 |
combined_text2 = combine(query_text_og, alt_text, run2_uses_query_expansion)
|
| 429 |
col_run1.markdown(query_text_og)
|
| 430 |
col_run2.markdown(combined_text2)
|
| 431 |
+
query_text1 = query_text_og
|
| 432 |
+
query_text2 = combined_text2
|
| 433 |
else:
|
| 434 |
query_text = query_text_og
|
| 435 |
col_run1.markdown(query_text)
|
| 436 |
col_run2.markdown(query_text)
|
| 437 |
+
query_text1 = query_text
|
| 438 |
+
query_text2 = query_text
|
| 439 |
|
| 440 |
st.divider()
|
| 441 |
|
|
|
|
| 458 |
if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel1:
|
| 459 |
alt_text = doc_expansion1[docid]["text"]
|
| 460 |
text = combine(text, alt_text, run1_uses_doc_expansion)
|
| 461 |
+
|
| 462 |
+
if use_model_saliency:
|
| 463 |
+
if col_run1.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{docid}relevant", value=False):
|
| 464 |
+
col_run1.markdown(get_saliency(query_text1, text),unsafe_allow_html=True)
|
| 465 |
+
else:
|
| 466 |
+
col_run1.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}1")
|
| 467 |
+
else:
|
| 468 |
+
col_run1.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}1")
|
| 469 |
|
| 470 |
for (docid, title, text) in doc_texts:
|
| 471 |
if doc_expansion2 is not None and run2_uses_doc_expansion != "None" and not show_orig_rel2:
|
| 472 |
alt_text = doc_expansion2[docid]["text"]
|
| 473 |
text = combine(text, alt_text, run2_uses_doc_expansion)
|
| 474 |
+
|
| 475 |
+
if use_model_saliency:
|
| 476 |
+
if col_run2.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{docid}relevant2", value=False):
|
| 477 |
+
col_run2.markdown(get_saliency(query_text2, text),unsafe_allow_html=True)
|
| 478 |
+
else:
|
| 479 |
+
col_run2.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}2")
|
| 480 |
+
else:
|
| 481 |
+
col_run2.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}2")
|
| 482 |
|
| 483 |
# top ranked
|
| 484 |
# NOTE: BEIR calls trec_eval which ranks by score, then doc_id for ties
|
|
|
|
| 526 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
| 527 |
alt_text = run1_top_n_docs_alt[d_idx]["text"]
|
| 528 |
doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
|
| 529 |
+
if use_model_saliency:
|
| 530 |
+
if col_run1.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked1", value=False):
|
| 531 |
+
col_run1.markdown(get_saliency(query_text1, doc_text),unsafe_allow_html=True)
|
| 532 |
+
else:
|
| 533 |
+
col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}1")
|
| 534 |
+
else:
|
| 535 |
+
col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}1")
|
| 536 |
else:
|
| 537 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
| 538 |
+
if use_model_saliency:
|
| 539 |
+
if col_run1.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked1", value=False):
|
| 540 |
+
col_run1.markdown(get_saliency(query_text1, doc),unsafe_allow_html=True)
|
| 541 |
+
else:
|
| 542 |
+
col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}1")
|
| 543 |
+
else:
|
| 544 |
+
col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}1")
|
| 545 |
+
|
| 546 |
|
| 547 |
if col_run2.checkbox('Show top ranked documents for Run 2', key=f"{inst_index}top-2run"):
|
| 548 |
col_run2.subheader("Top N Ranked Documents")
|
|
|
|
| 557 |
for d_idx, doc in enumerate(run2_top_n_docs):
|
| 558 |
alt_text = run2_top_n_docs_alt[d_idx]["text"]
|
| 559 |
doc_text = combine(doc["text"], alt_text, run2_uses_doc_expansion)
|
| 560 |
+
if use_model_saliency:
|
| 561 |
+
if col_run2.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked2", value=False):
|
| 562 |
+
col_run2.markdown(get_saliency(query_text2, doc_text),unsafe_allow_html=True)
|
| 563 |
+
else:
|
| 564 |
+
col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}2")
|
| 565 |
+
else:
|
| 566 |
+
col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}2")
|
| 567 |
else:
|
| 568 |
for d_idx, doc in enumerate(run2_top_n_docs):
|
| 569 |
+
if use_model_saliency:
|
| 570 |
+
if col_run2.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked2", value=False):
|
| 571 |
+
col_run2.markdown(get_saliency(query_text2, doc),unsafe_allow_html=True)
|
| 572 |
+
else:
|
| 573 |
+
col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}2")
|
| 574 |
+
else:
|
| 575 |
+
col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}2")
|
| 576 |
|
| 577 |
st.divider()
|
| 578 |
|
dataset_loading.py
CHANGED
|
@@ -14,6 +14,8 @@ import ir_datasets
|
|
| 14 |
|
| 15 |
from constants import BEIR, IR_DATASETS, LOCAL_DATASETS
|
| 16 |
|
|
|
|
|
|
|
| 17 |
def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
|
| 18 |
if corpus_file is None:
|
| 19 |
return None
|
|
@@ -39,6 +41,8 @@ def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
|
|
| 39 |
}
|
| 40 |
return did2text
|
| 41 |
|
|
|
|
|
|
|
| 42 |
def load_local_queries(queries_file):
|
| 43 |
if queries_file is None:
|
| 44 |
return None
|
|
@@ -60,6 +64,8 @@ def load_local_queries(queries_file):
|
|
| 60 |
qid2text[inst[id_key]] = inst["text"]
|
| 61 |
return qid2text
|
| 62 |
|
|
|
|
|
|
|
| 63 |
def load_local_qrels(qrels_file):
|
| 64 |
if qrels_file is None:
|
| 65 |
return None
|
|
@@ -84,6 +90,7 @@ def load_local_qrels(qrels_file):
|
|
| 84 |
return qid2did2label
|
| 85 |
|
| 86 |
|
|
|
|
| 87 |
def load_run(f_run):
|
| 88 |
run = pytrec_eval.parse_run(copy.deepcopy(f_run))
|
| 89 |
# convert bytes to strings for keys
|
|
@@ -102,7 +109,7 @@ def load_run(f_run):
|
|
| 102 |
return new_run, run_pandas
|
| 103 |
|
| 104 |
|
| 105 |
-
|
| 106 |
def load_jsonl(f):
|
| 107 |
did2text = defaultdict(list)
|
| 108 |
sub_did2text = {}
|
|
@@ -126,7 +133,7 @@ def load_jsonl(f):
|
|
| 126 |
return did2text, sub_did2text
|
| 127 |
|
| 128 |
|
| 129 |
-
|
| 130 |
def get_beir(dataset: str):
|
| 131 |
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
|
| 132 |
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
|
|
@@ -134,6 +141,7 @@ def get_beir(dataset: str):
|
|
| 134 |
return GenericDataLoader(data_folder=data_path).load(split="test")
|
| 135 |
|
| 136 |
|
|
|
|
| 137 |
def get_ir_datasets(dataset_name: str):
|
| 138 |
dataset = ir_datasets.load(dataset_name)
|
| 139 |
queries = {}
|
|
@@ -145,6 +153,7 @@ def get_ir_datasets(dataset_name: str):
|
|
| 145 |
return dataset.doc_store(), queries, dataset.qrels_dict()
|
| 146 |
|
| 147 |
|
|
|
|
| 148 |
def get_dataset(dataset_name: str):
|
| 149 |
if dataset_name == "":
|
| 150 |
return {}, {}, {}
|
|
|
|
| 14 |
|
| 15 |
from constants import BEIR, IR_DATASETS, LOCAL_DATASETS
|
| 16 |
|
| 17 |
+
|
| 18 |
+
@st.cache_data
|
| 19 |
def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
|
| 20 |
if corpus_file is None:
|
| 21 |
return None
|
|
|
|
| 41 |
}
|
| 42 |
return did2text
|
| 43 |
|
| 44 |
+
|
| 45 |
+
@st.cache_data
|
| 46 |
def load_local_queries(queries_file):
|
| 47 |
if queries_file is None:
|
| 48 |
return None
|
|
|
|
| 64 |
qid2text[inst[id_key]] = inst["text"]
|
| 65 |
return qid2text
|
| 66 |
|
| 67 |
+
|
| 68 |
+
@st.cache_data
|
| 69 |
def load_local_qrels(qrels_file):
|
| 70 |
if qrels_file is None:
|
| 71 |
return None
|
|
|
|
| 90 |
return qid2did2label
|
| 91 |
|
| 92 |
|
| 93 |
+
@st.cache_data
|
| 94 |
def load_run(f_run):
|
| 95 |
run = pytrec_eval.parse_run(copy.deepcopy(f_run))
|
| 96 |
# convert bytes to strings for keys
|
|
|
|
| 109 |
return new_run, run_pandas
|
| 110 |
|
| 111 |
|
| 112 |
+
@st.cache_data
|
| 113 |
def load_jsonl(f):
|
| 114 |
did2text = defaultdict(list)
|
| 115 |
sub_did2text = {}
|
|
|
|
| 133 |
return did2text, sub_did2text
|
| 134 |
|
| 135 |
|
| 136 |
+
@st.cache_data
|
| 137 |
def get_beir(dataset: str):
|
| 138 |
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
|
| 139 |
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
|
|
|
|
| 141 |
return GenericDataLoader(data_folder=data_path).load(split="test")
|
| 142 |
|
| 143 |
|
| 144 |
+
@st.cache_data
|
| 145 |
def get_ir_datasets(dataset_name: str):
|
| 146 |
dataset = ir_datasets.load(dataset_name)
|
| 147 |
queries = {}
|
|
|
|
| 153 |
return dataset.doc_store(), queries, dataset.qrels_dict()
|
| 154 |
|
| 155 |
|
| 156 |
+
@st.cache_data
|
| 157 |
def get_dataset(dataset_name: str):
|
| 158 |
if dataset_name == "":
|
| 159 |
return {}, {}, {}
|
requirements.txt
CHANGED
|
@@ -5,4 +5,6 @@ streamlit==1.24.1
|
|
| 5 |
ir_datasets==0.5.5
|
| 6 |
pyserini==0.21.0
|
| 7 |
torch==2.0.1
|
| 8 |
-
plotly==5.15.0
|
|
|
|
|
|
|
|
|
| 5 |
ir_datasets==0.5.5
|
| 6 |
pyserini==0.21.0
|
| 7 |
torch==2.0.1
|
| 8 |
+
plotly==5.15.0
|
| 9 |
+
captum==0.6.0
|
| 10 |
+
protobuf==4.21.11
|