Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| import textattack | |
| device = os.environ.get( | |
| "TA_DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ) | |
| def html_style_from_dict(style_dict): | |
| """Turns. | |
| { 'color': 'red', 'height': '100px'} | |
| into | |
| style: "color: red; height: 100px" | |
| """ | |
| style_str = "" | |
| for key in style_dict: | |
| style_str += key + ": " + style_dict[key] + ";" | |
| return 'style="{}"'.format(style_str) | |
| def html_table_from_rows(rows, title=None, header=None, style_dict=None): | |
| # Stylize the container div. | |
| if style_dict: | |
| table_html = "<div {}>".format(html_style_from_dict(style_dict)) | |
| else: | |
| table_html = "<div>" | |
| # Print the title string. | |
| if title: | |
| table_html += "<h1>{}</h1>".format(title) | |
| # Construct each row as HTML. | |
| table_html = '<table class="table">' | |
| if header: | |
| table_html += "<tr>" | |
| for element in header: | |
| table_html += "<th>" | |
| table_html += str(element) | |
| table_html += "</th>" | |
| table_html += "</tr>" | |
| for row in rows: | |
| table_html += "<tr>" | |
| for element in row: | |
| table_html += "<td>" | |
| table_html += str(element) | |
| table_html += "</td>" | |
| table_html += "</tr>" | |
| # Close the table and print to screen. | |
| table_html += "</table></div>" | |
| return table_html | |
| def get_textattack_model_num_labels(model_name, model_path): | |
| """Reads `train_args.json` and gets the number of labels for a trained | |
| model, if present.""" | |
| model_cache_path = textattack.shared.utils.download_from_s3(model_path) | |
| train_args_path = os.path.join(model_cache_path, "train_args.json") | |
| if not os.path.exists(train_args_path): | |
| textattack.shared.logger.warn( | |
| f"train_args.json not found in model path {model_path}. Defaulting to 2 labels." | |
| ) | |
| return 2 | |
| else: | |
| args = json.loads(open(train_args_path).read()) | |
| return args.get("num_labels", 2) | |
| def load_textattack_model_from_path(model_name, model_path): | |
| """Loads a pre-trained TextAttack model from its name and path. | |
| For example, model_name "lstm-yelp" and model path | |
| "models/classification/lstm/yelp". | |
| """ | |
| colored_model_name = textattack.shared.utils.color_text( | |
| model_name, color="blue", method="ansi" | |
| ) | |
| if model_name.startswith("lstm"): | |
| num_labels = get_textattack_model_num_labels(model_name, model_path) | |
| textattack.shared.logger.info( | |
| f"Loading pre-trained TextAttack LSTM: {colored_model_name}" | |
| ) | |
| model = textattack.models.helpers.LSTMForClassification( | |
| model_path=model_path, num_labels=num_labels | |
| ) | |
| elif model_name.startswith("cnn"): | |
| num_labels = get_textattack_model_num_labels(model_name, model_path) | |
| textattack.shared.logger.info( | |
| f"Loading pre-trained TextAttack CNN: {colored_model_name}" | |
| ) | |
| model = textattack.models.helpers.WordCNNForClassification( | |
| model_path=model_path, num_labels=num_labels | |
| ) | |
| elif model_name.startswith("t5"): | |
| model = textattack.models.helpers.T5ForTextToText(model_path) | |
| else: | |
| raise ValueError(f"Unknown textattack model {model_path}") | |
| return model | |
| def set_seed(random_seed): | |
| random.seed(random_seed) | |
| np.random.seed(random_seed) | |
| torch.manual_seed(random_seed) | |
| torch.cuda.manual_seed(random_seed) | |
| def hashable(key): | |
| try: | |
| hash(key) | |
| return True | |
| except TypeError: | |
| return False | |
| def sigmoid(n): | |
| return 1 / (1 + np.exp(-n)) | |
| GLOBAL_OBJECTS = {} | |
| ARGS_SPLIT_TOKEN = "^" | |