Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import pickle | |
| import numpy as np | |
| import glob | |
| import tqdm | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| from peft import PeftModel | |
| from tevatron.retriever.searcher import FaissFlatSearcher | |
| import logging | |
| import os | |
| import json | |
| import spaces | |
| import ir_datasets | |
| import pytrec_eval | |
| from huggingface_hub import login | |
| import faiss | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Authenticate with HF_TOKEN | |
| login(token=os.environ['HF_TOKEN']) | |
| # Global variables | |
| CUR_MODEL = "orionweller/repllama-instruct-hard-positives-v2-joint" | |
| BASE_MODEL = "meta-llama/Llama-2-7b-hf" | |
| tokenizer = None | |
| model = None | |
| retrievers = {} | |
| corpus_lookups = {} | |
| queries = {} | |
| q_lookups = {} | |
| qrels = {} | |
| datasets = ["scifact"] | |
| current_dataset = "scifact" | |
| def pool(last_hidden_states, attention_mask, pool_type="last"): | |
| last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
| if pool_type == "last": | |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) | |
| if left_padding: | |
| emb = last_hidden[:, -1] | |
| else: | |
| sequence_lengths = attention_mask.sum(dim=1) - 1 | |
| batch_size = last_hidden.shape[0] | |
| emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] | |
| else: | |
| raise ValueError(f"pool_type {pool_type} not supported") | |
| return emb | |
| def create_batch_dict(tokenizer, input_texts, always_add_eos="last", max_length=512): | |
| batch_dict = tokenizer( | |
| input_texts, | |
| max_length=max_length - 1, | |
| return_token_type_ids=False, | |
| return_attention_mask=False, | |
| padding=False, | |
| truncation=True | |
| ) | |
| if always_add_eos == "last": | |
| batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']] | |
| return tokenizer.pad( | |
| batch_dict, | |
| padding=True, | |
| pad_to_multiple_of=8, | |
| return_attention_mask=True, | |
| return_tensors="pt", | |
| ) | |
| class RepLlamaModel: | |
| def __init__(self, model_name_or_path): | |
| self.base_model = "meta-llama/Llama-2-7b-hf" | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.base_model) | |
| self.tokenizer.model_max_length = 2048 | |
| self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.tokenizer.padding_side = "right" | |
| self.model = self.get_model(model_name_or_path) | |
| self.model.config.max_length = 2048 | |
| def get_model(self, peft_model_name): | |
| base_model = AutoModel.from_pretrained(self.base_model) | |
| model = PeftModel.from_pretrained(base_model, peft_model_name) | |
| model = model.merge_and_unload() | |
| model.eval() | |
| return model | |
| def encode(self, texts, batch_size=32, **kwargs): | |
| self.model = self.model.cuda() | |
| all_embeddings = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i:i+batch_size] | |
| batch_dict = create_batch_dict(self.tokenizer, batch_texts, always_add_eos="last") | |
| batch_dict = {key: value.cuda() for key, value in batch_dict.items()} | |
| with torch.cuda.amp.autocast(): | |
| with torch.no_grad(): | |
| outputs = self.model(**batch_dict) | |
| embeddings = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last') | |
| embeddings = F.normalize(embeddings, p=2, dim=-1) | |
| all_embeddings.append(embeddings.cpu().numpy()) | |
| self.model = self.model.cpu() | |
| return np.concatenate(all_embeddings, axis=0) | |
| def load_faiss_index(dataset_name): | |
| index_path = f"{dataset_name}/faiss_index.bin" | |
| if os.path.exists(index_path): | |
| logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}") | |
| return faiss.read_index(index_path, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) | |
| return None | |
| def search_queries(dataset_name, q_reps, depth=1000): | |
| faiss_index = load_faiss_index(dataset_name) | |
| if faiss_index is None: | |
| raise ValueError(f"No FAISS index found for dataset {dataset_name}") | |
| # Ensure q_reps is a 2D numpy array of the correct type | |
| q_reps = np.ascontiguousarray(q_reps.astype('float16')) | |
| # Perform the search | |
| all_scores, all_indices = faiss_index.search(q_reps, depth) | |
| psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices] | |
| # Clean up | |
| del faiss_index | |
| return all_scores, np.array(psg_indices) | |
| def load_corpus_lookups(dataset_name): | |
| global corpus_lookups | |
| corpus_path = f"{dataset_name}/corpus_emb.*.pkl" | |
| index_files = glob.glob(corpus_path) | |
| corpus_lookups[dataset_name] = [] | |
| for file in index_files: | |
| with open(file, 'rb') as f: | |
| _, p_lookup = pickle.load(f) | |
| corpus_lookups[dataset_name] += p_lookup | |
| def load_queries(dataset_name): | |
| global queries, q_lookups, qrels | |
| dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else "")) | |
| queries[dataset_name] = [] | |
| q_lookups[dataset_name] = {} | |
| qrels[dataset_name] = {} | |
| for query in dataset.queries_iter(): | |
| queries[dataset_name].append(query.text) | |
| q_lookups[dataset_name][query.query_id] = query.text | |
| for qrel in dataset.qrels_iter(): | |
| if qrel.query_id not in qrels[dataset_name]: | |
| qrels[dataset_name][qrel.query_id] = {} | |
| qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance | |
| def evaluate(qrels, results, k_values): | |
| evaluator = pytrec_eval.RelevanceEvaluator( | |
| qrels, {f"ndcg_cut.{k}" for k in k_values} | {f"recall.{k}" for k in k_values} | |
| ) | |
| scores = evaluator.evaluate(results) | |
| metrics = {} | |
| for k in k_values: | |
| metrics[f"NDCG@{k}"] = round(np.mean([query_scores[f"ndcg_cut_{k}"] for query_scores in scores.values()]), 3) | |
| metrics[f"Recall@{k}"] = round(np.mean([query_scores[f"recall_{k}"] for query_scores in scores.values()]), 3) | |
| return metrics | |
| def run_evaluation(dataset, postfix): | |
| global current_dataset, queries, model | |
| current_dataset = dataset | |
| input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[current_dataset]] | |
| q_reps = model.encode(input_texts) | |
| all_scores, psg_indices = search_queries(dataset, q_reps) | |
| results = {qid: dict(zip(doc_ids, map(float, scores))) | |
| for qid, scores, doc_ids in zip(q_lookups[dataset].keys(), all_scores, psg_indices)} | |
| metrics = evaluate(qrels[dataset], results, k_values=[10, 100]) | |
| return { | |
| "NDCG@10": metrics["NDCG@10"], | |
| "Recall@100": metrics["Recall@100"] | |
| } | |
| def gradio_interface(dataset, postfix): | |
| return run_evaluation(dataset, postfix) | |
| if model is None: | |
| model = RepLlamaModel(model_name_or_path=CUR_MODEL) | |
| load_corpus_lookups(current_dataset) | |
| load_queries(current_dataset) | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Dropdown(choices=datasets, label="Dataset", value="scifact"), | |
| gr.Textbox(label="Prompt") | |
| ], | |
| outputs=gr.JSON(label="Evaluation Results"), | |
| title="Promptriever Demo", | |
| description="Select a dataset and enter a prompt to evaluate the model's performance. Note: it takes about **ten seconds** to evaluate.", | |
| examples=[ | |
| ["scifact", ""], | |
| ["scifact", "Think carefully about these conditions when determining relevance"] | |
| ], | |
| cache_examples=False, | |
| ) | |
| # Launch the interface | |
| iface.launch() |