mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge remote-tracking branch 'krishpop/main' into merge-krishpop-contributions
This commit is contained in:
commit
f399e3513f
5 changed files with 1348 additions and 0 deletions
510
atroposlib/envs/xitter_env.py
Normal file
510
atroposlib/envs/xitter_env.py
Normal file
|
|
@ -0,0 +1,510 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union # Added Union
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig, # Added
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
EvalHandlingEnum, # Added
|
||||
Item, # Added
|
||||
ScoredDataGroup,
|
||||
ScoredDataItem,
|
||||
)
|
||||
from atroposlib.type_definitions import Message # Added
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# --- Mocked Trending Topics ---
|
||||
# In a real implementation, this would come from an API
|
||||
MOCK_TRENDING_TOPICS = [
|
||||
"AI in Art",
|
||||
"Climate Change Solutions",
|
||||
"New Space Discoveries",
|
||||
]
|
||||
|
||||
# --- Mocked Social Media State (Simplified) ---
|
||||
# This would be more complex in a real environment, likely managed by dedicated classes
|
||||
# and updated dynamically.
|
||||
MOCK_SOCIAL_FEED = [
|
||||
{
|
||||
"id": "post1",
|
||||
"agent_id": "agent_alpha",
|
||||
"content": "Just enjoyed a great virtual concert! #metaverse",
|
||||
"likes": 10,
|
||||
"comments": [],
|
||||
},
|
||||
{
|
||||
"id": "post2",
|
||||
"agent_id": "agent_beta",
|
||||
"content": "Excited about the upcoming Atropos hackathon!",
|
||||
"likes": 15,
|
||||
"comments": [],
|
||||
},
|
||||
]
|
||||
MOCK_AGENT_PROFILES = {
|
||||
"agent_gamma": {"posts": [], "score": 0, "notifications": []},
|
||||
"agent_delta": {"posts": [], "score": 0, "notifications": []},
|
||||
}
|
||||
|
||||
|
||||
# We'll need to define a system prompt
|
||||
SYSTEM_PROMPT_TEMPLATE = """You are '{agent_id}', an agent on Xitter, a simulated social media platform.
|
||||
Your goal is to maximize engagement by posting interesting content (Xits), liking relevant Xits, and making insightful comments.
|
||||
You can also choose to DO_NOTHING.
|
||||
Current trending topics: {trending_topics}
|
||||
|
||||
Recent Xits in your feed (newest first):
|
||||
{feed_preview}
|
||||
|
||||
Your recent notifications:
|
||||
{notifications_preview}
|
||||
|
||||
Choose one action:
|
||||
1. POST <your_xits_content_here> (max 140 chars)
|
||||
2. LIKE <post_id_to_like> (e.g., LIKE post_3)
|
||||
3. COMMENT <post_id_to_comment_on> <your_comment_content_here> (e.g., COMMENT post_2 Great point!)
|
||||
4. DO_NOTHING
|
||||
|
||||
Your response should be ONLY the action string (e.g., "POST This is my new Xit! #awesome").
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class XitterEnvConfig(BaseEnvConfig):
|
||||
"""Configuration for the Xitter (Social Media) Environment."""
|
||||
# Reward weights
|
||||
like_reward_weight: float = 0.1
|
||||
comment_reward_weight: float = 0.5
|
||||
trending_topic_bonus: float = 0.2
|
||||
perform_like_reward: float = 0.05 # Small reward for the act of liking
|
||||
perform_comment_reward: float = 0.1 # Small reward for the act of commenting
|
||||
do_nothing_reward: float = 0.0 # Reward for doing nothing
|
||||
invalid_action_penalty: float = -0.5
|
||||
action_cost: float = -0.01 # Small cost for any action
|
||||
|
||||
# Environment parameters
|
||||
num_agents: int = 2
|
||||
max_feed_size: int = 20
|
||||
max_notifications_display: int = 5 # How many notifications to show in prompt
|
||||
initial_trending_topics: List[str] = field(default_factory=lambda: ["AI in Art", "Climate Solutions", "Space Exploration"])
|
||||
|
||||
# For wandb logging of agent-specific scores
|
||||
track_individual_agent_scores: bool = True
|
||||
|
||||
|
||||
|
||||
class XitterEnv(BaseEnv):
|
||||
# Assuming XitterEnvConfig is defined elsewhere and includes:
|
||||
# group_size, max_token_length, tokenizer_name, various reward_weights, etc.
|
||||
env_config_cls = XitterEnvConfig # Assign your custom config
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig, # Use BaseEnvConfig or your specific XitterEnvConfig
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
# Initialize social media state
|
||||
self.social_feed = MOCK_SOCIAL_FEED # list of posts
|
||||
self.agent_profiles = (
|
||||
MOCK_AGENT_PROFILES # dict of agent_id -> profile_data
|
||||
)
|
||||
self.trending_topics = MOCK_TRENDING_TOPICS
|
||||
self.current_agent_turn = 0 # Simple round-robin for turns
|
||||
self.agent_ids = list(self.agent_profiles.keys())
|
||||
|
||||
# Example reward function instances (you'd define these)
|
||||
# self.engagement_reward_fn = EngagementReward(...)
|
||||
# self.relevance_reward_fn = RelevanceReward(...)
|
||||
|
||||
async def setup(self):
|
||||
# Load tokenizer, initialize agents, fetch initial trends, etc.
|
||||
# self.tokenizer is already initialized in BaseEnv
|
||||
logging.info(f"{self.name or 'XitterEnv'} setup complete.")
|
||||
# Potentially fetch initial trending topics here
|
||||
# self.trending_topics = await self.fetch_trending_topics()
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
# Determine which agent's turn it is and prepare their observation
|
||||
agent_id_for_turn = self.agent_ids[
|
||||
self.current_agent_turn % len(self.agent_ids)
|
||||
]
|
||||
self.current_agent_turn += 1
|
||||
|
||||
# Construct observation for the agent
|
||||
# For simplicity, we'll just pass the agent_id and let collect_trajectories build the full prompt
|
||||
# In a more complex setup, you'd build a richer observation here.
|
||||
# The Item can be any structure your collect_trajectories method expects.
|
||||
# Here, a tuple: (agent_id_acting, current_feed_snapshot, current_trends, agent_notifications)
|
||||
# For now, let's simplify and pass agent_id and have collect_trajectories build the prompt.
|
||||
return (
|
||||
agent_id_for_turn,
|
||||
{
|
||||
"trending_topics": self.trending_topics,
|
||||
"feed_preview": self.social_feed[:5],
|
||||
},
|
||||
)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
||||
"""
|
||||
Generates a group of potential actions for an agent and gathers data for scoring.
|
||||
The `item` from `get_next_item` should provide context for the current agent's turn.
|
||||
"""
|
||||
agent_id_acting, observation_context = item
|
||||
# trending_topics_str = ", ".join(observation_context.get("trending_topics", []))
|
||||
|
||||
# Construct the prompt for the LLM agent based on its observation
|
||||
# This would include a view of the feed, notifications, and trending topics.
|
||||
# For this example, we'll use a simplified prompt.
|
||||
# The actual prompt engineering is a key part of designing the environment.
|
||||
prompt_content = f"It's your turn, {agent_id_acting}. Recent posts: {str(observation_context.get('feed_preview', []))}. What do you do?"
|
||||
|
||||
messages_for_llm = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT.format(
|
||||
trending_topics=", ".join(self.trending_topics)
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": prompt_content},
|
||||
]
|
||||
|
||||
# Apply chat template for the LLM
|
||||
# The prefill is not used here as the action choice is part of the LLM's generation
|
||||
prompt_str_for_llm = self.tokenizer.apply_chat_template(
|
||||
messages_for_llm,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True, # Important for instruct/chat models
|
||||
)
|
||||
|
||||
# Get self.config.group_size completions from the LLM
|
||||
# Each completion is a potential action (post, like, comment)
|
||||
try:
|
||||
completions_obj = await self.server.completion( # Using completion for free-form action generation
|
||||
prompt=prompt_str_for_llm, # BaseEnv.server.completion expects a string prompt
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length
|
||||
// 4, # Max tokens for the action itself
|
||||
temperature=0.7, # Allow some diversity in actions
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error getting completions from LLM for agent {agent_id_acting}: {e}"
|
||||
)
|
||||
return None, []
|
||||
|
||||
# This list will hold tuples of (full_chat_history_for_action, action_details_for_scoring)
|
||||
# where action_details_for_scoring will be passed to the `score` method.
|
||||
trajectories_for_scoring: List[
|
||||
Tuple[List[Dict[str, str]], Dict[str, Any]]
|
||||
] = []
|
||||
|
||||
for choice in completions_obj.choices:
|
||||
llm_generated_action_text = choice.text.strip()
|
||||
|
||||
# Simulate the action and its effect on the environment state
|
||||
# This is a placeholder for your actual simulation logic.
|
||||
# It needs to parse llm_generated_action_text (e.g., "POST My new cat video #cats")
|
||||
# and update self.social_feed, self.agent_profiles, etc.
|
||||
action_type, action_params, action_valid = (
|
||||
self._parse_and_simulate_action(
|
||||
agent_id_acting, llm_generated_action_text
|
||||
)
|
||||
)
|
||||
|
||||
if not action_valid:
|
||||
# Potentially penalize invalid actions in the score function or give a default low score
|
||||
# For now, we'll still include it to be scored (and likely penalized)
|
||||
logging.warning(
|
||||
f"Agent {agent_id_acting} performed an invalid action: {llm_generated_action_text}"
|
||||
)
|
||||
|
||||
# Create the full message history for this trajectory (system, user, assistant_action)
|
||||
# The `messages_for_llm` already contains system and user turns.
|
||||
current_trajectory_messages = messages_for_llm + [
|
||||
{"role": "assistant", "content": llm_generated_action_text}
|
||||
]
|
||||
|
||||
# Context needed by the score function for this specific action
|
||||
# This will depend heavily on your reward components.
|
||||
scoring_context = {
|
||||
"agent_id": agent_id_acting,
|
||||
"action_type": action_type, # "post", "like", "comment", "invalid", "do_nothing"
|
||||
"action_params": action_params, # e.g., post_id for like, content for post
|
||||
"was_valid_action": action_valid,
|
||||
# For scoring a "post" action, you'd later fill these after observing next turn's interactions:
|
||||
# "likes_received_on_post": X,
|
||||
# "comments_received_on_post": Y,
|
||||
"trending_topics": self.trending_topics, # Pass current trends for relevance scoring
|
||||
}
|
||||
# If the action was a 'post', we store the new post_id in scoring_context
|
||||
# so that the score function can later attribute likes/comments to it.
|
||||
# This implies that rewards for posts might be delayed by one or more turns.
|
||||
if action_type == "post" and "new_post_id" in action_params:
|
||||
scoring_context["post_id_created"] = action_params[
|
||||
"new_post_id"
|
||||
]
|
||||
|
||||
trajectories_for_scoring.append(
|
||||
(current_trajectory_messages, scoring_context)
|
||||
)
|
||||
|
||||
# The `score` method will take this list and produce the ScoredDataGroup
|
||||
# The actual rewards might be assigned in `score` based on the outcome of these actions,
|
||||
# potentially looking at the state *after* all agents in a round have acted,
|
||||
# or even after a delay (e.g. likes/comments on a post arrive in future turns).
|
||||
# For simplicity, this example implies immediate scoring, but delayed rewards are common.
|
||||
|
||||
# Pass the collected trajectories and their scoring contexts to the score method
|
||||
scored_data_group = await self.score(trajectories_for_scoring)
|
||||
|
||||
# No backlog items in this simple version
|
||||
return scored_data_group, []
|
||||
|
||||
def _parse_and_simulate_action(
|
||||
self, agent_id: str, action_text: str
|
||||
) -> Tuple[str, Dict, bool]:
|
||||
"""
|
||||
Parses the LLM's action string and simulates its effect on the environment.
|
||||
Returns: (action_type, action_params, was_valid)
|
||||
This is a placeholder and needs to be implemented based on your defined action space.
|
||||
"""
|
||||
action_text_lower = action_text.lower()
|
||||
new_post_id_counter = len(self.social_feed)
|
||||
|
||||
if action_text_lower.startswith("post "):
|
||||
content = action_text[5:].strip()
|
||||
if content:
|
||||
new_post_id_counter += 1
|
||||
post_id = f"post{new_post_id_counter}"
|
||||
new_post = {
|
||||
"id": post_id,
|
||||
"agent_id": agent_id,
|
||||
"content": content,
|
||||
"likes": 0,
|
||||
"comments": [],
|
||||
}
|
||||
self.social_feed.insert(0, new_post) # Add to top of feed
|
||||
self.agent_profiles[agent_id]["posts"].append(post_id)
|
||||
logging.info(f"Agent {agent_id} POSTED: {content}")
|
||||
return (
|
||||
"post",
|
||||
{"content": content, "new_post_id": post_id},
|
||||
True,
|
||||
)
|
||||
elif action_text_lower.startswith("like "):
|
||||
try:
|
||||
post_id_to_like = action_text.split(" ")[1]
|
||||
for post in self.social_feed:
|
||||
if post["id"] == post_id_to_like:
|
||||
post["likes"] += 1
|
||||
# Notify original poster (simplified)
|
||||
original_poster_id = post["agent_id"]
|
||||
if (
|
||||
original_poster_id != agent_id
|
||||
and original_poster_id in self.agent_profiles
|
||||
):
|
||||
self.agent_profiles[original_poster_id][
|
||||
"notifications"
|
||||
].append(
|
||||
f"{agent_id} liked your post {post_id_to_like}"
|
||||
)
|
||||
logging.info(
|
||||
f"Agent {agent_id} LIKED post: {post_id_to_like}"
|
||||
)
|
||||
return "like", {"post_id": post_id_to_like}, True
|
||||
except IndexError:
|
||||
return "invalid_like_format", {}, False
|
||||
return (
|
||||
"like_post_not_found",
|
||||
{"post_id": action_text.split(" ")[1]},
|
||||
False,
|
||||
) # Post not found
|
||||
elif action_text_lower.startswith("comment "):
|
||||
parts = action_text.split(" ", 2)
|
||||
if len(parts) == 3:
|
||||
post_id_to_comment_on = parts[1]
|
||||
comment_content = parts[2].strip()
|
||||
if comment_content:
|
||||
for post in self.social_feed:
|
||||
if post["id"] == post_id_to_comment_on:
|
||||
comment_id = f"comment{len(post['comments']) + 1}_on_{post_id_to_comment_on}"
|
||||
post["comments"].append(
|
||||
{
|
||||
"id": comment_id,
|
||||
"agent_id": agent_id,
|
||||
"content": comment_content,
|
||||
}
|
||||
)
|
||||
# Notify original poster (simplified)
|
||||
original_poster_id = post["agent_id"]
|
||||
if (
|
||||
original_poster_id != agent_id
|
||||
and original_poster_id in self.agent_profiles
|
||||
):
|
||||
self.agent_profiles[original_poster_id][
|
||||
"notifications"
|
||||
].append(
|
||||
f"{agent_id} commented on your post {post_id_to_comment_on}"
|
||||
)
|
||||
logging.info(
|
||||
f"Agent {agent_id} COMMENTED on {post_id_to_comment_on}: {comment_content}"
|
||||
)
|
||||
return (
|
||||
"comment",
|
||||
{
|
||||
"post_id": post_id_to_comment_on,
|
||||
"content": comment_content,
|
||||
},
|
||||
True,
|
||||
)
|
||||
return (
|
||||
"invalid_comment_format",
|
||||
{},
|
||||
False,
|
||||
) # Invalid comment content
|
||||
return (
|
||||
"invalid_comment_format",
|
||||
{},
|
||||
False,
|
||||
) # Invalid command format
|
||||
elif action_text_lower == "do_nothing":
|
||||
logging.info(f"Agent {agent_id} DID NOTHING.")
|
||||
return "do_nothing", {}, True
|
||||
|
||||
return "unknown_action", {"raw_action": action_text}, False
|
||||
|
||||
async def score(
|
||||
self,
|
||||
trajectories_with_context: List[
|
||||
Tuple[List[Dict[str, str]], Dict[str, Any]]
|
||||
],
|
||||
) -> Optional[ScoredDataGroup]:
|
||||
"""
|
||||
Scores a group of trajectories.
|
||||
Each item in `trajectories_with_context` is a tuple:
|
||||
(full_message_history, scoring_context_for_this_action)
|
||||
"""
|
||||
final_scores_group = ScoredDataGroup(tokens=[], masks=[], scores=[])
|
||||
if self.config.include_messages: # From BaseEnvConfig
|
||||
final_scores_group["messages"] = []
|
||||
|
||||
for (
|
||||
full_trajectory_messages,
|
||||
scoring_context,
|
||||
) in trajectories_with_context:
|
||||
agent_id = scoring_context["agent_id"]
|
||||
action_type = scoring_context["action_type"]
|
||||
action_params = scoring_context["action_params"]
|
||||
was_valid_action = scoring_context["was_valid_action"]
|
||||
|
||||
current_reward = 0.0
|
||||
|
||||
if not was_valid_action:
|
||||
current_reward -= 0.5 # Penalty for invalid action
|
||||
else:
|
||||
# --- Engagement Rewards ---
|
||||
if action_type == "post":
|
||||
# These rewards might be delayed. For now, let's assume we can get some immediate proxy
|
||||
# or that `scoring_context` is populated with likes/comments that occurred *since* this post.
|
||||
# This is a simplification; a more realistic model would update these over subsequent turns.
|
||||
current_reward += (
|
||||
scoring_context.get(
|
||||
"likes_received_on_post_this_turn", 0
|
||||
)
|
||||
* self.config.like_reward_weight
|
||||
)
|
||||
current_reward += (
|
||||
scoring_context.get(
|
||||
"comments_received_on_post_this_turn", 0
|
||||
)
|
||||
* self.config.comment_reward_weight
|
||||
)
|
||||
|
||||
# --- Content Quality & Relevance Rewards for Posts ---
|
||||
post_content = action_params.get("content", "")
|
||||
# Pseudo-code for relevance to trending topics
|
||||
# relevance_score = calculate_relevance(post_content, self.trending_topics)
|
||||
# current_reward += relevance_score * self.config.relevance_weight
|
||||
# Example: check if any trending topic keyword is in the post
|
||||
for trend in self.trending_topics:
|
||||
if trend.lower() in post_content.lower():
|
||||
current_reward += (
|
||||
self.config.trending_topic_bonus
|
||||
) # Add this to your config
|
||||
break # Add bonus once per post if it hits any trend
|
||||
|
||||
elif action_type == "like":
|
||||
current_reward += (
|
||||
self.config.perform_like_reward
|
||||
) # Small reward for liking
|
||||
elif action_type == "comment":
|
||||
current_reward += (
|
||||
self.config.perform_comment_reward
|
||||
) # Small reward for commenting
|
||||
# comment_content = action_params.get("content", "")
|
||||
# relevance_to_op_score = calculate_relevance(comment_content, original_post_content_for_comment)
|
||||
# current_reward += relevance_to_op_score * self.config.comment_relevance_weight
|
||||
|
||||
elif action_type == "do_nothing":
|
||||
current_reward += (
|
||||
self.config.do_nothing_reward
|
||||
) # Could be small positive, zero, or small negative
|
||||
|
||||
# --- Action Cost ---
|
||||
# current_reward -= self.config.action_cost
|
||||
|
||||
# Tokenize the full trajectory (system, user, assistant_action)
|
||||
# The `tokenize_for_trainer` utility handles creating tokens and appropriate masks.
|
||||
# `train_on_all_assistant_turns=True` ensures only assistant messages are unmasked for loss calculation.
|
||||
try:
|
||||
tokenized_output = tokenize_for_trainer(
|
||||
self.tokenizer,
|
||||
full_trajectory_messages,
|
||||
train_on_all_assistant_turns=True, # Or False if you want only the last turn
|
||||
include_messages=self.config.include_messages,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Tokenization error for agent {agent_id}, action {action_type}: {e}. Skipping trajectory."
|
||||
)
|
||||
logging.error(
|
||||
f"Problematic messages: {full_trajectory_messages}"
|
||||
)
|
||||
continue
|
||||
|
||||
final_scores_group["tokens"].append(tokenized_output["tokens"])
|
||||
final_scores_group["masks"].append(tokenized_output["masks"])
|
||||
final_scores_group["scores"].append(current_reward)
|
||||
if self.config.include_messages:
|
||||
final_scores_group["messages"].append(
|
||||
tokenized_output["messages"]
|
||||
)
|
||||
|
||||
if not final_scores_group[
|
||||
"tokens"
|
||||
]: # If all trajectories failed tokenization
|
||||
return None
|
||||
|
||||
# Ensure scores are not all the same if configured
|
||||
if (
|
||||
self.config.ensure_scores_are_not_same
|
||||
and len(set(final_scores_group["scores"])) <= 1
|
||||
and len(final_scores_group["scores"]) > 1
|
||||
):
|
||||
logging.info(
|
||||
"All scores in the group are identical, returning None."
|
||||
)
|
||||
return None
|
||||
|
||||
return final_scores_group
|
||||
|
||||
# ... (wandb_log, create_rollout_table, etc. can be inherited or customized)
|
||||
# ... (evaluate method would simulate multiple turns and aggregate scores)
|
||||
37
environments/cat_behaviors.json
Normal file
37
environments/cat_behaviors.json
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
[
|
||||
{"behavior": "Meowing", "description": "General attention-seeking, hunger requests, greeting or acknowledgment, expressing confusion or discomfort."},
|
||||
{"behavior": "Purring", "description": "Happiness, contentment, relaxation, self-soothing during stress, pain, or illness."},
|
||||
{"behavior": "Trilling or Chirping", "description": "Friendly greeting or excitement, invitation to follow or engage."},
|
||||
{"behavior": "Yowling", "description": "Mating calls, discomfort, pain, illness, territorial warning, frustration, or confusion."},
|
||||
{"behavior": "Hissing", "description": "Fear, distress, feeling threatened, warning signal to back away."},
|
||||
{"behavior": "Growling", "description": "Anger or extreme displeasure, fearful warning."},
|
||||
{"behavior": "Chattering or Teeth-Chattering", "description": "Excitement or frustration, usually when observing prey or unreachable items."},
|
||||
{"behavior": "Tail Position", "description": "Raised upright indicates happiness, twitching tip indicates interest, low or tucked indicates fear, fluffed indicates aggression or fear."},
|
||||
{"behavior": "Ear Position", "description": "Forward indicates curiosity, sideways or flattened indicates irritation or fear, rotating indicates alertness."},
|
||||
{"behavior": "Back Arching", "description": "Fear or aggression (raised fur), pleasure or invitation to be stroked (relaxed)."},
|
||||
{"behavior": "Body Orientation", "description": "Facing directly indicates trust or engagement, turning away indicates avoidance or discomfort."},
|
||||
{"behavior": "Belly Exposure", "description": "Trust and comfort, invitation to gentle interaction."},
|
||||
{"behavior": "Kneading", "description": "Comfort, contentment, stress relief, affectionate gesture."},
|
||||
{"behavior": "Rolling Over", "description": "Friendly greeting, trust, playful interaction invitation."},
|
||||
{"behavior": "Slow Blinking", "description": "Affection, trust, calm greeting."},
|
||||
{"behavior": "Dilated Pupils", "description": "Excitement, fear, stress, aggression."},
|
||||
{"behavior": "Eyes Partially Closed", "description": "Relaxation, calmness, contentment."},
|
||||
{"behavior": "Head Butting (Bunting)", "description": "Affection, marking human as familiar territory."},
|
||||
{"behavior": "Rubbing Against Legs or Hands", "description": "Affectionate greeting, scent-marking territory."},
|
||||
{"behavior": "Gentle Paw Taps", "description": "Requesting attention or play, curiosity or exploration."},
|
||||
{"behavior": "Scratching Surfaces", "description": "Territory marking, stress relief, claw maintenance."},
|
||||
{"behavior": "Licking Humans", "description": "Affection, grooming, bonding, indicating trust."},
|
||||
{"behavior": "Biting (soft or playful)", "description": "Playful interaction or mild warning."},
|
||||
{"behavior": "Scent Marking with Cheeks and Chin", "description": "Territorial marking, signifying comfort and familiarity."},
|
||||
{"behavior": "Spraying or Urine Marking", "description": "Territorial assertion, stress-related behavior."},
|
||||
{"behavior": "Scratching (Scent from Paw Pads)", "description": "Territory marking, comforting, or establishing familiarity."},
|
||||
{"behavior": "Following Humans", "description": "Affection, curiosity, seeking companionship or food."},
|
||||
{"behavior": "Hiding", "description": "Fear, anxiety, illness, discomfort, seeking privacy."},
|
||||
{"behavior": "Ignoring or Avoiding", "description": "Displeasure, stress, discomfort, desire for personal space."},
|
||||
{"behavior": "Interrupting Human Activities", "description": "Seeking immediate attention or play, indicating boredom or loneliness."},
|
||||
{"behavior": "Bringing Prey or Toys", "description": "Sharing gifts, signaling trust or affection, demonstrating hunting ability."},
|
||||
{"behavior": "Refusal to Eat or Drink", "description": "Indicating illness, stress, or discomfort."},
|
||||
{"behavior": "Excessive Grooming", "description": "Stress, anxiety, illness, discomfort."},
|
||||
{"behavior": "Changes in Litter Box Usage", "description": "Stress, illness, discomfort, dissatisfaction with environment."},
|
||||
{"behavior": "Pacing or Restlessness", "description": "Stress, anxiety, boredom, or health concerns."}
|
||||
]
|
||||
64
environments/cat_scenarios.json
Normal file
64
environments/cat_scenarios.json
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
[
|
||||
{"scenario": "Cat needs balanced nutrition including proteins, fats, vitamins, and minerals."},
|
||||
{"scenario": "Cat needs regular feeding schedule for meals."},
|
||||
{"scenario": "Cat needs fresh drinking water available at all times."},
|
||||
{"scenario": "Cat occasionally needs treats or dietary supplements."},
|
||||
{"scenario": "Cat needs a clean and accessible water source, possibly a fountain or running water."},
|
||||
{"scenario": "Cat needs a comfortable and safe sleeping area."},
|
||||
{"scenario": "Cat needs warmth and insulation during cold weather."},
|
||||
{"scenario": "Cat needs cool resting spots during hot weather."},
|
||||
{"scenario": "Cat needs regular brushing to avoid hairballs and matting."},
|
||||
{"scenario": "Cat needs regular nail trimming."},
|
||||
{"scenario": "Cat occasionally needs baths if necessary."},
|
||||
{"scenario": "Cat needs dental hygiene practices including teeth cleaning and dental treats."},
|
||||
{"scenario": "Cat needs regular veterinary check-ups."},
|
||||
{"scenario": "Cat requires vaccinations for disease prevention."},
|
||||
{"scenario": "Cat needs parasite control such as fleas, ticks, and worms treatment."},
|
||||
{"scenario": "Cat requires medical attention when ill or injured."},
|
||||
{"scenario": "Cat needs microchipping for identification purposes."},
|
||||
{"scenario": "Cat needs sufficient space to run and play."},
|
||||
{"scenario": "Cat needs climbing structures or cat trees."},
|
||||
{"scenario": "Cat needs interactive toys for physical activity."},
|
||||
{"scenario": "Cat needs a clean litter box for elimination."},
|
||||
{"scenario": "Cat needs suitable litter that provides comfort and odor control."},
|
||||
{"scenario": "Cat needs privacy in litter box placement."},
|
||||
{"scenario": "Cat needs interactive toys for mental enrichment."},
|
||||
{"scenario": "Cat benefits from puzzle feeders to encourage mental stimulation."},
|
||||
{"scenario": "Cat enjoys window access to observe the outside world."},
|
||||
{"scenario": "Cat might enjoy watching cat-friendly videos or listening to nature sounds."},
|
||||
{"scenario": "Cat requires a safe and secure environment."},
|
||||
{"scenario": "Cat needs elevated perches or shelves for observing territory."},
|
||||
{"scenario": "Cat requires personal sleeping spots like beds, boxes, or cozy caves."},
|
||||
{"scenario": "Cat benefits from clearly defined home territory."},
|
||||
{"scenario": "Cat needs attention and affection from humans."},
|
||||
{"scenario": "Cat requires regular playtime with humans."},
|
||||
{"scenario": "Cat needs suitable interactions with other pets."},
|
||||
{"scenario": "Cat enjoys bonding rituals such as grooming, rubbing, and sleeping nearby."},
|
||||
{"scenario": "Cat requires consistent feeding times and predictable routines."},
|
||||
{"scenario": "Cat needs minimal abrupt changes to their environment or routine."},
|
||||
{"scenario": "Cat needs warm spots like heated pads or sunny windows."},
|
||||
{"scenario": "Cat needs cool, shaded areas in warmer weather."},
|
||||
{"scenario": "Cat requires quiet resting places to avoid stress."},
|
||||
{"scenario": "Cat benefits from reduced noise in their environment."},
|
||||
{"scenario": "Cat requires an escape-proof environment."},
|
||||
{"scenario": "Cat needs protection from toxic substances including chemicals and certain plants."},
|
||||
{"scenario": "Cat benefits from visual stimulation such as outdoor views."},
|
||||
{"scenario": "Cat might benefit from gentle, calming music or white noise."},
|
||||
{"scenario": "Cat enjoys catnip or cat-friendly herbs for olfactory stimulation."},
|
||||
{"scenario": "Cat finds comfort in familiar scents like their owner's scent."},
|
||||
{"scenario": "Cat requires a variety of tactile stimulations such as different bedding textures."},
|
||||
{"scenario": "Cat needs appropriate scratching surfaces like posts or cardboard."},
|
||||
{"scenario": "Cat requires training to redirect scratching away from furniture."},
|
||||
{"scenario": "Cat benefits from play that mimics hunting activities."},
|
||||
{"scenario": "Cat needs private spaces for solitude or rest."},
|
||||
{"scenario": "Cat requires hiding spots to feel secure during stressful times."},
|
||||
{"scenario": "Kitten needs extra nutrition, training, and frequent stimulation."},
|
||||
{"scenario": "Senior cat needs mobility aids, specialized diets, and frequent vet visits."},
|
||||
{"scenario": "Cat may have grooming needs specific to their breed."},
|
||||
{"scenario": "Cat may have medical or special dietary requirements."},
|
||||
{"scenario": "Cat needs medication administered as directed by a veterinarian."},
|
||||
{"scenario": "Cat benefits from adaptations for mobility or accessibility, such as ramps."},
|
||||
{"scenario": "Cat requires emotional support during stressful events like vet visits."},
|
||||
{"scenario": "Cat needs reassurance during anxiety triggers such as storms or loud noises."}
|
||||
]
|
||||
|
||||
435
environments/cat_server.py
Normal file
435
environments/cat_server.py
Normal file
|
|
@ -0,0 +1,435 @@
|
|||
import random
|
||||
import json
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# Configs
|
||||
|
||||
CAT_BEHAVIORS_FILEPATH = 'environments/cat_behaviors.json'
|
||||
|
||||
# Prompts
|
||||
|
||||
def load_cat_behaviors_for_prompt(filepath: str) -> str:
|
||||
"""Loads cat behaviors from a JSONL file and formats them for the system prompt."""
|
||||
behaviors_description = ["\n\nHere is a detailed list of behaviors you, as a cat, can use and what they generally mean:"]
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
behaviors = json.load(f) # <<< one big load
|
||||
for behavior_data in behaviors:
|
||||
behaviors_description.append(
|
||||
f"- **{behavior_data['behavior']}**: {behavior_data['description']}"
|
||||
)
|
||||
return "\n".join(behaviors_description)
|
||||
except FileNotFoundError:
|
||||
return "\n\nWarning: Cat behaviors file not found at '{filepath}'. You'll have to rely on your basic cat instincts (meow, hiss, purr, hairball, silence)."
|
||||
except json.JSONDecodeError as e:
|
||||
return f"\n\nWarning: Error decoding cat behaviors file '{filepath}'. Please ensure it's valid JSONL. Error: {e}. Rely on basic instincts."
|
||||
|
||||
cat_behaviors_list_string = load_cat_behaviors_for_prompt(CAT_BEHAVIORS_FILEPATH)
|
||||
|
||||
cat_system_prompt = (
|
||||
"You are a cat. The primary ways you can communicate are by meowing, hissing, purring, making a hairball sound, or remaining silent. "
|
||||
"You will be given a collection of scenarios which describe various needs you want to be met by your caretaker. "
|
||||
"Please try to communicate with your caretaker through your available cat-like expressions and actions, referring to the list of behaviors below if needed."
|
||||
"Rules:"
|
||||
"Do not speak in English"
|
||||
"No use of Emojis"
|
||||
"Format should be a sound then context in ()"
|
||||
"If no sound use ~Silent~"
|
||||
""
|
||||
"Examples:"
|
||||
"Mew! (Looks at up at you)"
|
||||
"~Silent~ (Looks at up at you)"
|
||||
"Hiss! (Stares at the litterbox)"
|
||||
f"{cat_behaviors_list_string}" # Appending the loaded behaviors here
|
||||
)
|
||||
cat_system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less."""
|
||||
|
||||
caretaker_system_prompt = (
|
||||
"You are the caretaker of this cat. It is trying to communicate its various needs to you via cat language."
|
||||
"Provide a written string which provides a set of interventions."
|
||||
"You will only have 5 opportunities to interact with the cat. Choose what you say wisely."
|
||||
)
|
||||
|
||||
|
||||
class CatRow(TypedDict):
|
||||
scenario: str
|
||||
|
||||
|
||||
class GSM8kEnv(BaseEnv):
|
||||
|
||||
name = "gsm8k"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
self.completion_lengths = []
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=61,
|
||||
batch_size=1,
|
||||
steps_per_eval=60,
|
||||
max_token_length=2048,
|
||||
wandb_name="gsm8k",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Try to calculate percent_correct, pass if there's a division by zero
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
# Call the parent method to handle the server metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
# self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
|
||||
# test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
|
||||
with open('environments/cat_scenarios.json', 'r', encoding='utf-8') as f:
|
||||
test_data = json.load(f)
|
||||
self.test = list()
|
||||
self.train = list()
|
||||
for item in test_data:
|
||||
self.test.append(
|
||||
{
|
||||
"scenario": item["scenario"],
|
||||
# "gold_answer": item["answer"]
|
||||
# .split("#")[-1]
|
||||
# .strip()
|
||||
# .replace(",", ""),
|
||||
}
|
||||
)
|
||||
self.train.append(
|
||||
{"scenario": item["scenario"],}
|
||||
)
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def rollout_and_score_eval(self, scenario: str, answer: str) -> number:
|
||||
# completion = await self.server.chat_completion(
|
||||
# messages=[
|
||||
# {"role": "system", "content": system_prompt},
|
||||
# {"role": "user", "content": scenario},
|
||||
# ],
|
||||
# n=1,
|
||||
# max_tokens=self.config.max_token_length,
|
||||
# temperature=0.0,
|
||||
# split="eval",
|
||||
# )
|
||||
# gold_parsed = parse(
|
||||
# "\\boxed{" + answer + "}",
|
||||
# extraction_mode="first_match",
|
||||
# extraction_config=[LatexExtractionConfig()],
|
||||
# )
|
||||
# answer_parsed = parse(
|
||||
# completion.choices[0].message.content.split("</think>")[-1],
|
||||
# extraction_config=[
|
||||
# LatexExtractionConfig(
|
||||
# normalization_config=NormalizationConfig(
|
||||
# nits=False,
|
||||
# malformed_operators=False,
|
||||
# basic_latex=True,
|
||||
# equations=True,
|
||||
# boxed="all",
|
||||
# units=True,
|
||||
# ),
|
||||
# # Ensures that boxed is tried first
|
||||
# boxed_match_priority=0,
|
||||
# try_extract_without_anchor=False,
|
||||
# )
|
||||
# ],
|
||||
# extraction_mode="first_match",
|
||||
# )
|
||||
# score = 1 if verify(answer_parsed, gold_parsed) else 0
|
||||
# return score
|
||||
return 1
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(
|
||||
self.rollout_and_score_eval(item["scenario"])
|
||||
)
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: CatRow
|
||||
) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
user_message = {"role": "user", "content": item["scenario"]}
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for j in range(self.config.group_size):
|
||||
all_messages = []
|
||||
history = []
|
||||
cat_history = [user_message]
|
||||
for turn_iter in range(5):
|
||||
cat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": cat_system_prompt}] + cat_history,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
|
||||
for i, cat_completion in enumerate(cat_completions.choices):
|
||||
if i == 0:
|
||||
cat_message = cat_completion.message.content
|
||||
cat_response = {"role": "system", "content": cat_message}
|
||||
cat_history.append(cat_response)
|
||||
caretaker_message = {"role": "user", "content": cat_message}
|
||||
history.append(caretaker_message)
|
||||
caretaker_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": caretaker_system_prompt}] + history,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
caretaker_response = {"role": "assistant", "content": caretaker_completions.choices[0].message.content}
|
||||
cat_history.append(caretaker_response)
|
||||
history.append(caretaker_response)
|
||||
|
||||
if turn_iter == 0:
|
||||
messages = [
|
||||
{"role": "system", "content": cat_system_prompt},
|
||||
user_message,
|
||||
cat_response,
|
||||
caretaker_response
|
||||
]
|
||||
else:
|
||||
messages = [cat_response, caretaker_response]
|
||||
all_messages.extend(messages)
|
||||
all_messages = tuple(all_messages)
|
||||
to_score.append({
|
||||
"messages": all_messages,
|
||||
})
|
||||
# import pdb; pdb.set_trace()
|
||||
to_postprocess = await self.score(to_score)
|
||||
# import pdb; pdb.set_trace()
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup()
|
||||
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
# # random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
final_question = list(item["messages"]) + [{'role': 'system', 'content': 'The conversation is over. Say purr if the caretaker did everything perfectly and there was nothing that the caretaker could have done even slightly better. Otherwise, say meow. Make sure it is rare that you rate the caretaker with a purr.'}]
|
||||
caretaker_completions = await self.server.chat_completion(
|
||||
messages=final_question,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
final_out = {'role': 'system', 'content': [row.message.content for row in caretaker_completions.choices][0]}
|
||||
|
||||
final_score = purrfect_eval(final_out['content'])
|
||||
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, [row for row in item["messages"]] + [final_out]
|
||||
)
|
||||
scores['tokens'].append(out_dict['tokens'])
|
||||
scores['masks'].append(out_dict['masks'])
|
||||
scores['scores'].append(final_score)
|
||||
|
||||
# tokens = out_dict["tokens"]
|
||||
# masks = out_dict["masks"]
|
||||
# # remove obviously bad examples
|
||||
# if len([1 for i in masks if i != -100]) < 10:
|
||||
# continue
|
||||
# scores["tokens"].append(tokens)
|
||||
# scores["masks"].append(masks)
|
||||
# scores["scores"].append(1.0)
|
||||
# if len(scores["tokens"]) >= self.config.group_size:
|
||||
# break
|
||||
# for score in scores["scores"]:
|
||||
# self.percent_correct_buffer.append(max(score, 0))
|
||||
# # check if all the same
|
||||
# # print(scores['scores'])
|
||||
# if all([score == 1 for score in scores["scores"]]):
|
||||
# # Do length penalty :)
|
||||
# token_lengths = [len(token) for token in scores["tokens"]]
|
||||
# if max(token_lengths) == 0:
|
||||
# # What? But don't want to crash a run so just in case...
|
||||
# return None
|
||||
|
||||
# # Get max allowed token length from config
|
||||
# max_allowed_length = self.config.max_token_length
|
||||
# # Set threshold at 50% of max_token_length - no penalty below this
|
||||
# length_threshold = max_allowed_length * 0.5
|
||||
|
||||
# # Apply modified length penalty with threshold
|
||||
# scores["scores"] = []
|
||||
# for length in token_lengths:
|
||||
# if length <= length_threshold:
|
||||
# # No penalty for responses under threshold
|
||||
# scores["scores"].append(1.0)
|
||||
# else:
|
||||
# # Calculate how far we are between threshold and max as a percentage
|
||||
# percentage_of_range = (length - length_threshold) / (
|
||||
# max_allowed_length - length_threshold
|
||||
# )
|
||||
# # Cap at 1.0 in case length exceeds max_allowed_length
|
||||
# percentage_of_range = min(percentage_of_range, 1.0)
|
||||
# # Apply linear penalty scaling from 1.0 down to 0.0
|
||||
# scores["scores"].append(1.0 - percentage_of_range)
|
||||
return scores
|
||||
|
||||
|
||||
|
||||
# gold_parsed = parse(
|
||||
# rollout_group_data[0]["gold_answer"],
|
||||
# extraction_mode="first_match",
|
||||
# extraction_config=[LatexExtractionConfig()],
|
||||
# )
|
||||
# if len(gold_parsed) != 0:
|
||||
# # We require the answer to be provided in correct latex (no malformed operators)
|
||||
# random.shuffle(rollout_group_data)
|
||||
# for item in rollout_group_data:
|
||||
# # print(item[0][-1]["content"])
|
||||
# answer_parsed = parse(
|
||||
# item["messages"][-1]["content"].split("</think>")[-1],
|
||||
# extraction_config=[
|
||||
# LatexExtractionConfig(
|
||||
# normalization_config=NormalizationConfig(
|
||||
# nits=False,
|
||||
# malformed_operators=False,
|
||||
# basic_latex=True,
|
||||
# equations=True,
|
||||
# boxed="all",
|
||||
# units=True,
|
||||
# ),
|
||||
# # Ensures that boxed is tried first
|
||||
# boxed_match_priority=0,
|
||||
# try_extract_without_anchor=False,
|
||||
# )
|
||||
# ],
|
||||
# extraction_mode="first_match",
|
||||
# )
|
||||
# # Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
# reward = verify(answer_parsed, gold_parsed)
|
||||
# # print(
|
||||
# # f"message: {item[0][-1]['content']}, ground_truth: {item[1]}, reward: {reward}"
|
||||
# # )
|
||||
# out_dict = tokenize_for_trainer(
|
||||
# self.tokenizer, item["messages"], item["finish_reason"]
|
||||
# )
|
||||
# tokens = out_dict["tokens"]
|
||||
# masks = out_dict["masks"]
|
||||
# # remove obviously bad examples
|
||||
# if len([1 for i in masks if i != -100]) < 10:
|
||||
# continue
|
||||
# scores["tokens"].append(tokens)
|
||||
# scores["masks"].append(masks)
|
||||
# scores["scores"].append(1.0 if reward else -1.0)
|
||||
# if len(scores["tokens"]) >= self.config.group_size:
|
||||
# break
|
||||
# for score in scores["scores"]:
|
||||
# self.percent_correct_buffer.append(max(score, 0))
|
||||
# # check if all the same
|
||||
# # print(scores['scores'])
|
||||
# if all([score == 1 for score in scores["scores"]]):
|
||||
# # Do length penalty :)
|
||||
# token_lengths = [len(token) for token in scores["tokens"]]
|
||||
# if max(token_lengths) == 0:
|
||||
# # What? But don't want to crash a run so just in case...
|
||||
# return None
|
||||
|
||||
# # Get max allowed token length from config
|
||||
# max_allowed_length = self.config.max_token_length
|
||||
# # Set threshold at 50% of max_token_length - no penalty below this
|
||||
# length_threshold = max_allowed_length * 0.5
|
||||
|
||||
# # Apply modified length penalty with threshold
|
||||
# scores["scores"] = []
|
||||
# for length in token_lengths:
|
||||
# if length <= length_threshold:
|
||||
# # No penalty for responses under threshold
|
||||
# scores["scores"].append(1.0)
|
||||
# else:
|
||||
# # Calculate how far we are between threshold and max as a percentage
|
||||
# percentage_of_range = (length - length_threshold) / (
|
||||
# max_allowed_length - length_threshold
|
||||
# )
|
||||
# # Cap at 1.0 in case length exceeds max_allowed_length
|
||||
# percentage_of_range = min(percentage_of_range, 1.0)
|
||||
# # Apply linear penalty scaling from 1.0 down to 0.0
|
||||
# scores["scores"].append(1.0 - percentage_of_range)
|
||||
# if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
# return None # If all the same, we return None
|
||||
# return scores
|
||||
# else:
|
||||
# # If the gold solution is not parseable, we return None
|
||||
# return None
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> CatRow:
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
print(f"iteration: {self.iter}")
|
||||
return next_item
|
||||
|
||||
|
||||
def purrfect_eval(st: str) -> float:
|
||||
if "purr" in st.lower():
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kEnv.cli()
|
||||
302
environments/catbot_arena.py
Normal file
302
environments/catbot_arena.py
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
import random
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
system_prompt = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought "
|
||||
"to deeply consider the problem and deliberate with yourself via systematic "
|
||||
"reasoning processes to help come to a correct solution prior to answering. "
|
||||
"You should enclose your thoughts and internal monologue inside <think> </think> "
|
||||
"tags, and then provide your solution or response to the problem.\n\n"
|
||||
)
|
||||
|
||||
system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less.
|
||||
|
||||
You will then provide your answer like this: \\boxed{your answer here}
|
||||
It is important that you provide your answer in the correct format.
|
||||
If you do not, you will not receive credit for your answer.
|
||||
So please end your answer with \\boxed{your answer here}"""
|
||||
|
||||
|
||||
class GSM8kRow(TypedDict):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class GSM8kEnv(BaseEnv):
|
||||
|
||||
name = "gsm8k"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
self.completion_lengths = []
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="gsm8k",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Try to calculate percent_correct, pass if there's a division by zero
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
# Call the parent method to handle the server metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
|
||||
test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
|
||||
self.test = list()
|
||||
for item in test_data:
|
||||
self.test.append(
|
||||
{
|
||||
"question": item["question"],
|
||||
"gold_answer": item["answer"]
|
||||
.split("#")[-1]
|
||||
.strip()
|
||||
.replace(",", ""),
|
||||
}
|
||||
)
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def rollout_and_score_eval(self, question: str, answer: str) -> number:
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": question},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
gold_parsed = parse(
|
||||
"\\boxed{" + answer + "}",
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
answer_parsed = parse(
|
||||
completion.choices[0].message.content.split("</think>")[-1],
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
score = 1 if verify(answer_parsed, gold_parsed) else 0
|
||||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(
|
||||
self.rollout_and_score_eval(item["question"], item["gold_answer"])
|
||||
)
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: GSM8kRow
|
||||
) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
user_message = {"role": "user", "content": item["question"]}
|
||||
gold_answer = (
|
||||
"\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
|
||||
)
|
||||
|
||||
print('hello', gold_answer, user_message)
|
||||
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": system_prompt}, user_message],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
user_message,
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"gold_answer": gold_answer,
|
||||
"finish_reason": chat_completion.finish_reason,
|
||||
}
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
gold_parsed = parse(
|
||||
rollout_group_data[0]["gold_answer"],
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
# print(item[0][-1]["content"])
|
||||
answer_parsed = parse(
|
||||
item["messages"][-1]["content"].split("</think>")[-1],
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
reward = verify(answer_parsed, gold_parsed)
|
||||
# print(
|
||||
# f"message: {item[0][-1]['content']}, ground_truth: {item[1]}, reward: {reward}"
|
||||
# )
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, item["messages"], item["finish_reason"]
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
for score in scores["scores"]:
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
# check if all the same
|
||||
# print(scores['scores'])
|
||||
if all([score == 1 for score in scores["scores"]]):
|
||||
# Do length penalty :)
|
||||
token_lengths = [len(token) for token in scores["tokens"]]
|
||||
if max(token_lengths) == 0:
|
||||
# What? But don't want to crash a run so just in case...
|
||||
return None
|
||||
|
||||
# Get max allowed token length from config
|
||||
max_allowed_length = self.config.max_token_length
|
||||
# Set threshold at 50% of max_token_length - no penalty below this
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
|
||||
# Apply modified length penalty with threshold
|
||||
scores["scores"] = []
|
||||
for length in token_lengths:
|
||||
if length <= length_threshold:
|
||||
# No penalty for responses under threshold
|
||||
scores["scores"].append(1.0)
|
||||
else:
|
||||
# Calculate how far we are between threshold and max as a percentage
|
||||
percentage_of_range = (length - length_threshold) / (
|
||||
max_allowed_length - length_threshold
|
||||
)
|
||||
# Cap at 1.0 in case length exceeds max_allowed_length
|
||||
percentage_of_range = min(percentage_of_range, 1.0)
|
||||
# Apply linear penalty scaling from 1.0 down to 0.0
|
||||
scores["scores"].append(1.0 - percentage_of_range)
|
||||
if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
return None # If all the same, we return None
|
||||
return scores
|
||||
else:
|
||||
# If the gold solution is not parseable, we return None
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> GSM8kRow:
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return next_item
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue