diff --git a/environments/README.md b/environments/README.md index 8399ed49..e6d82935 100644 --- a/environments/README.md +++ b/environments/README.md @@ -587,6 +587,12 @@ Environment for training and evaluating exact string reversal with optional thin - Defaults: `α=0.5`, `p=2`, `penalty_min_score=0.2`. - Incorrect rollouts remain at 0.0. If no valid think block (or thinking disabled), penalty is skipped for that rollout. +**Curriculum: One-Epoch + Hard Retries (optional):** +- Controlled by `curriculum_one_epoch_enabled` (default: True). +- First pass (one epoch): each item is attempted once. If any rollout in the group is correct (≥1/N), the item is considered solved and never revisited. If the group has zero correct (0/N), the item is marked “hard” and placed into a retry pool. +- Retry phase: only begins after the first pass over all training items completes. Items in the retry pool are revisited up to `hard_retry_max_attempts` times (default: 3). If still unsolved, they are dropped and training completes naturally when the retry pool is exhausted. +- Tip: Use a large `total_steps`. The environment will stop serving items once the one-epoch + retries queues are exhausted (it raises completion in `get_next_item`). + **Configuration Options (`TextReversalEnvConfig`):** - `use_thinking` (bool, default: False): include thinking system prompt. - `dataset_name` (str, default: `PrimeIntellect/Reverse-Text-SFT`): training dataset. @@ -599,6 +605,8 @@ Environment for training and evaluating exact string reversal with optional thin - `penalty_alpha` (float, default: 0.5): penalty scale. - `penalty_power` (float, default: 2.0): penalty exponent (quadratic by default). - `penalty_min_score` (float, default: 0.2): lower bound for penalized correct rollouts. +- `curriculum_one_epoch_enabled` (bool, default: True): enables one-pass training plus a late retry phase for hard items. +- `hard_retry_max_attempts` (int, default: 3): maximum retry attempts per hard item in the retry phase. **Usage Examples:** ```bash @@ -625,6 +633,11 @@ python text_reversal_environment.py serve \ --env.penalty_alpha=0.6 \ --env.penalty_power=2.0 \ --env.penalty_min_score=0.3 + +# Enable one-epoch + retries curriculum and set max retries +python text_reversal_environment.py serve \ + --env.curriculum_one_epoch_enabled=True \ + --env.hard_retry_max_attempts=3 ``` **Evaluation Metric:** diff --git a/environments/text_reversal_environment.py b/environments/text_reversal_environment.py index 5e11e7a9..5b7961fb 100644 --- a/environments/text_reversal_environment.py +++ b/environments/text_reversal_environment.py @@ -57,6 +57,10 @@ class TextReversalEnvConfig(BaseEnvConfig): penalty_power: float = 2.0 penalty_min_score: float = 0.2 + # Curriculum: single-epoch + hard-item retries + curriculum_one_epoch_enabled: bool = True + hard_retry_max_attempts: int = 3 + class TextReversalEnv(BaseEnv): env_config_cls = TextReversalEnvConfig @@ -77,6 +81,14 @@ class TextReversalEnv(BaseEnv): self.train = None self.test = None self.iter = 0 + # Curriculum state + self.first_pass_queue: List[Dict] = [] + self.retry_pool_ids: set = set() + self.retry_queue: List[Dict] = [] + self.retry_attempt_counts: Dict[str, int] = {} + self.in_retry_phase: bool = False + self.training_completed: bool = False + self._prompt_to_raw: Dict[Tuple[frozenset, ...], Dict] = {} @classmethod def config_init( @@ -111,6 +123,8 @@ class TextReversalEnv(BaseEnv): penalty_alpha=0.5, penalty_power=2.0, penalty_min_score=0.2, + curriculum_one_epoch_enabled=True, + hard_retry_max_attempts=3, ) server_configs = [ @@ -243,6 +257,14 @@ class TextReversalEnv(BaseEnv): self.train = processed_items[test_size:] self.iter = 0 + # Initialize curriculum queues + self.first_pass_queue = list(self.train) + self.retry_pool_ids = set() + self.retry_queue = [] + self.retry_attempt_counts = {} + self.in_retry_phase = False + self.training_completed = False + self._prompt_to_raw = {} def _extract_fields( self, row: Dict @@ -394,9 +416,54 @@ class TextReversalEnv(BaseEnv): trajectory_messages.append({"role": "assistant", "content": choice.text}) to_score.append((tuple(trajectory_messages), item[1])) + # Determine correctness-only (pre-penalty) for curriculum handling + any_correct = False + expected_text_for_group = item[1] + for trajectory_messages, _ in to_score: + model_response_text = trajectory_messages[-1]["content"] + model_answer_text = self._strip_think_and_trailing(model_response_text) + if (model_answer_text or "").strip() == (expected_text_for_group or "").strip(): + any_correct = True + break + scored = await self.score(to_score) + + # Update curriculum after group outcome + if getattr(self.config, "curriculum_one_epoch_enabled", True): + await self._update_curriculum_after_group(item, any_correct) + return scored, [] + async def _update_curriculum_after_group(self, item: Item, any_correct: bool): + """Update curriculum state for one-epoch + retries. + + First pass: + - If any_correct: mark solved (do nothing further) + - If none correct: add to retry pool for later + Retry phase: + - If any_correct: solved (do not requeue) + - Else: requeue until attempts reach max, then drop + """ + prompt_tuple = item[0] + raw_item = self._prompt_to_raw.get(prompt_tuple) + if raw_item is None: + return + item_id = f"{hash(str(raw_item))}" + + if not self.in_retry_phase: + if any_correct: + return + if item_id not in self.retry_pool_ids: + self.retry_pool_ids.add(item_id) + return + else: + if any_correct: + return + current_attempts = self.retry_attempt_counts.get(item_id, 0) + max_attempts = int(getattr(self.config, "hard_retry_max_attempts", 3)) + if current_attempts < max_attempts: + self.retry_queue.append(raw_item) + async def score( self, rollout_group_data: List[Tuple[Tuple[Dict, ...], str]] ) -> Optional[ScoredDataGroup]: @@ -492,16 +559,56 @@ class TextReversalEnv(BaseEnv): if not self.train: raise RuntimeError("Training data not initialized") - item = self.train[self.iter % len(self.train)] - self.iter += 1 + if getattr(self.config, "curriculum_one_epoch_enabled", True): + if self.training_completed: + raise RuntimeError("Training completed: no more items to process") - messages = self._build_messages( - item.get("system_content", ""), item.get("user_content", "") - ) + selected_item: Optional[Dict] = None - prompt_tuple = tuple(frozenset(m.items()) for m in messages) - answer_text = item.get("expected_assistant", "") - return (prompt_tuple, answer_text) + if not self.in_retry_phase: + if len(self.first_pass_queue) > 0: + selected_item = self.first_pass_queue.pop(0) + else: + # enter retry phase + self.in_retry_phase = True + # Build retry queue preserving original order + self.retry_queue = [ri for ri in self.train if f"{hash(str(ri))}" in self.retry_pool_ids] + self.retry_attempt_counts = {f"{hash(str(ri))}": 0 for ri in self.retry_queue} + + if self.in_retry_phase: + while selected_item is None: + if len(self.retry_queue) == 0: + self.training_completed = True + raise RuntimeError("Training completed: retry pool exhausted") + candidate = self.retry_queue.pop(0) + cand_id = f"{hash(str(candidate))}" + attempts = self.retry_attempt_counts.get(cand_id, 0) + max_attempts = int(getattr(self.config, "hard_retry_max_attempts", 3)) + if attempts >= max_attempts: + continue + self.retry_attempt_counts[cand_id] = attempts + 1 + selected_item = candidate + + item = selected_item + messages = self._build_messages( + item.get("system_content", ""), item.get("user_content", "") + ) + prompt_tuple = tuple(frozenset(m.items()) for m in messages) + answer_text = item.get("expected_assistant", "") + # Map prompt to raw for curriculum updates + self._prompt_to_raw[prompt_tuple] = item + self.iter += 1 + return (prompt_tuple, answer_text) + else: + item = self.train[self.iter % len(self.train)] + self.iter += 1 + messages = self._build_messages( + item.get("system_content", ""), item.get("user_content", "") + ) + prompt_tuple = tuple(frozenset(m.items()) for m in messages) + answer_text = item.get("expected_assistant", "") + self._prompt_to_raw[prompt_tuple] = item + return (prompt_tuple, answer_text) async def add_rollouts_for_wandb( self,