Create optimizer_benchmark_environmenr.py

This commit is contained in:
ParsaIdp 2025-05-18 18:14:10 -07:00 committed by GitHub
parent f9a444b6f2
commit 71f6d48e87
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -0,0 +1,16 @@
from atroposlib.envs.base import BaseEnvConfig, BaseEnv
from atropos.environments.optimizer.wrapper import score_optimizer
class OptimizerBenchmarkEnvConfig(BaseEnvConfig):
architecture: str = "mnist"
class OptimizerBenchmarkEnvironment(BaseEnv):
env_config_cls = OptimizerBenchmarkEnvConfig
def __init__(self, config: OptimizerBenchmarkEnvConfig, server_configs=None, slurm=False, testing=False):
super().__init__(config, server_configs, slurm, testing)
self.architecture = config.architecture
async def evaluate(self, optimizer_code: str, *args, **kwargs):
# This method is required by BaseEnv
return score_optimizer(optimizer_code, self.architecture)