Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import numpy as np | |
| SEP = '[SEP]' | |
| CLS = '[CLS]' | |
| MASK = '[MASK]' | |
| def drop_bad_inds(arr, left_drop, right_drop): | |
| """Given the 4d array returned by attentions of shape (n_layer, n_head, n_left_text, n_right_text), | |
| return that array modified to drop ind1 from n_left_text and ind2 from n_right_text | |
| """ | |
| # print("Length of left drop: ", len(left_drop)) | |
| # print("Length of right drop: ", len(left_drop)) | |
| print("Shape of arr: ", arr.shape) | |
| arr = arr[:, :, ~left_drop, :] | |
| # Keys and queries don't match in the final dimension | |
| if arr.shape[-1] == len(right_drop): | |
| arr = arr[:, :, :, ~right_drop] | |
| return arr | |
| def strip_attention(attention): | |
| """Given an attention output of the BERT model, | |
| return the same object without CLS and SEP token weightings | |
| NOTE: Not currently fixing key and query | |
| """ | |
| attention_out = {} | |
| # Iterate through sentence combinations | |
| # Need queries, keys, att, left_text, right_text | |
| for i, (k, v) in enumerate(attention.items()): | |
| stripped_resp = {} | |
| left_tokens = np.array(v['left_text']) | |
| right_tokens = np.array(v['right_text']) | |
| att = np.array(v['att']) | |
| # key = np.array(v['keys']) | |
| # quer = np.array(v['queries']) | |
| left_drop = (left_tokens == CLS) | (left_tokens == SEP) | |
| right_drop = (right_tokens == CLS) | (right_tokens == SEP) | |
| att_out = drop_bad_inds(att, left_drop, right_drop) | |
| # key_out = drop_bad_inds(key, left_drop, right_drop) | |
| # quer_out = drop_bad_inds(quer, left_drop, right_drop) | |
| left_out = left_tokens[~left_drop] | |
| right_out = right_tokens[~right_drop] | |
| # assert att_out.shape[:3] == key_out.shape[:3] == quer_out.shape[:3] | |
| assert att_out.shape[2] == len(left_out) | |
| assert att_out.shape[3] == len(right_out) | |
| stripped_resp['att'] = att_out.tolist() | |
| stripped_resp['keys'] = v['keys'] | |
| stripped_resp['queries'] = v['queries'] | |
| stripped_resp['left_text'] = left_out.tolist() | |
| stripped_resp['right_text'] = right_out.tolist() | |
| attention_out[k] = stripped_resp | |
| return attention_out | |
| def mask_attention(deets, maskA, maskB): | |
| """Deets have form: | |
| tokens_a, tokens_b, query_tensor.data.numpy(), key_tensor.data.numpy(), attn_tensor.data.numpy() | |
| Take the first two in tuple and mask according to maskA and maskB which are lists of indices to mask | |
| """ | |
| tokens_a = np.array(deets[0]) | |
| tokens_a[maskA] = MASK | |
| tokens_a.tolist() | |
| tokens_b = np.array(deets[1]) | |
| tokens_b[maskb] = MASK | |
| tokens_b.tolist() | |
| deets[0] = tokens_a.tolist() | |
| deets[1] = tokens_b.tolist() | |
| return deets |