Update wrapper.py

This commit is contained in:
ParsaIdp 2025-05-18 17:49:31 -07:00 committed by GitHub
parent 3054e5f608
commit 856b437b3a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,5 +1,6 @@
from modal.functions import Function
from evaluator import OptimizerEvaluator
from FOB.optimizer_benchmark_env import OptimizerBenchmarkEnv
def score_optimizer(optimizer_code: str, architecture: str):
"""
@ -21,20 +22,31 @@ def score_optimizer(optimizer_code: str, architecture: str):
validity = evaluator.check_validity(optimizer_code=optimizer_code, stdout=response_obj["stdout"], stderr=response_obj["stderr"])
if not validity:
if validity:
return 0
else:
score = evaluator.score_optimizer(optimizer_code=optimizer_code, architecture=architecture)
return score
# Get evaluator score
score = evaluator.score(optimizer_code=optimizer_code, architecture=architecture)
# Use OptimizerBenchmarkEnv to get the reward
env = OptimizerBenchmarkEnv()
env.submit_optimizer(optimizer_code, 'custom_optimizer')
env.generate_experiment_yaml()
env.run_benchmark()
reward = env.get_reward()
# Return the sum of both
return score + reward
if __name__ == "__main__":
test_code = """
import torch
import math
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from torch.optim import SGD
from pytorch_fob.engine.parameter_groups import GroupedModel
from pytorch_fob.engine.configs import OptimizerConfig
class CustomOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr=0.01, beta1=0.9, beta2=0.999, eps=1e-8):
defaults = dict(lr=lr, beta1=beta1, beta2=beta2, eps=eps)
def configure_optimizers(model: GroupedModel, config: OptimizerConfig) -> OptimizerLRScheduler:
lr = config.learning_rate
optimizer = SGD(model.grouped_parameters(lr=lr), lr=lr)
return {"optimizer": optimizer}
"""
result = score_optimizer(test_code, "mnist")
print(result)