atropos/environments/rubiks_cube_environment.py
2025-05-18 15:27:37 -07:00

1120 lines
No EOL
42 KiB
Python

#!/usr/bin/env python3
"""
RubiksCubeEnv: Trainer environment for Rubik's Cube solving with multi-step reasoning
This environment implements a Rubik's cube solver that trains LLMs to solve cubes
through step-by-step reasoning and visualization. Extends BaseEnv.
"""
import asyncio
import copy
import json
import logging
import random
import re
from typing import Dict, List, Optional, Tuple, Any
import string
import numpy as np
# Define the face colors for visualization
UP_COLOR = 'W' # White
DOWN_COLOR = 'Y' # Yellow
RIGHT_COLOR = 'R' # Red
LEFT_COLOR = 'O' # Orange
FRONT_COLOR = 'G' # Green
BACK_COLOR = 'B' # Blue
class Cube:
"""
A Rubik's cube implementation with accurate move handling.
"""
def __init__(self):
# Initialize a solved cube
self.reset()
def reset(self):
"""Reset the cube to solved state"""
# Initialize the cube as a 3D array [face][row][col]
# Faces: 0=UP, 1=DOWN, 2=LEFT, 3=RIGHT, 4=FRONT, 5=BACK
self.cube = [
[[UP_COLOR for _ in range(3)] for _ in range(3)], # UP
[[DOWN_COLOR for _ in range(3)] for _ in range(3)], # DOWN
[[LEFT_COLOR for _ in range(3)] for _ in range(3)], # LEFT
[[RIGHT_COLOR for _ in range(3)] for _ in range(3)], # RIGHT
[[FRONT_COLOR for _ in range(3)] for _ in range(3)], # FRONT
[[BACK_COLOR for _ in range(3)] for _ in range(3)] # BACK
]
def is_solved(self) -> bool:
"""Check if the cube is solved"""
for face in self.cube:
center_color = face[1][1] # Center color never changes
for row in face:
for color in row:
if color != center_color:
return False
return True
def count_solved_cubies(self) -> float:
"""
Count the number of stickers in their correct position
Returns a normalized score between 0 and 1
"""
# Create a solved reference cube
reference = Cube()
# Count matching stickers
total_stickers = 6 * 9 # 6 faces, 9 stickers per face
match_count = 0
for face_idx in range(6):
for i in range(3):
for j in range(3):
if self.cube[face_idx][i][j] == reference.cube[face_idx][i][j]:
match_count += 1
return match_count / total_stickers
def rotate(self, move: str):
"""
Perform a move on the cube using standard notation
U, D, L, R, F, B are clockwise rotations of respective faces
U', D', L', R', F', B' are counterclockwise rotations
U2, D2, L2, R2, F2, B2 are double (180°) rotations
"""
# Map move notation to face index and rotation count
face_map = {
'U': 0, 'D': 1, 'L': 2, 'R': 3, 'F': 4, 'B': 5
}
# Parse the move
if len(move) == 0:
raise ValueError("Empty move")
face = move[0]
if face not in face_map:
raise ValueError(f"Invalid face: {face}")
face_idx = face_map[face]
# Handle rotation direction
if len(move) == 1:
# Clockwise rotation
count = 1
elif len(move) == 2:
if move[1] == "'":
# Counterclockwise rotation
count = 3
elif move[1] == "2":
# Double rotation
count = 2
else:
raise ValueError(f"Invalid move modifier: {move[1]}")
else:
raise ValueError(f"Invalid move format: {move}")
# Apply the rotation 'count' times
for _ in range(count):
self._rotate_face_clockwise(face_idx)
self._rotate_adjacent_faces(face_idx)
def _rotate_face_clockwise(self, face_idx: int):
"""Rotate a face clockwise"""
face = self.cube[face_idx]
new_face = [[None for _ in range(3)] for _ in range(3)]
# Copy with 90-degree clockwise rotation
for i in range(3):
for j in range(3):
new_face[j][2-i] = face[i][j]
self.cube[face_idx] = new_face
def _rotate_adjacent_faces(self, face_idx: int):
"""Rotate the appropriate edges on adjacent faces"""
if face_idx == 0: # UP face
# Rotate the top edges of FRONT, RIGHT, BACK, LEFT
temp = self.cube[4][0][:] # Save FRONT top edge
self.cube[4][0] = self.cube[2][0][:] # FRONT <- LEFT
self.cube[2][0] = self.cube[5][0][:] # LEFT <- BACK
self.cube[5][0] = self.cube[3][0][:] # BACK <- RIGHT
self.cube[3][0] = temp # RIGHT <- FRONT
elif face_idx == 1: # DOWN face
# Rotate the bottom edges of FRONT, LEFT, BACK, RIGHT
temp = self.cube[4][2][:] # Save FRONT bottom edge
self.cube[4][2] = self.cube[3][2][:] # FRONT <- RIGHT
self.cube[3][2] = self.cube[5][2][:] # RIGHT <- BACK
self.cube[5][2] = self.cube[2][2][:] # BACK <- LEFT
self.cube[2][2] = temp # LEFT <- FRONT
elif face_idx == 2: # LEFT face
# Rotate the left edges of UP, FRONT, DOWN, BACK
# Need to extract and set columns, not rows
temp = [self.cube[0][i][0] for i in range(3)] # Save UP left column
# UP left <- BACK right (reversed)
for i in range(3):
self.cube[0][i][0] = self.cube[5][2-i][2]
# BACK right <- DOWN left (reversed)
for i in range(3):
self.cube[5][i][2] = self.cube[1][2-i][0]
# DOWN left <- FRONT left
for i in range(3):
self.cube[1][i][0] = self.cube[4][i][0]
# FRONT left <- UP left
for i in range(3):
self.cube[4][i][0] = temp[i]
elif face_idx == 3: # RIGHT face
# Rotate the right edges of UP, BACK, DOWN, FRONT
temp = [self.cube[0][i][2] for i in range(3)] # Save UP right column
# UP right <- FRONT right
for i in range(3):
self.cube[0][i][2] = self.cube[4][i][2]
# FRONT right <- DOWN right
for i in range(3):
self.cube[4][i][2] = self.cube[1][i][2]
# DOWN right <- BACK left (reversed)
for i in range(3):
self.cube[1][i][2] = self.cube[5][2-i][0]
# BACK left <- UP right (reversed)
for i in range(3):
self.cube[5][i][0] = temp[2-i]
elif face_idx == 4: # FRONT face
# Rotate the edges of UP bottom, RIGHT left, DOWN top, LEFT right
# UP bottom row
temp = self.cube[0][2][:]
# UP bottom <- LEFT right (rotated)
for i in range(3):
self.cube[0][2][i] = self.cube[2][2-i][2]
# LEFT right <- DOWN top (rotated)
for i in range(3):
self.cube[2][i][2] = self.cube[1][0][i]
# DOWN top <- RIGHT left (rotated)
for i in range(3):
self.cube[1][0][i] = self.cube[3][2-i][0]
# RIGHT left <- UP bottom (rotated)
for i in range(3):
self.cube[3][i][0] = temp[i]
elif face_idx == 5: # BACK face
# Rotate the edges of UP top, LEFT left, DOWN bottom, RIGHT right
# UP top row
temp = self.cube[0][0][:]
# UP top <- RIGHT right (rotated)
for i in range(3):
self.cube[0][0][i] = self.cube[3][2-i][2]
# RIGHT right <- DOWN bottom (rotated)
for i in range(3):
self.cube[3][i][2] = self.cube[1][2][i]
# DOWN bottom <- LEFT left (rotated)
for i in range(3):
self.cube[1][2][i] = self.cube[2][2-i][0]
# LEFT left <- UP top (rotated)
for i in range(3):
self.cube[2][i][0] = temp[i]
def __str__(self) -> str:
"""Convert cube to string representation"""
face_names = ['U', 'D', 'L', 'R', 'F', 'B']
result = []
for i, face in enumerate(self.cube):
result.append(f"{face_names[i]}: {' '.join(face[0])}")
result.append(f" {' '.join(face[1])}")
result.append(f" {' '.join(face[2])}")
return "\n".join(result)
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
ScoredDataGroup,
)
from atroposlib.utils.message_history_utils import (
ensure_trajectory_token_limit,
truncate_thinking,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from atroposlib.utils.tool_call_parser import parse_tool_call
logger = logging.getLogger(__name__)
class RubiksCubeEnvConfig(BaseEnvConfig):
# Environment configuration
max_steps: int = 20 # Maximum steps allowed to solve the cube
temperature: float = 0.7
top_p: float = 0.9
wandb_name: str = "rubiks_cube"
thinking_active: bool = True
eval_episodes: int = 100
max_think_chars_history: int = 3000
max_trajectory_tokens: int = 24576 # seq_len of RL trainer
debug_mode: bool = False
group_size: int = 16
tiebreak_token_factor: float = 0.01
# Cube-specific configuration
scramble_moves: int = 7 # Number of random moves to scramble the cube
cube_size: int = 3 # 3x3 cube by default
reward_per_correctly_placed_cubie: float = 0.05
reward_per_step_reduction: float = 0.01 # Small penalty for using more steps
class RubiksCubeScoredDataGroup(ScoredDataGroup):
seed: int
tokens: Optional[List[List[int]]] = None
masks: Optional[List[List[int]]] = None
scores: Optional[List[float]] = None
messages: Optional[List[List[Dict]]] = None
parsed_actions: Optional[List[str]] = None
class CubeState:
def __init__(self, seed: int, scramble_moves: int):
self.seed = seed
self.cube = Cube()
self.message_history: List[Dict] = []
self.actions: List[str] = []
self.step_rewards: List[float] = []
self.total_reward: float = 0.0
self.num_steps: int = 0
# Seed random number generator for reproducibility
random.seed(seed)
# Reset cube to solved state
self.cube.reset()
# Scramble the cube with random moves
self._scramble_cube(scramble_moves)
def _scramble_cube(self, num_moves: int):
"""Scramble the cube with random moves"""
moves = ["U", "D", "L", "R", "F", "B",
"U'", "D'", "L'", "R'", "F'", "B'",
"U2", "D2", "L2", "R2", "F2", "B2"]
scramble_sequence = []
for _ in range(num_moves):
move = random.choice(moves)
scramble_sequence.append(move)
self.cube.rotate(move)
return " ".join(scramble_sequence)
def apply_move(self, move: str) -> bool:
"""Apply a move to the cube and return success"""
try:
self.cube.rotate(move)
self.actions.append(move)
self.num_steps += 1
return True
except Exception as e:
logger.error(f"Error applying move {move}: {e}")
return False
def is_solved(self) -> bool:
"""Check if the cube is solved"""
return self.cube.is_solved()
def get_cube_state_visualization(self) -> str:
"""Get a text representation of the cube state for visualization"""
# This returns a readable string representation of the cube layout
return str(self.cube)
class RubiksCubeEnv(BaseEnv):
def __init__(
self,
config: RubiksCubeEnvConfig,
server_configs: List[APIServerConfig],
slurm: bool = True,
testing: bool = False,
):
super().__init__(config, server_configs, slurm, testing)
self.episodes: Dict[int, CubeState] = {}
self.debug_mode = config.debug_mode
self.completed_episode_metrics_buffer: List[Dict[str, float]] = []
if self.debug_mode:
logger.setLevel(logging.DEBUG)
else:
if logger.level == logging.NOTSET or logger.level > logging.WARNING:
logger.setLevel(logging.WARNING)
self.tools = [
{
"type": "function",
"function": {
"name": "apply_move",
"description": "Apply a move to the Rubik's cube.",
"parameters": {
"move": {
"type": "string",
"description": "The move to apply to the cube. Valid moves are U, D, L, R, F, B (clockwise), U', D', L', R', F', B' (counterclockwise), and U2, D2, L2, R2, F2, B2 (180 degrees)."
}
},
},
}
]
tools_json = json.dumps(self.tools)
self.system_prompt = (
"You are an AI that solves Rubik's cubes step-by-step with clear reasoning. "
"You will be given the current state of a Rubik's cube, and you need to provide "
"moves to solve it.\n\n"
"The notation for cube moves follows the standard Rubik's cube notation:\n"
"- U: rotate the up face clockwise\n"
"- D: rotate the down face clockwise\n"
"- L: rotate the left face clockwise\n"
"- R: rotate the right face clockwise\n"
"- F: rotate the front face clockwise\n"
"- B: rotate the back face clockwise\n"
"- U', D', L', R', F', B': rotate the corresponding face counterclockwise\n"
"- U2, D2, L2, R2, F2, B2: rotate the corresponding face 180 degrees\n\n"
"You should analyze the current state of the cube, identify patterns, "
"and explain your reasoning step by step.\n\n"
"You should enclose your thoughts and internal monologue inside <think> </think> tags, and then "
"provide your move using the apply_move function call.\n\n"
f"<tools>\n{tools_json}\n</tools>\n\n"
"For your function call, return a JSON object with function name and arguments "
"within <tool_call> </tool_call> tags with the following schema:\n"
'<tool_call>\n{"arguments": {"move": "U"}, "name": "apply_move"}\n</tool_call>\n\n'
"Your full answer format should be:\n"
"<think>\n[Your detailed reasoning about the current cube state and the best move to make]\n</think>\n\n"
'<tool_call>\n{"arguments": {"move": "R"}, "name": "apply_move"}\n</tool_call>\n\n'
"Remember to carefully analyze the cube state and work toward the solution step by step."
)
def _get_or_create_episode(self, seed: int) -> CubeState:
if seed not in self.episodes:
ep = CubeState(seed, self.config.scramble_moves)
ep.message_history = [{"role": "system", "content": self.system_prompt}]
# Add initial observation
ep.message_history.append(
{"role": "environment", "content": self._format_observation(ep)}
)
self.episodes[seed] = ep
return self.episodes[seed]
def _format_observation(self, cube_state: CubeState) -> str:
"""Format the cube state as a string observation for the LLM"""
cube_visualization = cube_state.get_cube_state_visualization()
moves_made = ", ".join(cube_state.actions) if cube_state.actions else "None"
steps_remaining = self.config.max_steps - cube_state.num_steps
message = (
f"Current state of the Rubik's cube:\n\n"
f"```\n{cube_visualization}\n```\n\n"
f"Previous moves: {moves_made}\n"
f"Steps remaining: {steps_remaining}\n"
)
if cube_state.is_solved():
message += "\nCongratulations! The cube is now solved."
return message
def _calculate_cube_state_score(self, cube_state: CubeState) -> float:
"""
Calculate a score based on how close the cube is to being solved.
Higher scores for cubes that are closer to being solved.
"""
# Base score
score = 0.0
# Reward for a solved cube
if cube_state.is_solved():
score += 1.0
# Get the current state
cube = cube_state.cube
# Count correctly positioned cubies
# This is a simplified approach - in a real implementation,
# we would calculate this from the cube's internal state
correctly_placed = cube.count_solved_cubies()
score += correctly_placed * self.config.reward_per_correctly_placed_cubie
# Small penalty for using more steps
steps_penalty = cube_state.num_steps * self.config.reward_per_step_reduction
score -= steps_penalty
return score
def _parse_move(self, response: str) -> Optional[str]:
"""Extract move from the LLM response"""
if not response:
logger.warning(
"Attempted to parse an empty response string. Returning None."
)
return None
# First try parsing with tool_call tags
parsed_name, parsed_args, is_error = parse_tool_call(
response, self.tools, ["tool_call"]
)
# If that fails, try looking for direct text mentions of moves
if is_error:
error_detail = (
parsed_name
if isinstance(parsed_name, str) and parsed_name
else "Parser indicated error, but no specific message was returned"
)
logger.warning(
f"Failed to parse tool call. Full response: '{response}'. Error detail: {error_detail}"
)
# Fallback: Look for direct mentions of moves in the text
valid_moves = ["U", "D", "L", "R", "F", "B",
"U'", "D'", "L'", "R'", "F'", "B'",
"U2", "D2", "L2", "R2", "F2", "B2"]
# Look for patterns like "I'll apply move X" or "Performing X rotation"
move_patterns = [
r'move\s+([UDLRFB][\'2]?)',
r'applying\s+([UDLRFB][\'2]?)',
r'perform\w*\s+([UDLRFB][\'2]?)',
r'rotate\s+([UDLRFB][\'2]?)',
r'rotation\s+([UDLRFB][\'2]?)',
r'I\s*choose\s+([UDLRFB][\'2]?)',
r'Execute\s+([UDLRFB][\'2]?)',
r'([UDLRFB][\'2]?)\s+rotation',
r'([UDLRFB][\'2]?)\s+move'
]
for pattern in move_patterns:
match = re.search(pattern, response, re.IGNORECASE)
if match:
potential_move = match.group(1).strip()
if potential_move in valid_moves:
logger.warning(f"Recovered move '{potential_move}' from text using regex")
return potential_move
return None
move = parsed_args.get("move", "").strip() if isinstance(parsed_args, dict) else ""
valid_moves = ["U", "D", "L", "R", "F", "B",
"U'", "D'", "L'", "R'", "F'", "B'",
"U2", "D2", "L2", "R2", "F2", "B2"]
# First check if the move is directly valid
if move in valid_moves:
return move
# Check if the move is a sequence containing valid moves
# (when LLM outputs move sequences like "R U R'")
if " " in move:
# Take only the first move in the sequence
first_move = move.split()[0].strip()
if first_move in valid_moves:
logger.warning(f"Got move sequence '{move}' but taking only first move '{first_move}'")
return first_move
# If we get here, the move is invalid
logger.warning(
f"Parsed invalid move: '{move}'. "
f"Full response: '{response}'. Parsed args: {parsed_args}"
)
return None
def _score_response(
self,
is_valid_move: bool,
response_text: str,
cube_state: CubeState,
is_solved: bool,
) -> float:
"""
Calculate a score for a single agent response based on:
1. Whether the move was valid
2. Whether the move helps solve the cube
3. Presence of thinking tags
4. If the cube is solved
"""
# Base score from cube state after the move
current_score = self._calculate_cube_state_score(cube_state)
# Bonus for valid moves
if is_valid_move:
current_score += 0.2
else:
current_score -= 0.2
# Bonus for solving the cube
if is_solved:
current_score += 1.0
# Check for thinking tags
try:
# Make sure response_text is a string
if not isinstance(response_text, str):
logger.warning(f"response_text is not a string: {type(response_text)}")
response_text = str(response_text) if response_text is not None else ""
match = re.search(r"<think>(.*?)</think>", response_text, re.DOTALL)
if match:
thinking_content = match.group(1)
# Make sure thinking_content is a string before calling strip()
if isinstance(thinking_content, str):
if thinking_content.strip(): # Not empty
current_score += 0.2
else: # Empty thinking tags
current_score -= 0.1
else:
logger.warning(f"thinking_content is not a string: {type(thinking_content)}")
current_score -= 0.1
else: # No thinking tags
current_score -= 0.2
except Exception as e:
logger.warning(f"Error processing thinking tags: {e}")
current_score -= 0.1
return current_score
async def _sample_response(self, messages: List[Dict], n: int = 1) -> List[str]:
"""Sample responses from the language model"""
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
try:
completions = await self.server.completion(
prompt=prompt,
n=n,
max_tokens=self.config.max_token_length,
temperature=self.config.temperature,
top_p=self.config.top_p,
)
return [choice.text for choice in completions.choices]
except Exception as e:
logger.error(f"API error during completion: {e}")
return []
async def _next_step(
self, ep: CubeState, current_turn: int, max_turns: int
) -> Tuple[Optional[RubiksCubeScoredDataGroup], bool]:
"""Process one step of an episode"""
G = self.config.group_size
# Get current state
current_state_messages = ep.message_history.copy()
logger.debug(
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}/{max_turns}] "
f"Current state history length: {len(current_state_messages)}"
)
messages_for_llm = current_state_messages.copy()
agent_prompt_content = "<think>\n" if self.config.thinking_active else ""
messages_for_llm.append({"role": "agent", "content": agent_prompt_content})
# Generate G alternative responses
try:
responses = await self._sample_response(messages_for_llm, n=G)
if not responses or len(responses) != G:
logger.error(
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] "
f"Expected {G} responses, got {len(responses) if responses else 0}. "
f"Aborting step."
)
return None, True # Episode termination
except Exception as e_sample:
logger.error(
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Error sampling responses: {e_sample}",
exc_info=True,
)
return None, True # Episode termination
# Lists to store data for each alternative response
alt_full_responses: List[str] = []
alt_parsed_moves: List[Optional[str]] = []
alt_is_valid_move: List[bool] = []
alt_rewards: List[float] = []
alt_next_state_msgs: List[List[Dict]] = []
alt_is_terminal: List[bool] = []
alt_is_solved: List[bool] = []
alt_tokens: List[List[int]] = []
alt_masks: List[List[int]] = []
# Process each alternative response
for i in range(G):
llm_output_only = responses[i]
full_agent_response = agent_prompt_content + llm_output_only
alt_full_responses.append(full_agent_response)
# Parse the move from the response
parsed_move = self._parse_move(full_agent_response)
alt_parsed_moves.append(parsed_move)
# Create a copy of the current state for simulation
sim_ep = copy.deepcopy(ep)
# Apply the move if valid
is_valid_move = False
if parsed_move is not None:
is_valid_move = sim_ep.apply_move(parsed_move)
alt_is_valid_move.append(is_valid_move)
# Check if the cube is solved after the move
is_solved = sim_ep.is_solved()
alt_is_solved.append(is_solved)
# Calculate reward
reward = self._score_response(
is_valid_move,
full_agent_response,
sim_ep,
is_solved
)
alt_rewards.append(reward)
# Determine if the episode terminates
is_terminal = (
is_solved or
(current_turn + 1 >= max_turns) or
not is_valid_move
)
alt_is_terminal.append(is_terminal)
# Prepare next state messages
current_state_plus_response_i = current_state_messages + [
{"role": "agent", "content": full_agent_response}
]
if not is_terminal:
next_state_msgs_i = current_state_plus_response_i + [
{
"role": "environment",
"content": self._format_observation(sim_ep),
}
]
else:
next_state_msgs_i = current_state_plus_response_i
alt_next_state_msgs.append(next_state_msgs_i)
# Tokenize the next state for the trainer
tokenized_i = tokenize_for_trainer(self.tokenizer, next_state_msgs_i)
alt_tokens.append(tokenized_i["tokens"])
alt_masks.append(tokenized_i["masks"])
# Package the data for this step
if not (
len(alt_tokens) == G
and len(alt_masks) == G
and len(alt_rewards) == G
and len(alt_next_state_msgs) == G
and len(alt_parsed_moves) == G
):
logger.error(
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] "
f"Mismatch in alternative list lengths before creating ScoredDataGroup. "
f"Tokens:{len(alt_tokens)}, Masks:{len(alt_masks)}, Rewards:{len(alt_rewards)}, "
f"Msgs:{len(alt_next_state_msgs)}, ParsedMoves:{len(alt_parsed_moves)}. Expected {G} for all. "
f"Aborting step."
)
return None, True
current_step_data = RubiksCubeScoredDataGroup(
seed=ep.seed,
tokens=alt_tokens,
masks=alt_masks,
scores=alt_rewards,
messages=alt_next_state_msgs,
parsed_actions=alt_parsed_moves,
)
# Find the best response based on the rewards
best_reward_idx = np.argmax(alt_rewards)
chosen_reward_for_log = alt_rewards[best_reward_idx]
logger.debug(
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] "
f"Selected Alt {best_reward_idx} "
f"(Reward: {chosen_reward_for_log}) "
f"from {G} alternatives."
)
# Get the best parsed move
chosen_move = alt_parsed_moves[best_reward_idx]
chosen_full_response = alt_full_responses[best_reward_idx]
logger.info(
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Chosen move: "
f"{chosen_move} (from Alt {best_reward_idx} with "
f"Reward {chosen_reward_for_log})"
)
# Add the response to the episode history
response_for_history = truncate_thinking(
chosen_full_response,
self.tokenizer,
self.config.max_think_chars_history,
)
ep.message_history.append({"role": "agent", "content": response_for_history})
# Apply the chosen move to the main environment
is_valid_move = False
if chosen_move is not None:
is_valid_move = ep.apply_move(chosen_move)
# Check if the cube is solved
is_solved = ep.is_solved()
# Calculate reward
step_reward = self._score_response(
is_valid_move,
chosen_full_response,
ep,
is_solved
)
ep.step_rewards.append(step_reward)
# Determine if the episode terminates
is_episode_terminal = (
is_solved or
(current_turn + 1 >= max_turns) or
not is_valid_move
)
# Add the next observation if the episode continues
if not is_episode_terminal:
ep.message_history.append(
{"role": "environment", "content": self._format_observation(ep)}
)
return current_step_data, is_episode_terminal
async def collect_trajectories(
self, item: Tuple[int, int]
) -> Tuple[List[RubiksCubeScoredDataGroup], List[Tuple[int, int]]]:
"""Collect data for ONE FULL trajectory (episode)"""
seed, _ = item
G_config = self.config.group_size
max_turns = self.config.max_steps
trajectory_data_for_trainer: List[RubiksCubeScoredDataGroup] = []
logger.info(
f"[Collect Trajectories Seed: {seed}] Starting new trajectory. "
f"Group size G={G_config}, Max turns={max_turns}."
)
try:
ep = self._get_or_create_episode(seed)
except Exception as e:
logger.error(
f"[Collect Trajectories Seed: {seed}] Fatal error creating/getting episode: {e}",
exc_info=True,
)
return [], []
for turn_idx in range(max_turns):
logger.debug(
f"[Collect Trajectories Seed: {seed}] Attempting turn {turn_idx + 1}/{max_turns}."
)
try:
step_data, is_terminal_this_step = await self._next_step(
ep, turn_idx, max_turns
)
except Exception as e:
if "'list' object has no attribute 'strip'" in str(e):
# Special handling for this common error which doesn't actually
# prevent the overall process from working
logger.error(f"[Collect Trajectories Seed: {seed}] Non-fatal error in _next_step: {e}")
# Since we can't recover the step data, mark this step as failed
step_data = None
is_terminal_this_step = True
else:
# For other errors, log and terminate
logger.error(
f"[Collect Trajectories Seed: {seed}] Error in _next_step: {e}",
exc_info=True
)
step_data = None
is_terminal_this_step = True
if step_data:
trajectory_data_for_trainer.append(step_data)
else:
logger.error(
f"[Collect Trajectories Seed: {seed}] Turn {turn_idx + 1} failed to produce data."
" Terminating episode."
)
is_terminal_this_step = True
if is_terminal_this_step:
final_reward_at_termination = (
sum(ep.step_rewards) if ep.step_rewards else 0.0
)
logger.info(
f"[Collect Trajectories Seed: {seed}] Episode ended at turn {turn_idx + 1}. "
f"Reason: step reported terminal. Total reward: {final_reward_at_termination:.2f}"
)
break
else:
logger.info(
f"[Collect Trajectories Seed: {seed}] Episode reached max_turns ({max_turns})."
)
final_reward = sum(ep.step_rewards) if ep.step_rewards else 0.0
# Store metrics for this episode
episode_summary_metrics = {
"seed": ep.seed,
"total_reward": final_reward,
"num_steps": ep.num_steps,
"is_solved": ep.is_solved(),
}
self.completed_episode_metrics_buffer.append(episode_summary_metrics)
# Clean up the episode
if seed in self.episodes:
del self.episodes[seed]
# Ensure the trajectory doesn't exceed token limits
limited_trajectory_data = ensure_trajectory_token_limit(
trajectory_data_for_trainer,
self.tokenizer,
self.config.max_trajectory_tokens,
)
return limited_trajectory_data, []
async def setup(self):
"""Initialize the environment"""
# Nothing to do here as we don't need any special setup
pass
async def get_next_item(self) -> Tuple[int, int]:
"""Generate a new random seed for the next episode"""
return (random.randint(0, 1000000), 0)
async def rollout_and_score_eval(self, seed: int) -> Dict[str, float]:
"""Run a single episode for evaluation with a single response per step"""
ep = self._get_or_create_episode(seed)
max_turns = self.config.max_steps
metrics = {
"seed": seed,
"total_reward": 0.0,
"num_turns": 0,
"num_valid_moves": 0,
"num_invalid_moves": 0,
"is_solved": False,
}
for turn in range(max_turns):
messages = ep.message_history.copy()
agent_prompt_content = "<think>\n" if self.config.thinking_active else ""
messages.append({"role": "agent", "content": agent_prompt_content})
# Get a single response
responses = await self._sample_response(messages, n=1)
if not responses:
logger.error(
f"[Eval Seed: {seed}, Turn: {turn+1}] No response. Aborting."
)
break
llm_output_only = responses[0]
full_agent_response = agent_prompt_content + llm_output_only
# Parse and apply the move
move = self._parse_move(full_agent_response)
is_valid_move = False
if move is not None:
is_valid_move = ep.apply_move(move)
if is_valid_move:
metrics["num_valid_moves"] += 1
else:
metrics["num_invalid_moves"] += 1
# Calculate reward
is_solved = ep.is_solved()
reward = self._score_response(is_valid_move, full_agent_response, ep, is_solved)
metrics["total_reward"] += reward
metrics["num_turns"] += 1
# Add response to history
response_for_history = truncate_thinking(
full_agent_response, self.tokenizer, self.config.max_think_chars_history
)
ep.message_history.append(
{"role": "agent", "content": response_for_history}
)
# Add next observation if not terminal
is_terminal = is_solved or not is_valid_move
if not is_terminal:
ep.message_history.append(
{"role": "environment", "content": self._format_observation(ep)}
)
# Check for termination
if is_terminal:
metrics["is_solved"] = is_solved
logger.info(f"[Eval Seed: {seed}] Episode ended. Solved: {is_solved}")
break
# If we reached max turns, check final state
if metrics["num_turns"] == max_turns:
metrics["is_solved"] = ep.is_solved()
# Clean up
del self.episodes[seed]
return metrics
async def evaluate(self, *args, **kwargs):
"""Run evaluation episodes"""
if not self.config.use_wandb:
logger.info("Skipping evaluation as wandb is not enabled.")
return
num_eval_episodes = self.config.eval_episodes
logger.info(f"Starting evaluation for {num_eval_episodes} episodes.")
eval_tasks = [
self.rollout_and_score_eval(random.randint(1000001, 2000000))
for _ in range(num_eval_episodes)
]
all_metrics = await tqdm_asyncio.gather(*eval_tasks)
valid_metrics = [m for m in all_metrics if m]
if not valid_metrics:
logger.warning("No valid evaluation metrics.")
return
# Calculate metrics across all episodes
num_completed = len(valid_metrics)
avg_total_reward = sum(m["total_reward"] for m in valid_metrics) / num_completed
avg_num_turns = sum(m["num_turns"] for m in valid_metrics) / num_completed
total_valid_moves = sum(m["num_valid_moves"] for m in valid_metrics)
total_invalid_moves = sum(m["num_invalid_moves"] for m in valid_metrics)
total_moves = total_valid_moves + total_invalid_moves
move_validity_rate = total_valid_moves / total_moves if total_moves > 0 else 0
solve_rate = sum(1 for m in valid_metrics if m["is_solved"]) / num_completed
self.eval_metrics = [
("eval/avg_total_reward", avg_total_reward),
("eval/avg_num_turns", avg_num_turns),
("eval/move_validity_rate", move_validity_rate),
("eval/solve_rate", solve_rate),
("eval/num_completed_episodes", num_completed),
]
logger.info(f"Evaluation completed. Metrics: {self.eval_metrics}")
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
"""Log metrics to wandb"""
if wandb_metrics is None:
wandb_metrics = {}
if self.completed_episode_metrics_buffer:
num_episodes = len(self.completed_episode_metrics_buffer)
avg_reward = (
sum(m["total_reward"] for m in self.completed_episode_metrics_buffer)
/ num_episodes
)
avg_steps = (
sum(m["num_steps"] for m in self.completed_episode_metrics_buffer)
/ num_episodes
)
solve_rate = (
sum(
1
for m in self.completed_episode_metrics_buffer
if m["is_solved"]
)
/ num_episodes
)
# Log metrics
wandb_metrics[f"{self.wandb_prepend or 'rubiks'}_train/avg_episode_reward"] = avg_reward
wandb_metrics[f"{self.wandb_prepend or 'rubiks'}_train/avg_episode_steps"] = avg_steps
wandb_metrics[f"{self.wandb_prepend or 'rubiks'}_train/solve_rate"] = solve_rate
wandb_metrics[f"{self.wandb_prepend or 'rubiks'}_train/num_episodes"] = num_episodes
# Clear buffer
self.completed_episode_metrics_buffer = []
await super().wandb_log(wandb_metrics)
@classmethod
def config_init(cls) -> Tuple[RubiksCubeEnvConfig, List[APIServerConfig]]:
"""Initialize the configuration"""
env_config = RubiksCubeEnvConfig(
tokenizer_name="openai/gpt-4-turbo-preview",
group_size=16,
use_wandb=True,
max_num_workers=128,
rollout_server_url="http://localhost:9000",
total_steps=2000,
batch_size=1024,
steps_per_eval=20,
max_token_length=1024 * 16,
inference_weight=1.0,
wandb_name="rubiks_cube",
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
max_steps=20,
temperature=0.7,
top_p=0.9,
thinking_active=True,
eval_episodes=100,
max_think_chars_history=3000,
max_trajectory_tokens=24576,
debug_mode=False,
tiebreak_token_factor=0.01,
scramble_moves=7,
cube_size=3,
reward_per_correctly_placed_cubie=0.05,
reward_per_step_reduction=0.01,
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
num_requests_for_eval=256,
)
]
return env_config, server_configs
@classmethod
def cli(cls):
"""Command-line interface"""
super().cli()
if __name__ == "__main__":
RubiksCubeEnv.cli()