import gradio as gr
import py3Dmol
from Bio.PDB import *
import numpy as np
from Bio.PDB import PDBParser
import pandas as pd
import os, sys
import matplotlib.pyplot as plt
#sys.path.append(os.getcwd())
print('importing...')
from run_gfn import run_gfn2
print('done')
# JavaScript functions
resid_hover = """function(atom,viewer) {{
    if(!atom.label) {{
        atom.label = viewer.addLabel('{0}:'+atom.atom+atom.serial,
            {{position: atom, backgroundColor: 'mintcream', fontColor:'black'}});
    }}
}}"""
hover_func = """
function(atom,viewer) {
    if(!atom.label) {
        atom.label = viewer.addLabel(atom.interaction,
            {position: atom, backgroundColor: 'black', fontColor:'white'});
    }
}"""
unhover_func = """
function(atom,viewer) {
    if(atom.label) {
        viewer.removeLabel(atom.label);
        delete atom.label;
    }
}"""
#def get_qm_atom_features(gfn2_output, checked_features):
#    qm_atom_features = {}
#    qm_atom_features['atom type'] = gfn2_output["fetchatomicprops"]["atmlist"]
#    for checked_feature in checked_features:
#        if checked_feature == 'Charge':
#            qm_atom_features['Charge'] = gfn2_output["fetchatomicprops"]["charges"]
#        if checked_feature == 'Polarizability':
#            qm_atom_features['Polarizability'] = gfn2_output["fetchatomicprops"]["polarisabilities"]   
#    return qm_atom_features
def get_qm_atom_features(gfn2_output):
    qm_atom_features = {}
    atom_list = gfn2_output["fetchatomicprops"]["atmlist"]
    charge = gfn2_output["fetchatomicprops"]["charges"]
    pol = gfn2_output["fetchatomicprops"]["polarisabilities"]
    #atom_list = atom_list.append('Molecule')
    #charge = charge.append("")
    #pol = pol.append(gfn2_output["totalpol"])
    qm_atom_features['atom type'] = atom_list
    qm_atom_features['Charge'] = charge
    qm_atom_features['Polarizability'] = pol 
    return qm_atom_features
def get_thermo_data(gfn2_output):
    thermo_features = {}
    temperature = gfn2_output["fetchthermo"]["temperature"]
    entropy = gfn2_output["fetchthermo"]["entropy"]
    enthalpy = gfn2_output["fetchthermo"]["enthalpy"]
    cp = gfn2_output["fetchthermo"]["cp"]
    #atom_list = atom_list.append('Molecule')
    #charge = charge.append("")
    #pol = pol.append(gfn2_output["totalpol"])
    thermo_features['Temperature'] = temperature
    thermo_features['entropy'] = entropy
    thermo_features['enthalpy'] = enthalpy
    thermo_features['cp'] = cp     
    return thermo_features
def get_qm_mol_features(gfn2_output):
    qm_mol_features = {}
    qm_mol_features['Total Energy'] = gfn2_output["etotal"]
    qm_mol_features['Total Polarizability'] = gfn2_output["totalpol"]
  
    return qm_mol_features
def export_csv(d):
    d.to_csv("qm_atom_features.csv")
    return gr.File.update(value="qm_atom_features.csv", visible=True)
def export_csv_thermo(d):
    d.to_csv("thermodynamics.csv")
    return gr.File.update(value="thermodynamics.csv", visible=True)
def get_basic_visualization(input_f,input_format):
    view = py3Dmol.view(width=600, height=400)
    view.setBackgroundColor('white')
    view.addModel(input_f, input_format)
    view.setStyle({'stick': {'colorscheme': {'prop': 'resi', 'C': 'turquoise'}}})
    #view.setStyle({'stick': {'colorscheme': {'prop': 'resi', 'C': '#cccccc'}},'cartoon': {'color': '#4c4e9e', 'alpha':"0.6"}})
    view.zoomTo()
    output = view._make_html().replace("'", '"')
    print('output of html', output)
    x = f""" {output} """  # do not use ' in this input
    visualization_html = f""""""   
    return visualization_html
def add_spheres_feature_view(view, feature,xyz, viewnum, sizefactor, spec_color):
    normalization = max(max(feature),abs(min(feature)))
    for i in range(len(feature)):
        if feature[i]<0:
            color="#a0210f"
        else: 
            color=spec_color
        view.addSphere({'center':{
        'x':xyz[i][0], 
        'y':xyz[i][1],
        'z':xyz[i][2]},
        'radius':abs(feature[i])/normalization*sizefactor,'color':color,'alpha':1.00}, viewer=viewnum) 
    return view
def add_densities(view, dens, color, viewnum):
    view.addVolumetricData(dens, "cube", {'isoval': 0.01, 'color': color, 'opacity': 1.0}, viewer=viewnum)
    return view
