File size: 9,297 Bytes
2846c4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
from __future__ import annotations
import os, json, uuid
from datetime import datetime
from typing import List, Dict, Any

import gradio as gr
import pandas as pd

from app.storage import init_db, insert_variant, upsert_campaign, get_variant, get_metrics, get_campaign_value_per_conversion, log_event
from app.bandit import ThompsonBandit
from app.forecast import SeasonalityModel
from app.compliance import rule_based_check, llm_check_and_fix
from app.openai_client import openai_chat

# 初期化
init_db()
_seasonality_cache: Dict[str, SeasonalityModel] = {}

GEN_SYSTEM = """
あなたは日本語広告コピーのプロフェッショナルコピーライターです。
出力はJSON配列(各要素は{\"headline\":..., \"body\":...})のみで返してください。句読点や記号は自然に。誇大・断定は避け、事実ベースで魅力を伝えます。
"""

GEN_USER_TEMPLATE = """
ブランド: {brand}
商品/サービス: {product}
想定ターゲット: {target}
トーン: {tone}
制約: {constraints}
生成本数: {k}
条件:
- 1本あたり見出し(全角15-25字目安)+ 本文(全角40-90字目安)
- 禁止: 医薬効能の断定、100%、永久、即効、根拠のない数値
- CTAは自然に
"""

def _seasonal(campaign_id: str) -> SeasonalityModel:
    if campaign_id not in _seasonality_cache:
        m = SeasonalityModel(campaign_id)
        try:
            m.fit()
        except Exception:
            pass
        _seasonality_cache[campaign_id] = m
    return _seasonality_cache[campaign_id]

async def ui_generate(campaign_id: str, brand: str, product: str, target: str, tone: str, k_variants: int,
                      ng_words: str, value_per_conversion: float):
    constraints = {"ng_words": [w.strip() for w in ng_words.splitlines() if w.strip()]} if ng_words else {}

    upsert_campaign(
        campaign_id, brand, product, target, tone, "ja", constraints, value_per_conversion
    )

    user = GEN_USER_TEMPLATE.format(
        brand=brand,
        product=product,
        target=target,
        tone=tone,
        constraints=json.dumps(constraints, ensure_ascii=False),
        k=k_variants,
    )

    raw = await openai_chat([
        {"role": "system", "content": GEN_SYSTEM},
        {"role": "user", "content": user}
    ], temperature=0.8, max_tokens=800)

    try:
        items = json.loads(raw)
        assert isinstance(items, list)
    except Exception:
        raise gr.Error("LLM出力のJSONパースに失敗しました。プロンプトを短くするか、再実行してください。")

    rows = []
    for it in items[:k_variants]:
        headline = (it.get("headline") or "").strip()
        body = (it.get("body") or "").strip()
        text = f"{headline}\n{body}".strip()
        vid = str(uuid.uuid4())[:8]

        ok_rule, bads = rule_based_check(text, (constraints or {}).get("ng_words"))
        rejection_reason = None
        status = "approved"

        if not ok_rule:
            ok_llm, reasons, fixed = llm_check_and_fix(text)
            if ok_llm:
                text = fixed or text
            else:
                status = "rejected"
                rejection_reason = "; ".join(bads + reasons)
        else:
            ok_llm, reasons, fixed = llm_check_and_fix(text)
            if not ok_llm:
                text = fixed or text

        insert_variant(campaign_id, vid, text, status, rejection_reason)
        rows.append({
            "variant_id": vid,
            "status": status,
            "rejection_reason": rejection_reason or "",
            "text": text,
        })

    df = pd.DataFrame(rows, columns=["variant_id", "status", "rejection_reason", "text"]) if rows else pd.DataFrame()
    return df

def ui_serve(campaign_id: str, hour: int, segment: str):
    ctx = {"hour": int(hour), "segment": (segment or "").strip() or None}
    m = _seasonal(campaign_id)
    bandit = ThompsonBandit(campaign_id)
    vid, _ = bandit.sample_arm(ctx, seasonal_fn=m.expected_ctr)
    if not vid:
        raise gr.Error("配信可能なバリアントがありません。まずは Generate してください。")

    row = get_variant(campaign_id, vid)
    if not row:
        raise gr.Error("バリアントが見つかりません。")

    # impression 記録
    log_event(campaign_id, vid, "impression", datetime.utcnow().isoformat(), None)
    ThompsonBandit.update_with_event(campaign_id, vid, "impression")

    return vid, row["text"]

def ui_feedback(campaign_id: str, variant_id: str, event_type: str):
    if not variant_id:
        raise gr.Error("先に Serve してください。")
    log_event(campaign_id, variant_id, event_type, datetime.utcnow().isoformat(), None)
    ThompsonBandit.update_with_event(campaign_id, variant_id, event_type)
    return f"{event_type} を記録しました。"

