File size: 10,034 Bytes
c3fe8e0
 
 
762a2c6
 
 
 
c3fe8e0
762a2c6
c3fe8e0
 
2846c4d
c3fe8e0
2846c4d
 
 
 
c3fe8e0
 
 
 
2846c4d
 
 
13499b7
2846c4d
 
 
 
 
13499b7
 
 
 
9a9d6a7
2846c4d
 
9a9d6a7
 
 
 
 
 
 
 
 
 
 
 
 
 
2846c4d
 
 
 
 
 
 
 
 
9a9d6a7
 
 
 
2846c4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a9d6a7
 
 
 
 
 
 
 
2846c4d
9a9d6a7
2846c4d
13499b7
2846c4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13499b7
 
2846c4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13499b7
 
2846c4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3fe8e0
2846c4d
 
 
 
 
 
 
 
 
 
 
 
 
 
13499b7
2846c4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13499b7
2846c4d
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
from __future__ import annotations

# === Writable config dirs (must come right after future import) ===
import os
os.environ.setdefault("APP_DATA_DIR", "/data/app_data" if os.access("/data", os.W_OK) else "/tmp/app_data")
os.environ.setdefault("MPLCONFIGDIR", os.path.join(os.environ["APP_DATA_DIR"], "mplconfig"))
os.makedirs(os.environ["MPLCONFIGDIR"], exist_ok=True)
# =================================================================

import json
import uuid
from datetime import datetime
from typing import Dict

import gradio as gr
import pandas as pd

from app.storage import (
    init_db, insert_variant, upsert_campaign, get_variant, get_metrics,
    get_campaign_value_per_conversion, log_event
)
from app.bandit import ThompsonBandit
from app.forecast import SeasonalityModel
from app.compliance import rule_based_check, llm_check_and_fix
from app.openai_client import openai_chat_json

# 初期化
init_db()
_seasonality_cache: Dict[str, SeasonalityModel] = {}

# 固定カラム(常にこの形で返す)
GENERATE_COLUMNS = ["variant_id", "status", "rejection_reason", "text"]
REPORT_COLUMNS   = ["variant_id","impressions","clicks","conversions","ctr","cvr","expected_value"]

# JSONモード前提の厳格プロンプト
GEN_SYSTEM = """
あなたは日本語広告コピーのプロフェッショナルコピーライターです。
出力は**次のJSONオブジェクトのみ**で厳密に返してください。余計な文章・説明・前置きは禁止です。

形式:
{
  "variants": [
    {"headline": "全角15-25字程度", "body": "全角40-90字程度"},
    ...
  ]
}

ルール:
- 医薬効能の断定、100%、永久、即効、根拠のない数値などの誇大表現は禁止
- CTAは自然に
- 日本語で、句読点や記号は自然に
"""

GEN_USER_TEMPLATE = """
ブランド: {brand}
商品/サービス: {product}
想定ターゲット: {target}
トーン: {tone}
制約: {constraints}
生成本数: {k}

要件:
- "variants" 配列の要素数は **ちょうど {k}** 件にしてください
- 各要素は {"headline": "...", "body": "..."} のみ
"""

def _seasonal(campaign_id: str) -> SeasonalityModel:
    if campaign_id not in _seasonality_cache:
        m = SeasonalityModel(campaign_id)
        try:
            m.fit()
        except Exception:
            pass
        _seasonality_cache[campaign_id] = m
    return _seasonality_cache[campaign_id]

async def ui_generate(campaign_id: str, brand: str, product: str, target: str, tone: str, k_variants: int,
                      ng_words: str, value_per_conversion: float):
    constraints = {"ng_words": [w.strip() for w in ng_words.splitlines() if w.strip()]} if ng_words else {}

    upsert_campaign(
        campaign_id, brand, product, target, tone, "ja", constraints, value_per_conversion
    )

    user = GEN_USER_TEMPLATE.format(
        brand=brand,
        product=product,
        target=target,
        tone=tone,
        constraints=json.dumps(constraints, ensure_ascii=False),
        k=k_variants,
    )

    data = await openai_chat_json(
        [
            {"role": "system", "content": GEN_SYSTEM},
            {"role": "user", "content": user},
        ],
        temperature=0.3,
        max_tokens=1200,
    )

    items = data.get("variants", data if isinstance(data, list) else [])
    rows = []
    for it in (items[:k_variants] if isinstance(items, list) else []):
        headline = (it.get("headline") or "").strip()
        body = (it.get("body") or "").strip()
        text = f"{headline}\n{body}".strip()
        vid = str(uuid.uuid4())[:8]

        ok_rule, bads = rule_based_check(text, (constraints or {}).get("ng_words"))
        rejection_reason = None
        status = "approved"

        if not ok_rule:
            ok_llm, reasons, fixed = llm_check_and_fix(text)
            if ok_llm:
                text = fixed or text
            else:
                status = "rejected"
                rejection_reason = "; ".join(bads + reasons)
        else:
            ok_llm, reasons, fixed = llm_check_and_fix(text)
            if not ok_llm:
                text = fixed or text

        insert_variant(campaign_id, vid, text, status, rejection_reason)
        rows.append({
            "variant_id": vid,
            "status": status,
            "rejection_reason": rejection_reason or "",
            "text": text,
        })

    # 常に固定カラムで返す(空でもカラムを持つDataFrame)
    df = pd.DataFrame(rows, columns=GENERATE_COLUMNS)
    return df

