mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
student
This commit is contained in:
parent
c421582b6f
commit
01e25707b0
3 changed files with 480 additions and 0 deletions
41
environments/gsm8k_server_student_distill.py
Normal file
41
environments/gsm8k_server_student_distill.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
from typing import Tuple
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, ServerBaseline
|
||||
from atroposlib.envs.student_distillation_env import (
|
||||
StudentDistillationConfig,
|
||||
StudentDistillationEnv,
|
||||
)
|
||||
|
||||
from environments.gsm8k_server import GSM8kEnv
|
||||
|
||||
|
||||
class GSM8kStudentDistillEnv(GSM8kEnv, StudentDistillationEnv):
|
||||
name = "gsm8k_student_distill"
|
||||
env_config_cls = StudentDistillationConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[StudentDistillationConfig, ServerBaseline]:
|
||||
env_config = StudentDistillationConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="gsm8k_student_distill",
|
||||
student_distill_enabled=True,
|
||||
student_top_k=4,
|
||||
)
|
||||
server_config = APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
)
|
||||
return env_config, server_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kStudentDistillEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue