Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os, json, uuid | |
| from datetime import datetime | |
| from typing import List, Dict, Any | |
| import gradio as gr | |
| import pandas as pd | |
| from app.storage import init_db, insert_variant, upsert_campaign, get_variant, get_metrics, get_campaign_value_per_conversion, log_event | |
| from app.bandit import ThompsonBandit | |
| from app.forecast import SeasonalityModel | |
| from app.compliance import rule_based_check, llm_check_and_fix | |
| from app.openai_client import openai_chat | |
| # 初期化 | |
| init_db() | |
| _seasonality_cache: Dict[str, SeasonalityModel] = {} | |
| GEN_SYSTEM = """ | |
| あなたは日本語広告コピーのプロフェッショナルコピーライターです。 | |
| 出力はJSON配列(各要素は{\"headline\":..., \"body\":...})のみで返してください。句読点や記号は自然に。誇大・断定は避け、事実ベースで魅力を伝えます。 | |
| """ | |
| GEN_USER_TEMPLATE = """ | |
| ブランド: {brand} | |
| 商品/サービス: {product} | |
| 想定ターゲット: {target} | |
| トーン: {tone} | |
| 制約: {constraints} | |
| 生成本数: {k} | |
| 条件: | |
| - 1本あたり見出し(全角15-25字目安)+ 本文(全角40-90字目安) | |
| - 禁止: 医薬効能の断定、100%、永久、即効、根拠のない数値 | |
| - CTAは自然に | |
| """ | |
| def _seasonal(campaign_id: str) -> SeasonalityModel: | |
| if campaign_id not in _seasonality_cache: | |
| m = SeasonalityModel(campaign_id) | |
| try: | |
| m.fit() | |
| except Exception: | |
| pass | |
| _seasonality_cache[campaign_id] = m | |
| return _seasonality_cache[campaign_id] | |
| async def ui_generate(campaign_id: str, brand: str, product: str, target: str, tone: str, k_variants: int, | |
| ng_words: str, value_per_conversion: float): | |
| constraints = {"ng_words": [w.strip() for w in ng_words.splitlines() if w.strip()]} if ng_words else {} | |
| upsert_campaign( | |
| campaign_id, brand, product, target, tone, "ja", constraints, value_per_conversion | |
| ) | |
| user = GEN_USER_TEMPLATE.format( | |
| brand=brand, | |
| product=product, | |
| target=target, | |
| tone=tone, | |
| constraints=json.dumps(constraints, ensure_ascii=False), | |
| k=k_variants, | |
| ) | |
| raw = await openai_chat([ | |
| {"role": "system", "content": GEN_SYSTEM}, | |
| {"role": "user", "content": user} | |
| ], temperature=0.8, max_tokens=800) | |
| try: | |
| items = json.loads(raw) | |
| assert isinstance(items, list) | |
| except Exception: | |
| raise gr.Error("LLM出力のJSONパースに失敗しました。プロンプトを短くするか、再実行してください。") | |
| rows = [] | |
| for it in items[:k_variants]: | |
| headline = (it.get("headline") or "").strip() | |
| body = (it.get("body") or "").strip() | |
| text = f"{headline}\n{body}".strip() | |
| vid = str(uuid.uuid4())[:8] | |
| ok_rule, bads = rule_based_check(text, (constraints or {}).get("ng_words")) | |
| rejection_reason = None | |
| status = "approved" | |
| if not ok_rule: | |
| ok_llm, reasons, fixed = llm_check_and_fix(text) | |
| if ok_llm: | |
| text = fixed or text | |
| else: | |
| status = "rejected" | |
| rejection_reason = "; ".join(bads + reasons) | |
| else: | |
| ok_llm, reasons, fixed = llm_check_and_fix(text) | |
| if not ok_llm: | |
| text = fixed or text | |
| insert_variant(campaign_id, vid, text, status, rejection_reason) | |
| rows.append({ | |
| "variant_id": vid, | |
| "status": status, | |
| "rejection_reason": rejection_reason or "", | |
| "text": text, | |
| }) | |
| df = pd.DataFrame(rows, columns=["variant_id", "status", "rejection_reason", "text"]) if rows else pd.DataFrame() | |
| return df | |
| def ui_serve(campaign_id: str, hour: int, segment: str): | |
| ctx = {"hour": int(hour), "segment": (segment or "").strip() or None} | |
| m = _seasonal(campaign_id) | |
| bandit = ThompsonBandit(campaign_id) | |
| vid, _ = bandit.sample_arm(ctx, seasonal_fn=m.expected_ctr) | |
| if not vid: | |
| raise gr.Error("配信可能なバリアントがありません。まずは Generate してください。") | |
| row = get_variant(campaign_id, vid) | |
| if not row: | |
| raise gr.Error("バリアントが見つかりません。") | |
| # impression 記録 | |
| log_event(campaign_id, vid, "impression", datetime.utcnow().isoformat(), None) | |
| ThompsonBandit.update_with_event(campaign_id, vid, "impression") | |
| return vid, row["text"] | |
| def ui_feedback(campaign_id: str, variant_id: str, event_type: str): | |
| if not variant_id: | |
| raise gr.Error("先に Serve してください。") | |
| log_event(campaign_id, variant_id, event_type, datetime.utcnow().isoformat(), None) | |
| ThompsonBandit.update_with_event(campaign_id, variant_id, event_type) | |
| return f"{event_type} を記録しました。" | |
| def ui_report(campaign_id: str): | |
| mets = get_metrics(campaign_id) | |
| vpc = get_campaign_value_per_conversion(campaign_id) | |
| rows = [] | |
| for r in mets: | |
| imp = int(r["impressions"]); clk = int(r["clicks"]); conv = int(r["conversions"]) | |
| ctr = (clk / imp) if imp > 0 else 0.0 | |
| cvr = (conv / clk) if clk > 0 else 0.0 | |
| ev = ctr * cvr * vpc | |
| rows.append({ | |
| "variant_id": r["variant_id"], | |
| "impressions": imp, | |
| "clicks": clk, | |
| "conversions": conv, | |
| "ctr": round(ctr, 4), | |
| "cvr": round(cvr, 4), | |
| "expected_value": round(ev, 6), | |
| }) | |
| df = pd.DataFrame(rows, columns=["variant_id","impressions","clicks","conversions","ctr","cvr","expected_value"]) if rows else pd.DataFrame() | |
| return df | |
| def ui_check(text: str): | |
| ok_rule, bads = rule_based_check(text, []) | |
| ok_llm, reasons, fixed = llm_check_and_fix(text) | |
| status = "pass" if (ok_rule and ok_llm) else "needs_fix" | |
| fixed_text = fixed or (text if status == "pass" else "") | |
| reasons_joined = "; ".join(bads + reasons) | |
| return status, reasons_joined, fixed_text | |
| with gr.Blocks(title="AdCopy MAB Optimizer", fill_height=True) as demo: | |
| gr.Markdown(""" | |
| # AdCopy MAB Optimizer(HF UI) | |
| **広告コピー自動生成 → Thompson Sampling(CTR×CVR) → レポート** を、Hugging Face Spaces 上で完結。 | |
| - LLM: OpenAI (`OPENAI_API_KEY` を Space Secrets に設定) | |
| - DB: SQLite(`data/data.db`) | |
| - 季節性: Prophet/NeuralProphet(なければ簡易ヒューリスティック) | |
| """) | |
| with gr.Tab("1) Generate"): | |
| with gr.Row(): | |
| campaign_id = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1) | |
| k_variants = gr.Slider(1, 10, value=5, step=1, label="生成本数") | |
| value_per_conv = gr.Number(value=5000, label="value_per_conversion") | |
| brand = gr.Textbox(label="ブランド", value="SFM") | |
| product = gr.Textbox(label="商品/サービス", value="HbA1c測定アプリ") | |
| target = gr.Textbox(label="ターゲット", value="30-50代の健康意識が高い層") | |
| tone = gr.Textbox(label="トーン", value="エビデンス重視で安心感") | |
| ng_words = gr.Textbox(label="NGワード(改行区切り)", value="治る\n奇跡") | |
| btn_gen = gr.Button("広告案を生成&審査&保存") | |
| table_gen = gr.Dataframe(headers=["variant_id","status","rejection_reason","text"], interactive=False) | |
| btn_gen.click(ui_generate, [campaign_id, brand, product, target, tone, k_variants, ng_words, value_per_conv], [table_gen]) | |
| with gr.Tab("2) Serve & Feedback"): | |
| with gr.Row(): | |
| campaign_id2 = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1) | |
| hour = gr.Slider(0, 23, value=20, step=1, label="hour") | |
| segment = gr.Textbox(label="segment (任意)") | |
| btn_serve = gr.Button("Serve Ad(impressionを記録)") | |
| served_vid = gr.Textbox(label="served variant_id", interactive=False) | |
| served_text = gr.Textbox(label="served text", lines=6, interactive=False) | |
| btn_serve.click(ui_serve, [campaign_id2, hour, segment], [served_vid, served_text]) | |
| with gr.Row(): | |
| btn_click = gr.Button("Clickを記録") | |
| btn_conv = gr.Button("Conversionを記録") | |
| msg = gr.Markdown() | |
| btn_click.click(lambda cid, vid: ui_feedback(cid, vid, "click"), [campaign_id2, served_vid], [msg]) | |
| btn_conv.click(lambda cid, vid: ui_feedback(cid, vid, "conversion"), [campaign_id2, served_vid], [msg]) | |
| with gr.Tab("3) Report"): | |
| campaign_id3 = gr.Textbox(label="campaign_id", value="cmp-demo") | |
| btn_rep = gr.Button("更新") | |
| table_rep = gr.Dataframe(headers=["variant_id","impressions","clicks","conversions","ctr","cvr","expected_value"], interactive=False) | |
| btn_rep.click(ui_report, [campaign_id3], [table_rep]) | |
| with gr.Tab("4) Compliance Check"): | |
| cand = gr.Textbox(label="チェックする文面", lines=5) | |
| btn_chk = gr.Button("判定") | |
| status = gr.Textbox(label="status") | |
| reasons = gr.Textbox(label="reasons") | |
| fixed = gr.Textbox(label="fixed (修正案)", lines=5) | |
| btn_chk.click(ui_check, [cand], [status, reasons, fixed]) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |