Spaces:
Running
on
Zero
Running
on
Zero
| from .attribute import * | |
| import numpy as np | |
| import random | |
| from src.utils import * | |
| import time | |
| from sklearn.linear_model import LinearRegression | |
| from scipy.spatial.distance import cosine | |
| class PerturbationBasedAttribution(Attribution): | |
| def __init__(self, llm,explanation_level = "segment",K=5, attr_type = "tracllm",score_funcs=['stc','loo','denoised_shapley'], sh_N=5,w=2,beta = 0.2,verbose =1): | |
| super().__init__(llm,explanation_level,K,verbose) | |
| self.K=K | |
| self.w = w | |
| self.sh_N = sh_N | |
| self.attr_type = attr_type | |
| self.score_funcs = score_funcs | |
| self.beta = beta | |
| if "gpt" not in self.llm.name: | |
| self.model = llm.model | |
| self.tokenizer = llm.tokenizer | |
| self.func_map = { | |
| "shapley": self.shapley_scores, | |
| "denoised_shapley": self.denoised_shapley_scores, | |
| "stc": self.stc_scores, | |
| "loo": self.loo_scores | |
| } | |
| def marginal_contributions(self, question: str, contexts: list, answer: str) -> list: | |
| """ | |
| Estimate the Shapley values using a Monte Carlo approximation method, handling duplicate contexts. | |
| Each occurrence of a context, even if duplicated, is treated separately. | |
| Parameters: | |
| - contexts: a list of contexts, possibly with duplicates. | |
| - v: a function that takes a list of contexts and returns the total value for that coalition. | |
| - N: the number of random permutations to consider for the approximation. | |
| Returns: | |
| - A list with every context's Shapley value. | |
| """ | |
| k = len(contexts) | |
| # Initialize a list of Shapley values for each context occurrence | |
| shapley_values = [[] for _ in range(k)] | |
| count = 0 | |
| for j in range(self.sh_N): | |
| # Generate a random permutation of the indices of the contexts (to handle duplicates properly) | |
| perm_indices = random.sample(range(k), k) | |
| # Calculate the coalition value for the empty set + cf | |
| coalition_value = self.context_value(question, [""], answer) | |
| for i, index in enumerate(perm_indices): | |
| count += 1 | |
| # Create the coalition up to the current context (based on its index in the permutation) | |
| coalition = [contexts[idx] for idx in perm_indices[:i + 1]] | |
| coalition = sorted(coalition, key=lambda x: contexts.index(x)) # Sort based on original context order | |
| # Calculate the value for the current coalition | |
| context_value = self.context_value(question, coalition, answer) | |
| marginal_contribution = context_value - coalition_value | |
| # Update the Shapley value for the specific context at this index | |
| shapley_values[index].append(marginal_contribution) | |
| # Update the coalition value for the next iteration | |
| coalition_value = context_value | |
| return shapley_values | |
| def shapley_scores(self, question:str, contexts:list, answer:str) -> list: | |
| """ | |
| Estimate the Shapley values using a Monte Carlo approximation method. | |
| Parameters: | |
| - contexts: a list of contexts. | |
| - v: a function that takes a list of contexts and returns the total value for that coalition. | |
| - N: the number of random permutations to consider for the approximation. | |
| Returns: | |
| - A dictionary with contexts as keys and their approximate Shapley values as values. | |
| - A list with every context's shapley value. | |
| """ | |
| marginal_values= self.marginal_contributions(question, contexts, answer) | |
| shapley_values = np.zeros(len(marginal_values)) | |
| for i,value_list in enumerate(marginal_values): | |
| shapley_values[i] = np.mean(value_list) | |
| return shapley_values | |
| def denoised_shapley_scores(self, question:str, contexts:list, answer:str) -> list: | |
| marginal_values = self.marginal_contributions(question, contexts, answer) | |
| new_shapley_values = np.zeros(len(marginal_values)) | |
| for i,value_list in enumerate(marginal_values): | |
| new_shapley_values[i] = mean_of_percent(value_list,self.beta) | |
| return new_shapley_values | |
| def stc_scores(self, question:str, contexts:list, answer:str) -> list: | |
| k = len(contexts) | |
| scores = np.zeros(k) | |
| goal_score = self.context_value(question,[''],answer) | |
| for i,text in enumerate(contexts): | |
| scores[i] = (self.context_value(question, [text], answer) - goal_score) | |
| return scores.tolist() | |
| def loo_scores(self, question:str, contexts:list, answer:str) -> list: | |
| k = len(contexts) | |
| scores = np.zeros(k) | |
| v_all = self.context_value(question, contexts, answer) | |
| for i,text in enumerate(contexts): | |
| rest_texts = contexts[:i] + contexts[i+1:] | |
| scores[i] = v_all - self.context_value(question, rest_texts, answer) | |
| return scores.tolist() | |
| def tracllm(self, question:str, contexts:list, answer:str, score_func): | |
| current_nodes =[manual_zip(contexts, list(range(len(contexts))))] | |
| current_nodes_scores = None | |
| def get_important_nodes(nodes,importance_values): | |
| combined = list(zip(nodes, importance_values)) | |
| combined_sorted = sorted(combined, key=lambda x: x[1], reverse=True) | |
| # Determine the number of top nodes to keep | |
| k = min(self.K, len(combined)) | |
| top_nodes = combined_sorted[:k] | |
| top_nodes_sorted = sorted(top_nodes, key=lambda x: combined.index(x)) | |
| # Extract the top k important nodes and their scores in the original order | |
| important_nodes = [node for node, _ in top_nodes_sorted] | |
| important_nodes_scores = [score for _, score in top_nodes_sorted] | |
| return important_nodes, important_nodes_scores | |
| level = 0 | |
| while len(current_nodes)>0 and any(len(node) > 1 for node in current_nodes): | |
| level+=1 | |
| if self.verbose == 1: | |
| print(f"======= layer: {level}=======") | |
| new_nodes = [] | |
| for node in current_nodes: | |
| if len(node)>1: | |
| mid = len(node) // 2 | |
| node_left, node_right = node[:mid], node[mid:] | |
| new_nodes.append(node_left) | |
| new_nodes.append(node_right) | |
| else: | |
| new_nodes.append(node) | |
| if len(new_nodes)<= self.K: | |
| current_nodes = new_nodes | |
| else: | |
| importance_values= self.func_map[score_func](question, [" ".join(unzip_tuples(node)[0]) for node in new_nodes], answer) | |
| current_nodes,current_nodes_scores = get_important_nodes(new_nodes,importance_values) | |
| flattened_current_nodes = [item for sublist in current_nodes for item in sublist] | |
| return flattened_current_nodes, current_nodes_scores | |
| def vanilla_explanation(self, question:str, texts:list, answer:str,score_func): | |
| texts_scores = self.func_map[score_func](question, texts, answer) | |
| return texts,texts_scores | |
| def attribute(self, question:str, contexts:list, answer:str): | |
| """ | |
| Given question, contexts and answer, return attribution results | |
| """ | |
| ensemble_list = dict() | |
| texts = split_context(self.explanation_level,contexts) | |
| start_time = time.time() | |
| importance_dict = {} | |
| max_score_func_dict = {} | |
| score_funcs = self.score_funcs | |
| for score_func in score_funcs: | |
| if self.verbose == 1: | |
| print(f"-Start {score_func}") | |
| if score_func == "loo": | |
| weight = self.w | |
| else: | |
| weight = 1 | |
| if self.attr_type == "tracllm": | |
| important_nodes,importance_scores = self.tracllm(question, texts, answer,score_func) | |
| important_texts, important_ids = unzip_tuples(important_nodes) | |
| elif self.attr_type== "vanilla_perturb": | |
| important_texts,importance_scores = self.vanilla_explanation(question, texts, answer,score_func) | |
| texts = split_context(self.explanation_level,contexts) | |
| important_ids = [texts.index(text) for text in important_texts] | |
| else: | |
| raise ValueError("Unsupported attr_type.") | |
| ensemble_list[score_func] = list(zip(important_ids,importance_scores)) | |
| for idx, important_id in enumerate(important_ids): | |
| if important_id in importance_dict: | |
| if importance_dict[important_id]<weight*importance_scores[idx]: | |
| max_score_func_dict[important_id] = score_func | |
| importance_dict[important_id] = max(importance_dict[important_id],weight*importance_scores[idx]) | |
| else: | |
| importance_dict[important_id] = weight*importance_scores[idx] | |
| max_score_func_dict[important_id] = score_func | |
| end_time = time.time() | |
| important_ids = list(importance_dict.keys()) | |
| importance_scores = list(importance_dict.values()) | |
| return texts,important_ids, importance_scores, end_time-start_time,ensemble_list | |