Textworld minimal (#225)

* minimal implementation, simplified challenge registry

* need game save logic

* fixed challenge gen, works with local test

* updated challenge gen with wider ranges, working with local script

* runs working correctly, wandb stats look ok

* linting

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* removed unused imports

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
shannonsands 2025-08-01 10:16:35 +10:00 committed by GitHub
parent 1900a577d7
commit 47cb15745c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 1002 additions and 0 deletions

View file

@ -0,0 +1,648 @@
#!/usr/bin/env python3
"""
TextWorldEnv: Minimalist trainer environment for Microsoft TextWorld
A simple trainer environment that wraps TextWorld game generator and Gym interface
to train LLMs. The LLM outputs actions in plain text and receives only environment rewards.
No thinking tokens, memory, format rewards, or complex scoring - just pure environment interaction.
"""
import logging
import os
import random
import tempfile
import traceback
from typing import Any, Dict, List, Optional, Tuple
import gymnasium as gym # noqa: F401
import textworld
import textworld.challenges
import textworld.gym
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
ScoredDataItem,
)
from atroposlib.type_definitions import Item, Message
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from environments.game_environments.textworld_env.textworld_registry import (
create_textworld_registry,
) # noqa: F401
logger = logging.getLogger(__name__)
class TextWorldEnvConfig(BaseEnvConfig):
"""Configuration for the minimalist TextWorld environment trainer."""
env_name: str = "TextWorld"
wandb_name: str = "textworld-trainer-minimal"
group_size: int = 16
max_num_workers: int = 16
total_steps: int = 500
max_steps: int = 300 # max steps per episode (matches coin_collector max)
max_token_length: int = 32768
# Challenge settings
challenge_names: List[str] = [
"tw-simple",
"tw-cooking",
"tw-coin_collector",
"tw-treasure_hunter",
]
randomize_challenge_settings: bool = (
True # Randomize settings within each challenge
)
class TextWorldEnv(BaseEnv):
"""Minimalist TextWorld environment for training LLMs."""
name = "textworld_minimal"
env_config_cls = TextWorldEnvConfig
def __init__(
self,
config: TextWorldEnvConfig,
server_configs: List[APIServerConfig],
slurm: bool = True,
testing: bool = False,
):
super().__init__(config, server_configs, slurm, testing)
self.config: TextWorldEnvConfig = config
self.challenge_registry = None
# Track generated game files for cleanup
self._generated_files = set()
# Create temp directory for game files
self._temp_dir = tempfile.mkdtemp(prefix="textworld_minimal_")
# wandb logging
self.episode_outcomes_buffer = []
self.episode_rewards_buffer = []
self.episode_steps_buffer = []
self.episode_challenge_types = []
self.eval_metrics_custom = []
self.system_prompt = (
"You are playing a text-based adventure game. "
"Read the game state and respond with ONLY the action you want to take. "
"Do not include any explanation or reasoning, just the action command itself.\n\n"
"Examples of valid actions:\n"
"- go north\n"
"- take key\n"
"- open door\n"
"- examine table\n"
"- inventory\n\n"
"Respond with a single action only."
)
async def setup(self):
"""Initialize the environment and challenge registry."""
# Import registry creation from local module
try:
self.challenge_registry = create_textworld_registry()
logger.warning(
f"Initialized TextWorld challenge registry with challenges: {self.config.challenge_names}"
)
logger.warning("TextWorldEnv setup completed successfully")
except Exception as e:
logger.error(f"Failed to create TextWorld registry: {e}")
logger.error(traceback.format_exc())
raise
async def get_next_item(self) -> Item:
"""Get the next game configuration."""
# Randomly select a challenge
if len(self.config.challenge_names) == 1:
challenge_name = self.config.challenge_names[0]
else:
challenge_name = random.choice(self.config.challenge_names)
# Get challenge settings
challenge_name, settings = self.challenge_registry.get_challenge(
challenge_name, randomize_settings=self.config.randomize_challenge_settings
)
return {
"challenge_name": challenge_name,
"settings": settings,
"game_type": "challenge",
}
def _create_game(self, challenge_name: str, settings: Dict[str, Any]) -> str:
"""Create a TextWorld game and save it to a file.
Returns:
Path to the saved game file (.z8)
"""
# Create default options
options = textworld.GameOptions()
options.seeds = settings.get("seed", random.randint(0, 1000000))
if challenge_name == "tw-simple":
game_settings = {
"rewards": settings.get("rewards", "balanced"),
"goal": settings.get("goal", "detailed"),
"test": str(settings.get("test", False)).lower(),
}
game = textworld.challenges.simple.make(game_settings, options=options)
elif challenge_name == "tw-cooking":
game_settings = {
"recipe": settings.get("recipe", 1), # Number of ingredients
"take": settings.get("take", 1), # Number to find
"cook": settings.get("cook", False), # Whether to cook
"open": settings.get("open", False), # Whether to open containers
"drop": settings.get("drop", False), # Whether limited inventory
"go": settings.get("go", 1), # Number of locations
"recipe_seed": settings.get(
"recipe-seed",
settings.get("recipe_seed", random.randint(0, 1000000)),
),
"split": "train",
}
logger.debug(f"Cooking game settings: {game_settings}")
game = textworld.challenges.cooking.make(game_settings, options=options)
elif challenge_name == "tw-coin_collector":
game_settings = {"level": settings.get("level", 1)}
game = textworld.challenges.coin_collector.make(
game_settings, options=options
)
elif challenge_name == "tw-treasure_hunter":
game_settings = {"level": settings.get("level", 1)}
game = textworld.challenges.treasure_hunter.make(
game_settings, options=options
)
else:
raise ValueError(f"Unknown challenge: {challenge_name}")
# Save gamefile
game_file = os.path.join(
self._temp_dir,
f"{challenge_name}_{settings.get('seed', random.randint(0, 1000000))}.z8",
)
options.path = game_file
options.file_ext = ".z8"
game_file = textworld.generator.compile_game(game, options)
# Track for cleanup
self._generated_files.add(game_file)
return game_file
async def collect_trajectories(
self, item: Item
) -> Tuple[ScoredDataGroup, List[Item]]:
"""Collect parallel trajectories from the same game."""
challenge_name = item["challenge_name"]
settings = item["settings"]
game_file = self._create_game(challenge_name, settings)
# Register the gamefile
request_infos = textworld.EnvInfos(
description=True,
inventory=True,
objective=True,
admissible_commands=True,
won=True,
lost=True,
score=True,
moves=True,
max_score=True,
)
env_id = textworld.gym.register_game(game_file, request_infos)
scored_items = []
for i in range(self.config.group_size):
try:
scored_item = await self._collect_single_trajectory(
env_id, i, challenge_name
)
if scored_item:
scored_items.append(scored_item)
except Exception as e:
logger.error(f"Error collecting trajectory {i}: {e}")
continue
if not scored_items:
return (
ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
messages=[],
advantages=None,
ref_logprobs=None,
group_overrides={},
overrides=None,
images=None,
),
[],
)
sdg = ScoredDataGroup(
tokens=[],
masks=[],
scores=[],
messages=[],
advantages=None,
ref_logprobs=None,
group_overrides={},
overrides=None,
images=None,
)
for scored_item in scored_items:
sdg["tokens"].append(scored_item["tokens"])
sdg["masks"].append(scored_item["masks"])
sdg["scores"].append(scored_item["scores"])
if self.config.include_messages and scored_item.get("messages"):
sdg["messages"].append(scored_item["messages"])
metadata = scored_item.get("metadata", {})
final_score = metadata.get("final_score", 0)
won = metadata.get("won", False)
lost = metadata.get("lost", False)
moves = metadata.get("moves", 0)
if won:
outcome = 1.0
elif lost:
outcome = -1.0
else:
outcome = 0.0
self.episode_outcomes_buffer.append(outcome)
self.episode_rewards_buffer.append(final_score)
self.episode_steps_buffer.append(moves)
self.episode_challenge_types.append(
metadata.get("challenge_name", "unknown")
)
self._cleanup_game_file(game_file)
return sdg, []
async def _collect_single_trajectory(
self, env_id: str, trajectory_idx: int, challenge_name: str = "unknown"
) -> Optional[ScoredDataItem]:
"""Collect a single trajectory for the game."""
messages: List[Message] = []
try:
env = textworld.gym.make(env_id)
obs, info = env.reset()
messages.append({"role": "system", "content": self.system_prompt})
obs_text = self._format_observation(obs, info)
messages.append({"role": "user", "content": obs_text})
done = False
total_reward = 0.0
async with self.server.dedicated_server() as server:
while not done and len(messages) < self.config.max_steps * 2:
current_tokens = len(
self.tokenizer.apply_chat_template(messages, tokenize=True)
)
if current_tokens > self.config.max_token_length - 50:
logger.warning(
f"Trajectory {trajectory_idx}: Approaching token limit, ending episode"
)
logger.info(f"Token usage: {current_tokens} tokens")
logger.info(f"Number of messages: {len(messages)}")
logger.info(
f"Last observation length: {len(messages[-1]['content']) if messages else 0}"
)
break
try:
response = await server.chat_completion(
messages=messages,
n=1,
max_tokens=50, # short actions, no thinking
temperature=0.7,
)
action = response.choices[0].message.content.strip()
logger.debug(f"Trajectory {trajectory_idx}: Action: {action}")
except Exception as e:
logger.error(f"Trajectory {trajectory_idx}: LLM error: {e}")
break
messages.append({"role": "assistant", "content": action})
try:
obs, reward, done, info = env.step(action)
total_reward += reward
except Exception as e:
logger.error(
f"Trajectory {trajectory_idx}: Environment error: {e}"
)
break
if not done:
obs_text = self._format_observation(obs, info)
messages.append({"role": "user", "content": obs_text})
env.close()
tokenization_result = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=messages,
train_on_all_assistant_turns=True,
)
return ScoredDataItem(
messages=messages if self.config.include_messages else None,
tokens=tokenization_result["tokens"],
masks=tokenization_result["masks"],
scores=total_reward,
metadata={
"trajectory_idx": trajectory_idx,
"final_score": total_reward,
"won": info.get("won", False),
"lost": info.get("lost", False),
"moves": info.get("moves", 0),
"challenge_name": challenge_name,
},
)
except Exception as e:
logger.error(f"Trajectory {trajectory_idx}: Fatal error: {e}")
logger.error(traceback.format_exc())
return None
def _format_observation(self, obs: str, info: Dict[str, Any]) -> str:
"""Format game observation for the LLM."""
parts = []
# Main observation
parts.append(obs)
# Add objective if available
if "objective" in info and info["objective"]:
parts.append(f"\nObjective: {info['objective']}")
# Add score info
if "score" in info and "max_score" in info:
parts.append(f"\nScore: {info['score']}/{info['max_score']}")
# Add inventory if not empty
if "inventory" in info and info["inventory"]:
parts.append(f"\nInventory: {info['inventory']}")
return "\n".join(parts)
@classmethod
def config_init(cls) -> Tuple[TextWorldEnvConfig, List[APIServerConfig]]:
"""Initialize default configuration."""
env_config = TextWorldEnvConfig(
tokenizer_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
group_size=16,
use_wandb=True,
wandb_name=cls.name,
max_token_length=32768,
total_steps=500,
challenge_names=[
"tw-simple",
"tw-cooking",
"tw-coin_collector",
"tw-treasure_hunter",
],
randomize_challenge_settings=True,
)
server_configs = [
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9005/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9006/v1",
api_key="x",
num_requests_for_eval=128,
),
APIServerConfig(
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
base_url="http://localhost:9007/v1",
api_key="x",
num_requests_for_eval=128,
),
]
return env_config, server_configs
# TODO: implement evaluation properly re eval changes
async def evaluate(self, num_items: int) -> Dict[str, Any]:
"""Evaluate the model - not implemented for this minimal environment."""
logger.warning("Evaluation not implemented in minimal TextWorld environment")
return {"message": "Evaluation not implemented"}
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log episode statistics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
# Log training episode outcomes
if self.episode_outcomes_buffer:
wins = sum(1 for outcome in self.episode_outcomes_buffer if outcome > 0)
losses = sum(1 for outcome in self.episode_outcomes_buffer if outcome < 0)
draws = sum(1 for outcome in self.episode_outcomes_buffer if outcome == 0)
total_episodes = len(self.episode_outcomes_buffer)
win_rate = (wins / total_episodes) * 100 if total_episodes > 0 else 0.0
loss_rate = (losses / total_episodes) * 100 if total_episodes > 0 else 0.0
draw_rate = (draws / total_episodes) * 100 if total_episodes > 0 else 0.0
avg_steps = (
sum(self.episode_steps_buffer) / len(self.episode_steps_buffer)
if self.episode_steps_buffer
else 0
)
avg_outcome = (
sum(self.episode_outcomes_buffer) / len(self.episode_outcomes_buffer)
if self.episode_outcomes_buffer
else 0
)
avg_reward = (
sum(self.episode_rewards_buffer) / len(self.episode_rewards_buffer)
if self.episode_rewards_buffer
else 0
)
max_reward = (
max(self.episode_rewards_buffer) if self.episode_rewards_buffer else 0
)
min_reward = (
min(self.episode_rewards_buffer) if self.episode_rewards_buffer else 0
)
wandb_metrics[f"{self.name}/train/total_episodes"] = total_episodes
wandb_metrics[f"{self.name}/train/win_count_absolute"] = wins
wandb_metrics[f"{self.name}/train/loss_count_absolute"] = losses
wandb_metrics[f"{self.name}/train/draw_count_absolute"] = draws
wandb_metrics[f"{self.name}/train/win_rate_percent"] = win_rate
wandb_metrics[f"{self.name}/train/loss_rate_percent"] = loss_rate
wandb_metrics[f"{self.name}/train/draw_rate_percent"] = draw_rate
wandb_metrics[f"{self.name}/train/avg_episode_steps"] = avg_steps
wandb_metrics[f"{self.name}/train/avg_outcome"] = (
avg_outcome # -1, 0, 1 average
)
wandb_metrics[f"{self.name}/train/avg_reward_score"] = (
avg_reward # Actual game score
)
wandb_metrics[f"{self.name}/train/max_reward_score"] = max_reward
wandb_metrics[f"{self.name}/train/min_reward_score"] = min_reward
# Per-challenge statistics
challenge_stats = {}
for i, (outcome, reward, steps, challenge) in enumerate(
zip(
self.episode_outcomes_buffer,
self.episode_rewards_buffer,
self.episode_steps_buffer,
self.episode_challenge_types,
)
):
if challenge not in challenge_stats:
challenge_stats[challenge] = {
"outcomes": [],
"rewards": [],
"steps": [],
"count": 0,
}
challenge_stats[challenge]["outcomes"].append(outcome)
challenge_stats[challenge]["rewards"].append(reward)
challenge_stats[challenge]["steps"].append(steps)
challenge_stats[challenge]["count"] += 1
for challenge, stats in challenge_stats.items():
challenge_wins = sum(1 for o in stats["outcomes"] if o > 0)
challenge_losses = sum(1 for o in stats["outcomes"] if o < 0)
challenge_draws = sum(1 for o in stats["outcomes"] if o == 0)
challenge_total = stats["count"]
wandb_metrics[f"{self.name}/train/{challenge}/episodes_count"] = (
challenge_total
)
wandb_metrics[f"{self.name}/train/{challenge}/wins_count"] = (
challenge_wins
)
wandb_metrics[f"{self.name}/train/{challenge}/losses_count"] = (
challenge_losses
)
wandb_metrics[f"{self.name}/train/{challenge}/draws_count"] = (
challenge_draws
)
wandb_metrics[f"{self.name}/train/{challenge}/win_rate_percent"] = (
(challenge_wins / challenge_total) * 100
if challenge_total > 0
else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/loss_rate_percent"] = (
(challenge_losses / challenge_total) * 100
if challenge_total > 0
else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/draw_rate_percent"] = (
(challenge_draws / challenge_total) * 100
if challenge_total > 0
else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/avg_steps"] = (
sum(stats["steps"]) / len(stats["steps"]) if stats["steps"] else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/avg_outcome"] = (
sum(stats["outcomes"]) / len(stats["outcomes"])
if stats["outcomes"]
else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/avg_reward_score"] = (
sum(stats["rewards"]) / len(stats["rewards"])
if stats["rewards"]
else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/max_reward_score"] = (
max(stats["rewards"]) if stats["rewards"] else 0
)
wandb_metrics[f"{self.name}/train/{challenge}/min_reward_score"] = (
min(stats["rewards"]) if stats["rewards"] else 0
)
self.episode_outcomes_buffer = []
self.episode_rewards_buffer = []
self.episode_steps_buffer = []
self.episode_challenge_types = []
# Log eval metrics if any
for key, value in self.eval_metrics_custom:
wandb_metrics[key] = value
self.eval_metrics_custom = []
await super().wandb_log(wandb_metrics)
def _cleanup_game_file(self, game_file: str):
"""Clean up a generated game file and its associated files."""
if game_file in self._generated_files:
try:
# Remove .z8 file
if os.path.exists(game_file):
os.remove(game_file)
logger.debug(f"Removed game file: {game_file}")
# Remove .ni file if it exists
ni_file = game_file.replace(".z8", ".ni")
if os.path.exists(ni_file):
os.remove(ni_file)
logger.debug(f"Removed ni file: {ni_file}")
# Remove .json file if it exists
json_file = game_file.replace(".z8", ".json")
if os.path.exists(json_file):
os.remove(json_file)
logger.debug(f"Removed json file: {json_file}")
# TODO: handle .z5's from manually downloaded games when that's added
self._generated_files.remove(game_file)
except OSError as e:
logger.warning(f"Failed to clean up game file {game_file}: {e}")
def __del__(self):
"""Ensure cleanup on deletion."""
# Clean up any remaining game files
files_to_clean = list(self._generated_files)
for game_file in files_to_clean:
self._cleanup_game_file(game_file)
# Clean up local temp directory
if hasattr(self, "_temp_dir") and os.path.exists(self._temp_dir):
import shutil
try:
shutil.rmtree(self._temp_dir)
logger.debug(f"Cleaned up temp directory: {self._temp_dir}")
except Exception:
pass
if __name__ == "__main__":
TextWorldEnv.cli()

View file

@ -0,0 +1,221 @@
#!/usr/bin/env python3
"""
Local test server for the minimalist TextWorld environment.
"""
import asyncio
import logging
import os
from dotenv import load_dotenv
from atroposlib.envs.base import APIServerConfig, EvalHandlingEnum
from environments.game_environments.textworld_env.textworld_env import (
TextWorldEnv,
TextWorldEnvConfig,
)
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def main():
"""Run multiple TextWorld episodes for testing the minimalist environment."""
logger.info("Starting TextWorld (No Thinking) environment local debug runner")
# Configure environment - matching blackjack_no_thinking settings
env_config = TextWorldEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=1,
use_wandb=False,
wandb_name="textworld_no_thinking_local_debug",
max_num_workers=1,
rollout_server_url="http://localhost:8000",
total_steps=1,
batch_size=1,
steps_per_eval=0,
max_token_length=32768,
inference_weight=1.0,
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.NONE,
eval_limit_ratio=0.0,
max_steps=10, # Max steps per episode
include_messages=True, # Include messages for debugging
eval_episodes=0,
)
# Configure server - using same model as blackjack example
server_configs = [
APIServerConfig(
model_name="gpt-4.1-nano",
base_url="https://api.openai.com/v1",
api_key=os.getenv("OPENAI_API_KEY"),
num_requests_for_eval=0,
)
]
logger.info("Using hardcoded debug configuration for No Thinking TextWorld.")
logger.debug(f"Env Config: {env_config}")
logger.debug(f"Server Configs: {server_configs}")
try:
env = TextWorldEnv(
config=env_config,
server_configs=server_configs,
slurm=False,
testing=False,
)
except Exception as e:
logger.exception(f"Failed to initialize TextWorldEnv: {e}")
return
logger.info("Running 20 episodes across all challenges")
try:
await env.setup()
# Test each challenge type
import sys
# Check if specific challenge requested
challenge_to_test = sys.argv[1] if len(sys.argv) > 1 else None
num_episodes = 20 if not challenge_to_test else 1
# Track statistics
episode_results = []
challenge_counts = {
"tw-simple": 0,
"tw-cooking": 0,
"tw-coin_collector": 0,
"tw-treasure_hunter": 0,
}
for episode_num in range(num_episodes):
if challenge_to_test:
# Override config to test specific challenge
env.config.challenge_names = [challenge_to_test]
item = await env.get_next_item()
challenge_name = item["challenge_name"]
challenge_counts[challenge_name] += 1
logger.info(f"\n===== Episode {episode_num + 1}/{num_episodes} =====")
logger.info(f"Using game: {item}")
# Collect trajectories (group_size=1 so just one trajectory)
sdg, _ = await env.collect_trajectories(item)
scored_data_item = None
if (
sdg
and hasattr(sdg, "scored_data_items")
and sdg.scored_data_items
and len(sdg.scored_data_items) > 0
):
scored_data_item = sdg.scored_data_items[0]
elif (
sdg
and isinstance(sdg, dict)
and "scored_data_items" in sdg
and len(sdg["scored_data_items"]) > 0
):
scored_data_item = sdg["scored_data_items"][0]
if scored_data_item:
# Handle both object and dict access patterns
scores = (
scored_data_item.scores
if hasattr(scored_data_item, "scores")
else scored_data_item.get("scores")
)
metadata = (
scored_data_item.metadata
if hasattr(scored_data_item, "metadata")
else scored_data_item.get("metadata", {})
)
# Log brief summary
outcome_str = "Loss"
if metadata.get("won"):
outcome_str = "Win"
elif scores > 0:
outcome_str = "Partial Success"
moves = metadata.get("moves", 0)
logger.info(
f"Result: {outcome_str}, Score: {scores:.2f}, Moves: {moves}"
)
# Collect statistics
episode_results.append(
{
"episode": episode_num + 1,
"challenge": challenge_name,
"score": scores,
"won": metadata.get("won", False),
"moves": moves,
"difficulty": item.get("settings", {}),
}
)
else:
logger.error("Trajectory collection did not return a ScoredDataItem.")
episode_results.append(
{
"episode": episode_num + 1,
"challenge": challenge_name,
"score": 0.0,
"won": False,
"moves": 0,
"difficulty": item.get("settings", {}),
}
)
# Print overall statistics
logger.info("\n" + "=" * 60)
logger.info("OVERALL RESULTS SUMMARY")
logger.info("=" * 60)
logger.info(f"Total episodes: {num_episodes}")
logger.info(f"Challenge distribution: {challenge_counts}")
# Calculate win rates per challenge
for challenge in challenge_counts:
challenge_episodes = [
ep for ep in episode_results if ep["challenge"] == challenge
]
if challenge_episodes:
wins = sum(1 for ep in challenge_episodes if ep["won"])
avg_score = sum(ep["score"] for ep in challenge_episodes) / len(
challenge_episodes
)
avg_moves = sum(ep["moves"] for ep in challenge_episodes) / len(
challenge_episodes
)
logger.info(f"\n{challenge}:")
logger.info(f" Episodes: {len(challenge_episodes)}")
logger.info(
f" Win rate: {wins}/{len(challenge_episodes)} ({100*wins/len(challenge_episodes):.1f}%)"
)
logger.info(f" Avg score: {avg_score:.2f}")
logger.info(f" Avg moves: {avg_moves:.1f}")
# Overall stats
total_wins = sum(1 for ep in episode_results if ep["won"])
total_avg_score = (
sum(ep["score"] for ep in episode_results) / len(episode_results)
if episode_results
else 0
)
logger.info(
f"\nOverall win rate: {total_wins}/{len(episode_results)} ({100*total_wins/len(episode_results):.1f}%)"
)
logger.info(f"Overall avg score: {total_avg_score:.2f}")
except Exception as e:
logger.exception(
f"An error occurred during trajectory collection or summary: {e}"
)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,128 @@
#!/usr/bin/env python3
"""
TextWorld Challenge Registry
Provides a simple registry for the pre-built TextWorld challenges.
"""
import logging
import random
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class TextWorldChallengeRegistry:
"""Registry for pre-built TextWorld challenges."""
# Pre-built challenges with their settings ranges for randomization
CHALLENGES = {
"tw-simple": {
"rewards": ["sparse", "balanced", "dense"],
"goal": ["detailed", "brief", "none"],
"test": [False],
},
"tw-cooking": {
"recipe": [1, 2, 3, 4], # Number of ingredients in recipe
"take": [1, 2, 3, 4], # Number of ingredients to find (will be constrained)
"cook": [False, True], # Whether ingredients need cooking
"open": [False, True], # Whether containers/doors need opening
"drop": [False, True], # Whether inventory has limited capacity
"go": [1, 6, 9, 12], # Number of locations
},
"tw-coin_collector": {
"level": list(range(1, 301)), # Levels 1-300 (full range)
},
"tw-treasure_hunter": {
"level": list(range(1, 31)), # Levels 1-30 (full range)
},
}
# All available challenge names
ALL_CHALLENGES = list(CHALLENGES.keys())
def __init__(self, seed: Optional[int] = None):
self._challenges = self.CHALLENGES.copy()
self.rng = random.Random(seed)
# Cache for all possible combinations
self._all_combinations = None
self._combination_index = 0
def list_challenges(self) -> List[str]:
"""List all available pre-built challenges."""
return list(self._challenges.keys())
def get_random_challenge(
self, randomize_settings: bool = True
) -> Tuple[str, Dict[str, Any]]:
"""Get a random challenge with optionally randomized settings.
Args:
randomize_settings: Whether to randomize settings from available options
Returns:
Tuple of (challenge_name, settings_dict)
"""
challenge_name = self.rng.choice(self.list_challenges())
return self.get_challenge(challenge_name, randomize_settings)
def get_challenge(
self, name: str, randomize_settings: bool = True
) -> Tuple[str, Dict[str, Any]]:
"""Get challenge name and settings (optionally randomized).
Args:
name: Challenge name
randomize_settings: Whether to randomize settings from available options
Returns:
Tuple of (challenge_name, settings_dict)
"""
if name not in self._challenges:
raise ValueError(
f"Unknown challenge: {name}. Available: {self.list_challenges()}"
)
settings_ranges = self._challenges[name]
settings = {}
for key, options in settings_ranges.items():
if randomize_settings and len(options) > 1:
# Randomly select from available options
settings[key] = self.rng.choice(options)
else:
# Use first (default) option
settings[key] = options[0]
# Special handling for tw-cooking: ensure take <= recipe
if name == "tw-cooking" and randomize_settings:
recipe_value = settings["recipe"]
# Constrain take to be at most recipe value
valid_take_values = [
t for t in settings_ranges["take"] if t <= recipe_value
]
settings["take"] = (
self.rng.choice(valid_take_values) if valid_take_values else 1
)
# Generate a seed for this specific game instance
settings["seed"] = self.rng.randint(0, 0xFFFFFFFF)
# For tw-cooking, add recipe-seed
if name == "tw-cooking":
settings["recipe-seed"] = self.rng.randint(0, 0xFFFFFFFF)
return name, settings
def create_textworld_registry(seed: Optional[int] = None) -> TextWorldChallengeRegistry:
"""Create a TextWorld challenge registry.
Args:
seed: Random seed for reproducibility
Returns:
TextWorldChallengeRegistry instance
"""
return TextWorldChallengeRegistry(seed)