mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
feat: add HumorEnv environment for humor dataset in hack0 directory
This commit is contained in:
parent
0944f5aa81
commit
24a350bc71
1 changed files with 93 additions and 0 deletions
93
environments/hack0/llm_humor_server.py
Normal file
93
environments/hack0/llm_humor_server.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import os
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class HumorEnvConfig(BaseEnvConfig):
|
||||
data_path: str = "environments/hack0/humor_dataset.jsonl"
|
||||
|
||||
|
||||
class HumorEnv(BaseEnv):
|
||||
env_config_cls = HumorEnvConfig
|
||||
name = "humor"
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[HumorEnvConfig, List[APIServerConfig]]:
|
||||
env_config = cls.env_config_cls(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=2,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1024,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="humor",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt-4o-mini",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=256,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
ds = load_dataset("json", data_files=self.config.data_path, split="train")
|
||||
self.train = ds
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self) -> Tuple[dict]:
|
||||
record = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return (record,)
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
|
||||
record = item[0]
|
||||
prompt = record["question"]
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
to_score = []
|
||||
for choice in chat_completions.choices:
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": choice.message.content},
|
||||
]
|
||||
to_score.append((tuple(messages), choice.finish_reason))
|
||||
scored = await self.score(to_score)
|
||||
return scored, []
|
||||
|
||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||
scores = ScoredDataGroup(tokens=[], masks=[], scores=[])
|
||||
for (messages, _), idx in zip(rollout_group_data, range(len(rollout_group_data))):
|
||||
expected = self.train[idx % len(self.train)]["response"].strip()
|
||||
output = messages[-1]["content"].strip()
|
||||
score_val = 1.0 if output == expected else 0.0
|
||||
out = tokenize_for_trainer(self.tokenizer, list(messages))
|
||||
scores["tokens"].append(out["tokens"])
|
||||
scores["masks"].append(out["masks"])
|
||||
scores["scores"].append(score_val)
|
||||
return scores
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[dict] = None):
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
HumorEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue