Corin1998 commited on
Commit
d5d67fa
·
verified ·
1 Parent(s): 13499b7

Update ui.py

Browse files
Files changed (1) hide show
  1. ui.py +89 -12
ui.py CHANGED
@@ -24,6 +24,7 @@ from app.forecast import SeasonalityModel
24
  from app.compliance import rule_based_check, llm_check_and_fix
25
  from app.openai_client import openai_chat_json
26
 
 
27
  # 初期化
28
  init_db()
29
  _seasonality_cache: Dict[str, SeasonalityModel] = {}
@@ -64,18 +65,64 @@ GEN_USER_TEMPLATE = """
64
  - 各要素は {"headline": "...", "body": "..."} のみ
65
  """
66
 
 
67
  def _seasonal(campaign_id: str) -> SeasonalityModel:
68
  if campaign_id not in _seasonality_cache:
69
  m = SeasonalityModel(campaign_id)
70
  try:
71
  m.fit()
72
  except Exception:
 
73
  pass
74
  _seasonality_cache[campaign_id] = m
75
  return _seasonality_cache[campaign_id]
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  async def ui_generate(campaign_id: str, brand: str, product: str, target: str, tone: str, k_variants: int,
78
  ng_words: str, value_per_conversion: float):
 
79
  constraints = {"ng_words": [w.strip() for w in ng_words.splitlines() if w.strip()]} if ng_words else {}
80
 
