mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
1814 lines
66 KiB
Python
1814 lines
66 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
RubiksCubeEnv: Trainer environment for Rubik's Cube solving with multi-step reasoning
|
|
|
|
This environment implements a Rubik's cube solver that trains LLMs to solve cubes
|
|
through step-by-step reasoning and visualization. Extends BaseEnv.
|
|
"""
|
|
|
|
import copy
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
# Selcube-specific imports
|
|
from rubiks_cube_curriculum import RubiksCubeCurriculum
|
|
from rubiks_enhanced_visualizer import save_enhanced_visualization
|
|
from rubiks_strategies import get_strategy_prompt_for_level
|
|
from rubiks_token_rewards import (
|
|
calculate_advantage_token_weights,
|
|
calculate_token_level_rewards,
|
|
)
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
# Atropos imports
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
EvalHandlingEnum,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.utils.message_history_utils import (
|
|
ensure_trajectory_token_limit,
|
|
truncate_thinking,
|
|
)
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
from atroposlib.utils.tool_call_parser import parse_tool_call
|
|
|
|
# Define the face colors for visualization
|
|
UP_COLOR = "W" # White
|
|
DOWN_COLOR = "Y" # Yellow
|
|
RIGHT_COLOR = "R" # Red
|
|
LEFT_COLOR = "O" # Orange
|
|
FRONT_COLOR = "G" # Green
|
|
BACK_COLOR = "B" # Blue
|
|
|
|
|
|
class Cube:
|
|
"""
|
|
A Rubik's cube implementation with accurate move handling.
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Initialize a solved cube
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
"""Reset the cube to solved state"""
|
|
# Initialize the cube as a 3D array [face][row][col]
|
|
# Faces: 0=UP, 1=DOWN, 2=LEFT, 3=RIGHT, 4=FRONT, 5=BACK
|
|
self.cube = [
|
|
[[UP_COLOR for _ in range(3)] for _ in range(3)], # UP
|
|
[[DOWN_COLOR for _ in range(3)] for _ in range(3)], # DOWN
|
|
[[LEFT_COLOR for _ in range(3)] for _ in range(3)], # LEFT
|
|
[[RIGHT_COLOR for _ in range(3)] for _ in range(3)], # RIGHT
|
|
[[FRONT_COLOR for _ in range(3)] for _ in range(3)], # FRONT
|
|
[[BACK_COLOR for _ in range(3)] for _ in range(3)], # BACK
|
|
]
|
|
|
|
def is_solved(self) -> bool:
|
|
"""Check if the cube is solved"""
|
|
for face in self.cube:
|
|
center_color = face[1][1] # Center color never changes
|
|
for row in face:
|
|
for color in row:
|
|
if color != center_color:
|
|
return False
|
|
return True
|
|
|
|
def count_solved_cubies(self) -> float:
|
|
"""
|
|
Count the number of stickers in their correct position
|
|
Returns a normalized score between 0 and 1
|
|
"""
|
|
# Create a solved reference cube
|
|
reference = Cube()
|
|
|
|
# Count matching stickers
|
|
total_stickers = 6 * 9 # 6 faces, 9 stickers per face
|
|
match_count = 0
|
|
|
|
for face_idx in range(6):
|
|
for i in range(3):
|
|
for j in range(3):
|
|
if self.cube[face_idx][i][j] == reference.cube[face_idx][i][j]:
|
|
match_count += 1
|
|
|
|
return match_count / total_stickers
|
|
|
|
def count_correct_centers(self) -> int:
|
|
"""Count how many face centers are in their correct positions"""
|
|
# Centers never move in a standard 3x3 cube, so this should always return 6
|
|
# But for consistency with the reward system, we'll check anyway
|
|
reference = Cube()
|
|
center_count = 0
|
|
|
|
for face_idx in range(6):
|
|
if self.cube[face_idx][1][1] == reference.cube[face_idx][1][1]:
|
|
center_count += 1
|
|
|
|
return center_count
|
|
|
|
def count_solved_faces(self) -> int:
|
|
"""Count how many faces are completely solved"""
|
|
solved_faces = 0
|
|
|
|
for face in self.cube:
|
|
center_color = face[1][1]
|
|
face_solved = True
|
|
|
|
for row in face:
|
|
for color in row:
|
|
if color != center_color:
|
|
face_solved = False
|
|
break
|
|
if not face_solved:
|
|
break
|
|
|
|
if face_solved:
|
|
solved_faces += 1
|
|
|
|
return solved_faces
|
|
|
|
def has_cross_on_face(self) -> bool:
|
|
"""Check if there's a cross pattern on any face"""
|
|
for face_idx, face in enumerate(self.cube):
|
|
center_color = face[1][1]
|
|
|
|
# Check for cross pattern (center + middle edges match)
|
|
if (
|
|
face[0][1] == center_color
|
|
and face[1][0] == center_color # Top middle
|
|
and face[1][2] == center_color # Left middle
|
|
and face[2][1] == center_color # Right middle
|
|
): # Bottom middle
|
|
return True
|
|
|
|
return False
|
|
|
|
def count_correct_corners(self) -> int:
|
|
"""Count correctly positioned corner pieces"""
|
|
reference = Cube()
|
|
corner_count = 0
|
|
|
|
# Corner positions on each face
|
|
corners = [(0, 0), (0, 2), (2, 0), (2, 2)]
|
|
|
|
for face_idx in range(6):
|
|
for i, j in corners:
|
|
if self.cube[face_idx][i][j] == reference.cube[face_idx][i][j]:
|
|
corner_count += 1
|
|
|
|
# Divide by 3 because each corner piece appears on 3 faces
|
|
return corner_count // 3
|
|
|
|
def count_correct_edges(self) -> int:
|
|
"""Count correctly positioned edge pieces"""
|
|
reference = Cube()
|
|
edge_count = 0
|
|
|
|
# Edge positions on each face
|
|
edges = [(0, 1), (1, 0), (1, 2), (2, 1)]
|
|
|
|
for face_idx in range(6):
|
|
for i, j in edges:
|
|
if self.cube[face_idx][i][j] == reference.cube[face_idx][i][j]:
|
|
edge_count += 1
|
|
|
|
# Divide by 2 because each edge piece appears on 2 faces
|
|
return edge_count // 2
|
|
|
|
def partial_face_completion(self) -> float:
|
|
"""Calculate partial completion of faces (how close each face is to being solved)"""
|
|
total_score = 0.0
|
|
|
|
for face in self.cube:
|
|
center_color = face[1][1]
|
|
face_matches = 0
|
|
|
|
for row in face:
|
|
for color in row:
|
|
if color == center_color:
|
|
face_matches += 1
|
|
|
|
# Convert to a percentage (0.0 - 1.0) of completion for this face
|
|
face_score = (face_matches - 1) / 8.0 # -1 because center is always correct
|
|
total_score += face_score
|
|
|
|
return total_score
|
|
|
|
def rotate(self, move: str):
|
|
"""
|
|
Perform a move on the cube using standard notation
|
|
U, D, L, R, F, B are clockwise rotations of respective faces
|
|
U', D', L', R', F', B' are counterclockwise rotations
|
|
U2, D2, L2, R2, F2, B2 are double (180°) rotations
|
|
"""
|
|
# Map move notation to face index and rotation count
|
|
face_map = {"U": 0, "D": 1, "L": 2, "R": 3, "F": 4, "B": 5}
|
|
|
|
# Parse the move
|
|
if len(move) == 0:
|
|
raise ValueError("Empty move")
|
|
|
|
face = move[0]
|
|
if face not in face_map:
|
|
raise ValueError(f"Invalid face: {face}")
|
|
|
|
face_idx = face_map[face]
|
|
|
|
# Handle rotation direction
|
|
if len(move) == 1:
|
|
# Clockwise rotation
|
|
count = 1
|
|
elif len(move) == 2:
|
|
if move[1] == "'":
|
|
# Counterclockwise rotation
|
|
count = 3
|
|
elif move[1] == "2":
|
|
# Double rotation
|
|
count = 2
|
|
else:
|
|
raise ValueError(f"Invalid move modifier: {move[1]}")
|
|
else:
|
|
raise ValueError(f"Invalid move format: {move}")
|
|
|
|
# Apply the rotation 'count' times
|
|
for _ in range(count):
|
|
self._rotate_face_clockwise(face_idx)
|
|
self._rotate_adjacent_faces(face_idx)
|
|
|
|
def _rotate_face_clockwise(self, face_idx: int):
|
|
"""Rotate a face clockwise"""
|
|
face = self.cube[face_idx]
|
|
new_face = [[None for _ in range(3)] for _ in range(3)]
|
|
|
|
# Copy with 90-degree clockwise rotation
|
|
for i in range(3):
|
|
for j in range(3):
|
|
new_face[j][2 - i] = face[i][j]
|
|
|
|
self.cube[face_idx] = new_face
|
|
|
|
def _rotate_adjacent_faces(self, face_idx: int):
|
|
"""Rotate the appropriate edges on adjacent faces"""
|
|
if face_idx == 0: # UP face
|
|
# Rotate the top edges of FRONT, RIGHT, BACK, LEFT
|
|
temp = self.cube[4][0][:] # Save FRONT top edge
|
|
self.cube[4][0] = self.cube[2][0][:] # FRONT <- LEFT
|
|
self.cube[2][0] = self.cube[5][0][:] # LEFT <- BACK
|
|
self.cube[5][0] = self.cube[3][0][:] # BACK <- RIGHT
|
|
self.cube[3][0] = temp # RIGHT <- FRONT
|
|
|
|
elif face_idx == 1: # DOWN face
|
|
# Rotate the bottom edges of FRONT, LEFT, BACK, RIGHT
|
|
temp = self.cube[4][2][:] # Save FRONT bottom edge
|
|
self.cube[4][2] = self.cube[3][2][:] # FRONT <- RIGHT
|
|
self.cube[3][2] = self.cube[5][2][:] # RIGHT <- BACK
|
|
self.cube[5][2] = self.cube[2][2][:] # BACK <- LEFT
|
|
self.cube[2][2] = temp # LEFT <- FRONT
|
|
|
|
elif face_idx == 2: # LEFT face
|
|
# Rotate the left edges of UP, FRONT, DOWN, BACK
|
|
# Need to extract and set columns, not rows
|
|
temp = [self.cube[0][i][0] for i in range(3)] # Save UP left column
|
|
|
|
# UP left <- BACK right (reversed)
|
|
for i in range(3):
|
|
self.cube[0][i][0] = self.cube[5][2 - i][2]
|
|
|
|
# BACK right <- DOWN left (reversed)
|
|
for i in range(3):
|
|
self.cube[5][i][2] = self.cube[1][2 - i][0]
|
|
|
|
# DOWN left <- FRONT left
|
|
for i in range(3):
|
|
self.cube[1][i][0] = self.cube[4][i][0]
|
|
|
|
# FRONT left <- UP left
|
|
for i in range(3):
|
|
self.cube[4][i][0] = temp[i]
|
|
|
|
elif face_idx == 3: # RIGHT face
|
|
# Rotate the right edges of UP, BACK, DOWN, FRONT
|
|
temp = [self.cube[0][i][2] for i in range(3)] # Save UP right column
|
|
|
|
# UP right <- FRONT right
|
|
for i in range(3):
|
|
self.cube[0][i][2] = self.cube[4][i][2]
|
|
|
|
# FRONT right <- DOWN right
|
|
for i in range(3):
|
|
self.cube[4][i][2] = self.cube[1][i][2]
|
|
|
|
# DOWN right <- BACK left (reversed)
|
|
for i in range(3):
|
|
self.cube[1][i][2] = self.cube[5][2 - i][0]
|
|
|
|
# BACK left <- UP right (reversed)
|
|
for i in range(3):
|
|
self.cube[5][i][0] = temp[2 - i]
|
|
|
|
elif face_idx == 4: # FRONT face
|
|
# Rotate the edges of UP bottom, RIGHT left, DOWN top, LEFT right
|
|
# UP bottom row
|
|
temp = self.cube[0][2][:]
|
|
|
|
# UP bottom <- LEFT right (rotated)
|
|
for i in range(3):
|
|
self.cube[0][2][i] = self.cube[2][2 - i][2]
|
|
|
|
# LEFT right <- DOWN top (rotated)
|
|
for i in range(3):
|
|
self.cube[2][i][2] = self.cube[1][0][i]
|
|
|
|
# DOWN top <- RIGHT left (rotated)
|
|
for i in range(3):
|
|
self.cube[1][0][i] = self.cube[3][2 - i][0]
|
|
|
|
# RIGHT left <- UP bottom (rotated)
|
|
for i in range(3):
|
|
self.cube[3][i][0] = temp[i]
|
|
|
|
elif face_idx == 5: # BACK face
|
|
# Rotate the edges of UP top, LEFT left, DOWN bottom, RIGHT right
|
|
# UP top row
|
|
temp = self.cube[0][0][:]
|
|
|
|
# UP top <- RIGHT right (rotated)
|
|
for i in range(3):
|
|
self.cube[0][0][i] = self.cube[3][2 - i][2]
|
|
|
|
# RIGHT right <- DOWN bottom (rotated)
|
|
for i in range(3):
|
|
self.cube[3][i][2] = self.cube[1][2][i]
|
|
|
|
# DOWN bottom <- LEFT left (rotated)
|
|
for i in range(3):
|
|
self.cube[1][2][i] = self.cube[2][2 - i][0]
|
|
|
|
# LEFT left <- UP top (rotated)
|
|
for i in range(3):
|
|
self.cube[2][i][0] = temp[i]
|
|
|
|
def __str__(self) -> str:
|
|
"""Convert cube to string representation"""
|
|
face_names = ["U", "D", "L", "R", "F", "B"]
|
|
result = []
|
|
|
|
for i, face in enumerate(self.cube):
|
|
result.append(f"{face_names[i]}: {' '.join(face[0])}")
|
|
result.append(f" {' '.join(face[1])}")
|
|
result.append(f" {' '.join(face[2])}")
|
|
|
|
return "\n".join(result)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RubiksCubeEnvConfig(BaseEnvConfig):
|
|
# Environment configuration
|
|
max_steps: int = 20 # Maximum steps allowed to solve the cube
|
|
temperature: float = 0.7
|
|
top_p: float = 0.9
|
|
wandb_name: str = "rubiks_cube"
|
|
thinking_active: bool = True
|
|
eval_episodes: int = 100
|
|
max_think_chars_history: int = 3000
|
|
max_trajectory_tokens: int = 24576 # seq_len of RL trainer
|
|
debug_mode: bool = False
|
|
group_size: int = 16
|
|
tiebreak_token_factor: float = 0.01
|
|
|
|
# Cube-specific configuration
|
|
scramble_moves: int = 7 # Number of random moves to scramble the cube
|
|
cube_size: int = 3 # 3x3 cube by default
|
|
reward_per_correctly_placed_cubie: float = 0.05
|
|
reward_per_step_reduction: float = 0.01 # Small penalty for using more steps
|
|
|
|
# Curriculum learning configuration
|
|
use_curriculum: bool = True
|
|
curriculum_starting_level: int = 1
|
|
curriculum_max_level: int = 5
|
|
curriculum_auto_progress: bool = True
|
|
curriculum_success_threshold: float = 0.7
|
|
curriculum_advancement_window: int = 50
|
|
curriculum_min_solved: int = 25
|
|
|
|
# Token-level reward configuration
|
|
use_token_level_rewards: bool = True
|
|
token_reward_scale: float = 0.1 # Scale factor for token-level rewards
|
|
|
|
# Visualization configuration
|
|
generate_visualizations: bool = True
|
|
visualizations_dir: str = "./rubiks_visualizations"
|
|
save_best_episodes: bool = True # Save visualizations for best episodes
|
|
|
|
# Solving strategies configuration
|
|
provide_solving_strategies: bool = (
|
|
True # Include solving strategies in system prompt
|
|
)
|
|
strategy_explanation_reward: float = (
|
|
0.1 # Bonus reward for using strategy explanations
|
|
)
|
|
|
|
|
|
class RubiksCubeScoredDataGroup(ScoredDataGroup):
|
|
seed: int
|
|
tokens: Optional[List[List[int]]] = None
|
|
masks: Optional[List[List[int]]] = None
|
|
scores: Optional[List[float]] = None
|
|
token_scores: Optional[List[List[float]]] = None # Token-level rewards
|
|
token_weights: Optional[List[List[float]]] = None # Token weights for advantages
|
|
messages: Optional[List[List[Dict]]] = None
|
|
parsed_actions: Optional[List[str]] = None
|
|
|
|
|
|
class CubeState:
|
|
def __init__(self, seed: int, scramble_moves: int):
|
|
self.seed = seed
|
|
self.cube = Cube()
|
|
self.message_history: List[Dict] = []
|
|
self.actions: List[str] = []
|
|
self.step_rewards: List[float] = []
|
|
self.total_reward: float = 0.0
|
|
self.num_steps: int = 0
|
|
|
|
# Curriculum-specific settings (will be overridden by _get_or_create_episode)
|
|
self.max_steps = 20 # Default, will be overridden by curriculum
|
|
self.reward_per_correctly_placed_cubie = 0.05 # Default, will be overridden
|
|
self.curriculum_level = 0 # Not using curriculum
|
|
|
|
# Track scramble sequence for visualization and analysis
|
|
self.scramble_sequence: List[str] = []
|
|
self.scramble_sequence_length: int = 0
|
|
|
|
# Track progress history for visualization
|
|
self.progress_history: List[float] = []
|
|
|
|
# Seed random number generator for reproducibility
|
|
random.seed(seed)
|
|
|
|
# Reset cube to solved state
|
|
self.cube.reset()
|
|
|
|
# Scramble the cube with random moves
|
|
self._scramble_cube(scramble_moves)
|
|
|
|
# Record initial progress
|
|
self.progress_history.append(self.cube.count_solved_cubies())
|
|
|
|
def _scramble_cube(self, num_moves: int):
|
|
"""Scramble the cube with random moves"""
|
|
moves = [
|
|
"U",
|
|
"D",
|
|
"L",
|
|
"R",
|
|
"F",
|
|
"B",
|
|
"U'",
|
|
"D'",
|
|
"L'",
|
|
"R'",
|
|
"F'",
|
|
"B'",
|
|
"U2",
|
|
"D2",
|
|
"L2",
|
|
"R2",
|
|
"F2",
|
|
"B2",
|
|
]
|
|
|
|
self.scramble_sequence = []
|
|
for _ in range(num_moves):
|
|
move = random.choice(moves)
|
|
self.scramble_sequence.append(move)
|
|
self.cube.rotate(move)
|
|
|
|
# Store the length of the scramble sequence
|
|
self.scramble_sequence_length = len(self.scramble_sequence)
|
|
|
|
return " ".join(self.scramble_sequence)
|
|
|
|
def apply_move(self, move: str) -> bool:
|
|
"""Apply a move to the cube and return success"""
|
|
try:
|
|
self.cube.rotate(move)
|
|
self.actions.append(move)
|
|
self.num_steps += 1
|
|
|
|
# Record progress after move
|
|
self.progress_history.append(self.cube.count_solved_cubies())
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error applying move {move}: {e}")
|
|
return False
|
|
|
|
def is_solved(self) -> bool:
|
|
"""Check if the cube is solved"""
|
|
return self.cube.is_solved()
|
|
|
|
def get_cube_state_visualization(self) -> str:
|
|
"""Get a text representation of the cube state for visualization"""
|
|
# This returns a readable string representation of the cube layout
|
|
return str(self.cube)
|
|
|
|
|
|
class RubiksCubeEnv(BaseEnv):
|
|
def __init__(
|
|
self,
|
|
config: RubiksCubeEnvConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm: bool = True,
|
|
testing: bool = False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.episodes: Dict[int, CubeState] = {}
|
|
self.debug_mode = config.debug_mode
|
|
self.completed_episode_metrics_buffer: List[Dict[str, float]] = []
|
|
|
|
if self.debug_mode:
|
|
logger.setLevel(logging.DEBUG)
|
|
else:
|
|
if logger.level == logging.NOTSET or logger.level > logging.WARNING:
|
|
logger.setLevel(logging.WARNING)
|
|
|
|
# Initialize curriculum learning if enabled
|
|
self.use_curriculum = config.use_curriculum
|
|
if self.use_curriculum:
|
|
self.curriculum = RubiksCubeCurriculum(
|
|
starting_level=config.curriculum_starting_level,
|
|
max_level=config.curriculum_max_level,
|
|
auto_progress=config.curriculum_auto_progress,
|
|
success_threshold=config.curriculum_success_threshold,
|
|
advancement_window_size=config.curriculum_advancement_window,
|
|
min_solved_at_level=config.curriculum_min_solved,
|
|
)
|
|
logger.info(
|
|
f"Initialized curriculum learning at level {self.curriculum.current_level}"
|
|
)
|
|
|
|
# Token-level reward tracking
|
|
self.use_token_level_rewards = config.use_token_level_rewards
|
|
self.token_reward_scale = config.token_reward_scale
|
|
|
|
# Visualization settings
|
|
self.generate_visualizations = config.generate_visualizations
|
|
self.visualizations_dir = config.visualizations_dir
|
|
self.save_best_episodes = config.save_best_episodes
|
|
|
|
# Create visualizations directory if it doesn't exist
|
|
if self.generate_visualizations and self.visualizations_dir:
|
|
import os
|
|
|
|
os.makedirs(self.visualizations_dir, exist_ok=True)
|
|
logger.info(
|
|
f"Created visualizations directory at {self.visualizations_dir}"
|
|
)
|
|
|
|
# Track best episodes for visualization
|
|
self.best_episode_reward = -float("inf") # Track best reward
|
|
self.best_episode = None
|
|
|
|
self.tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "apply_move",
|
|
"description": "Apply a move to the Rubik's cube.",
|
|
"parameters": {
|
|
"move": {
|
|
"type": "string",
|
|
"description": (
|
|
"The move to apply to the cube. Valid moves are U, D, L, R, F, B "
|
|
"(clockwise), U', D', L', R', F', B' (counterclockwise), and "
|
|
"U2, D2, L2, R2, F2, B2 (180 degrees)."
|
|
),
|
|
},
|
|
"strategy": {
|
|
"type": "string",
|
|
"description": (
|
|
"Optional: The solving strategy you're using "
|
|
"(e.g., 'Layer-by-Layer', 'CFOP', 'Beginner Method')"
|
|
),
|
|
},
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
tools_json = json.dumps(self.tools)
|
|
|
|
# Base system prompt
|
|
base_prompt = (
|
|
"You are an AI that solves Rubik's cubes step-by-step with clear reasoning. "
|
|
"You will be given the current state of a Rubik's cube, and you need to provide "
|
|
"moves to solve it.\n\n"
|
|
"The notation for cube moves follows the standard Rubik's cube notation:\n"
|
|
"- U: rotate the up face clockwise\n"
|
|
"- D: rotate the down face clockwise\n"
|
|
"- L: rotate the left face clockwise\n"
|
|
"- R: rotate the right face clockwise\n"
|
|
"- F: rotate the front face clockwise\n"
|
|
"- B: rotate the back face clockwise\n"
|
|
"- U', D', L', R', F', B': rotate the corresponding face counterclockwise\n"
|
|
"- U2, D2, L2, R2, F2, B2: rotate the corresponding face 180 degrees\n\n"
|
|
"You should analyze the current state of the cube, identify patterns, "
|
|
"and explain your reasoning step by step.\n\n"
|
|
"You should enclose your thoughts and internal monologue inside <think> </think> "
|
|
"tags, and then provide your move using the apply_move function call.\n\n"
|
|
f"<tools>\n{tools_json}\n</tools>\n\n"
|
|
"For your function call, return a JSON object with function name and "
|
|
"arguments within <tool_call> </tool_call> tags with the following schema:\n"
|
|
'<tool_call>\n{"arguments": {"move": "U", "strategy": "Layer-by-Layer"}, '
|
|
'"name": "apply_move"}\n</tool_call>\n\n'
|
|
"Your full answer format should be:\n"
|
|
"<think>\n[Your detailed reasoning about the current cube state and the "
|
|
"best move to make]\n</think>\n\n"
|
|
'<tool_call>\n{"arguments": {"move": "R", "strategy": "Layer-by-Layer"}, '
|
|
'"name": "apply_move"}\n</tool_call>\n\n'
|
|
"Remember to carefully analyze the cube state and work toward the solution step by step.\n\n"
|
|
"When solving the cube:\n"
|
|
"1. Clearly state which strategy you're using\n"
|
|
"2. Explain how each move contributes to your solving strategy\n"
|
|
"3. Identify specific patterns or cases you recognize\n"
|
|
"4. Mention which step of your chosen method you're working on\n"
|
|
)
|
|
|
|
# Initialize the system prompt
|
|
self.system_prompt = base_prompt
|
|
|
|
# Add solving strategies if enabled
|
|
self.provide_solving_strategies = config.provide_solving_strategies
|
|
self.strategy_explanation_reward = config.strategy_explanation_reward
|
|
|
|
def _get_or_create_episode(self, seed: int) -> CubeState:
|
|
if seed not in self.episodes:
|
|
# Determine scramble moves based on curriculum if enabled
|
|
if self.use_curriculum:
|
|
current_level = self.curriculum.get_current_level()
|
|
scramble_moves = current_level.get_scramble_moves()
|
|
max_steps = current_level.max_steps
|
|
reward_per_cubie = current_level.reward_per_correctly_placed_cubie
|
|
|
|
# Update config with curriculum settings for this episode
|
|
# Note: We're not modifying the original config, just using curriculum values
|
|
logger.debug(
|
|
f"Using curriculum level {current_level.level} settings: "
|
|
f"{scramble_moves} scramble moves, {max_steps} max steps"
|
|
)
|
|
|
|
# Add solving strategies appropriate for this level if enabled
|
|
if self.provide_solving_strategies:
|
|
strategies_prompt = get_strategy_prompt_for_level(
|
|
current_level.level
|
|
)
|
|
# Combine base prompt with strategies
|
|
complete_prompt = self.system_prompt + "\n\n" + strategies_prompt
|
|
else:
|
|
complete_prompt = self.system_prompt
|
|
else:
|
|
scramble_moves = self.config.scramble_moves
|
|
max_steps = self.config.max_steps
|
|
reward_per_cubie = self.config.reward_per_correctly_placed_cubie
|
|
complete_prompt = self.system_prompt
|
|
|
|
# Create the episode
|
|
ep = CubeState(seed, scramble_moves)
|
|
ep.message_history = [{"role": "system", "content": complete_prompt}]
|
|
|
|
# Store curriculum-specific settings for this episode
|
|
ep.max_steps = max_steps
|
|
ep.reward_per_correctly_placed_cubie = reward_per_cubie
|
|
ep.curriculum_level = (
|
|
self.curriculum.current_level if self.use_curriculum else 0
|
|
)
|
|
|
|
# Track the solving strategy used
|
|
ep.current_strategy = None
|
|
|
|
# Add initial observation
|
|
ep.message_history.append(
|
|
{"role": "environment", "content": self._format_observation(ep)}
|
|
)
|
|
self.episodes[seed] = ep
|
|
return self.episodes[seed]
|
|
|
|
def _format_observation(self, cube_state: CubeState) -> str:
|
|
"""Format the cube state as a string observation for the LLM"""
|
|
cube_visualization = cube_state.get_cube_state_visualization()
|
|
|
|
moves_made = ", ".join(cube_state.actions) if cube_state.actions else "None"
|
|
steps_remaining = cube_state.max_steps - cube_state.num_steps
|
|
|
|
# Add scramble info for debugging and learning
|
|
scramble_length = len(cube_state.scramble_sequence)
|
|
|
|
message = (
|
|
f"Current state of the Rubik's cube:\n\n"
|
|
f"```\n{cube_visualization}\n```\n\n"
|
|
f"Previous moves: {moves_made}\n"
|
|
f"Steps remaining: {steps_remaining}\n"
|
|
)
|
|
|
|
# Add curriculum level info if using curriculum
|
|
if self.use_curriculum and cube_state.curriculum_level > 0:
|
|
current_level = self.curriculum.levels[cube_state.curriculum_level]
|
|
message += f"\nDifficulty level: {current_level.description}\n"
|
|
message += f"Scramble depth: {scramble_length} moves\n"
|
|
|
|
# Show the current solved percentage to help the LLM
|
|
solved_percentage = cube_state.cube.count_solved_cubies() * 100
|
|
message += (
|
|
f"\nCurrent progress: {solved_percentage:.1f}% of cubies correctly placed\n"
|
|
)
|
|
|
|
if cube_state.is_solved():
|
|
message += "\nCongratulations! The cube is now solved."
|
|
|
|
return message
|
|
|
|
def _calculate_cube_state_score(self, cube_state: CubeState) -> float:
|
|
"""
|
|
Calculate a score based on how close the cube is to being solved.
|
|
Higher scores for cubes that are closer to being solved.
|
|
|
|
Uses curriculum-specific reward settings if available.
|
|
"""
|
|
# Base score
|
|
score = 0.0
|
|
|
|
# Get the current state
|
|
cube = cube_state.cube
|
|
|
|
# Get the progress (correctly positioned cubies)
|
|
correctly_placed = cube.count_solved_cubies()
|
|
|
|
# Calculate how many face centers are correctly placed
|
|
face_centers_correct = cube.count_correct_centers()
|
|
center_bonus = (
|
|
face_centers_correct * 0.05
|
|
) # Small bonus for each correct center
|
|
|
|
# Calculate how many entire faces are solved
|
|
faces_solved = cube.count_solved_faces()
|
|
face_bonus = faces_solved * 0.1 # Significant bonus for each solved face
|
|
|
|
# Use the episode's specific reward per cubie setting (from curriculum)
|
|
progress_reward = (
|
|
correctly_placed * cube_state.reward_per_correctly_placed_cubie
|
|
)
|
|
|
|
# Calculate pattern-based rewards (recognizing common patterns like crosses)
|
|
cross_bonus = 0.0
|
|
if cube.has_cross_on_face():
|
|
cross_bonus = 0.15 # Bonus for having a cross on any face
|
|
|
|
# Calculate corner reward
|
|
corners_correct = cube.count_correct_corners()
|
|
corner_bonus = (
|
|
corners_correct * 0.03
|
|
) # Bonus for each correctly positioned corner
|
|
|
|
# Calculate edge reward
|
|
edges_correct = cube.count_correct_edges()
|
|
edge_bonus = edges_correct * 0.02 # Bonus for each correctly positioned edge
|
|
|
|
# Partial face reward
|
|
partial_faces = cube.partial_face_completion()
|
|
partial_face_bonus = (
|
|
partial_faces * 0.01
|
|
) # Small bonus for partial face completion
|
|
|
|
# Add base progress rewards
|
|
score += (
|
|
progress_reward
|
|
+ center_bonus
|
|
+ face_bonus
|
|
+ cross_bonus
|
|
+ corner_bonus
|
|
+ edge_bonus
|
|
+ partial_face_bonus
|
|
)
|
|
|
|
# Small penalty for using more steps
|
|
steps_penalty = cube_state.num_steps * self.config.reward_per_step_reduction
|
|
score -= steps_penalty
|
|
|
|
# Progressive reward for improvement
|
|
if len(cube_state.step_rewards) > 0:
|
|
last_progress = cube_state.step_rewards[-1]
|
|
if score > last_progress:
|
|
# Give bonus for improvement
|
|
improvement_bonus = (score - last_progress) * 0.5
|
|
score += improvement_bonus
|
|
|
|
# Reward for a solved cube - higher reward for more difficult curriculum levels
|
|
if cube_state.is_solved():
|
|
# Base reward for solving
|
|
score += 2.0
|
|
|
|
# Bonus for solving with fewer moves
|
|
efficiency_bonus = 2.0 * (
|
|
1.0 - (cube_state.num_steps / cube_state.max_steps)
|
|
)
|
|
score += efficiency_bonus
|
|
|
|
# Curriculum level bonus (higher levels get higher rewards)
|
|
if self.use_curriculum and cube_state.curriculum_level > 0:
|
|
level_bonus = cube_state.curriculum_level * 0.3 # 0.3 per level
|
|
score += level_bonus
|
|
|
|
# Huge bonus for optimal solution (close to minimum moves)
|
|
if cube_state.num_steps <= cube_state.scramble_sequence_length + 2:
|
|
score += 3.0 # Exceptional solution bonus
|
|
|
|
return score
|
|
|
|
def _parse_move(
|
|
self, response: str, ep: CubeState = None
|
|
) -> Tuple[Optional[str], Optional[str]]:
|
|
"""Extract move and strategy from the LLM response"""
|
|
if not response:
|
|
logger.warning(
|
|
"Attempted to parse an empty response string. Returning None."
|
|
)
|
|
return None, None
|
|
|
|
# First try parsing with tool_call tags
|
|
parsed_name, parsed_args, is_error = parse_tool_call(
|
|
response, self.tools, ["tool_call"]
|
|
)
|
|
|
|
# If that fails, try looking for direct text mentions of moves
|
|
if is_error:
|
|
error_detail = (
|
|
parsed_name
|
|
if isinstance(parsed_name, str) and parsed_name
|
|
else "Parser indicated error, but no specific message was returned"
|
|
)
|
|
logger.warning(
|
|
f"Failed to parse tool call. Full response: '{response}'. Error detail: {error_detail}"
|
|
)
|
|
|
|
# Fallback: Look for direct mentions of moves in the text
|
|
valid_moves = [
|
|
"U",
|
|
"D",
|
|
"L",
|
|
"R",
|
|
"F",
|
|
"B",
|
|
"U'",
|
|
"D'",
|
|
"L'",
|
|
"R'",
|
|
"F'",
|
|
"B'",
|
|
"U2",
|
|
"D2",
|
|
"L2",
|
|
"R2",
|
|
"F2",
|
|
"B2",
|
|
]
|
|
|
|
# Look for patterns like "I'll apply move X" or "Performing X rotation"
|
|
move_patterns = [
|
|
r"move\s+([UDLRFB][\'2]?)",
|
|
r"applying\s+([UDLRFB][\'2]?)",
|
|
r"perform\w*\s+([UDLRFB][\'2]?)",
|
|
r"rotate\s+([UDLRFB][\'2]?)",
|
|
r"rotation\s+([UDLRFB][\'2]?)",
|
|
r"I\s*choose\s+([UDLRFB][\'2]?)",
|
|
r"Execute\s+([UDLRFB][\'2]?)",
|
|
r"([UDLRFB][\'2]?)\s+rotation",
|
|
r"([UDLRFB][\'2]?)\s+move",
|
|
]
|
|
|
|
for pattern in move_patterns:
|
|
match = re.search(pattern, response, re.IGNORECASE)
|
|
if match:
|
|
potential_move = match.group(1).strip()
|
|
if potential_move in valid_moves:
|
|
logger.warning(
|
|
f"Recovered move '{potential_move}' from text using regex"
|
|
)
|
|
return potential_move, None
|
|
|
|
return None, None
|
|
|
|
# Extract move and strategy
|
|
move = (
|
|
parsed_args.get("move", "").strip() if isinstance(parsed_args, dict) else ""
|
|
)
|
|
strategy = (
|
|
parsed_args.get("strategy", "").strip()
|
|
if isinstance(parsed_args, dict)
|
|
else ""
|
|
)
|
|
|
|
# Track the strategy in the episode if provided
|
|
if ep and strategy and strategy != ep.current_strategy:
|
|
logger.info(f"Strategy changed to: {strategy}")
|
|
ep.current_strategy = strategy
|
|
|
|
valid_moves = [
|
|
"U",
|
|
"D",
|
|
"L",
|
|
"R",
|
|
"F",
|
|
"B",
|
|
"U'",
|
|
"D'",
|
|
"L'",
|
|
"R'",
|
|
"F'",
|
|
"B'",
|
|
"U2",
|
|
"D2",
|
|
"L2",
|
|
"R2",
|
|
"F2",
|
|
"B2",
|
|
]
|
|
|
|
# First check if the move is directly valid
|
|
if move in valid_moves:
|
|
return move, strategy
|
|
|
|
# Check if the move is a sequence containing valid moves
|
|
# (when LLM outputs move sequences like "R U R'")
|
|
if " " in move:
|
|
# Take only the first move in the sequence
|
|
first_move = move.split()[0].strip()
|
|
if first_move in valid_moves:
|
|
logger.warning(
|
|
f"Got move sequence '{move}' but taking only first move '{first_move}'"
|
|
)
|
|
return first_move, strategy
|
|
|
|
# If we get here, the move is invalid
|
|
logger.warning(
|
|
f"Parsed invalid move: '{move}'. "
|
|
f"Full response: '{response}'. Parsed args: {parsed_args}"
|
|
)
|
|
return None, strategy
|
|
|
|
def _score_response(
|
|
self,
|
|
is_valid_move: bool,
|
|
response_text: str,
|
|
cube_state: CubeState,
|
|
is_solved: bool,
|
|
) -> float:
|
|
"""
|
|
Calculate a score for a single agent response based on:
|
|
1. Whether the move was valid
|
|
2. Whether the move helps solve the cube
|
|
3. Presence of thinking tags
|
|
4. If the cube is solved
|
|
5. Strategy explanation quality
|
|
"""
|
|
# Base score from cube state after the move
|
|
current_score = self._calculate_cube_state_score(cube_state)
|
|
|
|
# Bonus for valid moves
|
|
if is_valid_move:
|
|
current_score += 0.2
|
|
else:
|
|
current_score -= 0.2
|
|
|
|
# Bonus for solving the cube
|
|
if is_solved:
|
|
current_score += 1.0
|
|
|
|
# Check for thinking tags
|
|
try:
|
|
# Make sure response_text is a string
|
|
if not isinstance(response_text, str):
|
|
logger.warning(f"response_text is not a string: {type(response_text)}")
|
|
response_text = str(response_text) if response_text is not None else ""
|
|
|
|
match = re.search(r"<think>(.*?)</think>", response_text, re.DOTALL)
|
|
if match:
|
|
thinking_content = match.group(1)
|
|
# Make sure thinking_content is a string before calling strip()
|
|
if isinstance(thinking_content, str):
|
|
if thinking_content.strip(): # Not empty
|
|
thinking_quality = self._evaluate_thinking_quality(
|
|
thinking_content, cube_state
|
|
)
|
|
current_score += thinking_quality
|
|
else: # Empty thinking tags
|
|
current_score -= 0.1
|
|
else:
|
|
logger.warning(
|
|
f"thinking_content is not a string: {type(thinking_content)}"
|
|
)
|
|
current_score -= 0.1
|
|
else: # No thinking tags
|
|
current_score -= 0.2
|
|
except Exception as e:
|
|
logger.warning(f"Error processing thinking tags: {e}")
|
|
current_score -= 0.1
|
|
|
|
return current_score
|
|
|
|
def _evaluate_thinking_quality(
|
|
self, thinking_content: str, cube_state: CubeState
|
|
) -> float:
|
|
"""
|
|
Evaluate the quality of thinking content to assign a reward
|
|
|
|
Args:
|
|
thinking_content: The content inside <think> tags
|
|
cube_state: The current cube state
|
|
|
|
Returns:
|
|
A score between 0.0 and 0.5 based on thinking quality
|
|
"""
|
|
# Base score for having thinking content
|
|
score = 0.2
|
|
|
|
# Check if a strategy is being used
|
|
if cube_state.current_strategy:
|
|
# Check if strategy is mentioned in thinking
|
|
if cube_state.current_strategy.lower() in thinking_content.lower():
|
|
score += 0.1
|
|
|
|
# Common strategy-related terms to look for
|
|
strategy_terms = [
|
|
"layer",
|
|
"cross",
|
|
"corners",
|
|
"edges",
|
|
"algorithm",
|
|
"orient",
|
|
"permute",
|
|
"f2l",
|
|
"oll",
|
|
"pll",
|
|
"cfop",
|
|
"beginner",
|
|
"method",
|
|
"technique",
|
|
"sequence",
|
|
"pattern",
|
|
]
|
|
|
|
# Count strategy terms used
|
|
term_count = sum(
|
|
1 for term in strategy_terms if term in thinking_content.lower()
|
|
)
|
|
strategy_term_bonus = min(0.1, term_count * 0.02) # Cap at 0.1
|
|
score += strategy_term_bonus
|
|
|
|
# Check for specificity of thinking
|
|
if cube_state.cube is not None:
|
|
# Look for color mentions
|
|
color_terms = [
|
|
"white",
|
|
"yellow",
|
|
"red",
|
|
"orange",
|
|
"blue",
|
|
"green",
|
|
"face",
|
|
"center",
|
|
"corner",
|
|
"edge",
|
|
]
|
|
color_count = sum(
|
|
1 for term in color_terms if term in thinking_content.lower()
|
|
)
|
|
color_bonus = min(0.1, color_count * 0.02) # Cap at 0.1
|
|
score += color_bonus
|
|
|
|
# Check for detailed explanations
|
|
if len(thinking_content) > 200: # More detailed thinking
|
|
score += 0.05
|
|
|
|
# Check for step-by-step reasoning
|
|
if re.search(
|
|
r"(step|first|second|third|next|then|after)",
|
|
thinking_content,
|
|
re.IGNORECASE,
|
|
):
|
|
score += 0.05
|
|
|
|
return min(0.5, score) # Cap the total at 0.5
|
|
|
|
async def _sample_response(self, messages: List[Dict], n: int = 1) -> List[str]:
|
|
"""Sample responses from the language model"""
|
|
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
try:
|
|
completions = await self.server.completion(
|
|
prompt=prompt,
|
|
n=n,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=self.config.temperature,
|
|
top_p=self.config.top_p,
|
|
)
|
|
return [choice.text for choice in completions.choices]
|
|
except Exception as e:
|
|
logger.error(f"API error during completion: {e}")
|
|
return []
|
|
|
|
async def _next_step(
|
|
self, ep: CubeState, current_turn: int, max_turns: int
|
|
) -> Tuple[Optional[RubiksCubeScoredDataGroup], bool]:
|
|
"""Process one step of an episode"""
|
|
G = self.config.group_size
|
|
|
|
# Get current state
|
|
current_state_messages = ep.message_history.copy()
|
|
logger.debug(
|
|
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}/{max_turns}] "
|
|
f"Current state history length: {len(current_state_messages)}"
|
|
)
|
|
|
|
messages_for_llm = current_state_messages.copy()
|
|
agent_prompt_content = "<think>\n" if self.config.thinking_active else ""
|
|
messages_for_llm.append({"role": "agent", "content": agent_prompt_content})
|
|
|
|
# Generate G alternative responses
|
|
try:
|
|
responses = await self._sample_response(messages_for_llm, n=G)
|
|
if not responses or len(responses) != G:
|
|
logger.error(
|
|
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] "
|
|
f"Expected {G} responses, got {len(responses) if responses else 0}. "
|
|
f"Aborting step."
|
|
)
|
|
return None, True # Episode termination
|
|
except Exception as e_sample:
|
|
logger.error(
|
|
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Error sampling responses: {e_sample}",
|
|
exc_info=True,
|
|
)
|
|
return None, True # Episode termination
|
|
|
|
# Lists to store data for each alternative response
|
|
alt_full_responses: List[str] = []
|
|
alt_parsed_moves: List[Optional[str]] = []
|
|
alt_is_valid_move: List[bool] = []
|
|
alt_rewards: List[float] = []
|
|
alt_next_state_msgs: List[List[Dict]] = []
|
|
alt_is_terminal: List[bool] = []
|
|
alt_is_solved: List[bool] = []
|
|
alt_tokens: List[List[int]] = []
|
|
alt_masks: List[List[int]] = []
|
|
alt_token_rewards: List[List[float]] = [] # Token-level rewards
|
|
|
|
# Process each alternative response
|
|
for i in range(G):
|
|
llm_output_only = responses[i]
|
|
full_agent_response = agent_prompt_content + llm_output_only
|
|
alt_full_responses.append(full_agent_response)
|
|
|
|
# Parse the move and strategy from the response
|
|
parsed_move, parsed_strategy = self._parse_move(full_agent_response, ep)
|
|
alt_parsed_moves.append(parsed_move)
|
|
|
|
# Create a copy of the current state for simulation
|
|
sim_ep = copy.deepcopy(ep)
|
|
|
|
# Track strategy if provided
|
|
if parsed_strategy:
|
|
sim_ep.current_strategy = parsed_strategy
|
|
|
|
# Apply the move if valid
|
|
is_valid_move = False
|
|
if parsed_move is not None:
|
|
is_valid_move = sim_ep.apply_move(parsed_move)
|
|
alt_is_valid_move.append(is_valid_move)
|
|
|
|
# Check if the cube is solved after the move
|
|
is_solved = sim_ep.is_solved()
|
|
alt_is_solved.append(is_solved)
|
|
|
|
# Calculate reward
|
|
reward = self._score_response(
|
|
is_valid_move, full_agent_response, sim_ep, is_solved
|
|
)
|
|
alt_rewards.append(reward)
|
|
|
|
# Determine if the episode terminates
|
|
is_terminal = (
|
|
is_solved or (current_turn + 1 >= max_turns) or not is_valid_move
|
|
)
|
|
alt_is_terminal.append(is_terminal)
|
|
|
|
# Prepare next state messages
|
|
current_state_plus_response_i = current_state_messages + [
|
|
{"role": "agent", "content": full_agent_response}
|
|
]
|
|
|
|
if not is_terminal:
|
|
next_state_msgs_i = current_state_plus_response_i + [
|
|
{
|
|
"role": "environment",
|
|
"content": self._format_observation(sim_ep),
|
|
}
|
|
]
|
|
else:
|
|
next_state_msgs_i = current_state_plus_response_i
|
|
|
|
alt_next_state_msgs.append(next_state_msgs_i)
|
|
|
|
# Tokenize the next state for the trainer
|
|
tokenized_i = tokenize_for_trainer(self.tokenizer, next_state_msgs_i)
|
|
alt_tokens.append(tokenized_i["tokens"])
|
|
alt_masks.append(tokenized_i["masks"])
|
|
|
|
# Calculate token-level rewards if enabled
|
|
if self.use_token_level_rewards:
|
|
token_level_rewards = calculate_token_level_rewards(
|
|
response_text=full_agent_response,
|
|
is_valid_move=is_valid_move,
|
|
parsed_move=parsed_move,
|
|
reward=reward,
|
|
token_ids=tokenized_i["tokens"],
|
|
scale_factor=self.token_reward_scale,
|
|
)
|
|
else:
|
|
# Default to uniform token rewards if not enabled
|
|
token_level_rewards = [reward * 0.1] * len(tokenized_i["tokens"])
|
|
|
|
alt_token_rewards.append(token_level_rewards)
|
|
|
|
# Package the data for this step
|
|
if not (
|
|
len(alt_tokens) == G
|
|
and len(alt_masks) == G
|
|
and len(alt_rewards) == G
|
|
and len(alt_next_state_msgs) == G
|
|
and len(alt_parsed_moves) == G
|
|
and len(alt_token_rewards) == G
|
|
):
|
|
logger.error(
|
|
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] "
|
|
f"Mismatch in alternative list lengths before creating ScoredDataGroup. "
|
|
f"Tokens:{len(alt_tokens)}, Masks:{len(alt_masks)}, Rewards:{len(alt_rewards)}, "
|
|
f"Token Rewards:{len(alt_token_rewards)}, Msgs:{len(alt_next_state_msgs)}, "
|
|
f"ParsedMoves:{len(alt_parsed_moves)}. Expected {G} for all. "
|
|
f"Aborting step."
|
|
)
|
|
return None, True
|
|
|
|
# Calculate token weights for advantage computation if token-level rewards are enabled
|
|
if self.use_token_level_rewards:
|
|
alt_token_weights = calculate_advantage_token_weights(alt_token_rewards)
|
|
else:
|
|
# Default to uniform weights
|
|
alt_token_weights = [[1.0] * len(tokens) for tokens in alt_tokens]
|
|
|
|
current_step_data = RubiksCubeScoredDataGroup(
|
|
seed=ep.seed,
|
|
tokens=alt_tokens,
|
|
masks=alt_masks,
|
|
scores=alt_rewards,
|
|
token_scores=alt_token_rewards,
|
|
token_weights=alt_token_weights,
|
|
messages=alt_next_state_msgs,
|
|
parsed_actions=alt_parsed_moves,
|
|
)
|
|
|
|
# Find the best response based on the rewards
|
|
best_reward_idx = np.argmax(alt_rewards)
|
|
|
|
chosen_reward_for_log = alt_rewards[best_reward_idx]
|
|
logger.debug(
|
|
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] "
|
|
f"Selected Alt {best_reward_idx} "
|
|
f"(Reward: {chosen_reward_for_log}) "
|
|
f"from {G} alternatives."
|
|
)
|
|
|
|
# Get the best parsed move
|
|
chosen_move = alt_parsed_moves[best_reward_idx]
|
|
chosen_full_response = alt_full_responses[best_reward_idx]
|
|
|
|
logger.info(
|
|
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Chosen move: "
|
|
f"{chosen_move} (from Alt {best_reward_idx} with "
|
|
f"Reward {chosen_reward_for_log})"
|
|
)
|
|
|
|
# Add the response to the episode history
|
|
response_for_history = truncate_thinking(
|
|
chosen_full_response,
|
|
self.tokenizer,
|
|
self.config.max_think_chars_history,
|
|
)
|
|
ep.message_history.append({"role": "agent", "content": response_for_history})
|
|
|
|
# Apply the chosen move to the main environment
|
|
is_valid_move = False
|
|
if chosen_move is not None:
|
|
is_valid_move = ep.apply_move(chosen_move)
|
|
|
|
# Check if the cube is solved
|
|
is_solved = ep.is_solved()
|
|
|
|
# Calculate reward
|
|
step_reward = self._score_response(
|
|
is_valid_move, chosen_full_response, ep, is_solved
|
|
)
|
|
ep.step_rewards.append(step_reward)
|
|
|
|
# Determine if the episode terminates
|
|
is_episode_terminal = (
|
|
is_solved or (current_turn + 1 >= max_turns) or not is_valid_move
|
|
)
|
|
|
|
# Add the next observation if the episode continues
|
|
if not is_episode_terminal:
|
|
ep.message_history.append(
|
|
{"role": "environment", "content": self._format_observation(ep)}
|
|
)
|
|
|
|
return current_step_data, is_episode_terminal
|
|
|
|
async def collect_trajectories(
|
|
self, item: Tuple[int, int]
|
|
) -> Tuple[List[RubiksCubeScoredDataGroup], List[Tuple[int, int]]]:
|
|
"""Collect data for ONE FULL trajectory (episode)"""
|
|
seed, _ = item
|
|
G_config = self.config.group_size
|
|
max_turns = self.config.max_steps
|
|
|
|
trajectory_data_for_trainer: List[RubiksCubeScoredDataGroup] = []
|
|
|
|
logger.info(
|
|
f"[Collect Trajectories Seed: {seed}] Starting new trajectory. "
|
|
f"Group size G={G_config}, Max turns={max_turns}."
|
|
)
|
|
|
|
try:
|
|
ep = self._get_or_create_episode(seed)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"[Collect Trajectories Seed: {seed}] Fatal error creating/getting episode: {e}",
|
|
exc_info=True,
|
|
)
|
|
return [], []
|
|
|
|
for turn_idx in range(max_turns):
|
|
logger.debug(
|
|
f"[Collect Trajectories Seed: {seed}] Attempting turn {turn_idx + 1}/{max_turns}."
|
|
)
|
|
|
|
try:
|
|
step_data, is_terminal_this_step = await self._next_step(
|
|
ep, turn_idx, max_turns
|
|
)
|
|
except Exception as e:
|
|
if "'list' object has no attribute 'strip'" in str(e):
|
|
# Special handling for this common error which doesn't actually
|
|
# prevent the overall process from working
|
|
logger.error(
|
|
f"[Collect Trajectories Seed: {seed}] Non-fatal error in _next_step: {e}"
|
|
)
|
|
# Since we can't recover the step data, mark this step as failed
|
|
step_data = None
|
|
is_terminal_this_step = True
|
|
else:
|
|
# For other errors, log and terminate
|
|
logger.error(
|
|
f"[Collect Trajectories Seed: {seed}] Error in _next_step: {e}",
|
|
exc_info=True,
|
|
)
|
|
step_data = None
|
|
is_terminal_this_step = True
|
|
|
|
if step_data:
|
|
trajectory_data_for_trainer.append(step_data)
|
|
else:
|
|
logger.error(
|
|
f"[Collect Trajectories Seed: {seed}] Turn {turn_idx + 1} failed to produce data."
|
|
" Terminating episode."
|
|
)
|
|
is_terminal_this_step = True
|
|
|
|
if is_terminal_this_step:
|
|
final_reward_at_termination = (
|
|
sum(ep.step_rewards) if ep.step_rewards else 0.0
|
|
)
|
|
logger.info(
|
|
f"[Collect Trajectories Seed: {seed}] Episode ended at turn {turn_idx + 1}. "
|
|
f"Reason: step reported terminal. Total reward: {final_reward_at_termination:.2f}"
|
|
)
|
|
break
|
|
else:
|
|
logger.info(
|
|
f"[Collect Trajectories Seed: {seed}] Episode reached max_turns ({max_turns})."
|
|
)
|
|
|
|
final_reward = sum(ep.step_rewards) if ep.step_rewards else 0.0
|
|
is_solved = ep.is_solved()
|
|
|
|
# Store metrics for this episode
|
|
episode_summary_metrics = {
|
|
"seed": ep.seed,
|
|
"total_reward": final_reward,
|
|
"num_steps": ep.num_steps,
|
|
"is_solved": is_solved,
|
|
"scramble_length": len(ep.scramble_sequence),
|
|
}
|
|
|
|
# Add curriculum metrics if using curriculum
|
|
if self.use_curriculum:
|
|
episode_summary_metrics["curriculum_level"] = ep.curriculum_level
|
|
|
|
# Record this episode result in the curriculum system
|
|
self.curriculum.record_episode_result(
|
|
level=ep.curriculum_level, is_solved=is_solved, num_steps=ep.num_steps
|
|
)
|
|
|
|
# Get curriculum metrics
|
|
curriculum_metrics = self.curriculum.get_level_metrics()
|
|
for k, v in curriculum_metrics.items():
|
|
episode_summary_metrics[k] = v
|
|
|
|
self.completed_episode_metrics_buffer.append(episode_summary_metrics)
|
|
|
|
# Generate visualization before cleaning up if enabled
|
|
if self.generate_visualizations:
|
|
self._generate_episode_visualization(ep)
|
|
|
|
# Save best episode
|
|
if self.save_best_episodes and final_reward > self.best_episode_reward:
|
|
self.best_episode_reward = final_reward
|
|
self.best_episode = ep
|
|
self._generate_episode_visualization(ep, is_best=True)
|
|
|
|
# Clean up the episode
|
|
if seed in self.episodes:
|
|
del self.episodes[seed]
|
|
|
|
# Ensure the trajectory doesn't exceed token limits
|
|
limited_trajectory_data = ensure_trajectory_token_limit(
|
|
trajectory_data_for_trainer,
|
|
self.tokenizer,
|
|
self.config.max_trajectory_tokens,
|
|
)
|
|
|
|
return limited_trajectory_data, []
|
|
|
|
async def setup(self):
|
|
"""Initialize the environment"""
|
|
# Nothing to do here as we don't need any special setup
|
|
pass
|
|
|
|
async def get_next_item(self) -> Tuple[int, int]:
|
|
"""Generate a new random seed for the next episode"""
|
|
return (random.randint(0, 1000000), 0)
|
|
|
|
async def rollout_and_score_eval(self, seed: int) -> Dict[str, float]:
|
|
"""Run a single episode for evaluation with a single response per step"""
|
|
ep = self._get_or_create_episode(seed)
|
|
max_turns = self.config.max_steps
|
|
metrics = {
|
|
"seed": seed,
|
|
"total_reward": 0.0,
|
|
"num_turns": 0,
|
|
"num_valid_moves": 0,
|
|
"num_invalid_moves": 0,
|
|
"is_solved": False,
|
|
}
|
|
|
|
for turn in range(max_turns):
|
|
messages = ep.message_history.copy()
|
|
agent_prompt_content = "<think>\n" if self.config.thinking_active else ""
|
|
messages.append({"role": "agent", "content": agent_prompt_content})
|
|
|
|
# Get a single response
|
|
responses = await self._sample_response(messages, n=1)
|
|
if not responses:
|
|
logger.error(
|
|
f"[Eval Seed: {seed}, Turn: {turn+1}] No response. Aborting."
|
|
)
|
|
break
|
|
|
|
llm_output_only = responses[0]
|
|
full_agent_response = agent_prompt_content + llm_output_only
|
|
|
|
# Parse and apply the move
|
|
move = self._parse_move(full_agent_response)
|
|
is_valid_move = False
|
|
|
|
if move is not None:
|
|
is_valid_move = ep.apply_move(move)
|
|
|
|
if is_valid_move:
|
|
metrics["num_valid_moves"] += 1
|
|
else:
|
|
metrics["num_invalid_moves"] += 1
|
|
|
|
# Calculate reward
|
|
is_solved = ep.is_solved()
|
|
reward = self._score_response(
|
|
is_valid_move, full_agent_response, ep, is_solved
|
|
)
|
|
metrics["total_reward"] += reward
|
|
metrics["num_turns"] += 1
|
|
|
|
# Add response to history
|
|
response_for_history = truncate_thinking(
|
|
full_agent_response, self.tokenizer, self.config.max_think_chars_history
|
|
)
|
|
ep.message_history.append(
|
|
{"role": "agent", "content": response_for_history}
|
|
)
|
|
|
|
# Add next observation if not terminal
|
|
is_terminal = is_solved or not is_valid_move
|
|
if not is_terminal:
|
|
ep.message_history.append(
|
|
{"role": "environment", "content": self._format_observation(ep)}
|
|
)
|
|
|
|
# Check for termination
|
|
if is_terminal:
|
|
metrics["is_solved"] = is_solved
|
|
logger.info(f"[Eval Seed: {seed}] Episode ended. Solved: {is_solved}")
|
|
break
|
|
|
|
# If we reached max turns, check final state
|
|
if metrics["num_turns"] == max_turns:
|
|
metrics["is_solved"] = ep.is_solved()
|
|
|
|
# Clean up
|
|
del self.episodes[seed]
|
|
return metrics
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""Run evaluation episodes"""
|
|
if not self.config.use_wandb:
|
|
logger.info("Skipping evaluation as wandb is not enabled.")
|
|
return
|
|
|
|
num_eval_episodes = self.config.eval_episodes
|
|
logger.info(f"Starting evaluation for {num_eval_episodes} episodes.")
|
|
|
|
eval_tasks = [
|
|
self.rollout_and_score_eval(random.randint(1000001, 2000000))
|
|
for _ in range(num_eval_episodes)
|
|
]
|
|
|
|
all_metrics = await tqdm_asyncio.gather(*eval_tasks)
|
|
valid_metrics = [m for m in all_metrics if m]
|
|
|
|
if not valid_metrics:
|
|
logger.warning("No valid evaluation metrics.")
|
|
return
|
|
|
|
# Calculate metrics across all episodes
|
|
num_completed = len(valid_metrics)
|
|
avg_total_reward = sum(m["total_reward"] for m in valid_metrics) / num_completed
|
|
avg_num_turns = sum(m["num_turns"] for m in valid_metrics) / num_completed
|
|
|
|
total_valid_moves = sum(m["num_valid_moves"] for m in valid_metrics)
|
|
total_invalid_moves = sum(m["num_invalid_moves"] for m in valid_metrics)
|
|
total_moves = total_valid_moves + total_invalid_moves
|
|
move_validity_rate = total_valid_moves / total_moves if total_moves > 0 else 0
|
|
|
|
solve_rate = sum(1 for m in valid_metrics if m["is_solved"]) / num_completed
|
|
|
|
self.eval_metrics = [
|
|
("eval/avg_total_reward", avg_total_reward),
|
|
("eval/avg_num_turns", avg_num_turns),
|
|
("eval/move_validity_rate", move_validity_rate),
|
|
("eval/solve_rate", solve_rate),
|
|
("eval/num_completed_episodes", num_completed),
|
|
]
|
|
|
|
logger.info(f"Evaluation completed. Metrics: {self.eval_metrics}")
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
|
|
"""Log metrics to wandb"""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
if self.completed_episode_metrics_buffer:
|
|
num_episodes = len(self.completed_episode_metrics_buffer)
|
|
avg_reward = (
|
|
sum(m["total_reward"] for m in self.completed_episode_metrics_buffer)
|
|
/ num_episodes
|
|
)
|
|
avg_steps = (
|
|
sum(m["num_steps"] for m in self.completed_episode_metrics_buffer)
|
|
/ num_episodes
|
|
)
|
|
solve_rate = (
|
|
sum(1 for m in self.completed_episode_metrics_buffer if m["is_solved"])
|
|
/ num_episodes
|
|
)
|
|
|
|
# Log basic metrics
|
|
prefix = self.wandb_prepend or "rubiks"
|
|
wandb_metrics[f"{prefix}_train/avg_episode_reward"] = avg_reward
|
|
wandb_metrics[f"{prefix}_train/avg_episode_steps"] = avg_steps
|
|
wandb_metrics[f"{prefix}_train/solve_rate"] = solve_rate
|
|
wandb_metrics[f"{prefix}_train/num_episodes"] = num_episodes
|
|
|
|
# Log average scramble length
|
|
avg_scramble = (
|
|
sum(
|
|
m.get("scramble_length", 0)
|
|
for m in self.completed_episode_metrics_buffer
|
|
)
|
|
/ num_episodes
|
|
)
|
|
wandb_metrics[f"{prefix}_train/avg_scramble_length"] = avg_scramble
|
|
|
|
# Log curriculum learning metrics if enabled
|
|
if self.use_curriculum:
|
|
# Get latest curriculum metrics
|
|
curriculum_metrics = self.curriculum.get_level_metrics()
|
|
|
|
# Log curriculum level
|
|
wandb_metrics[f"{prefix}_curriculum/level"] = curriculum_metrics[
|
|
"curriculum_level"
|
|
]
|
|
wandb_metrics[f"{prefix}_curriculum/success_rate"] = curriculum_metrics[
|
|
"level_success_rate"
|
|
]
|
|
wandb_metrics[f"{prefix}_curriculum/progress_to_next"] = (
|
|
curriculum_metrics["progress_to_next_level"]
|
|
)
|
|
wandb_metrics[f"{prefix}_curriculum/solved_count"] = curriculum_metrics[
|
|
"level_solved_count"
|
|
]
|
|
wandb_metrics[f"{prefix}_curriculum/episodes"] = curriculum_metrics[
|
|
"level_episodes"
|
|
]
|
|
|
|
# Log per-level metrics
|
|
level_episodes = {}
|
|
level_success = {}
|
|
for m in self.completed_episode_metrics_buffer:
|
|
if "curriculum_level" in m:
|
|
level = m["curriculum_level"]
|
|
level_episodes[level] = level_episodes.get(level, 0) + 1
|
|
if m["is_solved"]:
|
|
level_success[level] = level_success.get(level, 0) + 1
|
|
|
|
for level, count in level_episodes.items():
|
|
success = level_success.get(level, 0)
|
|
success_rate = success / count if count > 0 else 0
|
|
wandb_metrics[f"{prefix}_curriculum/level_{level}_episodes"] = count
|
|
wandb_metrics[f"{prefix}_curriculum/level_{level}_success_rate"] = (
|
|
success_rate
|
|
)
|
|
|
|
# Clear buffer
|
|
self.completed_episode_metrics_buffer = []
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[RubiksCubeEnvConfig, List[APIServerConfig]]:
|
|
"""Initialize the configuration"""
|
|
env_config = RubiksCubeEnvConfig(
|
|
tokenizer_name="openai/gpt-4-turbo-preview",
|
|
group_size=16,
|
|
use_wandb=True,
|
|
max_num_workers=128,
|
|
rollout_server_url="http://localhost:9000",
|
|
total_steps=2000,
|
|
batch_size=1024,
|
|
steps_per_eval=20,
|
|
max_token_length=1024 * 16,
|
|
inference_weight=1.0,
|
|
wandb_name="rubiks_cube",
|
|
data_path_to_save_groups=None,
|
|
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
|
eval_limit_ratio=0.1,
|
|
max_steps=20,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
thinking_active=True,
|
|
eval_episodes=100,
|
|
max_think_chars_history=3000,
|
|
max_trajectory_tokens=24576,
|
|
debug_mode=False,
|
|
tiebreak_token_factor=0.01,
|
|
scramble_moves=7,
|
|
cube_size=3,
|
|
reward_per_correctly_placed_cubie=0.05,
|
|
reward_per_step_reduction=0.01,
|
|
)
|
|
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
base_url="http://localhost:9004/v1",
|
|
num_requests_for_eval=256,
|
|
)
|
|
]
|
|
|
|
return env_config, server_configs
|
|
|
|
def _generate_episode_visualization(
|
|
self, ep: CubeState, is_best: bool = False
|
|
) -> Optional[str]:
|
|
"""Generate and save a visualization of the episode"""
|
|
if not self.generate_visualizations or not self.visualizations_dir:
|
|
return None
|
|
|
|
try:
|
|
# Extract thinking steps from message history
|
|
thinking_history = []
|
|
for message in ep.message_history:
|
|
if message.get("role") == "agent" and isinstance(
|
|
message.get("content"), str
|
|
):
|
|
content = message["content"]
|
|
thinking_match = re.search(
|
|
r"<think>(.*?)</think>", content, re.DOTALL
|
|
)
|
|
if thinking_match:
|
|
thinking_text = thinking_match.group(1).strip()
|
|
if thinking_text:
|
|
thinking_history.append(thinking_text)
|
|
|
|
# Generate a filename based on episode attributes
|
|
solved_status = "solved" if ep.is_solved() else "unsolved"
|
|
curr_level = f"L{ep.curriculum_level}_" if self.use_curriculum else ""
|
|
best_marker = "BEST_" if is_best else ""
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
filename = f"{best_marker}{curr_level}{solved_status}_steps{ep.num_steps}_seed{ep.seed}_{timestamp}.html"
|
|
output_path = os.path.join(self.visualizations_dir, filename)
|
|
|
|
# Get curriculum description if available
|
|
curriculum_description = None
|
|
if self.use_curriculum and ep.curriculum_level > 0:
|
|
current_level = self.curriculum.levels[ep.curriculum_level]
|
|
curriculum_description = current_level.description
|
|
|
|
# Save the visualization
|
|
html_path = save_enhanced_visualization(
|
|
cube_state=str(ep.cube),
|
|
move_history=ep.actions,
|
|
progress_history=ep.progress_history,
|
|
rewards_history=ep.step_rewards,
|
|
thinking_history=thinking_history,
|
|
scramble_sequence=ep.scramble_sequence,
|
|
is_solved=ep.is_solved(),
|
|
curriculum_level=ep.curriculum_level if self.use_curriculum else None,
|
|
curriculum_description=curriculum_description,
|
|
output_path=output_path,
|
|
)
|
|
|
|
logger.info(f"Generated visualization at {html_path}")
|
|
return html_path
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating visualization: {e}")
|
|
return None
|
|
|
|
@classmethod
|
|
def cli(cls):
|
|
"""Command-line interface"""
|
|
# Call the original CLI implementation
|
|
result = super().cli()
|
|
|
|
# After processing is complete, generate a consolidated report
|
|
try:
|
|
from rubiks_consolidated_report import generate_consolidated_report_from_dir
|
|
|
|
# Check if we've generated any visualizations
|
|
if (
|
|
hasattr(cls, "config")
|
|
and cls.config.generate_visualizations
|
|
and cls.config.visualizations_dir
|
|
):
|
|
print("\nGenerating consolidated report of all solving attempts...")
|
|
report_path = generate_consolidated_report_from_dir(
|
|
cls.config.visualizations_dir
|
|
)
|
|
print(f"Consolidated report generated at: {report_path}")
|
|
except Exception as e:
|
|
print(f"Error generating consolidated report: {e}")
|
|
|
|
return result
|
|
|
|
|
|
if __name__ == "__main__":
|
|
RubiksCubeEnv.cli()
|