diff --git a/environments/text_reversal_environment.py b/environments/text_reversal_environment.py index 903d5036..fbe2fd49 100644 --- a/environments/text_reversal_environment.py +++ b/environments/text_reversal_environment.py @@ -17,7 +17,6 @@ from atroposlib.envs.base import ( ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer - # Thinking-enabled system prompt thinking_system_prompt = ( "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " @@ -127,9 +126,9 @@ class TextReversalEnv(BaseEnv): wandb_metrics = {} try: - wandb_metrics["train/percent_correct"] = sum(self.percent_correct_buffer) / max( - 1, len(self.percent_correct_buffer) - ) + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / max(1, len(self.percent_correct_buffer)) except ZeroDivisionError: pass @@ -151,7 +150,9 @@ class TextReversalEnv(BaseEnv): - user_content - expected_assistant """ - dataset_name = getattr(self.config, "dataset_name", "PrimeIntellect/Reverse-Text-SFT") + dataset_name = getattr( + self.config, "dataset_name", "PrimeIntellect/Reverse-Text-SFT" + ) eval_dataset_name = getattr(self.config, "eval_dataset_name", None) try: full_dataset = load_dataset(dataset_name, split="train") @@ -229,7 +230,9 @@ class TextReversalEnv(BaseEnv): self.iter = 0 - def _extract_fields(self, row: Dict) -> Tuple[Optional[str], Optional[str], Optional[str]]: + def _extract_fields( + self, row: Dict + ) -> Tuple[Optional[str], Optional[str], Optional[str]]: """Extract (system_text, user_text, expected_text) from a raw dataset row. Expected forms: @@ -315,7 +318,9 @@ class TextReversalEnv(BaseEnv): completion = await self.server.completion( prompt=prompt, n=1, - max_tokens=getattr(self.config, "max_eval_token_length", self.config.max_token_length), + max_tokens=getattr( + self.config, "max_eval_token_length", self.config.max_token_length + ), temperature=0.2, split="eval", ) @@ -335,7 +340,9 @@ class TextReversalEnv(BaseEnv): percent_correct = sum(scores) / len(scores) if scores else 0.0 self.eval_metrics.append(("eval/percent_correct", percent_correct)) - async def collect_trajectories(self, item: Item) -> Tuple[Optional[ScoredDataGroup], List]: + async def collect_trajectories( + self, item: Item + ) -> Tuple[Optional[ScoredDataGroup], List]: # item: (prompt_messages_tuple, expected_text_string) prompt_messages_list = [dict(m) for m in item[0]] @@ -346,7 +353,9 @@ class TextReversalEnv(BaseEnv): completions = await self.server.completion( prompt=prompt, n=self.config.group_size, - max_tokens=getattr(self.config, "max_train_token_length", self.config.max_token_length), + max_tokens=getattr( + self.config, "max_train_token_length", self.config.max_token_length + ), temperature=0.8, ) @@ -415,7 +424,9 @@ class TextReversalEnv(BaseEnv): item = self.train[self.iter % len(self.train)] self.iter += 1 - messages = self._build_messages(item.get("system_content", ""), item.get("user_content", "")) + messages = self._build_messages( + item.get("system_content", ""), item.get("user_content", "") + ) prompt_tuple = tuple(frozenset(m.items()) for m in messages) answer_text = item.get("expected_assistant", "") @@ -433,7 +444,9 @@ class TextReversalEnv(BaseEnv): expected_text = item[1] if item else "" group_rows = [] for i in range(min(num_keep, len(scored_data["tokens"]))): - decoded = self.tokenizer.decode(scored_data["tokens"][i], skip_special_tokens=False) + decoded = self.tokenizer.decode( + scored_data["tokens"][i], skip_special_tokens=False + ) score_val = scored_data["scores"][i] group_rows.append((decoded, score_val, expected_text)) @@ -444,5 +457,3 @@ class TextReversalEnv(BaseEnv): if __name__ == "__main__": TextReversalEnv.cli() - -