Spaces:
Runtime error
Runtime error
| import copy | |
| import math | |
| import numpy as np | |
| from gradio import utils | |
| from gradio.components import Label, Number | |
| async def run_interpret(interface, raw_input): | |
| """ | |
| Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box | |
| interpretation for a certain set of UI component types, as well as the custom interpretation case. | |
| Parameters: | |
| raw_input: a list of raw inputs to apply the interpretation(s) on. | |
| """ | |
| if isinstance(interface.interpretation, list): # Either "default" or "shap" | |
| processed_input = [ | |
| input_component.preprocess(raw_input[i]) | |
| for i, input_component in enumerate(interface.input_components) | |
| ] | |
| original_output = await interface.call_function(0, processed_input) | |
| original_output = original_output["prediction"] | |
| if len(interface.output_components) == 1: | |
| original_output = [original_output] | |
| scores, alternative_outputs = [], [] | |
| for i, (x, interp) in enumerate(zip(raw_input, interface.interpretation)): | |
| if interp == "default": | |
| input_component = interface.input_components[i] | |
| neighbor_raw_input = list(raw_input) | |
| if input_component.interpret_by_tokens: | |
| tokens, neighbor_values, masks = input_component.tokenize(x) | |
| interface_scores = [] | |
| alternative_output = [] | |
| for neighbor_input in neighbor_values: | |
| neighbor_raw_input[i] = neighbor_input | |
| processed_neighbor_input = [ | |
| input_component.preprocess(neighbor_raw_input[i]) | |
| for i, input_component in enumerate( | |
| interface.input_components | |
| ) | |
| ] | |
| neighbor_output = await interface.call_function( | |
| 0, processed_neighbor_input | |
| ) | |
| neighbor_output = neighbor_output["prediction"] | |
| if len(interface.output_components) == 1: | |
| neighbor_output = [neighbor_output] | |
| processed_neighbor_output = [ | |
| output_component.postprocess(neighbor_output[i]) | |
| for i, output_component in enumerate( | |
| interface.output_components | |
| ) | |
| ] | |
| alternative_output.append(processed_neighbor_output) | |
| interface_scores.append( | |
| quantify_difference_in_label( | |
| interface, original_output, neighbor_output | |
| ) | |
| ) | |
| alternative_outputs.append(alternative_output) | |
| scores.append( | |
| input_component.get_interpretation_scores( | |
| raw_input[i], | |
| neighbor_values, | |
| interface_scores, | |
| masks=masks, | |
| tokens=tokens, | |
| ) | |
| ) | |
| else: | |
| ( | |
| neighbor_values, | |
| interpret_kwargs, | |
| ) = input_component.get_interpretation_neighbors(x) | |
| interface_scores = [] | |
| alternative_output = [] | |
| for neighbor_input in neighbor_values: | |
| neighbor_raw_input[i] = neighbor_input | |
| processed_neighbor_input = [ | |
| input_component.preprocess(neighbor_raw_input[i]) | |
| for i, input_component in enumerate( | |
| interface.input_components | |
| ) | |
| ] | |
| neighbor_output = await interface.call_function( | |
| 0, processed_neighbor_input | |
| ) | |
| neighbor_output = neighbor_output["prediction"] | |
| if len(interface.output_components) == 1: | |
| neighbor_output = [neighbor_output] | |
| processed_neighbor_output = [ | |
| output_component.postprocess(neighbor_output[i]) | |
| for i, output_component in enumerate( | |
| interface.output_components | |
| ) | |
| ] | |
| alternative_output.append(processed_neighbor_output) | |
| interface_scores.append( | |
| quantify_difference_in_label( | |
| interface, original_output, neighbor_output | |
| ) | |
| ) | |
| alternative_outputs.append(alternative_output) | |
| interface_scores = [-score for score in interface_scores] | |
| scores.append( | |
| input_component.get_interpretation_scores( | |
| raw_input[i], | |
| neighbor_values, | |
| interface_scores, | |
| **interpret_kwargs | |
| ) | |
| ) | |
| elif interp == "shap" or interp == "shapley": | |
| try: | |
| import shap # type: ignore | |
| except (ImportError, ModuleNotFoundError): | |
| raise ValueError( | |
| "The package `shap` is required for this interpretation method. Try: `pip install shap`" | |
| ) | |
| input_component = interface.input_components[i] | |
| if not (input_component.interpret_by_tokens): | |
| raise ValueError( | |
| "Input component {} does not support `shap` interpretation".format( | |
| input_component | |
| ) | |
| ) | |
| tokens, _, masks = input_component.tokenize(x) | |
| # construct a masked version of the input | |
| def get_masked_prediction(binary_mask): | |
| masked_xs = input_component.get_masked_inputs(tokens, binary_mask) | |
| preds = [] | |
| for masked_x in masked_xs: | |
| processed_masked_input = copy.deepcopy(processed_input) | |
| processed_masked_input[i] = input_component.preprocess(masked_x) | |
| new_output = utils.synchronize_async( | |
| interface.call_function, 0, processed_masked_input | |
| ) | |
| new_output = new_output["prediction"] | |
| if len(interface.output_components) == 1: | |
| new_output = [new_output] | |
| pred = get_regression_or_classification_value( | |
| interface, original_output, new_output | |
| ) | |
| preds.append(pred) | |
| return np.array(preds) | |
| num_total_segments = len(tokens) | |
| explainer = shap.KernelExplainer( | |
| get_masked_prediction, np.zeros((1, num_total_segments)) | |
| ) | |
| shap_values = explainer.shap_values( | |
| np.ones((1, num_total_segments)), | |
| nsamples=int(interface.num_shap * num_total_segments), | |
| silent=True, | |
| ) | |
| scores.append( | |
| input_component.get_interpretation_scores( | |
| raw_input[i], None, shap_values[0], masks=masks, tokens=tokens | |
| ) | |
| ) | |
| alternative_outputs.append([]) | |
| elif interp is None: | |
| scores.append(None) | |
| alternative_outputs.append([]) | |
| else: | |
| raise ValueError("Unknown intepretation method: {}".format(interp)) | |
| return scores, alternative_outputs | |
| else: # custom interpretation function | |
| processed_input = [ | |
| input_component.preprocess(raw_input[i]) | |
| for i, input_component in enumerate(interface.input_components) | |
| ] | |
| interpreter = interface.interpretation | |
| interpretation = interpreter(*processed_input) | |
| if len(raw_input) == 1: | |
| interpretation = [interpretation] | |
| return interpretation, [] | |
| def diff(original, perturbed): | |
| try: # try computing numerical difference | |
| score = float(original) - float(perturbed) | |
| except ValueError: # otherwise, look at strict difference in label | |
| score = int(not (original == perturbed)) | |
| return score | |
| def quantify_difference_in_label(interface, original_output, perturbed_output): | |
| output_component = interface.output_components[0] | |
| post_original_output = output_component.postprocess(original_output[0]) | |
| post_perturbed_output = output_component.postprocess(perturbed_output[0]) | |
| if isinstance(output_component, Label): | |
| original_label = post_original_output["label"] | |
| perturbed_label = post_perturbed_output["label"] | |
| # Handle different return types of Label interface | |
| if "confidences" in post_original_output: | |
| original_confidence = original_output[0][original_label] | |
| perturbed_confidence = perturbed_output[0][original_label] | |
| score = original_confidence - perturbed_confidence | |
| else: | |
| score = diff(original_label, perturbed_label) | |
| return score | |
| elif isinstance(output_component, Number): | |
| score = diff(post_original_output, post_perturbed_output) | |
| return score | |
| else: | |
| raise ValueError( | |
| "This interpretation method doesn't support the Output component: {}".format( | |
| output_component | |
| ) | |
| ) | |
| def get_regression_or_classification_value( | |
| interface, original_output, perturbed_output | |
| ): | |
| """Used to combine regression/classification for Shap interpretation method.""" | |
| output_component = interface.output_components[0] | |
| post_original_output = output_component.postprocess(original_output[0]) | |
| post_perturbed_output = output_component.postprocess(perturbed_output[0]) | |
| if type(output_component) == Label: | |
| original_label = post_original_output["label"] | |
| perturbed_label = post_perturbed_output["label"] | |
| # Handle different return types of Label interface | |
| if "confidences" in post_original_output: | |
| if math.isnan(perturbed_output[0][original_label]): | |
| return 0 | |
| return perturbed_output[0][original_label] | |
| else: | |
| score = diff( | |
| perturbed_label, original_label | |
| ) # Intentionally inverted order of arguments. | |
| return score | |
| else: | |
| raise ValueError( | |
| "This interpretation method doesn't support the Output component: {}".format( | |
| output_component | |
| ) | |
| ) | |