def ui_report(campaign_id: str):
    mets = get_metrics(campaign_id)
    vpc = get_campaign_value_per_conversion(campaign_id)
    rows = []
    for r in mets:
        imp = int(r["impressions"]); clk = int(r["clicks"]); conv = int(r["conversions"])
        ctr = (clk / imp) if imp > 0 else 0.0
        cvr = (conv / clk) if clk > 0 else 0.0
        ev  = ctr * cvr * vpc
        rows.append({
            "variant_id": r["variant_id"],
            "impressions": imp,
            "clicks": clk,
            "conversions": conv,
            "ctr": round(ctr, 4),
            "cvr": round(cvr, 4),
            "expected_value": round(ev, 6),
        })
    df = pd.DataFrame(rows, columns=["variant_id","impressions","clicks","conversions","ctr","cvr","expected_value"]) if rows else pd.DataFrame()
    return df

def ui_check(text: str):
    ok_rule, bads = rule_based_check(text, [])
    ok_llm, reasons, fixed = llm_check_and_fix(text)
    status = "pass" if (ok_rule and ok_llm) else "needs_fix"
    fixed_text = fixed or (text if status == "pass" else "")
    reasons_joined = "; ".join(bads + reasons)
    return status, reasons_joined, fixed_text

with gr.Blocks(title="AdCopy MAB Optimizer", fill_height=True) as demo:
    gr.Markdown("""
    # AdCopy MAB Optimizer(HF UI)
    **広告コピー自動生成 → Thompson Sampling(CTR×CVR) → レポート** を、Hugging Face Spaces 上で完結。
    - LLM: OpenAI (`OPENAI_API_KEY` を Space Secrets に設定)
    - DB: SQLite(`data/data.db`)
    - 季節性: Prophet/NeuralProphet(なければ簡易ヒューリスティック)
    """)

    with gr.Tab("1) Generate"):
        with gr.Row():
            campaign_id = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1)
            k_variants = gr.Slider(1, 10, value=5, step=1, label="生成本数")
            value_per_conv = gr.Number(value=5000, label="value_per_conversion")
        brand = gr.Textbox(label="ブランド", value="SFM")
        product = gr.Textbox(label="商品/サービス", value="HbA1c測定アプリ")
        target = gr.Textbox(label="ターゲット", value="30-50代の健康意識が高い層")
        tone = gr.Textbox(label="トーン", value="エビデンス重視で安心感")
        ng_words = gr.Textbox(label="NGワード(改行区切り)", value="治る\n奇跡")
        btn_gen = gr.Button("広告案を生成&審査&保存")
        table_gen = gr.Dataframe(headers=["variant_id","status","rejection_reason","text"], interactive=False)
        btn_gen.click(ui_generate, [campaign_id, brand, product, target, tone, k_variants, ng_words, value_per_conv], [table_gen])

    with gr.Tab("2) Serve & Feedback"):
        with gr.Row():
            campaign_id2 = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1)
            hour = gr.Slider(0, 23, value=20, step=1, label="hour")
            segment = gr.Textbox(label="segment (任意)")
        btn_serve = gr.Button("Serve Ad(impressionを記録)")
        served_vid = gr.Textbox(label="served variant_id", interactive=False)
        served_text = gr.Textbox(label="served text", lines=6, interactive=False)
        btn_serve.click(ui_serve, [campaign_id2, hour, segment], [served_vid, served_text])

        with gr.Row():
            btn_click = gr.Button("Clickを記録")
            btn_conv = gr.Button("Conversionを記録")
            msg = gr.Markdown()
        btn_click.click(lambda cid, vid: ui_feedback(cid, vid, "click"), [campaign_id2, served_vid], [msg])
        btn_conv.click(lambda cid, vid: ui_feedback(cid, vid, "conversion"), [campaign_id2, served_vid], [msg])

    with gr.Tab("3) Report"):
        campaign_id3 = gr.Textbox(label="campaign_id", value="cmp-demo")
        btn_rep = gr.Button("更新")
        table_rep = gr.Dataframe(headers=["variant_id","impressions","clicks","conversions","ctr","cvr","expected_value"], interactive=False)
        btn_rep.click(ui_report, [campaign_id3], [table_rep])

    with gr.Tab("4) Compliance Check"):
        cand = gr.Textbox(label="チェックする文面", lines=5)
        btn_chk = gr.Button("判定")
        status = gr.Textbox(label="status")
        reasons = gr.Textbox(label="reasons")
        fixed = gr.Textbox(label="fixed (修正案)", lines=5)
        btn_chk.click(ui_check, [cand], [status, reasons, fixed])

if __name__ == "__main__":
    demo.queue().launch(server_name="0.0.0.0", server_port=7860)