Spaces:
Sleeping
Sleeping
Update ui.py
Browse files
ui.py
CHANGED
|
@@ -24,6 +24,7 @@ from app.forecast import SeasonalityModel
|
|
| 24 |
from app.compliance import rule_based_check, llm_check_and_fix
|
| 25 |
from app.openai_client import openai_chat_json
|
| 26 |
|
|
|
|
| 27 |
# 初期化
|
| 28 |
init_db()
|
| 29 |
_seasonality_cache: Dict[str, SeasonalityModel] = {}
|
|
@@ -64,18 +65,64 @@ GEN_USER_TEMPLATE = """
|
|
| 64 |
- 各要素は {"headline": "...", "body": "..."} のみ
|
| 65 |
"""
|
| 66 |
|
|
|
|
| 67 |
def _seasonal(campaign_id: str) -> SeasonalityModel:
|
| 68 |
if campaign_id not in _seasonality_cache:
|
| 69 |
m = SeasonalityModel(campaign_id)
|
| 70 |
try:
|
| 71 |
m.fit()
|
| 72 |
except Exception:
|
|
|
|
| 73 |
pass
|
| 74 |
_seasonality_cache[campaign_id] = m
|
| 75 |
return _seasonality_cache[campaign_id]
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
async def ui_generate(campaign_id: str, brand: str, product: str, target: str, tone: str, k_variants: int,
|
| 78 |
ng_words: str, value_per_conversion: float):
|
|
|
|
| 79 |
constraints = {"ng_words": [w.strip() for w in ng_words.splitlines() if w.strip()]} if ng_words else {}
|
| 80 |
|
| 81 |
upsert_campaign(
|
|
@@ -91,20 +138,45 @@ async def ui_generate(campaign_id: str, brand: str, product: str, target: str, t
|
|
| 91 |
k=k_variants,
|
| 92 |
)
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
items = data.get("variants", data if isinstance(data, list) else [])
|
| 104 |
rows = []
|
| 105 |
-
for it in
|
| 106 |
-
headline =
|
| 107 |
-
body =
|
| 108 |
text = f"{headline}\n{body}".strip()
|
| 109 |
vid = str(uuid.uuid4())[:8]
|
| 110 |
|
|
@@ -136,6 +208,7 @@ async def ui_generate(campaign_id: str, brand: str, product: str, target: str, t
|
|
| 136 |
df = pd.DataFrame(rows, columns=GENERATE_COLUMNS)
|
| 137 |
return df
|
| 138 |
|
|
|
|
| 139 |
def ui_serve(campaign_id: str, hour: int, segment: str):
|
| 140 |
ctx = {"hour": int(hour), "segment": (segment or "").strip() or None}
|
| 141 |
m = _seasonal(campaign_id)
|
|
@@ -154,6 +227,7 @@ def ui_serve(campaign_id: str, hour: int, segment: str):
|
|
| 154 |
|
| 155 |
return vid, row["text"]
|
| 156 |
|
|
|
|
| 157 |
def ui_feedback(campaign_id: str, variant_id: str, event_type: str):
|
| 158 |
if not variant_id:
|
| 159 |
raise gr.Error("先に Serve してください。")
|
|
@@ -161,6 +235,7 @@ def ui_feedback(campaign_id: str, variant_id: str, event_type: str):
|
|
| 161 |
ThompsonBandit.update_with_event(campaign_id, variant_id, event_type)
|
| 162 |
return f"{event_type} を記録しました。"
|
| 163 |
|
|
|
|
| 164 |
def ui_report(campaign_id: str):
|
| 165 |
mets = get_metrics(campaign_id)
|
| 166 |
vpc = get_campaign_value_per_conversion(campaign_id)
|
|
@@ -183,6 +258,7 @@ def ui_report(campaign_id: str):
|
|
| 183 |
df = pd.DataFrame(rows, columns=REPORT_COLUMNS)
|
| 184 |
return df
|
| 185 |
|
|
|
|
| 186 |
def ui_check(text: str):
|
| 187 |
ok_rule, bads = rule_based_check(text, [])
|
| 188 |
ok_llm, reasons, fixed = llm_check_and_fix(text)
|
|
@@ -191,6 +267,7 @@ def ui_check(text: str):
|
|
| 191 |
reasons_joined = "; ".join(bads + reasons)
|
| 192 |
return status, reasons_joined, fixed_text
|
| 193 |
|
|
|
|
| 194 |
with gr.Blocks(title="AdCopy MAB Optimizer", fill_height=True) as demo:
|
| 195 |
gr.Markdown("""
|
| 196 |
# AdCopy MAB Optimizer(HF UI)
|
|
|
|
| 24 |
from app.compliance import rule_based_check, llm_check_and_fix
|
| 25 |
from app.openai_client import openai_chat_json
|
| 26 |
|
| 27 |
+
|
| 28 |
# 初期化
|
| 29 |
init_db()
|
| 30 |
_seasonality_cache: Dict[str, SeasonalityModel] = {}
|
|
|
|
| 65 |
- 各要素は {"headline": "...", "body": "..."} のみ
|
| 66 |
"""
|
| 67 |
|
| 68 |
+
|
| 69 |
def _seasonal(campaign_id: str) -> SeasonalityModel:
|
| 70 |
if campaign_id not in _seasonality_cache:
|
| 71 |
m = SeasonalityModel(campaign_id)
|
| 72 |
try:
|
| 73 |
m.fit()
|
| 74 |
except Exception:
|
| 75 |
+
# 失敗してもフォールバックあり
|
| 76 |
pass
|
| 77 |
_seasonality_cache[campaign_id] = m
|
| 78 |
return _seasonality_cache[campaign_id]
|
| 79 |
|
| 80 |
+
|
| 81 |
+
def _safe_get_variants(data, k: int):
|
| 82 |
+
"""LLM応答から variants 配列を安全に取り出して正規化。失敗時は None を返す。"""
|
| 83 |
+
items = []
|
| 84 |
+
if isinstance(data, dict) and isinstance(data.get("variants"), list):
|
| 85 |
+
items = data["variants"]
|
| 86 |
+
elif isinstance(data, list):
|
| 87 |
+
items = data
|
| 88 |
+
if not items or not all(isinstance(x, dict) for x in items):
|
| 89 |
+
return None
|
| 90 |
+
out = []
|
| 91 |
+
for it in items[:k]:
|
| 92 |
+
out.append({
|
| 93 |
+
"headline": str(it.get("headline", "")).strip(),
|
| 94 |
+
"body": str(it.get("body", "")).strip(),
|
| 95 |
+
})
|
| 96 |
+
return out
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _local_variants(brand: str, product: str, k: int):
|
| 100 |
+
"""LLMが不調でも UI を止めないための最終フォールバック生成(簡易・無害表現)。"""
|
| 101 |
+
base_head = [
|
| 102 |
+
"使いやすさで選ばれています",
|
| 103 |
+
"日々の習慣をシンプルに",
|
| 104 |
+
"はじめてでも安心",
|
| 105 |
+
"続けやすいサポートを",
|
| 106 |
+
"いま必要な機能だけを"
|
| 107 |
+
]
|
| 108 |
+
base_body = [
|
| 109 |
+
"{brand}の「{product}」。生活になじむ設計で、今日からムリなく始められます。まずは詳細をご覧ください。",
|
| 110 |
+
"毎日を少しラクに。{brand}の{product}が、あなたの習慣づくりを後押しします。今すぐチェック。",
|
| 111 |
+
"難しい操作は不要。{brand}の{product}なら、使い始めから自然に続けられます。詳しくはサイトへ。",
|
| 112 |
+
"必要な情報をひと目で。{brand}の{product}で、日々の管理をシンプルに。詳細を見る。",
|
| 113 |
+
"続けやすさを重視。{brand}の{product}で、小さな一歩から。"
|
| 114 |
+
]
|
| 115 |
+
out = []
|
| 116 |
+
for i in range(k):
|
| 117 |
+
hi = base_head[i % len(base_head)]
|
| 118 |
+
bo = base_body[i % len(base_body)].format(brand=brand, product=product)
|
| 119 |
+
out.append({"headline": hi, "body": bo})
|
| 120 |
+
return out
|
| 121 |
+
|
| 122 |
+
|
| 123 |
async def ui_generate(campaign_id: str, brand: str, product: str, target: str, tone: str, k_variants: int,
|
| 124 |
ng_words: str, value_per_conversion: float):
|
| 125 |
+
k_variants = int(k_variants) # Slider の float を明示的に int 化
|
| 126 |
constraints = {"ng_words": [w.strip() for w in ng_words.splitlines() if w.strip()]} if ng_words else {}
|
| 127 |
|
| 128 |
upsert_campaign(
|
|
|
|
| 138 |
k=k_variants,
|
| 139 |
)
|
| 140 |
|
| 141 |
+
# まずは通常プロンプトで JSON モード呼び出し
|
| 142 |
+
items = None
|
| 143 |
+
try:
|
| 144 |
+
data = await openai_chat_json(
|
| 145 |
+
[
|
| 146 |
+
{"role": "system", "content": GEN_SYSTEM},
|
| 147 |
+
{"role": "user", "content": user},
|
| 148 |
+
],
|
| 149 |
+
temperature=0.2,
|
| 150 |
+
max_tokens=1200,
|
| 151 |
+
)
|
| 152 |
+
items = _safe_get_variants(data, k_variants)
|
| 153 |
+
except Exception:
|
| 154 |
+
items = None
|
| 155 |
+
|
| 156 |
+
# 失敗/空のときは、温度をさらに下げて再試行(より厳格に)
|
| 157 |
+
if not items:
|
| 158 |
+
try:
|
| 159 |
+
retry_user = user + "\n\n注意: 'variants' は必ず指定件数、各要素は {\"headline\":\"...\",\"body\":\"...\"} のみ。"
|
| 160 |
+
data = await openai_chat_json(
|
| 161 |
+
[
|
| 162 |
+
{"role": "system", "content": GEN_SYSTEM},
|
| 163 |
+
{"role": "user", "content": retry_user},
|
| 164 |
+
],
|
| 165 |
+
temperature=0.1,
|
| 166 |
+
max_tokens=1000,
|
| 167 |
+
)
|
| 168 |
+
items = _safe_get_variants(data, k_variants)
|
| 169 |
+
except Exception:
|
| 170 |
+
items = None
|
| 171 |
+
|
| 172 |
+
# それでも無理ならローカル生成(UIを止めない)
|
| 173 |
+
if not items:
|
| 174 |
+
items = _local_variants(brand, product, k_variants)
|
| 175 |
|
|
|
|
| 176 |
rows = []
|
| 177 |
+
for it in items[:k_variants]:
|
| 178 |
+
headline = it["headline"]
|
| 179 |
+
body = it["body"]
|
| 180 |
text = f"{headline}\n{body}".strip()
|
| 181 |
vid = str(uuid.uuid4())[:8]
|
| 182 |
|
|
|
|
| 208 |
df = pd.DataFrame(rows, columns=GENERATE_COLUMNS)
|
| 209 |
return df
|
| 210 |
|
| 211 |
+
|
| 212 |
def ui_serve(campaign_id: str, hour: int, segment: str):
|
| 213 |
ctx = {"hour": int(hour), "segment": (segment or "").strip() or None}
|
| 214 |
m = _seasonal(campaign_id)
|
|
|
|
| 227 |
|
| 228 |
return vid, row["text"]
|
| 229 |
|
| 230 |
+
|
| 231 |
def ui_feedback(campaign_id: str, variant_id: str, event_type: str):
|
| 232 |
if not variant_id:
|
| 233 |
raise gr.Error("先に Serve してください。")
|
|
|
|
| 235 |
ThompsonBandit.update_with_event(campaign_id, variant_id, event_type)
|
| 236 |
return f"{event_type} を記録しました。"
|
| 237 |
|
| 238 |
+
|
| 239 |
def ui_report(campaign_id: str):
|
| 240 |
mets = get_metrics(campaign_id)
|
| 241 |
vpc = get_campaign_value_per_conversion(campaign_id)
|
|
|
|
| 258 |
df = pd.DataFrame(rows, columns=REPORT_COLUMNS)
|
| 259 |
return df
|
| 260 |
|
| 261 |
+
|
| 262 |
def ui_check(text: str):
|
| 263 |
ok_rule, bads = rule_based_check(text, [])
|
| 264 |
ok_llm, reasons, fixed = llm_check_and_fix(text)
|
|
|
|
| 267 |
reasons_joined = "; ".join(bads + reasons)
|
| 268 |
return status, reasons_joined, fixed_text
|
| 269 |
|
| 270 |
+
|
| 271 |
with gr.Blocks(title="AdCopy MAB Optimizer", fill_height=True) as demo:
|
| 272 |
gr.Markdown("""
|
| 273 |
# AdCopy MAB Optimizer(HF UI)
|