gbyuvd commited on
Commit
306fa47
·
verified ·
1 Parent(s): cd018ff

Update RL utils and train-sa using new KL and Beta computation+capping

Browse files
Files changed (3) hide show
  1. __init__.py +13 -13
  2. rl_utils.py +60 -24
  3. train_ppokl_withsa.py +36 -8
__init__.py CHANGED
@@ -1,14 +1,14 @@
1
- # __init__.py
2
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
3
- from .configuration_chemq3mtp import ChemQ3MTPConfig
4
- from .modeling_chemq3mtp import ChemQ3MTPForCausalLM
5
- from .FastChemTokenizerHF import FastChemTokenizerSelfies
6
-
7
- # Register the model
8
- AutoConfig.register("chemq3_mtp", ChemQ3MTPConfig)
9
- AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM)
10
-
11
- # Register the tokenizer
12
- AutoTokenizer.register(ChemQ3MTPConfig, FastChemTokenizerSelfies)
13
-
14
  __all__ = ["ChemQ3MTPConfig", "ChemQ3MTPForCausalLM", "FastChemTokenizerSelfies"]
 
1
+ # __init__.py
2
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
3
+ from .configuration_chemq3mtp import ChemQ3MTPConfig
4
+ from .modeling_chemq3mtp import ChemQ3MTPForCausalLM
5
+ from .FastChemTokenizerHF import FastChemTokenizerSelfies
6
+
7
+ # Register the model
8
+ AutoConfig.register("chemq3_mtp", ChemQ3MTPConfig)
9
+ AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM)
10
+
11
+ # Register the tokenizer
12
+ AutoTokenizer.register(ChemQ3MTPConfig, FastChemTokenizerSelfies)
13
+
14
  __all__ = ["ChemQ3MTPConfig", "ChemQ3MTPForCausalLM", "FastChemTokenizerSelfies"]
rl_utils.py CHANGED
@@ -290,31 +290,67 @@ def selfies_to_lipinski_reward(selfies_str: str) -> float:
290
  # ========================
291
 
292
  class AdaptiveKLController:
293
- def __init__(self, init_kl_coef: float = 0.1, target_kl: float = 0.01,
294
- kl_horizon: int = 200, increase_rate: float = 2.0,
295
- decrease_rate: float = 0.7):
296
- self.kl_coef = float(init_kl_coef)
297
- self.target_kl = float(target_kl)
298
- self.kl_horizon = int(kl_horizon)
299
- self.inc = float(increase_rate)
300
- self.dec = float(decrease_rate)
301
- self.buffer: List[float] = []
302
-
303
- def update(self, kl: float) -> float:
304
- self.buffer.append(float(kl))
305
- if len(self.buffer) >= self.kl_horizon:
306
- avg_kl = sum(self.buffer) / len(self.buffer)
307
- self.buffer.clear()
308
- if avg_kl > self.target_kl * 1.5:
309
- self.kl_coef *= self.inc
310
- print(f"KL too high ({avg_kl:.6f}), increasing β to {self.kl_coef:.6f}")
311
- elif avg_kl < self.target_kl * 0.5:
312
- self.kl_coef *= self.dec
313
- print(f"KL too low ({avg_kl:.6f}), decreasing β to {self.kl_coef:.6f}")
314
- return self.kl_coef
315
 
316
- def reset(self):
317
- self.buffer.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
 
320
  class EnhancedEntropyController:
 
290
  # ========================
291
 
292
  class AdaptiveKLController:
