import gradio as gr
from qasrl_model_pipeline import QASRL_Pipeline

models = ["kleinay/qanom-seq2seq-model-baseline", 
          "kleinay/qanom-seq2seq-model-joint"]
pipelines = {model: QASRL_Pipeline(model) for model in models}


description = f"""Using Seq2Seq T5 model which takes a sequence of items and outputs another sequence this model generates Questions and Answers (QA) with focus on Semantic Role Labeling (SRL)""" 
title="Seq2Seq T5 Questions and Answers (QA) with Semantic Role Labeling (SRL)"
examples = [[models[0], "In March and April the patient <p> had two falls.  One was related to asthma, heart palpitations.  The second was due to syncope and post covid vaccination dizziness during exercise.  The patient is now getting an EKG.  Former EKG had shown that there was a bundle branch block.  Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", True, "fall"],
            [models[1], "In March and April the patient had two falls.  One was related to asthma, heart palpitations.  The second was due to syncope and post covid vaccination dizziness during exercise.  The patient is now getting an EKG.  Former EKG had shown that there was a bundle branch block.  Patient had some uncontrolled immune system reactions <p> like anaphylaxis and shortness of breath.", True, "reactions"],
            [models[0], "In March and April the patient had two falls.  One was related <p> to asthma, heart palpitations.  The second was due to syncope and post covid vaccination dizziness during exercise.  The patient is now getting an EKG.  Former EKG had shown that there was a bundle branch block.  Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", True, "relate"],
            [models[1], "In March and April the patient <p> had two falls.  One was related to asthma, heart palpitations.  The second was due to syncope and post covid vaccination dizziness during exercise.  The patient is now getting an EKG.  Former EKG had shown that there was a bundle branch block.  Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", False, "fall"]]

input_sent_box_label = "Insert sentence here. Mark the predicate by adding the token '<p>' before it."
verb_form_inp_placeholder = "e.g. 'decide' for the nominalization 'decision', 'teach' for 'teacher', etc."
links = """<p style='text-align: center'>
<a href='https://www.qasrl.org' target='_blank'>QASRL Website</a>  |  <a href='https://huggingface.co/kleinay/qanom-seq2seq-model-baseline' target='_blank'>Model Repo at Huggingface Hub</a>
</p>"""
def call(model_name, sentence, is_nominal, verb_form):
    predicate_marker="<p>"
    if predicate_marker not in sentence:
        raise ValueError("You must highlight one word of the sentence as a predicate using preceding '<p>'.")
        
    if not verb_form:
        if is_nominal:
            raise ValueError("You should provide the verbal form of the nominalization")
            
        toks = sentence.split(" ")
        pred_idx = toks.index(predicate_marker)
        predicate = toks(pred_idx+1)
        verb_form=predicate
    pipeline = pipelines[model_name]
    pipe_out = pipeline([sentence], 
                    predicate_marker=predicate_marker, 
                    predicate_type="nominal" if is_nominal else "verbal",
                    verb_form=verb_form)[0]
    return pipe_out["QAs"], pipe_out["generated_text"]
iface = gr.Interface(fn=call, 
                     inputs=[gr.inputs.Radio(choices=models, default=models[0], label="Model"), 
                             gr.inputs.Textbox(placeholder=input_sent_box_label, label="Sentence", lines=4), 
                             gr.inputs.Checkbox(default=True, label="Is Nominalization?"),
                             gr.inputs.Textbox(placeholder=verb_form_inp_placeholder, label="Verbal form (for nominalizations)", default='')], 
                     outputs=[gr.outputs.JSON(label="Model Output - QASRL"), gr.outputs.Textbox(label="Raw output sequence")],
                     title=title,
                     description=description,
                     article=links,
                     examples=examples )
                     
iface.launch()