Corin1998 commited on
Commit
57af9d9
·
verified ·
1 Parent(s): 12e0e33

Update app/bandit.py

Browse files
Files changed (1) hide show
  1. app/bandit.py +27 -25
app/bandit.py CHANGED
@@ -1,56 +1,58 @@
1
  from __future__ import annotations
2
- import math
3
  import random
4
- from typing import Dict, Any, List, Tuple
5
  from . import storage
6
 
7
  class ThompsonBandit:
8
  """
9
- CTR(クリック)とCVR(コンバージョン)の二段ベータモデル。
10
- 目的関数は E[value] = p_click * p_conv * value_per_conversion
11
- 季節性補正は、意思決定時に小さな事前カウント(virtual counts)として加算。
12
  """
13
-
14
  def __init__(self, campaign_id: str, seasonality_boost: float = 5.0):
15
  self.campaign_id = campaign_id
16
  self.seasonality_boost = seasonality_boost
17
 
18
- def sample_arm(self, context: Dict[str, Any], seasonal_fin) -> Tuple[str, float]:
 
 
 
 
19
  metrics = storage.get_metrics(self.campaign_id)
 
 
 
20
  vpc = storage.get_campaign_value_per_conversion(self.campaign_id)
21
- best_score = -1.0
22
- best_variant = None
23
 
 
24
  s = 0.5
25
- if seasonal_fin:
26
  try:
27
- s = float(seasonal_fin(context))
28
- s = max(0.01, min(0.99,s))
29
  except Exception:
30
  s = 0.5
31
 
 
32
  for row in metrics:
33
- ac = row["alpha_click"]; bc = row["beta_click"]
34
- av = row["alpha_conv"]; bv =row["beta_conv"]
35
 
 
36
  ac_eff = ac + self.seasonality_boost * s
37
- bc_eff = bc + self.seasonality_boost * (1-s)
38
- av_eff = av
39
- bv_eff = bv
40
 
41
- pc = random.betavariate(ac_eff, bc_eff)
42
- pv = random.betavariate(av_eff, bv_eff)
43
  score = pc * pv * vpc
44
  if score > best_score:
45
- best_score = score
46
- best_variant = row["variant_id"]
47
-
48
  return best_variant, best_score
49
-
50
  @staticmethod
51
  def update_with_event(campaign_id: str, variant_id: str, event_type: str):
52
  if event_type == "impression":
53
- storage.update_metric(campaign_id, variant_id, "impressions",1)
54
  storage.update_metric(campaign_id, variant_id, "beta_click", 1)
55
  elif event_type == "click":
56
  storage.update_metric(campaign_id, variant_id, "clicks", 1)
@@ -59,4 +61,4 @@ class ThompsonBandit:
59
  storage.update_metric(campaign_id, variant_id, "conversions", 1)
60
  storage.update_metric(campaign_id, variant_id, "alpha_conv", 1)
61
  else:
62
- raise ValueError("unknown event_type")
 
1
  from __future__ import annotations
 
2
  import random
3
+ from typing import Dict, Any, Tuple, Callable, Optional
4
  from . import storage
5
 
6
  class ThompsonBandit:
7
  """
8
+ CTR(クリック)とCVR(コンバージョン)の二段ベータ。
9
+ 目的関数: E[value] = p_click * p_conv * value_per_conversion
10
+ 季節性は CTR 側の仮想カウントで補正。
11
  """
 
12
  def __init__(self, campaign_id: str, seasonality_boost: float = 5.0):
13
  self.campaign_id = campaign_id
14
  self.seasonality_boost = seasonality_boost
15
 
16
+ def sample_arm(
17
+ self,
18
+ context: Dict[str, Any],
19
+ seasonal_fn: Optional[Callable[[Dict[str, Any]], float]] = None, # ← キーワード引数OK
20
+ ) -> Tuple[Optional[str], float]:
21
  metrics = storage.get_metrics(self.campaign_id)
22
+ if not metrics:
23
+ return None, -1.0
24
+
25
  vpc = storage.get_campaign_value_per_conversion(self.campaign_id)
 
 
26
 
27
+ # 季節性スコア s ∈ (0, 1)
28
  s = 0.5
29
+ if seasonal_fn is not None:
30
  try:
31
+ s = float(seasonal_fn(context))
32
+ s = max(0.01, min(0.99, s))
33
  except Exception:
34
  s = 0.5
35
 
36
+ best_score, best_variant = -1.0, None
37
  for row in metrics:
38
+ ac, bc = float(row["alpha_click"]), float(row["beta_click"])
39
+ av, bv = float(row["alpha_conv"]), float(row["beta_conv"])
40
 
41
+ # 季節性でクリック側の事前分布を微調整
42
  ac_eff = ac + self.seasonality_boost * s
43
+ bc_eff = bc + self.seasonality_boost * (1.0 - s)
 
 
44
 
45
+ pc = random.betavariate(max(1e-6, ac_eff), max(1e-6, bc_eff))
46
+ pv = random.betavariate(max(1e-6, av), max(1e-6, bv))
47
  score = pc * pv * vpc
48
  if score > best_score:
49
+ best_score, best_variant = score, row["variant_id"]
 
 
50
  return best_variant, best_score
51
+
52
  @staticmethod
53
  def update_with_event(campaign_id: str, variant_id: str, event_type: str):
54
  if event_type == "impression":
55
+ storage.update_metric(campaign_id, variant_id, "impressions", 1)
56
  storage.update_metric(campaign_id, variant_id, "beta_click", 1)
57
  elif event_type == "click":
58
  storage.update_metric(campaign_id, variant_id, "clicks", 1)
 
61
  storage.update_metric(campaign_id, variant_id, "conversions", 1)
62
  storage.update_metric(campaign_id, variant_id, "alpha_conv", 1)
63
  else:
64
+ raise ValueError("unknown event_type")