Spaces:
Sleeping
Sleeping
Update app/bandit.py
Browse files- app/bandit.py +27 -25
app/bandit.py
CHANGED
|
@@ -1,56 +1,58 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
import math
|
| 3 |
import random
|
| 4 |
-
from typing import Dict, Any,
|
| 5 |
from . import storage
|
| 6 |
|
| 7 |
class ThompsonBandit:
|
| 8 |
"""
|
| 9 |
-
CTR(クリック)とCVR
|
| 10 |
-
|
| 11 |
-
|
| 12 |
"""
|
| 13 |
-
|
| 14 |
def __init__(self, campaign_id: str, seasonality_boost: float = 5.0):
|
| 15 |
self.campaign_id = campaign_id
|
| 16 |
self.seasonality_boost = seasonality_boost
|
| 17 |
|
| 18 |
-
def sample_arm(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
metrics = storage.get_metrics(self.campaign_id)
|
|
|
|
|
|
|
|
|
|
| 20 |
vpc = storage.get_campaign_value_per_conversion(self.campaign_id)
|
| 21 |
-
best_score = -1.0
|
| 22 |
-
best_variant = None
|
| 23 |
|
|
|
|
| 24 |
s = 0.5
|
| 25 |
-
if
|
| 26 |
try:
|
| 27 |
-
s = float(
|
| 28 |
-
s = max(0.01, min(0.99,s))
|
| 29 |
except Exception:
|
| 30 |
s = 0.5
|
| 31 |
|
|
|
|
| 32 |
for row in metrics:
|
| 33 |
-
ac = row["alpha_click"]
|
| 34 |
-
av = row["alpha_conv"]
|
| 35 |
|
|
|
|
| 36 |
ac_eff = ac + self.seasonality_boost * s
|
| 37 |
-
bc_eff = bc + self.seasonality_boost * (1-s)
|
| 38 |
-
av_eff = av
|
| 39 |
-
bv_eff = bv
|
| 40 |
|
| 41 |
-
pc = random.betavariate(ac_eff, bc_eff)
|
| 42 |
-
pv = random.betavariate(
|
| 43 |
score = pc * pv * vpc
|
| 44 |
if score > best_score:
|
| 45 |
-
best_score = score
|
| 46 |
-
best_variant = row["variant_id"]
|
| 47 |
-
|
| 48 |
return best_variant, best_score
|
| 49 |
-
|
| 50 |
@staticmethod
|
| 51 |
def update_with_event(campaign_id: str, variant_id: str, event_type: str):
|
| 52 |
if event_type == "impression":
|
| 53 |
-
storage.update_metric(campaign_id, variant_id, "impressions",1)
|
| 54 |
storage.update_metric(campaign_id, variant_id, "beta_click", 1)
|
| 55 |
elif event_type == "click":
|
| 56 |
storage.update_metric(campaign_id, variant_id, "clicks", 1)
|
|
@@ -59,4 +61,4 @@ class ThompsonBandit:
|
|
| 59 |
storage.update_metric(campaign_id, variant_id, "conversions", 1)
|
| 60 |
storage.update_metric(campaign_id, variant_id, "alpha_conv", 1)
|
| 61 |
else:
|
| 62 |
-
raise ValueError("unknown event_type")
|
|
|
|
| 1 |
from __future__ import annotations
|
|
|
|
| 2 |
import random
|
| 3 |
+
from typing import Dict, Any, Tuple, Callable, Optional
|
| 4 |
from . import storage
|
| 5 |
|
| 6 |
class ThompsonBandit:
|
| 7 |
"""
|
| 8 |
+
CTR(クリック)とCVR(コンバージョン)の二段ベータ。
|
| 9 |
+
目的関数: E[value] = p_click * p_conv * value_per_conversion
|
| 10 |
+
季節性は CTR 側の仮想カウントで補正。
|
| 11 |
"""
|
|
|
|
| 12 |
def __init__(self, campaign_id: str, seasonality_boost: float = 5.0):
|
| 13 |
self.campaign_id = campaign_id
|
| 14 |
self.seasonality_boost = seasonality_boost
|
| 15 |
|
| 16 |
+
def sample_arm(
|
| 17 |
+
self,
|
| 18 |
+
context: Dict[str, Any],
|
| 19 |
+
seasonal_fn: Optional[Callable[[Dict[str, Any]], float]] = None, # ← キーワード引数OK
|
| 20 |
+
) -> Tuple[Optional[str], float]:
|
| 21 |
metrics = storage.get_metrics(self.campaign_id)
|
| 22 |
+
if not metrics:
|
| 23 |
+
return None, -1.0
|
| 24 |
+
|
| 25 |
vpc = storage.get_campaign_value_per_conversion(self.campaign_id)
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
# 季節性スコア s ∈ (0, 1)
|
| 28 |
s = 0.5
|
| 29 |
+
if seasonal_fn is not None:
|
| 30 |
try:
|
| 31 |
+
s = float(seasonal_fn(context))
|
| 32 |
+
s = max(0.01, min(0.99, s))
|
| 33 |
except Exception:
|
| 34 |
s = 0.5
|
| 35 |
|
| 36 |
+
best_score, best_variant = -1.0, None
|
| 37 |
for row in metrics:
|
| 38 |
+
ac, bc = float(row["alpha_click"]), float(row["beta_click"])
|
| 39 |
+
av, bv = float(row["alpha_conv"]), float(row["beta_conv"])
|
| 40 |
|
| 41 |
+
# 季節性でクリック側の事前分布を微調整
|
| 42 |
ac_eff = ac + self.seasonality_boost * s
|
| 43 |
+
bc_eff = bc + self.seasonality_boost * (1.0 - s)
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
pc = random.betavariate(max(1e-6, ac_eff), max(1e-6, bc_eff))
|
| 46 |
+
pv = random.betavariate(max(1e-6, av), max(1e-6, bv))
|
| 47 |
score = pc * pv * vpc
|
| 48 |
if score > best_score:
|
| 49 |
+
best_score, best_variant = score, row["variant_id"]
|
|
|
|
|
|
|
| 50 |
return best_variant, best_score
|
| 51 |
+
|
| 52 |
@staticmethod
|
| 53 |
def update_with_event(campaign_id: str, variant_id: str, event_type: str):
|
| 54 |
if event_type == "impression":
|
| 55 |
+
storage.update_metric(campaign_id, variant_id, "impressions", 1)
|
| 56 |
storage.update_metric(campaign_id, variant_id, "beta_click", 1)
|
| 57 |
elif event_type == "click":
|
| 58 |
storage.update_metric(campaign_id, variant_id, "clicks", 1)
|
|
|
|
| 61 |
storage.update_metric(campaign_id, variant_id, "conversions", 1)
|
| 62 |
storage.update_metric(campaign_id, variant_id, "alpha_conv", 1)
|
| 63 |
else:
|
| 64 |
+
raise ValueError("unknown event_type")
|