mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Textworld minimal (#225)
* minimal implementation, simplified challenge registry * need game save logic * fixed challenge gen, works with local test * updated challenge gen with wider ranges, working with local script * runs working correctly, wandb stats look ok * linting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * removed unused imports --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1900a577d7
commit
47cb15745c
4 changed files with 1002 additions and 0 deletions
648
environments/game_environments/textworld_env/textworld_env.py
Normal file
648
environments/game_environments/textworld_env/textworld_env.py
Normal file
|
|
@ -0,0 +1,648 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
TextWorldEnv: Minimalist trainer environment for Microsoft TextWorld
|
||||
|
||||
A simple trainer environment that wraps TextWorld game generator and Gym interface
|
||||
to train LLMs. The LLM outputs actions in plain text and receives only environment rewards.
|
||||
No thinking tokens, memory, format rewards, or complex scoring - just pure environment interaction.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import gymnasium as gym # noqa: F401
|
||||
import textworld
|
||||
import textworld.challenges
|
||||
import textworld.gym
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
ScoredDataItem,
|
||||
)
|
||||
from atroposlib.type_definitions import Item, Message
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from environments.game_environments.textworld_env.textworld_registry import (
|
||||
create_textworld_registry,
|
||||
) # noqa: F401
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextWorldEnvConfig(BaseEnvConfig):
|
||||
"""Configuration for the minimalist TextWorld environment trainer."""
|
||||
|
||||
env_name: str = "TextWorld"
|
||||
wandb_name: str = "textworld-trainer-minimal"
|
||||
group_size: int = 16
|
||||
max_num_workers: int = 16
|
||||
total_steps: int = 500
|
||||
max_steps: int = 300 # max steps per episode (matches coin_collector max)
|
||||
max_token_length: int = 32768
|
||||
|
||||
# Challenge settings
|
||||
challenge_names: List[str] = [
|
||||
"tw-simple",
|
||||
"tw-cooking",
|
||||
"tw-coin_collector",
|
||||
"tw-treasure_hunter",
|
||||
]
|
||||
randomize_challenge_settings: bool = (
|
||||
True # Randomize settings within each challenge
|
||||
)
|
||||
|
||||
|
||||
class TextWorldEnv(BaseEnv):
|
||||
"""Minimalist TextWorld environment for training LLMs."""
|
||||
|
||||
name = "textworld_minimal"
|
||||
env_config_cls = TextWorldEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TextWorldEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = True,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: TextWorldEnvConfig = config
|
||||
self.challenge_registry = None
|
||||
|
||||
# Track generated game files for cleanup
|
||||
self._generated_files = set()
|
||||
|
||||
# Create temp directory for game files
|
||||
self._temp_dir = tempfile.mkdtemp(prefix="textworld_minimal_")
|
||||
|
||||
# wandb logging
|
||||
self.episode_outcomes_buffer = []
|
||||
self.episode_rewards_buffer = []
|
||||
self.episode_steps_buffer = []
|
||||
self.episode_challenge_types = []
|
||||
self.eval_metrics_custom = []
|
||||
|
||||
self.system_prompt = (
|
||||
"You are playing a text-based adventure game. "
|
||||
"Read the game state and respond with ONLY the action you want to take. "
|
||||
"Do not include any explanation or reasoning, just the action command itself.\n\n"
|
||||
"Examples of valid actions:\n"
|
||||
"- go north\n"
|
||||
"- take key\n"
|
||||
"- open door\n"
|
||||
"- examine table\n"
|
||||
"- inventory\n\n"
|
||||
"Respond with a single action only."
|
||||
)
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize the environment and challenge registry."""
|
||||
# Import registry creation from local module
|
||||
|
||||
try:
|
||||
self.challenge_registry = create_textworld_registry()
|
||||
logger.warning(
|
||||
f"Initialized TextWorld challenge registry with challenges: {self.config.challenge_names}"
|
||||
)
|
||||
logger.warning("TextWorldEnv setup completed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create TextWorld registry: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""Get the next game configuration."""
|
||||
# Randomly select a challenge
|
||||
if len(self.config.challenge_names) == 1:
|
||||
challenge_name = self.config.challenge_names[0]
|
||||
else:
|
||||
challenge_name = random.choice(self.config.challenge_names)
|
||||
|
||||
# Get challenge settings
|
||||
challenge_name, settings = self.challenge_registry.get_challenge(
|
||||
challenge_name, randomize_settings=self.config.randomize_challenge_settings
|
||||
)
|
||||
|
||||
return {
|
||||
"challenge_name": challenge_name,
|
||||
"settings": settings,
|
||||
"game_type": "challenge",
|
||||
}
|
||||
|
||||
def _create_game(self, challenge_name: str, settings: Dict[str, Any]) -> str:
|
||||
"""Create a TextWorld game and save it to a file.
|
||||
|
||||
Returns:
|
||||
Path to the saved game file (.z8)
|
||||
"""
|
||||
# Create default options
|
||||
options = textworld.GameOptions()
|
||||
options.seeds = settings.get("seed", random.randint(0, 1000000))
|
||||
|
||||
if challenge_name == "tw-simple":
|
||||
game_settings = {
|
||||
"rewards": settings.get("rewards", "balanced"),
|
||||
"goal": settings.get("goal", "detailed"),
|
||||
"test": str(settings.get("test", False)).lower(),
|
||||
}
|
||||
game = textworld.challenges.simple.make(game_settings, options=options)
|
||||
elif challenge_name == "tw-cooking":
|
||||
game_settings = {
|
||||
"recipe": settings.get("recipe", 1), # Number of ingredients
|
||||
"take": settings.get("take", 1), # Number to find
|
||||
"cook": settings.get("cook", False), # Whether to cook
|
||||
"open": settings.get("open", False), # Whether to open containers
|
||||
"drop": settings.get("drop", False), # Whether limited inventory
|
||||
"go": settings.get("go", 1), # Number of locations
|
||||
"recipe_seed": settings.get(
|
||||
"recipe-seed",
|
||||
settings.get("recipe_seed", random.randint(0, 1000000)),
|
||||
),
|
||||
"split": "train",
|
||||
}
|
||||
logger.debug(f"Cooking game settings: {game_settings}")
|
||||
game = textworld.challenges.cooking.make(game_settings, options=options)
|
||||
elif challenge_name == "tw-coin_collector":
|
||||
game_settings = {"level": settings.get("level", 1)}
|
||||
game = textworld.challenges.coin_collector.make(
|
||||
game_settings, options=options
|
||||
)
|
||||
elif challenge_name == "tw-treasure_hunter":
|
||||
game_settings = {"level": settings.get("level", 1)}
|
||||
game = textworld.challenges.treasure_hunter.make(
|
||||
game_settings, options=options
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown challenge: {challenge_name}")
|
||||
|
||||
# Save gamefile
|
||||
game_file = os.path.join(
|
||||
self._temp_dir,
|
||||
f"{challenge_name}_{settings.get('seed', random.randint(0, 1000000))}.z8",
|
||||
)
|
||||
options.path = game_file
|
||||
options.file_ext = ".z8"
|
||||
game_file = textworld.generator.compile_game(game, options)
|
||||
|
||||
# Track for cleanup
|
||||
self._generated_files.add(game_file)
|
||||
|
||||
return game_file
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[ScoredDataGroup, List[Item]]:
|
||||
"""Collect parallel trajectories from the same game."""
|
||||
challenge_name = item["challenge_name"]
|
||||
settings = item["settings"]
|
||||
|
||||
game_file = self._create_game(challenge_name, settings)
|
||||
|
||||
# Register the gamefile
|
||||
request_infos = textworld.EnvInfos(
|
||||
description=True,
|
||||
inventory=True,
|
||||
objective=True,
|
||||
admissible_commands=True,
|
||||
won=True,
|
||||
lost=True,
|
||||
score=True,
|
||||
moves=True,
|
||||
max_score=True,
|
||||
)
|
||||
env_id = textworld.gym.register_game(game_file, request_infos)
|
||||
|
||||
scored_items = []
|
||||
|
||||
for i in range(self.config.group_size):
|
||||
try:
|
||||
scored_item = await self._collect_single_trajectory(
|
||||
env_id, i, challenge_name
|
||||
)
|
||||
if scored_item:
|
||||
scored_items.append(scored_item)
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting trajectory {i}: {e}")
|
||||
continue
|
||||
|
||||
if not scored_items:
|
||||
return (
|
||||
ScoredDataGroup(
|
||||
tokens=[],
|
||||
masks=[],
|
||||
scores=[],
|
||||
messages=[],
|
||||
advantages=None,
|
||||
ref_logprobs=None,
|
||||
group_overrides={},
|
||||
overrides=None,
|
||||
images=None,
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
sdg = ScoredDataGroup(
|
||||
tokens=[],
|
||||
masks=[],
|
||||
scores=[],
|
||||
messages=[],
|
||||
advantages=None,
|
||||
ref_logprobs=None,
|
||||
group_overrides={},
|
||||
overrides=None,
|
||||
images=None,
|
||||
)
|
||||
|
||||
for scored_item in scored_items:
|
||||
sdg["tokens"].append(scored_item["tokens"])
|
||||
sdg["masks"].append(scored_item["masks"])
|
||||
sdg["scores"].append(scored_item["scores"])
|
||||
if self.config.include_messages and scored_item.get("messages"):
|
||||
sdg["messages"].append(scored_item["messages"])
|
||||
|
||||
metadata = scored_item.get("metadata", {})
|
||||
final_score = metadata.get("final_score", 0)
|
||||
won = metadata.get("won", False)
|
||||
lost = metadata.get("lost", False)
|
||||
moves = metadata.get("moves", 0)
|
||||
|
||||
if won:
|
||||
outcome = 1.0
|
||||
elif lost:
|
||||
outcome = -1.0
|
||||
else:
|
||||
outcome = 0.0
|
||||
|
||||
self.episode_outcomes_buffer.append(outcome)
|
||||
self.episode_rewards_buffer.append(final_score)
|
||||
self.episode_steps_buffer.append(moves)
|
||||
self.episode_challenge_types.append(
|
||||
metadata.get("challenge_name", "unknown")
|
||||
)
|
||||
|
||||
self._cleanup_game_file(game_file)
|
||||
|
||||
return sdg, []
|
||||
|
||||
async def _collect_single_trajectory(
|
||||
self, env_id: str, trajectory_idx: int, challenge_name: str = "unknown"
|
||||
) -> Optional[ScoredDataItem]:
|
||||
"""Collect a single trajectory for the game."""
|
||||
messages: List[Message] = []
|
||||
|
||||
try:
|
||||
env = textworld.gym.make(env_id)
|
||||
obs, info = env.reset()
|
||||
|
||||
messages.append({"role": "system", "content": self.system_prompt})
|
||||
|
||||
obs_text = self._format_observation(obs, info)
|
||||
messages.append({"role": "user", "content": obs_text})
|
||||
|
||||
done = False
|
||||
total_reward = 0.0
|
||||
|
||||
async with self.server.dedicated_server() as server:
|
||||
while not done and len(messages) < self.config.max_steps * 2:
|
||||
current_tokens = len(
|
||||
self.tokenizer.apply_chat_template(messages, tokenize=True)
|
||||
)
|
||||
if current_tokens > self.config.max_token_length - 50:
|
||||
logger.warning(
|
||||
f"Trajectory {trajectory_idx}: Approaching token limit, ending episode"
|
||||
)
|
||||
logger.info(f"Token usage: {current_tokens} tokens")
|
||||
logger.info(f"Number of messages: {len(messages)}")
|
||||
logger.info(
|
||||
f"Last observation length: {len(messages[-1]['content']) if messages else 0}"
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
response = await server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=50, # short actions, no thinking
|
||||
temperature=0.7,
|
||||
)
|
||||
action = response.choices[0].message.content.strip()
|
||||
logger.debug(f"Trajectory {trajectory_idx}: Action: {action}")
|
||||
except Exception as e:
|
||||
logger.error(f"Trajectory {trajectory_idx}: LLM error: {e}")
|
||||
break
|
||||
|
||||
messages.append({"role": "assistant", "content": action})
|
||||
|
||||
try:
|
||||
obs, reward, done, info = env.step(action)
|
||||
total_reward += reward
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Trajectory {trajectory_idx}: Environment error: {e}"
|
||||
)
|
||||
break
|
||||
|
||||
if not done:
|
||||
obs_text = self._format_observation(obs, info)
|
||||
messages.append({"role": "user", "content": obs_text})
|
||||
|
||||
env.close()
|
||||
|
||||
tokenization_result = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer,
|
||||
chat=messages,
|
||||
train_on_all_assistant_turns=True,
|
||||
)
|
||||
|
||||
return ScoredDataItem(
|
||||
messages=messages if self.config.include_messages else None,
|
||||
tokens=tokenization_result["tokens"],
|
||||
masks=tokenization_result["masks"],
|
||||
scores=total_reward,
|
||||
metadata={
|
||||
"trajectory_idx": trajectory_idx,
|
||||
"final_score": total_reward,
|
||||
"won": info.get("won", False),
|
||||
"lost": info.get("lost", False),
|
||||
"moves": info.get("moves", 0),
|
||||
"challenge_name": challenge_name,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Trajectory {trajectory_idx}: Fatal error: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def _format_observation(self, obs: str, info: Dict[str, Any]) -> str:
|
||||
"""Format game observation for the LLM."""
|
||||
parts = []
|
||||
|
||||
# Main observation
|
||||
parts.append(obs)
|
||||
|
||||
# Add objective if available
|
||||
if "objective" in info and info["objective"]:
|
||||
parts.append(f"\nObjective: {info['objective']}")
|
||||
|
||||
# Add score info
|
||||
if "score" in info and "max_score" in info:
|
||||
parts.append(f"\nScore: {info['score']}/{info['max_score']}")
|
||||
|
||||
# Add inventory if not empty
|
||||
if "inventory" in info and info["inventory"]:
|
||||
parts.append(f"\nInventory: {info['inventory']}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TextWorldEnvConfig, List[APIServerConfig]]:
|
||||
"""Initialize default configuration."""
|
||||
env_config = TextWorldEnvConfig(
|
||||
tokenizer_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||
group_size=16,
|
||||
use_wandb=True,
|
||||
wandb_name=cls.name,
|
||||
max_token_length=32768,
|
||||
total_steps=500,
|
||||
challenge_names=[
|
||||
"tw-simple",
|
||||
"tw-cooking",
|
||||
"tw-coin_collector",
|
||||
"tw-treasure_hunter",
|
||||
],
|
||||
randomize_challenge_settings=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=128,
|
||||
),
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||
base_url="http://localhost:9005/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=128,
|
||||
),
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||
base_url="http://localhost:9006/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=128,
|
||||
),
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
||||
base_url="http://localhost:9007/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=128,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# TODO: implement evaluation properly re eval changes
|
||||
async def evaluate(self, num_items: int) -> Dict[str, Any]:
|
||||
"""Evaluate the model - not implemented for this minimal environment."""
|
||||
logger.warning("Evaluation not implemented in minimal TextWorld environment")
|
||||
return {"message": "Evaluation not implemented"}
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log episode statistics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Log training episode outcomes
|
||||
if self.episode_outcomes_buffer:
|
||||
wins = sum(1 for outcome in self.episode_outcomes_buffer if outcome > 0)
|
||||
losses = sum(1 for outcome in self.episode_outcomes_buffer if outcome < 0)
|
||||
draws = sum(1 for outcome in self.episode_outcomes_buffer if outcome == 0)
|
||||
total_episodes = len(self.episode_outcomes_buffer)
|
||||
|
||||
win_rate = (wins / total_episodes) * 100 if total_episodes > 0 else 0.0
|
||||
loss_rate = (losses / total_episodes) * 100 if total_episodes > 0 else 0.0
|
||||
draw_rate = (draws / total_episodes) * 100 if total_episodes > 0 else 0.0
|
||||
|
||||
avg_steps = (
|
||||
sum(self.episode_steps_buffer) / len(self.episode_steps_buffer)
|
||||
if self.episode_steps_buffer
|
||||
else 0
|
||||
)
|
||||
avg_outcome = (
|
||||
sum(self.episode_outcomes_buffer) / len(self.episode_outcomes_buffer)
|
||||
if self.episode_outcomes_buffer
|
||||
else 0
|
||||
)
|
||||
avg_reward = (
|
||||
sum(self.episode_rewards_buffer) / len(self.episode_rewards_buffer)
|
||||
if self.episode_rewards_buffer
|
||||
else 0
|
||||
)
|
||||
max_reward = (
|
||||
max(self.episode_rewards_buffer) if self.episode_rewards_buffer else 0
|
||||
)
|
||||
min_reward = (
|
||||
min(self.episode_rewards_buffer) if self.episode_rewards_buffer else 0
|
||||
)
|
||||
|
||||
wandb_metrics[f"{self.name}/train/total_episodes"] = total_episodes
|
||||
wandb_metrics[f"{self.name}/train/win_count_absolute"] = wins
|
||||
wandb_metrics[f"{self.name}/train/loss_count_absolute"] = losses
|
||||
wandb_metrics[f"{self.name}/train/draw_count_absolute"] = draws
|
||||
wandb_metrics[f"{self.name}/train/win_rate_percent"] = win_rate
|
||||
wandb_metrics[f"{self.name}/train/loss_rate_percent"] = loss_rate
|
||||
wandb_metrics[f"{self.name}/train/draw_rate_percent"] = draw_rate
|
||||
wandb_metrics[f"{self.name}/train/avg_episode_steps"] = avg_steps
|
||||
wandb_metrics[f"{self.name}/train/avg_outcome"] = (
|
||||
avg_outcome # -1, 0, 1 average
|
||||
)
|
||||
wandb_metrics[f"{self.name}/train/avg_reward_score"] = (
|
||||
avg_reward # Actual game score
|
||||
)
|
||||
wandb_metrics[f"{self.name}/train/max_reward_score"] = max_reward
|
||||
wandb_metrics[f"{self.name}/train/min_reward_score"] = min_reward
|
||||
|
||||
# Per-challenge statistics
|
||||
challenge_stats = {}
|
||||
for i, (outcome, reward, steps, challenge) in enumerate(
|
||||
zip(
|
||||
self.episode_outcomes_buffer,
|
||||
self.episode_rewards_buffer,
|
||||
self.episode_steps_buffer,
|
||||
self.episode_challenge_types,
|
||||
)
|
||||
):
|
||||
if challenge not in challenge_stats:
|
||||
challenge_stats[challenge] = {
|
||||
"outcomes": [],
|
||||
"rewards": [],
|
||||
"steps": [],
|
||||
"count": 0,
|
||||
}
|
||||
challenge_stats[challenge]["outcomes"].append(outcome)
|
||||
challenge_stats[challenge]["rewards"].append(reward)
|
||||
challenge_stats[challenge]["steps"].append(steps)
|
||||
challenge_stats[challenge]["count"] += 1
|
||||
|
||||
for challenge, stats in challenge_stats.items():
|
||||
challenge_wins = sum(1 for o in stats["outcomes"] if o > 0)
|
||||
challenge_losses = sum(1 for o in stats["outcomes"] if o < 0)
|
||||
challenge_draws = sum(1 for o in stats["outcomes"] if o == 0)
|
||||
challenge_total = stats["count"]
|
||||
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/episodes_count"] = (
|
||||
challenge_total
|
||||
)
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/wins_count"] = (
|
||||
challenge_wins
|
||||
)
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/losses_count"] = (
|
||||
challenge_losses
|
||||
)
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/draws_count"] = (
|
||||
challenge_draws
|
||||
)
|
||||
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/win_rate_percent"] = (
|
||||
(challenge_wins / challenge_total) * 100
|
||||
if challenge_total > 0
|
||||
else 0
|
||||
)
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/loss_rate_percent"] = (
|
||||
(challenge_losses / challenge_total) * 100
|
||||
if challenge_total > 0
|
||||
else 0
|
||||
)
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/draw_rate_percent"] = (
|
||||
(challenge_draws / challenge_total) * 100
|
||||
if challenge_total > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/avg_steps"] = (
|
||||
sum(stats["steps"]) / len(stats["steps"]) if stats["steps"] else 0
|
||||
)
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/avg_outcome"] = (
|
||||
sum(stats["outcomes"]) / len(stats["outcomes"])
|
||||
if stats["outcomes"]
|
||||
else 0
|
||||
)
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/avg_reward_score"] = (
|
||||
sum(stats["rewards"]) / len(stats["rewards"])
|
||||
if stats["rewards"]
|
||||
else 0
|
||||
)
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/max_reward_score"] = (
|
||||
max(stats["rewards"]) if stats["rewards"] else 0
|
||||
)
|
||||
wandb_metrics[f"{self.name}/train/{challenge}/min_reward_score"] = (
|
||||
min(stats["rewards"]) if stats["rewards"] else 0
|
||||
)
|
||||
|
||||
self.episode_outcomes_buffer = []
|
||||
self.episode_rewards_buffer = []
|
||||
self.episode_steps_buffer = []
|
||||
self.episode_challenge_types = []
|
||||
|
||||
# Log eval metrics if any
|
||||
for key, value in self.eval_metrics_custom:
|
||||
wandb_metrics[key] = value
|
||||
self.eval_metrics_custom = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
def _cleanup_game_file(self, game_file: str):
|
||||
"""Clean up a generated game file and its associated files."""
|
||||
if game_file in self._generated_files:
|
||||
try:
|
||||
# Remove .z8 file
|
||||
if os.path.exists(game_file):
|
||||
os.remove(game_file)
|
||||
logger.debug(f"Removed game file: {game_file}")
|
||||
|
||||
# Remove .ni file if it exists
|
||||
ni_file = game_file.replace(".z8", ".ni")
|
||||
if os.path.exists(ni_file):
|
||||
os.remove(ni_file)
|
||||
logger.debug(f"Removed ni file: {ni_file}")
|
||||
|
||||
# Remove .json file if it exists
|
||||
json_file = game_file.replace(".z8", ".json")
|
||||
if os.path.exists(json_file):
|
||||
os.remove(json_file)
|
||||
logger.debug(f"Removed json file: {json_file}")
|
||||
|
||||
# TODO: handle .z5's from manually downloaded games when that's added
|
||||
|
||||
self._generated_files.remove(game_file)
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to clean up game file {game_file}: {e}")
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure cleanup on deletion."""
|
||||
# Clean up any remaining game files
|
||||
files_to_clean = list(self._generated_files)
|
||||
for game_file in files_to_clean:
|
||||
self._cleanup_game_file(game_file)
|
||||
|
||||
# Clean up local temp directory
|
||||
if hasattr(self, "_temp_dir") and os.path.exists(self._temp_dir):
|
||||
import shutil
|
||||
|
||||
try:
|
||||
shutil.rmtree(self._temp_dir)
|
||||
logger.debug(f"Cleaned up temp directory: {self._temp_dir}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TextWorldEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue