atropos/environments/game_environments/textworld_env/textworld_env.py
shannonsands 47cb15745c
Textworld minimal (#225)
* minimal implementation, simplified challenge registry

* need game save logic

* fixed challenge gen, works with local test

* updated challenge gen with wider ranges, working with local script

* runs working correctly, wandb stats look ok

* linting

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* removed unused imports

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-08-01 10:16:35 +10:00

648 lines
24 KiB
Python

#!/usr/bin/env python3
"""
TextWorldEnv: Minimalist trainer environment for Microsoft TextWorld
A simple trainer environment that wraps TextWorld game generator and Gym interface
to train LLMs. The LLM outputs actions in plain text and receives only environment rewards.
No thinking tokens, memory, format rewards, or complex scoring - just pure environment interaction.
"""
import logging
import os
import random
import tempfile
import traceback
from typing import Any, Dict, List, Optional, Tuple
import gymnasium as gym # noqa: F401
import textworld
import textworld.challenges
import textworld.gym
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
ScoredDataItem,
)
from atroposlib.type_definitions import Item, Message
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from environments.game_environments.textworld_env.textworld_registry import (
create_textworld_registry,
) # noqa: F401
logger = logging.getLogger(__name__)
class TextWorldEnvConfig(BaseEnvConfig):
"""Configuration for the minimalist TextWorld environment trainer."""
env_name: str = "TextWorld"
wandb_name: str = "textworld-trainer-minimal"
group_size: int = 16
max_num_workers: int = 16
total_steps: int = 500
max_steps: int = 300 # max steps per episode (matches coin_collector max)
max_token_length: int = 32768
# Challenge settings
challenge_names: List[str] = [
"tw-simple",
"tw-cooking",
"tw-coin_collector",
"tw-treasure_hunter",
]
randomize_challenge_settings: bool = (
True # Randomize settings within each challenge
)
class TextWorldEnv(BaseEnv):
"""Minimalist TextWorld environment for training LLMs."""
name = "textworld_minimal"
env_config_cls = TextWorldEnvConfig
def __init__(
self,
config: TextWorldEnvConfig,
server_configs: List[APIServerConfig],
slurm: bool = True,
testing: bool = False,
):
super().__init__(config, server_configs, slurm, testing)
self.config: TextWorldEnvConfig = config
self.challenge_registry = None
# Track generated game files for cleanup
self._generated_files = set()
# Create temp directory for game files
self._temp_dir = tempfile.mkdtemp(prefix="textworld_minimal_")
# wandb logging
self.episode_outcomes_buffer = []
self.episode_rewards_buffer = []
self.episode_steps_buffer = []
self.episode_challenge_types = []
self.eval_metrics_custom = []
self.system_prompt = (
"You are playing a text-based adventure game. "
"Read the game state and respond with ONLY the action you want to take. "
"Do not include any explanation or reasoning, just the action command itself.\n\n"
"Examples of valid actions:\n"
"- go north\n"
"- take key\n"
"- open door\n"
"- examine table\n"
"- inventory\n\n"
"Respond with a single action only."
)
async def setup(self):
"""Initialize the environment and challenge registry."""
# Import registry creation from local module
try:
self.challenge_registry = create_textworld_registry()
logger.warning(
f"Initialized TextWorld challenge registry with challenges: {self.config.challenge_names}"
)
logger.warning("TextWorldEnv setup completed successfully")
except Exception as e:
logger.error(f"Failed to create TextWorld registry: {e}")
logger.error(traceback.format_exc())
raise
async def get_next_item(self) -> Item:
"""Get the next game configuration."""
# Randomly select a challenge
if len(self.config.challenge_names) == 1:
challenge_name = self.config.challenge_names[0]
else:
challenge_name = random.choice(self.config.challenge_names)
# Get challenge settings
challenge_name, settings = self.challenge_registry.get_challenge(
challenge_name, randomize_settings=self.config.randomize_challenge_settings
)
return {
"challenge_name": challenge_name,
"settings": settings,
"game_type": "challenge",
}
def _create_game(self, challenge_name: str, settings: Dict[str, Any]) -> str:
"""Create a TextWorld game and save it to a file.
Returns:
Path to the saved game file (.z8)
"""
# Create default options
options = textworld.GameOptions()
options.seeds = settings.get("seed", random.randint(0, 1000000))
if challenge_name == "tw-simple":
game_settings = {
"rewards": settings.get("rewards", "balanced"),
"goal": settings.get("goal", "detailed"),
"test": str(settings.get("test", False)).lower(),
}
game = textworld.challenges.simple.make(game_settings, options=options)
elif challenge_name == "tw-cooking":
game_settings = {
"recipe": settings.get("recipe", 1), # Number of ingredients
"take": settings.get("take", 1), # Number to find
"cook": settings.get("cook", False), # Whether to cook
"open": settings.get("open", False), # Whether to open containers
"drop": settings.get("drop", False), # Whether limited inventory
"go": settings.get("go", 1), # Number of locations
"recipe_seed": settings.get(
"recipe-seed",
settings.get("recipe_seed", random.randint(0, 1000000)),
),
"split": "train",
}
logger.debug(f"Cooking game settings: {game_settings}")
game = textworld.challenges.cooking.make(game_settings, options=options)
elif challenge_name == "tw-coin_collector":
game_settings = {"level": settings.get("level", 1)}
game = textworld.challenges.coin_collector.make(
game_settings, options=options
)
elif challenge_name == "tw-treasure_hunter":
game_settings = {"level": settings.get("level", 1)}
game = textworld.challenges.treasure_hunter.make(
game_settings, options=options
)
else:
raise ValueError(f"Unknown challenge: {challenge_name}")
# Save gamefile
game_file = os.path.join(
self._temp_dir,
f"{challenge_name}_{settings.get('seed', random.randint(0, 1000000))}.z8",
)
options.path = game_file
options.file_ext = ".z8"
game_file = textworld.generator.compile_game(game, options)
# Track for cleanup
self._generated_files.add(game_file)
return game_file
async def collect_trajectories(
self, item: Item
) -> Tuple[ScoredDataGroup, List[Item]]:
"""Collect parallel trajectories from the same game."""
challenge_name = item["challenge_name"]
settings = item["settings"]
game_file = self._create_game(challenge_name, settings)
# Register the gamefile
request_infos = textworld.EnvInfos(
description=True,
inventory=True,
objective=True,
admissible_commands=True,
won=True,
lost=True,
score=True,
moves=True,
max_score=True,
)
env_id = textworld.gym.register_game(game_file, request_infos)
scored_items = []
for i in range(self.config.group_size):
try:
scored_item = await self._collect_single_trajectory(
env_id, i, challenge_name
)
if scored_item:
scored_items.append(scored_item)
except Exception as e:
logger.error(f"Error collecting trajectory {i}: {e}")
continue
if not scored_items:
return (
ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
messages=[],
advantages=None,
ref_logprobs=None,
group_overrides={},
overrides=None,
images=None,
),
[],
)
sdg = ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
messages=[],
advantages=None,
ref_logprobs=None,
group_overrides={},
overrides=None,
images=None,
)
for scored_item in scored_items:
sdg["tokens"].append(scored_item["tokens"])
sdg["masks"].append(scored_item["masks"])
sdg["scores"].append(scored_item["scores"])
if self.config.include_messages and scored_item.get("messages"):
sdg["messages"].append(scored_item["messages"])
metadata = scored_item.get("metadata", {})
final_score = metadata.get("final_score", 0)
won = metadata.get("won", False)
lost = metadata.get("lost", False)
moves = metadata.get("moves", 0)
if won:
outcome = 1.0
elif lost:
outcome = -1.0
else:
outcome = 0.0
self.episode_outcomes_buffer.append(outcome)
self.episode_rewards_buffer.append(final_score)
self.episode_steps_buffer.append(moves)
self.episode_challenge_types.append(
metadata.get("challenge_name", "unknown")
)
self._cleanup_game_file(game_file)
return sdg, []
async def _collect_single_trajectory(
self, env_id: str, trajectory_idx: int, challenge_name: str = "unknown"
) -> Optional[ScoredDataItem]:
"""Collect a single trajectory for the game."""
messages: List[Message] = []
try:
env = textworld.gym.make(env_id)
obs, info = env.reset()
messages.append({"role": "system", "content": self.system_prompt})
obs_text = self._format_observation(obs, info)
messages.append({"role": "user", "content": obs_text})
done = False
total_reward = 0.0
async with self.server.dedicated_server() as server:
while not done and len(messages) < self.config.max_steps * 2:
current_tokens = len(
self.tokenizer.apply_chat_template(messages, tokenize=True)
)
if current_tokens > self.config.max_token_length - 50:
logger.warning(
f"Trajectory {trajectory_idx}: Approaching token limit, ending episode"
)
logger.info(f"Token usage: {current_tokens} tokens")
logger.info(f"Number of messages: {len(messages)}")
logger.info(
f"Last observation length: {len(messages[-1]['content']) if messages else 0}"
)
break
try:
response = await server.chat_completion(
messages=messages,
n=1,
max_tokens=50, # short actions, no thinking
temperature=0.7,
)
action = response.choices[0].message.content.strip()
logger.debug(f"Trajectory {trajectory_idx}: Action: {action}")
except Exception as e:
logger.error(f"Trajectory {trajectory_idx}: LLM error: {e}")
break
messages.append({"role": "assistant", "content": action})
try:
obs, reward, done, info = env.step(action)
total_reward += reward
except Exception as e:
logger.error(
f"Trajectory {trajectory_idx}: Environment error: {e}"
)
break
if not done:
obs_text = self._format_observation(obs, info)
messages.append({"role": "user", "content": obs_text})
env.close()
tokenization_result = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=messages,
train_on_all_assistant_turns=True,
)
return ScoredDataItem(
messages=messages if self.config.include_messages else None,
tokens=tokenization_result["tokens"],
masks=tokenization_result["masks"],
scores=total_reward,
metadata={
"trajectory_idx": trajectory_idx,
"final_score": total_reward,
"won": info.get("won", False),
"lost": info.get("lost", False),
"moves": info.get("moves", 0),
"challenge_name": challenge_name,
},
)
except Exception as e:
logger.error(f"Trajectory {trajectory_idx}: Fatal error: {e}")
logger.error(traceback.format_exc())
return None
def _format_observation(self, obs: str, info: Dict[str, Any]) -> str:
"""Format game observation for the LLM."""
parts = []
# Main observation
parts.append(obs)
# Add objective if available
if "objective" in info and info["objective"]:
parts.append(f"\nObjective: {info['objective']}")
# Add score info
if "score" in info and "max_score" in info:
parts.append(f"\nScore: {info['score']}/{info['max_score']}")
# Add inventory if not empty
if "inventory" in info and info["inventory"]:
parts.append(f"\nInventory: {info['inventory']}")
return "\n".join(parts)
@classmethod
def config_init(cls) -> Tuple[TextWorldEnvConfig, List[APIServerConfig]]:
"""Initialize default configuration."""
env_config = TextWorldEnvConfig(
tokenizer_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
group_size=16,
use_wandb=True,
wandb_name=cls.name,
max_token_length=32768,
total_steps=500,
challenge_names=[
"tw-simple",
"tw-cooking",
"tw-coin_collector",
"tw-treasure_hunter",
],
randomize_challenge_settings=True,
)
server_configs = [
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9005/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9006/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9007/v1",
api_key="x",
num_requests_for_eval=128,
),
]
return env_config, server_configs
# TODO: implement evaluation properly re eval changes
async def evaluate(self, num_items: int) -> Dict[str, Any]:
"""Evaluate the model - not implemented for this minimal environment."""
logger.warning("Evaluation not implemented in minimal TextWorld environment")
return {"message": "Evaluation not implemented"}
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log episode statistics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
# Log training episode outcomes
if self.episode_outcomes_buffer:
wins = sum(1 for outcome in self.episode_outcomes_buffer if outcome > 0)
losses = sum(1 for outcome in self.episode_outcomes_buffer if outcome < 0)
draws = sum(1 for outcome in self.episode_outcomes_buffer if outcome == 0)
total_episodes = len(self.episode_outcomes_buffer)
win_rate = (wins / total_episodes) * 100 if total_episodes > 0 else 0.0
loss_rate = (losses / total_episodes) * 100 if total_episodes > 0 else 0.0
draw_rate = (draws / total_episodes) * 100 if total_episodes > 0 else 0.0
avg_steps = (
sum(self.episode_steps_buffer) / len(self.episode_steps_buffer)
if self.episode_steps_buffer
else 0
)
avg_outcome = (
sum(self.episode_outcomes_buffer) / len(self.episode_outcomes_buffer)
if self.episode_outcomes_buffer
else 0
)
avg_reward = (
sum(self.episode_rewards_buffer) / len(self.episode_rewards_buffer)
if self.episode_rewards_buffer
else 0
)
max_reward = (
max(self.episode_rewards_buffer) if self.episode_rewards_buffer else 0
)
min_reward = (
min(self.episode_rewards_buffer) if self.episode_rewards_buffer else 0
)
wandb_metrics[f"{self.name}/train/total_episodes"] = total_episodes
wandb_metrics[f"{self.name}/train/win_count_absolute"] = wins
wandb_metrics[f"{self.name}/train/loss_count_absolute"] = losses
wandb_metrics[f"{self.name}/train/draw_count_absolute"] = draws
wandb_metrics[f"{self.name}/train/win_rate_percent"] = win_rate
wandb_metrics[f"{self.name}/train/loss_rate_percent"] = loss_rate
wandb_metrics[f"{self.name}/train/draw_rate_percent"] = draw_rate
wandb_metrics[f"{self.name}/train/avg_episode_steps"] = avg_steps
wandb_metrics[f"{self.name}/train/avg_outcome"] = (
avg_outcome # -1, 0, 1 average
)
wandb_metrics[f"{self.name}/train/avg_reward_score"] = (
avg_reward # Actual game score
)
wandb_metrics[f"{self.name}/train/max_reward_score"] = max_reward
wandb_metrics[f"{self.name}/train/min_reward_score"] = min_reward
# Per-challenge statistics
challenge_stats = {}
for i, (outcome, reward, steps, challenge) in enumerate(
zip(
self.episode_outcomes_buffer,
self.episode_rewards_buffer,
self.episode_steps_buffer,
self.episode_challenge_types,
)
):
if challenge not in challenge_stats:
challenge_stats[challenge] = {
"outcomes": [],
"rewards": [],
"steps": [],
"count": 0,
}
challenge_stats[challenge]["outcomes"].append(outcome)
challenge_stats[challenge]["rewards"].append(reward)
challenge_stats[challenge]["steps"].append(steps)
challenge_stats[challenge]["count"] += 1
for challenge, stats in challenge_stats.items():
challenge_wins = sum(1 for o in stats["outcomes"] if o > 0)
challenge_losses = sum(1 for o in stats["outcomes"] if o < 0)
challenge_draws = sum(1 for o in stats["outcomes"] if o == 0)
challenge_total = stats["count"]
wandb_metrics[f"{self.name}/train/{challenge}/episodes_count"] = (
challenge_total
)
wandb_metrics[f"{self.name}/train/{challenge}/wins_count"] = (
challenge_wins
)
wandb_metrics[f"{self.name}/train/{challenge}/losses_count"] = (
challenge_losses
)
wandb_metrics[f"{self.name}/train/{challenge}/draws_count"] = (
challenge_draws
)
wandb_metrics[f"{self.name}/train/{challenge}/win_rate_percent"] = (
(challenge_wins / challenge_total) * 100
if challenge_total > 0
else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/loss_rate_percent"] = (
(challenge_losses / challenge_total) * 100
if challenge_total > 0
else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/draw_rate_percent"] = (
(challenge_draws / challenge_total) * 100
if challenge_total > 0
else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/avg_steps"] = (
sum(stats["steps"]) / len(stats["steps"]) if stats["steps"] else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/avg_outcome"] = (
sum(stats["outcomes"]) / len(stats["outcomes"])
if stats["outcomes"]
else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/avg_reward_score"] = (
sum(stats["rewards"]) / len(stats["rewards"])
if stats["rewards"]
else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/max_reward_score"] = (
max(stats["rewards"]) if stats["rewards"] else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/min_reward_score"] = (
min(stats["rewards"]) if stats["rewards"] else 0
)
self.episode_outcomes_buffer = []
self.episode_rewards_buffer = []
self.episode_steps_buffer = []
self.episode_challenge_types = []
# Log eval metrics if any
for key, value in self.eval_metrics_custom:
wandb_metrics[key] = value
self.eval_metrics_custom = []
await super().wandb_log(wandb_metrics)
def _cleanup_game_file(self, game_file: str):
"""Clean up a generated game file and its associated files."""
if game_file in self._generated_files:
try:
# Remove .z8 file
if os.path.exists(game_file):
os.remove(game_file)
logger.debug(f"Removed game file: {game_file}")
# Remove .ni file if it exists
ni_file = game_file.replace(".z8", ".ni")
if os.path.exists(ni_file):
os.remove(ni_file)
logger.debug(f"Removed ni file: {ni_file}")
# Remove .json file if it exists
json_file = game_file.replace(".z8", ".json")
if os.path.exists(json_file):
os.remove(json_file)
logger.debug(f"Removed json file: {json_file}")
# TODO: handle .z5's from manually downloaded games when that's added
self._generated_files.remove(game_file)
except OSError as e:
logger.warning(f"Failed to clean up game file {game_file}: {e}")
def __del__(self):
"""Ensure cleanup on deletion."""
# Clean up any remaining game files
files_to_clean = list(self._generated_files)
for game_file in files_to_clean:
self._cleanup_game_file(game_file)
# Clean up local temp directory
if hasattr(self, "_temp_dir") and os.path.exists(self._temp_dir):
import shutil
try:
shutil.rmtree(self._temp_dir)
logger.debug(f"Cleaned up temp directory: {self._temp_dir}")
except Exception:
pass
if __name__ == "__main__":
TextWorldEnv.cli()