Dusit-P's picture
Update app.py
58903f8 verified
import os, json, importlib.util, tempfile, traceback, torch, re, math
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModel
# ===== ปรับได้จาก Settings > Variables & secrets ของ Space =====
REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb")
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # "cnn_bilstm" | "baseline" | "last4weighted_pure"
HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้
# ---- theme colors (soft modern) ----
NEG_COLOR = os.getenv("NEG_COLOR", "#F87171") # red-400 (นุ่ม)
POS_COLOR = os.getenv("POS_COLOR", "#34D399") # emerald-400 (นุ่ม)
TEMPLATE = "plotly_white"
CACHE = {}
# ---------- โหลดสถาปัตยกรรมจาก repo (common/models.py) ----------
def _import_models():
if "models_module" in CACHE:
return CACHE["models_module"]
models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN)
spec = importlib.util.spec_from_file_location("models", models_py)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
CACHE["models_module"] = mod
return mod
# ---------- Fallback เผื่อ common/models.py ยังไม่รู้จัก Model3 ----------
class _BaseHead(nn.Module):
def __init__(self, hidden_in, hidden_lstm=128, classes=2, dropout=0.3, pooling='masked_mean'):
super().__init__()
self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_lstm*2, classes)
assert pooling in ['cls','masked_mean','masked_max']
self.pooling = pooling
def _pool(self, x, mask):
if self.pooling=='cls': return x[:,0,:]
mask = mask.unsqueeze(-1)
if self.pooling=='masked_mean':
s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d
x=x.masked_fill(mask==0,-1e9); return x.max(1).values
def forward_after_bert(self, seq, mask):
x,_ = self.lstm(seq)
x = self._pool(x, mask)
return self.fc(self.dropout(x))
class _Model3PureLast4(nn.Module):
"""Last-4 weighted (Pure): LSTM รับ 768 จาก BERT"""
def __init__(self, base_model, hidden=128, classes=2, dropout=0.3, pooling='masked_mean'):
super().__init__()
self.bert = AutoModel.from_pretrained(base_model)
self.w = nn.Parameter(torch.ones(4))
H = self.bert.config.hidden_size
self.head = _BaseHead(H, hidden, classes, dropout, pooling)
def forward(self, ids, mask):
out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True)
last4 = out.hidden_states[-4:]
w = F.softmax(self.w, dim=0)
seq = sum(w[i]*last4[i] for i in range(4)) # [B,T,768]
return self.head.forward_after_bert(seq, mask)
class _Model3ConvLast4(nn.Module):
"""Last-4 weighted + Conv1d(→128): LSTM รับ 128"""
def __init__(self, base_model, hidden=128, classes=2, dropout=0.3, pooling='masked_mean'):
super().__init__()
self.bert = AutoModel.from_pretrained(base_model)
self.w = nn.Parameter(torch.ones(4))
H = self.bert.config.hidden_size
self.c1 = nn.Conv1d(H,128,3,padding=1)
self.c2 = nn.Conv1d(128,128,5,padding=2)
self.head = _BaseHead(128, hidden, classes, dropout, pooling)
def forward(self, ids, mask):
out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True)
last4 = out.hidden_states[-4:]
w = F.softmax(self.w, dim=0)
seq = sum(w[i]*last4[i] for i in range(4)) # [B,T,768]
x = F.relu(self.c1(seq.transpose(1,2)))
x = F.relu(self.c2(x)).transpose(1,2) # [B,T,128]
return self.head.forward_after_bert(x, mask)
def _create_model_fallback(arch: str, base_model: str):
"""เลือกสถาปัตยกรรม fallback จากชื่อ arch ใน config.json"""
if arch in ("Model3_Pure_Last4Weighted", "last4weighted_pure", "last4_pure"):
return _Model3PureLast4(base_model)
if arch in ("Model3_MLP_Last4Weighted", "last4weighted"):
return _Model3ConvLast4(base_model)
raise ValueError(f"No fallback available for arch={arch}")
# ---------- โหลดโมเดลจากโฟลเดอร์ใน repo (เช่น cnn_bilstm/, baseline/, last4weighted_pure/) ----------
def load_model(model_name: str):
key = f"model:{model_name}"
if key in CACHE:
return CACHE[key]
cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
with open(cfg_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
arch_name = cfg.get("arch", "")
tok = AutoTokenizer.from_pretrained(base_model)
# พยายามสร้างจาก common/models.py ก่อน ถ้าไม่สำเร็จค่อย fallback
try:
models = _import_models()
model = models.create_model_by_name(arch_name)
except Exception as e:
print(f"[INFO] Using fallback for arch={arch_name} ({e})")
model = _create_model_fallback(arch_name, base_model)
state = load_file(w_path)
# ใช้ strict=True ถ้า key ตรง; ถ้าอยากกัน edge-case สามารถปรับเป็น strict=False ได้
model.load_state_dict(state, strict=True)
model.eval()
CACHE[key] = (model, tok, cfg)
return CACHE[key]
# ---------- helpers ----------
def _format_pct(x: float) -> str:
return f"{x*100:.2f}%"
# ====== ฟิลเตอร์ข้อความที่ไม่ใช่รีวิว / ค่าว่าง / สัญลักษณ์ ======
_INVALID_STRINGS = {"-", "--", "—", "n/a", "na", "null", "none", "nan", ".", "…", ""} # lower-case
_RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]") # ต้องมีอย่างน้อย 1 ตัวอักษรไทยหรืออังกฤษ
def _norm_text(v) -> str:
"""แปลงค่าให้เป็นสตริงพร้อม trim และกัน NaN/None"""
if v is None:
return ""
if isinstance(v, float) and math.isnan(v):
return ""
s = str(v).strip()
return s
def _is_substantive_text(s: str, min_chars: int = 2) -> bool:
"""เงื่อนไขว่าเป็นข้อความที่พอจะวิเคราะห์ได้"""
if not s:
return False
s_lower = s.lower()
if s_lower in _INVALID_STRINGS:
return False
if not _RE_HAS_LETTER.search(s):
return False
if len(s.replace(" ", "")) < min_chars:
return False
return True
def _clean_texts(texts):
"""รับ list ใด ๆ → คืน (รายการที่ใช้ได้, จำนวนที่ถูกข้าม)"""
all_norm = [_norm_text(t) for t in texts]
cleaned = [t for t in all_norm if _is_substantive_text(t)]
skipped = len(all_norm) - len(cleaned)
return cleaned, skipped
def _detect_cols(df: pd.DataFrame):
"""เดาชื่อคอลัมน์รีวิว/ร้านอัตโนมัติ ถ้าไม่พบรีวิว เลือกคอลัมน์ object ตัวแรก"""
rev_cands = ["review", "text", "comment", "content", "message", "ข้อความ", "รีวิว"]
shop_cands = ["shop", "shop_name", "store", "restaurant", "brand", "merchant", "ชื่อร้าน"]
review_col = next((c for c in rev_cands if c in df.columns), None)
shop_col = next((c for c in shop_cands if c in df.columns), None)
if review_col is None:
obj_cols = [c for c in df.columns if df[c].dtype == object]
if obj_cols:
review_col = obj_cols[0]
return review_col, shop_col
def _summarize_df(df: pd.DataFrame):
"""สรุปภาพรวม + ตัวเลขเฉลี่ยความมั่นใจ"""
total = len(df)
neg = int((df["label"] == "negative").sum())
pos = int((df["label"] == "positive").sum())
neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
info = (
f"**Summary** \n"
f"- Total: {total} \n"
f"- Negative: {neg} \n"
f"- Positive: {pos} \n"
f"- Avg negative: {neg_avg:.2f}% \n"
f"- Avg positive: {pos_avg:.2f}%"
)
return {"total": total, "neg": neg, "pos": pos, "neg_avg": neg_avg, "pos_avg": pos_avg, "md": info}
def _make_figures(df: pd.DataFrame):
s = _summarize_df(df)
# --- BAR: 2 trace, สีคงที่ ---
fig_bar = go.Figure()
fig_bar.add_bar(name="negative", x=["negative"], y=[s["neg"]], marker_color=NEG_COLOR)
fig_bar.add_bar(name="positive", x=["positive"], y=[s["pos"]], marker_color=POS_COLOR)
fig_bar.update_layout(
barmode="group",
title="Label counts",
xaxis_title="label",
yaxis_title="count",
template=TEMPLATE,
legend_title="label",
)
# --- PIE: สีสอดคล้องกับ bar ---
fig_pie = go.Figure(
go.Pie(
labels=["negative", "positive"],
values=[s["neg"], s["pos"]],
hole=0.35,
sort=False,
marker=dict(colors=[NEG_COLOR, POS_COLOR]),
)
)
fig_pie.update_layout(title="Label share", template=TEMPLATE)
return fig_bar, fig_pie, s["md"]
def _shop_summary(out_df: pd.DataFrame, max_shops=15):
"""สรุปต่อร้าน: ตาราง + stacked bar (pos/neg) — ใช้สีคงที่"""
if "shop" not in out_df.columns:
empty_tbl = pd.DataFrame(columns=["shop","total","positive","negative","positive_rate(%)","negative_rate(%)"])
return go.Figure(), empty_tbl
g = out_df.groupby("shop")["label"].value_counts().unstack(fill_value=0)
for col in ["positive","negative"]:
if col not in g.columns:
g[col] = 0
g["total"] = g["positive"] + g["negative"]
g = g.sort_values("total", ascending=False)
table = g[["total","positive","negative"]].copy()
table["positive_rate(%))"] = (table["positive"] / table["total"] * 100).round(2)
table["negative_rate(%)"] = (table["negative"] / table["total"] * 100).round(2)
table = table.reset_index().rename(columns={"index":"shop"})
# กราฟโชว์ top N ร้าน
top = table.head(max_shops)
fig = go.Figure()
fig.add_bar(name="positive", x=top["shop"], y=top["positive"], marker_color=POS_COLOR)
fig.add_bar(name="negative", x=top["shop"], y=top["negative"], marker_color=NEG_COLOR)
fig.update_layout(
barmode="stack",
title=f"Per-shop counts (top {len(top)})",
xaxis_title="shop",
yaxis_title="count",
legend_title="label",
template=TEMPLATE,
xaxis=dict(tickangle=-30),
)
return fig, table
# ---------- core prediction ----------
def _predict_batch(texts, model_name, batch_size=64):
"""รับ list[str] (ผ่านการกรองแล้ว) → คืน list[dict]"""
model, tok, cfg = load_model(model_name)
results = []
for i in range(0, len(texts), batch_size):
chunk = texts[i:i+batch_size]
enc = tok(chunk, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
with torch.no_grad():
logits = model(enc["input_ids"], enc["attention_mask"])
probs = F.softmax(logits, dim=1).cpu().numpy()
for txt, p in zip(chunk, probs):
neg, pos = float(p[0]), float(p[1])
label = "positive" if pos >= neg else "negative"
results.append({
"review": txt,
"negative(%)": _format_pct(neg),
"positive(%)": _format_pct(pos),
"label": label,
})
return results
# ---------- API wrappers ----------
def predict_one(text: str, model_choice: str):
try:
s = _norm_text(text)
if not _is_substantive_text(s):
return {"negative": 0.0, "positive": 0.0}, "invalid"
model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง
out = _predict_batch([s], model_name)[0]
probs = {
"negative": float(out["negative(%)"].rstrip("%"))/100.0,
"positive": float(out["positive(%)"].rstrip("%"))/100.0,
}
return probs, out["label"]
except Exception as e:
print("ERROR in predict_one:", repr(e))
traceback.print_exc()
raise
def predict_many(text_block: str, model_choice: str):
try:
model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง
raw_lines = (text_block or "").splitlines()
trimmed = [_norm_text(ln) for ln in raw_lines if _norm_text(ln)]
cleaned, skipped = _clean_texts(trimmed)
if len(cleaned) == 0:
empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
return empty, go.Figure(), go.Figure(), "No valid text"
results = _predict_batch(cleaned, model_name)
df = pd.DataFrame(results, columns=["review","negative(%)","positive(%)","label"])
fig_bar, fig_pie, info_md = _make_figures(df)
info_md = f"{info_md} \n- Skipped (empty/non-text): {skipped}"
return df, fig_bar, fig_pie, info_md
except Exception as e:
print("ERROR in predict_many:", repr(e))
traceback.print_exc()
raise
def predict_csv(file_obj, model_choice: str, review_col_override: str = "", shop_col_override: str = ""):
"""
พฤติกรรม:
- ไม่ตัดแถวทิ้ง: แถว invalid ยังอยู่ เรียงตามไฟล์เดิม
- review ของแถว invalid = NA, ไม่คำนวณผลลัพธ์
- shop คงค่าจากไฟล์เดิม ไม่แปลงเป็นสตริง
- กราฟ/สรุป คำนวณจากเฉพาะแถว valid
"""
try:
if file_obj is None:
return pd.DataFrame(), None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "กรุณาอัปโหลดไฟล์ CSV"
model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง
df = pd.read_csv(file_obj.name)
auto_rev, auto_shop = _detect_cols(df)
rev_col = (review_col_override or "").strip() or auto_rev
shop_col = (shop_col_override or "").strip() or auto_shop
if rev_col not in df.columns:
raise ValueError(f"ไม่พบคอลัมน์รีวิว '{rev_col}' ใน CSV (columns = {list(df.columns)})")
# === เตรียมรีวิวและมาสก์แถวที่ 'มีเนื้อหา' เท่านั้น ===
reviews_norm = df[rev_col].apply(_norm_text)
mask_valid = reviews_norm.apply(_is_substantive_text)
idx_valid = df.index[mask_valid].tolist()
skipped = int((~mask_valid).sum())
# === พยากรณ์เฉพาะแถวที่ valid ===
results = []
if len(idx_valid) > 0:
texts_valid = reviews_norm.loc[idx_valid].tolist()
results = _predict_batch(texts_valid, model_name) # list[dict] ตามลำดับ idx_valid
# === สร้าง DataFrame ผลลัพธ์ "ครบทุกแถว" ตามลำดับเดิม ===
out = pd.DataFrame(index=df.index, columns=["review","negative(%)","positive(%)","label"])
# review: valid → normalized text, invalid → NA
out.loc[idx_valid, "review"] = reviews_norm.loc[idx_valid].values
out.loc[~mask_valid, "review"] = pd.NA
# เติมผลพยากรณ์กลับตาม index เดิมสำหรับแถว valid
for i, idx in enumerate(idx_valid):
p = results[i]
out.at[idx, "negative(%)"] = p["negative(%)"]
out.at[idx, "positive(%)"] = p["positive(%)"]
out.at[idx, "label"] = p["label"]
# แทรกคอลัมน์ shop ด้านหน้า (คงค่าตามต้นฉบับโดยไม่ .astype(str))
if shop_col and shop_col in df.columns:
out.insert(0, "shop", df[shop_col])
else:
out.insert(0, "shop", pd.Series([pd.NA]*len(out), index=out.index))
# === เตรียมข้อมูล "เฉพาะแถวที่ valid" ไว้ทำกราฟ/สรุป ===
out_valid = out.loc[idx_valid].copy()
# ไฟล์ผลลัพธ์สำหรับดาวน์โหลด → ครบทุกแถว
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
out.to_csv(tmp.name, index=False, encoding="utf-8-sig")
if out_valid.empty:
empty_fig = go.Figure()
info_md = "ไม่พบรีวิวที่เป็นข้อความ\n- Skipped (empty/non-text): {}".format(skipped)
empty_tbl = pd.DataFrame(columns=["shop","total","positive","negative","positive_rate(%)","negative_rate(%)"])
return out, tmp.name, empty_fig, empty_fig, empty_fig, empty_tbl, info_md
# กราฟ/สรุปรวม (จากแถวที่ valid เท่านั้น)
fig_bar, fig_pie, info_md = _make_figures(out_valid)
# กราฟ/ตารางต่อร้าน (ใช้เฉพาะ valid)
fig_shop, tbl_shop = _shop_summary(out_valid)
# แนบข้อความบอกคอลัมน์ที่ใช้ + จำนวนแถวที่ถูกข้าม
info_md = (
f"{info_md} \n"
f"ใช้คอลัมน์รีวิว: {rev_col}"
+ (f" | คอลัมน์ร้าน: {shop_col}" if shop_col and (shop_col in df.columns) else " | ไม่มีคอลัมน์ร้าน")
+ f" \n- Skipped (empty/non-text): {skipped}"
)
return out, tmp.name, fig_bar, fig_pie, fig_shop, tbl_shop, info_md
except Exception as e:
print("ERROR in predict_csv:", repr(e))
traceback.print_exc()
raise
# ---------- Gradio UI ----------
AVAILABLE_CHOICES = ["cnn_bilstm", "baseline", "last4weighted_bilstm"] # เพิ่มชื่อโฟลเดอร์โมเดลใหม่ที่คุณอัปจริง
if DEFAULT_MODEL not in AVAILABLE_CHOICES:
DEFAULT_MODEL = "cnn_bilstm"
with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo:
gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM/CNN/Last4 Heads)")
model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล")
with gr.Tab("Single"):
t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)")
probs = gr.Label(label="Probabilities")
pred = gr.Textbox(label="Prediction", interactive=False)
gr.Button("Predict").click(predict_one, [t1, model_radio], [probs, pred])
with gr.Tab("Batch (หลายข้อความ)"):
t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)")
df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
bar2 = gr.Plot(label="Label counts (bar)")
pie2 = gr.Plot(label="Label share (pie)")
sum2 = gr.Markdown()
gr.Button("Run Batch").click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2])
with gr.Tab("CSV (auto-detect columns)"):
f = gr.File(label="อัปโหลด CSV", file_types=[".csv"])
review_col_inp = gr.Textbox(label="ชื่อคอลัมน์รีวิว (เว้นว่างให้เดาได้)")
shop_col_inp = gr.Textbox(label="ชื่อคอลัมน์ร้าน (เว้นว่างได้)")
df3 = gr.Dataframe(label="ผลลัพธ์ CSV", interactive=False)
download = gr.File(label="ดาวน์โหลดผลลัพธ์")
bar3 = gr.Plot(label="Label counts (bar)")
pie3 = gr.Plot(label="Label share (pie)")
shop_bar = gr.Plot(label="Per-shop stacked bar")
shop_tbl = gr.Dataframe(label="Per-shop summary", interactive=False)
info = gr.Markdown()
gr.Button("Run CSV").click(
predict_csv,
inputs=[f, model_radio, review_col_inp, shop_col_inp],
outputs=[df3, download, bar3, pie3, shop_bar, shop_tbl, info]
)
if __name__ == "__main__":
demo.launch()