Spaces:
Sleeping
Sleeping
| import os, json, importlib.util, tempfile, traceback, torch, re, math | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from transformers import AutoTokenizer, AutoModel | |
| # ===== ปรับได้จาก Settings > Variables & secrets ของ Space ===== | |
| REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb") | |
| DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # "cnn_bilstm" | "baseline" | "last4weighted_pure" | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้ | |
| # ---- theme colors (soft modern) ---- | |
| NEG_COLOR = os.getenv("NEG_COLOR", "#F87171") # red-400 (นุ่ม) | |
| POS_COLOR = os.getenv("POS_COLOR", "#34D399") # emerald-400 (นุ่ม) | |
| TEMPLATE = "plotly_white" | |
| CACHE = {} | |
| # ---------- โหลดสถาปัตยกรรมจาก repo (common/models.py) ---------- | |
| def _import_models(): | |
| if "models_module" in CACHE: | |
| return CACHE["models_module"] | |
| models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN) | |
| spec = importlib.util.spec_from_file_location("models", models_py) | |
| mod = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(mod) | |
| CACHE["models_module"] = mod | |
| return mod | |
| # ---------- Fallback เผื่อ common/models.py ยังไม่รู้จัก Model3 ---------- | |
| class _BaseHead(nn.Module): | |
| def __init__(self, hidden_in, hidden_lstm=128, classes=2, dropout=0.3, pooling='masked_mean'): | |
| super().__init__() | |
| self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True) | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc = nn.Linear(hidden_lstm*2, classes) | |
| assert pooling in ['cls','masked_mean','masked_max'] | |
| self.pooling = pooling | |
| def _pool(self, x, mask): | |
| if self.pooling=='cls': return x[:,0,:] | |
| mask = mask.unsqueeze(-1) | |
| if self.pooling=='masked_mean': | |
| s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d | |
| x=x.masked_fill(mask==0,-1e9); return x.max(1).values | |
| def forward_after_bert(self, seq, mask): | |
| x,_ = self.lstm(seq) | |
| x = self._pool(x, mask) | |
| return self.fc(self.dropout(x)) | |
| class _Model3PureLast4(nn.Module): | |
| """Last-4 weighted (Pure): LSTM รับ 768 จาก BERT""" | |
| def __init__(self, base_model, hidden=128, classes=2, dropout=0.3, pooling='masked_mean'): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained(base_model) | |
| self.w = nn.Parameter(torch.ones(4)) | |
| H = self.bert.config.hidden_size | |
| self.head = _BaseHead(H, hidden, classes, dropout, pooling) | |
| def forward(self, ids, mask): | |
| out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True) | |
| last4 = out.hidden_states[-4:] | |
| w = F.softmax(self.w, dim=0) | |
| seq = sum(w[i]*last4[i] for i in range(4)) # [B,T,768] | |
| return self.head.forward_after_bert(seq, mask) | |
| class _Model3ConvLast4(nn.Module): | |
| """Last-4 weighted + Conv1d(→128): LSTM รับ 128""" | |
| def __init__(self, base_model, hidden=128, classes=2, dropout=0.3, pooling='masked_mean'): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained(base_model) | |
| self.w = nn.Parameter(torch.ones(4)) | |
| H = self.bert.config.hidden_size | |
| self.c1 = nn.Conv1d(H,128,3,padding=1) | |
| self.c2 = nn.Conv1d(128,128,5,padding=2) | |
| self.head = _BaseHead(128, hidden, classes, dropout, pooling) | |
| def forward(self, ids, mask): | |
| out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True) | |
| last4 = out.hidden_states[-4:] | |
| w = F.softmax(self.w, dim=0) | |
| seq = sum(w[i]*last4[i] for i in range(4)) # [B,T,768] | |
| x = F.relu(self.c1(seq.transpose(1,2))) | |
| x = F.relu(self.c2(x)).transpose(1,2) # [B,T,128] | |
| return self.head.forward_after_bert(x, mask) | |
| def _create_model_fallback(arch: str, base_model: str): | |
| """เลือกสถาปัตยกรรม fallback จากชื่อ arch ใน config.json""" | |
| if arch in ("Model3_Pure_Last4Weighted", "last4weighted_pure", "last4_pure"): | |
| return _Model3PureLast4(base_model) | |
| if arch in ("Model3_MLP_Last4Weighted", "last4weighted"): | |
| return _Model3ConvLast4(base_model) | |
| raise ValueError(f"No fallback available for arch={arch}") | |
| # ---------- โหลดโมเดลจากโฟลเดอร์ใน repo (เช่น cnn_bilstm/, baseline/, last4weighted_pure/) ---------- | |
| def load_model(model_name: str): | |
| key = f"model:{model_name}" | |
| if key in CACHE: | |
| return CACHE[key] | |
| cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN) | |
| w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN) | |
| with open(cfg_path, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased") | |
| arch_name = cfg.get("arch", "") | |
| tok = AutoTokenizer.from_pretrained(base_model) | |
| # พยายามสร้างจาก common/models.py ก่อน ถ้าไม่สำเร็จค่อย fallback | |
| try: | |
| models = _import_models() | |
| model = models.create_model_by_name(arch_name) | |
| except Exception as e: | |
| print(f"[INFO] Using fallback for arch={arch_name} ({e})") | |
| model = _create_model_fallback(arch_name, base_model) | |
| state = load_file(w_path) | |
| # ใช้ strict=True ถ้า key ตรง; ถ้าอยากกัน edge-case สามารถปรับเป็น strict=False ได้ | |
| model.load_state_dict(state, strict=True) | |
| model.eval() | |
| CACHE[key] = (model, tok, cfg) | |
| return CACHE[key] | |
| # ---------- helpers ---------- | |
| def _format_pct(x: float) -> str: | |
| return f"{x*100:.2f}%" | |
| # ====== ฟิลเตอร์ข้อความที่ไม่ใช่รีวิว / ค่าว่าง / สัญลักษณ์ ====== | |
| _INVALID_STRINGS = {"-", "--", "—", "n/a", "na", "null", "none", "nan", ".", "…", ""} # lower-case | |
| _RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]") # ต้องมีอย่างน้อย 1 ตัวอักษรไทยหรืออังกฤษ | |
| def _norm_text(v) -> str: | |
| """แปลงค่าให้เป็นสตริงพร้อม trim และกัน NaN/None""" | |
| if v is None: | |
| return "" | |
| if isinstance(v, float) and math.isnan(v): | |
| return "" | |
| s = str(v).strip() | |
| return s | |
| def _is_substantive_text(s: str, min_chars: int = 2) -> bool: | |
| """เงื่อนไขว่าเป็นข้อความที่พอจะวิเคราะห์ได้""" | |
| if not s: | |
| return False | |
| s_lower = s.lower() | |
| if s_lower in _INVALID_STRINGS: | |
| return False | |
| if not _RE_HAS_LETTER.search(s): | |
| return False | |
| if len(s.replace(" ", "")) < min_chars: | |
| return False | |
| return True | |
| def _clean_texts(texts): | |
| """รับ list ใด ๆ → คืน (รายการที่ใช้ได้, จำนวนที่ถูกข้าม)""" | |
| all_norm = [_norm_text(t) for t in texts] | |
| cleaned = [t for t in all_norm if _is_substantive_text(t)] | |
| skipped = len(all_norm) - len(cleaned) | |
| return cleaned, skipped | |
| def _detect_cols(df: pd.DataFrame): | |
| """เดาชื่อคอลัมน์รีวิว/ร้านอัตโนมัติ ถ้าไม่พบรีวิว เลือกคอลัมน์ object ตัวแรก""" | |
| rev_cands = ["review", "text", "comment", "content", "message", "ข้อความ", "รีวิว"] | |
| shop_cands = ["shop", "shop_name", "store", "restaurant", "brand", "merchant", "ชื่อร้าน"] | |
| review_col = next((c for c in rev_cands if c in df.columns), None) | |
| shop_col = next((c for c in shop_cands if c in df.columns), None) | |
| if review_col is None: | |
| obj_cols = [c for c in df.columns if df[c].dtype == object] | |
| if obj_cols: | |
| review_col = obj_cols[0] | |
| return review_col, shop_col | |
| def _summarize_df(df: pd.DataFrame): | |
| """สรุปภาพรวม + ตัวเลขเฉลี่ยความมั่นใจ""" | |
| total = len(df) | |
| neg = int((df["label"] == "negative").sum()) | |
| pos = int((df["label"] == "positive").sum()) | |
| neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean() | |
| pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean() | |
| info = ( | |
| f"**Summary** \n" | |
| f"- Total: {total} \n" | |
| f"- Negative: {neg} \n" | |
| f"- Positive: {pos} \n" | |
| f"- Avg negative: {neg_avg:.2f}% \n" | |
| f"- Avg positive: {pos_avg:.2f}%" | |
| ) | |
| return {"total": total, "neg": neg, "pos": pos, "neg_avg": neg_avg, "pos_avg": pos_avg, "md": info} | |
| def _make_figures(df: pd.DataFrame): | |
| s = _summarize_df(df) | |
| # --- BAR: 2 trace, สีคงที่ --- | |
| fig_bar = go.Figure() | |
| fig_bar.add_bar(name="negative", x=["negative"], y=[s["neg"]], marker_color=NEG_COLOR) | |
| fig_bar.add_bar(name="positive", x=["positive"], y=[s["pos"]], marker_color=POS_COLOR) | |
| fig_bar.update_layout( | |
| barmode="group", | |
| title="Label counts", | |
| xaxis_title="label", | |
| yaxis_title="count", | |
| template=TEMPLATE, | |
| legend_title="label", | |
| ) | |
| # --- PIE: สีสอดคล้องกับ bar --- | |
| fig_pie = go.Figure( | |
| go.Pie( | |
| labels=["negative", "positive"], | |
| values=[s["neg"], s["pos"]], | |
| hole=0.35, | |
| sort=False, | |
| marker=dict(colors=[NEG_COLOR, POS_COLOR]), | |
| ) | |
| ) | |
| fig_pie.update_layout(title="Label share", template=TEMPLATE) | |
| return fig_bar, fig_pie, s["md"] | |
| def _shop_summary(out_df: pd.DataFrame, max_shops=15): | |
| """สรุปต่อร้าน: ตาราง + stacked bar (pos/neg) — ใช้สีคงที่""" | |
| if "shop" not in out_df.columns: | |
| empty_tbl = pd.DataFrame(columns=["shop","total","positive","negative","positive_rate(%)","negative_rate(%)"]) | |
| return go.Figure(), empty_tbl | |
| g = out_df.groupby("shop")["label"].value_counts().unstack(fill_value=0) | |
| for col in ["positive","negative"]: | |
| if col not in g.columns: | |
| g[col] = 0 | |
| g["total"] = g["positive"] + g["negative"] | |
| g = g.sort_values("total", ascending=False) | |
| table = g[["total","positive","negative"]].copy() | |
| table["positive_rate(%))"] = (table["positive"] / table["total"] * 100).round(2) | |
| table["negative_rate(%)"] = (table["negative"] / table["total"] * 100).round(2) | |
| table = table.reset_index().rename(columns={"index":"shop"}) | |
| # กราฟโชว์ top N ร้าน | |
| top = table.head(max_shops) | |
| fig = go.Figure() | |
| fig.add_bar(name="positive", x=top["shop"], y=top["positive"], marker_color=POS_COLOR) | |
| fig.add_bar(name="negative", x=top["shop"], y=top["negative"], marker_color=NEG_COLOR) | |
| fig.update_layout( | |
| barmode="stack", | |
| title=f"Per-shop counts (top {len(top)})", | |
| xaxis_title="shop", | |
| yaxis_title="count", | |
| legend_title="label", | |
| template=TEMPLATE, | |
| xaxis=dict(tickangle=-30), | |
| ) | |
| return fig, table | |
| # ---------- core prediction ---------- | |
| def _predict_batch(texts, model_name, batch_size=64): | |
| """รับ list[str] (ผ่านการกรองแล้ว) → คืน list[dict]""" | |
| model, tok, cfg = load_model(model_name) | |
| results = [] | |
| for i in range(0, len(texts), batch_size): | |
| chunk = texts[i:i+batch_size] | |
| enc = tok(chunk, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(enc["input_ids"], enc["attention_mask"]) | |
| probs = F.softmax(logits, dim=1).cpu().numpy() | |
| for txt, p in zip(chunk, probs): | |
| neg, pos = float(p[0]), float(p[1]) | |
| label = "positive" if pos >= neg else "negative" | |
| results.append({ | |
| "review": txt, | |
| "negative(%)": _format_pct(neg), | |
| "positive(%)": _format_pct(pos), | |
| "label": label, | |
| }) | |
| return results | |
| # ---------- API wrappers ---------- | |
| def predict_one(text: str, model_choice: str): | |
| try: | |
| s = _norm_text(text) | |
| if not _is_substantive_text(s): | |
| return {"negative": 0.0, "positive": 0.0}, "invalid" | |
| model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง | |
| out = _predict_batch([s], model_name)[0] | |
| probs = { | |
| "negative": float(out["negative(%)"].rstrip("%"))/100.0, | |
| "positive": float(out["positive(%)"].rstrip("%"))/100.0, | |
| } | |
| return probs, out["label"] | |
| except Exception as e: | |
| print("ERROR in predict_one:", repr(e)) | |
| traceback.print_exc() | |
| raise | |
| def predict_many(text_block: str, model_choice: str): | |
| try: | |
| model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง | |
| raw_lines = (text_block or "").splitlines() | |
| trimmed = [_norm_text(ln) for ln in raw_lines if _norm_text(ln)] | |
| cleaned, skipped = _clean_texts(trimmed) | |
| if len(cleaned) == 0: | |
| empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"]) | |
| return empty, go.Figure(), go.Figure(), "No valid text" | |
| results = _predict_batch(cleaned, model_name) | |
| df = pd.DataFrame(results, columns=["review","negative(%)","positive(%)","label"]) | |
| fig_bar, fig_pie, info_md = _make_figures(df) | |
| info_md = f"{info_md} \n- Skipped (empty/non-text): {skipped}" | |
| return df, fig_bar, fig_pie, info_md | |
| except Exception as e: | |
| print("ERROR in predict_many:", repr(e)) | |
| traceback.print_exc() | |
| raise | |
| def predict_csv(file_obj, model_choice: str, review_col_override: str = "", shop_col_override: str = ""): | |
| """ | |
| พฤติกรรม: | |
| - ไม่ตัดแถวทิ้ง: แถว invalid ยังอยู่ เรียงตามไฟล์เดิม | |
| - review ของแถว invalid = NA, ไม่คำนวณผลลัพธ์ | |
| - shop คงค่าจากไฟล์เดิม ไม่แปลงเป็นสตริง | |
| - กราฟ/สรุป คำนวณจากเฉพาะแถว valid | |
| """ | |
| try: | |
| if file_obj is None: | |
| return pd.DataFrame(), None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "กรุณาอัปโหลดไฟล์ CSV" | |
| model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง | |
| df = pd.read_csv(file_obj.name) | |
| auto_rev, auto_shop = _detect_cols(df) | |
| rev_col = (review_col_override or "").strip() or auto_rev | |
| shop_col = (shop_col_override or "").strip() or auto_shop | |
| if rev_col not in df.columns: | |
| raise ValueError(f"ไม่พบคอลัมน์รีวิว '{rev_col}' ใน CSV (columns = {list(df.columns)})") | |
| # === เตรียมรีวิวและมาสก์แถวที่ 'มีเนื้อหา' เท่านั้น === | |
| reviews_norm = df[rev_col].apply(_norm_text) | |
| mask_valid = reviews_norm.apply(_is_substantive_text) | |
| idx_valid = df.index[mask_valid].tolist() | |
| skipped = int((~mask_valid).sum()) | |
| # === พยากรณ์เฉพาะแถวที่ valid === | |
| results = [] | |
| if len(idx_valid) > 0: | |
| texts_valid = reviews_norm.loc[idx_valid].tolist() | |
| results = _predict_batch(texts_valid, model_name) # list[dict] ตามลำดับ idx_valid | |
| # === สร้าง DataFrame ผลลัพธ์ "ครบทุกแถว" ตามลำดับเดิม === | |
| out = pd.DataFrame(index=df.index, columns=["review","negative(%)","positive(%)","label"]) | |
| # review: valid → normalized text, invalid → NA | |
| out.loc[idx_valid, "review"] = reviews_norm.loc[idx_valid].values | |
| out.loc[~mask_valid, "review"] = pd.NA | |
| # เติมผลพยากรณ์กลับตาม index เดิมสำหรับแถว valid | |
| for i, idx in enumerate(idx_valid): | |
| p = results[i] | |
| out.at[idx, "negative(%)"] = p["negative(%)"] | |
| out.at[idx, "positive(%)"] = p["positive(%)"] | |
| out.at[idx, "label"] = p["label"] | |
| # แทรกคอลัมน์ shop ด้านหน้า (คงค่าตามต้นฉบับโดยไม่ .astype(str)) | |
| if shop_col and shop_col in df.columns: | |
| out.insert(0, "shop", df[shop_col]) | |
| else: | |
| out.insert(0, "shop", pd.Series([pd.NA]*len(out), index=out.index)) | |
| # === เตรียมข้อมูล "เฉพาะแถวที่ valid" ไว้ทำกราฟ/สรุป === | |
| out_valid = out.loc[idx_valid].copy() | |
| # ไฟล์ผลลัพธ์สำหรับดาวน์โหลด → ครบทุกแถว | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") | |
| out.to_csv(tmp.name, index=False, encoding="utf-8-sig") | |
| if out_valid.empty: | |
| empty_fig = go.Figure() | |
| info_md = "ไม่พบรีวิวที่เป็นข้อความ\n- Skipped (empty/non-text): {}".format(skipped) | |
| empty_tbl = pd.DataFrame(columns=["shop","total","positive","negative","positive_rate(%)","negative_rate(%)"]) | |
| return out, tmp.name, empty_fig, empty_fig, empty_fig, empty_tbl, info_md | |
| # กราฟ/สรุปรวม (จากแถวที่ valid เท่านั้น) | |
| fig_bar, fig_pie, info_md = _make_figures(out_valid) | |
| # กราฟ/ตารางต่อร้าน (ใช้เฉพาะ valid) | |
| fig_shop, tbl_shop = _shop_summary(out_valid) | |
| # แนบข้อความบอกคอลัมน์ที่ใช้ + จำนวนแถวที่ถูกข้าม | |
| info_md = ( | |
| f"{info_md} \n" | |
| f"ใช้คอลัมน์รีวิว: {rev_col}" | |
| + (f" | คอลัมน์ร้าน: {shop_col}" if shop_col and (shop_col in df.columns) else " | ไม่มีคอลัมน์ร้าน") | |
| + f" \n- Skipped (empty/non-text): {skipped}" | |
| ) | |
| return out, tmp.name, fig_bar, fig_pie, fig_shop, tbl_shop, info_md | |
| except Exception as e: | |
| print("ERROR in predict_csv:", repr(e)) | |
| traceback.print_exc() | |
| raise | |
| # ---------- Gradio UI ---------- | |
| AVAILABLE_CHOICES = ["cnn_bilstm", "baseline", "last4weighted_bilstm"] # เพิ่มชื่อโฟลเดอร์โมเดลใหม่ที่คุณอัปจริง | |
| if DEFAULT_MODEL not in AVAILABLE_CHOICES: | |
| DEFAULT_MODEL = "cnn_bilstm" | |
| with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo: | |
| gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM/CNN/Last4 Heads)") | |
| model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล") | |
| with gr.Tab("Single"): | |
| t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)") | |
| probs = gr.Label(label="Probabilities") | |
| pred = gr.Textbox(label="Prediction", interactive=False) | |
| gr.Button("Predict").click(predict_one, [t1, model_radio], [probs, pred]) | |
| with gr.Tab("Batch (หลายข้อความ)"): | |
| t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)") | |
| df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False) | |
| bar2 = gr.Plot(label="Label counts (bar)") | |
| pie2 = gr.Plot(label="Label share (pie)") | |
| sum2 = gr.Markdown() | |
| gr.Button("Run Batch").click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2]) | |
| with gr.Tab("CSV (auto-detect columns)"): | |
| f = gr.File(label="อัปโหลด CSV", file_types=[".csv"]) | |
| review_col_inp = gr.Textbox(label="ชื่อคอลัมน์รีวิว (เว้นว่างให้เดาได้)") | |
| shop_col_inp = gr.Textbox(label="ชื่อคอลัมน์ร้าน (เว้นว่างได้)") | |
| df3 = gr.Dataframe(label="ผลลัพธ์ CSV", interactive=False) | |
| download = gr.File(label="ดาวน์โหลดผลลัพธ์") | |
| bar3 = gr.Plot(label="Label counts (bar)") | |
| pie3 = gr.Plot(label="Label share (pie)") | |
| shop_bar = gr.Plot(label="Per-shop stacked bar") | |
| shop_tbl = gr.Dataframe(label="Per-shop summary", interactive=False) | |
| info = gr.Markdown() | |
| gr.Button("Run CSV").click( | |
| predict_csv, | |
| inputs=[f, model_radio, review_col_inp, shop_col_inp], | |
| outputs=[df3, download, bar3, pie3, shop_bar, shop_tbl, info] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |