atropos/environments/gsm8k_server_student_distill.py
Jai Suphavadeeprasit 01e25707b0 student
2026-03-17 12:02:48 -04:00

41 lines
1.3 KiB
Python

from typing import Tuple
from atroposlib.envs.base import APIServerConfig, ServerBaseline
from atroposlib.envs.student_distillation_env import (
StudentDistillationConfig,
StudentDistillationEnv,
)
from environments.gsm8k_server import GSM8kEnv
class GSM8kStudentDistillEnv(GSM8kEnv, StudentDistillationEnv):
name = "gsm8k_student_distill"
env_config_cls = StudentDistillationConfig
@classmethod
def config_init(cls) -> Tuple[StudentDistillationConfig, ServerBaseline]:
env_config = StudentDistillationConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
wandb_name="gsm8k_student_distill",
student_distill_enabled=True,
student_top_k=4,
)
server_config = APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
)
return env_config, server_config
if __name__ == "__main__":
GSM8kStudentDistillEnv.cli()