81
  upsert_campaign(
@@ -91,20 +138,45 @@ async def ui_generate(campaign_id: str, brand: str, product: str, target: str, t
91
  k=k_variants,
92
  )
93
 
94
- data = await openai_chat_json(
95
- [
96
- {"role": "system", "content": GEN_SYSTEM},
97
- {"role": "user", "content": user},
98
- ],
99
- temperature=0.3,
100
- max_tokens=1200,
101
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- items = data.get("variants", data if isinstance(data, list) else [])
104
  rows = []
105
- for it in (items[:k_variants] if isinstance(items, list) else []):
106
- headline = (it.get("headline") or "").strip()
107
- body = (it.get("body") or "").strip()
108
  text = f"{headline}\n{body}".strip()
109
  vid = str(uuid.uuid4())[:8]
110
 
@@ -136,6 +208,7 @@ async def ui_generate(campaign_id: str, brand: str, product: str, target: str, t
136
  df = pd.DataFrame(rows, columns=GENERATE_COLUMNS)
137
  return df
138
 
 
139
  def ui_serve(campaign_id: str, hour: int, segment: str):
140
  ctx = {"hour": int(hour), "segment": (segment or "").strip() or None}
141
  m = _seasonal(campaign_id)
@@ -154,6 +227,7 @@ def ui_serve(campaign_id: str, hour: int, segment: str):
154
 
155
  return vid, row["text"]
156
 
 
157
  def ui_feedback(campaign_id: str, variant_id: str, event_type: str):
158
  if not variant_id:
159
  raise gr.Error("先に Serve してください。")
@@ -161,6 +235,7 @@ def ui_feedback(campaign_id: str, variant_id: str, event_type: str):
161
  ThompsonBandit.update_with_event(campaign_id, variant_id, event_type)
162
  return f"{event_type} を記録しました。"
163
 
 
164
  def ui_report(campaign_id: str):
165
  mets = get_metrics(campaign_id)
166
  vpc = get_campaign_value_per_conversion(campaign_id)
@@ -183,6 +258,7 @@ def ui_report(campaign_id: str):
183
  df = pd.DataFrame(rows, columns=REPORT_COLUMNS)
184
  return df
185
 
 
186
  def ui_check(text: str):
187
  ok_rule, bads = rule_based_check(text, [])
188
  ok_llm, reasons, fixed = llm_check_and_fix(text)
@@ -191,6 +267,7 @@ def ui_check(text: str):
191
  reasons_joined = "; ".join(bads + reasons)
192
  return status, reasons_joined, fixed_text
193
 
 
194
  with gr.Blocks(title="AdCopy MAB Optimizer", fill_height=True) as demo:
195
  gr.Markdown("""
196
  # AdCopy MAB Optimizer(HF UI)
 
24
  from app.compliance import rule_based_check, llm_check_and_fix
25
  from app.openai_client import openai_chat_json
26
 
27
+
28
  # 初期化
29
  init_db()
30
  _seasonality_cache: Dict[str, SeasonalityModel] = {}
 
65
  - 各要素は {"headline": "...", "body": "..."} のみ
66
  """
67
 
68
+
69
  def _seasonal(campaign_id: str) -> SeasonalityModel:
70
  if campaign_id not in _seasonality_cache:
71
  m = SeasonalityModel(campaign_id)
72
  try:
73
  m.fit()
74
  except Exception:
75
+ # 失敗してもフォールバックあり
76
  pass
77
  _seasonality_cache[campaign_id] = m
78
  return _seasonality_cache[campaign_id]
79
 
80
+
81
+ def _safe_get_variants(data, k: int):
82
+ """LLM応答から variants 配列を安全に取り出して正規化。失敗時は None を返す。"""
83
+ items = []
84
+ if isinstance(data, dict) and isinstance(data.get("variants"), list):
85
+ items = data["variants"]
86
+ elif isinstance(data, list):
87
+ items = data
88
+ if not items or not all(isinstance(x, dict) for x in items):
89
+ return None
90
+ out = []
91
+ for it in items[:k]:
92
+ out.append({
93
+ "headline": str(it.get("headline", "")).strip(),
94
+ "body": str(it.get("body", "")).strip(),
95
+ })
96
+ return out
97
+
98
+
99
+ def _local_variants(brand: str, product: str, k: int):
100
+ """LLMが不調でも UI を止めないための最終フォールバック生成(簡易・無害表現)。"""
101
+ base_head = [
102
+ "使いやすさで選ばれています",
103
+ "日々の習慣をシンプルに",
104
+ "はじめてでも安心",
105
+ "続けやすいサポートを",
106
+ "いま必要な機能だけを"
107
+ ]
108
+ base_body = [
109
+ "{brand}の「{product}」。生活になじむ設計で、今日からムリなく始められます。まずは詳細をご覧ください。",
110
+ "毎日を少しラクに。{brand}の{product}が、あなたの習慣づくりを後押しします。今すぐチェック。",
111
+ "難しい操作は不要。{brand}の{product}なら、使い始めから自然に続けられます。詳しくはサイトへ。",
112
+ "必要な情報をひと目で。{brand}の{product}で、日々の管理をシンプルに。詳細を見る。",
113
+ "続けやすさを重視。{brand}の{product}で、小さな一歩から。"
114
+ ]
115
+ out = []
116
+ for i in range(k):
117
+ hi = base_head[i % len(base_head)]
118
+ bo = base_body[i % len(base_body)].format(brand=brand, product=product)
119
+ out.append({"headline": hi, "body": bo})
120
+ return out
121
+
122
+
123
  async def ui_generate(campaign_id: str, brand: str, product: str, target: str, tone: str, k_variants: int,
124
  ng_words: str, value_per_conversion: float):
125
+ k_variants = int(k_variants) # Slider の float を明示的に int 化
126
  constraints = {"ng_words": [w.strip() for w in ng_words.splitlines() if w.strip()]} if ng_words else {}
127
 
128
  upsert_campaign(
 
138
  k=k_variants,
139
  )
140
 
141
+ # まずは通常プロンプトで JSON モード呼び出し
142
+ items = None
143
+ try:
144
+ data = await openai_chat_json(
145
+ [
146
+ {"role": "system", "content": GEN_SYSTEM},
147
+ {"role": "user", "content": user},
148
+ ],
149
+ temperature=0.2,
150
+ max_tokens=1200,
151
+ )
152
+ items = _safe_get_variants(data, k_variants)
153
+ except Exception:
154
+ items = None
155
+
156
+ # 失敗/空のときは、温度をさらに下げて再試行(より厳格に)
157
+ if not items:
158
+ try:
159
+ retry_user = user + "\n\n注意: 'variants' は必ず指定件数、各要素は {\"headline\":\"...\",\"body\":\"...\"} のみ。"
160
+ data = await openai_chat_json(
161
+ [
162
+ {"role": "system", "content": GEN_SYSTEM},
163
+ {"role": "user", "content": retry_user},
164
+ ],
165
+ temperature=0.1,
166
+ max_tokens=1000,
167
+ )
168
+ items = _safe_get_variants(data, k_variants)
169
+ except Exception:
170
+ items = None
171
+
172
+ # それでも無理ならローカル生成(UIを止めない)
173
+ if not items:
174
+ items = _local_variants(brand, product, k_variants)
175
 
 
176
  rows = []
177
+ for it in items[:k_variants]:
178
+ headline = it["headline"]
179
+ body = it["body"]
180
  text = f"{headline}\n{body}".strip()
181
  vid = str(uuid.uuid4())[:8]
182
 
 
208
  df = pd.DataFrame(rows, columns=GENERATE_COLUMNS)
209
  return df
210
 
211
+
212
  def ui_serve(campaign_id: str, hour: int, segment: str):
213
  ctx = {"hour": int(hour), "segment": (segment or "").strip() or None}
214
  m = _seasonal(campaign_id)
 
227
 
228
  return vid, row["text"]
229
 
230
+
231
  def ui_feedback(campaign_id: str, variant_id: str, event_type: str):
232
  if not variant_id:
233
  raise gr.Error("先に Serve してください。")
 
235
  ThompsonBandit.update_with_event(campaign_id, variant_id, event_type)
236
  return f"{event_type} を記録しました。"
237
 
238
+
239
  def ui_report(campaign_id: str):
240
  mets = get_metrics(campaign_id)
241
  vpc = get_campaign_value_per_conversion(campaign_id)
 
258
  df = pd.DataFrame(rows, columns=REPORT_COLUMNS)
259
  return df
260
 
261
+
262
  def ui_check(text: str):
263
  ok_rule, bads = rule_based_check(text, [])
264
  ok_llm, reasons, fixed = llm_check_and_fix(text)
 
267
  reasons_joined = "; ".join(bads + reasons)
268
  return status, reasons_joined, fixed_text
269
 
270
+
271
  with gr.Blocks(title="AdCopy MAB Optimizer", fill_height=True) as demo:
272
  gr.Markdown("""
273
  # AdCopy MAB Optimizer(HF UI)