Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import os, json, uuid
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import List, Dict, Any
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
from app.storage import init_db, insert_variant, upsert_campaign, get_variant, get_metrics, get_campaign_value_per_conversion, log_event
|
| 10 |
+
from app.bandit import ThompsonBandit
|
| 11 |
+
from app.forecast import SeasonalityModel
|
| 12 |
+
from app.compliance import rule_based_check, llm_check_and_fix
|
| 13 |
+
from app.openai_client import openai_chat
|
| 14 |
+
|
| 15 |
+
# 初期化
|
| 16 |
+
init_db()
|
| 17 |
+
_seasonality_cache: Dict[str, SeasonalityModel] = {}
|
| 18 |
+
|
| 19 |
+
GEN_SYSTEM = """
|
| 20 |
+
あなたは日本語広告コピーのプロフェッショナルコピーライターです。
|
| 21 |
+
出力はJSON配列(各要素は{\"headline\":..., \"body\":...})のみで返してください。句読点や記号は自然に。誇大・断定は避け、事実ベースで魅力を伝えます。
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
GEN_USER_TEMPLATE = """
|
| 25 |
+
ブランド: {brand}
|
| 26 |
+
商品/サービス: {product}
|
| 27 |
+
想定ターゲット: {target}
|
| 28 |
+
トーン: {tone}
|
| 29 |
+
制約: {constraints}
|
| 30 |
+
生成本数: {k}
|
| 31 |
+
条件:
|
| 32 |
+
- 1本あたり見出し(全角15-25字目安)+ 本文(全角40-90字目安)
|
| 33 |
+
- 禁止: 医薬効能の断定、100%、永久、即効、根拠のない数値
|
| 34 |
+
- CTAは自然に
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def _seasonal(campaign_id: str) -> SeasonalityModel:
|
| 38 |
+
if campaign_id not in _seasonality_cache:
|
| 39 |
+
m = SeasonalityModel(campaign_id)
|
| 40 |
+
try:
|
| 41 |
+
m.fit()
|
| 42 |
+
except Exception:
|
| 43 |
+
pass
|
| 44 |
+
_seasonality_cache[campaign_id] = m
|
| 45 |
+
return _seasonality_cache[campaign_id]
|
| 46 |
+
|
| 47 |
+
async def ui_generate(campaign_id: str, brand: str, product: str, target: str, tone: str, k_variants: int,
|
| 48 |
+
ng_words: str, value_per_conversion: float):
|
| 49 |
+
constraints = {"ng_words": [w.strip() for w in ng_words.splitlines() if w.strip()]} if ng_words else {}
|
| 50 |
+
|
| 51 |
+
upsert_campaign(
|
| 52 |
+
campaign_id, brand, product, target, tone, "ja", constraints, value_per_conversion
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
user = GEN_USER_TEMPLATE.format(
|
| 56 |
+
brand=brand,
|
| 57 |
+
product=product,
|
| 58 |
+
target=target,
|
| 59 |
+
tone=tone,
|
| 60 |
+
constraints=json.dumps(constraints, ensure_ascii=False),
|
| 61 |
+
k=k_variants,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
raw = await openai_chat([
|
| 65 |
+
{"role": "system", "content": GEN_SYSTEM},
|
| 66 |
+
{"role": "user", "content": user}
|
| 67 |
+
], temperature=0.8, max_tokens=800)
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
items = json.loads(raw)
|
| 71 |
+
assert isinstance(items, list)
|
| 72 |
+
except Exception:
|
| 73 |
+
raise gr.Error("LLM出力のJSONパースに失敗しました。プロンプトを短くするか、再実行してください。")
|
| 74 |
+
|
| 75 |
+
rows = []
|
| 76 |
+
for it in items[:k_variants]:
|
| 77 |
+
headline = (it.get("headline") or "").strip()
|
| 78 |
+
body = (it.get("body") or "").strip()
|
| 79 |
+
text = f"{headline}\n{body}".strip()
|
| 80 |
+
vid = str(uuid.uuid4())[:8]
|
| 81 |
+
|
| 82 |
+
ok_rule, bads = rule_based_check(text, (constraints or {}).get("ng_words"))
|
| 83 |
+
rejection_reason = None
|
| 84 |
+
status = "approved"
|
| 85 |
+
|
| 86 |
+
if not ok_rule:
|
| 87 |
+
ok_llm, reasons, fixed = llm_check_and_fix(text)
|
| 88 |
+
if ok_llm:
|
| 89 |
+
text = fixed or text
|
| 90 |
+
else:
|
| 91 |
+
status = "rejected"
|
| 92 |
+
rejection_reason = "; ".join(bads + reasons)
|
| 93 |
+
else:
|
| 94 |
+
ok_llm, reasons, fixed = llm_check_and_fix(text)
|
| 95 |
+
if not ok_llm:
|
| 96 |
+
text = fixed or text
|
| 97 |
+
|
| 98 |
+
insert_variant(campaign_id, vid, text, status, rejection_reason)
|
| 99 |
+
rows.append({
|
| 100 |
+
"variant_id": vid,
|
| 101 |
+
"status": status,
|
| 102 |
+
"rejection_reason": rejection_reason or "",
|
| 103 |
+
"text": text,
|
| 104 |
+
})
|
| 105 |
+
|
| 106 |
+
df = pd.DataFrame(rows, columns=["variant_id", "status", "rejection_reason", "text"]) if rows else pd.DataFrame()
|
| 107 |
+
return df
|
| 108 |
+
|
| 109 |
+
def ui_serve(campaign_id: str, hour: int, segment: str):
|
| 110 |
+
ctx = {"hour": int(hour), "segment": (segment or "").strip() or None}
|
| 111 |
+
m = _seasonal(campaign_id)
|
| 112 |
+
bandit = ThompsonBandit(campaign_id)
|
| 113 |
+
vid, _ = bandit.sample_arm(ctx, seasonal_fn=m.expected_ctr)
|
| 114 |
+
if not vid:
|
| 115 |
+
raise gr.Error("配信可能なバリアントがありません。まずは Generate してください。")
|
| 116 |
+
|
| 117 |
+
row = get_variant(campaign_id, vid)
|
| 118 |
+
if not row:
|
| 119 |
+
raise gr.Error("バリアントが見つかりません。")
|
| 120 |
+
|
| 121 |
+
# impression 記録
|
| 122 |
+
log_event(campaign_id, vid, "impression", datetime.utcnow().isoformat(), None)
|
| 123 |
+
ThompsonBandit.update_with_event(campaign_id, vid, "impression")
|
| 124 |
+
|
| 125 |
+
return vid, row["text"]
|
| 126 |
+
|
| 127 |
+
def ui_feedback(campaign_id: str, variant_id: str, event_type: str):
|
| 128 |
+
if not variant_id:
|
| 129 |
+
raise gr.Error("先に Serve してください。")
|
| 130 |
+
log_event(campaign_id, variant_id, event_type, datetime.utcnow().isoformat(), None)
|
| 131 |
+
ThompsonBandit.update_with_event(campaign_id, variant_id, event_type)
|
| 132 |
+
return f"{event_type} を記録しました。"
|
| 133 |
+
|
| 134 |
+
def ui_report(campaign_id: str):
|
| 135 |
+
mets = get_metrics(campaign_id)
|
| 136 |
+
vpc = get_campaign_value_per_conversion(campaign_id)
|
| 137 |
+
rows = []
|
| 138 |
+
for r in mets:
|
| 139 |
+
imp = int(r["impressions"]); clk = int(r["clicks"]); conv = int(r["conversions"])
|
| 140 |
+
ctr = (clk / imp) if imp > 0 else 0.0
|
| 141 |
+
cvr = (conv / clk) if clk > 0 else 0.0
|
| 142 |
+
ev = ctr * cvr * vpc
|
| 143 |
+
rows.append({
|
| 144 |
+
"variant_id": r["variant_id"],
|
| 145 |
+
"impressions": imp,
|
| 146 |
+
"clicks": clk,
|
| 147 |
+
"conversions": conv,
|
| 148 |
+
"ctr": round(ctr, 4),
|
| 149 |
+
"cvr": round(cvr, 4),
|
| 150 |
+
"expected_value": round(ev, 6),
|
| 151 |
+
})
|
| 152 |
+
df = pd.DataFrame(rows, columns=["variant_id","impressions","clicks","conversions","ctr","cvr","expected_value"]) if rows else pd.DataFrame()
|
| 153 |
+
return df
|
| 154 |
+
|
| 155 |
+
def ui_check(text: str):
|
| 156 |
+
ok_rule, bads = rule_based_check(text, [])
|
| 157 |
+
ok_llm, reasons, fixed = llm_check_and_fix(text)
|
| 158 |
+
status = "pass" if (ok_rule and ok_llm) else "needs_fix"
|
| 159 |
+
fixed_text = fixed or (text if status == "pass" else "")
|
| 160 |
+
reasons_joined = "; ".join(bads + reasons)
|
| 161 |
+
return status, reasons_joined, fixed_text
|
| 162 |
+
|
| 163 |
+
with gr.Blocks(title="AdCopy MAB Optimizer", fill_height=True) as demo:
|
| 164 |
+
gr.Markdown("""
|
| 165 |
+
# AdCopy MAB Optimizer(HF UI)
|
| 166 |
+
**広告コピー自動生成 → Thompson Sampling(CTR×CVR) → レポート** を、Hugging Face Spaces 上で完結。
|
| 167 |
+
- LLM: OpenAI (`OPENAI_API_KEY` を Space Secrets に設定)
|
| 168 |
+
- DB: SQLite(`data/data.db`)
|
| 169 |
+
- 季節性: Prophet/NeuralProphet(なければ簡易ヒューリスティック)
|
| 170 |
+
""")
|
| 171 |
+
|
| 172 |
+
with gr.Tab("1) Generate"):
|
| 173 |
+
with gr.Row():
|
| 174 |
+
campaign_id = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1)
|
| 175 |
+
k_variants = gr.Slider(1, 10, value=5, step=1, label="生成本数")
|
| 176 |
+
value_per_conv = gr.Number(value=5000, label="value_per_conversion")
|
| 177 |
+
brand = gr.Textbox(label="ブランド", value="SFM")
|
| 178 |
+
product = gr.Textbox(label="商品/サービス", value="HbA1c測定アプリ")
|
| 179 |
+
target = gr.Textbox(label="ターゲット", value="30-50代の健康意識が高い層")
|
| 180 |
+
tone = gr.Textbox(label="トーン", value="エビデンス重視で安心感")
|
| 181 |
+
ng_words = gr.Textbox(label="NGワード(改行区切り)", value="治る\n奇跡")
|
| 182 |
+
btn_gen = gr.Button("広告案を生成&審査&保存")
|
| 183 |
+
table_gen = gr.Dataframe(headers=["variant_id","status","rejection_reason","text"], interactive=False)
|
| 184 |
+
btn_gen.click(ui_generate, [campaign_id, brand, product, target, tone, k_variants, ng_words, value_per_conv], [table_gen])
|
| 185 |
+
|
| 186 |
+
with gr.Tab("2) Serve & Feedback"):
|
| 187 |
+
with gr.Row():
|
| 188 |
+
campaign_id2 = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1)
|
| 189 |
+
hour = gr.Slider(0, 23, value=20, step=1, label="hour")
|
| 190 |
+
segment = gr.Textbox(label="segment (任意)")
|
| 191 |
+
btn_serve = gr.Button("Serve Ad(impressionを記録)")
|
| 192 |
+
served_vid = gr.Textbox(label="served variant_id", interactive=False)
|
| 193 |
+
served_text = gr.Textbox(label="served text", lines=6, interactive=False)
|
| 194 |
+
btn_serve.click(ui_serve, [campaign_id2, hour, segment], [served_vid, served_text])
|
| 195 |
+
|
| 196 |
+
with gr.Row():
|
| 197 |
+
btn_click = gr.Button("Clickを記録")
|
| 198 |
+
btn_conv = gr.Button("Conversionを記録")
|
| 199 |
+
msg = gr.Markdown()
|
| 200 |
+
btn_click.click(lambda cid, vid: ui_feedback(cid, vid, "click"), [campaign_id2, served_vid], [msg])
|
| 201 |
+
btn_conv.click(lambda cid, vid: ui_feedback(cid, vid, "conversion"), [campaign_id2, served_vid], [msg])
|
| 202 |
+
|
| 203 |
+
with gr.Tab("3) Report"):
|
| 204 |
+
campaign_id3 = gr.Textbox(label="campaign_id", value="cmp-demo")
|
| 205 |
+
btn_rep = gr.Button("更新")
|
| 206 |
+
table_rep = gr.Dataframe(headers=["variant_id","impressions","clicks","conversions","ctr","cvr","expected_value"], interactive=False)
|
| 207 |
+
btn_rep.click(ui_report, [campaign_id3], [table_rep])
|
| 208 |
+
|
| 209 |
+
with gr.Tab("4) Compliance Check"):
|
| 210 |
+
cand = gr.Textbox(label="チェックする文面", lines=5)
|
| 211 |
+
btn_chk = gr.Button("判定")
|
| 212 |
+
status = gr.Textbox(label="status")
|
| 213 |
+
reasons = gr.Textbox(label="reasons")
|
| 214 |
+
fixed = gr.Textbox(label="fixed (修正案)", lines=5)
|
| 215 |
+
btn_chk.click(ui_check, [cand], [status, reasons, fixed])
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|