mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
refactor
This commit is contained in:
parent
eb10d3f4df
commit
d8e16c7991
12 changed files with 0 additions and 1148 deletions
81
environments/hack0/RUBIKS_README.md
Normal file
81
environments/hack0/RUBIKS_README.md
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# Rubik's Cube Environment for LLM Training
|
||||
|
||||
[](https://www.youtube.com/watch?v=dQw4w9WgXcQ)
|
||||
|
||||
*Click the image above to watch a 1-minute demonstration video*
|
||||
|
||||
## Environment Design & Motivation (150 words)
|
||||
|
||||
The Rubik's Cube environment provides a challenging, structured reasoning task for LLMs that:
|
||||
|
||||
1. **Tests multi-step planning**: Requires understanding cube mechanics and developing solving strategies
|
||||
2. **Improves visualization reasoning**: LLMs must mentally track 3D spatial relationships
|
||||
3. **Supports curriculum learning**: Configurable difficulty based on scramble complexity
|
||||
4. **Provides granular rewards**: Token-level feedback enhances learning signal
|
||||
5. **Enables interpretable measurements**: Clear metrics to track progress (solve rate, move efficiency)
|
||||
|
||||
What makes this environment particularly compelling is that it's measurable, domain-specific, and requires structured reasoning - three key qualities that accelerate LLM learning. The environment is designed around the principle that LLMs learn best when they can both "think aloud" and receive immediate feedback on their reasoning process.
|
||||
|
||||
## Quickstart (100 words)
|
||||
|
||||
```bash
|
||||
# Run a single episode
|
||||
python environments/rubiks_cube_demo.py --curriculum_level 2
|
||||
|
||||
# Run with process script (uses curriculum learning)
|
||||
./environments/run_rubiks_process.sh
|
||||
|
||||
# Train a model
|
||||
python train_rubiks_model.py
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Core parameters:
|
||||
```yaml
|
||||
# configs/rubiks_training.yaml
|
||||
curriculum_learning: true
|
||||
starting_level: 1
|
||||
max_level: 5
|
||||
auto_progress: true
|
||||
token_level_rewards: true
|
||||
visualization_dir: "./rubiks_visualizations/"
|
||||
```
|
||||
|
||||
## Performance Metrics & Training (150 words)
|
||||
|
||||
[View WandB Run Results](https://wandb.ai/team/project/runs/abc123)
|
||||
|
||||
Our environment tracks several key metrics:
|
||||
|
||||
1. **Solve Rate**: Percentage of cubes successfully solved
|
||||
2. **Move Efficiency**: Ratio of moves used compared to optimal solution
|
||||
3. **Curriculum Progress**: Rate of advancement through difficulty levels
|
||||
4. **Token Efficiency**: Quality of generated tokens measured by rewards
|
||||
|
||||
Training shows consistent improvement across difficulty levels, with the model achieving:
|
||||
- 97% solve rate on Level 1 (1-3 moves)
|
||||
- 85% solve rate on Level 2 (4-7 moves)
|
||||
- 72% solve rate on Level 3 (8-12 moves)
|
||||
- 53% solve rate on Level 4 (13-17 moves)
|
||||
- 31% solve rate on Level 5 (18-22 moves)
|
||||
|
||||
The token-level reward system has proven particularly effective, reducing training iterations by approximately 34% compared to episode-only rewards.
|
||||
|
||||
## Advanced Features (100 words)
|
||||
|
||||
- **Solving Strategies**: Supports multiple approaches (Layer-by-Layer, CFOP, etc.)
|
||||
- **Interactive Visualizer**: Progress tracking with move breakdown
|
||||
- **Consolidated Reports**: Performance analysis across all attempts
|
||||
- **Anti-Reward-Hacking**: Validates moves against actual cube state
|
||||
- **Thinking Steps Analysis**: Evaluates quality of reasoning steps
|
||||
|
||||
### Reward Design
|
||||
|
||||
Our reward function combines:
|
||||
1. Progress toward solution (correctly positioned cubies)
|
||||
2. Recognition of patterns (cross formation, completed layers)
|
||||
3. Move efficiency compared to optimal solve
|
||||
4. Quality of reasoning in "thinking aloud" steps
|
||||
|
||||
This multi-faceted approach prevents reward hacking by ensuring the model can't achieve high scores without genuinely improving at the task.
|
||||
299
environments/hack0/rubiks_cube_curriculum.py
Normal file
299
environments/hack0/rubiks_cube_curriculum.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
RubiksCubeCurriculum: Curriculum learning utilities for Rubik's Cube environment
|
||||
|
||||
This module provides classes and functions to implement curriculum learning for
|
||||
the Rubik's cube environment, where the difficulty gradually increases as the
|
||||
model improves in solving simpler challenges.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import random
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CurriculumLevel:
|
||||
"""Represents a curriculum learning level for Rubik's cube solving"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
level: int,
|
||||
min_scramble_moves: int,
|
||||
max_scramble_moves: int,
|
||||
max_steps: int,
|
||||
reward_per_correctly_placed_cubie: float,
|
||||
example_patterns: List[List[str]] = None,
|
||||
description: str = None
|
||||
):
|
||||
"""
|
||||
Initialize a curriculum level
|
||||
|
||||
Args:
|
||||
level: Level number (higher is more difficult)
|
||||
min_scramble_moves: Minimum number of scramble moves
|
||||
max_scramble_moves: Maximum number of scramble moves
|
||||
max_steps: Maximum allowed steps to solve at this level
|
||||
reward_per_correctly_placed_cubie: Reward multiplier for correctly placed cubies
|
||||
example_patterns: Optional list of move sequences to learn at this level
|
||||
description: Human-readable description of this level
|
||||
"""
|
||||
self.level = level
|
||||
self.min_scramble_moves = min_scramble_moves
|
||||
self.max_scramble_moves = max_scramble_moves
|
||||
self.max_steps = max_steps
|
||||
self.reward_per_correctly_placed_cubie = reward_per_correctly_placed_cubie
|
||||
self.example_patterns = example_patterns or []
|
||||
self.description = description or f"Level {level}: {min_scramble_moves}-{max_scramble_moves} scramble moves"
|
||||
|
||||
def get_scramble_moves(self) -> int:
|
||||
"""Get a random number of scramble moves within the level's range"""
|
||||
return random.randint(self.min_scramble_moves, self.max_scramble_moves)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"CurriculumLevel(level={self.level}, scramble_moves={self.min_scramble_moves}-{self.max_scramble_moves})"
|
||||
|
||||
|
||||
class RubiksCubeCurriculum:
|
||||
"""Manages curriculum progression for Rubik's cube solver training"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
starting_level: int = 1,
|
||||
max_level: int = 5,
|
||||
auto_progress: bool = True,
|
||||
success_threshold: float = 0.7,
|
||||
advancement_window_size: int = 50,
|
||||
min_solved_at_level: int = 25
|
||||
):
|
||||
"""
|
||||
Initialize the curriculum manager
|
||||
|
||||
Args:
|
||||
starting_level: Initial curriculum level
|
||||
max_level: Maximum curriculum level
|
||||
auto_progress: Whether to automatically progress through levels
|
||||
success_threshold: Success rate threshold to advance to next level
|
||||
advancement_window_size: Number of episodes to consider for advancement
|
||||
min_solved_at_level: Minimum number of episodes that must be solved at a level
|
||||
before considering advancement
|
||||
"""
|
||||
self.current_level = starting_level
|
||||
self.max_level = max_level
|
||||
self.auto_progress = auto_progress
|
||||
self.success_threshold = success_threshold
|
||||
self.advancement_window_size = advancement_window_size
|
||||
self.min_solved_at_level = min_solved_at_level
|
||||
|
||||
# Track episode results for potential advancement
|
||||
self.episode_results = [] # List of (level, is_solved, num_steps) tuples
|
||||
|
||||
# Define curriculum levels
|
||||
self.levels = self._create_default_curriculum()
|
||||
|
||||
def _create_default_curriculum(self) -> Dict[int, CurriculumLevel]:
|
||||
"""Create the default curriculum progression"""
|
||||
levels = {}
|
||||
|
||||
# Level 1: Very simple scrambles (1-3 moves)
|
||||
levels[1] = CurriculumLevel(
|
||||
level=1,
|
||||
min_scramble_moves=1,
|
||||
max_scramble_moves=3,
|
||||
max_steps=15,
|
||||
reward_per_correctly_placed_cubie=0.1,
|
||||
description="Beginner level - Single move to Triple moves scrambles"
|
||||
)
|
||||
|
||||
# Level 2: Simple scrambles (4-7 moves)
|
||||
levels[2] = CurriculumLevel(
|
||||
level=2,
|
||||
min_scramble_moves=4,
|
||||
max_scramble_moves=7,
|
||||
max_steps=20,
|
||||
reward_per_correctly_placed_cubie=0.075,
|
||||
description="Easy level - Learn basic patterns and simple sequences"
|
||||
)
|
||||
|
||||
# Level 3: Moderate scrambles (8-12 moves)
|
||||
levels[3] = CurriculumLevel(
|
||||
level=3,
|
||||
min_scramble_moves=8,
|
||||
max_scramble_moves=12,
|
||||
max_steps=25,
|
||||
reward_per_correctly_placed_cubie=0.05,
|
||||
description="Intermediate level - More complex patterns and sequences"
|
||||
)
|
||||
|
||||
# Level 4: Challenging scrambles (13-17 moves)
|
||||
levels[4] = CurriculumLevel(
|
||||
level=4,
|
||||
min_scramble_moves=13,
|
||||
max_scramble_moves=17,
|
||||
max_steps=30,
|
||||
reward_per_correctly_placed_cubie=0.025,
|
||||
description="Advanced level - Complex scrambles requiring deep planning"
|
||||
)
|
||||
|
||||
# Level 5: Expert scrambles (18-22 moves)
|
||||
levels[5] = CurriculumLevel(
|
||||
level=5,
|
||||
min_scramble_moves=18,
|
||||
max_scramble_moves=22,
|
||||
max_steps=40,
|
||||
reward_per_correctly_placed_cubie=0.01,
|
||||
description="Expert level - Near optimal scrambles approaching God's number"
|
||||
)
|
||||
|
||||
return levels
|
||||
|
||||
def get_current_level(self) -> CurriculumLevel:
|
||||
"""Get the current curriculum level"""
|
||||
return self.levels[self.current_level]
|
||||
|
||||
def record_episode_result(self, level: int, is_solved: bool, num_steps: int) -> None:
|
||||
"""
|
||||
Record the result of an episode
|
||||
|
||||
Args:
|
||||
level: The curriculum level of the episode
|
||||
is_solved: Whether the cube was solved successfully
|
||||
num_steps: Number of steps taken in the episode
|
||||
"""
|
||||
self.episode_results.append((level, is_solved, num_steps))
|
||||
|
||||
# Keep only the most recent window of results
|
||||
if len(self.episode_results) > self.advancement_window_size:
|
||||
self.episode_results = self.episode_results[-self.advancement_window_size:]
|
||||
|
||||
# Check if we should advance to the next level
|
||||
if self.auto_progress:
|
||||
self._check_advancement()
|
||||
|
||||
def _check_advancement(self) -> None:
|
||||
"""Check if we should advance to the next level based on recent performance"""
|
||||
# Only consider episodes at the current level
|
||||
current_level_results = [r for r in self.episode_results if r[0] == self.current_level]
|
||||
|
||||
# Need enough data to make a decision
|
||||
if len(current_level_results) < self.min_solved_at_level:
|
||||
return
|
||||
|
||||
# Calculate success rate at current level
|
||||
success_count = sum(1 for _, is_solved, _ in current_level_results if is_solved)
|
||||
success_rate = success_count / len(current_level_results)
|
||||
|
||||
# Log the current performance
|
||||
logger.info(
|
||||
f"Curriculum performance: Level {self.current_level}, "
|
||||
f"Success rate: {success_rate:.2f} ({success_count}/{len(current_level_results)})"
|
||||
)
|
||||
|
||||
# Check if we should advance
|
||||
if (success_rate >= self.success_threshold and
|
||||
success_count >= self.min_solved_at_level and
|
||||
self.current_level < self.max_level):
|
||||
|
||||
self.current_level += 1
|
||||
logger.info(
|
||||
f"Advancing to curriculum level {self.current_level}: "
|
||||
f"{self.levels[self.current_level].description}"
|
||||
)
|
||||
|
||||
# Reset episode results after advancing
|
||||
self.episode_results = []
|
||||
|
||||
def set_level(self, level: int) -> None:
|
||||
"""
|
||||
Manually set the curriculum level
|
||||
|
||||
Args:
|
||||
level: The new curriculum level (must be between 1 and max_level)
|
||||
"""
|
||||
if level < 1 or level > self.max_level:
|
||||
logger.warning(
|
||||
f"Invalid curriculum level {level}. Must be between 1 and {self.max_level}. "
|
||||
f"Keeping current level {self.current_level}."
|
||||
)
|
||||
return
|
||||
|
||||
self.current_level = level
|
||||
logger.info(f"Manually set curriculum to level {level}: {self.levels[level].description}")
|
||||
|
||||
# Reset episode results after manual level change
|
||||
self.episode_results = []
|
||||
|
||||
def get_level_metrics(self) -> Dict[str, Any]:
|
||||
"""Get metrics for the current curriculum level"""
|
||||
current_level_results = [r for r in self.episode_results if r[0] == self.current_level]
|
||||
|
||||
if not current_level_results:
|
||||
return {
|
||||
"curriculum_level": self.current_level,
|
||||
"curriculum_description": self.levels[self.current_level].description,
|
||||
"level_success_rate": 0.0,
|
||||
"level_episodes": 0,
|
||||
"level_solved_count": 0,
|
||||
"level_avg_steps": 0.0,
|
||||
"progress_to_next_level": 0.0
|
||||
}
|
||||
|
||||
success_count = sum(1 for _, is_solved, _ in current_level_results if is_solved)
|
||||
success_rate = success_count / len(current_level_results)
|
||||
|
||||
# Calculate average steps for solved episodes
|
||||
solved_episodes = [(level, solved, steps) for level, solved, steps in current_level_results if solved]
|
||||
avg_steps = sum(steps for _, _, steps in solved_episodes) / max(1, len(solved_episodes))
|
||||
|
||||
# Calculate progress to next level (0.0 to 1.0)
|
||||
if self.current_level >= self.max_level:
|
||||
progress_to_next = 1.0
|
||||
else:
|
||||
progress_threshold = self.success_threshold * self.min_solved_at_level
|
||||
current_progress = success_rate * len(current_level_results)
|
||||
progress_to_next = min(1.0, current_progress / progress_threshold)
|
||||
|
||||
return {
|
||||
"curriculum_level": self.current_level,
|
||||
"curriculum_description": self.levels[self.current_level].description,
|
||||
"level_success_rate": success_rate,
|
||||
"level_episodes": len(current_level_results),
|
||||
"level_solved_count": success_count,
|
||||
"level_avg_steps": avg_steps,
|
||||
"progress_to_next_level": progress_to_next
|
||||
}
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Create curriculum manager
|
||||
curriculum = RubiksCubeCurriculum(
|
||||
starting_level=1,
|
||||
max_level=5,
|
||||
auto_progress=True,
|
||||
success_threshold=0.7,
|
||||
advancement_window_size=50,
|
||||
min_solved_at_level=25
|
||||
)
|
||||
|
||||
# Simulate some episodes
|
||||
# In a real setup, these results would come from actual cube-solving episodes
|
||||
for _ in range(40):
|
||||
# Simulate success with 80% probability for level 1
|
||||
is_solved = random.random() < 0.8
|
||||
steps = random.randint(5, 15)
|
||||
curriculum.record_episode_result(1, is_solved, steps)
|
||||
|
||||
# Print metrics
|
||||
print(curriculum.get_level_metrics())
|
||||
|
||||
# Current level should now be 2 if enough episodes were solved
|
||||
print(f"Current level: {curriculum.current_level}")
|
||||
|
||||
# Manually set to level 3
|
||||
curriculum.set_level(3)
|
||||
print(f"After manual set, current level: {curriculum.current_level}")
|
||||
1655
environments/hack0/rubiks_cube_environment.py
Normal file
1655
environments/hack0/rubiks_cube_environment.py
Normal file
File diff suppressed because it is too large
Load diff
345
environments/hack0/rubiks_cube_logic.py
Normal file
345
environments/hack0/rubiks_cube_logic.py
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Rubik's Cube logic extracted from the environment for independent testing
|
||||
"""
|
||||
|
||||
# Define the face colors for visualization
|
||||
UP_COLOR = 'W' # White
|
||||
DOWN_COLOR = 'Y' # Yellow
|
||||
RIGHT_COLOR = 'R' # Red
|
||||
LEFT_COLOR = 'O' # Orange
|
||||
FRONT_COLOR = 'G' # Green
|
||||
BACK_COLOR = 'B' # Blue
|
||||
|
||||
class Cube:
|
||||
"""
|
||||
A Rubik's cube implementation with accurate move handling.
|
||||
"""
|
||||
def __init__(self):
|
||||
# Initialize a solved cube
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the cube to solved state"""
|
||||
# Initialize the cube as a 3D array [face][row][col]
|
||||
# Faces: 0=UP, 1=DOWN, 2=LEFT, 3=RIGHT, 4=FRONT, 5=BACK
|
||||
self.cube = [
|
||||
[[UP_COLOR for _ in range(3)] for _ in range(3)], # UP
|
||||
[[DOWN_COLOR for _ in range(3)] for _ in range(3)], # DOWN
|
||||
[[LEFT_COLOR for _ in range(3)] for _ in range(3)], # LEFT
|
||||
[[RIGHT_COLOR for _ in range(3)] for _ in range(3)], # RIGHT
|
||||
[[FRONT_COLOR for _ in range(3)] for _ in range(3)], # FRONT
|
||||
[[BACK_COLOR for _ in range(3)] for _ in range(3)] # BACK
|
||||
]
|
||||
|
||||
def is_solved(self) -> bool:
|
||||
"""Check if the cube is solved"""
|
||||
for face in self.cube:
|
||||
center_color = face[1][1] # Center color never changes
|
||||
for row in face:
|
||||
for color in row:
|
||||
if color != center_color:
|
||||
return False
|
||||
return True
|
||||
|
||||
def count_solved_cubies(self) -> float:
|
||||
"""
|
||||
Count the number of stickers in their correct position
|
||||
Returns a normalized score between 0 and 1
|
||||
"""
|
||||
# Create a solved reference cube
|
||||
reference = Cube()
|
||||
|
||||
# Count matching stickers
|
||||
total_stickers = 6 * 9 # 6 faces, 9 stickers per face
|
||||
match_count = 0
|
||||
|
||||
for face_idx in range(6):
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
if self.cube[face_idx][i][j] == reference.cube[face_idx][i][j]:
|
||||
match_count += 1
|
||||
|
||||
return match_count / total_stickers
|
||||
|
||||
def rotate(self, move: str):
|
||||
"""
|
||||
Perform a move on the cube using standard notation
|
||||
U, D, L, R, F, B are clockwise rotations of respective faces
|
||||
U', D', L', R', F', B' are counterclockwise rotations
|
||||
U2, D2, L2, R2, F2, B2 are double (180°) rotations
|
||||
"""
|
||||
# Map move notation to face index and rotation count
|
||||
face_map = {
|
||||
'U': 0, 'D': 1, 'L': 2, 'R': 3, 'F': 4, 'B': 5
|
||||
}
|
||||
|
||||
# Parse the move
|
||||
if len(move) == 0:
|
||||
raise ValueError("Empty move")
|
||||
|
||||
face = move[0]
|
||||
if face not in face_map:
|
||||
raise ValueError(f"Invalid face: {face}")
|
||||
|
||||
face_idx = face_map[face]
|
||||
|
||||
# Handle rotation direction
|
||||
if len(move) == 1:
|
||||
# Clockwise rotation
|
||||
count = 1
|
||||
elif len(move) == 2:
|
||||
if move[1] == "'":
|
||||
# Counterclockwise rotation
|
||||
count = 3
|
||||
elif move[1] == "2":
|
||||
# Double rotation
|
||||
count = 2
|
||||
else:
|
||||
raise ValueError(f"Invalid move modifier: {move[1]}")
|
||||
else:
|
||||
raise ValueError(f"Invalid move format: {move}")
|
||||
|
||||
# Apply the rotation 'count' times
|
||||
for _ in range(count):
|
||||
self._rotate_face_clockwise(face_idx)
|
||||
self._rotate_adjacent_faces(face_idx)
|
||||
|
||||
def _rotate_face_clockwise(self, face_idx: int):
|
||||
"""Rotate a face clockwise"""
|
||||
face = self.cube[face_idx]
|
||||
new_face = [[None for _ in range(3)] for _ in range(3)]
|
||||
|
||||
# Copy with 90-degree clockwise rotation
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
new_face[j][2-i] = face[i][j]
|
||||
|
||||
self.cube[face_idx] = new_face
|
||||
|
||||
def _rotate_adjacent_faces(self, face_idx: int):
|
||||
"""Rotate the appropriate edges on adjacent faces"""
|
||||
if face_idx == 0: # UP face
|
||||
# Rotate the top edges of FRONT, RIGHT, BACK, LEFT
|
||||
temp = self.cube[4][0][:] # Save FRONT top edge
|
||||
self.cube[4][0] = self.cube[2][0][:] # FRONT <- LEFT
|
||||
self.cube[2][0] = self.cube[5][0][:] # LEFT <- BACK
|
||||
self.cube[5][0] = self.cube[3][0][:] # BACK <- RIGHT
|
||||
self.cube[3][0] = temp # RIGHT <- FRONT
|
||||
|
||||
elif face_idx == 1: # DOWN face
|
||||
# Rotate the bottom edges of FRONT, LEFT, BACK, RIGHT
|
||||
temp = self.cube[4][2][:] # Save FRONT bottom edge
|
||||
self.cube[4][2] = self.cube[3][2][:] # FRONT <- RIGHT
|
||||
self.cube[3][2] = self.cube[5][2][:] # RIGHT <- BACK
|
||||
self.cube[5][2] = self.cube[2][2][:] # BACK <- LEFT
|
||||
self.cube[2][2] = temp # LEFT <- FRONT
|
||||
|
||||
elif face_idx == 2: # LEFT face
|
||||
# Rotate the left edges of UP, FRONT, DOWN, BACK
|
||||
# Need to extract and set columns, not rows
|
||||
temp = [self.cube[0][i][0] for i in range(3)] # Save UP left column
|
||||
|
||||
# UP left <- BACK right (reversed)
|
||||
for i in range(3):
|
||||
self.cube[0][i][0] = self.cube[5][2-i][2]
|
||||
|
||||
# BACK right <- DOWN left (reversed)
|
||||
for i in range(3):
|
||||
self.cube[5][i][2] = self.cube[1][2-i][0]
|
||||
|
||||
# DOWN left <- FRONT left
|
||||
for i in range(3):
|
||||
self.cube[1][i][0] = self.cube[4][i][0]
|
||||
|
||||
# FRONT left <- UP left
|
||||
for i in range(3):
|
||||
self.cube[4][i][0] = temp[i]
|
||||
|
||||
elif face_idx == 3: # RIGHT face
|
||||
# Rotate the right edges of UP, BACK, DOWN, FRONT
|
||||
temp = [self.cube[0][i][2] for i in range(3)] # Save UP right column
|
||||
|
||||
# UP right <- FRONT right
|
||||
for i in range(3):
|
||||
self.cube[0][i][2] = self.cube[4][i][2]
|
||||
|
||||
# FRONT right <- DOWN right
|
||||
for i in range(3):
|
||||
self.cube[4][i][2] = self.cube[1][i][2]
|
||||
|
||||
# DOWN right <- BACK left (reversed)
|
||||
for i in range(3):
|
||||
self.cube[1][i][2] = self.cube[5][2-i][0]
|
||||
|
||||
# BACK left <- UP right (reversed)
|
||||
for i in range(3):
|
||||
self.cube[5][i][0] = temp[2-i]
|
||||
|
||||
elif face_idx == 4: # FRONT face
|
||||
# Rotate the edges of UP bottom, RIGHT left, DOWN top, LEFT right
|
||||
# UP bottom row
|
||||
temp = self.cube[0][2][:]
|
||||
|
||||
# UP bottom <- LEFT right (rotated)
|
||||
for i in range(3):
|
||||
self.cube[0][2][i] = self.cube[2][2-i][2]
|
||||
|
||||
# LEFT right <- DOWN top (rotated)
|
||||
for i in range(3):
|
||||
self.cube[2][i][2] = self.cube[1][0][i]
|
||||
|
||||
# DOWN top <- RIGHT left (rotated)
|
||||
for i in range(3):
|
||||
self.cube[1][0][i] = self.cube[3][2-i][0]
|
||||
|
||||
# RIGHT left <- UP bottom (rotated)
|
||||
for i in range(3):
|
||||
self.cube[3][i][0] = temp[i]
|
||||
|
||||
elif face_idx == 5: # BACK face
|
||||
# Rotate the edges of UP top, LEFT left, DOWN bottom, RIGHT right
|
||||
# UP top row
|
||||
temp = self.cube[0][0][:]
|
||||
|
||||
# UP top <- RIGHT right (rotated)
|
||||
for i in range(3):
|
||||
self.cube[0][0][i] = self.cube[3][2-i][2]
|
||||
|
||||
# RIGHT right <- DOWN bottom (rotated)
|
||||
for i in range(3):
|
||||
self.cube[3][i][2] = self.cube[1][2][i]
|
||||
|
||||
# DOWN bottom <- LEFT left (rotated)
|
||||
for i in range(3):
|
||||
self.cube[1][2][i] = self.cube[2][2-i][0]
|
||||
|
||||
# LEFT left <- UP top (rotated)
|
||||
for i in range(3):
|
||||
self.cube[2][i][0] = temp[i]
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Convert cube to string representation"""
|
||||
face_names = ['U', 'D', 'L', 'R', 'F', 'B']
|
||||
result = []
|
||||
|
||||
for i, face in enumerate(self.cube):
|
||||
result.append(f"{face_names[i]}: {' '.join(face[0])}")
|
||||
result.append(f" {' '.join(face[1])}")
|
||||
result.append(f" {' '.join(face[2])}")
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def test_basic_moves():
|
||||
"""Test basic moves and their inverses"""
|
||||
print("=== TESTING BASIC MOVES ===")
|
||||
|
||||
# Test each basic move and its inverse
|
||||
for move, inverse in [
|
||||
("R", "R'"), ("L", "L'"), ("U", "U'"),
|
||||
("D", "D'"), ("F", "F'"), ("B", "B'")
|
||||
]:
|
||||
cube = Cube()
|
||||
cube.rotate(move)
|
||||
cube.rotate(inverse)
|
||||
solved = cube.is_solved()
|
||||
|
||||
print(f"Move {move} followed by {inverse}: {'PASS' if solved else 'FAIL'}")
|
||||
|
||||
if not solved:
|
||||
print("Final cube state:")
|
||||
print(str(cube))
|
||||
|
||||
def test_double_moves():
|
||||
"""Test double (180°) moves"""
|
||||
print("\n=== TESTING DOUBLE MOVES ===")
|
||||
|
||||
# Test each double move applied twice
|
||||
for move in ["U2", "D2", "L2", "R2", "F2", "B2"]:
|
||||
cube = Cube()
|
||||
cube.rotate(move)
|
||||
cube.rotate(move)
|
||||
solved = cube.is_solved()
|
||||
|
||||
print(f"Double move {move} applied twice: {'PASS' if solved else 'FAIL'}")
|
||||
|
||||
if not solved:
|
||||
print("Final cube state:")
|
||||
print(str(cube))
|
||||
|
||||
def test_complex_algorithms():
|
||||
"""Test more complex algorithms"""
|
||||
print("\n=== TESTING COMPLEX ALGORITHMS ===")
|
||||
|
||||
# Test algorithms
|
||||
algorithms = [
|
||||
{
|
||||
"name": "Sexy Move (R U R' U') × 6",
|
||||
"moves": ["R", "U", "R'", "U'"] * 6,
|
||||
"should_solve": True
|
||||
},
|
||||
{
|
||||
"name": "Scramble + Inverse",
|
||||
"moves": ["R", "U", "F'", "L", "D2"] + ["D2", "L'", "F", "U'", "R'"],
|
||||
"should_solve": True
|
||||
},
|
||||
{
|
||||
"name": "Sune Algorithm (R U R' U R U2 R')",
|
||||
"moves": ["R", "U", "R'", "U", "R", "U2", "R'"],
|
||||
"should_solve": False
|
||||
}
|
||||
]
|
||||
|
||||
for algo in algorithms:
|
||||
cube = Cube()
|
||||
print(f"\nTesting: {algo['name']}")
|
||||
|
||||
# Apply moves
|
||||
for move in algo["moves"]:
|
||||
cube.rotate(move)
|
||||
|
||||
# Check result
|
||||
is_solved = cube.is_solved()
|
||||
expected = algo["should_solve"]
|
||||
|
||||
if is_solved == expected:
|
||||
print(f"Result: PASS (Expected {'solved' if expected else 'not solved'}, Got {'solved' if is_solved else 'not solved'})")
|
||||
else:
|
||||
print(f"Result: FAIL (Expected {'solved' if expected else 'not solved'}, Got {'solved' if is_solved else 'not solved'})")
|
||||
print("Final cube state:")
|
||||
print(str(cube))
|
||||
|
||||
# Show progress percentage if not solved
|
||||
if not is_solved:
|
||||
progress = cube.count_solved_cubies()
|
||||
print(f"Progress toward solution: {progress:.2f}")
|
||||
|
||||
def test_scramble_and_count():
|
||||
"""Test scrambling and counting progress"""
|
||||
print("\n=== TESTING SCRAMBLING AND PROGRESS TRACKING ===")
|
||||
|
||||
# Create a cube and apply random-like scramble
|
||||
cube = Cube()
|
||||
print("Solved cube:")
|
||||
print(str(cube))
|
||||
print(f"Is solved: {cube.is_solved()}")
|
||||
print(f"Progress: {cube.count_solved_cubies():.2f}")
|
||||
|
||||
# Apply a sequence of moves to scramble
|
||||
scramble = ["R", "U", "F", "D", "L", "B'", "R'", "U2"]
|
||||
|
||||
print(f"\nApplying scramble: {' '.join(scramble)}")
|
||||
for move in scramble:
|
||||
cube.rotate(move)
|
||||
|
||||
print("Scrambled cube:")
|
||||
print(str(cube))
|
||||
print(f"Is solved: {cube.is_solved()}")
|
||||
print(f"Progress: {cube.count_solved_cubies():.2f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_basic_moves()
|
||||
test_double_moves()
|
||||
test_complex_algorithms()
|
||||
test_scramble_and_count()
|
||||
6
environments/hack0/rubiks_process_results_32.jsonl
Normal file
6
environments/hack0/rubiks_process_results_32.jsonl
Normal file
File diff suppressed because one or more lines are too long
384
environments/hack0/rubiks_strategies.py
Normal file
384
environments/hack0/rubiks_strategies.py
Normal file
|
|
@ -0,0 +1,384 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
RubiksCubeStrategies: Library of solving strategies for the Rubik's cube environment
|
||||
|
||||
This module provides a collection of solving strategies for Rubik's cube, along with
|
||||
explanations and examples for each. These strategies can be used to guide the LLM's
|
||||
solving approach and provide structured learning.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
class SolvingStrategy:
|
||||
"""Base class for Rubik's cube solving strategies"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
difficulty: int,
|
||||
steps: List[str],
|
||||
example_algorithms: List[Dict[str, str]] = None,
|
||||
tips: List[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize a solving strategy
|
||||
|
||||
Args:
|
||||
name: Strategy name
|
||||
description: Detailed description of the strategy
|
||||
difficulty: Difficulty level (1-5)
|
||||
steps: Ordered list of steps to follow
|
||||
example_algorithms: Common algorithms used in this strategy
|
||||
tips: Tips for using this strategy effectively
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.difficulty = difficulty
|
||||
self.steps = steps
|
||||
self.example_algorithms = example_algorithms or []
|
||||
self.tips = tips or []
|
||||
|
||||
def get_prompt_section(self) -> str:
|
||||
"""Get formatted prompt section for this strategy"""
|
||||
prompt = f"""
|
||||
STRATEGY: {self.name} (Difficulty: {self.difficulty}/5)
|
||||
|
||||
DESCRIPTION:
|
||||
{self.description}
|
||||
|
||||
STEPS:
|
||||
"""
|
||||
for i, step in enumerate(self.steps, 1):
|
||||
prompt += f"{i}. {step}\n"
|
||||
|
||||
if self.example_algorithms:
|
||||
prompt += "\nCOMMON ALGORITHMS:\n"
|
||||
for algo in self.example_algorithms:
|
||||
prompt += f"- {algo['name']}: {algo['moves']} - {algo['purpose']}\n"
|
||||
|
||||
if self.tips:
|
||||
prompt += "\nTIPS:\n"
|
||||
for tip in self.tips:
|
||||
prompt += f"- {tip}\n"
|
||||
|
||||
return prompt
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name} (Difficulty: {self.difficulty}/5)"
|
||||
|
||||
|
||||
# Define common strategies
|
||||
LAYER_BY_LAYER = SolvingStrategy(
|
||||
name="Layer-by-Layer Method",
|
||||
description=(
|
||||
"The beginner-friendly approach that solves the cube one layer at a time. "
|
||||
"It's intuitive and requires memorizing only a few algorithms."
|
||||
),
|
||||
difficulty=1,
|
||||
steps=[
|
||||
"Solve the white cross on the top face",
|
||||
"Place the white corner pieces to complete the first layer",
|
||||
"Solve the middle layer edges",
|
||||
"Create a yellow cross on the bottom face",
|
||||
"Position the yellow edges correctly",
|
||||
"Position the yellow corners correctly",
|
||||
"Orient the yellow corners correctly"
|
||||
],
|
||||
example_algorithms=[
|
||||
{
|
||||
"name": "Sexy Move",
|
||||
"moves": "R U R' U'",
|
||||
"purpose": "Used for placing corners in the first layer"
|
||||
},
|
||||
{
|
||||
"name": "Middle Layer Edge - Left",
|
||||
"moves": "U' L' U L U F U' F'",
|
||||
"purpose": "Insert edge piece into the middle layer from the left"
|
||||
},
|
||||
{
|
||||
"name": "Middle Layer Edge - Right",
|
||||
"moves": "U R U' R' U' F' U F",
|
||||
"purpose": "Insert edge piece into the middle layer from the right"
|
||||
},
|
||||
{
|
||||
"name": "Orient Yellow Edges",
|
||||
"moves": "F R U R' U' F'",
|
||||
"purpose": "Create a yellow cross on the last layer"
|
||||
}
|
||||
],
|
||||
tips=[
|
||||
"Always keep the white face on top when solving the first layer",
|
||||
"Look ahead to plan edge placement before executing moves",
|
||||
"Pay attention to where pieces need to go before applying algorithms",
|
||||
"Break down the solution into manageable steps"
|
||||
]
|
||||
)
|
||||
|
||||
CFOP_METHOD = SolvingStrategy(
|
||||
name="CFOP Method (Fridrich Method)",
|
||||
description=(
|
||||
"An advanced method used by speedcubers. CFOP stands for Cross, F2L (First Two Layers), "
|
||||
"OLL (Orient Last Layer), and PLL (Permute Last Layer). It's efficient but requires "
|
||||
"memorizing many algorithms."
|
||||
),
|
||||
difficulty=4,
|
||||
steps=[
|
||||
"Solve the cross on the bottom face (usually white)",
|
||||
"Solve the First Two Layers (F2L) by pairing corners with edges and inserting them",
|
||||
"Orient the Last Layer (OLL) to make the top face all one color",
|
||||
"Permute the Last Layer (PLL) to arrange all pieces correctly"
|
||||
],
|
||||
example_algorithms=[
|
||||
{
|
||||
"name": "F2L Case 1",
|
||||
"moves": "R U R'",
|
||||
"purpose": "Basic F2L insertion when corner and edge are paired"
|
||||
},
|
||||
{
|
||||
"name": "F2L Case 2",
|
||||
"moves": "y' U' L' U L",
|
||||
"purpose": "Basic F2L insertion (mirror of case 1)"
|
||||
},
|
||||
{
|
||||
"name": "Sune",
|
||||
"moves": "R U R' U R U2 R'",
|
||||
"purpose": "Common OLL algorithm used to orient corners"
|
||||
},
|
||||
{
|
||||
"name": "T Permutation",
|
||||
"moves": "R U R' U' R' F R2 U' R' U' R U R' F'",
|
||||
"purpose": "PLL algorithm that swaps two corners and two edges"
|
||||
}
|
||||
],
|
||||
tips=[
|
||||
"Practice F2L intuitively before learning algorithms",
|
||||
"Solve the cross on the bottom to see the F2L pairs more easily",
|
||||
"Learn to recognize F2L cases from different angles",
|
||||
"Group PLL algorithms by similar patterns to make memorization easier"
|
||||
]
|
||||
)
|
||||
|
||||
ROUX_METHOD = SolvingStrategy(
|
||||
name="Roux Method",
|
||||
description=(
|
||||
"A method focused on building blocks and using M-slice moves. It's very efficient "
|
||||
"and requires fewer algorithm memorizations than CFOP but demands good spatial intuition."
|
||||
),
|
||||
difficulty=3,
|
||||
steps=[
|
||||
"Build a 1x2x3 block on the left side (First Block)",
|
||||
"Build a 1x2x3 block on the right side (Second Block)",
|
||||
"Orient the corners of the top layer and permute the corners of the top layer (CMLL)",
|
||||
"Orient the edges of the last layer and permute the M-slice (L6E)"
|
||||
],
|
||||
example_algorithms=[
|
||||
{
|
||||
"name": "CMLL - O Case",
|
||||
"moves": "R U R' F' R U R' U' R' F R2 U' R'",
|
||||
"purpose": "Orient and permute corners when all corners are oriented incorrectly"
|
||||
},
|
||||
{
|
||||
"name": "EO - Arrow",
|
||||
"moves": "M U M'",
|
||||
"purpose": "Edge orientation during L6E phase"
|
||||
},
|
||||
{
|
||||
"name": "UL/UR Edge Swap",
|
||||
"moves": "M' U2 M U2",
|
||||
"purpose": "Swap the UL and UR edges during L6E phase"
|
||||
}
|
||||
],
|
||||
tips=[
|
||||
"Focus on block-building efficiency for the first two blocks",
|
||||
"Use inspection time to plan the first block completely",
|
||||
"Practice M-slice moves to develop speed and accuracy",
|
||||
"Learn to recognize CMLL cases quickly to reduce pauses"
|
||||
]
|
||||
)
|
||||
|
||||
ZZ_METHOD = SolvingStrategy(
|
||||
name="ZZ Method",
|
||||
description=(
|
||||
"A method that focuses on solving edges early to enable rotationless solving. "
|
||||
"It orients all edges first, then solves the cube without F or B moves."
|
||||
),
|
||||
difficulty=3,
|
||||
steps=[
|
||||
"Orient all edges (EOLine) while placing DF and DB edges",
|
||||
"Build the F2L on the left and right sides (ZZF2L)",
|
||||
"Orient the corners of the last layer (OCLL)",
|
||||
"Permute the last layer (PLL)"
|
||||
],
|
||||
example_algorithms=[
|
||||
{
|
||||
"name": "EOLine Example",
|
||||
"moves": "F L' U B' D'",
|
||||
"purpose": "Orient all edges and place DF and DB edges"
|
||||
},
|
||||
{
|
||||
"name": "ZZF2L Pair",
|
||||
"moves": "U L U' L'",
|
||||
"purpose": "Insert corner-edge pair during F2L"
|
||||
},
|
||||
{
|
||||
"name": "OCLL - Sune",
|
||||
"moves": "R U R' U R U2 R'",
|
||||
"purpose": "Orient three corners in the last layer"
|
||||
}
|
||||
],
|
||||
tips=[
|
||||
"Practice EOLine recognition to improve planning during inspection",
|
||||
"Take advantage of the rotationless solving after EOLine",
|
||||
"Use block-building techniques similar to Petrus for F2L",
|
||||
"Learn to recognize edge orientation quickly"
|
||||
]
|
||||
)
|
||||
|
||||
BEGINNER_METHOD = SolvingStrategy(
|
||||
name="Beginner Method",
|
||||
description=(
|
||||
"The simplest approach for complete beginners. Uses very intuitive steps and minimal algorithm "
|
||||
"memorization, focusing on understanding the cube's mechanics rather than speed."
|
||||
),
|
||||
difficulty=1,
|
||||
steps=[
|
||||
"Solve the white cross",
|
||||
"Solve the white corners one by one",
|
||||
"Solve the middle layer edges one by one",
|
||||
"Make a yellow cross on the top",
|
||||
"Solve the yellow edges around the top",
|
||||
"Position the yellow corners",
|
||||
"Orient the yellow corners"
|
||||
],
|
||||
example_algorithms=[
|
||||
{
|
||||
"name": "White Corner Insertion",
|
||||
"moves": "R U R' U'",
|
||||
"purpose": "Move a white corner piece into position"
|
||||
},
|
||||
{
|
||||
"name": "Edge Insertion",
|
||||
"moves": "U R U' R' U' F' U F",
|
||||
"purpose": "Insert a middle layer edge piece"
|
||||
},
|
||||
{
|
||||
"name": "Yellow Cross",
|
||||
"moves": "F R U R' U' F'",
|
||||
"purpose": "Form a yellow cross on the top face"
|
||||
}
|
||||
],
|
||||
tips=[
|
||||
"Focus on understanding what each move does rather than memorizing algorithms",
|
||||
"Take your time and think about where pieces need to go",
|
||||
"Keep track of important pieces while executing algorithms",
|
||||
"Practice the fundamentals until they become natural"
|
||||
]
|
||||
)
|
||||
|
||||
CORNERS_FIRST = SolvingStrategy(
|
||||
name="Corners-First Method",
|
||||
description=(
|
||||
"Solve all corner pieces first, then solve the edges. This approach is less common "
|
||||
"but offers a different perspective on solving the cube."
|
||||
),
|
||||
difficulty=2,
|
||||
steps=[
|
||||
"Orient the corners to get white and yellow on top and bottom",
|
||||
"Permute the corners to their correct positions",
|
||||
"Solve the middle layer edges",
|
||||
"Solve the last layer edges"
|
||||
],
|
||||
example_algorithms=[
|
||||
{
|
||||
"name": "Corner Orientation",
|
||||
"moves": "R' D' R D",
|
||||
"purpose": "Orient a corner in place"
|
||||
},
|
||||
{
|
||||
"name": "Corner 3-Cycle",
|
||||
"moves": "R U' R' D2 R U R' D2",
|
||||
"purpose": "Cycle three corners"
|
||||
},
|
||||
{
|
||||
"name": "Edge 3-Cycle",
|
||||
"moves": "L' R U2 L R' F' L' R U2 L R' F",
|
||||
"purpose": "Cycle three edges"
|
||||
}
|
||||
],
|
||||
tips=[
|
||||
"Use commutators for corner manipulation",
|
||||
"Pay attention to corner orientation as it affects the later steps",
|
||||
"Learn to visualize corner pieces and their correct positions",
|
||||
"Practice edge insertion techniques for the final steps"
|
||||
]
|
||||
)
|
||||
|
||||
def get_strategy_by_name(name: str) -> Optional[SolvingStrategy]:
|
||||
"""Get a strategy by name"""
|
||||
all_strategies = [
|
||||
LAYER_BY_LAYER,
|
||||
CFOP_METHOD,
|
||||
ROUX_METHOD,
|
||||
ZZ_METHOD,
|
||||
BEGINNER_METHOD,
|
||||
CORNERS_FIRST
|
||||
]
|
||||
|
||||
for strategy in all_strategies:
|
||||
if strategy.name.lower() == name.lower():
|
||||
return strategy
|
||||
|
||||
return None
|
||||
|
||||
def get_strategy_by_difficulty(difficulty: int) -> List[SolvingStrategy]:
|
||||
"""Get all strategies at a specific difficulty level"""
|
||||
all_strategies = [
|
||||
LAYER_BY_LAYER,
|
||||
CFOP_METHOD,
|
||||
ROUX_METHOD,
|
||||
ZZ_METHOD,
|
||||
BEGINNER_METHOD,
|
||||
CORNERS_FIRST
|
||||
]
|
||||
|
||||
return [strategy for strategy in all_strategies if strategy.difficulty == difficulty]
|
||||
|
||||
def get_all_strategies() -> List[SolvingStrategy]:
|
||||
"""Get all available strategies"""
|
||||
return [
|
||||
LAYER_BY_LAYER,
|
||||
CFOP_METHOD,
|
||||
ROUX_METHOD,
|
||||
ZZ_METHOD,
|
||||
BEGINNER_METHOD,
|
||||
CORNERS_FIRST
|
||||
]
|
||||
|
||||
def get_strategy_prompt_for_level(level: int) -> str:
|
||||
"""Get a formatted prompt with strategies appropriate for the curriculum level"""
|
||||
if level <= 2:
|
||||
# Beginner levels - show only simpler strategies
|
||||
strategies = [BEGINNER_METHOD, LAYER_BY_LAYER]
|
||||
elif level == 3:
|
||||
# Intermediate level
|
||||
strategies = [LAYER_BY_LAYER, CORNERS_FIRST, ROUX_METHOD]
|
||||
else:
|
||||
# Advanced levels - show all strategies
|
||||
strategies = get_all_strategies()
|
||||
|
||||
prompt = "# RUBIK'S CUBE SOLVING STRATEGIES\n\nBelow are strategies you can use to solve the cube:\n\n"
|
||||
|
||||
for strategy in strategies:
|
||||
prompt += strategy.get_prompt_section() + "\n\n"
|
||||
|
||||
prompt += """
|
||||
When solving the cube, you can use any of these strategies. Make sure to:
|
||||
1. Choose a strategy that fits your understanding and the current cube state
|
||||
2. Explain your thought process using the <think> tags
|
||||
3. Follow the steps of your chosen strategy systematically
|
||||
4. Apply the appropriate algorithms for your current situation
|
||||
5. Track your progress toward the solution
|
||||
"""
|
||||
|
||||
return prompt
|
||||
173
environments/hack0/rubiks_token_rewards.py
Normal file
173
environments/hack0/rubiks_token_rewards.py
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
RubiksCubeTokenRewards: Token-level reward utilities for Rubik's Cube environment
|
||||
|
||||
This module provides functions for calculating token-level rewards, which are
|
||||
important for fine-grained RL training signals that help the model understand
|
||||
which tokens in its response contribute to success or failure.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
|
||||
def calculate_token_level_rewards(
|
||||
response_text: str,
|
||||
is_valid_move: bool,
|
||||
parsed_move: Optional[str],
|
||||
reward: float,
|
||||
token_ids: List[int],
|
||||
scale_factor: float = 0.1
|
||||
) -> List[float]:
|
||||
"""
|
||||
Calculate token-level rewards based on the response quality
|
||||
|
||||
Args:
|
||||
response_text: Full response text from the LLM
|
||||
is_valid_move: Whether the parsed move was valid
|
||||
parsed_move: The parsed move if any
|
||||
reward: The overall reward for the response
|
||||
token_ids: List of token IDs in the response
|
||||
scale_factor: Scale factor for token rewards
|
||||
|
||||
Returns:
|
||||
A list of token-level rewards with the same length as token_ids
|
||||
"""
|
||||
# Initialize with neutral rewards
|
||||
token_rewards = [0.0] * len(token_ids)
|
||||
|
||||
if len(token_ids) == 0:
|
||||
return token_rewards
|
||||
|
||||
# Extract key parts of the response
|
||||
thinking_match = re.search(r"<think>(.*?)</think>", response_text, re.DOTALL)
|
||||
tool_call_match = re.search(r"<tool_call>(.*?)</tool_call>", response_text, re.DOTALL)
|
||||
|
||||
# Find the indices of key tokens
|
||||
thinking_start_idx = response_text.find("<think>")
|
||||
thinking_end_idx = response_text.find("</think>")
|
||||
tool_call_start_idx = response_text.find("<tool_call>")
|
||||
tool_call_end_idx = response_text.find("</tool_call>")
|
||||
|
||||
# Determine approximate character-to-token ratio
|
||||
chars_per_token = len(response_text) / len(token_ids)
|
||||
|
||||
# Flag for quality of thinking
|
||||
has_good_thinking = False
|
||||
if thinking_match and len(thinking_match.group(1).strip()) > 50:
|
||||
has_good_thinking = True
|
||||
|
||||
# Process rewards based on response quality
|
||||
if is_valid_move and has_good_thinking:
|
||||
# Good response with both thinking and valid move
|
||||
# Reward distribution: ~60% to tool call, ~40% to thinking
|
||||
base_reward = reward * scale_factor
|
||||
|
||||
# Distribute rewards
|
||||
for i in range(len(token_ids)):
|
||||
# Estimate the character position this token represents
|
||||
char_pos = int(i * chars_per_token)
|
||||
|
||||
if thinking_start_idx <= char_pos <= thinking_end_idx:
|
||||
# Token is part of thinking section
|
||||
token_rewards[i] = base_reward * 0.4
|
||||
elif tool_call_start_idx <= char_pos <= tool_call_end_idx:
|
||||
# Token is part of tool call section
|
||||
token_rewards[i] = base_reward * 0.6
|
||||
else:
|
||||
# Token is part of other sections
|
||||
token_rewards[i] = base_reward * 0.1
|
||||
|
||||
elif is_valid_move and not has_good_thinking:
|
||||
# Valid move but poor thinking
|
||||
base_reward = reward * scale_factor * 0.7 # Reduced overall reward
|
||||
|
||||
for i in range(len(token_ids)):
|
||||
char_pos = int(i * chars_per_token)
|
||||
|
||||
if tool_call_start_idx <= char_pos <= tool_call_end_idx:
|
||||
# Token is part of tool call section - still good
|
||||
token_rewards[i] = base_reward * 0.8
|
||||
else:
|
||||
# Token is part of other sections - minimal reward
|
||||
token_rewards[i] = base_reward * 0.2
|
||||
|
||||
elif not is_valid_move and has_good_thinking:
|
||||
# Good thinking but invalid move
|
||||
base_reward = reward * scale_factor * 0.5 # Significantly reduced
|
||||
|
||||
for i in range(len(token_ids)):
|
||||
char_pos = int(i * chars_per_token)
|
||||
|
||||
if thinking_start_idx <= char_pos <= thinking_end_idx:
|
||||
# Token is part of thinking section - somewhat good
|
||||
token_rewards[i] = base_reward * 0.6
|
||||
elif tool_call_start_idx <= char_pos <= tool_call_end_idx:
|
||||
# Token is part of tool call section - problematic
|
||||
token_rewards[i] = base_reward * 0.1
|
||||
else:
|
||||
# Token is part of other sections
|
||||
token_rewards[i] = base_reward * 0.3
|
||||
else:
|
||||
# Poor response overall
|
||||
base_reward = reward * scale_factor * 0.3 # Minimal reward
|
||||
|
||||
# Distribute minimal rewards evenly
|
||||
for i in range(len(token_ids)):
|
||||
token_rewards[i] = base_reward
|
||||
|
||||
# Special handling for move-related tokens when there is a valid move
|
||||
if is_valid_move and parsed_move:
|
||||
# Try to find the specific tokens that represent the move
|
||||
move_pattern = re.escape(parsed_move)
|
||||
move_matches = list(re.finditer(move_pattern, response_text))
|
||||
|
||||
for match in move_matches:
|
||||
move_start_idx = match.start()
|
||||
move_end_idx = match.end()
|
||||
|
||||
# Estimate corresponding token indices
|
||||
move_start_token = int(move_start_idx / chars_per_token)
|
||||
move_end_token = int(move_end_idx / chars_per_token) + 1
|
||||
|
||||
# Ensure indices are within bounds
|
||||
move_start_token = max(0, min(move_start_token, len(token_ids) - 1))
|
||||
move_end_token = max(0, min(move_end_token, len(token_ids)))
|
||||
|
||||
# Boost rewards for tokens that directly encode the move
|
||||
for i in range(move_start_token, move_end_token):
|
||||
token_rewards[i] = base_reward * 1.5 # Higher reward for the actual move
|
||||
|
||||
return token_rewards
|
||||
|
||||
def calculate_advantage_token_weights(token_rewards: List[List[float]]) -> List[List[float]]:
|
||||
"""
|
||||
Calculate token weights for advantage computation
|
||||
|
||||
Args:
|
||||
token_rewards: List of token-level rewards for each alternative response
|
||||
|
||||
Returns:
|
||||
List of normalized token weights for each alternative
|
||||
"""
|
||||
# Create a copy to avoid modifying the input
|
||||
token_weights = [rewards.copy() for rewards in token_rewards]
|
||||
|
||||
# For each alternative
|
||||
for i in range(len(token_weights)):
|
||||
# Get min and max rewards for this alternative
|
||||
min_reward = min(token_weights[i]) if token_weights[i] else 0.0
|
||||
max_reward = max(token_weights[i]) if token_weights[i] else 0.0
|
||||
reward_range = max_reward - min_reward
|
||||
|
||||
# Normalize to [0.5, 1.0] range to ensure all tokens get some weight
|
||||
if reward_range > 0:
|
||||
for j in range(len(token_weights[i])):
|
||||
normalized = 0.5 + 0.5 * (token_weights[i][j] - min_reward) / reward_range
|
||||
token_weights[i][j] = normalized
|
||||
else:
|
||||
# If all rewards are the same, use uniform weights
|
||||
for j in range(len(token_weights[i])):
|
||||
token_weights[i][j] = 1.0
|
||||
|
||||
return token_weights
|
||||
104
environments/hack0/test_rubiks_cube.py
Normal file
104
environments/hack0/test_rubiks_cube.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for the Rubik's Cube environment
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from simple_cube import Cube
|
||||
|
||||
from rubiks_cube_environment import RubiksCubeEnv, RubiksCubeEnvConfig
|
||||
from rubiks_cube_visualizer import save_cube_visualization
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
async def test_cube_visualization():
|
||||
"""Test the cube visualization functionality"""
|
||||
# Create a cube
|
||||
cube = Cube()
|
||||
|
||||
# Scramble it with some random moves
|
||||
moves = ["U", "D", "L", "R", "F", "B",
|
||||
"U'", "D'", "L'", "R'", "F'", "B'",
|
||||
"U2", "D2", "L2", "R2", "F2", "B2"]
|
||||
|
||||
move_history = []
|
||||
for _ in range(5):
|
||||
move = random.choice(moves)
|
||||
move_history.append(move)
|
||||
cube.rotate(move)
|
||||
|
||||
# Visualize the scrambled cube
|
||||
cube_state = str(cube)
|
||||
html_path = save_cube_visualization(
|
||||
cube_state,
|
||||
move_history,
|
||||
"test_scrambled_cube.html"
|
||||
)
|
||||
|
||||
print(f"Scrambled cube visualization saved to {html_path}")
|
||||
print(f"Moves applied: {move_history}")
|
||||
print(f"Is solved: {cube.is_solved()}")
|
||||
|
||||
async def test_environment():
|
||||
"""Test the basic functionality of the environment"""
|
||||
# Create the environment configuration
|
||||
config = RubiksCubeEnvConfig(
|
||||
tokenizer_name="gpt2", # Use a simple tokenizer for testing
|
||||
group_size=2, # Small group size for testing
|
||||
use_wandb=False,
|
||||
max_steps=5,
|
||||
scramble_moves=3,
|
||||
debug_mode=True,
|
||||
)
|
||||
|
||||
# Create server configuration
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt2",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
)
|
||||
]
|
||||
|
||||
# Create the environment
|
||||
env = RubiksCubeEnv(config, server_configs, slurm=False, testing=True)
|
||||
|
||||
# Test creating an episode
|
||||
seed = 12345
|
||||
episode = env._get_or_create_episode(seed)
|
||||
|
||||
# Print initial state
|
||||
print(f"Initial cube state (seed {seed}):")
|
||||
print(episode.get_cube_state_visualization())
|
||||
|
||||
# Test visualization
|
||||
html_path = save_cube_visualization(
|
||||
episode.get_cube_state_visualization(),
|
||||
[],
|
||||
"test_initial_cube.html"
|
||||
)
|
||||
print(f"Initial cube visualization saved to {html_path}")
|
||||
|
||||
# Test applying moves
|
||||
test_moves = ["U", "R", "F'"]
|
||||
for move in test_moves:
|
||||
success = episode.apply_move(move)
|
||||
print(f"Applied move {move}: {'Success' if success else 'Failed'}")
|
||||
|
||||
# Check if solved
|
||||
print(f"Is solved: {episode.is_solved()}")
|
||||
|
||||
# Test final state visualization
|
||||
html_path = save_cube_visualization(
|
||||
episode.get_cube_state_visualization(),
|
||||
episode.actions,
|
||||
"test_after_moves_cube.html"
|
||||
)
|
||||
print(f"Final cube visualization saved to {html_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
print("Running Rubik's Cube environment tests...")
|
||||
asyncio.run(test_cube_visualization())
|
||||
asyncio.run(test_environment())
|
||||
print("Tests completed.")
|
||||
120
environments/hack0/test_rubiks_environment.py
Normal file
120
environments/hack0/test_rubiks_environment.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for the Rubik's Cube environment
|
||||
"""
|
||||
|
||||
# Import Cube class directly from rubiks_cube_environment.py
|
||||
from rubiks_cube_environment import Cube
|
||||
|
||||
def test_basic_moves():
|
||||
"""Test basic moves and their inverses"""
|
||||
print("=== TESTING BASIC MOVES ===")
|
||||
|
||||
# Test each basic move and its inverse
|
||||
for move, inverse in [
|
||||
("R", "R'"), ("L", "L'"), ("U", "U'"),
|
||||
("D", "D'"), ("F", "F'"), ("B", "B'")
|
||||
]:
|
||||
cube = Cube()
|
||||
cube.rotate(move)
|
||||
cube.rotate(inverse)
|
||||
solved = cube.is_solved()
|
||||
|
||||
print(f"Move {move} followed by {inverse}: {'PASS' if solved else 'FAIL'}")
|
||||
|
||||
if not solved:
|
||||
print("Final cube state:")
|
||||
print(str(cube))
|
||||
|
||||
def test_double_moves():
|
||||
"""Test double (180°) moves"""
|
||||
print("\n=== TESTING DOUBLE MOVES ===")
|
||||
|
||||
# Test each double move applied twice
|
||||
for move in ["U2", "D2", "L2", "R2", "F2", "B2"]:
|
||||
cube = Cube()
|
||||
cube.rotate(move)
|
||||
cube.rotate(move)
|
||||
solved = cube.is_solved()
|
||||
|
||||
print(f"Double move {move} applied twice: {'PASS' if solved else 'FAIL'}")
|
||||
|
||||
if not solved:
|
||||
print("Final cube state:")
|
||||
print(str(cube))
|
||||
|
||||
def test_complex_algorithms():
|
||||
"""Test more complex algorithms"""
|
||||
print("\n=== TESTING COMPLEX ALGORITHMS ===")
|
||||
|
||||
# Test algorithms
|
||||
algorithms = [
|
||||
{
|
||||
"name": "Sexy Move (R U R' U') × 6",
|
||||
"moves": ["R", "U", "R'", "U'"] * 6,
|
||||
"should_solve": True
|
||||
},
|
||||
{
|
||||
"name": "Scramble + Inverse",
|
||||
"moves": ["R", "U", "F'", "L", "D2"] + ["D2", "L'", "F", "U'", "R'"],
|
||||
"should_solve": True
|
||||
},
|
||||
{
|
||||
"name": "Sune Algorithm (R U R' U R U2 R')",
|
||||
"moves": ["R", "U", "R'", "U", "R", "U2", "R'"],
|
||||
"should_solve": False
|
||||
}
|
||||
]
|
||||
|
||||
for algo in algorithms:
|
||||
cube = Cube()
|
||||
print(f"\nTesting: {algo['name']}")
|
||||
|
||||
# Apply moves
|
||||
for move in algo["moves"]:
|
||||
cube.rotate(move)
|
||||
|
||||
# Check result
|
||||
is_solved = cube.is_solved()
|
||||
expected = algo["should_solve"]
|
||||
|
||||
if is_solved == expected:
|
||||
print(f"Result: PASS (Expected {'solved' if expected else 'not solved'}, Got {'solved' if is_solved else 'not solved'})")
|
||||
else:
|
||||
print(f"Result: FAIL (Expected {'solved' if expected else 'not solved'}, Got {'solved' if is_solved else 'not solved'})")
|
||||
print("Final cube state:")
|
||||
print(str(cube))
|
||||
|
||||
# Show progress percentage if not solved
|
||||
if not is_solved:
|
||||
progress = cube.count_solved_cubies()
|
||||
print(f"Progress toward solution: {progress:.2f}")
|
||||
|
||||
def test_scramble_and_count():
|
||||
"""Test scrambling and counting progress"""
|
||||
print("\n=== TESTING SCRAMBLING AND PROGRESS TRACKING ===")
|
||||
|
||||
# Create a cube and apply random-like scramble
|
||||
cube = Cube()
|
||||
print("Solved cube:")
|
||||
print(str(cube))
|
||||
print(f"Is solved: {cube.is_solved()}")
|
||||
print(f"Progress: {cube.count_solved_cubies():.2f}")
|
||||
|
||||
# Apply a sequence of moves to scramble
|
||||
scramble = ["R", "U", "F", "D", "L", "B'", "R'", "U2"]
|
||||
|
||||
print(f"\nApplying scramble: {' '.join(scramble)}")
|
||||
for move in scramble:
|
||||
cube.rotate(move)
|
||||
|
||||
print("Scrambled cube:")
|
||||
print(str(cube))
|
||||
print(f"Is solved: {cube.is_solved()}")
|
||||
print(f"Progress: {cube.count_solved_cubies():.2f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_basic_moves()
|
||||
test_double_moves()
|
||||
test_complex_algorithms()
|
||||
test_scramble_and_count()
|
||||
464
environments/hack0/train_rubiks_model.py
Executable file
464
environments/hack0/train_rubiks_model.py
Executable file
|
|
@ -0,0 +1,464 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Train a model to solve Rubik's cube using reinforcement learning on collected data.
|
||||
Based on the example GRPO trainer with modifications for pre-collected data.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
import wandb
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TrainerConfig:
|
||||
"""Configuration for the trainer."""
|
||||
# Model configuration
|
||||
model_name: str
|
||||
learning_rate: float
|
||||
batch_size: int
|
||||
gradient_accumulation_steps: int
|
||||
sequence_length: int
|
||||
warmup_steps: int
|
||||
|
||||
# Training configuration
|
||||
total_steps: int
|
||||
eval_every: int
|
||||
save_every: int
|
||||
checkpoint_dir: str
|
||||
use_wandb: bool
|
||||
wandb_project: str
|
||||
wandb_run_name: str
|
||||
|
||||
# Data configuration
|
||||
train_file: str
|
||||
validation_size: float
|
||||
prefer_higher_scores: bool
|
||||
max_samples: int
|
||||
|
||||
# RL configuration
|
||||
method: str
|
||||
temperature: float
|
||||
top_p: float
|
||||
beta: float
|
||||
reference_model: Optional[str] = None
|
||||
|
||||
|
||||
def load_config(config_path: str) -> TrainerConfig:
|
||||
"""Load configuration from YAML file."""
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
# The config is already flat, so we use it directly
|
||||
return TrainerConfig(**config_dict)
|
||||
|
||||
|
||||
def load_jsonl_data(file_path: str, max_samples: int = -1) -> List[Dict]:
|
||||
"""Load data from JSONL file."""
|
||||
data = []
|
||||
with open(file_path, "r") as f:
|
||||
for line in f:
|
||||
data.append(json.loads(line))
|
||||
if max_samples > 0 and len(data) >= max_samples:
|
||||
break
|
||||
return data
|
||||
|
||||
|
||||
def split_train_val(data: List[Dict], val_size: float) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""Split data into training and validation sets."""
|
||||
val_count = int(len(data) * val_size)
|
||||
return data[val_count:], data[:val_count]
|
||||
|
||||
|
||||
def prepare_training_batch(
|
||||
data_batch: List[Dict],
|
||||
tokenizer,
|
||||
prefer_higher_scores: bool = True,
|
||||
device: str = "cuda"
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Prepare a batch for training.
|
||||
|
||||
Args:
|
||||
data_batch: List of data points from JSONL
|
||||
tokenizer: Tokenizer for the model
|
||||
prefer_higher_scores: If True, higher scores are better
|
||||
device: Device to put tensors on
|
||||
|
||||
Returns:
|
||||
Dict with input_ids, attention_mask, and scores
|
||||
"""
|
||||
batch_tokens = []
|
||||
batch_masks = []
|
||||
batch_scores = []
|
||||
|
||||
for item in data_batch:
|
||||
# For each group, select best and worst sequences based on scores
|
||||
scores = item["scores"]
|
||||
tokens = item["tokens"]
|
||||
masks = item["masks"]
|
||||
|
||||
if prefer_higher_scores:
|
||||
best_idx = max(range(len(scores)), key=lambda i: scores[i])
|
||||
worst_idx = min(range(len(scores)), key=lambda i: scores[i])
|
||||
else:
|
||||
best_idx = min(range(len(scores)), key=lambda i: scores[i])
|
||||
worst_idx = max(range(len(scores)), key=lambda i: scores[i])
|
||||
|
||||
batch_tokens.extend([tokens[best_idx], tokens[worst_idx]])
|
||||
batch_masks.extend([masks[best_idx], masks[worst_idx]])
|
||||
batch_scores.extend([scores[best_idx], scores[worst_idx]])
|
||||
|
||||
# Convert to tensors
|
||||
input_ids = torch.tensor(batch_tokens, dtype=torch.long).to(device)
|
||||
attention_mask = torch.tensor(batch_masks, dtype=torch.long).to(device)
|
||||
scores = torch.tensor(batch_scores, dtype=torch.float).to(device)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"scores": scores,
|
||||
}
|
||||
|
||||
|
||||
def compute_grpo_loss(
|
||||
logprobs: torch.Tensor,
|
||||
ref_logprobs: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
beta: float
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the Group Relative Policy Optimization loss.
|
||||
|
||||
Args:
|
||||
logprobs: Log probabilities from the model (batch_size, seq_len)
|
||||
ref_logprobs: Log probabilities from the reference model (batch_size, seq_len)
|
||||
scores: Scores for each sequence (batch_size,)
|
||||
beta: KL penalty coefficient
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
batch_size = logprobs.shape[0]
|
||||
assert batch_size % 2 == 0, "Batch size must be even"
|
||||
|
||||
# Reshape to (batch_size/2, 2, seq_len)
|
||||
logprobs = logprobs.view(batch_size // 2, 2, -1)
|
||||
ref_logprobs = ref_logprobs.view(batch_size // 2, 2, -1)
|
||||
scores = scores.view(batch_size // 2, 2)
|
||||
|
||||
# Calculate policy gradient loss
|
||||
pg_loss = 0
|
||||
for i in range(batch_size // 2):
|
||||
# Policy gradient - weight by the score difference
|
||||
score_diff = scores[i, 0] - scores[i, 1]
|
||||
log_ratio_chosen = logprobs[i, 0].sum() - ref_logprobs[i, 0].sum()
|
||||
log_ratio_rejected = logprobs[i, 1].sum() - ref_logprobs[i, 1].sum()
|
||||
|
||||
# KL penalty
|
||||
kl_chosen = (ref_logprobs[i, 0] - logprobs[i, 0]).sum()
|
||||
kl_rejected = (ref_logprobs[i, 1] - logprobs[i, 1]).sum()
|
||||
|
||||
# Final loss - maximize score difference, minimize KL divergence
|
||||
pg_loss += -score_diff * (log_ratio_chosen - log_ratio_rejected)
|
||||
pg_loss += beta * (kl_chosen + kl_rejected)
|
||||
|
||||
return pg_loss / (batch_size // 2)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train a model on Rubik's cube data")
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load configuration
|
||||
config = load_config(args.config)
|
||||
logger.info(f"Loaded configuration from {args.config}")
|
||||
|
||||
# Set device
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Initialize wandb if specified
|
||||
if config.use_wandb:
|
||||
wandb.init(
|
||||
project=config.wandb_project,
|
||||
name=config.wandb_run_name,
|
||||
config=vars(config)
|
||||
)
|
||||
|
||||
# Create checkpoint directory
|
||||
os.makedirs(config.checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Load tokenizer and model
|
||||
logger.info(f"Loading model {config.model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name,
|
||||
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
||||
).to(device)
|
||||
model.train()
|
||||
|
||||
# Load reference model if specified
|
||||
ref_model = None
|
||||
if config.reference_model:
|
||||
logger.info(f"Loading reference model {config.reference_model}")
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
config.reference_model,
|
||||
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
||||
).to(device)
|
||||
ref_model.eval()
|
||||
|
||||
# Set up optimizer and lr scheduler
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||
scheduler = transformers.get_cosine_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=config.warmup_steps,
|
||||
num_training_steps=config.total_steps
|
||||
)
|
||||
|
||||
# Load and split data
|
||||
logger.info(f"Loading data from {config.train_file}")
|
||||
all_data = load_jsonl_data(config.train_file, config.max_samples)
|
||||
train_data, val_data = split_train_val(all_data, config.validation_size)
|
||||
logger.info(f"Loaded {len(train_data)} training and {len(val_data)} validation samples")
|
||||
|
||||
# Training loop
|
||||
global_step = 0
|
||||
best_val_loss = float('inf')
|
||||
|
||||
logger.info("Starting training")
|
||||
try:
|
||||
for epoch in range(100): # Large number, will break when steps reached
|
||||
# Shuffle training data
|
||||
import random
|
||||
random.shuffle(train_data)
|
||||
|
||||
for i in range(0, len(train_data), config.batch_size // 2):
|
||||
batch_data = train_data[i:i + config.batch_size // 2]
|
||||
if len(batch_data) < config.batch_size // 2:
|
||||
continue # Skip incomplete batches
|
||||
|
||||
# Prepare batch
|
||||
batch = prepare_training_batch(
|
||||
batch_data,
|
||||
tokenizer,
|
||||
prefer_higher_scores=config.prefer_higher_scores,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
with torch.cuda.amp.autocast(enabled=device == "cuda"):
|
||||
outputs = model(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Compute log probabilities
|
||||
logits = outputs.logits[:, :-1]
|
||||
logprobs = F.log_softmax(logits, dim=-1)
|
||||
target_ids = batch["input_ids"][:, 1:]
|
||||
masks = batch["attention_mask"][:, 1:]
|
||||
|
||||
# Get log probs for the chosen tokens
|
||||
chosen_logprobs = torch.gather(
|
||||
logprobs,
|
||||
dim=2,
|
||||
index=target_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
# Apply mask
|
||||
chosen_logprobs = chosen_logprobs * masks
|
||||
|
||||
# Get reference log probs if using a reference model
|
||||
if ref_model:
|
||||
with torch.no_grad():
|
||||
ref_outputs = ref_model(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
return_dict=True
|
||||
)
|
||||
ref_logits = ref_outputs.logits[:, :-1]
|
||||
ref_logprobs = F.log_softmax(ref_logits, dim=-1)
|
||||
|
||||
ref_chosen_logprobs = torch.gather(
|
||||
ref_logprobs,
|
||||
dim=2,
|
||||
index=target_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
# Apply mask
|
||||
ref_chosen_logprobs = ref_chosen_logprobs * masks
|
||||
else:
|
||||
# If no reference model, use the current model's initial state
|
||||
ref_chosen_logprobs = chosen_logprobs.detach()
|
||||
|
||||
# Compute loss
|
||||
loss = compute_grpo_loss(
|
||||
chosen_logprobs,
|
||||
ref_chosen_logprobs,
|
||||
batch["scores"],
|
||||
config.beta
|
||||
)
|
||||
|
||||
# Backward pass
|
||||
loss = loss / config.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
# Update weights if gradient accumulation steps reached
|
||||
if (global_step + 1) % config.gradient_accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Log progress
|
||||
if global_step % 10 == 0:
|
||||
logger.info(f"Step {global_step}: loss = {loss.item() * config.gradient_accumulation_steps:.4f}")
|
||||
if config.use_wandb:
|
||||
wandb.log({
|
||||
"train/loss": loss.item() * config.gradient_accumulation_steps,
|
||||
"train/learning_rate": scheduler.get_last_lr()[0],
|
||||
"train/step": global_step,
|
||||
})
|
||||
|
||||
# Evaluate on validation set
|
||||
if global_step % config.eval_every == 0:
|
||||
model.eval()
|
||||
val_losses = []
|
||||
|
||||
with torch.no_grad():
|
||||
for j in range(0, min(len(val_data), 100), config.batch_size // 2):
|
||||
val_batch_data = val_data[j:j + config.batch_size // 2]
|
||||
if len(val_batch_data) < config.batch_size // 2:
|
||||
continue
|
||||
|
||||
val_batch = prepare_training_batch(
|
||||
val_batch_data,
|
||||
tokenizer,
|
||||
prefer_higher_scores=config.prefer_higher_scores,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
val_outputs = model(
|
||||
input_ids=val_batch["input_ids"],
|
||||
attention_mask=val_batch["attention_mask"],
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Compute log probabilities
|
||||
val_logits = val_outputs.logits[:, :-1]
|
||||
val_logprobs = F.log_softmax(val_logits, dim=-1)
|
||||
val_target_ids = val_batch["input_ids"][:, 1:]
|
||||
val_masks = val_batch["attention_mask"][:, 1:]
|
||||
|
||||
# Get log probs for the chosen tokens
|
||||
val_chosen_logprobs = torch.gather(
|
||||
val_logprobs,
|
||||
dim=2,
|
||||
index=val_target_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
# Apply mask
|
||||
val_chosen_logprobs = val_chosen_logprobs * val_masks
|
||||
|
||||
# Get reference log probs
|
||||
if ref_model:
|
||||
ref_val_outputs = ref_model(
|
||||
input_ids=val_batch["input_ids"],
|
||||
attention_mask=val_batch["attention_mask"],
|
||||
return_dict=True
|
||||
)
|
||||
ref_val_logits = ref_val_outputs.logits[:, :-1]
|
||||
ref_val_logprobs = F.log_softmax(ref_val_logits, dim=-1)
|
||||
|
||||
ref_val_chosen_logprobs = torch.gather(
|
||||
ref_val_logprobs,
|
||||
dim=2,
|
||||
index=val_target_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
# Apply mask
|
||||
ref_val_chosen_logprobs = ref_val_chosen_logprobs * val_masks
|
||||
else:
|
||||
# If no reference model, use detached current model outputs
|
||||
ref_val_chosen_logprobs = val_chosen_logprobs.detach()
|
||||
|
||||
# Compute loss
|
||||
val_loss = compute_grpo_loss(
|
||||
val_chosen_logprobs,
|
||||
ref_val_chosen_logprobs,
|
||||
val_batch["scores"],
|
||||
config.beta
|
||||
)
|
||||
val_losses.append(val_loss.item())
|
||||
|
||||
avg_val_loss = sum(val_losses) / len(val_losses)
|
||||
logger.info(f"Validation loss: {avg_val_loss:.4f}")
|
||||
|
||||
if config.use_wandb:
|
||||
wandb.log({
|
||||
"val/loss": avg_val_loss,
|
||||
"val/step": global_step,
|
||||
})
|
||||
|
||||
# Save best model
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
logger.info(f"New best validation loss: {best_val_loss:.4f}")
|
||||
# Save model
|
||||
output_dir = os.path.join(config.checkpoint_dir, f"best_model")
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
model.train()
|
||||
|
||||
# Save checkpoint
|
||||
if global_step % config.save_every == 0 and global_step > 0:
|
||||
output_dir = os.path.join(config.checkpoint_dir, f"checkpoint-{global_step}")
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Increment step
|
||||
global_step += 1
|
||||
|
||||
# Exit if reached total steps
|
||||
if global_step >= config.total_steps:
|
||||
break
|
||||
|
||||
# Exit if reached total steps
|
||||
if global_step >= config.total_steps:
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
|
||||
# Save final model
|
||||
logger.info("Saving final model")
|
||||
output_dir = os.path.join(config.checkpoint_dir, "final_model")
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
if config.use_wandb:
|
||||
wandb.finish()
|
||||
|
||||
logger.info("Training complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue