Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| import torchvision | |
| from PIL import Image | |
| from utils import * | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from huggingface_hub import Repository, upload_file | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| from collections import Counter | |
| with open('app.css','r') as f: | |
| BLOCK_CSS = f.read() | |
| n_epochs = 10 | |
| batch_size_train = 128 | |
| batch_size_test = 1000 | |
| learning_rate = 0.01 | |
| adv_learning_rate= 0.001 | |
| momentum = 0.5 | |
| log_interval = 10 | |
| random_seed = 1 | |
| TRAIN_CUTOFF = 10 | |
| TEST_PER_SAMPLE = 5000 | |
| DASHBOARD_EXPLANATION = DASHBOARD_EXPLANATION.format(TEST_PER_SAMPLE=TEST_PER_SAMPLE) | |
| WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF) | |
| MODEL_PATH = 'model' | |
| METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json') | |
| MODEL_WEIGHTS_PATH = os.path.join(MODEL_PATH,'mnist_model.pth') | |
| OPTIMIZER_PATH = os.path.join(MODEL_PATH,'optimizer.pth') | |
| REPOSITORY_DIR = "data" | |
| LOCAL_DIR = 'data_local' | |
| os.makedirs(LOCAL_DIR,exist_ok=True) | |
| GET_STATISTICS_MESSAGE = "Get Statistics" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| MODEL_REPO = 'mnist-adversarial-model' | |
| HF_DATASET ="mnist-adversarial-dataset" | |
| DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}" | |
| MODEL_REPO_URL = f"https://huggingface.co/model/chrisjay/{MODEL_REPO}" | |
| repo = Repository( | |
| local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN | |
| ) | |
| repo.git_pull() | |
| model_repo = Repository( | |
| local_dir=MODEL_PATH, clone_from=MODEL_REPO_URL, use_auth_token=HF_TOKEN, repo_type="model" | |
| ) | |
| model_repo.git_pull() | |
| torch.backends.cudnn.enabled = False | |
| torch.manual_seed(random_seed) | |
| class MNISTAdversarial_Dataset(Dataset): | |
| def __init__(self,data_dir,transform): | |
| repo.git_pull() | |
| self.data_dir = os.path.join(data_dir,'data') | |
| self.transform = transform | |
| files = [f.name for f in os.scandir(self.data_dir)] | |
| self.images = [] | |
| self.numbers = [] | |
| for f in files: | |
| self.FOLDER = os.path.join(os.path.join(self.data_dir,f)) | |
| metadata_path = os.path.join(self.FOLDER,'metadata.jsonl') | |
| image_path =os.path.join(self.FOLDER,'image.png') | |
| if os.path.exists(image_path) and os.path.exists(metadata_path): | |
| metadata = read_json_lines(metadata_path) | |
| if metadata is not None: | |
| img = Image.open(image_path) | |
| self.images.append(img) | |
| self.numbers.append(metadata[0]['correct_number']) | |
| assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers." | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self,idx): | |
| img, label = self.images[idx], self.numbers[idx] | |
| img = self.transform(img) | |
| return img, label | |
| class MNISTCorrupted_By_Digit(Dataset): | |
| def __init__(self,transform,digit,limit=TEST_PER_SAMPLE): | |
| self.transform = transform | |
| self.digit = digit | |
| corrupted_dir="./mnist_c" | |
| files = [f.name for f in os.scandir(corrupted_dir)] | |
| images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files] | |
| labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files] | |
| self.data = np.vstack(images) | |
| self.labels = np.hstack(labels) | |
| assert (self.data.shape[0] == self.labels.shape[0]) | |
| mask = self.labels == self.digit | |
| data_masked = self.data[mask] | |
| # Just to be on the safe side, ensure limit is more than the minimum | |
| limit = min(limit,data_masked.shape[0]) | |
| self.data_for_use = data_masked[:limit] | |
| self.labels_for_use = self.labels[mask][:limit] | |
| assert (self.data_for_use.shape[0] == self.labels_for_use.shape[0]) | |
| def __len__(self): | |
| return len(self.data_for_use) | |
| def __getitem__(self,idx): | |
| if torch.is_tensor(idx): | |
| idx = idx.tolist() | |
| image = self.data_for_use[idx] | |
| label = self.labels_for_use[idx] | |
| if self.transform: | |
| image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms | |
| image = self.transform(image_pil) | |
| return image, label | |
| class MNISTCorrupted(Dataset): | |
| def __init__(self,transform): | |
| self.transform = transform | |
| corrupted_dir="./mnist_c" | |
| files = [f.name for f in os.scandir(corrupted_dir)] | |
| images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:TEST_PER_SAMPLE] for f in files] | |
| labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:TEST_PER_SAMPLE] for f in files] | |
| self.data = np.vstack(images) | |
| self.labels = np.hstack(labels) | |
| assert (self.data.shape[0] == self.labels.shape[0]) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| if torch.is_tensor(idx): | |
| idx = idx.tolist() | |
| image = self.data[idx] | |
| label = self.labels[idx] | |
| if self.transform: | |
| image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms | |
| image = self.transform(image_pil) | |
| return image, label | |
| TRAIN_TRANSFORM = torchvision.transforms.Compose([ | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Normalize( | |
| (0.1307,), (0.3081,)) | |
| ]) | |
| test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM), | |
| batch_size=batch_size_test, shuffle=False) | |
| # Source: https://nextjournal.com/gkoehler/pytorch-mnist | |
| class MNIST_Model(nn.Module): | |
| def __init__(self): | |
| super(MNIST_Model, self).__init__() | |
| self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | |
| self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | |
| self.conv2_drop = nn.Dropout2d() | |
| self.fc1 = nn.Linear(320, 50) | |
| self.fc2 = nn.Linear(50, 10) | |
| def forward(self, x): | |
| x = F.relu(F.max_pool2d(self.conv1(x), 2)) | |
| x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | |
| x = x.view(-1, 320) | |
| x = F.relu(self.fc1(x)) | |
| x = F.dropout(x, training=self.training) | |
| x = self.fc2(x) | |
| return F.log_softmax(x) | |
| def train(epochs,network,optimizer,train_loader): | |
| train_losses=[] | |
| network.train() | |
| for epoch in range(epochs): | |
| for batch_idx, (data, target) in enumerate(train_loader): | |
| optimizer.zero_grad() | |
| output = network(data) | |
| loss = F.nll_loss(output, target) | |
| loss.backward() | |
| optimizer.step() | |
| if batch_idx % log_interval == 0: | |
| print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |
| epoch, batch_idx * len(data), len(train_loader.dataset), | |
| 100. * batch_idx / len(train_loader), loss.item())) | |
| train_losses.append(loss.item()) | |
| torch.save(network.state_dict(), MODEL_WEIGHTS_PATH) | |
| torch.save(optimizer.state_dict(), OPTIMIZER_PATH) | |
| def test(): | |
| test_losses=[] | |
| network.eval() | |
| test_loss = 0 | |
| correct = 0 | |
| with torch.no_grad(): | |
| for data, target in test_loader: | |
| output = network(data) | |
| test_loss += F.nll_loss(output, target, size_average=False).item() | |
| pred = output.data.max(1, keepdim=True)[1] | |
| correct += pred.eq(target.data.view_as(pred)).sum() | |
| test_loss /= len(test_loader.dataset) | |
| test_losses.append(test_loss) | |
| acc = 100. * correct / len(test_loader.dataset) | |
| acc = acc.item() | |
| test_metric = '〽Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format( | |
| test_loss,acc) | |
| print(test_metric) | |
| return test_metric,acc | |
| random_seed = 1 | |
| torch.backends.cudnn.enabled = False | |
| torch.manual_seed(random_seed) | |
| network = MNIST_Model() | |
| optimizer = optim.SGD(network.parameters(), lr=learning_rate, | |
| momentum=momentum) | |
| train_loader = torch.utils.data.DataLoader( | |
| torchvision.datasets.MNIST('./files/', train=True, download=True, | |
| transform=TRAIN_TRANSFORM), | |
| batch_size=batch_size_train, shuffle=True) | |
| test_iid_loader = torch.utils.data.DataLoader( | |
| torchvision.datasets.MNIST('./files/', train=False, download=True, | |
| transform=TRAIN_TRANSFORM), | |
| batch_size=batch_size_test, shuffle=True) | |
| model_state_dict = MODEL_WEIGHTS_PATH | |
| optimizer_state_dict = OPTIMIZER_PATH | |
| if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict): | |
| network_state_dict = torch.load(model_state_dict) | |
| network.load_state_dict(network_state_dict) | |
| optimizer_state_dict = torch.load(optimizer_state_dict) | |
| optimizer.load_state_dict(optimizer_state_dict) | |
| # Train model | |
| #n_epochs=20 | |
| #train(n_epochs,network,optimizer,train_loader) | |
| #test() | |
| def train_and_test(train_model=True): | |
| if train_model: | |
| # Train for one epoch and test | |
| train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM) | |
| train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True) | |
| train(n_epochs,network,optimizer,train_loader) | |
| test_metric,test_acc = test() | |
| network.eval() | |
| if os.path.exists(METRIC_PATH): | |
| metric_dict = read_json(METRIC_PATH) | |
| metric_dict['all'] = metric_dict['all']+ [test_acc] if 'all' in metric_dict else [] + [test_acc] | |
| else: | |
| metric_dict={} | |
| metric_dict['all'] = [test_acc] | |
| for i in range(10): | |
| data_per_digit = MNISTCorrupted_By_Digit(TRAIN_TRANSFORM,i) | |
| dataloader_per_digit = torch.utils.data.DataLoader(data_per_digit,batch_size=len(data_per_digit), shuffle=False) | |
| data_per_digit, label_per_digit = iter(dataloader_per_digit).next() | |
| output = network(data_per_digit) | |
| pred = output.data.max(1, keepdim=True)[1] | |
| correct = pred.eq(label_per_digit.data.view_as(pred)).sum() | |
| acc = 100. * correct / len(data_per_digit) | |
| acc=acc.item() | |
| if os.path.exists(METRIC_PATH): | |
| metric_dict[str(i)].append(acc) | |
| else: | |
| metric_dict[str(i)] = [acc] | |
| dump_json(thing=metric_dict,file=METRIC_PATH) | |
| # Push models and metrics to hub | |
| model_repo.push_to_hub() | |
| return test_metric | |
| # Update model weights again | |
| model_state_dict = MODEL_WEIGHTS_PATH | |
| optimizer_state_dict = OPTIMIZER_PATH | |
| model_repo.git_pull() | |
| if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict): | |
| network_state_dict = torch.load(model_state_dict) | |
| network.load_state_dict(network_state_dict) | |
| optimizer_state_dict = torch.load(optimizer_state_dict) | |
| optimizer.load_state_dict(optimizer_state_dict) | |
| else: | |
| # Use best weights | |
| BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth" | |
| BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth" | |
| network_state_dict = torch.load(BEST_WEIGHTS_MODEL) | |
| network.load_state_dict(network_state_dict) | |
| optimizer_state_dict = torch.load(BEST_WEIGHTS_OPTIMIZER) | |
| optimizer.load_state_dict(optimizer_state_dict) | |
| if not os.path.exists(METRIC_PATH): | |
| _ = train_and_test(False) | |
| def image_classifier(inp): | |
| """ | |
| It loads the latest model weights from the model repository, and then uses those weights to make a | |
| prediction on the input image. | |
| :param inp: the image to be classified | |
| :return: A dictionary of the form {class_number: confidence} | |
| """ | |
| # Get latest model weights ---------------- | |
| model_repo.git_pull() | |
| model_state_dict = MODEL_WEIGHTS_PATH | |
| optimizer_state_dict = OPTIMIZER_PATH | |
| which_weights='' | |
| if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict): | |
| which_weights = "Using weights from model repo" | |
| network_state_dict = torch.load(model_state_dict) | |
| network.load_state_dict(network_state_dict) | |
| optimizer_state_dict = torch.load(optimizer_state_dict) | |
| optimizer.load_state_dict(optimizer_state_dict) | |
| else: | |
| # Use best weights | |
| which_weights = "Using default best weights" | |
| BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth" | |
| BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth" | |
| network.load_state_dict(torch.load(BEST_WEIGHTS_MODEL)) | |
| optimizer.load_state_dict(torch.load(BEST_WEIGHTS_OPTIMIZER)) | |
| network.eval() | |
| input_image = TRAIN_TRANSFORM(inp).unsqueeze(0) | |
| with torch.no_grad(): | |
| prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0) | |
| #pred_number = prediction.data.max(1, keepdim=True)[1] | |
| sorted_prediction = torch.sort(prediction,descending=True) | |
| confidences={} | |
| for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()): | |
| confidences.update({s:v}) | |
| return confidences | |
| def flag(input_image,correct_result,adversarial_number): | |
| """ | |
| It takes in an image, the correct result, and the number of adversarial images that have been | |
| uploaded so far. It saves the image and metadata to a local directory, uploads the image and | |
| metadata to the hub, and then pulls the data from the hub to the local directory. If the number of | |
| images in the local directory is divisible by the TRAIN_CUTOFF, then it trains the model on the | |
| adversarial data | |
| :param input_image: The adversarial image that you want to save | |
| :param correct_result: The correct number that the image represents | |
| :param adversarial_number: This is the number of adversarial examples that have been uploaded to the | |
| dataset | |
| :return: The output is the output of the flag function. | |
| """ | |
| adversarial_number = 0 if None else adversarial_number | |
| metadata_name = get_unique_name() | |
| SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name) | |
| os.makedirs(SAVE_FILE_DIR,exist_ok=True) | |
| image_output_filename = os.path.join(SAVE_FILE_DIR,'image.png') | |
| try: | |
| input_image.save(image_output_filename) | |
| except Exception: | |
| raise Exception(f"Had issues saving PIL image to file") | |
| # Write metadata.json to file | |
| json_file_path = os.path.join(SAVE_FILE_DIR,'metadata.jsonl') | |
| metadata= {'id':metadata_name,'file_name':'image.png', | |
| 'correct_number':correct_result | |
| } | |
| dump_json(metadata,json_file_path) | |
| # Simply upload the image file and metadata using the hub's upload_file | |
| # Upload the image | |
| repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png')) | |
| _ = upload_file(path_or_fileobj = image_output_filename, | |
| path_in_repo =repo_image_path, | |
| repo_id=f'chrisjay/{HF_DATASET}', | |
| repo_type='dataset', | |
| token=HF_TOKEN | |
| ) | |
| # Upload the metadata | |
| repo_json_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'metadata.jsonl')) | |
| _ = upload_file(path_or_fileobj = json_file_path, | |
| path_in_repo =repo_json_path, | |
| repo_id=f'chrisjay/{HF_DATASET}', | |
| repo_type='dataset', | |
| token=HF_TOKEN | |
| ) | |
| adversarial_number+=1 | |
| output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data. </div>' | |
| repo.git_pull() | |
| length_of_dataset = len([f for f in os.scandir("./data_mnist/data")]) | |
| test_metric = f"<html> {DEFAULT_TEST_METRIC} </html>" | |
| if length_of_dataset % TRAIN_CUTOFF ==0: | |
| test_metric_ = train_and_test() | |
| test_metric = f"<html> {test_metric_} </html>" | |
| output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>' | |
| return output,adversarial_number | |
| def get_number_dict(DATA_DIR): | |
| """ | |
| It takes a directory as input, and returns a list of the number of times each number appears in the | |
| metadata.jsonl files in that directory | |
| :param DATA_DIR: The directory where the data is stored | |
| """ | |
| files = [f.name for f in os.scandir(DATA_DIR)] | |
| metadata_jsons = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl')) for f in files] | |
| numbers = [m[0]['correct_number'] for m in metadata_jsons if m is not None] | |
| numbers_count = Counter(numbers) | |
| numbers_count_keys = list(numbers_count.keys()) | |
| numbers_count_values = [numbers_count[k] for k in numbers_count_keys] | |
| return numbers_count_keys,numbers_count_values | |
| def get_statistics(): | |
| """ | |
| It loads the model and optimizer state dicts, pulls the latest data from the repo, gets the number | |
| of adversarial samples per digit, plots the distribution of adversarial samples per digit, plots the | |
| test accuracy per digit per train step, and plots the test accuracy for all digits per train step | |
| :return: the following: | |
| """ | |
| model_repo.git_pull() | |
| model_state_dict = MODEL_WEIGHTS_PATH | |
| optimizer_state_dict = OPTIMIZER_PATH | |
| if os.path.exists(model_state_dict): | |
| network_state_dict = torch.load(model_state_dict) | |
| network.load_state_dict(network_state_dict) | |
| if os.path.exists(optimizer_state_dict): | |
| optimizer_state_dict = torch.load(optimizer_state_dict) | |
| optimizer.load_state_dict(optimizer_state_dict) | |
| repo.git_pull() | |
| DATA_DIR = './data_mnist/data' | |
| numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR) | |
| STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values)) | |
| plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit",True) | |
| fig_d, ax_d = plt.subplots(tight_layout=True) | |
| if os.path.exists(METRIC_PATH): | |
| metric_dict = read_json(METRIC_PATH) | |
| for i in range(10): | |
| try: | |
| x_i = [i+1 for i in range(len(metric_dict[str(i)]))] | |
| ax_d.plot(x_i, metric_dict[str(i)],label=str(i)) | |
| except Exception: | |
| continue | |
| ax_d.set_xticks(range(0, len(metric_dict['0'])+1, 1)) | |
| else: | |
| metric_dict={} | |
| fig_d.legend() | |
| ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step") | |
| done_html = f"""<div style="color: green"> | |
| <p> ✅ Statistics loaded successfully! Click `{GET_STATISTICS_MESSAGE}`to reload.</p> | |
| </div> | |
| """ | |
| # Plot for total test accuracy for all digits | |
| fig_all, ax_all = plt.subplots(tight_layout=True) | |
| x_i = [i+1 for i in range(len(metric_dict['all']))] | |
| ax_all.plot(x_i, metric_dict['all']) | |
| ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits") | |
| ax_all.set_xticks(range(0, x_i[-1]+1, 1)) | |
| return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_ | |
| def main(): | |
| block = gr.Blocks(css=BLOCK_CSS) | |
| with block: | |
| gr.Markdown(TITLE) | |
| gr.Markdown(description) | |
| with gr.Tabs(): | |
| with gr.TabItem('MNIST'): | |
| gr.Markdown(WHAT_TO_DO) | |
| #test_metric = gr.outputs.HTML("") | |
| with gr.Row(): | |
| image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil") | |
| label_output = gr.outputs.Label(num_top_classes=2) | |
| gr.Markdown(MODEL_IS_WRONG) | |
| number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?") | |
| gr.Markdown('Please wait a while after you press `Flag`. It takes time.') | |
| flag_btn = gr.Button("Flag") | |
| output_result = gr.outputs.HTML() | |
| adversarial_number = gr.Variable(value=0) | |
| image_input.change(image_classifier,inputs = [image_input],outputs=[label_output]) | |
| flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number]) | |
| with gr.TabItem('Dashboard') as dashboard: | |
| get_stat = gr.Button(f'{GET_STATISTICS_MESSAGE}') | |
| notification = gr.HTML(f"""<div style="color: green"> | |
| <p> ⌛ Click `{GET_STATISTICS_MESSAGE}` to generate statistics... </p> | |
| </div> | |
| """) | |
| stats = gr.Markdown() | |
| stat_adv_image =gr.Plot(type="matplotlib") | |
| gr.Markdown(DASHBOARD_EXPLANATION) | |
| test_results=gr.Plot(type="matplotlib") | |
| gr.Markdown(DASHBOARD_EXPLANATION_TEST) | |
| test_results_all=gr.Plot(type="matplotlib") | |
| #dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats]) | |
| get_stat.click(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,test_results_all,notification,stats]) | |
| block.launch() | |
| if __name__ == "__main__": | |
| main() |