Spaces:
Runtime error
Runtime error
fix to dashboard not loading
Browse files- app.py +25 -18
- data_mnist +1 -1
- utils.py +20 -9
app.py
CHANGED
|
@@ -37,7 +37,7 @@ os.makedirs(LOCAL_DIR,exist_ok=True)
|
|
| 37 |
|
| 38 |
|
| 39 |
|
| 40 |
-
|
| 41 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 42 |
MODEL_REPO = 'mnist-adversarial-model'
|
| 43 |
HF_DATASET ="mnist-adversarial-dataset"
|
|
@@ -74,10 +74,11 @@ class MNISTAdversarial_Dataset(Dataset):
|
|
| 74 |
|
| 75 |
image_path =os.path.join(self.FOLDER,'image.png')
|
| 76 |
if os.path.exists(image_path) and os.path.exists(metadata_path):
|
| 77 |
-
img = Image.open(image_path)
|
| 78 |
-
self.images.append(img)
|
| 79 |
metadata = read_json_lines(metadata_path)
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
| 81 |
assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers."
|
| 82 |
def __len__(self):
|
| 83 |
return len(self.images)
|
|
@@ -395,8 +396,15 @@ def flag(input_image,correct_result,adversarial_number):
|
|
| 395 |
return output,adversarial_number
|
| 396 |
|
| 397 |
def get_number_dict(DATA_DIR):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
files = [f.name for f in os.scandir(DATA_DIR)]
|
| 399 |
-
|
|
|
|
| 400 |
numbers_count = Counter(numbers)
|
| 401 |
numbers_count_keys = list(numbers_count.keys())
|
| 402 |
numbers_count_values = [numbers_count[k] for k in numbers_count_keys]
|
|
@@ -425,10 +433,8 @@ def get_statistics():
|
|
| 425 |
repo.git_pull()
|
| 426 |
DATA_DIR = './data_mnist/data'
|
| 427 |
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
|
| 428 |
-
|
| 429 |
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
|
| 430 |
-
|
| 431 |
-
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit")
|
| 432 |
|
| 433 |
fig_d, ax_d = plt.subplots(tight_layout=True)
|
| 434 |
|
|
@@ -440,26 +446,25 @@ def get_statistics():
|
|
| 440 |
ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
|
| 441 |
except Exception:
|
| 442 |
continue
|
|
|
|
| 443 |
|
| 444 |
else:
|
| 445 |
metric_dict={}
|
| 446 |
|
| 447 |
fig_d.legend()
|
| 448 |
ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step")
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
<p> ✅ Statistics loaded successfully!</p>
|
| 452 |
</div>
|
| 453 |
"""
|
| 454 |
-
|
| 455 |
# Plot for total test accuracy for all digits
|
| 456 |
fig_all, ax_all = plt.subplots(tight_layout=True)
|
| 457 |
x_i = [i+1 for i in range(len(metric_dict['all']))]
|
| 458 |
|
| 459 |
ax_all.plot(x_i, metric_dict['all'])
|
| 460 |
-
fig_all.legend()
|
| 461 |
ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits")
|
| 462 |
-
|
|
|
|
| 463 |
return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_
|
| 464 |
|
| 465 |
|
|
@@ -485,7 +490,7 @@ def main():
|
|
| 485 |
|
| 486 |
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
|
| 487 |
|
| 488 |
-
|
| 489 |
flag_btn = gr.Button("Flag")
|
| 490 |
|
| 491 |
output_result = gr.outputs.HTML()
|
|
@@ -496,8 +501,9 @@ def main():
|
|
| 496 |
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number])
|
| 497 |
|
| 498 |
with gr.TabItem('Dashboard') as dashboard:
|
| 499 |
-
|
| 500 |
-
|
|
|
|
| 501 |
</div>
|
| 502 |
""")
|
| 503 |
|
|
@@ -508,7 +514,8 @@ def main():
|
|
| 508 |
gr.Markdown(DASHBOARD_EXPLANATION_TEST)
|
| 509 |
test_results_all=gr.Plot(type="matplotlib")
|
| 510 |
|
| 511 |
-
dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,
|
|
|
|
| 512 |
|
| 513 |
|
| 514 |
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
|
| 40 |
+
GET_STATISTICS_MESSAGE = "Get Statistics"
|
| 41 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 42 |
MODEL_REPO = 'mnist-adversarial-model'
|
| 43 |
HF_DATASET ="mnist-adversarial-dataset"
|
|
|
|
| 74 |
|
| 75 |
image_path =os.path.join(self.FOLDER,'image.png')
|
| 76 |
if os.path.exists(image_path) and os.path.exists(metadata_path):
|
|
|
|
|
|
|
| 77 |
metadata = read_json_lines(metadata_path)
|
| 78 |
+
if metadata is not None:
|
| 79 |
+
img = Image.open(image_path)
|
| 80 |
+
self.images.append(img)
|
| 81 |
+
self.numbers.append(metadata[0]['correct_number'])
|
| 82 |
assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers."
|
| 83 |
def __len__(self):
|
| 84 |
return len(self.images)
|
|
|
|
| 396 |
return output,adversarial_number
|
| 397 |
|
| 398 |
def get_number_dict(DATA_DIR):
|
| 399 |
+
"""
|
| 400 |
+
It takes a directory as input, and returns a list of the number of times each number appears in the
|
| 401 |
+
metadata.jsonl files in that directory
|
| 402 |
+
|
| 403 |
+
:param DATA_DIR: The directory where the data is stored
|
| 404 |
+
"""
|
| 405 |
files = [f.name for f in os.scandir(DATA_DIR)]
|
| 406 |
+
metadata_jsons = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl')) for f in files]
|
| 407 |
+
numbers = [m[0]['correct_number'] for m in metadata_jsons if m is not None]
|
| 408 |
numbers_count = Counter(numbers)
|
| 409 |
numbers_count_keys = list(numbers_count.keys())
|
| 410 |
numbers_count_values = [numbers_count[k] for k in numbers_count_keys]
|
|
|
|
| 433 |
repo.git_pull()
|
| 434 |
DATA_DIR = './data_mnist/data'
|
| 435 |
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
|
|
|
|
| 436 |
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
|
| 437 |
+
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit",True)
|
|
|
|
| 438 |
|
| 439 |
fig_d, ax_d = plt.subplots(tight_layout=True)
|
| 440 |
|
|
|
|
| 446 |
ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
|
| 447 |
except Exception:
|
| 448 |
continue
|
| 449 |
+
ax_d.set_xticks(range(0, len(metric_dict['0'])+1, 1))
|
| 450 |
|
| 451 |
else:
|
| 452 |
metric_dict={}
|
| 453 |
|
| 454 |
fig_d.legend()
|
| 455 |
ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step")
|
| 456 |
+
done_html = f"""<div style="color: green">
|
| 457 |
+
<p> ✅ Statistics loaded successfully! Click `{GET_STATISTICS_MESSAGE}`to reload.</p>
|
|
|
|
| 458 |
</div>
|
| 459 |
"""
|
|
|
|
| 460 |
# Plot for total test accuracy for all digits
|
| 461 |
fig_all, ax_all = plt.subplots(tight_layout=True)
|
| 462 |
x_i = [i+1 for i in range(len(metric_dict['all']))]
|
| 463 |
|
| 464 |
ax_all.plot(x_i, metric_dict['all'])
|
|
|
|
| 465 |
ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits")
|
| 466 |
+
ax_all.set_xticks(range(0, x_i[-1]+1, 1))
|
| 467 |
+
|
| 468 |
return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_
|
| 469 |
|
| 470 |
|
|
|
|
| 490 |
|
| 491 |
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
|
| 492 |
|
| 493 |
+
gr.Markdown('Please wait a while after you press `Flag`. It takes time.')
|
| 494 |
flag_btn = gr.Button("Flag")
|
| 495 |
|
| 496 |
output_result = gr.outputs.HTML()
|
|
|
|
| 501 |
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number])
|
| 502 |
|
| 503 |
with gr.TabItem('Dashboard') as dashboard:
|
| 504 |
+
get_stat = gr.Button(f'{GET_STATISTICS_MESSAGE}')
|
| 505 |
+
notification = gr.HTML(f"""<div style="color: green">
|
| 506 |
+
<p> ⌛ Click `{GET_STATISTICS_MESSAGE}` to generate statistics... </p>
|
| 507 |
</div>
|
| 508 |
""")
|
| 509 |
|
|
|
|
| 514 |
gr.Markdown(DASHBOARD_EXPLANATION_TEST)
|
| 515 |
test_results_all=gr.Plot(type="matplotlib")
|
| 516 |
|
| 517 |
+
#dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats])
|
| 518 |
+
get_stat.click(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,test_results_all,notification,stats])
|
| 519 |
|
| 520 |
|
| 521 |
|
data_mnist
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit 0d5120c897f5b71d2f99b7fb2ef5dc28e3d7000d
|
utils.py
CHANGED
|
@@ -3,6 +3,7 @@ import json
|
|
| 3 |
import hashlib
|
| 4 |
import random
|
| 5 |
import string
|
|
|
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
|
| 8 |
TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
|
|
@@ -25,7 +26,7 @@ MODEL_IS_WRONG = """
|
|
| 25 |
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
|
| 26 |
|
| 27 |
DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
|
| 28 |
-
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers."
|
| 29 |
|
| 30 |
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
|
| 31 |
|
|
@@ -39,12 +40,16 @@ def read_json(file):
|
|
| 39 |
return json.load(f)
|
| 40 |
|
| 41 |
def read_json_lines(file):
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def json_dump(thing):
|
|
@@ -63,11 +68,17 @@ def dump_json(thing,file):
|
|
| 63 |
json.dump(thing,f)
|
| 64 |
|
| 65 |
|
| 66 |
-
def plot_bar(value,name,x_name,y_name,title):
|
| 67 |
fig, ax = plt.subplots(tight_layout=True)
|
| 68 |
|
| 69 |
ax.set(xlabel=x_name, ylabel=y_name,title=title)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
ax.barh(name, value)
|
| 72 |
|
| 73 |
-
return ax.figure
|
|
|
|
| 3 |
import hashlib
|
| 4 |
import random
|
| 5 |
import string
|
| 6 |
+
import warnings
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
|
| 9 |
TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
|
|
|
|
| 26 |
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
|
| 27 |
|
| 28 |
DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
|
| 29 |
+
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers combined."
|
| 30 |
|
| 31 |
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
|
| 32 |
|
|
|
|
| 40 |
return json.load(f)
|
| 41 |
|
| 42 |
def read_json_lines(file):
|
| 43 |
+
try:
|
| 44 |
+
with open(file,'r',encoding="utf8") as f:
|
| 45 |
+
lines = f.readlines()
|
| 46 |
+
data=[]
|
| 47 |
+
for l in lines:
|
| 48 |
+
data.append(json.loads(l))
|
| 49 |
+
return data
|
| 50 |
+
except Exception as err:
|
| 51 |
+
warnings.warn(f"{err}")
|
| 52 |
+
return None
|
| 53 |
|
| 54 |
|
| 55 |
def json_dump(thing):
|
|
|
|
| 68 |
json.dump(thing,f)
|
| 69 |
|
| 70 |
|
| 71 |
+
def plot_bar(value,name,x_name,y_name,title,set_yticks=False,set_xticks=False):
|
| 72 |
fig, ax = plt.subplots(tight_layout=True)
|
| 73 |
|
| 74 |
ax.set(xlabel=x_name, ylabel=y_name,title=title)
|
| 75 |
|
| 76 |
+
if set_yticks:
|
| 77 |
+
ax.set_yticks(range(min(name), max(name)+1, 1))
|
| 78 |
+
if set_xticks:
|
| 79 |
+
ax.set_xticks(range(min(name), max(name)+1, 1))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
ax.barh(name, value)
|
| 83 |
|
| 84 |
+
return ax.figure
|