mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
refactor:final production-ready audit; remove debug artifacts and non-ASCII characters
This commit is contained in:
parent
287b7e7250
commit
8cd30c3703
3 changed files with 14 additions and 16 deletions
|
|
@ -9,6 +9,7 @@ Environment pattern follows sql_query_env for consistency.
|
|||
"""
|
||||
|
||||
import random
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from code_executor import (
|
||||
|
|
@ -26,6 +27,8 @@ from atroposlib.envs.base import (
|
|||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# System prompt following established Atropos patterns
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought "
|
||||
|
|
@ -165,7 +168,7 @@ class CodeDebugEnv(BaseEnv):
|
|||
"""Load the HumanEvalPack dataset (HumanEvalFix) and prepare train/test splits."""
|
||||
from datasets import load_dataset
|
||||
|
||||
print("Loading HumanEvalPack (python) dataset...")
|
||||
logger.info("Loading HumanEvalPack (python) dataset...")
|
||||
dataset = load_dataset("bigcode/humanevalpack", "python", split="test")
|
||||
|
||||
all_items: List[CodeDebugItem] = []
|
||||
|
|
@ -181,7 +184,7 @@ class CodeDebugEnv(BaseEnv):
|
|||
}
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_items)} problems")
|
||||
logger.info("Loaded %d problems", len(all_items))
|
||||
|
||||
# Verify a few items actually work with canonical solutions
|
||||
verified = 0
|
||||
|
|
@ -192,7 +195,7 @@ class CodeDebugEnv(BaseEnv):
|
|||
)
|
||||
if passed:
|
||||
verified += 1
|
||||
print(f"Verified {verified}/10 canonical solutions execute correctly")
|
||||
logger.info("Verified %d/10 canonical solutions execute correctly", verified)
|
||||
|
||||
# Split 80/20 train/test
|
||||
random.shuffle(all_items)
|
||||
|
|
@ -200,7 +203,7 @@ class CodeDebugEnv(BaseEnv):
|
|||
self.train = all_items[:split_idx]
|
||||
self.test = all_items[split_idx:]
|
||||
|
||||
print(f"Train: {len(self.train)}, Test: {len(self.test)}")
|
||||
logger.info("Train: %d, Test: %d", len(self.train), len(self.test))
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
|
|
@ -231,7 +234,7 @@ class CodeDebugEnv(BaseEnv):
|
|||
if all_passed:
|
||||
return 1.0, False
|
||||
|
||||
# Check for partial credit — how many tests pass?
|
||||
# Check for partial credit - how many tests pass?
|
||||
passed, total = count_test_results(
|
||||
generated_code, item["test"], item["entry_point"]
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue