mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
Intern bootcamp env (#146)
* Created registry and started off the env * Local testing works * process working but error in gen * removed old code * adding debug, it's still not progressing to collect trajectories * linting * removed redundant settings
This commit is contained in:
parent
ea304892ee
commit
283877dd88
8 changed files with 1218 additions and 0 deletions
266
environments/intern_bootcamp/bootcamp_registry.py
Normal file
266
environments/intern_bootcamp/bootcamp_registry.py
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
"""
|
||||
Bootcamp Registry for InternBootcamp Environment
|
||||
|
||||
This module provides a registry system for dynamically discovering and managing
|
||||
InternBootcamp tasks without having to manually import each one.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BootcampRegistry:
|
||||
"""Registry for InternBootcamp tasks with dynamic discovery."""
|
||||
|
||||
def __init__(self):
|
||||
self._registry: Dict[str, Type] = {}
|
||||
self._discovered = False
|
||||
|
||||
def discover_bootcamps(self) -> None:
|
||||
"""Dynamically discover all available bootcamp classes from InternBootcamp."""
|
||||
if self._discovered:
|
||||
return
|
||||
|
||||
try:
|
||||
# Import the internbootcamp.bootcamp module
|
||||
bootcamp_module = importlib.import_module("internbootcamp.bootcamp")
|
||||
|
||||
# Get all attributes from the module
|
||||
for name in dir(bootcamp_module):
|
||||
if name.endswith("bootcamp") and not name.startswith("_"):
|
||||
try:
|
||||
obj = getattr(bootcamp_module, name)
|
||||
# Check if it's a class and has the required methods
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and hasattr(obj, "case_generator")
|
||||
and hasattr(obj, "prompt_func")
|
||||
and hasattr(obj, "verify_score")
|
||||
):
|
||||
self._registry[name] = obj
|
||||
logger.debug(f"Registered bootcamp: {name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register {name}: {e}")
|
||||
|
||||
self._discovered = True
|
||||
logger.info(f"Discovered {len(self._registry)} bootcamp tasks")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import internbootcamp.bootcamp: {e}")
|
||||
raise
|
||||
|
||||
def get_bootcamp_class(self, name: str) -> Type:
|
||||
"""Get a bootcamp class by name."""
|
||||
if not self._discovered:
|
||||
self.discover_bootcamps()
|
||||
|
||||
if name not in self._registry:
|
||||
available = self.list_available_bootcamps()
|
||||
raise ValueError(
|
||||
f"Unknown bootcamp: {name}. "
|
||||
f"Available bootcamps: {', '.join(available[:10])}..."
|
||||
f" ({len(available)} total)"
|
||||
)
|
||||
|
||||
return self._registry[name]
|
||||
|
||||
def create_bootcamp_instance(self, name: str, **params) -> Any:
|
||||
"""Create an instance of a bootcamp with given parameters."""
|
||||
bootcamp_class = self.get_bootcamp_class(name)
|
||||
|
||||
# Get the __init__ signature to see what parameters are accepted
|
||||
try:
|
||||
sig = inspect.signature(bootcamp_class.__init__)
|
||||
valid_params = {}
|
||||
|
||||
# Filter out parameters that the bootcamp doesn't accept
|
||||
for param_name, param_value in params.items():
|
||||
if param_name in sig.parameters:
|
||||
valid_params[param_name] = param_value
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parameter '{param_name}' not accepted by {name}, ignoring"
|
||||
)
|
||||
|
||||
return bootcamp_class(**valid_params)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create instance of {name}: {e}")
|
||||
# Try with no parameters as fallback
|
||||
try:
|
||||
return bootcamp_class()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def list_available_bootcamps(self) -> List[str]:
|
||||
"""List all available bootcamp names."""
|
||||
if not self._discovered:
|
||||
self.discover_bootcamps()
|
||||
return sorted(list(self._registry.keys()))
|
||||
|
||||
def get_bootcamp_info(self, name: str) -> Dict[str, Any]:
|
||||
"""Get information about a specific bootcamp."""
|
||||
bootcamp_class = self.get_bootcamp_class(name)
|
||||
|
||||
info = {
|
||||
"name": name,
|
||||
"class": bootcamp_class,
|
||||
"docstring": inspect.getdoc(bootcamp_class) or "No documentation available",
|
||||
"parameters": {},
|
||||
}
|
||||
|
||||
# Get __init__ parameters
|
||||
try:
|
||||
sig = inspect.signature(bootcamp_class.__init__)
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name not in ["self"]:
|
||||
param_info = {
|
||||
"default": (
|
||||
param.default
|
||||
if param.default != inspect.Parameter.empty
|
||||
else None
|
||||
),
|
||||
"annotation": (
|
||||
str(param.annotation)
|
||||
if param.annotation != inspect.Parameter.empty
|
||||
else None
|
||||
),
|
||||
}
|
||||
info["parameters"][param_name] = param_info
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not inspect parameters for {name}: {e}")
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class RandomTask:
|
||||
"""Special bootcamp that randomly selects from available bootcamps on each call."""
|
||||
|
||||
def __init__(self, **params):
|
||||
self.registry = BootcampRegistry()
|
||||
self.registry.discover_bootcamps()
|
||||
self.available_bootcamps = self.registry.list_available_bootcamps()
|
||||
# Remove base classes and template classes from the list
|
||||
self.available_bootcamps = [
|
||||
name
|
||||
for name in self.available_bootcamps
|
||||
if not any(x in name.lower() for x in ["base", "template", "{puzzlename}"])
|
||||
]
|
||||
self.params = params
|
||||
self.current_bootcamp = None
|
||||
self.current_bootcamp_name = None
|
||||
logger.info(
|
||||
f"RandomTask initialized with {len(self.available_bootcamps)} available bootcamps"
|
||||
)
|
||||
|
||||
def case_generator(self) -> object:
|
||||
"""Generate a case by randomly selecting a bootcamp."""
|
||||
# Select a random bootcamp
|
||||
self.current_bootcamp_name = random.choice(self.available_bootcamps)
|
||||
self.current_bootcamp = self.registry.create_bootcamp_instance(
|
||||
self.current_bootcamp_name, **self.params
|
||||
)
|
||||
|
||||
# Generate case from the selected bootcamp
|
||||
case = self.current_bootcamp.case_generator()
|
||||
|
||||
# Add bootcamp name to the case for tracking
|
||||
if isinstance(case, dict):
|
||||
case["_bootcamp_name"] = self.current_bootcamp_name
|
||||
else:
|
||||
# If case is not a dict, wrap it
|
||||
case = {"data": case, "_bootcamp_name": self.current_bootcamp_name}
|
||||
|
||||
return case
|
||||
|
||||
def prompt_func(self, identity) -> str:
|
||||
"""Generate prompt using the current bootcamp."""
|
||||
# Extract the bootcamp name if stored
|
||||
bootcamp_name = identity.get("_bootcamp_name", self.current_bootcamp_name)
|
||||
|
||||
# If we need to recreate the bootcamp (e.g., during scoring)
|
||||
if not self.current_bootcamp or self.current_bootcamp_name != bootcamp_name:
|
||||
self.current_bootcamp_name = bootcamp_name
|
||||
self.current_bootcamp = self.registry.create_bootcamp_instance(
|
||||
bootcamp_name, **self.params
|
||||
)
|
||||
|
||||
# Remove the bootcamp name before passing to prompt_func
|
||||
identity_copy = dict(identity)
|
||||
identity_copy.pop("_bootcamp_name", None)
|
||||
if "data" in identity_copy and len(identity_copy) == 1:
|
||||
identity_copy = identity_copy["data"]
|
||||
|
||||
return self.current_bootcamp.prompt_func(identity_copy)
|
||||
|
||||
@classmethod
|
||||
def extract_output(cls, output):
|
||||
"""This should not be called directly for RandomTask."""
|
||||
raise NotImplementedError(
|
||||
"RandomTask does not implement extract_output directly"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _verify_correction(cls, solution, identity):
|
||||
"""This should not be called directly for RandomTask."""
|
||||
raise NotImplementedError(
|
||||
"RandomTask does not implement _verify_correction directly"
|
||||
)
|
||||
|
||||
def verify_score(
|
||||
self,
|
||||
model_output,
|
||||
identity,
|
||||
format_score=0,
|
||||
short_penalty=True,
|
||||
short_threshold=100,
|
||||
format_penalty=True,
|
||||
) -> float:
|
||||
"""Verify score using the appropriate bootcamp."""
|
||||
# Extract the bootcamp name
|
||||
bootcamp_name = identity.get("_bootcamp_name", self.current_bootcamp_name)
|
||||
|
||||
# If we need to recreate the bootcamp
|
||||
if not self.current_bootcamp or self.current_bootcamp_name != bootcamp_name:
|
||||
self.current_bootcamp_name = bootcamp_name
|
||||
self.current_bootcamp = self.registry.create_bootcamp_instance(
|
||||
bootcamp_name, **self.params
|
||||
)
|
||||
|
||||
# Remove the bootcamp name before passing to verify_score
|
||||
identity_copy = dict(identity)
|
||||
identity_copy.pop("_bootcamp_name", None)
|
||||
if "data" in identity_copy and len(identity_copy) == 1:
|
||||
identity_copy = identity_copy["data"]
|
||||
|
||||
# Call the bootcamp's verify_score method
|
||||
return self.current_bootcamp.verify_score(
|
||||
model_output,
|
||||
identity_copy,
|
||||
format_score,
|
||||
short_penalty,
|
||||
short_threshold,
|
||||
format_penalty,
|
||||
)
|
||||
|
||||
|
||||
# Global registry instance
|
||||
bootcamp_registry = BootcampRegistry()
|
||||
|
||||
|
||||
def get_available_bootcamps() -> List[str]:
|
||||
"""Get a list of all available bootcamp names."""
|
||||
return bootcamp_registry.list_available_bootcamps()
|
||||
|
||||
|
||||
def create_bootcamp(name: str, **params) -> Any:
|
||||
"""Create a bootcamp instance by name with parameters."""
|
||||
# Special handling for RandomTask
|
||||
if name == "RandomTask":
|
||||
return RandomTask(**params)
|
||||
return bootcamp_registry.create_bootcamp_instance(name, **params)
|
||||
Loading…
Add table
Add a link
Reference in a new issue