Spaces:
Runtime error
Runtime error
work on test
Browse files
app.py
CHANGED
|
@@ -102,9 +102,10 @@ def test():
|
|
| 102 |
correct += pred.eq(target.data.view_as(pred)).sum()
|
| 103 |
test_loss /= len(test_loader.dataset)
|
| 104 |
test_losses.append(test_loss)
|
| 105 |
-
|
| 106 |
test_loss, correct, len(test_loader.dataset),
|
| 107 |
-
100. * correct / len(test_loader.dataset))
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
|
|
@@ -153,8 +154,12 @@ def image_classifier(inp):
|
|
| 153 |
confidences.update({s:v})
|
| 154 |
return confidences
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
def flag(input_image,correct_result):
|
| 158 |
# take an image, the wrong result, the correct result.
|
| 159 |
# push to dataset.
|
| 160 |
# get size of current dataset
|
|
@@ -197,8 +202,12 @@ def flag(input_image,correct_result):
|
|
| 197 |
token=HF_TOKEN
|
| 198 |
)
|
| 199 |
|
| 200 |
-
output = f'<div> Successfully saved to flagged dataset. </div>'
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
|
| 204 |
|
|
@@ -241,8 +250,12 @@ def main():
|
|
| 241 |
|
| 242 |
flag_btn = gr.Button("Flag")
|
| 243 |
output_result = gr.outputs.HTML()
|
|
|
|
| 244 |
submit.click(image_classifier,inputs = [image_input],outputs=[label_output])
|
| 245 |
-
flag_btn.click(flag,inputs=[image_input,number_dropdown],outputs=[output_result])
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
|
| 248 |
block.launch()
|
|
|
|
| 102 |
correct += pred.eq(target.data.view_as(pred)).sum()
|
| 103 |
test_loss /= len(test_loader.dataset)
|
| 104 |
test_losses.append(test_loss)
|
| 105 |
+
test_metric = '〽Current test metric - Avg. loss: `{:.4f}`, Accuracy: `{}/{}` (`{:.0f}%`)\n'.format(
|
| 106 |
test_loss, correct, len(test_loader.dataset),
|
| 107 |
+
100. * correct / len(test_loader.dataset))
|
| 108 |
+
return test_metric
|
| 109 |
|
| 110 |
|
| 111 |
|
|
|
|
| 154 |
confidences.update({s:v})
|
| 155 |
return confidences
|
| 156 |
|
| 157 |
+
def train_and_test():
|
| 158 |
+
# Train for one epoch and test
|
| 159 |
+
train(1,network,optimizer)
|
| 160 |
+
test_metric = test()
|
| 161 |
|
| 162 |
+
def flag(input_image,correct_result,train):
|
| 163 |
# take an image, the wrong result, the correct result.
|
| 164 |
# push to dataset.
|
| 165 |
# get size of current dataset
|
|
|
|
| 202 |
token=HF_TOKEN
|
| 203 |
)
|
| 204 |
|
| 205 |
+
output = f'<div> ✔ Successfully saved to flagged dataset. </div>'
|
| 206 |
+
train=True
|
| 207 |
+
if train:
|
| 208 |
+
output = f'<div> ✔ Successfully saved to flagged dataset. Training the model on adversarial data! </div>'
|
| 209 |
+
|
| 210 |
+
return output,train
|
| 211 |
|
| 212 |
|
| 213 |
|
|
|
|
| 250 |
|
| 251 |
flag_btn = gr.Button("Flag")
|
| 252 |
output_result = gr.outputs.HTML()
|
| 253 |
+
to_train = gr.Variable(value=False)
|
| 254 |
submit.click(image_classifier,inputs = [image_input],outputs=[label_output])
|
| 255 |
+
flag_btn.click(flag,inputs=[image_input,number_dropdown,to_train],outputs=[output_result,to_train])
|
| 256 |
+
if to_train.value:
|
| 257 |
+
import pdb;pdb.set_trace()
|
| 258 |
+
train_and_test()
|
| 259 |
|
| 260 |
|
| 261 |
block.launch()
|