def get_feature_visualization(input_f,input_format, features, xyz):
    view = py3Dmol.view(width=620, height=620, viewergrid=(2,2))
    view.setBackgroundColor('white')
    view.addModel(input_f, input_format, viewer=(0,0))
    view.addModel(input_f, input_format, viewer=(0,1))
    view.addModel(input_f, input_format, viewer=(1,0))
    view.addModel(input_f, input_format, viewer=(1,0))
    view.setStyle({'stick': {'colorscheme': {'prop': 'resi', 'C': '#cccccc'}}}, viewer=(0,0))
    view.setStyle({'stick': {'colorscheme': {'prop': 'resi', 'C': '#cccccc'}, "radius":"0.07"}}, viewer=(0,1))
    view.setStyle({'stick': {'colorscheme': {'prop': 'resi', 'C': '#cccccc'}, "radius":"0.07"}}, viewer=(1,0))
    #view.setStyle({'stick': {'colorscheme': {'prop': 'resi', 'C': '#cccccc'}, "radius":"0.5"}}, viewer=(0,1))
    #print('features', features)
    add_spheres_feature_view(view, features["fetchatomicprops"]["charges"], xyz, (0,1), 1.0, '#4c4e9e')
    add_spheres_feature_view(view, features["fetchatomicprops"]["polarisabilities"], xyz, (1,0), 1.0, '#809BAC')
    add_densities(view, open('dummy_struct_dens.cub', "r").read(), '#F7D7BE', (1,1))
    #view.setStyle({'stick': {'colorscheme': {'prop': 'resi', 'C': '#cccccc'}},'cartoon': {'color': '#4c4e9e', 'alpha':"0.6"}})
    view.zoomTo(viewer=(0,0))
    view.zoomTo(viewer=(0,1))
    view.zoomTo(viewer=(1,0))
    view.zoomTo(viewer=(1,1))
    output = view._make_html().replace("'", '"')
    x = f""" {output} """  # do not use ' in this input
    visualization_html = f""""""   
    return visualization_html
def create_input_files(input_file):
    input_f = open(input_file.name, "r").read()
    input_format = input_file.name.split('.')[-1]
    
    with open('dummy_struct.'+input_format, "w") as oF:
        oF.write(input_f)
    return input_format
def plot_thermo(thermo_data):
    fig, (ax1,ax2,ax3) = plt.subplots(3, 1, figsize=(8,8))
    ax1.plot(thermo_data['Temperature'], thermo_data['entropy'], lw=2,color= "#a0210f")
    ax2.plot(thermo_data['Temperature'], thermo_data['enthalpy'], lw=2, color= "#4c4e9e")
    ax3.plot(thermo_data['Temperature'], thermo_data['cp'], lw=2, color= "#809BAC")
    fig.suptitle('Thermodynamics')
    ax3.set_xlabel('Temperature K')
    ax1.set_ylabel('Entropy [J/Kmol]')
    ax2.set_ylabel('Enthalpy [J]')
    ax3.set_ylabel('Heat capacity [J/Kmol]')
    fig.tight_layout()
    #ax1.legend(loc='lower right')
    return fig
def predict(input_file, charge):
    input_f = open(input_file.name, "r").read()
    input_format = input_file.name.split('.')[-1]
    with open('dummy_struct.'+input_format, "w") as oF:
        oF.write(input_f)
    
    input_f2 = open('dummy_struct.'+input_format, "r").read()
    gfn2_input = ["filename","geom=dummy_struct."+input_format, 'calcdens=1', 'thermo=1', "charge="+charge]
    gfn2_output = run_gfn2(gfn2_input)
    feature_visualization_html = get_feature_visualization(input_f2,input_format, gfn2_output, gfn2_output['xyz'])
    qm_atom_features = get_qm_atom_features(gfn2_output)
    thermo_data = get_thermo_data(gfn2_output)
    plot = plot_thermo(thermo_data)
    return feature_visualization_html, pd.DataFrame(qm_atom_features), pd.DataFrame(thermo_data), plot#, pd.DataFrame(thermo_data)#, pd.DataFrame(qm_mol_features, index=[0])
with gr.Blocks() as demo:
    gr.Markdown("# QM property calculation")
    with gr.Row():
        input_file = gr.File(label="Structure file for input (xyz format)")
        input_file2 = input_file
        charge = gr.Textbox(placeholder="Total charge", label="Give the total charge of the input molecule. (Default=0)")
    
    single_btn = gr.Button(label="Run")
    with gr.Row():
        basic_html = gr.HTML()
        plot = gr.Plot()
    gr.HighlightedText(value=[("Positive Charge","Purple"),("Negative charge","red"),("Polarizability","Light blue"), ("Electronic Densities", "Beige")], color_map={"red":"#a0210f", "Light blue":"#809BAC", "Purple":"#4c4e9e", "Beige":"#F7D7BE"})
    with gr.Row():
        Dbutton = gr.Button("Download  calculated atom features")
        csv = gr.File(interactive=False, visible=False)
        D2button = gr.Button("Download  thermodynamic properties")
        csv2 = gr.File(interactive=False, visible=False)
    with gr.Row():
        df_atom_features = gr.Dataframe()
        df_thermo_props = gr.Dataframe()
      
    single_btn.click(fn=predict, inputs=[input_file, charge], outputs=[basic_html, df_atom_features, df_thermo_props, plot])
    Dbutton.click(export_csv, df_atom_features, csv)
    D2button.click(export_csv_thermo, df_thermo_props, csv2) 
    demo.launch(server_name="0.0.0.0", server_port=7860)