mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
linting and moved to community
This commit is contained in:
parent
8df34efc56
commit
a6ac7a3e42
46 changed files with 245 additions and 2314 deletions
62
environments/community/humor_generation/README.md
Normal file
62
environments/community/humor_generation/README.md
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
# Humor Generation Environment
|
||||
|
||||
## Overview
|
||||
A reinforcement learning environment for training language models to generate humor in the style of specific comedians and formats. The environment uses a multi-dimensional scoring rubric to evaluate joke quality across relevance, style consistency, creativity, humor effectiveness, virality, and cognitive coherence.
|
||||
|
||||
## Features
|
||||
- **Multi-Comedian Training**: Supports various comedian styles (Norm Macdonald, John Mulaney, Hasan Minhaj, Dave Chappelle, Ali Wong, Chris Rock)
|
||||
- **Format Diversity**: Trains on different humor formats (haiku, one-liner, q/a over SMS)
|
||||
- **Comprehensive Scoring**: 6-dimensional evaluation rubric for joke quality assessment
|
||||
- **Dataset Generation**: Automated dataset creation using GPT-4o-mini
|
||||
- **WandB Integration**: Comprehensive experiment tracking and visualization
|
||||
|
||||
## Environment Structure
|
||||
- `humor_env.py`: Main environment implementation with scoring logic
|
||||
- `generate_humor_dataset.py`: Script for creating training datasets
|
||||
- `humor_dataset.jsonl`: Pre-generated dataset with comedian/format combinations
|
||||
|
||||
## Scoring Rubric
|
||||
The environment evaluates generated jokes across six dimensions (0-3 points each):
|
||||
1. **Relevance to Format** (0-2): How well the joke fits the specified format
|
||||
2. **Style Consistency** (0-2): Adherence to the target comedian's style
|
||||
3. **Creativity** (0-3): Originality and inventiveness of the humor
|
||||
4. **Humor Effectiveness** (0-3): How funny and engaging the joke is
|
||||
5. **Virality** (0-3): Potential for widespread appeal and sharing
|
||||
6. **Cognitive Coherence** (0-3): Logical structure and comprehensibility
|
||||
|
||||
## Usage
|
||||
|
||||
### Running the Environment
|
||||
```bash
|
||||
python environments/community/humor_generation/humor_env.py serve
|
||||
```
|
||||
|
||||
### Generating New Datasets
|
||||
```bash
|
||||
cd environments/community/humor_generation/
|
||||
python generate_humor_dataset.py
|
||||
```
|
||||
|
||||
## Configuration
|
||||
- **Model**: GPT-4o-mini for both generation and evaluation
|
||||
- **Group Size**: 2 completions per prompt
|
||||
- **Max Tokens**: 2048 for joke generation, 512 for scoring
|
||||
- **Evaluation**: LLM-based scoring using detailed rubric prompts
|
||||
|
||||
## Requirements
|
||||
- OpenAI API key (set as `OPENAI_API_KEY` environment variable)
|
||||
- Standard Atropos dependencies
|
||||
- WandB account for experiment tracking
|
||||
|
||||
## Dataset Format
|
||||
Each record contains:
|
||||
- `comedian`: Target comedian style
|
||||
- `format`: Humor format (haiku, one-liner, q/a over SMS)
|
||||
- `question`: Prompt asking for model recommendations and example jokes
|
||||
- `response`: GPT-4o-mini generated response with explanations and examples
|
||||
|
||||
## Training Applications
|
||||
- **Style Transfer**: Learning to mimic specific comedian voices
|
||||
- **Format Adaptation**: Generating humor in constrained formats
|
||||
- **Quality Assessment**: Training models to evaluate humor effectiveness
|
||||
- **Creative Writing**: Developing AI systems for entertainment content creation
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
comedians = [
|
||||
"Norm Macdonald",
|
||||
"John Mulaney",
|
||||
"Hasan Minhaj",
|
||||
"Dave Chappelle",
|
||||
"Ali Wong",
|
||||
"Chris Rock",
|
||||
]
|
||||
formats = [
|
||||
"haiku",
|
||||
"one-liner",
|
||||
"q/a over sms",
|
||||
]
|
||||
|
||||
output_file = "humor_dataset.jsonl"
|
||||
model_name = "gpt-4o-mini"
|
||||
logger.info(f"Generating humor dataset to {output_file} using model {model_name}")
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as fout:
|
||||
for comedian in comedians:
|
||||
for fmt in formats:
|
||||
question = (
|
||||
f"What’s the best local LLM model to generate {fmt} jokes "
|
||||
f"in the style of {comedian}? Please explain your reasoning step by step, "
|
||||
f"and generate 3 example jokes."
|
||||
)
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": question}],
|
||||
)
|
||||
answer = response.choices[0].message.content.strip()
|
||||
record = {
|
||||
"comedian": comedian,
|
||||
"format": fmt,
|
||||
"question": question,
|
||||
"response": answer,
|
||||
}
|
||||
fout.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
logger.info(f"Wrote record: comedian={comedian}, format={fmt}")
|
||||
|
||||
# Verify dataset count
|
||||
count = sum(1 for _ in open(output_file, encoding="utf-8"))
|
||||
logger.info(f"Dataset {output_file} contains {count} records")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
127
environments/community/humor_generation/humor_env.py
Normal file
127
environments/community/humor_generation/humor_env.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class HumorEnvConfig(BaseEnvConfig):
|
||||
data_path: str = "environments/community/humor_generation/humor_dataset.jsonl"
|
||||
|
||||
|
||||
class HumorEnv(BaseEnv):
|
||||
env_config_cls = HumorEnvConfig
|
||||
name = "humor"
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[HumorEnvConfig, List[APIServerConfig]]:
|
||||
env_config = cls.env_config_cls(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=2,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1024,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="humor",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt-4o-mini",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=256,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
ds = load_dataset("json", data_files=self.config.data_path, split="train")
|
||||
self.train = ds
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self) -> Tuple[dict]:
|
||||
record = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return (record,)
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
|
||||
record = item[0]
|
||||
prompt = record["question"]
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
to_score = []
|
||||
for choice in chat_completions.choices:
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": choice.message.content},
|
||||
]
|
||||
to_score.append((tuple(messages), choice.finish_reason))
|
||||
scored = await self.score(to_score)
|
||||
return scored, []
|
||||
|
||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||
"""
|
||||
Score each generated joke using the detailed rubric via an LLM call.
|
||||
"""
|
||||
scores = ScoredDataGroup(tokens=[], masks=[], scores=[])
|
||||
# All items share same comedian/format
|
||||
fmt = self.train[0]["format"]
|
||||
comedian = self.train[0]["comedian"]
|
||||
for messages, _ in rollout_group_data:
|
||||
joke = messages[-1]["content"].strip()
|
||||
# Build the rubric prompt
|
||||
rubric_prompt = (
|
||||
f'1. Relevance to the format ({fmt}): Evaluate the joke "{joke}". Score: X (0-2)\n'
|
||||
f'2. Style consistency ({comedian}): Evaluate the joke "{joke}". Score: X (0-2)\n'
|
||||
f'3. Creativity: Evaluate the joke "{joke}". Score: X (0-3)\n'
|
||||
f'4. Humor effectiveness: Evaluate the joke "{joke}". Score: X (0-3)\n'
|
||||
f'5. Virality: Evaluate the joke "{joke}". Score: X (0-3)\n'
|
||||
f'6. Cognitive coherence: Evaluate the joke "{joke}". Score: X (0-3)\n'
|
||||
"Please provide each score on its own line as 'Score: <number>'."
|
||||
)
|
||||
judge = await self.server.chat_completion(
|
||||
messages=[{"role": "user", "content": rubric_prompt}],
|
||||
n=1,
|
||||
max_tokens=512,
|
||||
)
|
||||
text = judge.choices[0].message.content
|
||||
# Parse out all Score: X lines
|
||||
nums = [
|
||||
int(line.split("Score:")[-1].strip().split()[0])
|
||||
for line in text.splitlines()
|
||||
if "Score:" in line
|
||||
]
|
||||
avg_score = sum(nums) / len(nums) if nums else 0.0
|
||||
out = tokenize_for_trainer(self.tokenizer, messages)
|
||||
scores["tokens"].append(out["tokens"])
|
||||
scores["masks"].append(out["masks"])
|
||||
scores["scores"].append(avg_score)
|
||||
return scores
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[dict] = None):
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# No-op evaluation; required by BaseEnv abstract interface
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
# default to 'serve' if no subcommand provided
|
||||
if len(sys.argv) == 1:
|
||||
sys.argv.append("serve")
|
||||
HumorEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue