mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-30 17:40:36 +00:00
feat: optimizer evaluator
This commit is contained in:
parent
c189fc3351
commit
fa42662c54
2 changed files with 309 additions and 0 deletions
62
environments/optimizer/evaluator.py
Normal file
62
environments/optimizer/evaluator.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from verdict.schema import Schema
|
||||
from verdict import Pipeline, Layer
|
||||
from verdict.common.judge import JudgeUnit
|
||||
from verdict.scale import ContinuousScale
|
||||
from verdict.transform import MaxPoolUnit
|
||||
|
||||
|
||||
class OptimizerEvaluator:
|
||||
def __init__(self):
|
||||
self.pipeline = (
|
||||
Pipeline()
|
||||
>> Layer(
|
||||
JudgeUnit(scale=ContinuousScale(1, 10)).prompt(
|
||||
(
|
||||
"You are a judge that is an expert at evaluating optimizers for their novelty "
|
||||
"as they will be accepted to a prestigious research conference. Given the following "
|
||||
"optimizer code and its architecture/use-case, you must rate it on a scale of 1 to 10 "
|
||||
"based on how novel it is and its impactfulness in speeding up model training. "
|
||||
"Here is the code: {source.optimizer_code}\n"
|
||||
"Here is the architecture: {source.architecture}"
|
||||
)
|
||||
),
|
||||
repeat=3,
|
||||
).via("xai/grok-3-latest")
|
||||
>> MaxPoolUnit()
|
||||
)
|
||||
|
||||
def run(self, optimizer_code: str, architecture: str) -> int:
|
||||
schema = Schema.of(
|
||||
optimizer_code=optimizer_code,
|
||||
architecture=architecture,
|
||||
)
|
||||
response, _ = self.pipeline.run(schema)
|
||||
final_score = self.__get_final_score(response)
|
||||
return final_score
|
||||
|
||||
def __get_final_score(self, response: dict) -> float:
|
||||
return response.get("Pipeline_root.block.block.unit[Map MaxPool]_score", 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluator = OptimizerEvaluator()
|
||||
|
||||
optimizer_code = """
|
||||
import torch
|
||||
|
||||
# Define parameter (requires_grad=True)
|
||||
x = torch.tensor([0.0], requires_grad=True)
|
||||
optimizer = torch.optim.SGD([x], lr=0.1)
|
||||
|
||||
for step in range(20):
|
||||
optimizer.zero_grad()
|
||||
loss = (x - 3) ** 2
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print(f"Step {step + 1}: x = {x.item():.4f}, loss = {loss.item():.4f}")
|
||||
|
||||
print(f"\nOptimal x: {x.item():.4f}")
|
||||
"""
|
||||
|
||||
score = evaluator.run(optimizer_code=optimizer_code, architecture="MLP")
|
||||
print(score)
|
||||
247
environments/optimizer/requirements.txt
Normal file
247
environments/optimizer/requirements.txt
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile pyproject.toml -o requirements.txt
|
||||
aiohappyeyeballs==2.4.6
|
||||
# via aiohttp
|
||||
aiohttp==3.11.12
|
||||
# via
|
||||
# datasets
|
||||
# fsspec
|
||||
# instructor
|
||||
# litellm
|
||||
aiosignal==1.3.2
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.8.0
|
||||
# via
|
||||
# httpx
|
||||
# openai
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.1.0
|
||||
# via
|
||||
# aiohttp
|
||||
# jsonschema
|
||||
# referencing
|
||||
certifi==2025.1.31
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
charset-normalizer==3.4.1
|
||||
# via requests
|
||||
click==8.1.8
|
||||
# via
|
||||
# litellm
|
||||
# typer
|
||||
datasets==2.2.1
|
||||
# via verdict (pyproject.toml)
|
||||
dill==0.3.9
|
||||
# via
|
||||
# verdict (pyproject.toml)
|
||||
# datasets
|
||||
# multiprocess
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.16
|
||||
# via instructor
|
||||
eval-type-backport==0.2.2
|
||||
# via verdict (pyproject.toml)
|
||||
exceptiongroup==1.2.2
|
||||
# via anyio
|
||||
filelock==3.17.0
|
||||
# via huggingface-hub
|
||||
frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2025.2.0
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
graphviz==0.20.3
|
||||
# via verdict (pyproject.toml)
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
httpcore==1.0.7
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# litellm
|
||||
# openai
|
||||
huggingface-hub==0.28.1
|
||||
# via
|
||||
# datasets
|
||||
# tokenizers
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
importlib-metadata==8.6.1
|
||||
# via litellm
|
||||
instructor==1.7.2
|
||||
# via verdict (pyproject.toml)
|
||||
jinja2==3.1.5
|
||||
# via
|
||||
# instructor
|
||||
# litellm
|
||||
jiter==0.8.2
|
||||
# via
|
||||
# instructor
|
||||
# openai
|
||||
joblib==1.4.2
|
||||
# via scikit-learn
|
||||
jsonschema==4.23.0
|
||||
# via litellm
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
krippendorff==0.8.1
|
||||
# via verdict (pyproject.toml)
|
||||
litellm==1.61.9
|
||||
# via verdict (pyproject.toml)
|
||||
loguru==0.7.3
|
||||
# via verdict (pyproject.toml)
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.2
|
||||
# via jinja2
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
multidict==6.1.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.17
|
||||
# via datasets
|
||||
networkx==3.2.1
|
||||
# via verdict (pyproject.toml)
|
||||
numpy==2.0.2
|
||||
# via
|
||||
# verdict (pyproject.toml)
|
||||
# datasets
|
||||
# krippendorff
|
||||
# pandas
|
||||
# scikit-learn
|
||||
# scipy
|
||||
openai==1.63.2
|
||||
# via
|
||||
# verdict (pyproject.toml)
|
||||
# instructor
|
||||
# litellm
|
||||
packaging==24.2
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
pandas==2.2.3
|
||||
# via
|
||||
# verdict (pyproject.toml)
|
||||
# datasets
|
||||
pillow==11.1.0
|
||||
# via verdict (pyproject.toml)
|
||||
propcache==0.2.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
pyarrow==19.0.1
|
||||
# via datasets
|
||||
pydantic==2.10.6
|
||||
# via
|
||||
# verdict (pyproject.toml)
|
||||
# instructor
|
||||
# litellm
|
||||
# openai
|
||||
pydantic-core==2.27.2
|
||||
# via
|
||||
# instructor
|
||||
# pydantic
|
||||
pygments==2.19.1
|
||||
# via rich
|
||||
python-dateutil==2.9.0.post0
|
||||
# via pandas
|
||||
python-dotenv==1.0.1
|
||||
# via litellm
|
||||
pytz==2025.1
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
# via huggingface-hub
|
||||
referencing==0.36.2
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2024.11.6
|
||||
# via tiktoken
|
||||
requests==2.32.3
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# instructor
|
||||
# responses
|
||||
# tiktoken
|
||||
responses==0.18.0
|
||||
# via datasets
|
||||
rich==13.9.4
|
||||
# via
|
||||
# verdict (pyproject.toml)
|
||||
# instructor
|
||||
# typer
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
scikit-learn==1.6.1
|
||||
# via verdict (pyproject.toml)
|
||||
scipy==1.13.1
|
||||
# via
|
||||
# verdict (pyproject.toml)
|
||||
# scikit-learn
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# openai
|
||||
tenacity==9.0.0
|
||||
# via instructor
|
||||
threadpoolctl==3.5.0
|
||||
# via scikit-learn
|
||||
tiktoken==0.9.0
|
||||
# via litellm
|
||||
tokenizers==0.21.0
|
||||
# via
|
||||
# verdict (pyproject.toml)
|
||||
# litellm
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# openai
|
||||
typer==0.15.1
|
||||
# via instructor
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# verdict (pyproject.toml)
|
||||
# anyio
|
||||
# huggingface-hub
|
||||
# multidict
|
||||
# openai
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rich
|
||||
# typer
|
||||
tzdata==2025.1
|
||||
# via pandas
|
||||
urllib3==2.3.0
|
||||
# via
|
||||
# requests
|
||||
# responses
|
||||
xxhash==3.5.0
|
||||
# via datasets
|
||||
yarl==1.18.3
|
||||
# via aiohttp
|
||||
zipp==3.21.0
|
||||
# via importlib-metadata
|
||||
Loading…
Add table
Add a link
Reference in a new issue