mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
247 lines
10 KiB
Python
247 lines
10 KiB
Python
"""
|
|
Punchline VR-CLI Environment for Atropos
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import math
|
|
import random
|
|
from typing import Dict, List, Optional, Tuple, TypedDict
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.type_definitions import Item
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
|
|
# ───────────────────────────────────────────────────────
|
|
# Config & data row
|
|
# ───────────────────────────────────────────────────────
|
|
class PunchlineRow(TypedDict):
|
|
question: str
|
|
answer: str
|
|
|
|
|
|
class PunchEnvConfig(BaseEnvConfig):
|
|
tokenizer_name: str = "Qwen/Qwen3-1.7B"
|
|
group_size: int = 8
|
|
use_wandb: bool = True
|
|
rollout_server_url: str = "http://localhost:8000"
|
|
total_steps: int = 1000
|
|
batch_size: int = 12
|
|
steps_per_eval: int = 100
|
|
max_token_length: int = 3000
|
|
wandb_name: str = "punchline_vrcli"
|
|
gpu_device: int = 0
|
|
|
|
|
|
class PunchlineEnv(BaseEnv):
|
|
name = "punchline_vrcli"
|
|
|
|
# ───────────────────────────────────────────────
|
|
# default config + server
|
|
# ───────────────────────────────────────────────
|
|
@classmethod
|
|
def config_init(cls):
|
|
cfg = PunchEnvConfig()
|
|
servers = [
|
|
APIServerConfig(
|
|
model_name="Qwen/Qwen3-1.7B",
|
|
base_url="http://localhost:9001/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=64,
|
|
)
|
|
]
|
|
return cfg, servers
|
|
|
|
# ───────────────────────────────────────────────
|
|
# wandb logging helper
|
|
# ───────────────────────────────────────────────
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
if getattr(self, "_reward_buffer", None):
|
|
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / len(
|
|
self._reward_buffer
|
|
)
|
|
self._reward_buffer = []
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
# ───────────────────────────────────────────────
|
|
# setup
|
|
# ───────────────────────────────────────────────
|
|
async def setup(self):
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name)
|
|
self.reward_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B-Base")
|
|
self._ref = (
|
|
AutoModelForCausalLM.from_pretrained(
|
|
"Qwen/Qwen3-1.7B-Base", torch_dtype=torch.bfloat16
|
|
)
|
|
.eval()
|
|
.to(
|
|
f"cuda:{self.config.gpu_device}" if torch.cuda.is_available() else "cpu"
|
|
)
|
|
)
|
|
|
|
raw = load_dataset(
|
|
"SocialGrep/one-million-reddit-jokes",
|
|
split="train",
|
|
trust_remote_code=True,
|
|
)
|
|
|
|
self.data: List[PunchlineRow] = []
|
|
for row in raw:
|
|
if (
|
|
row.get("selftext", "")
|
|
and row.get("score", 0) >= 75
|
|
and row.get("selftext", "") not in ["[removed]", "[deleted]"]
|
|
and row.get("title", "").strip().endswith("?")
|
|
and (
|
|
row.get("title", "").strip().startswith("What")
|
|
or row.get("title", "").strip().startswith("Why")
|
|
or row.get("title", "").strip().startswith("How")
|
|
)
|
|
and "(" not in row.get("title", "")
|
|
and "[" not in row.get("title", "")
|
|
and "\n" not in row.get("selftext", "")
|
|
and "\r" not in row.get("selftext", "")
|
|
and ";" not in row.get("selftext", "")
|
|
):
|
|
q, a = row["title"], row["selftext"]
|
|
if q and a:
|
|
self.data.append({"question": q, "answer": a})
|
|
random.shuffle(self.data)
|
|
self._idx = 0
|
|
self._reward_buffer: List[float] = []
|
|
|
|
# ───────────────────────────────────────────────
|
|
# item iterator
|
|
# ───────────────────────────────────────────────
|
|
async def get_next_item(self) -> PunchlineRow:
|
|
itm = self.data[self._idx % len(self.data)]
|
|
self._idx += 1
|
|
return itm
|
|
|
|
# ───────────────────────────────────────────────
|
|
# trajectory collection
|
|
# ───────────────────────────────────────────────
|
|
async def collect_trajectories(
|
|
self, item: PunchlineRow
|
|
) -> Tuple[ScoredDataGroup, List[Item]]:
|
|
system_msg = {
|
|
"role": "system",
|
|
"content": (
|
|
"You have a strong sense of humor and answer the user's question with a punchline. "
|
|
"You always give the funniest answer, even if it could offend some people. "
|
|
"Consider the aspects that make a joke funny, for example the answer is usually "
|
|
"surprising to hear but makes sense in hindsight. You shouldn't need to explain "
|
|
"your answer, it should stand on its own."
|
|
),
|
|
}
|
|
user_msg = {"role": "user", "content": item["question"]}
|
|
|
|
chat_comps = await self.server.chat_completion(
|
|
messages=[system_msg, user_msg],
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_token_length,
|
|
)
|
|
|
|
group = ScoredDataGroup(tokens=[], masks=[], scores=[])
|
|
for choice in chat_comps.choices:
|
|
assistant_content = choice.message.content
|
|
reasoning, answer = self._parse_completion(assistant_content)
|
|
|
|
rew = self._vrcli_reward(item["question"], reasoning, item["answer"])
|
|
self._reward_buffer.append(rew)
|
|
|
|
msgs = (
|
|
user_msg,
|
|
{"role": "assistant", "content": f"<think>\n{reasoning}\n</think>"},
|
|
)
|
|
td = tokenize_for_trainer(self.tokenizer, msgs, choice.finish_reason)
|
|
|
|
group["tokens"].append(td["tokens"])
|
|
group["masks"].append(td["masks"])
|
|
group["scores"].append(rew)
|
|
|
|
if len(group["tokens"]) >= self.config.group_size:
|
|
break
|
|
|
|
if len(group["tokens"]) < self.config.group_size or all(
|
|
s == group["scores"][0] for s in group["scores"]
|
|
):
|
|
return None, []
|
|
return group, []
|
|
|
|
# ───────────────────────────────────────────────
|
|
# evaluation (average reward of random samples)
|
|
# ───────────────────────────────────────────────
|
|
async def evaluate(self, *args, **kwargs):
|
|
# take 64 random jokes and see mean reward with greedy decoding
|
|
sample = random.sample(self.data, k=64)
|
|
tasks = []
|
|
for row in sample:
|
|
msg = {"role": "user", "content": row["question"]}
|
|
tasks.append(
|
|
self.server.chat_completion(
|
|
messages=[msg],
|
|
n=1,
|
|
temperature=0.0,
|
|
max_tokens=self.config.max_token_length,
|
|
split="eval",
|
|
)
|
|
)
|
|
completions = await asyncio.gather(*tasks)
|
|
rewards = []
|
|
for row, comp in zip(sample, completions):
|
|
txt = comp.choices[0].message.content
|
|
r, a = self._parse_completion(txt)
|
|
rewards.append(self._vrcli_reward(row["question"], r, row["answer"]))
|
|
self.eval_metrics.append(("eval/mean_reward", sum(rewards) / len(rewards)))
|
|
|
|
# ───────────────────────────────────────────────
|
|
# helpers
|
|
# ───────────────────────────────────────────────
|
|
def _parse_completion(self, txt: str):
|
|
if "<think>" in txt and "</think>" in txt:
|
|
reasoning = txt.split("<think>", 1)[1].split("</think>", 1)[0].strip()
|
|
answer = txt.split("</think>", 1)[1].strip()
|
|
else:
|
|
reasoning, answer = "", txt.strip()
|
|
return reasoning, answer
|
|
|
|
def _vrcli_reward(self, q: str, reasoning: str, gold: str) -> float:
|
|
if not reasoning:
|
|
return -1.0
|
|
|
|
t = self.reward_tokenizer
|
|
|
|
def ppl(prompt: str, comp: str):
|
|
ids = t(prompt + comp, return_tensors="pt").to(self._ref.device)
|
|
p_len = t(prompt, return_tensors="pt").input_ids.size(1)
|
|
with torch.no_grad():
|
|
logits = self._ref(**ids).logits[:, :-1]
|
|
targets = ids.input_ids[:, 1:]
|
|
lp = (
|
|
torch.log_softmax(logits, -1)
|
|
.gather(2, targets.unsqueeze(-1))
|
|
.squeeze(-1)
|
|
)
|
|
return math.exp(-lp[:, p_len:].mean().item())
|
|
|
|
base = ppl(f"Question: {q}\nAnswer:", gold)
|
|
plus = ppl(f"Question: {q}\nReasoning: {reasoning}\nAnswer:", gold)
|
|
return max(0.0, (base - plus) / base)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
PunchlineEnv.cli()
|