Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import pickle | |
| import gradio as gr | |
| import numpy as np | |
| import subprocess | |
| import shutil | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import roc_curve, auc | |
| # Define the function to process the input file and model selection | |
| def process_file(file,label, model_name): | |
| with open(file.name, 'r') as f: | |
| content = f.read() | |
| saved_test_dataset = "train.txt" | |
| saved_test_label = "train_label.txt" | |
| # Save the uploaded file content to a specified location | |
| shutil.copyfile(file.name, saved_test_dataset) | |
| shutil.copyfile(label.name, saved_test_label) | |
| # For demonstration purposes, we'll just return the content with the selected model name | |
| if(model_name=="FS"): | |
| checkpoint="ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32" | |
| elif(model_name=="IS"): | |
| checkpoint="ratio_proportion_change3/output/IS/bert_fine_tuned.model.ep14" | |
| elif(model_name=="CORRECTNESS"): | |
| checkpoint="ratio_proportion_change3/output/correctness/bert_fine_tuned.model.ep48" | |
| elif(model_name=="EFFECTIVENESS"): | |
| checkpoint="ratio_proportion_change3/output/effectiveness/bert_fine_tuned.model.ep28" | |
| else: | |
| checkpoint=None | |
| print(checkpoint) | |
| subprocess.run(["python", "src/test_saved_model.py", | |
| "--finetuned_bert_checkpoint",checkpoint | |
| ]) | |
| result = {} | |
| with open("result.txt", 'r') as file: | |
| for line in file: | |
| key, value = line.strip().split(': ', 1) | |
| # print(type(key)) | |
| if key=='epoch': | |
| result[key]=value | |
| else: | |
| result[key]=float(value) | |
| # Create a plot | |
| with open("roc_data.pkl", "rb") as f: | |
| fpr, tpr, _ = pickle.load(f) | |
| roc_auc = auc(fpr, tpr) | |
| fig, ax = plt.subplots() | |
| ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})') | |
| ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') | |
| ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'ROC Curve: {model_name}') | |
| ax.legend(loc="lower right") | |
| ax.grid() | |
| # Save plot to a file | |
| plot_path = "plot.png" | |
| fig.savefig(plot_path) | |
| plt.close(fig) | |
| # Prepare text output | |
| text_output = f"Model: {model_name}\nResult:\n{result}" | |
| return text_output,plot_path | |
| # List of models for the dropdown menu | |
| models = ["FS", "IS", "CORRECTNESS","EFFECTIVENESS"] | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# ASTRA") | |
| gr.Markdown("Upload a .txt file and select a model from the dropdown menu.") | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload a .txt file", file_types=['.txt']) | |
| label_input = gr.File(label="Upload a .txt file", file_types=['.txt']) | |
| model_dropdown = gr.Dropdown(choices=models, label="Select a model") | |
| with gr.Row(): | |
| output_text = gr.Textbox(label="Output Text") | |
| output_image = gr.Image(label="Output Plot") | |
| btn = gr.Button("Submit") | |
| btn.click(fn=process_file, inputs=[file_input,label_input, model_dropdown], outputs=[output_text,output_image]) | |
| # Launch the app | |
| demo.launch() | |