mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Linting done
This commit is contained in:
parent
a58562447f
commit
65108d12b2
264 changed files with 606 additions and 142874 deletions
104
environments/community/selcube/test_rubiks_cube.py
Normal file
104
environments/community/selcube/test_rubiks_cube.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for the Rubik's Cube environment
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from simple_cube import Cube
|
||||
|
||||
from rubiks_cube_environment import RubiksCubeEnv, RubiksCubeEnvConfig
|
||||
from rubiks_cube_visualizer import save_cube_visualization
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
async def test_cube_visualization():
|
||||
"""Test the cube visualization functionality"""
|
||||
# Create a cube
|
||||
cube = Cube()
|
||||
|
||||
# Scramble it with some random moves
|
||||
moves = ["U", "D", "L", "R", "F", "B",
|
||||
"U'", "D'", "L'", "R'", "F'", "B'",
|
||||
"U2", "D2", "L2", "R2", "F2", "B2"]
|
||||
|
||||
move_history = []
|
||||
for _ in range(5):
|
||||
move = random.choice(moves)
|
||||
move_history.append(move)
|
||||
cube.rotate(move)
|
||||
|
||||
# Visualize the scrambled cube
|
||||
cube_state = str(cube)
|
||||
html_path = save_cube_visualization(
|
||||
cube_state,
|
||||
move_history,
|
||||
"test_scrambled_cube.html"
|
||||
)
|
||||
|
||||
print(f"Scrambled cube visualization saved to {html_path}")
|
||||
print(f"Moves applied: {move_history}")
|
||||
print(f"Is solved: {cube.is_solved()}")
|
||||
|
||||
async def test_environment():
|
||||
"""Test the basic functionality of the environment"""
|
||||
# Create the environment configuration
|
||||
config = RubiksCubeEnvConfig(
|
||||
tokenizer_name="gpt2", # Use a simple tokenizer for testing
|
||||
group_size=2, # Small group size for testing
|
||||
use_wandb=False,
|
||||
max_steps=5,
|
||||
scramble_moves=3,
|
||||
debug_mode=True,
|
||||
)
|
||||
|
||||
# Create server configuration
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt2",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
)
|
||||
]
|
||||
|
||||
# Create the environment
|
||||
env = RubiksCubeEnv(config, server_configs, slurm=False, testing=True)
|
||||
|
||||
# Test creating an episode
|
||||
seed = 12345
|
||||
episode = env._get_or_create_episode(seed)
|
||||
|
||||
# Print initial state
|
||||
print(f"Initial cube state (seed {seed}):")
|
||||
print(episode.get_cube_state_visualization())
|
||||
|
||||
# Test visualization
|
||||
html_path = save_cube_visualization(
|
||||
episode.get_cube_state_visualization(),
|
||||
[],
|
||||
"test_initial_cube.html"
|
||||
)
|
||||
print(f"Initial cube visualization saved to {html_path}")
|
||||
|
||||
# Test applying moves
|
||||
test_moves = ["U", "R", "F'"]
|
||||
for move in test_moves:
|
||||
success = episode.apply_move(move)
|
||||
print(f"Applied move {move}: {'Success' if success else 'Failed'}")
|
||||
|
||||
# Check if solved
|
||||
print(f"Is solved: {episode.is_solved()}")
|
||||
|
||||
# Test final state visualization
|
||||
html_path = save_cube_visualization(
|
||||
episode.get_cube_state_visualization(),
|
||||
episode.actions,
|
||||
"test_after_moves_cube.html"
|
||||
)
|
||||
print(f"Final cube visualization saved to {html_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
print("Running Rubik's Cube environment tests...")
|
||||
asyncio.run(test_cube_visualization())
|
||||
asyncio.run(test_environment())
|
||||
print("Tests completed.")
|
||||
Loading…
Add table
Add a link
Reference in a new issue