def ui_serve(campaign_id: str, hour: int, segment: str):
    ctx = {"hour": int(hour), "segment": (segment or "").strip() or None}
    m = _seasonal(campaign_id)
    bandit = ThompsonBandit(campaign_id)
    vid, _ = bandit.sample_arm(ctx, seasonal_fn=m.expected_ctr)
    if not vid:
        raise gr.Error("配信可能なバリアントがありません。まずは Generate してください。")

    row = get_variant(campaign_id, vid)
    if not row:
        raise gr.Error("バリアントが見つかりません。")

    # impression 記録
    log_event(campaign_id, vid, "impression", datetime.utcnow().isoformat(), None)
    ThompsonBandit.update_with_event(campaign_id, vid, "impression")

    return vid, row["text"]

def ui_feedback(campaign_id: str, variant_id: str, event_type: str):
    if not variant_id:
        raise gr.Error("先に Serve してください。")
    log_event(campaign_id, variant_id, event_type, datetime.utcnow().isoformat(), None)
    ThompsonBandit.update_with_event(campaign_id, variant_id, event_type)
    return f"{event_type} を記録しました。"

def ui_report(campaign_id: str):
    mets = get_metrics(campaign_id)
    vpc = get_campaign_value_per_conversion(campaign_id)
    rows = []
    for r in mets:
        imp = int(r["impressions"]); clk = int(r["clicks"]); conv = int(r["conversions"])
        ctr = (clk / imp) if imp > 0 else 0.0
        cvr = (conv / clk) if clk > 0 else 0.0
        ev  = ctr * cvr * vpc
        rows.append({
            "variant_id": r["variant_id"],
            "impressions": imp,
            "clicks": clk,
            "conversions": conv,
            "ctr": round(ctr, 4),
            "cvr": round(cvr, 4),
            "expected_value": round(ev, 6),
        })
    # 常に固定カラムで返す(空でもカラムを持つDataFrame)
    df = pd.DataFrame(rows, columns=REPORT_COLUMNS)
    return df

def ui_check(text: str):
    ok_rule, bads = rule_based_check(text, [])
    ok_llm, reasons, fixed = llm_check_and_fix(text)
    status = "pass" if (ok_rule and ok_llm) else "needs_fix"
    fixed_text = fixed or (text if status == "pass" else "")
    reasons_joined = "; ".join(bads + reasons)
    return status, reasons_joined, fixed_text

with gr.Blocks(title="AdCopy MAB Optimizer", fill_height=True) as demo:
    gr.Markdown("""
    # AdCopy MAB Optimizer(HF UI)
    **広告コピー自動生成 → Thompson Sampling(CTR×CVR) → レポート** を、Hugging Face Spaces 上で完結。
    - LLM: OpenAI (`OPENAI_API_KEY` を Space Secrets に設定)
    - DB: SQLite(`/data/app_data/data.db` など、書き込み可能ディレクトリ)
    - 季節性: Prophet/NeuralProphet(なければ簡易ヒューリスティック)
    """)

    with gr.Tab("1) Generate"):
        with gr.Row():
            campaign_id = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1)
            k_variants = gr.Slider(1, 10, value=5, step=1, label="生成本数")
            value_per_conv = gr.Number(value=5000, label="value_per_conversion")
        brand = gr.Textbox(label="ブランド", value="SFM")
        product = gr.Textbox(label="商品/サービス", value="HbA1c測定アプリ")
        target = gr.Textbox(label="ターゲット", value="30-50代の健康意識が高い層")
        tone = gr.Textbox(label="トーン", value="エビデンス重視で安心感")
        ng_words = gr.Textbox(label="NGワード(改行区切り)", value="治る\n奇跡")
        btn_gen = gr.Button("広告案を生成&審査&保存")
        table_gen = gr.Dataframe(headers=GENERATE_COLUMNS, interactive=False)
        btn_gen.click(ui_generate, [campaign_id, brand, product, target, tone, k_variants, ng_words, value_per_conv], [table_gen])

    with gr.Tab("2) Serve & Feedback"):
        with gr.Row():
            campaign_id2 = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1)
            hour = gr.Slider(0, 23, value=20, step=1, label="hour")
            segment = gr.Textbox(label="segment (任意)")
        btn_serve = gr.Button("Serve Ad(impressionを記録)")
        served_vid = gr.Textbox(label="served variant_id", interactive=False)
        served_text = gr.Textbox(label="served text", lines=6, interactive=False)
        btn_serve.click(ui_serve, [campaign_id2, hour, segment], [served_vid, served_text])

        with gr.Row():
            btn_click = gr.Button("Clickを記録")
            btn_conv = gr.Button("Conversionを記録")
            msg = gr.Markdown()
        btn_click.click(lambda cid, vid: ui_feedback(cid, vid, "click"), [campaign_id2, served_vid], [msg])
        btn_conv.click(lambda cid, vid: ui_feedback(cid, vid, "conversion"), [campaign_id2, served_vid], [msg])

    with gr.Tab("3) Report"):
        campaign_id3 = gr.Textbox(label="campaign_id", value="cmp-demo")
        btn_rep = gr.Button("更新")
        table_rep = gr.Dataframe(headers=REPORT_COLUMNS, interactive=False)
        btn_rep.click(ui_report, [campaign_id3], [table_rep])

    with gr.Tab("4) Compliance Check"):
        cand = gr.Textbox(label="チェックする文面", lines=5)
        btn_chk = gr.Button("判定")
        status = gr.Textbox(label="status")
        reasons = gr.Textbox(label="reasons")
        fixed = gr.Textbox(label="fixed (修正案)", lines=5)
        btn_chk.click(ui_check, [cand], [status, reasons, fixed])

if __name__ == "__main__":
    demo.queue().launch(server_name="0.0.0.0", server_port=7860)