[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-08-12 20:51:01 +00:00
parent bcdc51a6fc
commit 21605d100c

View file

@ -17,7 +17,6 @@ from atroposlib.envs.base import (
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# Thinking-enabled system prompt
thinking_system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
@ -127,9 +126,9 @@ class TextReversalEnv(BaseEnv):
wandb_metrics = {}
try:
wandb_metrics["train/percent_correct"] = sum(self.percent_correct_buffer) / max(
1, len(self.percent_correct_buffer)
)
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / max(1, len(self.percent_correct_buffer))
except ZeroDivisionError:
pass
@ -151,7 +150,9 @@ class TextReversalEnv(BaseEnv):
- user_content
- expected_assistant
"""
dataset_name = getattr(self.config, "dataset_name", "PrimeIntellect/Reverse-Text-SFT")
dataset_name = getattr(
self.config, "dataset_name", "PrimeIntellect/Reverse-Text-SFT"
)
eval_dataset_name = getattr(self.config, "eval_dataset_name", None)
try:
full_dataset = load_dataset(dataset_name, split="train")
@ -229,7 +230,9 @@ class TextReversalEnv(BaseEnv):
self.iter = 0
def _extract_fields(self, row: Dict) -> Tuple[Optional[str], Optional[str], Optional[str]]:
def _extract_fields(
self, row: Dict
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""Extract (system_text, user_text, expected_text) from a raw dataset row.
Expected forms:
@ -315,7 +318,9 @@ class TextReversalEnv(BaseEnv):
completion = await self.server.completion(
prompt=prompt,
n=1,
max_tokens=getattr(self.config, "max_eval_token_length", self.config.max_token_length),
max_tokens=getattr(
self.config, "max_eval_token_length", self.config.max_token_length
),
temperature=0.2,
split="eval",
)
@ -335,7 +340,9 @@ class TextReversalEnv(BaseEnv):
percent_correct = sum(scores) / len(scores) if scores else 0.0
self.eval_metrics.append(("eval/percent_correct", percent_correct))
async def collect_trajectories(self, item: Item) -> Tuple[Optional[ScoredDataGroup], List]:
async def collect_trajectories(
self, item: Item
) -> Tuple[Optional[ScoredDataGroup], List]:
# item: (prompt_messages_tuple, expected_text_string)
prompt_messages_list = [dict(m) for m in item[0]]
@ -346,7 +353,9 @@ class TextReversalEnv(BaseEnv):
completions = await self.server.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=getattr(self.config, "max_train_token_length", self.config.max_token_length),
max_tokens=getattr(
self.config, "max_train_token_length", self.config.max_token_length
),
temperature=0.8,
)
@ -415,7 +424,9 @@ class TextReversalEnv(BaseEnv):
item = self.train[self.iter % len(self.train)]
self.iter += 1
messages = self._build_messages(item.get("system_content", ""), item.get("user_content", ""))
messages = self._build_messages(
item.get("system_content", ""), item.get("user_content", "")
)
prompt_tuple = tuple(frozenset(m.items()) for m in messages)
answer_text = item.get("expected_assistant", "")
@ -433,7 +444,9 @@ class TextReversalEnv(BaseEnv):
expected_text = item[1] if item else ""
group_rows = []
for i in range(min(num_keep, len(scored_data["tokens"]))):
decoded = self.tokenizer.decode(scored_data["tokens"][i], skip_special_tokens=False)
decoded = self.tokenizer.decode(
scored_data["tokens"][i], skip_special_tokens=False
)
score_val = scored_data["scores"][i]
group_rows.append((decoded, score_val, expected_text))
@ -444,5 +457,3 @@ class TextReversalEnv(BaseEnv):
if __name__ == "__main__":
TextReversalEnv.cli()