mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
run pre-commit on all files
This commit is contained in:
parent
b959c30ebf
commit
40b12dae60
17 changed files with 169 additions and 118 deletions
|
|
@ -32,7 +32,9 @@ class DatasetEnvConfig(BaseEnvConfig):
|
|||
None, description="Field in dataset containing canonical correct answer"
|
||||
)
|
||||
system_prompt: Optional[str] = Field(None, description="System prompt to use")
|
||||
prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '<think>')")
|
||||
prefill: Optional[str] = Field(
|
||||
None, description="Text to prefill the completion with (e.g. '<think>')"
|
||||
)
|
||||
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset")
|
||||
max_generations_per_prompt: int = Field(
|
||||
1, description="Number of generations per prompt for collection"
|
||||
|
|
@ -137,21 +139,21 @@ class DatasetEnv(BaseEnv):
|
|||
# Extract user prompt and answer from item
|
||||
user_content = dict(item[0][0])["content"]
|
||||
answer = item[1] if len(item) > 1 else None
|
||||
|
||||
|
||||
# Create messages list
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Add prefill as assistant message if configured
|
||||
if self.config.prefill:
|
||||
messages.append({"role": "assistant", "content": self.config.prefill})
|
||||
|
||||
|
||||
# Convert messages to a prompt string using the tokenizer
|
||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
|
||||
# Calculate max tokens for generation (with optional warmup)
|
||||
max_tokens = self.config.max_tokens
|
||||
if self.config.length_warmup_steps > 0:
|
||||
|
|
@ -160,7 +162,7 @@ class DatasetEnv(BaseEnv):
|
|||
self.config.min_tokens
|
||||
+ warmup_progress * (self.config.max_tokens - self.config.min_tokens)
|
||||
)
|
||||
|
||||
|
||||
# Generate completion using completions API
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt,
|
||||
|
|
@ -169,34 +171,38 @@ class DatasetEnv(BaseEnv):
|
|||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
)
|
||||
|
||||
|
||||
to_score = []
|
||||
to_backlog = []
|
||||
|
||||
|
||||
# Process completions
|
||||
for completion in completions.choices:
|
||||
# Get the completion text
|
||||
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
|
||||
|
||||
completion_text = (
|
||||
completion.text
|
||||
if hasattr(completion, "text")
|
||||
else completion.message.content
|
||||
)
|
||||
|
||||
# Build full message sequence for scoring
|
||||
full_messages = []
|
||||
if self.config.system_prompt:
|
||||
full_messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
full_messages.append(
|
||||
{"role": "system", "content": self.config.system_prompt}
|
||||
)
|
||||
|
||||
full_messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Combine prefill with completion if prefill was used
|
||||
response_content = completion_text
|
||||
if self.config.prefill:
|
||||
response_content = self.config.prefill + completion_text
|
||||
|
||||
|
||||
full_messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
|
||||
# Add to scoring list with answer and ground truth
|
||||
to_score.append(
|
||||
(full_messages, answer, item[2] if len(item) > 2 else None)
|
||||
)
|
||||
|
||||
to_score.append((full_messages, answer, item[2] if len(item) > 2 else None))
|
||||
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]:
|
||||
|
|
@ -204,27 +210,27 @@ class DatasetEnv(BaseEnv):
|
|||
|
||||
async def collect_trajectories(self, item: Item) -> Tuple[List, List]:
|
||||
self.current_item = item
|
||||
|
||||
|
||||
# Extract user prompt from item
|
||||
user_content = dict(item[0][0])["content"]
|
||||
|
||||
|
||||
# Create messages list
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Add prefill as assistant message if configured
|
||||
if self.config.prefill:
|
||||
messages.append({"role": "assistant", "content": self.config.prefill})
|
||||
|
||||
|
||||
# Convert messages to a prompt string using the tokenizer
|
||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
|
||||
# Calculate max tokens for generation (with optional warmup)
|
||||
max_tokens = self.config.max_tokens
|
||||
|
||||
|
||||
# Generate completions
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt,
|
||||
|
|
@ -233,30 +239,36 @@ class DatasetEnv(BaseEnv):
|
|||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
)
|
||||
|
||||
|
||||
print(f"Completions: {completions}")
|
||||
# Process completions
|
||||
trajectories = []
|
||||
for completion in completions.choices:
|
||||
# Get the completion text
|
||||
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
|
||||
|
||||
completion_text = (
|
||||
completion.text
|
||||
if hasattr(completion, "text")
|
||||
else completion.message.content
|
||||
)
|
||||
|
||||
# Build complete message sequence
|
||||
full_messages = []
|
||||
if self.config.system_prompt:
|
||||
full_messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
full_messages.append(
|
||||
{"role": "system", "content": self.config.system_prompt}
|
||||
)
|
||||
|
||||
full_messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Combine prefill with completion if prefill was used
|
||||
response_content = completion_text
|
||||
if self.config.prefill:
|
||||
response_content = self.config.prefill + completion_text
|
||||
|
||||
|
||||
full_messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
|
||||
trajectories.append(full_messages)
|
||||
|
||||
|
||||
return trajectories, []
|
||||
|
||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||
|
|
@ -402,6 +414,7 @@ class DatasetEnv(BaseEnv):
|
|||
|
||||
await super().wandb_log(metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Launch the DatasetEnv via the BaseEnv CLI (serve or process)
|
||||
DatasetEnv.cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue