Spaces:
Sleeping
Sleeping
| import copy | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from tqdm import tqdm | |
| import nltk | |
| import string | |
| from copy import deepcopy | |
| from torchprofile import profile_macs | |
| from datetime import datetime | |
| from transformers import BertTokenizer, BertModel, BertForMaskedLM | |
| from nltk.tokenize.treebank import TreebankWordTokenizer, TreebankWordDetokenizer | |
| from blackbox_utils.Attack_base import MyAttack | |
| class CharacterAttack(MyAttack): | |
| # TODO: 存储一个list每次只修改不同的token位置 | |
| def __init__(self, name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key): | |
| super(CharacterAttack, self).__init__(name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key) | |
| def compute_importance(self, text): | |
| current_tensor = self.preprocess_function(text)["input_ids"][0] | |
| # print(current_tensor) | |
| word_losses = {} | |
| for idx in range(1,len(current_tensor)-1): | |
| # print(current_tensor[:idx]) | |
| # print(current_tensor[idx+1:]) | |
| sentence_tokens_without = torch.cat([current_tensor[:idx],current_tensor[idx + 1:]]) | |
| sentence_without = self.tokenizer.decode(sentence_tokens_without) | |
| sentence_without = [sentence_without,text[1]] | |
| word_losses[int(current_tensor[idx])] = self.compute_loss(sentence_without) | |
| word_losses = [k for k, _ in sorted(word_losses.items(), key=lambda item: item[1], reverse=True)] | |
| return word_losses | |
| def compute_loss(self, text): | |
| inputs = self.preprocess_function(text) | |
| shift_inputs = (inputs['input_ids'],inputs['attention_mask'],inputs['token_type_ids']) | |
| # toc = datetime.now() | |
| macs = profile_macs(self.model, shift_inputs) | |
| # tic = datetime.now() | |
| # print((tic-toc).total_seconds()) | |
| result = self.random_tokenizer(*inputs, padding=self.padding, max_length=self.max_length, truncation=True) | |
| token_length = len(result["input_ids"]) | |
| macs_per_token = macs/(token_length*10**8) | |
| return self.predict(macs_per_token) | |
| def mutation(self, current_adv_text): | |
| current_tensor = self.preprocess_function(current_adv_text) | |
| # print(current_tensor) | |
| current_tensor = current_tensor["input_ids"][0] | |
| # print(current_tensor) | |
| new_strings = self.character_replace_mutation(current_adv_text, current_tensor) | |
| return new_strings | |
| def transfer(c: str): | |
| if c in string.ascii_lowercase: | |
| return c.upper() | |
| elif c in string.ascii_uppercase: | |
| return c.lower() | |
| return c | |
| def character_replace_mutation(self, current_text, current_tensor): | |
| important_tensor = self.compute_importance(current_text) | |
| # current_string = [self.tokenizer.decoder[int(t)] for t in current_tensor] | |
| new_strings = [current_text] | |
| # 遍历每个vocabulary,查找文本有的第一个token | |
| # print(current_tensor) | |
| for t in important_tensor: | |
| if int(t) not in current_tensor: | |
| continue | |
| ori_decode_token = self.tokenizer.decode([int(t)]) | |
| # print(ori_decode_token) | |
| # if self.space_token in ori_decode_token: | |
| # ori_token = ori_decode_token.replace(self.space_token, '') | |
| # else: | |
| ori_token = ori_decode_token | |
| # 如果只有一个长度 | |
| if len(ori_token) == 1 or ori_token not in current_text[0]: #todo | |
| continue | |
| # 随意插入一个字符 | |
| candidate = [ori_token[:i] + insert + ori_token[i:] for i in range(len(ori_token)) for insert in self.insert_character] | |
| # 随意更换一个大小写 | |
| candidate += [ori_token[:i - 1] + self.transfer(ori_token[i - 1]) + ori_token[i:] for i in range(1, len(ori_token))] | |
| # print(candidate) | |
| # 最多只替换一次 | |
| new_strings += [[current_text[0].replace(ori_token, c, 1),current_text[1]] for c in candidate] | |
| # ori_tensor_pos = current_tensor.eq(int(t)).nonzero() | |
| # | |
| # for p in ori_tensor_pos: | |
| # new_strings += [current_string[:p] + c + current_string[p + 1:] for c in candidate] | |
| # 存在一个有效的改动就返回 | |
| if len(new_strings) > 1: | |
| return new_strings | |
| return new_strings |