diff --git a/environments/optimizer/evaluator.py b/environments/optimizer/evaluator.py new file mode 100644 index 00000000..7feb463b --- /dev/null +++ b/environments/optimizer/evaluator.py @@ -0,0 +1,62 @@ +from verdict.schema import Schema +from verdict import Pipeline, Layer +from verdict.common.judge import JudgeUnit +from verdict.scale import ContinuousScale +from verdict.transform import MaxPoolUnit + + +class OptimizerEvaluator: + def __init__(self): + self.pipeline = ( + Pipeline() + >> Layer( + JudgeUnit(scale=ContinuousScale(1, 10)).prompt( + ( + "You are a judge that is an expert at evaluating optimizers for their novelty " + "as they will be accepted to a prestigious research conference. Given the following " + "optimizer code and its architecture/use-case, you must rate it on a scale of 1 to 10 " + "based on how novel it is and its impactfulness in speeding up model training. " + "Here is the code: {source.optimizer_code}\n" + "Here is the architecture: {source.architecture}" + ) + ), + repeat=3, + ).via("xai/grok-3-latest") + >> MaxPoolUnit() + ) + + def run(self, optimizer_code: str, architecture: str) -> int: + schema = Schema.of( + optimizer_code=optimizer_code, + architecture=architecture, + ) + response, _ = self.pipeline.run(schema) + final_score = self.__get_final_score(response) + return final_score + + def __get_final_score(self, response: dict) -> float: + return response.get("Pipeline_root.block.block.unit[Map MaxPool]_score", 0.0) + + +if __name__ == "__main__": + evaluator = OptimizerEvaluator() + + optimizer_code = """ +import torch + +# Define parameter (requires_grad=True) +x = torch.tensor([0.0], requires_grad=True) +optimizer = torch.optim.SGD([x], lr=0.1) + +for step in range(20): + optimizer.zero_grad() + loss = (x - 3) ** 2 + loss.backward() + optimizer.step() + print(f"Step {step + 1}: x = {x.item():.4f}, loss = {loss.item():.4f}") + +print(f"\nOptimal x: {x.item():.4f}") + """ + + score = evaluator.run(optimizer_code=optimizer_code, architecture="MLP") + print(score) diff --git a/environments/optimizer/requirements.txt b/environments/optimizer/requirements.txt new file mode 100644 index 00000000..0c6e57d5 --- /dev/null +++ b/environments/optimizer/requirements.txt @@ -0,0 +1,247 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml -o requirements.txt +aiohappyeyeballs==2.4.6 + # via aiohttp +aiohttp==3.11.12 + # via + # datasets + # fsspec + # instructor + # litellm +aiosignal==1.3.2 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +anyio==4.8.0 + # via + # httpx + # openai +async-timeout==5.0.1 + # via aiohttp +attrs==25.1.0 + # via + # aiohttp + # jsonschema + # referencing +certifi==2025.1.31 + # via + # httpcore + # httpx + # requests +charset-normalizer==3.4.1 + # via requests +click==8.1.8 + # via + # litellm + # typer +datasets==2.2.1 + # via verdict (pyproject.toml) +dill==0.3.9 + # via + # verdict (pyproject.toml) + # datasets + # multiprocess +distro==1.9.0 + # via openai +docstring-parser==0.16 + # via instructor +eval-type-backport==0.2.2 + # via verdict (pyproject.toml) +exceptiongroup==1.2.2 + # via anyio +filelock==3.17.0 + # via huggingface-hub +frozenlist==1.5.0 + # via + # aiohttp + # aiosignal +fsspec==2025.2.0 + # via + # datasets + # huggingface-hub +graphviz==0.20.3 + # via verdict (pyproject.toml) +h11==0.14.0 + # via httpcore +httpcore==1.0.7 + # via httpx +httpx==0.28.1 + # via + # litellm + # openai +huggingface-hub==0.28.1 + # via + # datasets + # tokenizers +idna==3.10 + # via + # anyio + # httpx + # requests + # yarl +importlib-metadata==8.6.1 + # via litellm +instructor==1.7.2 + # via verdict (pyproject.toml) +jinja2==3.1.5 + # via + # instructor + # litellm +jiter==0.8.2 + # via + # instructor + # openai +joblib==1.4.2 + # via scikit-learn +jsonschema==4.23.0 + # via litellm +jsonschema-specifications==2024.10.1 + # via jsonschema +krippendorff==0.8.1 + # via verdict (pyproject.toml) +litellm==1.61.9 + # via verdict (pyproject.toml) +loguru==0.7.3 + # via verdict (pyproject.toml) +markdown-it-py==3.0.0 + # via rich +markupsafe==3.0.2 + # via jinja2 +mdurl==0.1.2 + # via markdown-it-py +multidict==6.1.0 + # via + # aiohttp + # yarl +multiprocess==0.70.17 + # via datasets +networkx==3.2.1 + # via verdict (pyproject.toml) +numpy==2.0.2 + # via + # verdict (pyproject.toml) + # datasets + # krippendorff + # pandas + # scikit-learn + # scipy +openai==1.63.2 + # via + # verdict (pyproject.toml) + # instructor + # litellm +packaging==24.2 + # via + # datasets + # huggingface-hub +pandas==2.2.3 + # via + # verdict (pyproject.toml) + # datasets +pillow==11.1.0 + # via verdict (pyproject.toml) +propcache==0.2.1 + # via + # aiohttp + # yarl +pyarrow==19.0.1 + # via datasets +pydantic==2.10.6 + # via + # verdict (pyproject.toml) + # instructor + # litellm + # openai +pydantic-core==2.27.2 + # via + # instructor + # pydantic +pygments==2.19.1 + # via rich +python-dateutil==2.9.0.post0 + # via pandas +python-dotenv==1.0.1 + # via litellm +pytz==2025.1 + # via pandas +pyyaml==6.0.2 + # via huggingface-hub +referencing==0.36.2 + # via + # jsonschema + # jsonschema-specifications +regex==2024.11.6 + # via tiktoken +requests==2.32.3 + # via + # datasets + # huggingface-hub + # instructor + # responses + # tiktoken +responses==0.18.0 + # via datasets +rich==13.9.4 + # via + # verdict (pyproject.toml) + # instructor + # typer +rpds-py==0.22.3 + # via + # jsonschema + # referencing +scikit-learn==1.6.1 + # via verdict (pyproject.toml) +scipy==1.13.1 + # via + # verdict (pyproject.toml) + # scikit-learn +shellingham==1.5.4 + # via typer +six==1.17.0 + # via python-dateutil +sniffio==1.3.1 + # via + # anyio + # openai +tenacity==9.0.0 + # via instructor +threadpoolctl==3.5.0 + # via scikit-learn +tiktoken==0.9.0 + # via litellm +tokenizers==0.21.0 + # via + # verdict (pyproject.toml) + # litellm +tqdm==4.67.1 + # via + # datasets + # huggingface-hub + # openai +typer==0.15.1 + # via instructor +typing-extensions==4.12.2 + # via + # verdict (pyproject.toml) + # anyio + # huggingface-hub + # multidict + # openai + # pydantic + # pydantic-core + # referencing + # rich + # typer +tzdata==2025.1 + # via pandas +urllib3==2.3.0 + # via + # requests + # responses +xxhash==3.5.0 + # via datasets +yarl==1.18.3 + # via aiohttp +zipp==3.21.0 + # via importlib-metadata \ No newline at end of file