Spaces:
Running
Running
| import textattack | |
| import transformers | |
| from FlowCorrector import Flow_Corrector | |
| import torch | |
| import torch.nn.functional as F | |
| def count_matching_classes(original, corrected): | |
| if len(original) != len(corrected): | |
| raise ValueError("Arrays must have the same length") | |
| matching_count = 0 | |
| for i in range(len(corrected)): | |
| if original[i] == corrected[i]: | |
| matching_count += 1 | |
| return matching_count | |
| if __name__ == "main" : | |
| # Load model, tokenizer, and model_wrapper | |
| model = transformers.AutoModelForSequenceClassification.from_pretrained( | |
| "textattack/bert-base-uncased-ag-news" | |
| ) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| "textattack/bert-base-uncased-ag-news" | |
| ) | |
| model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) | |
| # Construct our four components for `Attack` | |
| from textattack.constraints.pre_transformation import ( | |
| RepeatModification, | |
| StopwordModification, | |
| ) | |
| from textattack.constraints.semantics import WordEmbeddingDistance | |
| from textattack.transformations import WordSwapEmbedding | |
| from textattack.search_methods import GreedyWordSwapWIR | |
| goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper) | |
| constraints = [ | |
| RepeatModification(), | |
| StopwordModification(), | |
| WordEmbeddingDistance(min_cos_sim=0.9), | |
| ] | |
| transformation = WordSwapEmbedding(max_candidates=50) | |
| search_method = GreedyWordSwapWIR(wir_method="weighted-saliency") | |
| # Construct the actual attack | |
| attack = textattack.Attack(goal_function, constraints, transformation, search_method) | |
| attack.cuda_() | |
| # intialisation de coreecteur | |
| corrector = Flow_Corrector( | |
| attack, | |
| word_rank_file="en_full_ranked.json", | |
| word_freq_file="en_full_freq.json", | |
| ) | |
| # All these texts are adverserial ones | |
| with open('perturbed_texts_ag_news.txt', 'r') as f: | |
| detected_texts = [line.strip() for line in f] | |
| #These are orginal texts in same order of adverserial ones | |
| with open("original_texts_ag_news.txt", "r") as f: | |
| original_texts = [line.strip() for line in f] | |
| victim_model = attack.goal_function.model | |
| # getting original labels for benchmarking later | |
| original_classes = [ | |
| torch.argmax(F.softmax(victim_model(original_text), dim=1)).item() | |
| for original_text in original_texts | |
| ] | |
| """ 0 :World | |
| 1 : Sports | |
| 2 : Business | |
| 3 : Sci/Tech""" | |
| corrected_classes = corrector.correct(original_texts) | |
| print(f"match {count_matching_classes()}") | |