darisdzakwanhoesien's picture
Update for all model running
6f50095 verified
import gradio as gr
import torch
from transformers import pipeline
# -------------------
# 1. Model definitions
# -------------------
MODELS = {
"econbert": "climatebert/econbert",
"controversy-classification": "climatebert/ClimateControversyBERT_classification",
"controversy-bert": "climatebert/ClimateControversyBert",
"netzero-reduction": "climatebert/netzero-reduction",
"transition-physical": "climatebert/transition-physical",
"renewable": "climatebert/renewable",
"climate-detector": "climatebert/distilroberta-base-climate-detector",
"climate-commitment": "climatebert/distilroberta-base-climate-commitment",
"climate-tcfd": "climatebert/distilroberta-base-climate-tcfd",
"climate-s": "climatebert/distilroberta-base-climate-s",
"climate-specificity": "climatebert/distilroberta-base-climate-specificity",
"climate-sentiment": "climatebert/distilroberta-base-climate-sentiment",
"environmental-claims": "climatebert/environmental-claims",
"climate-f": "climatebert/distilroberta-base-climate-f",
"climate-d-s": "climatebert/distilroberta-base-climate-d-s",
"climate-d": "climatebert/distilroberta-base-climate-d",
}
# -------------------
# 2. Human-readable label maps
# -------------------
LABEL_MAPS = {
"climate-commitment": {
"LABEL_0": "Not about climate commitments",
"LABEL_1": "About climate commitments",
},
"climate-detector": {
"LABEL_0": "Not climate-related",
"LABEL_1": "Climate-related",
},
"climate-sentiment": {
"LABEL_0": "Negative",
"LABEL_1": "Neutral",
"LABEL_2": "Positive",
},
"climate-specificity": {
"LABEL_0": "Low specificity",
"LABEL_1": "Medium specificity",
"LABEL_2": "High specificity",
},
"netzero-reduction": {
"LABEL_0": "No net-zero / reduction commitment",
"LABEL_1": "Net-zero / reduction commitment",
},
"transition-physical": {
"LABEL_0": "Transition risk",
"LABEL_1": "Physical risk",
},
"renewable": {
"LABEL_0": "Not about renewables",
"LABEL_1": "About renewables",
},
}
# -------------------
# 3. Pipeline cache
# -------------------
pipelines = {}
def load_model(model_key):
"""Load and cache a model pipeline."""
if model_key not in pipelines:
repo_id = MODELS[model_key]
device = 0 if torch.cuda.is_available() else -1
print(f"๐Ÿ”น Loading model: {model_key} ({repo_id})")
pipelines[model_key] = pipeline(
"text-classification",
model=repo_id,
device=device,
torch_dtype=torch.float16 if device == 0 else None,
truncation=True,
max_length=512
)
return pipelines[model_key]
# -------------------
# 4. Inference across all models
# -------------------
def predict_all_models(text):
"""Run inference across all ClimateBERT models and return structured output."""
if not text.strip():
return "โš ๏ธ Please enter some text."
results_summary = []
for model_key, repo in MODELS.items():
try:
model = load_model(model_key)
outputs = model(text)
label_map = LABEL_MAPS.get(model_key, {})
formatted = "\n".join([
f"โ€ข {label_map.get(r['label'], r['label'])}: {r['score']:.2f}"
for r in outputs
])
results_summary.append(f"### {model_key}\n{formatted}")
except Exception as e:
results_summary.append(f"### {model_key}\nโŒ Error: {str(e)}")
return "\n\n".join(results_summary)
# -------------------
# 5. Gradio UI
# -------------------
with gr.Blocks(title="๐ŸŒ ClimateBERT All-Models Analyzer") as demo:
gr.Markdown("""
# ๐ŸŒ ClimateBERT Multi-Model Analysis
This app runs **all ClimateBERT models** on your input text (`mergedMarkdown` style).
It detects sentiment, specificity, renewables, commitments, and more โ€” all at once.
""")
text_input = gr.Textbox(
label="Input Text (mergedMarkdown)",
placeholder="Paste the sustainability report, ESG statement, or corporate disclosure here...",
lines=5
)
output = gr.Markdown(label="Model Outputs")
run_btn = gr.Button("๐Ÿ” Run All Models")
run_btn.click(predict_all_models, inputs=text_input, outputs=output)
gr.Markdown("""
---
**Note:** Each model captures a different aspect of climate-related discourse (e.g., sentiment, specificity, commitments, etc.).
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)