mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
491 lines
22 KiB
Python
491 lines
22 KiB
Python
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.type_definitions import Item
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
# --- Mocked Trending Topics ---
|
|
# In a real implementation, this would come from an API
|
|
MOCK_TRENDING_TOPICS = [
|
|
"AI in Art",
|
|
"Climate Change Solutions",
|
|
"New Space Discoveries",
|
|
]
|
|
|
|
# --- Mocked Social Media State (Simplified) ---
|
|
# This would be more complex in a real environment, likely managed by dedicated classes
|
|
# and updated dynamically.
|
|
MOCK_SOCIAL_FEED = [
|
|
{
|
|
"id": "post1",
|
|
"agent_id": "agent_alpha",
|
|
"content": "Just enjoyed a great virtual concert! #metaverse",
|
|
"likes": 10,
|
|
"comments": [],
|
|
},
|
|
{
|
|
"id": "post2",
|
|
"agent_id": "agent_beta",
|
|
"content": "Excited about the upcoming Atropos hackathon!",
|
|
"likes": 15,
|
|
"comments": [],
|
|
},
|
|
]
|
|
MOCK_AGENT_PROFILES = {
|
|
"agent_gamma": {"posts": [], "score": 0, "notifications": []},
|
|
"agent_delta": {"posts": [], "score": 0, "notifications": []},
|
|
}
|
|
|
|
|
|
# We'll need to define a system prompt
|
|
SYSTEM_PROMPT_TEMPLATE = """You are '{agent_id}', an agent on Xitter, a simulated social media platform.
|
|
Your goal is to maximize engagement by posting interesting content (Xits), liking relevant Xits,
|
|
and making insightful comments. You can also choose to DO_NOTHING.
|
|
Current trending topics: {trending_topics}
|
|
|
|
Recent Xits in your feed (newest first):
|
|
{feed_preview}
|
|
|
|
Your recent notifications:
|
|
{notifications_preview}
|
|
|
|
Choose one action:
|
|
1. POST <your_xits_content_here> (max 140 chars)
|
|
2. LIKE <post_id_to_like> (e.g., LIKE post_3)
|
|
3. COMMENT <post_id_to_comment_on> <your_comment_content_here> (e.g., COMMENT post_2 Great point!)
|
|
4. DO_NOTHING
|
|
|
|
Your response should be ONLY the action string (e.g., "POST This is my new Xit! #awesome").
|
|
"""
|
|
|
|
|
|
@dataclass
|
|
class XitterEnvConfig(BaseEnvConfig):
|
|
"""Configuration for the Xitter (Social Media) Environment."""
|
|
|
|
# Reward weights
|
|
like_reward_weight: float = 0.1
|
|
comment_reward_weight: float = 0.5
|
|
trending_topic_bonus: float = 0.2
|
|
perform_like_reward: float = 0.05 # Small reward for the act of liking
|
|
perform_comment_reward: float = 0.1 # Small reward for the act of commenting
|
|
do_nothing_reward: float = 0.0 # Reward for doing nothing
|
|
invalid_action_penalty: float = -0.5
|
|
action_cost: float = -0.01 # Small cost for any action
|
|
|
|
# Environment parameters
|
|
num_agents: int = 2
|
|
max_feed_size: int = 20
|
|
max_notifications_display: int = 5 # How many notifications to show in prompt
|
|
initial_trending_topics: List[str] = field(
|
|
default_factory=lambda: ["AI in Art", "Climate Solutions", "Space Exploration"]
|
|
)
|
|
|
|
# For wandb logging of agent-specific scores
|
|
track_individual_agent_scores: bool = True
|
|
|
|
|
|
class XitterEnv(BaseEnv):
|
|
# Assuming XitterEnvConfig is defined elsewhere and includes:
|
|
# group_size, max_token_length, tokenizer_name, various reward_weights, etc.
|
|
env_config_cls = XitterEnvConfig # Assign your custom config
|
|
|
|
def __init__(
|
|
self,
|
|
config: BaseEnvConfig, # Use BaseEnvConfig or your specific XitterEnvConfig
|
|
server_configs: List[APIServerConfig],
|
|
slurm=False,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
# Initialize social media state
|
|
self.social_feed = MOCK_SOCIAL_FEED # list of posts
|
|
self.agent_profiles = MOCK_AGENT_PROFILES # dict of agent_id -> profile_data
|
|
self.trending_topics = MOCK_TRENDING_TOPICS
|
|
self.current_agent_turn = 0 # Simple round-robin for turns
|
|
self.agent_ids = list(self.agent_profiles.keys())
|
|
|
|
# Example reward function instances (you'd define these)
|
|
# self.engagement_reward_fn = EngagementReward(...)
|
|
# self.relevance_reward_fn = RelevanceReward(...)
|
|
|
|
async def setup(self):
|
|
# Load tokenizer, initialize agents, fetch initial trends, etc.
|
|
# self.tokenizer is already initialized in BaseEnv
|
|
logging.info(f"{self.name or 'XitterEnv'} setup complete.")
|
|
# Potentially fetch initial trending topics here
|
|
# self.trending_topics = await self.fetch_trending_topics()
|
|
|
|
async def get_next_item(self) -> Item:
|
|
# Determine which agent's turn it is and prepare their observation
|
|
agent_id_for_turn = self.agent_ids[
|
|
self.current_agent_turn % len(self.agent_ids)
|
|
]
|
|
self.current_agent_turn += 1
|
|
|
|
# Construct observation for the agent
|
|
# For simplicity, we'll just pass the agent_id and let collect_trajectories build the full prompt
|
|
# In a more complex setup, you'd build a richer observation here.
|
|
# The Item can be any structure your collect_trajectories method expects.
|
|
# Here, a tuple: (agent_id_acting, current_feed_snapshot, current_trends, agent_notifications)
|
|
# For now, let's simplify and pass agent_id and have collect_trajectories build the prompt.
|
|
return (
|
|
agent_id_for_turn,
|
|
{
|
|
"trending_topics": self.trending_topics,
|
|
"feed_preview": self.social_feed[:5],
|
|
},
|
|
)
|
|
|
|
async def collect_trajectories(
|
|
self, item: Item
|
|
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
|
"""
|
|
Generates a group of potential actions for an agent and gathers data for scoring.
|
|
The `item` from `get_next_item` should provide context for the current agent's turn.
|
|
"""
|
|
agent_id_acting, observation_context = item
|
|
# trending_topics_str = ", ".join(observation_context.get("trending_topics", []))
|
|
|
|
# Construct the prompt for the LLM agent based on its observation
|
|
# This would include a view of the feed, notifications, and trending topics.
|
|
# For this example, we'll use a simplified prompt.
|
|
# The actual prompt engineering is a key part of designing the environment.
|
|
recent_posts = str(observation_context.get("feed_preview", []))
|
|
prompt_content = f"It's your turn, {agent_id_acting}. Recent posts: {recent_posts}. What do you do?"
|
|
|
|
feed_preview_text = ", ".join(
|
|
[post["content"] for post in observation_context.get("feed_preview", [])]
|
|
)
|
|
notifications_text = ", ".join(
|
|
self.agent_profiles[agent_id_acting].get("notifications", [])
|
|
)
|
|
|
|
messages_for_llm = [
|
|
{
|
|
"role": "system",
|
|
"content": SYSTEM_PROMPT_TEMPLATE.format(
|
|
agent_id=agent_id_acting,
|
|
trending_topics=", ".join(self.trending_topics),
|
|
feed_preview=feed_preview_text,
|
|
notifications_preview=notifications_text,
|
|
),
|
|
},
|
|
{"role": "user", "content": prompt_content},
|
|
]
|
|
|
|
# Apply chat template for the LLM
|
|
# The prefill is not used here as the action choice is part of the LLM's generation
|
|
prompt_str_for_llm = self.tokenizer.apply_chat_template(
|
|
messages_for_llm,
|
|
tokenize=False,
|
|
add_generation_prompt=True, # Important for instruct/chat models
|
|
)
|
|
|
|
# Get self.config.group_size completions from the LLM
|
|
# Each completion is a potential action (post, like, comment)
|
|
try:
|
|
completions_obj = await self.server.completion( # Using completion for free-form action generation
|
|
prompt=prompt_str_for_llm, # BaseEnv.server.completion expects a string prompt
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_token_length
|
|
// 4, # Max tokens for the action itself
|
|
temperature=0.7, # Allow some diversity in actions
|
|
)
|
|
except Exception as e:
|
|
logging.error(
|
|
f"Error getting completions from LLM for agent {agent_id_acting}: {e}"
|
|
)
|
|
return None, []
|
|
|
|
# This list will hold tuples of (full_chat_history_for_action, action_details_for_scoring)
|
|
# where action_details_for_scoring will be passed to the `score` method.
|
|
trajectories_for_scoring: List[Tuple[List[Dict[str, str]], Dict[str, Any]]] = []
|
|
|
|
for choice in completions_obj.choices:
|
|
llm_generated_action_text = choice.text.strip()
|
|
|
|
# Simulate the action and its effect on the environment state
|
|
# This is a placeholder for your actual simulation logic.
|
|
# It needs to parse llm_generated_action_text (e.g., "POST My new cat video #cats")
|
|
# and update self.social_feed, self.agent_profiles, etc.
|
|
action_type, action_params, action_valid = self._parse_and_simulate_action(
|
|
agent_id_acting, llm_generated_action_text
|
|
)
|
|
|
|
if not action_valid:
|
|
# Potentially penalize invalid actions in the score function or give a default low score
|
|
# For now, we'll still include it to be scored (and likely penalized)
|
|
logging.warning(
|
|
f"Agent {agent_id_acting} performed an invalid action: {llm_generated_action_text}"
|
|
)
|
|
|
|
# Create the full message history for this trajectory (system, user, assistant_action)
|
|
# The `messages_for_llm` already contains system and user turns.
|
|
current_trajectory_messages = messages_for_llm + [
|
|
{"role": "assistant", "content": llm_generated_action_text}
|
|
]
|
|
|
|
# Context needed by the score function for this specific action
|
|
# This will depend heavily on your reward components.
|
|
scoring_context = {
|
|
"agent_id": agent_id_acting,
|
|
"action_type": action_type, # "post", "like", "comment", "invalid", "do_nothing"
|
|
"action_params": action_params, # e.g., post_id for like, content for post
|
|
"was_valid_action": action_valid,
|
|
# For scoring a "post" action, you'd later fill these after observing next turn's interactions:
|
|
# "likes_received_on_post": X,
|
|
# "comments_received_on_post": Y,
|
|
"trending_topics": self.trending_topics, # Pass current trends for relevance scoring
|
|
}
|
|
# If the action was a 'post', we store the new post_id in scoring_context
|
|
# so that the score function can later attribute likes/comments to it.
|
|
# This implies that rewards for posts might be delayed by one or more turns.
|
|
if action_type == "post" and "new_post_id" in action_params:
|
|
scoring_context["post_id_created"] = action_params["new_post_id"]
|
|
|
|
trajectories_for_scoring.append(
|
|
(current_trajectory_messages, scoring_context)
|
|
)
|
|
|
|
# The `score` method will take this list and produce the ScoredDataGroup
|
|
# The actual rewards might be assigned in `score` based on the outcome of these actions,
|
|
# potentially looking at the state *after* all agents in a round have acted,
|
|
# or even after a delay (e.g. likes/comments on a post arrive in future turns).
|
|
# For simplicity, this example implies immediate scoring, but delayed rewards are common.
|
|
|
|
# Pass the collected trajectories and their scoring contexts to the score method
|
|
scored_data_group = await self.score(trajectories_for_scoring)
|
|
|
|
# No backlog items in this simple version
|
|
return scored_data_group, []
|
|
|
|
def _parse_and_simulate_action(
|
|
self, agent_id: str, action_text: str
|
|
) -> Tuple[str, Dict, bool]:
|
|
"""
|
|
Parses the LLM's action string and simulates its effect on the environment.
|
|
Returns: (action_type, action_params, was_valid)
|
|
This is a placeholder and needs to be implemented based on your defined action space.
|
|
"""
|
|
action_text_lower = action_text.lower()
|
|
new_post_id_counter = len(self.social_feed)
|
|
|
|
if action_text_lower.startswith("post "):
|
|
content = action_text[5:].strip()
|
|
if content:
|
|
new_post_id_counter += 1
|
|
post_id = f"post{new_post_id_counter}"
|
|
new_post = {
|
|
"id": post_id,
|
|
"agent_id": agent_id,
|
|
"content": content,
|
|
"likes": 0,
|
|
"comments": [],
|
|
}
|
|
self.social_feed.insert(0, new_post) # Add to top of feed
|
|
self.agent_profiles[agent_id]["posts"].append(post_id)
|
|
logging.info(f"Agent {agent_id} POSTED: {content}")
|
|
return (
|
|
"post",
|
|
{"content": content, "new_post_id": post_id},
|
|
True,
|
|
)
|
|
elif action_text_lower.startswith("like "):
|
|
try:
|
|
post_id_to_like = action_text.split(" ")[1]
|
|
for post in self.social_feed:
|
|
if post["id"] == post_id_to_like:
|
|
post["likes"] += 1
|
|
# Notify original poster (simplified)
|
|
original_poster_id = post["agent_id"]
|
|
if (
|
|
original_poster_id != agent_id
|
|
and original_poster_id in self.agent_profiles
|
|
):
|
|
self.agent_profiles[original_poster_id][
|
|
"notifications"
|
|
].append(f"{agent_id} liked your post {post_id_to_like}")
|
|
logging.info(f"Agent {agent_id} LIKED post: {post_id_to_like}")
|
|
return "like", {"post_id": post_id_to_like}, True
|
|
except IndexError:
|
|
return "invalid_like_format", {}, False
|
|
return (
|
|
"like_post_not_found",
|
|
{"post_id": action_text.split(" ")[1]},
|
|
False,
|
|
) # Post not found
|
|
elif action_text_lower.startswith("comment "):
|
|
parts = action_text.split(" ", 2)
|
|
if len(parts) == 3:
|
|
post_id_to_comment_on = parts[1]
|
|
comment_content = parts[2].strip()
|
|
if comment_content:
|
|
for post in self.social_feed:
|
|
if post["id"] == post_id_to_comment_on:
|
|
comment_id = f"comment{len(post['comments']) + 1}_on_{post_id_to_comment_on}"
|
|
post["comments"].append(
|
|
{
|
|
"id": comment_id,
|
|
"agent_id": agent_id,
|
|
"content": comment_content,
|
|
}
|
|
)
|
|
# Notify original poster (simplified)
|
|
original_poster_id = post["agent_id"]
|
|
if (
|
|
original_poster_id != agent_id
|
|
and original_poster_id in self.agent_profiles
|
|
):
|
|
self.agent_profiles[original_poster_id][
|
|
"notifications"
|
|
].append(
|
|
f"{agent_id} commented on your post {post_id_to_comment_on}"
|
|
)
|
|
logging.info(
|
|
f"Agent {agent_id} COMMENTED on {post_id_to_comment_on}: {comment_content}"
|
|
)
|
|
return (
|
|
"comment",
|
|
{
|
|
"post_id": post_id_to_comment_on,
|
|
"content": comment_content,
|
|
},
|
|
True,
|
|
)
|
|
return (
|
|
"invalid_comment_format",
|
|
{},
|
|
False,
|
|
) # Invalid comment content
|
|
return (
|
|
"invalid_comment_format",
|
|
{},
|
|
False,
|
|
) # Invalid command format
|
|
elif action_text_lower == "do_nothing":
|
|
logging.info(f"Agent {agent_id} DID NOTHING.")
|
|
return "do_nothing", {}, True
|
|
|
|
return "unknown_action", {"raw_action": action_text}, False
|
|
|
|
async def score(
|
|
self,
|
|
trajectories_with_context: List[Tuple[List[Dict[str, str]], Dict[str, Any]]],
|
|
) -> Optional[ScoredDataGroup]:
|
|
"""
|
|
Scores a group of trajectories.
|
|
Each item in `trajectories_with_context` is a tuple:
|
|
(full_message_history, scoring_context_for_this_action)
|
|
"""
|
|
final_scores_group = ScoredDataGroup(tokens=[], masks=[], scores=[])
|
|
if self.config.include_messages: # From BaseEnvConfig
|
|
final_scores_group["messages"] = []
|
|
|
|
for (
|
|
full_trajectory_messages,
|
|
scoring_context,
|
|
) in trajectories_with_context:
|
|
agent_id = scoring_context["agent_id"]
|
|
action_type = scoring_context["action_type"]
|
|
action_params = scoring_context["action_params"]
|
|
was_valid_action = scoring_context["was_valid_action"]
|
|
|
|
current_reward = 0.0
|
|
|
|
if not was_valid_action:
|
|
current_reward -= 0.5 # Penalty for invalid action
|
|
else:
|
|
# --- Engagement Rewards ---
|
|
if action_type == "post":
|
|
# These rewards might be delayed. For now, let's assume we can get some immediate proxy
|
|
# or that `scoring_context` is populated with likes/comments that occurred *since* this post.
|
|
# This is a simplification; a more realistic model would update these over subsequent turns.
|
|
current_reward += (
|
|
scoring_context.get("likes_received_on_post_this_turn", 0)
|
|
* self.config.like_reward_weight
|
|
)
|
|
current_reward += (
|
|
scoring_context.get("comments_received_on_post_this_turn", 0)
|
|
* self.config.comment_reward_weight
|
|
)
|
|
|
|
# --- Content Quality & Relevance Rewards for Posts ---
|
|
post_content = action_params.get("content", "")
|
|
# Pseudo-code for relevance to trending topics
|
|
# relevance_score = calculate_relevance(post_content, self.trending_topics)
|
|
# current_reward += relevance_score * self.config.relevance_weight
|
|
# Example: check if any trending topic keyword is in the post
|
|
for trend in self.trending_topics:
|
|
if trend.lower() in post_content.lower():
|
|
current_reward += (
|
|
self.config.trending_topic_bonus
|
|
) # Add this to your config
|
|
break # Add bonus once per post if it hits any trend
|
|
|
|
elif action_type == "like":
|
|
current_reward += (
|
|
self.config.perform_like_reward
|
|
) # Small reward for liking
|
|
elif action_type == "comment":
|
|
current_reward += (
|
|
self.config.perform_comment_reward
|
|
) # Small reward for commenting
|
|
# comment_content = action_params.get("content", "")
|
|
# relevance_to_op_score = calculate_relevance(comment_content, original_post_content_for_comment)
|
|
# current_reward += relevance_to_op_score * self.config.comment_relevance_weight
|
|
|
|
elif action_type == "do_nothing":
|
|
current_reward += (
|
|
self.config.do_nothing_reward
|
|
) # Could be small positive, zero, or small negative
|
|
|
|
# --- Action Cost ---
|
|
# current_reward -= self.config.action_cost
|
|
|
|
# Tokenize the full trajectory (system, user, assistant_action)
|
|
# The `tokenize_for_trainer` utility handles creating tokens and appropriate masks.
|
|
# `train_on_all_assistant_turns=True` ensures only assistant messages are unmasked for loss calculation.
|
|
try:
|
|
tokenized_output = tokenize_for_trainer(
|
|
self.tokenizer,
|
|
full_trajectory_messages,
|
|
train_on_all_assistant_turns=True, # Or False if you want only the last turn
|
|
include_messages=self.config.include_messages,
|
|
)
|
|
except Exception as e:
|
|
logging.error(
|
|
f"Tokenization error for agent {agent_id}, action {action_type}: {e}. Skipping trajectory."
|
|
)
|
|
logging.error(f"Problematic messages: {full_trajectory_messages}")
|
|
continue
|
|
|
|
final_scores_group["tokens"].append(tokenized_output["tokens"])
|
|
final_scores_group["masks"].append(tokenized_output["masks"])
|
|
final_scores_group["scores"].append(current_reward)
|
|
if self.config.include_messages:
|
|
final_scores_group["messages"].append(tokenized_output["messages"])
|
|
|
|
if not final_scores_group["tokens"]: # If all trajectories failed tokenization
|
|
return None
|
|
|
|
# Ensure scores are not all the same if configured
|
|
if (
|
|
self.config.ensure_scores_are_not_same
|
|
and len(set(final_scores_group["scores"])) <= 1
|
|
and len(final_scores_group["scores"]) > 1
|
|
):
|
|
logging.info("All scores in the group are identical, returning None.")
|
|
return None
|
|
|
|
return final_scores_group
|
|
|
|
# ... (wandb_log, create_rollout_table, etc. can be inherited or customized)
|
|
# ... (evaluate method would simulate multiple turns and aggregate scores)
|