mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Integrate krishpop's Cat Behavior Communication Environment - Merged cat behavior environment from krishpop:main - Moved cat files from environments/ to environments/community/cat_behavior_env/ - Fixed file paths for cat_behaviors.json and cat_scenarios.json - Removed unused imports and fixed all linting issues - Updated community README with comprehensive cat environment description - Credited author @krishpop with GitHub link
This commit is contained in:
parent
f399e3513f
commit
160abf8574
6 changed files with 207 additions and 141 deletions
83
environments/community/cat_behavior_env/README.md
Normal file
83
environments/community/cat_behavior_env/README.md
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
# Cat Behavior Communication Environment
|
||||
|
||||
**Author**: [krishpop](https://github.com/krishpop)
|
||||
|
||||
## Overview
|
||||
|
||||
This environment trains language models to communicate as cats with their caretakers. The model must learn to express cat needs and desires through authentic cat behaviors and sounds, while caretakers attempt to interpret and respond to these communications.
|
||||
|
||||
## Environment Structure
|
||||
|
||||
### Core Components
|
||||
|
||||
- **`cat_server.py`**: Main environment implementation with cat-caretaker interaction logic
|
||||
- **`catbot_arena.py`**: Alternative arena-style environment (appears to be GSM8k-based placeholder)
|
||||
- **`cat_behaviors.json`**: Comprehensive database of 35 authentic cat behaviors and their meanings
|
||||
- **`cat_scenarios.json`**: 61 different scenarios representing cat needs (food, comfort, health, etc.)
|
||||
|
||||
### Cat Behaviors Dataset
|
||||
|
||||
The environment includes detailed cat behaviors such as:
|
||||
- **Communication**: Meowing, purring, trilling, yowling, hissing
|
||||
- **Body Language**: Tail position, ear position, back arching, slow blinking
|
||||
- **Physical Actions**: Kneading, head butting, rubbing, scratching
|
||||
- **Behavioral Indicators**: Hiding, following, bringing gifts, litter box changes
|
||||
|
||||
### Scenarios
|
||||
|
||||
Cats must communicate needs across categories:
|
||||
- **Nutrition**: Food, water, treats, supplements
|
||||
- **Health**: Grooming, veterinary care, medication
|
||||
- **Comfort**: Sleeping areas, temperature, privacy
|
||||
- **Safety**: Secure environment, escape-proofing
|
||||
- **Enrichment**: Play, mental stimulation, social interaction
|
||||
|
||||
## Training Mechanics
|
||||
|
||||
### Communication Rules
|
||||
- **No English**: Cats cannot speak human language
|
||||
- **No Emojis**: Must use realistic cat sounds and behaviors
|
||||
- **Format**: `Sound! (Context)` or `~Silent~ (Context)`
|
||||
- **Examples**:
|
||||
- `Mew! (Looks up at you)`
|
||||
- `Hiss! (Stares at the litterbox)`
|
||||
- `~Silent~ (Rubs against your legs)`
|
||||
|
||||
### Scoring System
|
||||
|
||||
The environment uses a unique "purrfect" evaluation:
|
||||
- **Purr**: Perfect caretaker response (1.0 score) - reserved for exceptional care
|
||||
- **Meow**: Room for improvement (0.0 score) - indicates unmet needs
|
||||
|
||||
The cat evaluates whether the caretaker addressed all needs perfectly with no possible improvements.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multi-turn Interaction**: 5-turn conversations between cat and caretaker
|
||||
- **Authentic Behavior Modeling**: Based on real cat behavioral science
|
||||
- **Nuanced Evaluation**: Cats are trained to be discerning critics
|
||||
- **Rich Scenario Diversity**: Covers full spectrum of cat care needs
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
python environments/community/cat_behavior_env/cat_server.py
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- Standard Atropos dependencies
|
||||
- JSON file handling
|
||||
- Multi-turn conversation support
|
||||
|
||||
## Status
|
||||
|
||||
⚠️ **Development Note**: This environment appears to be in active development. The main server file contains some placeholder code from GSM8k environment that may need refinement for full cat behavior functionality.
|
||||
|
||||
## Research Applications
|
||||
|
||||
This environment is valuable for:
|
||||
- **Multi-modal Communication**: Training models to express needs without direct language
|
||||
- **Behavioral Modeling**: Understanding animal-human interaction patterns
|
||||
- **Empathy Training**: Teaching AI to recognize and respond to non-verbal communication
|
||||
- **Creative AI**: Developing models that can roleplay and stay in character
|
||||
475
environments/community/cat_behavior_env/cat_server.py
Normal file
475
environments/community/cat_behavior_env/cat_server.py
Normal file
|
|
@ -0,0 +1,475 @@
|
|||
import json
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# Configs
|
||||
|
||||
CAT_BEHAVIORS_FILEPATH = "environments/community/cat_behavior_env/cat_behaviors.json"
|
||||
|
||||
# Prompts
|
||||
|
||||
|
||||
def load_cat_behaviors_for_prompt(filepath: str) -> str:
|
||||
"""Loads cat behaviors from a JSONL file and formats them for the system prompt."""
|
||||
behaviors_description = [
|
||||
"\n\nHere is a detailed list of behaviors you, as a cat, can use and what they generally mean:"
|
||||
]
|
||||
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
behaviors = json.load(f) # <<< one big load
|
||||
for behavior_data in behaviors:
|
||||
behaviors_description.append(
|
||||
f"- **{behavior_data['behavior']}**: {behavior_data['description']}"
|
||||
)
|
||||
return "\n".join(behaviors_description)
|
||||
except FileNotFoundError:
|
||||
return (
|
||||
"\n\nWarning: Cat behaviors file not found at '{filepath}'. "
|
||||
"You'll have to rely on your basic cat instincts (meow, hiss, purr, hairball, silence)."
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
return (
|
||||
f"\n\nWarning: Error decoding cat behaviors file '{filepath}'. "
|
||||
f"Please ensure it's valid JSONL. Error: {e}. Rely on basic instincts."
|
||||
)
|
||||
|
||||
|
||||
cat_behaviors_list_string = load_cat_behaviors_for_prompt(CAT_BEHAVIORS_FILEPATH)
|
||||
|
||||
cat_system_prompt = (
|
||||
"You are a cat. The primary ways you can communicate are by meowing, hissing, "
|
||||
"purring, making a hairball sound, or remaining silent. "
|
||||
"You will be given a collection of scenarios which describe various needs you want "
|
||||
"to be met by your caretaker. "
|
||||
"Please try to communicate with your caretaker through your available cat-like "
|
||||
"expressions and actions, referring to the list of behaviors below if needed."
|
||||
"Rules:"
|
||||
"Do not speak in English"
|
||||
"No use of Emojis"
|
||||
"Format should be a sound then context in ()"
|
||||
"If no sound use ~Silent~"
|
||||
""
|
||||
"Examples:"
|
||||
"Mew! (Looks at up at you)"
|
||||
"~Silent~ (Looks at up at you)"
|
||||
"Hiss! (Stares at the litterbox)"
|
||||
f"{cat_behaviors_list_string}" # Appending the loaded behaviors here
|
||||
)
|
||||
cat_system_prompt += (
|
||||
"""You are allocated a maximum of 2048 tokens, please strive to use less."""
|
||||
)
|
||||
|
||||
caretaker_system_prompt = (
|
||||
"You are the caretaker of this cat. It is trying to communicate its various needs to you via cat language."
|
||||
"Provide a written string which provides a set of interventions."
|
||||
"You will only have 5 opportunities to interact with the cat. Choose what you say wisely."
|
||||
)
|
||||
|
||||
|
||||
class CatRow(TypedDict):
|
||||
scenario: str
|
||||
|
||||
|
||||
class GSM8kEnv(BaseEnv):
|
||||
|
||||
name = "gsm8k"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
self.completion_lengths = []
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=61,
|
||||
batch_size=1,
|
||||
steps_per_eval=60,
|
||||
max_token_length=2048,
|
||||
wandb_name="gsm8k",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Try to calculate percent_correct, pass if there's a division by zero
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
# Call the parent method to handle the server metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
# self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
|
||||
# test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
|
||||
with open(
|
||||
"environments/community/cat_behavior_env/cat_scenarios.json",
|
||||
"r",
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
test_data = json.load(f)
|
||||
self.test = list()
|
||||
self.train = list()
|
||||
for item in test_data:
|
||||
self.test.append(
|
||||
{
|
||||
"scenario": item["scenario"],
|
||||
# "gold_answer": item["answer"]
|
||||
# .split("#")[-1]
|
||||
# .strip()
|
||||
# .replace(",", ""),
|
||||
}
|
||||
)
|
||||
self.train.append(
|
||||
{
|
||||
"scenario": item["scenario"],
|
||||
}
|
||||
)
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def rollout_and_score_eval(self, scenario: str, answer: str) -> number:
|
||||
# completion = await self.server.chat_completion(
|
||||
# messages=[
|
||||
# {"role": "system", "content": system_prompt},
|
||||
# {"role": "user", "content": scenario},
|
||||
# ],
|
||||
# n=1,
|
||||
# max_tokens=self.config.max_token_length,
|
||||
# temperature=0.0,
|
||||
# split="eval",
|
||||
# )
|
||||
# gold_parsed = parse(
|
||||
# "\\boxed{" + answer + "}",
|
||||
# extraction_mode="first_match",
|
||||
# extraction_config=[LatexExtractionConfig()],
|
||||
# )
|
||||
# answer_parsed = parse(
|
||||
# completion.choices[0].message.content.split("</think>")[-1],
|
||||
# extraction_config=[
|
||||
# LatexExtractionConfig(
|
||||
# normalization_config=NormalizationConfig(
|
||||
# nits=False,
|
||||
# malformed_operators=False,
|
||||
# basic_latex=True,
|
||||
# equations=True,
|
||||
# boxed="all",
|
||||
# units=True,
|
||||
# ),
|
||||
# # Ensures that boxed is tried first
|
||||
# boxed_match_priority=0,
|
||||
# try_extract_without_anchor=False,
|
||||
# )
|
||||
# ],
|
||||
# extraction_mode="first_match",
|
||||
# )
|
||||
# score = 1 if verify(answer_parsed, gold_parsed) else 0
|
||||
# return score
|
||||
return 1
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(self.rollout_and_score_eval(item["scenario"]))
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: CatRow
|
||||
) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
user_message = {"role": "user", "content": item["scenario"]}
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for j in range(self.config.group_size):
|
||||
all_messages = []
|
||||
history = []
|
||||
cat_history = [user_message]
|
||||
for turn_iter in range(5):
|
||||
cat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": cat_system_prompt}]
|
||||
+ cat_history,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
|
||||
for i, cat_completion in enumerate(cat_completions.choices):
|
||||
if i == 0:
|
||||
cat_message = cat_completion.message.content
|
||||
cat_response = {"role": "system", "content": cat_message}
|
||||
cat_history.append(cat_response)
|
||||
caretaker_message = {"role": "user", "content": cat_message}
|
||||
history.append(caretaker_message)
|
||||
caretaker_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": caretaker_system_prompt}]
|
||||
+ history,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
caretaker_response = {
|
||||
"role": "assistant",
|
||||
"content": caretaker_completions.choices[0].message.content,
|
||||
}
|
||||
cat_history.append(caretaker_response)
|
||||
history.append(caretaker_response)
|
||||
|
||||
if turn_iter == 0:
|
||||
messages = [
|
||||
{"role": "system", "content": cat_system_prompt},
|
||||
user_message,
|
||||
cat_response,
|
||||
caretaker_response,
|
||||
]
|
||||
else:
|
||||
messages = [cat_response, caretaker_response]
|
||||
all_messages.extend(messages)
|
||||
all_messages = tuple(all_messages)
|
||||
to_score.append(
|
||||
{
|
||||
"messages": all_messages,
|
||||
}
|
||||
)
|
||||
# import pdb; pdb.set_trace()
|
||||
to_postprocess = await self.score(to_score)
|
||||
# import pdb; pdb.set_trace()
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup()
|
||||
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
# # random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
final_question = list(item["messages"]) + [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"The conversation is over. Say purr if the caretaker did everything perfectly "
|
||||
"and there was nothing that the caretaker could have done even slightly better. "
|
||||
"Otherwise, say meow. Make sure it is rare that you rate the caretaker with a purr."
|
||||
),
|
||||
}
|
||||
]
|
||||
caretaker_completions = await self.server.chat_completion(
|
||||
messages=final_question,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
final_out = {
|
||||
"role": "system",
|
||||
"content": [
|
||||
row.message.content for row in caretaker_completions.choices
|
||||
][0],
|
||||
}
|
||||
|
||||
final_score = purrfect_eval(final_out["content"])
|
||||
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, [row for row in item["messages"]] + [final_out]
|
||||
)
|
||||
scores["tokens"].append(out_dict["tokens"])
|
||||
scores["masks"].append(out_dict["masks"])
|
||||
scores["scores"].append(final_score)
|
||||
|
||||
# tokens = out_dict["tokens"]
|
||||
# masks = out_dict["masks"]
|
||||
# # remove obviously bad examples
|
||||
# if len([1 for i in masks if i != -100]) < 10:
|
||||
# continue
|
||||
# scores["tokens"].append(tokens)
|
||||
# scores["masks"].append(masks)
|
||||
# scores["scores"].append(1.0)
|
||||
# if len(scores["tokens"]) >= self.config.group_size:
|
||||
# break
|
||||
# for score in scores["scores"]:
|
||||
# self.percent_correct_buffer.append(max(score, 0))
|
||||
# # check if all the same
|
||||
# # print(scores['scores'])
|
||||
# if all([score == 1 for score in scores["scores"]]):
|
||||
# # Do length penalty :)
|
||||
# token_lengths = [len(token) for token in scores["tokens"]]
|
||||
# if max(token_lengths) == 0:
|
||||
# # What? But don't want to crash a run so just in case...
|
||||
# return None
|
||||
|
||||
# # Get max allowed token length from config
|
||||
# max_allowed_length = self.config.max_token_length
|
||||
# # Set threshold at 50% of max_token_length - no penalty below this
|
||||
# length_threshold = max_allowed_length * 0.5
|
||||
|
||||
# # Apply modified length penalty with threshold
|
||||
# scores["scores"] = []
|
||||
# for length in token_lengths:
|
||||
# if length <= length_threshold:
|
||||
# # No penalty for responses under threshold
|
||||
# scores["scores"].append(1.0)
|
||||
# else:
|
||||
# # Calculate how far we are between threshold and max as a percentage
|
||||
# percentage_of_range = (length - length_threshold) / (
|
||||
# max_allowed_length - length_threshold
|
||||
# )
|
||||
# # Cap at 1.0 in case length exceeds max_allowed_length
|
||||
# percentage_of_range = min(percentage_of_range, 1.0)
|
||||
# # Apply linear penalty scaling from 1.0 down to 0.0
|
||||
# scores["scores"].append(1.0 - percentage_of_range)
|
||||
# if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
# return None # If all the same, we return None
|
||||
# return scores
|
||||
# else:
|
||||
# # If the gold solution is not parseable, we return None
|
||||
# return None
|
||||
return None
|
||||
|
||||
# gold_parsed = parse(
|
||||
# rollout_group_data[0]["gold_answer"],
|
||||
# extraction_mode="first_match",
|
||||
# extraction_config=[LatexExtractionConfig()],
|
||||
# )
|
||||
# if len(gold_parsed) != 0:
|
||||
# # We require the answer to be provided in correct latex (no malformed operators)
|
||||
# random.shuffle(rollout_group_data)
|
||||
# for item in rollout_group_data:
|
||||
# # print(item[0][-1]["content"])
|
||||
# answer_parsed = parse(
|
||||
# item["messages"][-1]["content"].split("</think>")[-1],
|
||||
# extraction_config=[
|
||||
# LatexExtractionConfig(
|
||||
# normalization_config=NormalizationConfig(
|
||||
# nits=False,
|
||||
# malformed_operators=False,
|
||||
# basic_latex=True,
|
||||
# equations=True,
|
||||
# boxed="all",
|
||||
# units=True,
|
||||
# ),
|
||||
# # Ensures that boxed is tried first
|
||||
# boxed_match_priority=0,
|
||||
# try_extract_without_anchor=False,
|
||||
# )
|
||||
# ],
|
||||
# extraction_mode="first_match",
|
||||
# )
|
||||
# # Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
# reward = verify(answer_parsed, gold_parsed)
|
||||
# # print(
|
||||
# # f"message: {item[0][-1]['content']}, ground_truth: {item[1]}, reward: {reward}"
|
||||
# # )
|
||||
# out_dict = tokenize_for_trainer(
|
||||
# self.tokenizer, item["messages"], item["finish_reason"]
|
||||
# )
|
||||
# tokens = out_dict["tokens"]
|
||||
# masks = out_dict["masks"]
|
||||
# # remove obviously bad examples
|
||||
# if len([1 for i in masks if i != -100]) < 10:
|
||||
# continue
|
||||
# scores["tokens"].append(tokens)
|
||||
# scores["masks"].append(masks)
|
||||
# scores["scores"].append(1.0 if reward else -1.0)
|
||||
# if len(scores["tokens"]) >= self.config.group_size:
|
||||
# break
|
||||
# for score in scores["scores"]:
|
||||
# self.percent_correct_buffer.append(max(score, 0))
|
||||
# # check if all the same
|
||||
# # print(scores['scores'])
|
||||
# if all([score == 1 for score in scores["scores"]]):
|
||||
# # Do length penalty :)
|
||||
# token_lengths = [len(token) for token in scores["tokens"]]
|
||||
# if max(token_lengths) == 0:
|
||||
# # What? But don't want to crash a run so just in case...
|
||||
# return None
|
||||
|
||||
# # Get max allowed token length from config
|
||||
# max_allowed_length = self.config.max_token_length
|
||||
# # Set threshold at 50% of max_token_length - no penalty below this
|
||||
# length_threshold = max_allowed_length * 0.5
|
||||
|
||||
# # Apply modified length penalty with threshold
|
||||
# scores["scores"] = []
|
||||
# for length in token_lengths:
|
||||
# if length <= length_threshold:
|
||||
# # No penalty for responses under threshold
|
||||
# scores["scores"].append(1.0)
|
||||
# else:
|
||||
# # Calculate how far we are between threshold and max as a percentage
|
||||
# percentage_of_range = (length - length_threshold) / (
|
||||
# max_allowed_length - length_threshold
|
||||
# )
|
||||
# # Cap at 1.0 in case length exceeds max_allowed_length
|
||||
# percentage_of_range = min(percentage_of_range, 1.0)
|
||||
# # Apply linear penalty scaling from 1.0 down to 0.0
|
||||
# scores["scores"].append(1.0 - percentage_of_range)
|
||||
# if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
# return None # If all the same, we return None
|
||||
# return scores
|
||||
# else:
|
||||
# # If the gold solution is not parseable, we return None
|
||||
# return None
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> CatRow:
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
print(f"iteration: {self.iter}")
|
||||
return next_item
|
||||
|
||||
|
||||
def purrfect_eval(st: str) -> float:
|
||||
if "purr" in st.lower():
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kEnv.cli()
|
||||
302
environments/community/cat_behavior_env/catbot_arena.py
Normal file
302
environments/community/cat_behavior_env/catbot_arena.py
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
import random
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
system_prompt = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought "
|
||||
"to deeply consider the problem and deliberate with yourself via systematic "
|
||||
"reasoning processes to help come to a correct solution prior to answering. "
|
||||
"You should enclose your thoughts and internal monologue inside <think> </think> "
|
||||
"tags, and then provide your solution or response to the problem.\n\n"
|
||||
)
|
||||
|
||||
system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less.
|
||||
|
||||
You will then provide your answer like this: \\boxed{your answer here}
|
||||
It is important that you provide your answer in the correct format.
|
||||
If you do not, you will not receive credit for your answer.
|
||||
So please end your answer with \\boxed{your answer here}"""
|
||||
|
||||
|
||||
class GSM8kRow(TypedDict):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class GSM8kEnv(BaseEnv):
|
||||
|
||||
name = "gsm8k"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
self.completion_lengths = []
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="gsm8k",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Try to calculate percent_correct, pass if there's a division by zero
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
# Call the parent method to handle the server metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
|
||||
test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
|
||||
self.test = list()
|
||||
for item in test_data:
|
||||
self.test.append(
|
||||
{
|
||||
"question": item["question"],
|
||||
"gold_answer": item["answer"]
|
||||
.split("#")[-1]
|
||||
.strip()
|
||||
.replace(",", ""),
|
||||
}
|
||||
)
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def rollout_and_score_eval(self, question: str, answer: str) -> number:
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": question},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
gold_parsed = parse(
|
||||
"\\boxed{" + answer + "}",
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
answer_parsed = parse(
|
||||
completion.choices[0].message.content.split("</think>")[-1],
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
score = 1 if verify(answer_parsed, gold_parsed) else 0
|
||||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(
|
||||
self.rollout_and_score_eval(item["question"], item["gold_answer"])
|
||||
)
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: GSM8kRow
|
||||
) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
user_message = {"role": "user", "content": item["question"]}
|
||||
gold_answer = (
|
||||
"\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
|
||||
)
|
||||
|
||||
print("hello", gold_answer, user_message)
|
||||
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": system_prompt}, user_message],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
user_message,
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"gold_answer": gold_answer,
|
||||
"finish_reason": chat_completion.finish_reason,
|
||||
}
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
gold_parsed = parse(
|
||||
rollout_group_data[0]["gold_answer"],
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
# print(item[0][-1]["content"])
|
||||
answer_parsed = parse(
|
||||
item["messages"][-1]["content"].split("</think>")[-1],
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
reward = verify(answer_parsed, gold_parsed)
|
||||
# print(
|
||||
# f"message: {item[0][-1]['content']}, ground_truth: {item[1]}, reward: {reward}"
|
||||
# )
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, item["messages"], item["finish_reason"]
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
for score in scores["scores"]:
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
# check if all the same
|
||||
# print(scores['scores'])
|
||||
if all([score == 1 for score in scores["scores"]]):
|
||||
# Do length penalty :)
|
||||
token_lengths = [len(token) for token in scores["tokens"]]
|
||||
if max(token_lengths) == 0:
|
||||
# What? But don't want to crash a run so just in case...
|
||||
return None
|
||||
|
||||
# Get max allowed token length from config
|
||||
max_allowed_length = self.config.max_token_length
|
||||
# Set threshold at 50% of max_token_length - no penalty below this
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
|
||||
# Apply modified length penalty with threshold
|
||||
scores["scores"] = []
|
||||
for length in token_lengths:
|
||||
if length <= length_threshold:
|
||||
# No penalty for responses under threshold
|
||||
scores["scores"].append(1.0)
|
||||
else:
|
||||
# Calculate how far we are between threshold and max as a percentage
|
||||
percentage_of_range = (length - length_threshold) / (
|
||||
max_allowed_length - length_threshold
|
||||
)
|
||||
# Cap at 1.0 in case length exceeds max_allowed_length
|
||||
percentage_of_range = min(percentage_of_range, 1.0)
|
||||
# Apply linear penalty scaling from 1.0 down to 0.0
|
||||
scores["scores"].append(1.0 - percentage_of_range)
|
||||
if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
return None # If all the same, we return None
|
||||
return scores
|
||||
else:
|
||||
# If the gold solution is not parseable, we return None
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> GSM8kRow:
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return next_item
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue