Spaces:
Running
on
Zero
Running
on
Zero
| ''' | |
| Evaluation methods for no ground truth. | |
| 1.NLI | |
| 2.AttrScore | |
| 3.GPT-4 AttrScore | |
| ''' | |
| import torch | |
| from src.models import create_model | |
| from src.prompts import wrap_prompt | |
| from src.utils import * | |
| from src.utils import _read_results,_save_results | |
| import PromptInjectionAttacks as PI | |
| import signal | |
| import gc | |
| import math | |
| import time | |
| from sentence_transformers import SentenceTransformer, util | |
| def get_similarity(text1, text2,model): | |
| start_time = time.time() | |
| emb1 = model.encode(text1, convert_to_tensor=True) | |
| emb2 = model.encode(text2, convert_tensor=True) | |
| end_time = time.time() | |
| print("Time taken to calculate similarity: ", end_time - start_time) | |
| similarity = float(util.pytorch_cos_sim(emb1, emb2).item()) | |
| return similarity | |
| def calculate_precision_recall_f1(predicted, actual): | |
| predicted_set = set(predicted) | |
| actual_set = set(actual) | |
| TP = len(predicted_set & actual_set) # Intersection of predicted and actual sets | |
| FP = len(predicted_set - actual_set) # Elements in predicted but not in actual | |
| FN = len(actual_set - predicted_set) # Elements in actual but not in predicted | |
| precision = TP / (TP + FP) if (TP + FP) > 0 else 0 | |
| recall = TP / (TP + FN) if (TP + FN) > 0 else 0 | |
| f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
| return precision, recall, f1_score | |
| def remove_specific_indexes(lst, indexes_to_remove): | |
| return [item for idx, item in enumerate(lst) if idx not in indexes_to_remove] | |
| def retain_specific_indexes(lst, indexes_to_retain): | |
| return [item for idx, item in enumerate(lst) if idx in indexes_to_retain] | |
| def check_condition(args,llm,model,question,all_texts,important_ids,importance_scores,answer, k): | |
| top_k=top_k_indexes(importance_scores, k) | |
| topk_ids = [important_ids[j] for j in top_k] | |
| #remove top-K texts to check ASR change | |
| new_texts = remove_specific_indexes(all_texts, topk_ids) | |
| new_prompt = wrap_prompt(question, new_texts) | |
| new_answer =llm.query(new_prompt) | |
| completeness_condition = get_similarity(answer, new_answer,model) <0.99 | |
| print("==============================================================") | |
| print("current k: ", k) | |
| print("answer: ", answer, "new_answer: ", new_answer, "comp similarity: ", get_similarity(answer, new_answer)) | |
| new_texts = retain_specific_indexes(all_texts, topk_ids) | |
| new_prompt = wrap_prompt(question, new_texts) | |
| new_answer =llm.query(new_prompt) | |
| sufficiency_condition = get_similarity(answer, new_answer,model) > 0.99 | |
| print("answer: ", answer, "new_answer: ", new_answer, "suff similarity: ", get_similarity(answer, new_answer)) | |
| print("current k: ", k, "suff: ", sufficiency_condition, "comp: ", completeness_condition) | |
| print("==============================================================") | |
| return sufficiency_condition and completeness_condition | |
| def evaluate_prompt_injection(args,llm): | |
| pred_results_path = args.results_path | |
| new_attr_result = [] | |
| attr_result = _read_results(args, pred_results_path) | |
| for i, dict in enumerate(attr_result): | |
| print("Question number: ",i) | |
| important_ids = dict['important_ids'] | |
| importance_scores = dict['scores'] | |
| assert len(importance_scores) == len(important_ids) | |
| question = dict['question'] | |
| target_answer = dict['target_answer'] | |
| llm_answer = dict['answer'] | |
| malicious_texts = dict['gt_important_texts'] | |
| all_texts = split_context(args.explanation_level,dict['contexts']) | |
| gt_ids,_ = get_gt_ids(all_texts,malicious_texts) | |
| k = args.K | |
| if gt_ids == []: | |
| continue | |
| top_k=top_k_indexes(importance_scores, k) | |
| topk_ids = [important_ids[j] for j in top_k] | |
| print("gt_ids",gt_ids) | |
| print("topk_ids",topk_ids) | |
| precision, recall, f1 = calculate_precision_recall_f1( topk_ids,gt_ids) | |
| print(f'precision: {precision}, recall: {recall}, f1: {f1}') | |
| #remove top-K texts to check ASR change | |
| new_texts = remove_specific_indexes(all_texts, topk_ids) | |
| prompt = wrap_prompt(question, new_texts) | |
| try: | |
| signal.alarm(60) | |
| new_answer =llm.query(prompt) | |
| signal.alarm(0) | |
| except TimeoutError as exc: | |
| print("time out") | |
| ASV = 0 | |
| continue | |
| ASV = clean_str(target_answer) in clean_str(new_answer) or clean_str(new_answer) in clean_str(target_answer) | |
| dict[f'precision'] = precision | |
| dict[f'recall'] = recall | |
| dict[f'f1'] = f1 | |
| dict[f'asv'] = ASV | |
| new_attr_result.append(dict) | |
| total_recall = 0 | |
| total_precision = 0 | |
| total_f1 = 0 | |
| total_asv = 0 | |
| total_time = 0 | |
| for dict in new_attr_result: | |
| total_recall+= dict["recall"] | |
| total_precision+= dict["precision"] | |
| total_f1+= dict["f1"] | |
| total_asv+= dict["asv"] | |
| total_time+= dict["time"] | |
| print("AVG ASV after removal: ",total_asv/args.data_num) | |
| print("AVG PRECISION: ",total_precision/len(new_attr_result)) | |
| print("AVG RECALL: ",total_recall/len(new_attr_result)) | |
| print("AVG F1: ",total_f1/len(new_attr_result)) | |
| print("AVG TIME: ",total_time/len(new_attr_result)) | |
| del llm | |
| # Run the garbage collector | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def evaluate_poison_rag(args,llm): | |
| pred_results_path = args.results_path | |
| new_attr_result = [] | |
| attr_result = _read_results(args, pred_results_path) | |
| for i, dict in enumerate(attr_result): | |
| print("Question number: ",i) | |
| important_ids = dict['important_ids'] | |
| importance_scores = dict['scores'] | |
| assert len(importance_scores) == len(important_ids) | |
| question = dict['question'] | |
| target_answer = dict['target_answer'] | |
| llm_answer = dict['answer'] | |
| injected_adv = dict['gt_important_texts'] | |
| print("Question: ", question) | |
| all_texts = dict['contexts'] | |
| k = args.K | |
| top_k=top_k_indexes(importance_scores, k) | |
| topk_ids = [important_ids[j] for j in top_k] | |
| gt_ids,_ = get_gt_ids(all_texts,injected_adv) | |
| new_texts = remove_specific_indexes(all_texts, topk_ids) | |
| prompt = wrap_prompt(question, new_texts) | |
| precision, recall, f1 = calculate_precision_recall_f1( topk_ids,gt_ids) | |
| try: | |
| signal.alarm(60) | |
| new_answer =llm.query(prompt) | |
| ASV = int(clean_str(target_answer) in clean_str(new_answer)) | |
| signal.alarm(0) | |
| except TimeoutError as exc: | |
| print("time out") | |
| ASV = 1 | |
| dict[f'precision'] = precision | |
| dict[f'recall'] = recall | |
| dict[f'f1'] = f1 | |
| dict[f'asv'] = ASV | |
| new_attr_result.append(dict) | |
| total_recall = 0 | |
| total_precision = 0 | |
| total_asv = 0 | |
| total_time = 0 | |
| for dict in new_attr_result: | |
| total_recall+= dict["recall"] | |
| total_precision+= dict["precision"] | |
| total_asv+= dict["asv"] | |
| total_time+= dict["time"] | |
| print("AVG ASV after removal:: ",total_asv/args.data_num) | |
| print("AVG PRECISION: ",total_precision/len(new_attr_result)) | |
| print("AVG RECALL: ",total_recall/len(new_attr_result)) | |
| print("AVG TIME: ",total_time/len(new_attr_result)) | |
| _save_results(args, new_attr_result, pred_results_path) | |
| del llm | |
| # Run the garbage collector | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def evaluate_needle_in_haystack(args,llm): | |
| pred_results_path = args.results_path | |
| new_attr_result = [] | |
| attr_result = _read_results(args, pred_results_path) | |
| k = args.K | |
| for i, dict in enumerate(attr_result): | |
| print("Question number: ",i) | |
| important_ids = dict['important_ids'] | |
| importance_scores = dict['scores'] | |
| assert len(importance_scores) == len(important_ids) | |
| question = dict['question'] | |
| target_answer = dict['target_answer'] | |
| needles = dict['gt_important_texts'] | |
| all_texts = split_context(args.explanation_level,dict['contexts'])#contexts_to_sentences(dict['topk_contexts']) | |
| gt_ids=[] | |
| gt_texts = [] | |
| for j, segment in enumerate(all_texts): | |
| for needle in needles: | |
| if check_overlap(segment,needle,10): | |
| gt_ids.append(j) | |
| gt_texts.append(all_texts[j]) | |
| if gt_ids == []: | |
| continue | |
| top_k=top_k_indexes(importance_scores, k) | |
| topk_ids = [important_ids[j] for j in top_k] | |
| new_sentences = remove_specific_indexes(all_texts, topk_ids) | |
| precision, recall, f1 = calculate_precision_recall_f1( topk_ids,gt_ids) | |
| print(f'precision: {precision}, recall: {recall}, f1: {f1}') | |
| prompt = wrap_prompt(question, new_sentences) | |
| try: | |
| signal.alarm(60) | |
| new_answer =llm.query(prompt) | |
| signal.alarm(0) | |
| except TimeoutError as exc: | |
| print("time out") | |
| continue | |
| print("target answer:",target_answer) | |
| print("new answer:", new_answer) | |
| ACC = 1 | |
| for target in target_answer: | |
| if (clean_str(target_answer) not in clean_str(new_answer)): | |
| ACC = 0 | |
| dict[f'precision'] = precision | |
| dict[f'recall'] = recall | |
| dict[f'f1'] = f1 | |
| dict[f'acc'] = ACC | |
| new_attr_result.append(dict) | |
| total_recall = 0 | |
| total_precision = 0 | |
| total_acc = 0 | |
| total_time = 0 | |
| for dict in new_attr_result: | |
| total_recall+= dict["recall"] | |
| total_precision+= dict["precision"] | |
| total_acc+= dict["acc"] | |
| total_time+= dict["time"] | |
| print("AVG ACC after removal: ",total_acc/args.data_num) | |
| print("AVG PRECISION: ",total_precision/len(new_attr_result)) | |
| print("AVG RECALL: ",total_recall/len(new_attr_result)) | |
| print("AVG TIME: ",total_time/len(new_attr_result)) | |
| del llm | |
| # Run the garbage collector | |
| gc.collect() | |
| torch.cuda.empty_cache() | |