Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import pathlib | |
| import pandas as pd | |
| from collections import defaultdict | |
| import json | |
| import copy | |
| import plotly.express as px | |
| def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]): | |
| if corpus_file is None: | |
| return None | |
| did2text = {} | |
| id_key = "_id" | |
| with corpus_file as f: | |
| for idx, line in enumerate(f): | |
| uses_bytes = not (type(line) == str) | |
| if uses_bytes: | |
| if idx == 0 and "doc_id" in line.decode("utf-8"): | |
| continue | |
| inst = json.loads(line.decode("utf-8")) | |
| else: | |
| if idx == 0 and "doc_id" in line: | |
| continue | |
| inst = json.loads(line) | |
| all_text = " ".join([inst[col] for col in columns_to_combine if col in inst and inst[col] is not None]) | |
| if id_key not in inst: | |
| id_key = "doc_id" | |
| did2text[inst[id_key]] = { | |
| "text": all_text, | |
| "title": inst["title"] if "title" in inst else "", | |
| } | |
| return did2text | |
| def load_local_queries(queries_file): | |
| if queries_file is None: | |
| return None | |
| qid2text = {} | |
| id_key = "_id" | |
| with queries_file as f: | |
| for idx, line in enumerate(f): | |
| uses_bytes = not (type(line) == str) | |
| if uses_bytes: | |
| if idx == 0 and "query_id" in line.decode("utf-8"): | |
| continue | |
| inst = json.loads(line.decode("utf-8")) | |
| else: | |
| if idx == 0 and "query_id" in line: | |
| continue | |
| inst = json.loads(line) | |
| if id_key not in inst: | |
| id_key = "query_id" | |
| qid2text[inst[id_key]] = inst["text"].replace("\t", " === ") | |
| return qid2text | |
| def load_local_qrels(qrels_file): | |
| if qrels_file is None: | |
| return None | |
| qid2did2label = defaultdict(dict) | |
| with qrels_file as f: | |
| for idx, line in enumerate(f): | |
| uses_bytes = not (type(line) == str) | |
| if uses_bytes: | |
| if idx == 0 and "qid" in line.decode("utf-8") or "query-id" in line.decode("utf-8"): | |
| continue | |
| cur_line = line.decode("utf-8") | |
| else: | |
| if idx == 0 and "qid" in line or "query-id" in line: | |
| continue | |
| cur_line = line | |
| try: | |
| qid, _, doc_id, label = cur_line.split() | |
| except: | |
| qid, doc_id, label = cur_line.split() | |
| qid2did2label[str(qid)][str(doc_id)] = int(label) | |
| return qid2did2label | |
| def load_jsonl(f): | |
| did2text = defaultdict(list) | |
| sub_did2text = {} | |
| for idx, line in enumerate(f): | |
| inst = json.loads(line) | |
| if "question" in inst: | |
| docid = inst["metadata"][0]["passage_id"] if "doc_id" not in inst else inst["doc_id"] | |
| did2text[docid].append(inst["question"]) | |
| elif "text" in inst: | |
| docid = inst["doc_id"] if "doc_id" in inst else inst["did"] | |
| did2text[docid].append(inst["text"]) | |
| sub_did2text[inst["did"]] = inst["text"] | |
| elif "query" in inst: | |
| docid = inst["doc_id"] if "doc_id" in inst else inst["did"] | |
| did2text[docid].append(inst["query"]) | |
| else: | |
| breakpoint() | |
| raise NotImplementedError("Need to handle this case") | |
| return did2text, sub_did2text | |
| def get_dataset(dataset_name: str, input_fields_doc, input_fields_query): | |
| if type(input_fields_doc) == str: | |
| input_fields_doc = input_fields_doc.strip().split(",") | |
| if type(input_fields_query) == str: | |
| input_fields_query = input_fields_query.strip().split(",") | |
| if dataset_name == "": | |
| return {}, {}, {} | |
| else: | |
| raise NotImplementedError("Dataset not implemented") |