Corin1998's picture
Update ui.py
e5dff79 verified
from __future__ import annotations
# === Writable config dirs (must come right after future import) ===
import os
os.environ.setdefault("APP_DATA_DIR", "/data/app_data" if os.access("/data", os.W_OK) else "/tmp/app_data")
os.environ.setdefault("MPLCONFIGDIR", os.path.join(os.environ["APP_DATA_DIR"], "mplconfig"))
os.makedirs(os.environ["MPLCONFIGDIR"], exist_ok=True)
# =================================================================
import json
import uuid
from datetime import datetime
from typing import Dict
import gradio as gr
import pandas as pd
from app.storage import (
init_db, insert_variant, upsert_campaign, get_variant, get_metrics,
get_campaign_value_per_conversion, log_event
)
from app.bandit import ThompsonBandit
from app.forecast import SeasonalityModel
from app.compliance import rule_based_check, llm_check_and_fix
from app.openai_client import openai_chat_json
# 初期化
init_db()
_seasonality_cache: Dict[str, SeasonalityModel] = {}
# 固定カラム(常にこの形で返す)
GENERATE_COLUMNS = ["variant_id", "status", "rejection_reason", "text"]
REPORT_COLUMNS = ["variant_id","impressions","clicks","conversions","ctr","cvr","expected_value"]
# JSONモード前提の厳格プロンプト
GEN_SYSTEM = """
あなたは日本語広告コピーのプロフェッショナルコピーライターです。
出力は**次のJSONオブジェクトのみ**で厳密に返してください。余計な文章・説明・前置きは禁止です。
形式:
{
"variants": [
{"headline": "全角15-25字程度", "body": "全角40-90字程度"},
...
]
}
ルール:
- 医薬効能の断定、100%、永久、即効、根拠のない数値などの誇大表現は禁止
- CTAは自然に
- 日本語で、句読点や記号は自然に
"""
GEN_USER_TEMPLATE = """
ブランド: {brand}
商品/サービス: {product}
想定ターゲット: {target}
トーン: {tone}
制約: {constraints}
生成本数: {k}
要件:
- "variants" 配列の要素数は **ちょうど {k}** 件にしてください
- 各要素は {{"headline": "...", "body": "..."}} のみ
"""
def _seasonal(campaign_id: str) -> SeasonalityModel:
if campaign_id not in _seasonality_cache:
m = SeasonalityModel(campaign_id)
try:
m.fit()
except Exception:
# 失敗してもフォールバックあり
pass
_seasonality_cache[campaign_id] = m
return _seasonality_cache[campaign_id]
def _safe_get_variants(data, k: int):
"""LLM応答から variants 配列を安全に取り出して正規化。失敗時は None を返す。"""
items = []
if isinstance(data, dict) and isinstance(data.get("variants"), list):
items = data["variants"]
elif isinstance(data, list):
items = data
if not items or not all(isinstance(x, dict) for x in items):
return None
out = []
for it in items[:k]:
out.append({
"headline": str(it.get("headline", "")).strip(),
"body": str(it.get("body", "")).strip(),
})
return out
def _local_variants(brand: str, product: str, k: int):
"""LLMが不調でも UI を止めないための最終フォールバック生成(簡易・無害表現)。"""
base_head = [
"使いやすさで選ばれています",
"日々の習慣をシンプルに",
"はじめてでも安心",
"続けやすいサポートを",
"いま必要な機能だけを"
]
base_body = [
"{brand}の「{product}」。生活になじむ設計で、今日からムリなく始められます。まずは詳細をご覧ください。",
"毎日を少しラクに。{brand}の{product}が、あなたの習慣づくりを後押しします。今すぐチェック。",
"難しい操作は不要。{brand}の{product}なら、使い始めから自然に続けられます。詳しくはサイトへ。",
"必要な情報をひと目で。{brand}の{product}で、日々の管理をシンプルに。詳細を見る。",
"続けやすさを重視。{brand}の{product}で、小さな一歩から。"
]
out = []
for i in range(k):
hi = base_head[i % len(base_head)]
bo = base_body[i % len(base_body)].format(brand=brand, product=product)
out.append({"headline": hi, "body": bo})
return out
async def ui_generate(campaign_id: str, brand: str, product: str, target: str, tone: str, k_variants: int,
ng_words: str, value_per_conversion: float):
k_variants = int(k_variants) # Slider の float を明示的に int 化
constraints = {"ng_words": [w.strip() for w in ng_words.splitlines() if w.strip()]} if ng_words else {}
upsert_campaign(
campaign_id, brand, product, target, tone, "ja", constraints, value_per_conversion
)
user = GEN_USER_TEMPLATE.format(
brand=brand,
product=product,
target=target,
tone=tone,
constraints=json.dumps(constraints, ensure_ascii=False),
k=k_variants,
)
# まずは通常プロンプトで JSON モード呼び出し
items = None
try:
data = await openai_chat_json(
[
{"role": "system", "content": GEN_SYSTEM},
{"role": "user", "content": user},
],
temperature=0.2,
max_tokens=1200,
)
items = _safe_get_variants(data, k_variants)
except Exception:
items = None
# 失敗/空のときは、温度をさらに下げて再試行(より厳格に)
if not items:
try:
retry_user = user + "\n\n注意: 'variants' は必ず指定件数、各要素は {\"headline\":\"...\",\"body\":\"...\"} のみ。"
data = await openai_chat_json(
[
{"role": "system", "content": GEN_SYSTEM},
{"role": "user", "content": retry_user},
],
temperature=0.1,
max_tokens=1000,
)
items = _safe_get_variants(data, k_variants)
except Exception:
items = None
# それでも無理ならローカル生成(UIを止めない)
if not items:
items = _local_variants(brand, product, k_variants)
rows = []
for it in items[:k_variants]:
headline = it["headline"]
body = it["body"]
text = f"{headline}\n{body}".strip()
vid = str(uuid.uuid4())[:8]
ok_rule, bads = rule_based_check(text, (constraints or {}).get("ng_words"))
rejection_reason = None
status = "approved"
if not ok_rule:
ok_llm, reasons, fixed = llm_check_and_fix(text)
if ok_llm:
text = fixed or text
else:
status = "rejected"
rejection_reason = "; ".join(bads + reasons)
else:
ok_llm, reasons, fixed = llm_check_and_fix(text)
if not ok_llm:
text = fixed or text
insert_variant(campaign_id, vid, text, status, rejection_reason)
rows.append({
"variant_id": vid,
"status": status,
"rejection_reason": rejection_reason or "",
"text": text,
})
# 常に固定カラムで返す(空でもカラムを持つDataFrame)
df = pd.DataFrame(rows, columns=GENERATE_COLUMNS)
return df
def ui_serve(campaign_id: str, hour: int, segment: str):
ctx = {"hour": int(hour), "segment": (segment or "").strip() or None}
m = _seasonal(campaign_id)
bandit = ThompsonBandit(campaign_id)
vid, _ = bandit.sample_arm(ctx, m.expected_ctr)
if not vid:
raise gr.Error("配信可能なバリアントがありません。まずは Generate してください。")
row = get_variant(campaign_id, vid)
if not row:
raise gr.Error("バリアントが見つかりません。")
# impression 記録
log_event(campaign_id, vid, "impression", datetime.utcnow().isoformat(), None)
ThompsonBandit.update_with_event(campaign_id, vid, "impression")
return vid, row["text"]
def ui_feedback(campaign_id: str, variant_id: str, event_type: str):
if not variant_id:
raise gr.Error("先に Serve してください。")
log_event(campaign_id, variant_id, event_type, datetime.utcnow().isoformat(), None)
ThompsonBandit.update_with_event(campaign_id, variant_id, event_type)
return f"{event_type} を記録しました。"
def ui_report(campaign_id: str):
mets = get_metrics(campaign_id)
vpc = get_campaign_value_per_conversion(campaign_id)
rows = []
for r in mets:
imp = int(r["impressions"]); clk = int(r["clicks"]); conv = int(r["conversions"])
ctr = (clk / imp) if imp > 0 else 0.0
cvr = (conv / clk) if clk > 0 else 0.0
ev = ctr * cvr * vpc
rows.append({
"variant_id": r["variant_id"],
"impressions": imp,
"clicks": clk,
"conversions": conv,
"ctr": round(ctr, 4),
"cvr": round(cvr, 4),
"expected_value": round(ev, 6),
})
# 常に固定カラムで返す(空でもカラムを持つDataFrame)
df = pd.DataFrame(rows, columns=REPORT_COLUMNS)
return df
def ui_check(text: str):
ok_rule, bads = rule_based_check(text, [])
ok_llm, reasons, fixed = llm_check_and_fix(text)
status = "pass" if (ok_rule and ok_llm) else "needs_fix"
fixed_text = fixed or (text if status == "pass" else "")
reasons_joined = "; ".join(bads + reasons)
return status, reasons_joined, fixed_text
with gr.Blocks(title="AdCopy MAB Optimizer", fill_height=True) as demo:
gr.Markdown("""
# AdCopy MAB Optimizer(HF UI)
**広告コピー自動生成 → Thompson Sampling(CTR×CVR) → レポート** を、Hugging Face Spaces 上で完結。
- LLM: OpenAI (`OPENAI_API_KEY` を Space Secrets に設定)
- DB: SQLite(`/data/app_data/data.db` など、書き込み可能ディレクトリ)
- 季節性: Prophet/NeuralProphet(なければ簡易ヒューリスティック)
""")
with gr.Tab("1) Generate"):
with gr.Row():
campaign_id = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1)
k_variants = gr.Slider(1, 10, value=5, step=1, label="生成本数")
value_per_conv = gr.Number(value=5000, label="value_per_conversion")
brand = gr.Textbox(label="ブランド", value="SFM")
product = gr.Textbox(label="商品/サービス", value="HbA1c測定アプリ")
target = gr.Textbox(label="ターゲット", value="30-50代の健康意識が高い層")
tone = gr.Textbox(label="トーン", value="エビデンス重視で安心感")
ng_words = gr.Textbox(label="NGワード(改行区切り)", value="治る\n奇跡")
btn_gen = gr.Button("広告案を生成&審査&保存")
table_gen = gr.Dataframe(headers=GENERATE_COLUMNS, interactive=False)
btn_gen.click(ui_generate, [campaign_id, brand, product, target, tone, k_variants, ng_words, value_per_conv], [table_gen])
with gr.Tab("2) Serve & Feedback"):
with gr.Row():
campaign_id2 = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1)
hour = gr.Slider(0, 23, value=20, step=1, label="hour")
segment = gr.Textbox(label="segment (任意)")
btn_serve = gr.Button("Serve Ad(impressionを記録)")
served_vid = gr.Textbox(label="served variant_id", interactive=False)
served_text = gr.Textbox(label="served text", lines=6, interactive=False)
btn_serve.click(ui_serve, [campaign_id2, hour, segment], [served_vid, served_text])
with gr.Row():
btn_click = gr.Button("Clickを記録")
btn_conv = gr.Button("Conversionを記録")
msg = gr.Markdown()
btn_click.click(lambda cid, vid: ui_feedback(cid, vid, "click"), [campaign_id2, served_vid], [msg])
btn_conv.click(lambda cid, vid: ui_feedback(cid, vid, "conversion"), [campaign_id2, served_vid], [msg])
with gr.Tab("3) Report"):
campaign_id3 = gr.Textbox(label="campaign_id", value="cmp-demo")
btn_rep = gr.Button("更新")
table_rep = gr.Dataframe(headers=REPORT_COLUMNS, interactive=False)
btn_rep.click(ui_report, [campaign_id3], [table_rep])
with gr.Tab("4) Compliance Check"):
cand = gr.Textbox(label="チェックする文面", lines=5)
btn_chk = gr.Button("判定")
status = gr.Textbox(label="status")
reasons = gr.Textbox(label="reasons")
fixed = gr.Textbox(label="fixed (修正案)", lines=5)
btn_chk.click(ui_check, [cand], [status, reasons, fixed])
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)