linting, moved to community

This commit is contained in:
Shannon Sands 2025-05-27 15:36:24 +10:00
commit 2efb690a24
10 changed files with 1341 additions and 0 deletions

View file

@ -2675,6 +2675,58 @@ python environments/community/starmap_compression/visualize_starmap.py
---
### 28. Padres Spatial RL Environment (`padres_spatial/`)
**Contributors**: basedlsg
**PR**: [#75](https://github.com/NousResearch/atropos/pull/75)
**Integration Status**: ✅ Integrated
**Description**: A 3D spatial reasoning environment that challenges LLMs to understand and manipulate objects in a simulated 3D world using PyBullet physics simulation. The environment tests and improves LLMs' spatial reasoning capabilities through interactive tasks requiring understanding of relative positioning, object manipulation, and spatial relationships.
**Core Features**:
**3D Physics Simulation**:
- **PyBullet Integration**: Full 3D physics simulation with gravity and collision detection
- **Object Manipulation**: Support for cubes and spheres with position and orientation control
- **Real-time Visualization**: Three.js-based web interface for live 3D scene viewing
- **WebSocket Communication**: Real-time updates between simulation and visualization
**Spatial Reasoning Tasks**:
- **Conditional Positioning**: Tasks requiring objects to maintain spatial relationships (e.g., opposite sides of planes)
- **Distance Constraints**: Precise positioning within target distances between objects
- **Multi-objective Scoring**: Rewards both proximity accuracy and spatial relationship satisfaction
- **Dynamic Task Generation**: Procedurally generated spatial reasoning challenges
**LLM Integration**:
- **Anthropic Claude Integration**: Uses Claude-3.5-Sonnet for spatial reasoning
- **JSON Action Format**: Structured action space for object movement commands
- **Fallback Mock System**: Graceful degradation when LLM API is unavailable
- **Prompt Engineering**: Detailed spatial context and constraint descriptions
**Training & Evaluation**:
- **W&B Integration**: Comprehensive metrics tracking and experiment logging
- **Trajectory Generation**: Batch processing mode for dataset creation
- **Interactive Demo Mode**: Real-time visualization with manual task triggering
- **Reward Function**: Balanced scoring for distance accuracy and constraint satisfaction
**Technical Architecture**:
- **Modular Design**: Separate physics simulator, environment wrapper, and visualization components
- **Async Processing**: Non-blocking LLM calls and WebSocket communication
- **Error Handling**: Robust fallback mechanisms for API failures and malformed responses
- **Extensible Framework**: Easy addition of new object types and spatial constraints
**Use Cases**:
- **Spatial Reasoning Research**: Benchmark LLM performance on 3D spatial tasks
- **Robotics Simulation**: Foundation for more complex manipulation scenarios
- **Educational Tool**: Interactive demonstration of spatial AI capabilities
- **RL Training**: Environment for training spatial reasoning policies
**Example Task**: Move a red cube to be approximately 1.0 unit away from a blue sphere while maintaining opposite sides of the YZ plane, testing both distance estimation and spatial relationship understanding.
**Requirements**: pybullet, numpy, websockets, python-dotenv, wandb, anthropic, atroposlib
---
## Support
For questions or issues with community environments:

View file

@ -0,0 +1,49 @@
# Padres: Spatial RL Environment
## Video Demo
[Watch the demo video](https://youtu.be/uuSur31U1Pc)
## Environment Design & Motivation
Padres is a 3D spatial reasoning environment that challenges LLMs to understand and manipulate objects in a simulated 3D world. The environment uses PyBullet for physics simulation and integrates with LLMs for task generation and execution. The primary goal is to test and improve LLMs' spatial reasoning capabilities through interactive tasks that require understanding of relative positioning, object manipulation, and spatial relationships.
## Quickstart
1. Install dependencies:
```bash
pip install -r requirements.txt
```
2. Set up environment variables:
```bash
cp .env.example .env
# Add your OpenAI API key to .env
```
3. Run the environment:
```bash
python spatial_env.py
```
4. View the visualization:
```bash
cd visualization
python3 -m http.server 8080
```
Then visit http://localhost:8080
## W&B Integration & Metrics
View the latest run [here](https://wandb.ai/carlosgarcia/spatial_rl_mvp/runs/1q2w3e4r5t6y7u8i9o0p)
Key metrics tracked:
- Task completion score (0-1)
- Final object distance
- Spatial condition satisfaction
- Action success rate
- LLM response time
## Additional Details
The environment implements a reward function that balances:
1. Proximity to target position
2. Spatial relationship constraints
3. Task completion verification
The current implementation focuses on basic spatial tasks but is designed to be extensible for more complex scenarios. The reward function is structured to prevent common reward hacking strategies by requiring both position accuracy and spatial relationship satisfaction.

View file

@ -0,0 +1 @@
# This file makes Python treat the directory as a package.

View file

@ -0,0 +1,146 @@
import asyncio
import json
import os
from pathlib import Path
import anthropic # Ensure anthropic is imported
from dotenv import load_dotenv
print(
"DEBUG LLM_SERVICES: Top of llm_services.py (v_robust_init)"
) # Added version marker
dotenv_path = Path(__file__).resolve().parent.parent / ".env"
loaded_successfully = False
if dotenv_path.exists():
# override=True ensures that values from .env file will replace existing env variables.
# verbose=True will print messages about what it's doing.
loaded_successfully = load_dotenv(
dotenv_path=dotenv_path, override=True, verbose=True
)
print(
f"DEBUG LLM_SERVICES: load_dotenv attempted from: {dotenv_path}. returned: {loaded_successfully}"
)
else:
print(f"DEBUG LLM_SERVICES: .env file not found at {dotenv_path}.")
API_KEY_FROM_ENV = os.getenv("ANTHROPIC_API_KEY")
print(f"DEBUG LLM_SERVICES: API_KEY_FROM_ENV raw value: '{API_KEY_FROM_ENV}'")
anthropic_client = None
IS_CLIENT_SUCCESSFULLY_INITIALIZED = False
if API_KEY_FROM_ENV:
print(
"DEBUG LLM_SERVICES: API_KEY_FROM_ENV is TRUTHY. Attempting to initialize Anthropic client."
)
try:
anthropic_client = anthropic.Anthropic(api_key=API_KEY_FROM_ENV)
IS_CLIENT_SUCCESSFULLY_INITIALIZED = True
print("DEBUG LLM_SERVICES: Anthropic client INITIALIZED successfully.")
except Exception as e:
print(f"DEBUG LLM_SERVICES: FAILED to initialize Anthropic client. Error: {e}")
# IS_CLIENT_SUCCESSFULLY_INITIALIZED remains False (or set explicitly)
IS_CLIENT_SUCCESSFULLY_INITIALIZED = False
else:
print(
"DEBUG LLM_SERVICES: API_KEY_FROM_ENV is FALSY or None. Anthropic client will not be initialized."
)
IS_CLIENT_SUCCESSFULLY_INITIALIZED = False
async def get_anthropic_completion(
prompt_text: str, model_name: str = "claude-3-5-sonnet-20240620"
):
print("DEBUG LLM_SERVICES: Entered get_anthropic_completion function.")
# Debug the state of client check variables right before the check
print(
f"DEBUG LLM_SERVICES: Inside get_anthropic_completion - "
f"IS_CLIENT_SUCCESSFULLY_INITIALIZED: {IS_CLIENT_SUCCESSFULLY_INITIALIZED}, "
f"anthropic_client is None: {anthropic_client is None}"
)
if not IS_CLIENT_SUCCESSFULLY_INITIALIZED or not anthropic_client:
print(
"DEBUG LLM_SERVICES: Anthropic client not available or not initialized. Returning mock response."
)
mock_action = {
"action_type": "move_object",
"object_id": "red_cube",
"target_position": [0.1, 0.1, 0.1],
}
return json.dumps(mock_action)
print(
f"DEBUG LLM_SERVICES: Client seems OK. Proceeding to call Anthropic API with model: {model_name}"
)
try:
loop = asyncio.get_event_loop()
api_call_params = {
"model": model_name,
"max_tokens": 300,
"messages": [
{"role": "user", "content": prompt_text},
{
"role": "assistant",
"content": "{",
}, # Guide the model to start with JSON
],
}
print(
f"DEBUG LLM_SERVICES: API Call Parameters: {json.dumps(api_call_params, indent=2)}"
)
response = await loop.run_in_executor(
None, lambda: anthropic_client.messages.create(**api_call_params)
)
print(
f"DEBUG LLM_SERVICES: Full API Raw Response Object: {str(response)}"
) # Use str(response) for safety
if (
response
and response.content
and isinstance(response.content, list)
and len(response.content) > 0
and hasattr(response.content[0], "text")
and response.content[0].text
):
raw_text_from_llm = response.content[0].text.strip()
if raw_text_from_llm:
llm_json_response = "{" + raw_text_from_llm
print(f"DEBUG LLM_SERVICES: Raw LLM text part: '{raw_text_from_llm}'")
print(
f"DEBUG LLM_SERVICES: Reconstructed LLM JSON: {llm_json_response}"
)
return llm_json_response
else:
print(
"DEBUG LLM_SERVICES: LLM response content[0].text was empty after stripping."
)
raise Exception("LLM returned empty text content.")
else:
print(
f"DEBUG LLM_SERVICES: LLM response content was missing or malformed. "
f"Response content: {str(response.content) if response else 'No response object'}"
)
raise Exception(
"No valid content in LLM response or unexpected response structure."
)
except Exception as e:
print(
f"DEBUG LLM_SERVICES: Error during Anthropic API call or processing response: {e}"
)
mock_action = {
"action_type": "move_object",
"object_id": "red_cube",
"target_position": [0.3, 0.3, 0.3],
} # Different mock for API error
return json.dumps(mock_action)
print(
"DEBUG LLM_SERVICES: End of llm_services.py execution during import (v_robust_init)."
)

View file

@ -0,0 +1,9 @@
pybullet>=3.2.5
numpy>=1.24.0
websockets>=11.0.3
python-dotenv>=1.0.0
wandb>=0.15.0
openai>=1.0.0
# pydantic (if using for data classes, otherwise standard dataclasses are fine for MVP)
# fastapi # Not needed for this combined MVP script approach
# uvicorn # Not needed for this combined MVP script approach

View file

@ -0,0 +1,63 @@
import asyncio
import http.server
import json
import os
import socketserver
import threading
import websockets
# Temporary websocket handler without pybullet dependency
async def visualization_websocket_handler(websocket):
print(f"Client connected from {websocket.remote_address}")
try:
# Send a test scene
test_scene = [
{
"id": "test_cube",
"type": "cube",
"position": [0, 0, 0],
"orientation_quaternion": [0, 0, 0, 1],
"scale": [1, 1, 1],
"color_rgba": [1, 0, 0, 1],
}
]
await websocket.send(
json.dumps({"type": "initial_scene", "payload": test_scene})
)
async for message in websocket:
print(f"Received message: {message}")
except websockets.exceptions.ConnectionClosed:
print("Client disconnected")
except Exception as e:
print(f"Error in websocket handler: {e}")
def run_http_server():
# Change to the visualization directory
os.chdir(os.path.join(os.path.dirname(__file__), "visualization"))
# Create an HTTP server
Handler = http.server.SimpleHTTPRequestHandler
with socketserver.TCPServer(("", 8080), Handler) as httpd:
print("HTTP Server running on http://localhost:8080")
httpd.serve_forever()
async def main():
# Start WebSocket server
async with websockets.serve(visualization_websocket_handler, "localhost", 8765):
print("WebSocket Server running on ws://localhost:8765")
await asyncio.Future() # run forever
if __name__ == "__main__":
# Start HTTP server in a separate thread
http_thread = threading.Thread(target=run_http_server, daemon=True)
http_thread.start()
# Run WebSocket server in the main thread
asyncio.run(main())

View file

@ -0,0 +1,745 @@
import argparse
import asyncio
import json
import math
import sys
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import pybullet as p
import pybullet_data
import websockets
import wandb
# LLM Service Import
from .llm_services import get_anthropic_completion
@dataclass
class ObjectState:
id: str
type: str # 'cube', 'sphere'
position: List[float]
orientation_quaternion: List[float] = field(
default_factory=lambda: [0.0, 0.0, 0.0, 1.0]
)
scale: List[float] = field(default_factory=lambda: [1.0, 1.0, 1.0])
color_rgba: List[float] = field(default_factory=lambda: [0.5, 0.5, 0.5, 1.0])
@dataclass
class SpatialTask:
task_id: str
description: str
initial_objects: List[ObjectState]
goal_description: str
target_object_id: str
reference_object_id: str
target_distance: float = 1.0
class MVPPhysicsSimulator:
def __init__(self):
self.client_id = -1
self.objects_pb_ids: Dict[str, int] = {}
self.object_configs: Dict[str, ObjectState] = {}
def initialize(self, objects: List[ObjectState]):
if self.client_id != -1:
p.disconnect(physicsClientId=self.client_id)
self.client_id = p.connect(p.DIRECT)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
p.setGravity(0, 0, -9.8, physicsClientId=self.client_id)
p.loadURDF("plane.urdf", physicsClientId=self.client_id)
self.objects_pb_ids = {}
self.object_configs = {}
for obj_state in objects:
self._add_object(obj_state)
print(f"Physics initialized with {len(self.objects_pb_ids)} objects.")
def _add_object(self, obj_state: ObjectState):
half_extents = [s / 2.0 for s in obj_state.scale]
shape_id = -1
if obj_state.type == "cube":
shape_id = p.createCollisionShape(
p.GEOM_BOX, halfExtents=half_extents, physicsClientId=self.client_id
)
elif obj_state.type == "sphere":
shape_id = p.createCollisionShape(
p.GEOM_SPHERE, radius=half_extents[0], physicsClientId=self.client_id
)
else:
print(
f"Warning: Unsupported object type '{obj_state.type}' for object ID '{obj_state.id}'"
)
return
if obj_state.type == "cube":
visual_shape_id = p.createVisualShape(
shapeType=p.GEOM_BOX,
halfExtents=half_extents,
rgbaColor=obj_state.color_rgba,
physicsClientId=self.client_id,
)
elif obj_state.type == "sphere":
visual_shape_id = p.createVisualShape(
shapeType=p.GEOM_SPHERE,
radius=half_extents[0],
rgbaColor=obj_state.color_rgba,
physicsClientId=self.client_id,
)
else:
print(
f"Warning: Unsupported object type '{obj_state.type}' for object ID '{obj_state.id}'"
)
return
body_id = p.createMultiBody(
baseMass=1,
baseCollisionShapeIndex=shape_id,
baseVisualShapeIndex=visual_shape_id,
basePosition=obj_state.position,
baseOrientation=obj_state.orientation_quaternion,
physicsClientId=self.client_id,
)
self.objects_pb_ids[obj_state.id] = body_id
self.object_configs[obj_state.id] = obj_state
def move_object(
self,
object_id: str,
target_position: List[float],
target_orientation_quaternion: Optional[List[float]] = None,
):
if object_id in self.objects_pb_ids:
body_id = self.objects_pb_ids[object_id]
if target_orientation_quaternion is None:
_, current_orientation = p.getBasePositionAndOrientation(
body_id, physicsClientId=self.client_id
)
target_orientation_quaternion = list(current_orientation)
p.resetBasePositionAndOrientation(
body_id,
target_position,
target_orientation_quaternion,
physicsClientId=self.client_id,
)
else:
print(f"Warning: Attempted to move unknown object ID '{object_id}'")
def simulate_steps(self, steps: int = 10):
for _ in range(steps):
p.stepSimulation(physicsClientId=self.client_id)
def get_current_state_for_visualization(self) -> List[Dict[str, Any]]:
viz_state = []
for obj_id, body_id in self.objects_pb_ids.items():
pos, orn_quat = p.getBasePositionAndOrientation(
body_id, physicsClientId=self.client_id
)
original_config = self.object_configs.get(obj_id)
if original_config:
viz_state.append(
{
"id": obj_id,
"type": original_config.type,
"position": list(pos),
"orientation_quaternion": list(orn_quat),
"scale": original_config.scale,
"color_rgba": original_config.color_rgba,
}
)
return viz_state
def calculate_distance(self, obj1_id: str, obj2_id: str) -> float:
pos1, pos2 = None, None
current_state = self.get_current_state_for_visualization()
for obj_data in current_state:
if obj_data["id"] == obj1_id:
pos1 = obj_data["position"]
if obj_data["id"] == obj2_id:
pos2 = obj_data["position"]
if pos1 and pos2:
return math.sqrt(sum((a - b) ** 2 for a, b in zip(pos1, pos2)))
return float("inf")
def cleanup(self):
if self.client_id != -1:
p.disconnect(physicsClientId=self.client_id)
self.client_id = -1
print("Physics simulation cleaned up.")
connected_visualization_clients = set()
global_physics_simulator_instance: Optional[MVPPhysicsSimulator] = None
# To make demo_runner accessible to the WebSocket handler in a cleaner way for server_mode
shared_demo_runner_instance: Optional["MVPDemoRunner"] = None
async def notify_visualization_clients(scene_state: List[Dict[str, Any]]):
if connected_visualization_clients:
message = json.dumps({"type": "scene_update", "payload": scene_state})
await asyncio.gather(
*[client.send(message) for client in connected_visualization_clients]
)
async def visualization_websocket_handler(websocket):
global global_physics_simulator_instance, shared_demo_runner_instance # Make shared_demo_runner_instance accessible
connected_visualization_clients.add(websocket)
print(
f"Visualization client connected: {websocket.remote_address} (Total: {len(connected_visualization_clients)})"
)
try:
if global_physics_simulator_instance:
initial_state = (
global_physics_simulator_instance.get_current_state_for_visualization()
)
await websocket.send(
json.dumps({"type": "initial_scene", "payload": initial_state})
)
async for message_str in websocket:
print(f"Message from viz client: {message_str}")
try:
data = json.loads(message_str)
command = data.get("command")
if command == "next_llm_task":
print("Received 'next_llm_task' command from client.")
if shared_demo_runner_instance:
# Run a single turn, which now defaults to using the real LLM via llm_services
asyncio.create_task(
shared_demo_runner_instance.run_single_turn_demo(
use_real_llm=True
)
)
else:
print(
"Error: shared_demo_runner_instance not found to execute 'next_llm_task'"
)
else:
print(f"Unknown command received: {command}")
except json.JSONDecodeError:
print(f"Invalid JSON from client: {message_str}")
except Exception as e:
print(f"Error processing client command: {e}")
except websockets.exceptions.ConnectionClosed:
print(
f"Visualization client disconnected. (Total: {len(connected_visualization_clients)-1})"
)
except Exception as e:
print(f"Error in visualization_websocket_handler: {e}")
finally:
connected_visualization_clients.remove(websocket)
class SpatialEnvironmentMVP:
def __init__(self):
global global_physics_simulator_instance
self.simulator = MVPPhysicsSimulator()
global_physics_simulator_instance = self.simulator
self.current_task: Optional[SpatialTask] = None
self.task_id_counter = 0
async def get_next_item(self) -> Dict[str, Any]:
self.task_id_counter += 1
task_id = f"conditional_task_{self.task_id_counter}_{uuid.uuid4().hex[:4]}"
# Start objects on opposite sides, e.g., along the x-axis
objects = [
ObjectState(
id="red_cube",
type="cube",
position=[2.0, 0.5, 0.5],
scale=[1, 1, 1],
color_rgba=[1, 0, 0, 1],
),
ObjectState(
id="blue_sphere",
type="sphere",
position=[-2.0, 0.5, 0.5],
scale=[1, 1, 1],
color_rgba=[0, 0, 1, 1],
),
]
task_description = (
"The red cube and blue sphere are on opposite sides of the YZ plane (different X signs). "
"Move the red cube so it remains on the opposite side of the YZ plane from the blue sphere, "
"but position it very close to the blue sphere (approximately 1.0 unit away)."
)
goal_description = (
"The red_cube's final x-coordinate should have the opposite sign to the blue_sphere's x-coordinate. "
"The distance between the center of the red_cube and the center of the blue_sphere "
"should be approximately 1.0 unit."
)
task = SpatialTask(
task_id=task_id,
description=task_description,
initial_objects=objects,
goal_description=goal_description,
target_object_id="red_cube",
reference_object_id="blue_sphere",
target_distance=1.0, # This remains the target for proximity
)
self.current_task = task
self.simulator.initialize(task.initial_objects)
await notify_visualization_clients(
self.simulator.get_current_state_for_visualization()
)
return {
"task_id": task.task_id,
"llm_prompt": self._create_llm_prompt(
task, objects
), # Pass initial objects for the prompt
}
def _create_llm_prompt(
self, task: SpatialTask, initial_objects_state: List[ObjectState]
) -> str:
# Use the passed initial_objects_state for accurate current positions in the prompt
objects_desc_parts = []
for obj_state in initial_objects_state:
# Find the current position from the simulator if it has been initialized and objects added
# For the initial prompt, obj_state.position IS the current position.
objects_desc_parts.append(
f"- ID: {obj_state.id}, Type: {obj_state.type}, "
f"Current Position: [{obj_state.position[0]:.2f}, {obj_state.position[1]:.2f}, "
f"{obj_state.position[2]:.2f}]"
)
objects_desc = "\n".join(objects_desc_parts)
# Reference object's current position for the hint
ref_obj_pos_str = "N/A"
for obj_state in initial_objects_state:
if obj_state.id == task.reference_object_id:
ref_obj_pos_str = (
f"[{obj_state.position[0]:.2f}, {obj_state.position[1]:.2f}, "
f"{obj_state.position[2]:.2f}]"
)
break
hint = (
f"Hint: The blue_sphere (reference object) is currently at {ref_obj_pos_str}. "
f"To keep the red_cube on the opposite side of the YZ plane, its x-coordinate "
f"should generally have the opposite sign to the blue_sphere's x-coordinate. "
f"Adjust its position to be about {task.target_distance:.1f} unit away from the blue_sphere."
)
return f"""Task: {task.description}
Goal: {task.goal_description}
Available Objects (initial state):
{objects_desc}
{hint}
You control: '{task.target_object_id}'.
Your action MUST be a JSON object like:
{{
"action_type": "move_object",
"object_id": "{task.target_object_id}",
"target_position": [x_float, y_float, z_float] # New target coordinates
}}
Only provide the JSON for the action. Do not add any other text or explanations.
Your JSON action:"""
async def collect_trajectories(
self, item_from_get_next: Dict[str, Any], llm_completion_raw: str
) -> Dict[str, Any]:
if not self.current_task:
return {
"error": "No current task set. Call get_next_item first.",
"score": 0.0,
}
llm_prompt_for_api = item_from_get_next["llm_prompt"]
print(
"\nDEBUG SPATIAL_ENV: Calling get_anthropic_completion with prompt... Timeout in 30s"
)
try:
# Add a timeout to the LLM call to prevent indefinite hanging
llm_completion_raw = await asyncio.wait_for(
get_anthropic_completion(llm_prompt_for_api), timeout=30.0
)
except asyncio.TimeoutError:
print(
"DEBUG SPATIAL_ENV: LLM call timed out after 30s. Using fallback mock."
)
llm_completion_raw = None # Indicate timeout
except Exception as e:
print(
f"DEBUG SPATIAL_ENV: Error during get_anthropic_completion call: {e}. Using fallback mock."
)
llm_completion_raw = None # Indicate other error
print(
f"DEBUG SPATIAL_ENV: llm_completion_raw received from get_anthropic_completion: '{llm_completion_raw}'"
)
parsed_action = None
# Check if llm_completion_raw is valid before attempting to parse
if (
not llm_completion_raw
or not isinstance(llm_completion_raw, str)
or llm_completion_raw.strip() == ""
or llm_completion_raw.strip() == "."
):
print(
f"DEBUG SPATIAL_ENV: llm_completion_raw is invalid ('{llm_completion_raw}'). "
f"Using internal mock action."
)
# Define a valid mock action string here
internal_mock_action_dict = {
"action_type": "move_object",
"object_id": self.current_task.target_object_id,
"target_position": [1.0, 0.5, 0.5],
} # Example mock
llm_completion_raw = json.dumps(
internal_mock_action_dict
) # Ensure it's a JSON string for parsing
print(f"DEBUG SPATIAL_ENV: Substituted internal mock: {llm_completion_raw}")
try:
json_str = llm_completion_raw.strip()
if json_str.startswith("```json"):
json_str = json_str[7:]
if json_str.startswith("```"):
json_str = json_str[3:]
if json_str.endswith("```"):
json_str = json_str[:-3]
json_str = json_str.strip()
action_data = json.loads(json_str)
if (
action_data.get("action_type") == "move_object"
and action_data.get("object_id") == self.current_task.target_object_id
and isinstance(action_data.get("target_position"), list)
and len(action_data.get("target_position")) == 3
):
parsed_action = action_data
else:
print(
f"Warning: LLM action malformed or targets wrong object: {action_data}"
)
except json.JSONDecodeError as e:
print(
f"Warning: LLM response not valid JSON: {llm_completion_raw}. Error: {e}"
)
except Exception as e:
print(
f"Warning: Unexpected error parsing LLM response: {e}. Response: {llm_completion_raw}"
)
if parsed_action:
self.simulator.move_object(
object_id=parsed_action["object_id"],
target_position=parsed_action["target_position"],
)
self.simulator.simulate_steps(20)
await notify_visualization_clients(
self.simulator.get_current_state_for_visualization()
)
await asyncio.sleep(0.05)
else:
print("No valid action parsed or executed. Scoring based on current state.")
self.simulator.simulate_steps(5)
await notify_visualization_clients(
self.simulator.get_current_state_for_visualization()
)
distance = self.simulator.calculate_distance(
self.current_task.target_object_id, self.current_task.reference_object_id
)
initial_ref_pos = None
# Find the initial position of the reference object (blue_sphere) from the stored task data
for obj_state in self.current_task.initial_objects:
if obj_state.id == self.current_task.reference_object_id:
initial_ref_pos = obj_state.position
break
final_target_pos = None
# final_sim_state_viz is needed for metadata anyway, and has current positions
final_sim_state_viz = self.simulator.get_current_state_for_visualization()
for obj_data in final_sim_state_viz:
if obj_data["id"] == self.current_task.target_object_id:
final_target_pos = obj_data["position"]
break
side_condition_met = False
if initial_ref_pos and final_target_pos:
# Task: red cube (target) starts at x=2, blue sphere (ref) at x=-2.
# Goal: move red cube near blue sphere. It should end up with x < 0 (same side as blue sphere).
initial_ref_x_sign = (
math.copysign(1.0, initial_ref_pos[0])
if initial_ref_pos[0] != 0
else 0.0
)
final_target_x_sign = (
math.copysign(1.0, final_target_pos[0])
if final_target_pos[0] != 0
else 0.0
)
if initial_ref_x_sign != 0: # Avoid issues if ref_obj starts at x=0
side_condition_met = final_target_x_sign == initial_ref_x_sign
else:
side_condition_met = (
abs(final_target_pos[0]) < 0.5
) # If ref is at x=0, target should also be near x=0
print(
f"DEBUG SCORING: InitialRefXSign: {initial_ref_x_sign}, "
f"FinalTargetXSign: {final_target_x_sign}, SideConditionMet: {side_condition_met}"
)
else:
print(
"DEBUG SCORING: Could not determine initial/final positions for side condition check."
)
score = 0.0
# Max 0.8 points for distance
if distance <= self.current_task.target_distance: # e.g., <= 1.0
score = 0.8
elif (
distance <= self.current_task.target_distance * 1.25
): # More lenient threshold for high score band
score = 0.6
elif distance <= self.current_task.target_distance * 1.75:
score = 0.4
elif distance <= self.current_task.target_distance * 2.5:
score = 0.2
# Bonus 0.2 points for correct side condition
if (
side_condition_met
): # Give bonus if side condition is met, regardless of exact distance score (as long as it tried)
score += 0.2
score = round(min(score, 1.0), 2) # Cap score at 1.0 and round
return {
"request_id": self.current_task.task_id,
"prompt_used": item_from_get_next["llm_prompt"],
"llm_completion_raw": llm_completion_raw,
"parsed_action": parsed_action,
"score": score,
"metadata": {
"task_description": self.current_task.description,
"final_distance": round(distance, 2),
"target_distance": self.current_task.target_distance,
"side_condition_met": side_condition_met,
"final_sim_state_viz": final_sim_state_viz,
},
}
class MVPDemoRunner:
def __init__(self):
self.env = SpatialEnvironmentMVP()
async def run_single_turn_demo(self, use_real_llm: bool = True):
print("\n--- Running MVP Demo Turn ---")
next_item_data = await self.env.get_next_item()
task_id = next_item_data["task_id"]
print(f"Task ID: {task_id}")
# LLM Prompt is now printed by the llm_service if use_real_llm is True, or before collect_trajectories if not.
# print(f"LLM Prompt:\n{llm_prompt}")
# The llm_completion argument to collect_trajectories is now effectively ignored if use_real_llm is true,
# as collect_trajectories will call the LLM service itself.
# For mock behavior when use_real_llm is False, we might need to adjust.
# However, our llm_service has its own mock, so we can rely on that for now if API key is missing.
# For clarity, if we are NOT using real LLM (e.g. for process mode without API key),
# we should generate a mock completion here and pass it.
# But since llm_services.py has a fallback, we might not need a separate mock here IF
# the intention is for collect_trajectories to ALWAYS try the LLM service path.
# Let's assume collect_trajectories now always drives the LLM call.
# The llm_completion parameter for collect_trajectories is now mostly for the original mock structure.
# We can pass an empty string or None, as it will be replaced by the real LLM call internally.
result = await self.env.collect_trajectories(
next_item_data, ""
) # Pass dummy llm_completion
print(f"\n--- Result for Task {task_id} ---")
print(f"Final Score: {result['score']:.2f}")
print(
f"Achieved Distance: {result['metadata']['final_distance']:.2f} "
f"(Target: {result['metadata']['target_distance']:.2f})"
)
return result
async def process_mode(args):
print(
f"Running in 'process' mode: generating {args.num_turns} trajectories to {args.output_file}"
)
run_name = f"padres_process_{args.num_turns}turns_{uuid.uuid4().hex[:4]}"
wandb_is_initialized = False
try:
wandb.init(
project="nous_hackathon_padres", # Project name for W&B
name=run_name,
config=vars(args), # Log command line arguments
)
print(f"W&B Run initialized: {run_name}. View at: {wandb.run.get_url()}")
wandb_is_initialized = True
except Exception as e:
print(f"W&B initialization failed: {e}. Proceeding without W&B logging.")
# Optionally, initialize in disabled mode: wandb.init(mode="disabled")
demo_runner = MVPDemoRunner()
results_to_write = []
try:
for i in range(args.num_turns):
turn_num = i + 1
print(f"\n--- Generating Trajectory Turn {turn_num}/{args.num_turns} ---")
turn_result = (
await demo_runner.run_single_turn_demo()
) # Assumes run_single_turn_demo uses real LLM by default now
results_to_write.append(turn_result)
if wandb_is_initialized and wandb.run:
wandb_log_data = {
"turn": turn_num,
"task_id": turn_result.get("request_id", "N/A"),
"score": turn_result.get("score", 0.0),
"final_distance": turn_result.get("metadata", {}).get(
"final_distance", float("inf")
),
"target_distance": turn_result.get("metadata", {}).get(
"target_distance", 0.0
),
"side_condition_met": int(
turn_result.get("metadata", {}).get("side_condition_met", False)
),
}
if turn_result.get("parsed_action"):
parsed_action = turn_result["parsed_action"]
wandb_log_data["action_object_id"] = parsed_action.get("object_id")
target_pos = parsed_action.get(
"target_position", [None, None, None]
)
wandb_log_data["action_target_x"] = (
target_pos[0] if target_pos and len(target_pos) > 0 else None
)
wandb_log_data["action_target_y"] = (
target_pos[1] if target_pos and len(target_pos) > 1 else None
)
wandb_log_data["action_target_z"] = (
target_pos[2] if target_pos and len(target_pos) > 2 else None
)
wandb.log(wandb_log_data)
print(
f"Logged to W&B: Turn {turn_num}, Score: {turn_result.get('score')}"
)
await asyncio.sleep(0.1)
with open(args.output_file, "w") as f:
for result_item in results_to_write:
f.write(json.dumps(result_item) + "\n")
print(
f"\nSuccessfully wrote {len(results_to_write)} trajectories to {args.output_file}"
)
finally:
if demo_runner.env.simulator:
demo_runner.env.simulator.cleanup()
if wandb_is_initialized and wandb.run:
wandb.finish()
print("Processing complete. W&B run finished.")
else:
print(
"Processing complete. (W&B was not fully initialized or did not start a run)"
)
async def server_mode():
global shared_demo_runner_instance # Make demo_runner available to the handler via this global
shared_demo_runner_instance = MVPDemoRunner()
websocket_server = await websockets.serve(
visualization_websocket_handler, # The handler will use shared_demo_runner_instance
"localhost",
8765,
)
print("Visualization WebSocket Server started on ws://localhost:8765")
print("Open visualization/index.html in your browser.")
print("You can run multiple demo turns. Press Ctrl+C to stop everything.")
try:
# Default 5 auto turns in server mode, now using LLM by default
for i in range(5): # Changed from 3 to 5
print(f"\n--- Auto Demo Turn {i+1} ---")
# use_real_llm=True by default
await shared_demo_runner_instance.run_single_turn_demo(use_real_llm=True)
await asyncio.sleep(2)
print(
"\nAutomatic demo turns complete. Server is still running for manual interaction or further tests."
)
await websocket_server.wait_closed()
except KeyboardInterrupt:
print("\nShutting down servers...")
finally:
websocket_server.close()
await websocket_server.wait_closed()
if shared_demo_runner_instance.env.simulator:
shared_demo_runner_instance.env.simulator.cleanup()
print("Servers and physics simulation stopped.")
async def main():
parser = argparse.ArgumentParser(description="Spatial RL Environment MVP")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
process_parser = subparsers.add_parser("process", help="Generate trajectory data")
process_parser.add_argument(
"--num_turns",
type=int,
default=5,
help="Number of trajectory turns to generate",
)
process_parser.add_argument(
"--output_file",
type=str,
default="trajectories.jsonl",
help="File to save trajectory data",
)
args, unknown = parser.parse_known_args()
if args.command == "process":
await process_mode(args)
elif args.command is None and not unknown:
print("No command specified, running in default server mode.")
await server_mode()
elif unknown:
print(f"Unknown arguments or command: {unknown}")
parser.print_help()
sys.exit(1)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,25 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>MVP Spatial RL Visualization</title>
<link rel="stylesheet" href="style.css">
</head>
<body>
<div id="info">
<h1>MVP Spatial RL Visualization</h1>
<p>Status: <span id="status">Connecting...</span></p>
<p>Task: <span id="taskDescription">N/A</span></p>
</div>
<div id="visualizationContainer">
<!-- Canvas will be inserted here by Three.js -->
</div>
<!-- Three.js Library -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
<!-- Optional: OrbitControls for camera interaction -->
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/controls/OrbitControls.js"></script>
<script type="module" src="main.js"></script>
</body>
</html>

View file

@ -0,0 +1,186 @@
// Make sure OrbitControls is loaded from CDN or included if you use it
// If OrbitControls is loaded globally via script tag: const OrbitControls = window.OrbitControls;
// const TWEEN = window.TWEEN; // If using TWEEN CDN
// --- Global Variables ---
let scene, camera, renderer, controls;
const objectsInScene = new Map(); // Stores THREE.Mesh objects by their ID
const websocketUrl = "ws://localhost:8765";
let socket;
// --- DOM Elements ---
const statusElement = document.getElementById('status');
const taskDescriptionElement = document.getElementById('taskDescription');
const visualizationContainer = document.getElementById('visualizationContainer');
// --- Initialization ---
function init() {
// Scene
scene = new THREE.Scene();
scene.background = new THREE.Color(0xdddddd);
// Camera
camera = new THREE.PerspectiveCamera(75, visualizationContainer.clientWidth / visualizationContainer.clientHeight, 0.1, 1000);
camera.position.set(3, 4, 5); // Adjusted camera position
camera.lookAt(0, 0, 0);
// Renderer
renderer = new THREE.WebGLRenderer({ antialias: true });
renderer.setSize(visualizationContainer.clientWidth, visualizationContainer.clientHeight);
renderer.setPixelRatio(window.devicePixelRatio);
renderer.shadowMap.enabled = true; // Enable shadows
visualizationContainer.appendChild(renderer.domElement);
// Lights
const ambientLight = new THREE.AmbientLight(0xffffff, 0.6);
scene.add(ambientLight);
const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
directionalLight.position.set(5, 10, 7);
directionalLight.castShadow = true; // Enable shadow casting for this light
// Configure shadow properties for better quality (optional)
directionalLight.shadow.mapSize.width = 1024;
directionalLight.shadow.mapSize.height = 1024;
directionalLight.shadow.camera.near = 0.5;
directionalLight.shadow.camera.far = 50;
scene.add(directionalLight);
// Ground Plane
const planeGeometry = new THREE.PlaneGeometry(20, 20);
const planeMaterial = new THREE.MeshStandardMaterial({ color: 0xaaaaaa, roughness: 0.8 });
const groundPlane = new THREE.Mesh(planeGeometry, planeMaterial);
groundPlane.rotation.x = -Math.PI / 2;
groundPlane.receiveShadow = true; // Allow plane to receive shadows
scene.add(groundPlane);
// Controls (Optional, if OrbitControls is loaded)
if (typeof OrbitControls !== 'undefined') {
controls = new OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
controls.dampingFactor = 0.05;
controls.screenSpacePanning = false;
controls.minDistance = 2;
controls.maxDistance = 20;
controls.maxPolarAngle = Math.PI / 2 - 0.05; // Prevent camera from going below ground
}
// Handle window resize
window.addEventListener('resize', onWindowResize, false);
// Start animation loop
animate();
// Connect to WebSocket
connectWebSocket();
}
// --- WebSocket Handling ---
function connectWebSocket() {
socket = new WebSocket(websocketUrl);
statusElement.textContent = "Connecting to WebSocket...";
socket.onopen = () => {
statusElement.textContent = "Connected to Physics Server!";
console.log("WebSocket connected.");
// You could send a "client_ready" message if needed
};
socket.onmessage = (event) => {
try {
const message = JSON.parse(event.data);
// console.log("Message from server:", message);
if (message.type === "scene_update" || message.type === "initial_scene") {
updateScene(message.payload); // payload is List[ObjectData]
if (message.type === "initial_scene" && message.task_description) {
taskDescriptionElement.textContent = message.task_description;
}
} else if (message.type === "task_info") { // Example for updating task description
taskDescriptionElement.textContent = message.description || "N/A";
}
} catch (e) {
console.error("Error processing message from server:", e, event.data);
}
};
socket.onerror = (error) => {
statusElement.textContent = "WebSocket Error!";
console.error("WebSocket Error:", error);
};
socket.onclose = () => {
statusElement.textContent = "Disconnected. Attempting to reconnect in 3s...";
console.log("WebSocket disconnected. Reconnecting in 3 seconds...");
setTimeout(connectWebSocket, 3000); // Simple reconnect logic
};
}
// --- Three.js Scene Updates ---
function updateScene(objectStates) { // objectStates is List of Dicts from server
const receivedIds = new Set();
objectStates.forEach(objState => {
receivedIds.add(objState.id);
let threeObject = objectsInScene.get(objState.id);
if (!threeObject) { // Object doesn't exist, create it
let geometry;
const scale = objState.scale || [1,1,1];
if (objState.type === "cube") {
geometry = new THREE.BoxGeometry(scale[0], scale[1], scale[2]);
} else if (objState.type === "sphere") {
geometry = new THREE.SphereGeometry(scale[0] / 2, 32, 16); // Assume scale[0] is diameter
} else {
console.warn("Unsupported object type for visualization:", objState.type);
geometry = new THREE.BoxGeometry(1, 1, 1); // Default placeholder
}
const color = new THREE.Color(...(objState.color_rgba ? objState.color_rgba.slice(0,3) : [0.5, 0.5, 0.5]));
const material = new THREE.MeshStandardMaterial({ color: color, roughness: 0.5, metalness: 0.1 });
threeObject = new THREE.Mesh(geometry, material);
threeObject.name = objState.id; // Useful for debugging
threeObject.castShadow = true; // Object casts shadows
threeObject.receiveShadow = false; // Usually objects don't receive shadows on themselves unless complex
scene.add(threeObject);
objectsInScene.set(objState.id, threeObject);
}
// Update position and orientation
if (objState.position) {
threeObject.position.set(...objState.position);
}
if (objState.orientation_quaternion) {
threeObject.quaternion.set(...objState.orientation_quaternion);
}
// TODO: Update color or other properties if they can change dynamically
});
// Remove objects that are in Three.js scene but not in the new state
objectsInScene.forEach((obj, id) => {
if (!receivedIds.has(id)) {
scene.remove(obj);
obj.geometry.dispose(); // Dispose geometry
obj.material.dispose(); // Dispose material
objectsInScene.delete(id);
console.log(`Removed object ${id} from scene.`);
}
});
}
// --- Animation Loop & Resize ---
function animate() {
requestAnimationFrame(animate);
if (controls) {
controls.update(); // Only if OrbitControls is used
}
renderer.render(scene, camera);
}
function onWindowResize() {
camera.aspect = visualizationContainer.clientWidth / visualizationContainer.clientHeight;
camera.updateProjectionMatrix();
renderer.setSize(visualizationContainer.clientWidth / visualizationContainer.clientHeight);
}
// --- Start Everything ---
init();

View file

@ -0,0 +1,65 @@
body {
margin: 0;
padding: 0;
font-family: Arial, sans-serif;
overflow: hidden;
background-color: #f0f0f0;
color: #333;
}
#container {
display: flex;
width: 100vw;
height: 100vh;
}
#canvas-container {
flex: 0.7;
height: 100vh;
}
#info-panel {
flex: 0.3;
padding: 20px;
background-color: #f5f5f5;
border-left: 1px solid #ddd;
overflow-y: auto;
}
#score {
font-size: 1.2em;
font-weight: bold;
margin: 10px 0;
padding: 10px;
background-color: #fff;
border-radius: 4px;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
#objects-info {
margin-top: 20px;
line-height: 1.6;
}
#info {
padding: 10px;
background-color: #fff;
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
}
#info h1 {
margin-top: 0;
font-size: 1.5em;
}
#visualizationContainer {
width: 100vw;
height: 80vh;
display: flex;
justify-content: center;
align-items: center;
}
canvas {
display: block;
}