Spaces:
Sleeping
Sleeping
| import os, json, importlib.util, tempfile, traceback, torch, re, math | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from transformers import AutoTokenizer | |
| # ===== ปรับได้จาก Settings > Variables & secrets ของ Space ===== | |
| REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb") | |
| DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # หรือ "baseline" | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้ | |
| # ---- theme colors (soft modern) ---- | |
| NEG_COLOR = os.getenv("NEG_COLOR", "#F87171") # red-400 (นุ่ม) | |
| POS_COLOR = os.getenv("POS_COLOR", "#34D399") # emerald-400 (นุ่ม) | |
| TEMPLATE = "plotly_white" | |
| CACHE = {} | |
| # ---------- load architecture & weights from model repo ---------- | |
| def _import_models(): | |
| if "models_module" in CACHE: | |
| return CACHE["models_module"] | |
| models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN) | |
| spec = importlib.util.spec_from_file_location("models", models_py) | |
| mod = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(mod) | |
| CACHE["models_module"] = mod | |
| return mod | |
| def load_model(model_name: str): | |
| key = f"model:{model_name}" | |
| if key in CACHE: | |
| return CACHE[key] | |
| cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN) | |
| w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN) | |
| with open(cfg_path, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| models = _import_models() | |
| tok = AutoTokenizer.from_pretrained(cfg["base_model"]) | |
| model = models.create_model_by_name(cfg["arch"]) | |
| state = load_file(w_path) | |
| model.load_state_dict(state, strict=True) | |
| model.eval() | |
| CACHE[key] = (model, tok, cfg) | |
| return CACHE[key] | |
| # ---------- helpers ---------- | |
| def _format_pct(x: float) -> str: | |
| return f"{x*100:.2f}%" | |
| # ====== ฟิลเตอร์ข้อความที่ไม่ใช่รีวิว / ค่าว่าง / สัญลักษณ์ ====== | |
| _INVALID_STRINGS = {"-", "--", "—", "n/a", "na", "null", "none", "nan", ".", "…", ""} # lower-case | |
| _RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]") # ต้องมีอย่างน้อย 1 ตัวอักษรไทยหรืออังกฤษ | |
| def _norm_text(v) -> str: | |
| """แปลงค่าให้เป็นสตริงพร้อม trim และกัน NaN/None""" | |
| if v is None: | |
| return "" | |
| if isinstance(v, float) and math.isnan(v): | |
| return "" | |
| s = str(v).strip() | |
| return s | |
| def _is_substantive_text(s: str, min_chars: int = 2) -> bool: | |
| """เงื่อนไขว่าเป็นข้อความที่พอจะวิเคราะห์ได้""" | |
| if not s: | |
| return False | |
| s_lower = s.lower() | |
| if s_lower in _INVALID_STRINGS: | |
| return False | |
| if not _RE_HAS_LETTER.search(s): | |
| return False | |
| if len(s.replace(" ", "")) < min_chars: | |
| return False | |
| return True | |
| def _clean_texts(texts): | |
| """รับ list ใด ๆ → คืน (รายการที่ใช้ได้, จำนวนที่ถูกข้าม)""" | |
| all_norm = [_norm_text(t) for t in texts] | |
| cleaned = [t for t in all_norm if _is_substantive_text(t)] | |
| skipped = len(all_norm) - len(cleaned) | |
| return cleaned, skipped | |
| def _detect_cols(df: pd.DataFrame): | |
| """เดาชื่อคอลัมน์รีวิว/ร้านอัตโนมัติ ถ้าไม่พบรีวิว เลือกคอลัมน์ object ตัวแรก""" | |
| rev_cands = ["review", "text", "comment", "content", "message", "ข้อความ", "รีวิว"] | |
| shop_cands = ["shop", "shop_name", "store", "restaurant", "brand", "merchant", "ชื่อร้าน"] | |
| review_col = next((c for c in rev_cands if c in df.columns), None) | |
| shop_col = next((c for c in shop_cands if c in df.columns), None) | |
| if review_col is None: | |
| obj_cols = [c for c in df.columns if df[c].dtype == object] | |
| if obj_cols: | |
| review_col = obj_cols[0] | |
| return review_col, shop_col | |
| def _summarize_df(df: pd.DataFrame): | |
| """สรุปภาพรวม + ตัวเลขเฉลี่ยความมั่นใจ""" | |
| total = len(df) | |
| neg = int((df["label"] == "negative").sum()) | |
| pos = int((df["label"] == "positive").sum()) | |
| neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean() | |
| pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean() | |
| info = ( | |
| f"**Summary** \n" | |
| f"- Total: {total} \n" | |
| f"- Negative: {neg} \n" | |
| f"- Positive: {pos} \n" | |
| f"- Avg negative: {neg_avg:.2f}% \n" | |
| f"- Avg positive: {pos_avg:.2f}%" | |
| ) | |
| return {"total": total, "neg": neg, "pos": pos, "neg_avg": neg_avg, "pos_avg": pos_avg, "md": info} | |
| def _make_figures(df: pd.DataFrame): | |
| s = _summarize_df(df) | |
| # --- BAR: 2 trace, สีคงที่ --- | |
| fig_bar = go.Figure() | |
| fig_bar.add_bar(name="negative", x=["negative"], y=[s["neg"]], marker_color=NEG_COLOR) | |
| fig_bar.add_bar(name="positive", x=["positive"], y=[s["pos"]], marker_color=POS_COLOR) | |
| fig_bar.update_layout( | |
| barmode="group", | |
| title="Label counts", | |
| xaxis_title="label", | |
| yaxis_title="count", | |
| template=TEMPLATE, | |
| legend_title="label", | |
| ) | |
| # --- PIE: สีสอดคล้องกับ bar --- | |
| fig_pie = go.Figure( | |
| go.Pie( | |
| labels=["negative", "positive"], | |
| values=[s["neg"], s["pos"]], | |
| hole=0.35, | |
| sort=False, | |
| marker=dict(colors=[NEG_COLOR, POS_COLOR]), | |
| ) | |
| ) | |
| fig_pie.update_layout(title="Label share", template=TEMPLATE) | |
| return fig_bar, fig_pie, s["md"] | |
| def _shop_summary(out_df: pd.DataFrame, max_shops=15): | |
| """สรุปต่อร้าน: ตาราง + stacked bar (pos/neg) — ใช้สีคงที่""" | |
| if "shop" not in out_df.columns: | |
| empty_tbl = pd.DataFrame(columns=["shop","total","positive","negative","positive_rate(%)","negative_rate(%)"]) | |
| return go.Figure(), empty_tbl | |
| g = out_df.groupby("shop")["label"].value_counts().unstack(fill_value=0) | |
| for col in ["positive","negative"]: | |
| if col not in g.columns: | |
| g[col] = 0 | |
| g["total"] = g["positive"] + g["negative"] | |
| g = g.sort_values("total", ascending=False) | |
| table = g[["total","positive","negative"]].copy() | |
| table["positive_rate(%)"] = (table["positive"] / table["total"] * 100).round(2) | |
| table["negative_rate(%)"] = (table["negative"] / table["total"] * 100).round(2) | |
| table = table.reset_index().rename(columns={"index":"shop"}) | |
| # กราฟโชว์ top N ร้าน | |
| top = table.head(max_shops) | |
| fig = go.Figure() | |
| fig.add_bar(name="positive", x=top["shop"], y=top["positive"], marker_color=POS_COLOR) | |
| fig.add_bar(name="negative", x=top["shop"], y=top["negative"], marker_color=NEG_COLOR) | |
| fig.update_layout( | |
| barmode="stack", | |
| title=f"Per-shop counts (top {len(top)})", | |
| xaxis_title="shop", | |
| yaxis_title="count", | |
| legend_title="label", | |
| template=TEMPLATE, | |
| xaxis=dict(tickangle=-30), | |
| ) | |
| return fig, table | |
| # ---------- core prediction ---------- | |
| def _predict_batch(texts, model_name, batch_size=64): | |
| """รับ list[str] (ผ่านการกรองแล้ว) → คืน list[dict]""" | |
| model, tok, cfg = load_model(model_name) | |
| results = [] | |
| for i in range(0, len(texts), batch_size): | |
| chunk = texts[i:i+batch_size] | |
| enc = tok(chunk, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(enc["input_ids"], enc["attention_mask"]) | |
| probs = F.softmax(logits, dim=1).cpu().numpy() | |
| for txt, p in zip(chunk, probs): | |
| neg, pos = float(p[0]), float(p[1]) | |
| label = "positive" if pos >= neg else "negative" | |
| results.append({ | |
| "review": txt, | |
| "negative(%)": _format_pct(neg), | |
| "positive(%)": _format_pct(pos), | |
| "label": label, | |
| }) | |
| return results | |
| # ---------- API wrappers ---------- | |
| def predict_one(text: str, model_choice: str): | |
| try: | |
| s = _norm_text(text) | |
| if not _is_substantive_text(s): | |
| return {"negative": 0.0, "positive": 0.0}, "invalid" | |
| model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm" | |
| out = _predict_batch([s], model_name)[0] | |
| probs = { | |
| "negative": float(out["negative(%)"].rstrip("%"))/100.0, | |
| "positive": float(out["positive(%)"].rstrip("%"))/100.0, | |
| } | |
| return probs, out["label"] | |
| except Exception as e: | |
| print("ERROR in predict_one:", repr(e)) | |
| traceback.print_exc() | |
| raise | |
| def predict_many(text_block: str, model_choice: str): | |
| try: | |
| model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm" | |
| raw_lines = (text_block or "").splitlines() | |
| trimmed = [_norm_text(ln) for ln in raw_lines if _norm_text(ln)] | |
| cleaned, skipped = _clean_texts(trimmed) | |
| if len(cleaned) == 0: | |
| empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"]) | |
| return empty, go.Figure(), go.Figure(), "No valid text" | |
| results = _predict_batch(cleaned, model_name) | |
| df = pd.DataFrame(results, columns=["review","negative(%)","positive(%)","label"]) | |
| fig_bar, fig_pie, info_md = _make_figures(df) | |
| info_md = f"{info_md} \n- Skipped (empty/non-text): {skipped}" | |
| return df, fig_bar, fig_pie, info_md | |
| except Exception as e: | |
| print("ERROR in predict_many:", repr(e)) | |
| traceback.print_exc() | |
| raise | |
| def predict_csv(file_obj, model_choice: str, review_col_override: str = "", shop_col_override: str = ""): | |
| try: | |
| if file_obj is None: | |
| return pd.DataFrame(), None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "กรุณาอัปโหลดไฟล์ CSV" | |
| model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm" | |
| df = pd.read_csv(file_obj.name) | |
| auto_rev, auto_shop = _detect_cols(df) | |
| rev_col = (review_col_override or "").strip() or auto_rev | |
| shop_col = (shop_col_override or "").strip() or auto_shop | |
| if rev_col not in df.columns: | |
| raise ValueError(f"ไม่พบคอลัมน์รีวิว '{rev_col}' ใน CSV (columns = {list(df.columns)})") | |
| # === กรองแถวที่ใช้ได้จริง === | |
| reviews_norm = df[rev_col].apply(_norm_text) | |
| mask_use = reviews_norm.apply(_is_substantive_text) | |
| skipped = int((~mask_use).sum()) | |
| used_df = df.loc[mask_use].copy() | |
| if used_df.empty: | |
| empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"]) | |
| return empty, None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "ไม่พบรีวิวที่เป็นข้อความ" | |
| results = _predict_batch(used_df[rev_col].astype(str).tolist(), model_name) | |
| out = pd.DataFrame(results, columns=["review","negative(%)","positive(%)","label"]) | |
| if shop_col and shop_col in used_df.columns: | |
| out.insert(0, "shop", used_df[shop_col].astype(str).fillna("")) | |
| # ไฟล์ผลลัพธ์สำหรับดาวน์โหลด | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") | |
| out.to_csv(tmp.name, index=False, encoding="utf-8-sig") | |
| # กราฟ/สรุปรวม | |
| fig_bar, fig_pie, info_md = _make_figures(out) | |
| # กราฟ/ตารางต่อร้าน (ถ้ามี shop) | |
| fig_shop, tbl_shop = _shop_summary(out) | |
| # แนบข้อความบอกคอลัมน์ที่ใช้ + จำนวนแถวที่ถูกข้าม | |
| info_md = ( | |
| f"{info_md} \n" | |
| f"ใช้คอลัมน์รีวิว: {rev_col}" | |
| + (f" | คอลัมน์ร้าน: {shop_col}" if ("shop" in out.columns) else " | ไม่มีคอลัมน์ร้าน") | |
| + f" \n- Skipped (empty/non-text): {skipped}" | |
| ) | |
| return out, tmp.name, fig_bar, fig_pie, fig_shop, tbl_shop, info_md | |
| except Exception as e: | |
| print("ERROR in predict_csv:", repr(e)) | |
| traceback.print_exc() | |
| raise | |
| # ---------- Gradio UI ---------- | |
| with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo: | |
| gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM/CNN Heads)") | |
| model_radio = gr.Radio(choices=["cnn_bilstm","baseline"], value=DEFAULT_MODEL, label="เลือกโมเดล") | |
| with gr.Tab("Single"): | |
| t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)") | |
| probs = gr.Label(label="Probabilities") | |
| pred = gr.Textbox(label="Prediction", interactive=False) | |
| gr.Button("Predict").click(predict_one, [t1, model_radio], [probs, pred]) | |
| with gr.Tab("Batch (หลายข้อความ)"): | |
| t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)") | |
| df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False) | |
| bar2 = gr.Plot(label="Label counts (bar)") | |
| pie2 = gr.Plot(label="Label share (pie)") | |
| sum2 = gr.Markdown() | |
| gr.Button("Run Batch").click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2]) | |
| with gr.Tab("CSV (auto-detect columns)"): | |
| f = gr.File(label="อัปโหลด CSV", file_types=[".csv"]) | |
| review_col_inp = gr.Textbox(label="ชื่อคอลัมน์รีวิว (เว้นว่างให้เดาได้)") | |
| shop_col_inp = gr.Textbox(label="ชื่อคอลัมน์ร้าน (เว้นว่างได้)") | |
| df3 = gr.Dataframe(label="ผลลัพธ์ CSV", interactive=False) | |
| download = gr.File(label="ดาวน์โหลดผลลัพธ์") | |
| bar3 = gr.Plot(label="Label counts (bar)") | |
| pie3 = gr.Plot(label="Label share (pie)") | |
| shop_bar = gr.Plot(label="Per-shop stacked bar") | |
| shop_tbl = gr.Dataframe(label="Per-shop summary", interactive=False) | |
| info = gr.Markdown() | |
| gr.Button("Run CSV").click( | |
| predict_csv, | |
| inputs=[f, model_radio, review_col_inp, shop_col_inp], | |
| outputs=[df3, download, bar3, pie3, shop_bar, shop_tbl, info] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |