mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
Move to punchlines folder
This commit is contained in:
parent
9b920bc7c6
commit
cf9429f47f
3 changed files with 0 additions and 0 deletions
73
environments/hack0/punchlines/README.md
Normal file
73
environments/hack0/punchlines/README.md
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
# Punchline VR-CLI Environment
|
||||
|
||||
This environment trains an LLM to generate humorous punchlines for jokes, utilizing a Reinforcement Learning (RL) technique called Verifiable Rewards via Completion Likelihood Improvement (VR-CLI), from the paper "Learning to Reason for Long-Form Story Generation" (Gurung & Lapata, 2025) [https://arxiv.org/html/2503.22828v1](https://arxiv.org/html/2503.22828v1).
|
||||
|
||||
## Design
|
||||
|
||||
The core idea is to teach a model not just to produce a punchline, but to first generate "reasoning" that leads to a good punchline. The quality of this reasoning is then "verified" by measuring how much it improves the likelihood (reduces the perplexity) of the *actual* punchline from the dataset, as assessed by a separate, fixed reference model. This greatly reduces overfitting, as the model does not have access to the ground-truth answer. Typical fine-tuning fails at this, as it makes the model memorize the jokes rather than gain an understanding of why they are funny.
|
||||
|
||||
## Example
|
||||
|
||||
Here's an example of one trajectory:
|
||||
|
||||
```
|
||||
Question: What do you call a herd of cows masturbating?
|
||||
|
||||
<think>
|
||||
Okay, the user is asking, "What do you call a herd of cows masturbating?" Hmm, that's a play on words. Let me think.
|
||||
The key here is the word "masturbating" and the animal "cows." The answer needs to be a pun or a play on words.
|
||||
|
||||
First, I need to connect "herd" with "masturbating." A herd of cows is a group, so maybe something related to a group.
|
||||
"Masturbating" is a term that's not typically used for animals, but maybe there's a word that combines the two.........
|
||||
</think>
|
||||
|
||||
Beef strokin off!
|
||||
```
|
||||
|
||||
## Impact
|
||||
|
||||
How is this applicable to anything useful? Well, VR-CLI can be applied to many domains beyond jokes. The original authors of the paper used it for creative writing, but it could also be applied to code generation without requiring it to be executed in a VM or custom business tasks with existing examples.
|
||||
|
||||
## Key Components:
|
||||
|
||||
* **Dataset:** The environment uses the `"SocialGrep/one-million-reddit-jokes"` dataset, filtering for jokes with a question-answer format (setup and punchline) and a minimum number of upvotes.
|
||||
* **Task:** Given the setup of a joke (the "question"), the model `Qwen/Qwen3-1.7B` is prompted to generate a thinking process `<think>...</think>` followed by the punchline.
|
||||
* **Reward (VR-CLI):**
|
||||
1. A base perplexity of the "golden" punchline is calculated given only the joke's setup, using a reference LLM `Qwen/Qwen3-1.7B-Base`.
|
||||
2. A new perplexity of the golden punchline is calculated, this time conditioned on both the joke's setup AND the model-generated reasoning.
|
||||
3. The reward is proportional to the improvement in perplexity `(base_perplexity - plus_perplexity) / base_perplexity`. A positive reward indicates the reasoning was helpful.
|
||||
|
||||
## Metrics
|
||||
|
||||
The `evaluate` function selects 64 random jokes, then computes the average reward using greedy decoding. This is a simple and effective way to measure the model's progress over time.
|
||||
|
||||
## Quickstart
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Ensure you have Python installed, then install the dependencies:
|
||||
|
||||
```bash
|
||||
pip install "vllm>=0.8.5" torch transformers datasets wandb tenacity atroposlib pydantic
|
||||
```
|
||||
|
||||
### Processing the Dataset
|
||||
|
||||
To process the dataset, use the `punchline_env.py` script with the `process` command:
|
||||
|
||||
```bash
|
||||
python punchline_env.py process \
|
||||
--env.data_path_to_save_groups jokes.jsonl \
|
||||
--env.wandb_name punchline_vrcli \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.api_key x \
|
||||
--openai.model_name Qwen/Qwen3-1.7B
|
||||
```
|
||||
|
||||
You will need to have vLLM serving the model on port 9001 for this to work.
|
||||
|
||||
[Weights & Biases link](https://wandb.ai/jaboggs-nous-hackathon-nc-state-university/uncategorized/runs/0vly0u4p)
|
||||
|
||||
#### Output
|
||||
|
||||
[jokes.zip](https://github.com/user-attachments/files/20275345/jokes.zip)
|
||||
229
environments/hack0/punchlines/punchline_env.py
Normal file
229
environments/hack0/punchlines/punchline_env.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""
|
||||
Punchline VR-CLI Environment for Atropos
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import math, random, asyncio
|
||||
from typing import List, Optional, Tuple, TypedDict, Dict
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
# ───────────────────────────────────────────────────────
|
||||
# Config & data row
|
||||
# ───────────────────────────────────────────────────────
|
||||
class PunchlineRow(TypedDict):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class PunchEnvConfig(BaseEnvConfig):
|
||||
tokenizer_name: str = "Qwen/Qwen3-1.7B"
|
||||
group_size: int = 8
|
||||
use_wandb: bool = True
|
||||
rollout_server_url: str = "http://localhost:8000"
|
||||
total_steps: int = 1000
|
||||
batch_size: int = 12
|
||||
steps_per_eval: int = 100
|
||||
max_token_length: int = 3000
|
||||
wandb_name: str = "punchline_vrcli"
|
||||
gpu_device: int = 0
|
||||
|
||||
|
||||
class PunchlineEnv(BaseEnv):
|
||||
name = "punchline_vrcli"
|
||||
|
||||
# ───────────────────────────────────────────────
|
||||
# default config + server
|
||||
# ───────────────────────────────────────────────
|
||||
@classmethod
|
||||
def config_init(cls):
|
||||
cfg = PunchEnvConfig()
|
||||
servers = [
|
||||
APIServerConfig(
|
||||
model_name="Qwen/Qwen3-1.7B",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=64,
|
||||
)
|
||||
]
|
||||
return cfg, servers
|
||||
|
||||
# ───────────────────────────────────────────────
|
||||
# wandb logging helper
|
||||
# ───────────────────────────────────────────────
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
if getattr(self, "_reward_buffer", None):
|
||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / len(
|
||||
self._reward_buffer
|
||||
)
|
||||
self._reward_buffer = []
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
# ───────────────────────────────────────────────
|
||||
# setup
|
||||
# ───────────────────────────────────────────────
|
||||
async def setup(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name)
|
||||
self.reward_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B-Base")
|
||||
self._ref = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen3-1.7B-Base", torch_dtype=torch.bfloat16
|
||||
)
|
||||
.eval()
|
||||
.to(f"cuda:{self.config.gpu_device}" if torch.cuda.is_available() else "cpu")
|
||||
)
|
||||
|
||||
raw = load_dataset(
|
||||
"SocialGrep/one-million-reddit-jokes",
|
||||
split="train",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
self.data: List[PunchlineRow] = []
|
||||
for row in raw:
|
||||
if (
|
||||
row.get("selftext", "")
|
||||
and row.get("score", 0) >= 75
|
||||
and row.get("selftext", "") not in ["[removed]", "[deleted]"]
|
||||
and row.get("title", "").strip().endswith("?")
|
||||
and (
|
||||
row.get("title", "").strip().startswith("What")
|
||||
or row.get("title", "").strip().startswith("Why")
|
||||
or row.get("title", "").strip().startswith("How"))
|
||||
and "(" not in row.get("title", "")
|
||||
and "[" not in row.get("title", "")
|
||||
and "\n" not in row.get("selftext", "")
|
||||
and "\r" not in row.get("selftext", "")
|
||||
and ";" not in row.get("selftext", "")
|
||||
):
|
||||
q, a = row["title"], row["selftext"]
|
||||
if q and a:
|
||||
self.data.append({"question": q, "answer": a})
|
||||
random.shuffle(self.data)
|
||||
self._idx = 0
|
||||
self._reward_buffer: List[float] = []
|
||||
|
||||
# ───────────────────────────────────────────────
|
||||
# item iterator
|
||||
# ───────────────────────────────────────────────
|
||||
async def get_next_item(self) -> PunchlineRow:
|
||||
itm = self.data[self._idx % len(self.data)]
|
||||
self._idx += 1
|
||||
return itm
|
||||
|
||||
# ───────────────────────────────────────────────
|
||||
# trajectory collection
|
||||
# ───────────────────────────────────────────────
|
||||
async def collect_trajectories(
|
||||
self, item: PunchlineRow
|
||||
) -> Tuple[ScoredDataGroup, List[Item]]:
|
||||
system_msg = {"role": "system", "content": "You have a strong sense of humor and answer the user's question with a punchline. You always give the funniest answer, even if it could offend some people. Consider the aspects that make a joke funny, for example the answer is usually surprising to hear but makes sense in hindsight. You shouldn't need to explain your answer, it should stand on its own."}
|
||||
user_msg = {"role": "user", "content": item["question"]}
|
||||
|
||||
chat_comps = await self.server.chat_completion(
|
||||
messages=[system_msg, user_msg],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
|
||||
group = ScoredDataGroup(tokens=[], masks=[], scores=[])
|
||||
for choice in chat_comps.choices:
|
||||
assistant_content = choice.message.content
|
||||
reasoning, answer = self._parse_completion(assistant_content)
|
||||
|
||||
rew = self._vrcli_reward(
|
||||
item["question"], reasoning, item["answer"]
|
||||
)
|
||||
self._reward_buffer.append(rew)
|
||||
|
||||
msgs = (user_msg, {"role": "assistant", "content": f"<think>\n{reasoning}\n</think>"})
|
||||
td = tokenize_for_trainer(self.tokenizer, msgs, choice.finish_reason)
|
||||
|
||||
group["tokens"].append(td["tokens"])
|
||||
group["masks"].append(td["masks"])
|
||||
group["scores"].append(rew)
|
||||
|
||||
if len(group["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
if len(group["tokens"]) < self.config.group_size or all(
|
||||
s == group["scores"][0] for s in group["scores"]
|
||||
):
|
||||
return None, []
|
||||
return group, []
|
||||
|
||||
# ───────────────────────────────────────────────
|
||||
# evaluation (average reward of random samples)
|
||||
# ───────────────────────────────────────────────
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# take 64 random jokes and see mean reward with greedy decoding
|
||||
sample = random.sample(self.data, k=64)
|
||||
tasks = []
|
||||
for row in sample:
|
||||
msg = {"role": "user", "content": row["question"]}
|
||||
tasks.append(
|
||||
self.server.chat_completion(
|
||||
messages=[msg],
|
||||
n=1,
|
||||
temperature=0.0,
|
||||
max_tokens=self.config.max_token_length,
|
||||
split="eval",
|
||||
)
|
||||
)
|
||||
completions = await asyncio.gather(*tasks)
|
||||
rewards = []
|
||||
for row, comp in zip(sample, completions):
|
||||
txt = comp.choices[0].message.content
|
||||
r, a = self._parse_completion(txt)
|
||||
rewards.append(self._vrcli_reward(row["question"], r, row["answer"]))
|
||||
self.eval_metrics.append(("eval/mean_reward", sum(rewards) / len(rewards)))
|
||||
|
||||
# ───────────────────────────────────────────────
|
||||
# helpers
|
||||
# ───────────────────────────────────────────────
|
||||
def _parse_completion(self, txt: str):
|
||||
if "<think>" in txt and "</think>" in txt:
|
||||
reasoning = txt.split("<think>", 1)[1].split("</think>", 1)[0].strip()
|
||||
answer = txt.split("</think>", 1)[1].strip()
|
||||
else:
|
||||
reasoning, answer = "", txt.strip()
|
||||
return reasoning, answer
|
||||
|
||||
def _vrcli_reward(self, q: str, reasoning: str, gold: str) -> float:
|
||||
if not reasoning:
|
||||
return -1.0
|
||||
|
||||
t = self.reward_tokenizer
|
||||
|
||||
def ppl(prompt: str, comp: str):
|
||||
ids = t(prompt + comp, return_tensors="pt").to(self._ref.device)
|
||||
p_len = t(prompt, return_tensors="pt").input_ids.size(1)
|
||||
with torch.no_grad():
|
||||
logits = self._ref(**ids).logits[:, :-1]
|
||||
targets = ids.input_ids[:, 1:]
|
||||
lp = torch.log_softmax(logits, -1).gather(
|
||||
2, targets.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
return math.exp(-lp[:, p_len:].mean().item())
|
||||
|
||||
base = ppl(f"Question: {q}\nAnswer:", gold)
|
||||
plus = ppl(f"Question: {q}\nReasoning: {reasoning}\nAnswer:", gold)
|
||||
return max(0.0, (base - plus) / base)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PunchlineEnv.cli()
|
||||
8
environments/hack0/punchlines/requirements.txt
Normal file
8
environments/hack0/punchlines/requirements.txt
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
vllm>=0.8.5
|
||||
torch
|
||||
transformers
|
||||
datasets
|
||||
wandb
|
||||
tenacity
|
||||
atroposlib
|
||||
pydantic
|
||||
Loading…
Add table
Add a link
Reference in a new issue