mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
bcdc51a6fc
commit
21605d100c
1 changed files with 24 additions and 13 deletions
|
|
@ -17,7 +17,6 @@ from atroposlib.envs.base import (
|
|||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
# Thinking-enabled system prompt
|
||||
thinking_system_prompt = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
||||
|
|
@ -127,9 +126,9 @@ class TextReversalEnv(BaseEnv):
|
|||
wandb_metrics = {}
|
||||
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(self.percent_correct_buffer) / max(
|
||||
1, len(self.percent_correct_buffer)
|
||||
)
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / max(1, len(self.percent_correct_buffer))
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
|
||||
|
|
@ -151,7 +150,9 @@ class TextReversalEnv(BaseEnv):
|
|||
- user_content
|
||||
- expected_assistant
|
||||
"""
|
||||
dataset_name = getattr(self.config, "dataset_name", "PrimeIntellect/Reverse-Text-SFT")
|
||||
dataset_name = getattr(
|
||||
self.config, "dataset_name", "PrimeIntellect/Reverse-Text-SFT"
|
||||
)
|
||||
eval_dataset_name = getattr(self.config, "eval_dataset_name", None)
|
||||
try:
|
||||
full_dataset = load_dataset(dataset_name, split="train")
|
||||
|
|
@ -229,7 +230,9 @@ class TextReversalEnv(BaseEnv):
|
|||
|
||||
self.iter = 0
|
||||
|
||||
def _extract_fields(self, row: Dict) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
def _extract_fields(
|
||||
self, row: Dict
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""Extract (system_text, user_text, expected_text) from a raw dataset row.
|
||||
|
||||
Expected forms:
|
||||
|
|
@ -315,7 +318,9 @@ class TextReversalEnv(BaseEnv):
|
|||
completion = await self.server.completion(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
max_tokens=getattr(self.config, "max_eval_token_length", self.config.max_token_length),
|
||||
max_tokens=getattr(
|
||||
self.config, "max_eval_token_length", self.config.max_token_length
|
||||
),
|
||||
temperature=0.2,
|
||||
split="eval",
|
||||
)
|
||||
|
|
@ -335,7 +340,9 @@ class TextReversalEnv(BaseEnv):
|
|||
percent_correct = sum(scores) / len(scores) if scores else 0.0
|
||||
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
||||
|
||||
async def collect_trajectories(self, item: Item) -> Tuple[Optional[ScoredDataGroup], List]:
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List]:
|
||||
# item: (prompt_messages_tuple, expected_text_string)
|
||||
prompt_messages_list = [dict(m) for m in item[0]]
|
||||
|
||||
|
|
@ -346,7 +353,9 @@ class TextReversalEnv(BaseEnv):
|
|||
completions = await self.server.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=getattr(self.config, "max_train_token_length", self.config.max_token_length),
|
||||
max_tokens=getattr(
|
||||
self.config, "max_train_token_length", self.config.max_token_length
|
||||
),
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
|
|
@ -415,7 +424,9 @@ class TextReversalEnv(BaseEnv):
|
|||
item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
||||
messages = self._build_messages(item.get("system_content", ""), item.get("user_content", ""))
|
||||
messages = self._build_messages(
|
||||
item.get("system_content", ""), item.get("user_content", "")
|
||||
)
|
||||
|
||||
prompt_tuple = tuple(frozenset(m.items()) for m in messages)
|
||||
answer_text = item.get("expected_assistant", "")
|
||||
|
|
@ -433,7 +444,9 @@ class TextReversalEnv(BaseEnv):
|
|||
expected_text = item[1] if item else ""
|
||||
group_rows = []
|
||||
for i in range(min(num_keep, len(scored_data["tokens"]))):
|
||||
decoded = self.tokenizer.decode(scored_data["tokens"][i], skip_special_tokens=False)
|
||||
decoded = self.tokenizer.decode(
|
||||
scored_data["tokens"][i], skip_special_tokens=False
|
||||
)
|
||||
score_val = scored_data["scores"][i]
|
||||
group_rows.append((decoded, score_val, expected_text))
|
||||
|
||||
|
|
@ -444,5 +457,3 @@ class TextReversalEnv(BaseEnv):
|
|||
|
||||
if __name__ == "__main__":
|
||||
TextReversalEnv.cli()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue