mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
refactor:final production-ready audit; remove debug artifacts and non-ASCII characters
This commit is contained in:
parent
287b7e7250
commit
8cd30c3703
3 changed files with 14 additions and 16 deletions
|
|
@ -68,7 +68,7 @@ class OpenAIServer(APIServer):
|
|||
kwargs.get("messages", None) is not None
|
||||
), "Messages are required for chat completion!"
|
||||
# DEBUG: Print the request being sent to vLLM
|
||||
# print(f"\n🚀 DEBUG: OpenAI Request Keywords: {kwargs}")
|
||||
# OpenAI request keywords for completion
|
||||
|
||||
try:
|
||||
if self.config.n_kwarg_is_ignored:
|
||||
|
|
@ -106,13 +106,8 @@ class OpenAIServer(APIServer):
|
|||
completions.choices.extend(c.choices)
|
||||
return completions
|
||||
except Exception as e:
|
||||
print(f"\n❌ DEBUG: OpenAI API Error: {type(e).__name__}: {e}")
|
||||
if hasattr(e, "response"):
|
||||
try:
|
||||
print(f"DEBUG: Response Body: {e.response.text}")
|
||||
except:
|
||||
pass
|
||||
raise e
|
||||
self.logger.error("OpenAI API Error: %s: %s", type(e).__name__, e)
|
||||
raise
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ Environment pattern follows sql_query_env for consistency.
|
|||
"""
|
||||
|
||||
import random
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from code_executor import (
|
||||
|
|
@ -26,6 +27,8 @@ from atroposlib.envs.base import (
|
|||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# System prompt following established Atropos patterns
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought "
|
||||
|
|
@ -165,7 +168,7 @@ class CodeDebugEnv(BaseEnv):
|
|||
"""Load the HumanEvalPack dataset (HumanEvalFix) and prepare train/test splits."""
|
||||
from datasets import load_dataset
|
||||
|
||||
print("Loading HumanEvalPack (python) dataset...")
|
||||
logger.info("Loading HumanEvalPack (python) dataset...")
|
||||
dataset = load_dataset("bigcode/humanevalpack", "python", split="test")
|
||||
|
||||
all_items: List[CodeDebugItem] = []
|
||||
|
|
@ -181,7 +184,7 @@ class CodeDebugEnv(BaseEnv):
|
|||
}
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_items)} problems")
|
||||
logger.info("Loaded %d problems", len(all_items))
|
||||
|
||||
# Verify a few items actually work with canonical solutions
|
||||
verified = 0
|
||||
|
|
@ -192,7 +195,7 @@ class CodeDebugEnv(BaseEnv):
|
|||
)
|
||||
if passed:
|
||||
verified += 1
|
||||
print(f"Verified {verified}/10 canonical solutions execute correctly")
|
||||
logger.info("Verified %d/10 canonical solutions execute correctly", verified)
|
||||
|
||||
# Split 80/20 train/test
|
||||
random.shuffle(all_items)
|
||||
|
|
@ -200,7 +203,7 @@ class CodeDebugEnv(BaseEnv):
|
|||
self.train = all_items[:split_idx]
|
||||
self.test = all_items[split_idx:]
|
||||
|
||||
print(f"Train: {len(self.train)}, Test: {len(self.test)}")
|
||||
logger.info("Train: %d, Test: %d", len(self.train), len(self.test))
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
|
|
@ -231,7 +234,7 @@ class CodeDebugEnv(BaseEnv):
|
|||
if all_passed:
|
||||
return 1.0, False
|
||||
|
||||
# Check for partial credit — how many tests pass?
|
||||
# Check for partial credit - how many tests pass?
|
||||
passed, total = count_test_results(
|
||||
generated_code, item["test"], item["entry_point"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ def _counting_check(candidate):
|
|||
try:
|
||||
src = inspect.getsource(_orig_check)
|
||||
except (TypeError, OSError):
|
||||
# Can't inspect — just run it
|
||||
# Can't inspect - just run it
|
||||
try:
|
||||
_orig_check(candidate)
|
||||
print("1/1")
|
||||
|
|
@ -155,12 +155,12 @@ def _counting_check(candidate):
|
|||
assert_lines = [l.strip() for l in src.split('\\n') if l.strip().startswith('assert')]
|
||||
_total = max(len(assert_lines), 1)
|
||||
|
||||
# Run the full check — if it passes, all assertions passed
|
||||
# Run the full check - if it passes, all assertions passed
|
||||
try:
|
||||
_orig_check(candidate)
|
||||
_passed = _total
|
||||
except AssertionError:
|
||||
# Some failed — try to count
|
||||
# Some failed - try to count
|
||||
_passed = 0
|
||||
for line in assert_lines:
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue