Spaces:
Runtime error
Runtime error
| import json | |
| import hashlib | |
| import random | |
| import string | |
| import warnings | |
| import matplotlib.pyplot as plt | |
| TITLE = "# MNIST Adversarial: Try to fool this MNIST model" | |
| description = """This project is about dynamic adversarial data collection (DADC). | |
| The basic idea is to collect “adversarial data” - the kind of data that is difficult for a model to predict correctly. | |
| This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label. | |
| """ | |
| WHAT_TO_DO=""" | |
| ### What to do: | |
| 1. Draw any number from 0-9. The model will automatically try to predict it after drawing. | |
| 2. If the model misclassifies it, Flag that example. | |
| 3. This will add your (adversarial) example to a dataset on which the model will be trained later. | |
| 4. The model will finetune on the adversarial samples after every __{num_samples}__ samples have been generated. | |
| """ | |
| MODEL_IS_WRONG = """ | |
| --- | |
| ### Did the model get it wrong or has a low confidence? Choose the correct prediction below and flag it. When you flag it, the instance is saved [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset) and the model learns from it periodically. | |
| """ | |
| DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>" | |
| DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit." | |
| DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers combined." | |
| STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)." | |
| def get_unique_name(): | |
| return ''.join([random.choice(string.ascii_letters | |
| + string.digits) for n in range(32)]) | |
| def read_json(file): | |
| with open(file,'r',encoding="utf8") as f: | |
| return json.load(f) | |
| def read_json_lines(file): | |
| try: | |
| with open(file,'r',encoding="utf8") as f: | |
| lines = f.readlines() | |
| data=[] | |
| for l in lines: | |
| data.append(json.loads(l)) | |
| return data | |
| except Exception as err: | |
| warnings.warn(f"{err}") | |
| return None | |
| def json_dump(thing): | |
| return json.dumps(thing, | |
| ensure_ascii=False, | |
| sort_keys=True, | |
| indent=None, | |
| separators=(',', ':')) | |
| def get_hash(thing): # stable-hashing | |
| return str(hashlib.md5(json_dump(thing).encode('utf-8')).hexdigest()) | |
| def dump_json(thing,file): | |
| with open(file,'w+',encoding="utf8") as f: | |
| json.dump(thing,f) | |
| def plot_bar(value,name,x_name,y_name,title,set_yticks=False,set_xticks=False): | |
| fig, ax = plt.subplots(tight_layout=True) | |
| ax.set(xlabel=x_name, ylabel=y_name,title=title) | |
| if set_yticks: | |
| ax.set_yticks(range(min(name), max(name)+1, 1)) | |
| if set_xticks: | |
| ax.set_xticks(range(min(name), max(name)+1, 1)) | |
| ax.barh(name, value) | |
| return ax.figure | |