Corin1998 commited on
Commit
2846c4d
·
verified ·
1 Parent(s): f454b51

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -0
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os, json, uuid
3
+ from datetime import datetime
4
+ from typing import List, Dict, Any
5
+
6
+ import gradio as gr
7
+ import pandas as pd
8
+
9
+ from app.storage import init_db, insert_variant, upsert_campaign, get_variant, get_metrics, get_campaign_value_per_conversion, log_event
10
+ from app.bandit import ThompsonBandit
11
+ from app.forecast import SeasonalityModel
12
+ from app.compliance import rule_based_check, llm_check_and_fix
13
+ from app.openai_client import openai_chat
14
+
15
+ # 初期化
16
+ init_db()
17
+ _seasonality_cache: Dict[str, SeasonalityModel] = {}
18
+
19
+ GEN_SYSTEM = """
20
+ あなたは日本語広告コピーのプロフェッショナルコピーライターです。
21
+ 出力はJSON配列(各要素は{\"headline\":..., \"body\":...})のみで返してください。句読点や記号は自然に。誇大・断定は避け、事実ベースで魅力を伝えます。
22
+ """
23
+
24
+ GEN_USER_TEMPLATE = """
25
+ ブランド: {brand}
26
+ 商品/サービス: {product}
27
+ 想定ターゲット: {target}
28
+ トーン: {tone}
29
+ 制約: {constraints}
30
+ 生成本数: {k}
31
+ 条件:
32
+ - 1本あたり見出し(全角15-25字目安)+ 本文(全角40-90字目安)
33
+ - 禁止: 医薬効能の断定、100%、永久、即効、根拠のない数値
34
+ - CTAは自然に
35
+ """
36
+
37
+ def _seasonal(campaign_id: str) -> SeasonalityModel:
38
+ if campaign_id not in _seasonality_cache:
39
+ m = SeasonalityModel(campaign_id)
40
+ try:
41
+ m.fit()
42
+ except Exception:
43
+ pass
44
+ _seasonality_cache[campaign_id] = m
45
+ return _seasonality_cache[campaign_id]
46
+
47
+ async def ui_generate(campaign_id: str, brand: str, product: str, target: str, tone: str, k_variants: int,
48
+ ng_words: str, value_per_conversion: float):
49
+ constraints = {"ng_words": [w.strip() for w in ng_words.splitlines() if w.strip()]} if ng_words else {}
50
+
51
+ upsert_campaign(
52
+ campaign_id, brand, product, target, tone, "ja", constraints, value_per_conversion
53
+ )
54
+
55
+ user = GEN_USER_TEMPLATE.format(
56
+ brand=brand,
57
+ product=product,
58
+ target=target,
59
+ tone=tone,
60
+ constraints=json.dumps(constraints, ensure_ascii=False),
61
+ k=k_variants,
62
+ )
63
+
64
+ raw = await openai_chat([
65
+ {"role": "system", "content": GEN_SYSTEM},
66
+ {"role": "user", "content": user}
67
+ ], temperature=0.8, max_tokens=800)
68
+
69
+ try:
70
+ items = json.loads(raw)
71
+ assert isinstance(items, list)
72
+ except Exception:
73
+ raise gr.Error("LLM出力のJSONパースに失敗しました。プロンプトを短くするか、再実行してください。")
74
+
75
+ rows = []
76
+ for it in items[:k_variants]:
77
+ headline = (it.get("headline") or "").strip()
78
+ body = (it.get("body") or "").strip()
79
+ text = f"{headline}\n{body}".strip()
80
+ vid = str(uuid.uuid4())[:8]
81
+
82
+ ok_rule, bads = rule_based_check(text, (constraints or {}).get("ng_words"))
83
+ rejection_reason = None
84
+ status = "approved"
85
+
86
+ if not ok_rule:
87
+ ok_llm, reasons, fixed = llm_check_and_fix(text)
88
+ if ok_llm:
89
+ text = fixed or text
90
+ else:
91
+ status = "rejected"
92
+ rejection_reason = "; ".join(bads + reasons)
93
+ else:
94
+ ok_llm, reasons, fixed = llm_check_and_fix(text)
95
+ if not ok_llm:
96
+ text = fixed or text
97
+
98
+ insert_variant(campaign_id, vid, text, status, rejection_reason)
99
+ rows.append({
100
+ "variant_id": vid,
101
+ "status": status,
102
+ "rejection_reason": rejection_reason or "",
103
+ "text": text,
104
+ })
105
+
106
+ df = pd.DataFrame(rows, columns=["variant_id", "status", "rejection_reason", "text"]) if rows else pd.DataFrame()
107
+ return df
108
+
109
+ def ui_serve(campaign_id: str, hour: int, segment: str):
110
+ ctx = {"hour": int(hour), "segment": (segment or "").strip() or None}
111
+ m = _seasonal(campaign_id)
112
+ bandit = ThompsonBandit(campaign_id)
113
+ vid, _ = bandit.sample_arm(ctx, seasonal_fn=m.expected_ctr)
114
+ if not vid:
115
+ raise gr.Error("配信可能なバリアントがありません。まずは Generate してください。")
116
+
117
+ row = get_variant(campaign_id, vid)
118
+ if not row:
119
+ raise gr.Error("バリアントが見つかりません。")
120
+
121
+ # impression 記録
122
+ log_event(campaign_id, vid, "impression", datetime.utcnow().isoformat(), None)
123
+ ThompsonBandit.update_with_event(campaign_id, vid, "impression")
124
+
125
+ return vid, row["text"]
126
+
127
+ def ui_feedback(campaign_id: str, variant_id: str, event_type: str):
128
+ if not variant_id:
129
+ raise gr.Error("先に Serve してください。")
130
+ log_event(campaign_id, variant_id, event_type, datetime.utcnow().isoformat(), None)
131
+ ThompsonBandit.update_with_event(campaign_id, variant_id, event_type)
132
+ return f"{event_type} を記録しました。"
133
+
134
+ def ui_report(campaign_id: str):
135
+ mets = get_metrics(campaign_id)
136
+ vpc = get_campaign_value_per_conversion(campaign_id)
137
+ rows = []
138
+ for r in mets:
139
+ imp = int(r["impressions"]); clk = int(r["clicks"]); conv = int(r["conversions"])
140
+ ctr = (clk / imp) if imp > 0 else 0.0
141
+ cvr = (conv / clk) if clk > 0 else 0.0
142
+ ev = ctr * cvr * vpc
143
+ rows.append({
144
+ "variant_id": r["variant_id"],
145
+ "impressions": imp,
146
+ "clicks": clk,
147
+ "conversions": conv,
148
+ "ctr": round(ctr, 4),
149
+ "cvr": round(cvr, 4),
150
+ "expected_value": round(ev, 6),
151
+ })
152
+ df = pd.DataFrame(rows, columns=["variant_id","impressions","clicks","conversions","ctr","cvr","expected_value"]) if rows else pd.DataFrame()
153
+ return df
154
+
155
+ def ui_check(text: str):
156
+ ok_rule, bads = rule_based_check(text, [])
157
+ ok_llm, reasons, fixed = llm_check_and_fix(text)
158
+ status = "pass" if (ok_rule and ok_llm) else "needs_fix"
159
+ fixed_text = fixed or (text if status == "pass" else "")
160
+ reasons_joined = "; ".join(bads + reasons)
161
+ return status, reasons_joined, fixed_text
162
+
163
+ with gr.Blocks(title="AdCopy MAB Optimizer", fill_height=True) as demo:
164
+ gr.Markdown("""
165
+ # AdCopy MAB Optimizer(HF UI)
166
+ **広告コピー自動生成 → Thompson Sampling(CTR×CVR) → レポート** を、Hugging Face Spaces 上で完結。
167
+ - LLM: OpenAI (`OPENAI_API_KEY` を Space Secrets に設定)
168
+ - DB: SQLite(`data/data.db`)
169
+ - 季節性: Prophet/NeuralProphet(なければ簡易ヒューリスティック)
170
+ """)
171
+
172
+ with gr.Tab("1) Generate"):
173
+ with gr.Row():
174
+ campaign_id = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1)
175
+ k_variants = gr.Slider(1, 10, value=5, step=1, label="生成本数")
176
+ value_per_conv = gr.Number(value=5000, label="value_per_conversion")
177
+ brand = gr.Textbox(label="ブランド", value="SFM")
178
+ product = gr.Textbox(label="商品/サービス", value="HbA1c測定アプリ")
179
+ target = gr.Textbox(label="ターゲット", value="30-50代の健康意識が高い層")
180
+ tone = gr.Textbox(label="トーン", value="エビデンス重視で安心感")
181
+ ng_words = gr.Textbox(label="NGワード(改行区切り)", value="治る\n奇跡")
182
+ btn_gen = gr.Button("広告案を生成&審査&保存")
183
+ table_gen = gr.Dataframe(headers=["variant_id","status","rejection_reason","text"], interactive=False)
184
+ btn_gen.click(ui_generate, [campaign_id, brand, product, target, tone, k_variants, ng_words, value_per_conv], [table_gen])
185
+
186
+ with gr.Tab("2) Serve & Feedback"):
187
+ with gr.Row():
188
+ campaign_id2 = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1)
189
+ hour = gr.Slider(0, 23, value=20, step=1, label="hour")
190
+ segment = gr.Textbox(label="segment (任意)")
191
+ btn_serve = gr.Button("Serve Ad(impressionを記録)")
192
+ served_vid = gr.Textbox(label="served variant_id", interactive=False)
193
+ served_text = gr.Textbox(label="served text", lines=6, interactive=False)
194
+ btn_serve.click(ui_serve, [campaign_id2, hour, segment], [served_vid, served_text])
195
+
196
+ with gr.Row():
197
+ btn_click = gr.Button("Clickを記録")
198
+ btn_conv = gr.Button("Conversionを記録")
199
+ msg = gr.Markdown()
200
+ btn_click.click(lambda cid, vid: ui_feedback(cid, vid, "click"), [campaign_id2, served_vid], [msg])
201
+ btn_conv.click(lambda cid, vid: ui_feedback(cid, vid, "conversion"), [campaign_id2, served_vid], [msg])
202
+
203
+ with gr.Tab("3) Report"):
204
+ campaign_id3 = gr.Textbox(label="campaign_id", value="cmp-demo")
205
+ btn_rep = gr.Button("更新")
206
+ table_rep = gr.Dataframe(headers=["variant_id","impressions","clicks","conversions","ctr","cvr","expected_value"], interactive=False)
207
+ btn_rep.click(ui_report, [campaign_id3], [table_rep])
208
+
209
+ with gr.Tab("4) Compliance Check"):
210
+ cand = gr.Textbox(label="チェックする文面", lines=5)
211
+ btn_chk = gr.Button("判定")
212
+ status = gr.Textbox(label="status")
213
+ reasons = gr.Textbox(label="reasons")
214
+ fixed = gr.Textbox(label="fixed (修正案)", lines=5)
215
+ btn_chk.click(ui_check, [cand], [status, reasons, fixed])
216
+
217
+ if __name__ == "__main__":
218
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)