Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import List, Optional | |
| import networkx as nx | |
| import torch | |
| import llm_transparency_tool.routes.contributions as contributions | |
| from llm_transparency_tool.models.transparent_llm import TransparentLlm | |
| class GraphBuilder: | |
| """ | |
| Constructs the contributions graph with edges given one by one. The resulting graph | |
| is a networkx graph that can be accessed via the `graph` field. It contains the | |
| following types of nodes: | |
| - X0_<token>: the original token. | |
| - A<layer>_<token>: the residual stream after attention at the given layer for the | |
| given token. | |
| - M<layer>_<token>: the ffn block. | |
| - I<layer>_<token>: the residual stream after the ffn block. | |
| """ | |
| def __init__(self, n_layers: int, n_tokens: int): | |
| self._n_layers = n_layers | |
| self._n_tokens = n_tokens | |
| self.graph = nx.DiGraph() | |
| for layer in range(n_layers): | |
| for token in range(n_tokens): | |
| self.graph.add_node(f"A{layer}_{token}") | |
| self.graph.add_node(f"I{layer}_{token}") | |
| self.graph.add_node(f"M{layer}_{token}") | |
| for token in range(n_tokens): | |
| self.graph.add_node(f"X0_{token}") | |
| def get_output_node(self, token: int): | |
| return f"I{self._n_layers - 1}_{token}" | |
| def _add_edge(self, u: str, v: str, weight: float): | |
| # TODO(igortufanov): Here we sum up weights for multi-edges. It happens with | |
| # attention from the current token and the residual edge. Ideally these need to | |
| # be 2 separate edges, but then we need to do a MultiGraph. Multigraph is fine, | |
| # but when we try to traverse it, we face some NetworkX issue with EDGE_OK | |
| # receiving 3 arguments instead of 2. | |
| if self.graph.has_edge(u, v): | |
| self.graph[u][v]["weight"] += weight | |
| else: | |
| self.graph.add_edge(u, v, weight=weight) | |
| def add_attention_edge(self, layer: int, token_from: int, token_to: int, w: float): | |
| self._add_edge( | |
| f"I{layer-1}_{token_from}" if layer > 0 else f"X0_{token_from}", | |
| f"A{layer}_{token_to}", | |
| w, | |
| ) | |
| def add_residual_to_attn(self, layer: int, token: int, w: float): | |
| self._add_edge( | |
| f"I{layer-1}_{token}" if layer > 0 else f"X0_{token}", | |
| f"A{layer}_{token}", | |
| w, | |
| ) | |
| def add_ffn_edge(self, layer: int, token: int, w: float): | |
| self._add_edge(f"A{layer}_{token}", f"M{layer}_{token}", w) | |
| self._add_edge(f"M{layer}_{token}", f"I{layer}_{token}", w) | |
| def add_residual_to_ffn(self, layer: int, token: int, w: float): | |
| self._add_edge(f"A{layer}_{token}", f"I{layer}_{token}", w) | |
| def build_full_graph( | |
| model: TransparentLlm, | |
| batch_i: int = 0, | |
| renormalizing_threshold: Optional[float] = None, | |
| ) -> nx.Graph: | |
| """ | |
| Build the contribution graph for all blocks of the model and all tokens. | |
| model: The transparent llm which already did the inference. | |
| batch_i: Which sentence to use from the batch that was given to the model. | |
| renormalizing_threshold: If specified, will apply renormalizing thresholding to the | |
| contributions. All contributions below the threshold will be erazed and the rest | |
| will be renormalized. | |
| """ | |
| n_layers = model.model_info().n_layers | |
| n_tokens = model.tokens()[batch_i].shape[0] | |
| builder = GraphBuilder(n_layers, n_tokens) | |
| for layer in range(n_layers): | |
| c_attn, c_resid_attn = contributions.get_attention_contributions( | |
| resid_pre=model.residual_in(layer)[batch_i].unsqueeze(0), | |
| resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0), | |
| decomposed_attn=model.decomposed_attn(batch_i, layer).unsqueeze(0), | |
| ) | |
| if renormalizing_threshold is not None: | |
| c_attn, c_resid_attn = contributions.apply_threshold_and_renormalize( | |
| renormalizing_threshold, c_attn, c_resid_attn | |
| ) | |
| for token_from in range(n_tokens): | |
| for token_to in range(n_tokens): | |
| # Sum attention contributions over heads. | |
| c = c_attn[batch_i, token_to, token_from].sum().item() | |
| builder.add_attention_edge(layer, token_from, token_to, c) | |
| for token in range(n_tokens): | |
| builder.add_residual_to_attn( | |
| layer, token, c_resid_attn[batch_i, token].item() | |
| ) | |
| c_ffn, c_resid_ffn = contributions.get_mlp_contributions( | |
| resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0), | |
| resid_post=model.residual_out(layer)[batch_i].unsqueeze(0), | |
| mlp_out=model.ffn_out(layer)[batch_i].unsqueeze(0), | |
| ) | |
| if renormalizing_threshold is not None: | |
| c_ffn, c_resid_ffn = contributions.apply_threshold_and_renormalize( | |
| renormalizing_threshold, c_ffn, c_resid_ffn | |
| ) | |
| for token in range(n_tokens): | |
| builder.add_ffn_edge(layer, token, c_ffn[batch_i, token].item()) | |
| builder.add_residual_to_ffn( | |
| layer, token, c_resid_ffn[batch_i, token].item() | |
| ) | |
| return builder.graph | |
| def build_paths_to_predictions( | |
| graph: nx.Graph, | |
| n_layers: int, | |
| n_tokens: int, | |
| starting_tokens: List[int], | |
| threshold: float, | |
| ) -> List[nx.Graph]: | |
| """ | |
| Given the full graph, this function returns only the trees leading to the specified | |
| tokens. Edges with weight below `threshold` will be ignored. | |
| """ | |
| builder = GraphBuilder(n_layers, n_tokens) | |
| rgraph = graph.reverse() | |
| search_graph = nx.subgraph_view( | |
| rgraph, filter_edge=lambda u, v: rgraph[u][v]["weight"] > threshold | |
| ) | |
| result = [] | |
| for start in starting_tokens: | |
| assert start < n_tokens | |
| assert start >= 0 | |
| edges = nx.edge_dfs(search_graph, source=builder.get_output_node(start)) | |
| tree = search_graph.edge_subgraph(edges) | |
| # Reverse the edges because the dfs was going from upper layer downwards. | |
| result.append(tree.reverse()) | |
| return result | |