Dusit-P's picture
Update app.py
c968c6a verified
raw
history blame
15.3 kB
import os, json, importlib.util, tempfile, traceback, torch, re, math
import torch.nn.functional as F
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer
# ===== ปรับได้จาก Settings > Variables & secrets ของ Space =====
REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb")
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # หรือ "baseline"
HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้
# ---- theme colors (soft modern) ----
NEG_COLOR = os.getenv("NEG_COLOR", "#F87171") # red-400 (นุ่ม)
POS_COLOR = os.getenv("POS_COLOR", "#34D399") # emerald-400 (นุ่ม)
TEMPLATE = "plotly_white"
CACHE = {}
# ---------- load architecture & weights from model repo ----------
def _import_models():
if "models_module" in CACHE:
return CACHE["models_module"]
models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN)
spec = importlib.util.spec_from_file_location("models", models_py)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
CACHE["models_module"] = mod
return mod
def load_model(model_name: str):
key = f"model:{model_name}"
if key in CACHE:
return CACHE[key]
cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
with open(cfg_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
models = _import_models()
tok = AutoTokenizer.from_pretrained(cfg["base_model"])
model = models.create_model_by_name(cfg["arch"])
state = load_file(w_path)
model.load_state_dict(state, strict=True)
model.eval()
CACHE[key] = (model, tok, cfg)
return CACHE[key]
# ---------- helpers ----------
def _format_pct(x: float) -> str:
return f"{x*100:.2f}%"
# ====== ฟิลเตอร์ข้อความที่ไม่ใช่รีวิว / ค่าว่าง / สัญลักษณ์ ======
_INVALID_STRINGS = {"-", "--", "—", "n/a", "na", "null", "none", "nan", ".", "…", ""} # lower-case
_RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]") # ต้องมีอย่างน้อย 1 ตัวอักษรไทยหรืออังกฤษ
def _norm_text(v) -> str:
"""แปลงค่าให้เป็นสตริงพร้อม trim และกัน NaN/None"""
if v is None:
return ""
if isinstance(v, float) and math.isnan(v):
return ""
s = str(v).strip()
return s
def _is_substantive_text(s: str, min_chars: int = 2) -> bool:
"""เงื่อนไขว่าเป็นข้อความที่พอจะวิเคราะห์ได้"""
if not s:
return False
s_lower = s.lower()
if s_lower in _INVALID_STRINGS:
return False
if not _RE_HAS_LETTER.search(s):
return False
if len(s.replace(" ", "")) < min_chars:
return False
return True
def _clean_texts(texts):
"""รับ list ใด ๆ → คืน (รายการที่ใช้ได้, จำนวนที่ถูกข้าม)"""
all_norm = [_norm_text(t) for t in texts]
cleaned = [t for t in all_norm if _is_substantive_text(t)]
skipped = len(all_norm) - len(cleaned)
return cleaned, skipped
def _detect_cols(df: pd.DataFrame):
"""เดาชื่อคอลัมน์รีวิว/ร้านอัตโนมัติ ถ้าไม่พบรีวิว เลือกคอลัมน์ object ตัวแรก"""
rev_cands = ["review", "text", "comment", "content", "message", "ข้อความ", "รีวิว"]
shop_cands = ["shop", "shop_name", "store", "restaurant", "brand", "merchant", "ชื่อร้าน"]
review_col = next((c for c in rev_cands if c in df.columns), None)
shop_col = next((c for c in shop_cands if c in df.columns), None)
if review_col is None:
obj_cols = [c for c in df.columns if df[c].dtype == object]
if obj_cols:
review_col = obj_cols[0]
return review_col, shop_col
def _summarize_df(df: pd.DataFrame):
"""สรุปภาพรวม + ตัวเลขเฉลี่ยความมั่นใจ"""
total = len(df)
neg = int((df["label"] == "negative").sum())
pos = int((df["label"] == "positive").sum())
neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
info = (
f"**Summary** \n"
f"- Total: {total} \n"
f"- Negative: {neg} \n"
f"- Positive: {pos} \n"
f"- Avg negative: {neg_avg:.2f}% \n"
f"- Avg positive: {pos_avg:.2f}%"
)
return {"total": total, "neg": neg, "pos": pos, "neg_avg": neg_avg, "pos_avg": pos_avg, "md": info}
def _make_figures(df: pd.DataFrame):
s = _summarize_df(df)
# --- BAR: 2 trace, สีคงที่ ---
fig_bar = go.Figure()
fig_bar.add_bar(name="negative", x=["negative"], y=[s["neg"]], marker_color=NEG_COLOR)
fig_bar.add_bar(name="positive", x=["positive"], y=[s["pos"]], marker_color=POS_COLOR)
fig_bar.update_layout(
barmode="group",
title="Label counts",
xaxis_title="label",
yaxis_title="count",
template=TEMPLATE,
legend_title="label",
)
# --- PIE: สีสอดคล้องกับ bar ---
fig_pie = go.Figure(
go.Pie(
labels=["negative", "positive"],
values=[s["neg"], s["pos"]],
hole=0.35,
sort=False,
marker=dict(colors=[NEG_COLOR, POS_COLOR]),
)
)
fig_pie.update_layout(title="Label share", template=TEMPLATE)
return fig_bar, fig_pie, s["md"]
def _shop_summary(out_df: pd.DataFrame, max_shops=15):
"""สรุปต่อร้าน: ตาราง + stacked bar (pos/neg) — ใช้สีคงที่"""
if "shop" not in out_df.columns:
empty_tbl = pd.DataFrame(columns=["shop","total","positive","negative","positive_rate(%)","negative_rate(%)"])
return go.Figure(), empty_tbl
g = out_df.groupby("shop")["label"].value_counts().unstack(fill_value=0)
for col in ["positive","negative"]:
if col not in g.columns:
g[col] = 0
g["total"] = g["positive"] + g["negative"]
g = g.sort_values("total", ascending=False)
table = g[["total","positive","negative"]].copy()
table["positive_rate(%)"] = (table["positive"] / table["total"] * 100).round(2)
table["negative_rate(%)"] = (table["negative"] / table["total"] * 100).round(2)
table = table.reset_index().rename(columns={"index":"shop"})
# กราฟโชว์ top N ร้าน
top = table.head(max_shops)
fig = go.Figure()
fig.add_bar(name="positive", x=top["shop"], y=top["positive"], marker_color=POS_COLOR)
fig.add_bar(name="negative", x=top["shop"], y=top["negative"], marker_color=NEG_COLOR)
fig.update_layout(
barmode="stack",
title=f"Per-shop counts (top {len(top)})",
xaxis_title="shop",
yaxis_title="count",
legend_title="label",
template=TEMPLATE,
xaxis=dict(tickangle=-30),
)
return fig, table
# ---------- core prediction ----------
def _predict_batch(texts, model_name, batch_size=64):
"""รับ list[str] (ผ่านการกรองแล้ว) → คืน list[dict]"""
model, tok, cfg = load_model(model_name)
results = []
for i in range(0, len(texts), batch_size):
chunk = texts[i:i+batch_size]
enc = tok(chunk, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
with torch.no_grad():
logits = model(enc["input_ids"], enc["attention_mask"])
probs = F.softmax(logits, dim=1).cpu().numpy()
for txt, p in zip(chunk, probs):
neg, pos = float(p[0]), float(p[1])
label = "positive" if pos >= neg else "negative"
results.append({
"review": txt,
"negative(%)": _format_pct(neg),
"positive(%)": _format_pct(pos),
"label": label,
})
return results
# ---------- API wrappers ----------
def predict_one(text: str, model_choice: str):
try:
s = _norm_text(text)
if not _is_substantive_text(s):
return {"negative": 0.0, "positive": 0.0}, "invalid"
model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
out = _predict_batch([s], model_name)[0]
probs = {
"negative": float(out["negative(%)"].rstrip("%"))/100.0,
"positive": float(out["positive(%)"].rstrip("%"))/100.0,
}
return probs, out["label"]
except Exception as e:
print("ERROR in predict_one:", repr(e))
traceback.print_exc()
raise
def predict_many(text_block: str, model_choice: str):
try:
model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
raw_lines = (text_block or "").splitlines()
trimmed = [_norm_text(ln) for ln in raw_lines if _norm_text(ln)]
cleaned, skipped = _clean_texts(trimmed)
if len(cleaned) == 0:
empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
return empty, go.Figure(), go.Figure(), "No valid text"
results = _predict_batch(cleaned, model_name)
df = pd.DataFrame(results, columns=["review","negative(%)","positive(%)","label"])
fig_bar, fig_pie, info_md = _make_figures(df)
info_md = f"{info_md} \n- Skipped (empty/non-text): {skipped}"
return df, fig_bar, fig_pie, info_md
except Exception as e:
print("ERROR in predict_many:", repr(e))
traceback.print_exc()
raise
def predict_csv(file_obj, model_choice: str, review_col_override: str = "", shop_col_override: str = ""):
try:
if file_obj is None:
return pd.DataFrame(), None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "กรุณาอัปโหลดไฟล์ CSV"
model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
df = pd.read_csv(file_obj.name)
auto_rev, auto_shop = _detect_cols(df)
rev_col = (review_col_override or "").strip() or auto_rev
shop_col = (shop_col_override or "").strip() or auto_shop
if rev_col not in df.columns:
raise ValueError(f"ไม่พบคอลัมน์รีวิว '{rev_col}' ใน CSV (columns = {list(df.columns)})")
# === กรองแถวที่ใช้ได้จริง ===
reviews_norm = df[rev_col].apply(_norm_text)
mask_use = reviews_norm.apply(_is_substantive_text)
skipped = int((~mask_use).sum())
used_df = df.loc[mask_use].copy()
if used_df.empty:
empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
return empty, None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "ไม่พบรีวิวที่เป็นข้อความ"
results = _predict_batch(used_df[rev_col].astype(str).tolist(), model_name)
out = pd.DataFrame(results, columns=["review","negative(%)","positive(%)","label"])
if shop_col and shop_col in used_df.columns:
out.insert(0, "shop", used_df[shop_col].astype(str).fillna(""))
# ไฟล์ผลลัพธ์สำหรับดาวน์โหลด
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
out.to_csv(tmp.name, index=False, encoding="utf-8-sig")
# กราฟ/สรุปรวม
fig_bar, fig_pie, info_md = _make_figures(out)
# กราฟ/ตารางต่อร้าน (ถ้ามี shop)
fig_shop, tbl_shop = _shop_summary(out)
# แนบข้อความบอกคอลัมน์ที่ใช้ + จำนวนแถวที่ถูกข้าม
info_md = (
f"{info_md} \n"
f"ใช้คอลัมน์รีวิว: {rev_col}"
+ (f" | คอลัมน์ร้าน: {shop_col}" if ("shop" in out.columns) else " | ไม่มีคอลัมน์ร้าน")
+ f" \n- Skipped (empty/non-text): {skipped}"
)
return out, tmp.name, fig_bar, fig_pie, fig_shop, tbl_shop, info_md
except Exception as e:
print("ERROR in predict_csv:", repr(e))
traceback.print_exc()
raise
# ---------- Gradio UI ----------
with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo:
gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM/CNN Heads)")
model_radio = gr.Radio(choices=["cnn_bilstm","baseline"], value=DEFAULT_MODEL, label="เลือกโมเดล")
with gr.Tab("Single"):
t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)")
probs = gr.Label(label="Probabilities")
pred = gr.Textbox(label="Prediction", interactive=False)
gr.Button("Predict").click(predict_one, [t1, model_radio], [probs, pred])
with gr.Tab("Batch (หลายข้อความ)"):
t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)")
df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
bar2 = gr.Plot(label="Label counts (bar)")
pie2 = gr.Plot(label="Label share (pie)")
sum2 = gr.Markdown()
gr.Button("Run Batch").click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2])
with gr.Tab("CSV (auto-detect columns)"):
f = gr.File(label="อัปโหลด CSV", file_types=[".csv"])
review_col_inp = gr.Textbox(label="ชื่อคอลัมน์รีวิว (เว้นว่างให้เดาได้)")
shop_col_inp = gr.Textbox(label="ชื่อคอลัมน์ร้าน (เว้นว่างได้)")
df3 = gr.Dataframe(label="ผลลัพธ์ CSV", interactive=False)
download = gr.File(label="ดาวน์โหลดผลลัพธ์")
bar3 = gr.Plot(label="Label counts (bar)")
pie3 = gr.Plot(label="Label share (pie)")
shop_bar = gr.Plot(label="Per-shop stacked bar")
shop_tbl = gr.Dataframe(label="Per-shop summary", interactive=False)
info = gr.Markdown()
gr.Button("Run CSV").click(
predict_csv,
inputs=[f, model_radio, review_col_inp, shop_col_inp],
outputs=[df3, download, bar3, pie3, shop_bar, shop_tbl, info]
)
if __name__ == "__main__":
demo.launch()