run pre-commit on all files

This commit is contained in:
dmahan93 2025-05-09 09:54:20 -05:00
parent b959c30ebf
commit 40b12dae60
17 changed files with 169 additions and 118 deletions

View file

@ -32,7 +32,9 @@ class DatasetEnvConfig(BaseEnvConfig):
None, description="Field in dataset containing canonical correct answer"
)
system_prompt: Optional[str] = Field(None, description="System prompt to use")
prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '<think>')")
prefill: Optional[str] = Field(
None, description="Text to prefill the completion with (e.g. '<think>')"
)
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset")
max_generations_per_prompt: int = Field(
1, description="Number of generations per prompt for collection"
@ -137,21 +139,21 @@ class DatasetEnv(BaseEnv):
# Extract user prompt and answer from item
user_content = dict(item[0][0])["content"]
answer = item[1] if len(item) > 1 else None
# Create messages list
messages = []
if self.config.system_prompt:
messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": user_content})
# Add prefill as assistant message if configured
if self.config.prefill:
messages.append({"role": "assistant", "content": self.config.prefill})
# Convert messages to a prompt string using the tokenizer
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
# Calculate max tokens for generation (with optional warmup)
max_tokens = self.config.max_tokens
if self.config.length_warmup_steps > 0:
@ -160,7 +162,7 @@ class DatasetEnv(BaseEnv):
self.config.min_tokens
+ warmup_progress * (self.config.max_tokens - self.config.min_tokens)
)
# Generate completion using completions API
completions = await self.server.completion(
prompt=prompt,
@ -169,34 +171,38 @@ class DatasetEnv(BaseEnv):
temperature=self.config.temperature,
top_p=self.config.top_p,
)
to_score = []
to_backlog = []
# Process completions
for completion in completions.choices:
# Get the completion text
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
completion_text = (
completion.text
if hasattr(completion, "text")
else completion.message.content
)
# Build full message sequence for scoring
full_messages = []
if self.config.system_prompt:
full_messages.append({"role": "system", "content": self.config.system_prompt})
full_messages.append(
{"role": "system", "content": self.config.system_prompt}
)
full_messages.append({"role": "user", "content": user_content})
# Combine prefill with completion if prefill was used
response_content = completion_text
if self.config.prefill:
response_content = self.config.prefill + completion_text
full_messages.append({"role": "assistant", "content": response_content})
# Add to scoring list with answer and ground truth
to_score.append(
(full_messages, answer, item[2] if len(item) > 2 else None)
)
to_score.append((full_messages, answer, item[2] if len(item) > 2 else None))
return to_score, to_backlog
async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]:
@ -204,27 +210,27 @@ class DatasetEnv(BaseEnv):
async def collect_trajectories(self, item: Item) -> Tuple[List, List]:
self.current_item = item
# Extract user prompt from item
user_content = dict(item[0][0])["content"]
# Create messages list
messages = []
if self.config.system_prompt:
messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": user_content})
# Add prefill as assistant message if configured
if self.config.prefill:
messages.append({"role": "assistant", "content": self.config.prefill})
# Convert messages to a prompt string using the tokenizer
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
# Calculate max tokens for generation (with optional warmup)
max_tokens = self.config.max_tokens
# Generate completions
completions = await self.server.completion(
prompt=prompt,
@ -233,30 +239,36 @@ class DatasetEnv(BaseEnv):
temperature=self.config.temperature,
top_p=self.config.top_p,
)
print(f"Completions: {completions}")
# Process completions
trajectories = []
for completion in completions.choices:
# Get the completion text
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
completion_text = (
completion.text
if hasattr(completion, "text")
else completion.message.content
)
# Build complete message sequence
full_messages = []
if self.config.system_prompt:
full_messages.append({"role": "system", "content": self.config.system_prompt})
full_messages.append(
{"role": "system", "content": self.config.system_prompt}
)
full_messages.append({"role": "user", "content": user_content})
# Combine prefill with completion if prefill was used
response_content = completion_text
if self.config.prefill:
response_content = self.config.prefill + completion_text
full_messages.append({"role": "assistant", "content": response_content})
trajectories.append(full_messages)
return trajectories, []
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
@ -402,6 +414,7 @@ class DatasetEnv(BaseEnv):
await super().wandb_log(metrics)
if __name__ == "__main__":
# Launch the DatasetEnv via the BaseEnv CLI (serve or process)
DatasetEnv.cli()