293
+ """
294
+ Adaptive KL controller with hard clipping and EMA smoothing.
295
+ Prevents runaway beta values and exploding KL penalties.
296
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
+ def __init__(
299
+ self,
300
+ init_kl_coef: float = 0.2,
301
+ target_kl: float = 6.0,
302
+ horizon: int = 10000,
303
+ max_kl_coef: float = 10.0,
304
+ max_inc_factor: float = 2.0,
305
+ ema_alpha: float = 0.9,
306
+ kl_penalty_cap: float = 10.0,
307
+ ):
308
+ self.value = init_kl_coef
309
+ self.target = target_kl
310
+ self.horizon = horizon
311
+ self.max_kl_coef = max_kl_coef
312
+ self.max_inc_factor = max_inc_factor
313
+ self.ema_alpha = ema_alpha
314
+ self.kl_penalty_cap = kl_penalty_cap
315
+
316
+ # Exponential moving average of KL
317
+ self.ema_kl = None
318
+
319
+ def update(self, current_kl: float, n_steps: int) -> None:
320
+ # update EMA
321
+ if self.ema_kl is None:
322
+ self.ema_kl = current_kl
323
+ else:
324
+ self.ema_kl = (
325
+ self.ema_alpha * self.ema_kl + (1 - self.ema_alpha) * current_kl
326
+ )
327
+
328
+ proportional_error = np.clip(
329
+ (self.ema_kl - self.target) / self.target, -1.0, 1.0
330
+ )
331
+ mult = 1.0 + proportional_error * (n_steps / self.horizon)
332
+
333
+ # cap growth
334
+ if mult > self.max_inc_factor:
335
+ mult = self.max_inc_factor
336
+
337
+ # update beta
338
+ new_val = self.value * mult
339
+ self.value = min(new_val, self.max_kl_coef)
340
+
341
+ def __call__(self) -> float:
342
+ return self.value
343
+
344
+
345
+ def compute_kl_penalty(kl_vals: torch.Tensor, kl_coef: float, kl_penalty_cap: float):
346
+ """
347
+ Compute KL penalty with clipping.
348
+ Returns (clipped_penalty, raw_penalty, kl_mean).
349
+ """
350
+ kl_mean = kl_vals.mean()
351
+ raw_penalty = kl_coef * kl_mean
352
+ clipped_penalty = torch.clamp(raw_penalty, max=kl_penalty_cap)
353
+ return clipped_penalty, raw_penalty, kl_mean
354
 
355
 
356
  class EnhancedEntropyController:
train_ppokl_withsa.py CHANGED
@@ -12,7 +12,7 @@ import numpy as np
12
  from tqdm import tqdm
13
  from FastChemTokenizerHF import FastChemTokenizerSelfies
14
  from ChemQ3MTP import ChemQ3MTPForCausalLM
15
- from ChemQ3MTP.rl_utils import CurriculumManager, AdaptiveKLController, batch_compute_rewards, compute_ppo_loss, compute_kl_divergence, compute_entropy_bonus
16
 
17
  def main():
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -30,8 +30,15 @@ def main():
30
  print("\n🎯 Phase 2: RL Fine-tuning with PPO + Curriculum Learning")
31
  model.set_mtp_training(False)
32
 
33
- # Initialize KL controller
34
- kl_controller = AdaptiveKLController(init_kl_coef=0.1, target_kl=0.01, kl_horizon=100)
 
 
 
 
 
 
 
35
  model.kl_controller = kl_controller # Set on model for consistency
36
 
37
  optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
@@ -45,13 +52,14 @@ def main():
45
  input_ids = dummy_input.input_ids.to(device)
46
 
47
  # Training config
48
- total_steps = 10000
49
  checkpoint_steps = {total_steps // 4, total_steps // 2, 3 * total_steps // 4, total_steps}
50
  checkpoint_dir = "./ppo_checkpoints_test"
51
  os.makedirs(checkpoint_dir, exist_ok=True)
52
 
53
  # --- RL Training Loop with tqdm ---
54
  for step in tqdm(range(total_steps), desc="RL Training"):
 
55
  max_new_tokens = curriculum.get_max_new_tokens()
56
 
57
  # === PPO Rollout ===
@@ -83,7 +91,7 @@ def main():
83
  # === Compute rewards using rl_utils ===
84
  rewards_dict = batch_compute_rewards(
85
  selfies_list=selfies_list,
86
- reward_mode="sa", # SA-only mode
87
  )
88
  rewards = rewards_dict["total_rewards"].to(device)
89
 
@@ -96,10 +104,28 @@ def main():
96
  baseline=baseline
97
  )
98
 
 
 
99
  # === Compute KL divergence and update controller ===
100
  kl_div = compute_kl_divergence(old_action_probs, new_action_probs)
101
- beta = kl_controller.update(kl_div.mean().item())
102
- kl_penalty = beta * kl_div.mean()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  # === Compute entropy bonus with adaptive weighting ===
105
  entropy_per_example = compute_entropy_bonus(new_action_probs)
@@ -184,7 +210,9 @@ def main():
184
  f"Lipinski={lipinski_score:.3f} | "
185
  f"Reward={rewards.mean().item():.3f} | "
186
  f"Entropy={entropy.item():.3f} | "
187
- f"EntropyW={adaptive_entropy_weight:.4f}"
 
 
188
  )
189
  if avg_sa_reward is not None:
190
  log_line += f" | SA={avg_sa_reward:.3f}"
 
12
  from tqdm import tqdm
13
  from FastChemTokenizerHF import FastChemTokenizerSelfies
14
  from ChemQ3MTP import ChemQ3MTPForCausalLM
15
+ from ChemQ3MTP.rl_utils import CurriculumManager, AdaptiveKLController, batch_compute_rewards, compute_ppo_loss, compute_kl_divergence, compute_entropy_bonus, compute_kl_penalty
16
 
17
  def main():
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
30
  print("\n🎯 Phase 2: RL Fine-tuning with PPO + Curriculum Learning")
31
  model.set_mtp_training(False)
32
 
33
+ # Initialize KL controller - Using correct parameter name based on class definition
34
+ kl_controller = AdaptiveKLController(
35
+ init_kl_coef=0.1,
36
+ target_kl=0.01,
37
+ horizon=100, # <-- use horizon instead of kl_horizon
38
+ max_kl_coef=100.0, # optional
39
+ ema_alpha=0.9, # optional
40
+ kl_penalty_cap=10.0 # optional
41
+ )
42
  model.kl_controller = kl_controller # Set on model for consistency
43
 
44
  optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
 
52
  input_ids = dummy_input.input_ids.to(device)
53
 
54
  # Training config
55
+ total_steps = 2500
56
  checkpoint_steps = {total_steps // 4, total_steps // 2, 3 * total_steps // 4, total_steps}
57
  checkpoint_dir = "./ppo_checkpoints_test"
58
  os.makedirs(checkpoint_dir, exist_ok=True)
59
 
60
  # --- RL Training Loop with tqdm ---
61
  for step in tqdm(range(total_steps), desc="RL Training"):
62
+ global_step = step # Define global_step for KL controller
63
  max_new_tokens = curriculum.get_max_new_tokens()
64
 
65
  # === PPO Rollout ===
 
91
  # === Compute rewards using rl_utils ===
92
  rewards_dict = batch_compute_rewards(
93
  selfies_list=selfies_list,
94
+ reward_mode="chemq3", # Bioaware-only mode
95
  )
96
  rewards = rewards_dict["total_rewards"].to(device)
97
 
 
104
  baseline=baseline
105
  )
106
 
107
+ # === Compute KL divergence and update controller ===
108
+ # Compute KL divergence per batch
109
  # === Compute KL divergence and update controller ===
110
  kl_div = compute_kl_divergence(old_action_probs, new_action_probs)
111
+ kl_mean = kl_div.mean().item()
112
+
113
+ # Update KL controller using EMA-smoothed KL
114
+ kl_controller.update(kl_mean, n_steps=global_step)
115
+ beta = kl_controller() # get current coefficient
116
+
117
+ # Compute clipped KL penalty
118
+ kl_penalty, raw_kl_penalty, kl_mean_tensor = compute_kl_penalty(
119
+ kl_div, beta, kl_controller.kl_penalty_cap
120
+ )
121
+
122
+ # --- Logging (safe, interpretable values) ---
123
+ logs = {}
124
+ logs["kl_mean"] = kl_mean_tensor.item()
125
+ logs["kl_beta"] = beta
126
+ logs["kl_penalty_raw"] = raw_kl_penalty.item()
127
+ logs["kl_penalty_clipped"] = kl_penalty.item()
128
+
129
 
130
  # === Compute entropy bonus with adaptive weighting ===
131
  entropy_per_example = compute_entropy_bonus(new_action_probs)
 
210
  f"Lipinski={lipinski_score:.3f} | "
211
  f"Reward={rewards.mean().item():.3f} | "
212
  f"Entropy={entropy.item():.3f} | "
213
+ f"EntropyW={adaptive_entropy_weight:.4f} | "
214
+ f"KL_Beta={beta:.4f} | "
215
+ f"KL_Mean={kl_mean:.4f}"
216
  )
217
  if avg_sa_reward is not None:
218
  log_line += f" | SA={avg_sa_reward:.3f}"