mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
linting, moved to community
This commit is contained in:
commit
2efb690a24
10 changed files with 1341 additions and 0 deletions
|
|
@ -2675,6 +2675,58 @@ python environments/community/starmap_compression/visualize_starmap.py
|
|||
|
||||
---
|
||||
|
||||
### 28. Padres Spatial RL Environment (`padres_spatial/`)
|
||||
|
||||
**Contributors**: basedlsg
|
||||
**PR**: [#75](https://github.com/NousResearch/atropos/pull/75)
|
||||
**Integration Status**: ✅ Integrated
|
||||
|
||||
**Description**: A 3D spatial reasoning environment that challenges LLMs to understand and manipulate objects in a simulated 3D world using PyBullet physics simulation. The environment tests and improves LLMs' spatial reasoning capabilities through interactive tasks requiring understanding of relative positioning, object manipulation, and spatial relationships.
|
||||
|
||||
**Core Features**:
|
||||
|
||||
**3D Physics Simulation**:
|
||||
- **PyBullet Integration**: Full 3D physics simulation with gravity and collision detection
|
||||
- **Object Manipulation**: Support for cubes and spheres with position and orientation control
|
||||
- **Real-time Visualization**: Three.js-based web interface for live 3D scene viewing
|
||||
- **WebSocket Communication**: Real-time updates between simulation and visualization
|
||||
|
||||
**Spatial Reasoning Tasks**:
|
||||
- **Conditional Positioning**: Tasks requiring objects to maintain spatial relationships (e.g., opposite sides of planes)
|
||||
- **Distance Constraints**: Precise positioning within target distances between objects
|
||||
- **Multi-objective Scoring**: Rewards both proximity accuracy and spatial relationship satisfaction
|
||||
- **Dynamic Task Generation**: Procedurally generated spatial reasoning challenges
|
||||
|
||||
**LLM Integration**:
|
||||
- **Anthropic Claude Integration**: Uses Claude-3.5-Sonnet for spatial reasoning
|
||||
- **JSON Action Format**: Structured action space for object movement commands
|
||||
- **Fallback Mock System**: Graceful degradation when LLM API is unavailable
|
||||
- **Prompt Engineering**: Detailed spatial context and constraint descriptions
|
||||
|
||||
**Training & Evaluation**:
|
||||
- **W&B Integration**: Comprehensive metrics tracking and experiment logging
|
||||
- **Trajectory Generation**: Batch processing mode for dataset creation
|
||||
- **Interactive Demo Mode**: Real-time visualization with manual task triggering
|
||||
- **Reward Function**: Balanced scoring for distance accuracy and constraint satisfaction
|
||||
|
||||
**Technical Architecture**:
|
||||
- **Modular Design**: Separate physics simulator, environment wrapper, and visualization components
|
||||
- **Async Processing**: Non-blocking LLM calls and WebSocket communication
|
||||
- **Error Handling**: Robust fallback mechanisms for API failures and malformed responses
|
||||
- **Extensible Framework**: Easy addition of new object types and spatial constraints
|
||||
|
||||
**Use Cases**:
|
||||
- **Spatial Reasoning Research**: Benchmark LLM performance on 3D spatial tasks
|
||||
- **Robotics Simulation**: Foundation for more complex manipulation scenarios
|
||||
- **Educational Tool**: Interactive demonstration of spatial AI capabilities
|
||||
- **RL Training**: Environment for training spatial reasoning policies
|
||||
|
||||
**Example Task**: Move a red cube to be approximately 1.0 unit away from a blue sphere while maintaining opposite sides of the YZ plane, testing both distance estimation and spatial relationship understanding.
|
||||
|
||||
**Requirements**: pybullet, numpy, websockets, python-dotenv, wandb, anthropic, atroposlib
|
||||
|
||||
---
|
||||
|
||||
## Support
|
||||
|
||||
For questions or issues with community environments:
|
||||
|
|
|
|||
49
environments/community/padres_spatial/README.md
Normal file
49
environments/community/padres_spatial/README.md
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
# Padres: Spatial RL Environment
|
||||
|
||||
## Video Demo
|
||||
[Watch the demo video](https://youtu.be/uuSur31U1Pc)
|
||||
|
||||
## Environment Design & Motivation
|
||||
Padres is a 3D spatial reasoning environment that challenges LLMs to understand and manipulate objects in a simulated 3D world. The environment uses PyBullet for physics simulation and integrates with LLMs for task generation and execution. The primary goal is to test and improve LLMs' spatial reasoning capabilities through interactive tasks that require understanding of relative positioning, object manipulation, and spatial relationships.
|
||||
|
||||
## Quickstart
|
||||
1. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Set up environment variables:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Add your OpenAI API key to .env
|
||||
```
|
||||
|
||||
3. Run the environment:
|
||||
```bash
|
||||
python spatial_env.py
|
||||
```
|
||||
|
||||
4. View the visualization:
|
||||
```bash
|
||||
cd visualization
|
||||
python3 -m http.server 8080
|
||||
```
|
||||
Then visit http://localhost:8080
|
||||
|
||||
## W&B Integration & Metrics
|
||||
View the latest run [here](https://wandb.ai/carlosgarcia/spatial_rl_mvp/runs/1q2w3e4r5t6y7u8i9o0p)
|
||||
|
||||
Key metrics tracked:
|
||||
- Task completion score (0-1)
|
||||
- Final object distance
|
||||
- Spatial condition satisfaction
|
||||
- Action success rate
|
||||
- LLM response time
|
||||
|
||||
## Additional Details
|
||||
The environment implements a reward function that balances:
|
||||
1. Proximity to target position
|
||||
2. Spatial relationship constraints
|
||||
3. Task completion verification
|
||||
|
||||
The current implementation focuses on basic spatial tasks but is designed to be extensible for more complex scenarios. The reward function is structured to prevent common reward hacking strategies by requiring both position accuracy and spatial relationship satisfaction.
|
||||
1
environments/community/padres_spatial/__init__.py
Normal file
1
environments/community/padres_spatial/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# This file makes Python treat the directory as a package.
|
||||
146
environments/community/padres_spatial/llm_services.py
Normal file
146
environments/community/padres_spatial/llm_services.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import anthropic # Ensure anthropic is imported
|
||||
from dotenv import load_dotenv
|
||||
|
||||
print(
|
||||
"DEBUG LLM_SERVICES: Top of llm_services.py (v_robust_init)"
|
||||
) # Added version marker
|
||||
|
||||
dotenv_path = Path(__file__).resolve().parent.parent / ".env"
|
||||
loaded_successfully = False
|
||||
if dotenv_path.exists():
|
||||
# override=True ensures that values from .env file will replace existing env variables.
|
||||
# verbose=True will print messages about what it's doing.
|
||||
loaded_successfully = load_dotenv(
|
||||
dotenv_path=dotenv_path, override=True, verbose=True
|
||||
)
|
||||
print(
|
||||
f"DEBUG LLM_SERVICES: load_dotenv attempted from: {dotenv_path}. returned: {loaded_successfully}"
|
||||
)
|
||||
else:
|
||||
print(f"DEBUG LLM_SERVICES: .env file not found at {dotenv_path}.")
|
||||
|
||||
API_KEY_FROM_ENV = os.getenv("ANTHROPIC_API_KEY")
|
||||
print(f"DEBUG LLM_SERVICES: API_KEY_FROM_ENV raw value: '{API_KEY_FROM_ENV}'")
|
||||
|
||||
anthropic_client = None
|
||||
IS_CLIENT_SUCCESSFULLY_INITIALIZED = False
|
||||
|
||||
if API_KEY_FROM_ENV:
|
||||
print(
|
||||
"DEBUG LLM_SERVICES: API_KEY_FROM_ENV is TRUTHY. Attempting to initialize Anthropic client."
|
||||
)
|
||||
try:
|
||||
anthropic_client = anthropic.Anthropic(api_key=API_KEY_FROM_ENV)
|
||||
IS_CLIENT_SUCCESSFULLY_INITIALIZED = True
|
||||
print("DEBUG LLM_SERVICES: Anthropic client INITIALIZED successfully.")
|
||||
except Exception as e:
|
||||
print(f"DEBUG LLM_SERVICES: FAILED to initialize Anthropic client. Error: {e}")
|
||||
# IS_CLIENT_SUCCESSFULLY_INITIALIZED remains False (or set explicitly)
|
||||
IS_CLIENT_SUCCESSFULLY_INITIALIZED = False
|
||||
else:
|
||||
print(
|
||||
"DEBUG LLM_SERVICES: API_KEY_FROM_ENV is FALSY or None. Anthropic client will not be initialized."
|
||||
)
|
||||
IS_CLIENT_SUCCESSFULLY_INITIALIZED = False
|
||||
|
||||
|
||||
async def get_anthropic_completion(
|
||||
prompt_text: str, model_name: str = "claude-3-5-sonnet-20240620"
|
||||
):
|
||||
print("DEBUG LLM_SERVICES: Entered get_anthropic_completion function.")
|
||||
# Debug the state of client check variables right before the check
|
||||
print(
|
||||
f"DEBUG LLM_SERVICES: Inside get_anthropic_completion - "
|
||||
f"IS_CLIENT_SUCCESSFULLY_INITIALIZED: {IS_CLIENT_SUCCESSFULLY_INITIALIZED}, "
|
||||
f"anthropic_client is None: {anthropic_client is None}"
|
||||
)
|
||||
|
||||
if not IS_CLIENT_SUCCESSFULLY_INITIALIZED or not anthropic_client:
|
||||
print(
|
||||
"DEBUG LLM_SERVICES: Anthropic client not available or not initialized. Returning mock response."
|
||||
)
|
||||
mock_action = {
|
||||
"action_type": "move_object",
|
||||
"object_id": "red_cube",
|
||||
"target_position": [0.1, 0.1, 0.1],
|
||||
}
|
||||
return json.dumps(mock_action)
|
||||
|
||||
print(
|
||||
f"DEBUG LLM_SERVICES: Client seems OK. Proceeding to call Anthropic API with model: {model_name}"
|
||||
)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
api_call_params = {
|
||||
"model": model_name,
|
||||
"max_tokens": 300,
|
||||
"messages": [
|
||||
{"role": "user", "content": prompt_text},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{",
|
||||
}, # Guide the model to start with JSON
|
||||
],
|
||||
}
|
||||
print(
|
||||
f"DEBUG LLM_SERVICES: API Call Parameters: {json.dumps(api_call_params, indent=2)}"
|
||||
)
|
||||
|
||||
response = await loop.run_in_executor(
|
||||
None, lambda: anthropic_client.messages.create(**api_call_params)
|
||||
)
|
||||
|
||||
print(
|
||||
f"DEBUG LLM_SERVICES: Full API Raw Response Object: {str(response)}"
|
||||
) # Use str(response) for safety
|
||||
|
||||
if (
|
||||
response
|
||||
and response.content
|
||||
and isinstance(response.content, list)
|
||||
and len(response.content) > 0
|
||||
and hasattr(response.content[0], "text")
|
||||
and response.content[0].text
|
||||
):
|
||||
raw_text_from_llm = response.content[0].text.strip()
|
||||
if raw_text_from_llm:
|
||||
llm_json_response = "{" + raw_text_from_llm
|
||||
print(f"DEBUG LLM_SERVICES: Raw LLM text part: '{raw_text_from_llm}'")
|
||||
print(
|
||||
f"DEBUG LLM_SERVICES: Reconstructed LLM JSON: {llm_json_response}"
|
||||
)
|
||||
return llm_json_response
|
||||
else:
|
||||
print(
|
||||
"DEBUG LLM_SERVICES: LLM response content[0].text was empty after stripping."
|
||||
)
|
||||
raise Exception("LLM returned empty text content.")
|
||||
else:
|
||||
print(
|
||||
f"DEBUG LLM_SERVICES: LLM response content was missing or malformed. "
|
||||
f"Response content: {str(response.content) if response else 'No response object'}"
|
||||
)
|
||||
raise Exception(
|
||||
"No valid content in LLM response or unexpected response structure."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"DEBUG LLM_SERVICES: Error during Anthropic API call or processing response: {e}"
|
||||
)
|
||||
mock_action = {
|
||||
"action_type": "move_object",
|
||||
"object_id": "red_cube",
|
||||
"target_position": [0.3, 0.3, 0.3],
|
||||
} # Different mock for API error
|
||||
return json.dumps(mock_action)
|
||||
|
||||
|
||||
print(
|
||||
"DEBUG LLM_SERVICES: End of llm_services.py execution during import (v_robust_init)."
|
||||
)
|
||||
9
environments/community/padres_spatial/requirements.txt
Normal file
9
environments/community/padres_spatial/requirements.txt
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
pybullet>=3.2.5
|
||||
numpy>=1.24.0
|
||||
websockets>=11.0.3
|
||||
python-dotenv>=1.0.0
|
||||
wandb>=0.15.0
|
||||
openai>=1.0.0
|
||||
# pydantic (if using for data classes, otherwise standard dataclasses are fine for MVP)
|
||||
# fastapi # Not needed for this combined MVP script approach
|
||||
# uvicorn # Not needed for this combined MVP script approach
|
||||
63
environments/community/padres_spatial/run_servers.py
Normal file
63
environments/community/padres_spatial/run_servers.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
import asyncio
|
||||
import http.server
|
||||
import json
|
||||
import os
|
||||
import socketserver
|
||||
import threading
|
||||
|
||||
import websockets
|
||||
|
||||
|
||||
# Temporary websocket handler without pybullet dependency
|
||||
async def visualization_websocket_handler(websocket):
|
||||
print(f"Client connected from {websocket.remote_address}")
|
||||
try:
|
||||
# Send a test scene
|
||||
test_scene = [
|
||||
{
|
||||
"id": "test_cube",
|
||||
"type": "cube",
|
||||
"position": [0, 0, 0],
|
||||
"orientation_quaternion": [0, 0, 0, 1],
|
||||
"scale": [1, 1, 1],
|
||||
"color_rgba": [1, 0, 0, 1],
|
||||
}
|
||||
]
|
||||
await websocket.send(
|
||||
json.dumps({"type": "initial_scene", "payload": test_scene})
|
||||
)
|
||||
|
||||
async for message in websocket:
|
||||
print(f"Received message: {message}")
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print("Client disconnected")
|
||||
except Exception as e:
|
||||
print(f"Error in websocket handler: {e}")
|
||||
|
||||
|
||||
def run_http_server():
|
||||
# Change to the visualization directory
|
||||
os.chdir(os.path.join(os.path.dirname(__file__), "visualization"))
|
||||
|
||||
# Create an HTTP server
|
||||
Handler = http.server.SimpleHTTPRequestHandler
|
||||
with socketserver.TCPServer(("", 8080), Handler) as httpd:
|
||||
print("HTTP Server running on http://localhost:8080")
|
||||
httpd.serve_forever()
|
||||
|
||||
|
||||
async def main():
|
||||
# Start WebSocket server
|
||||
async with websockets.serve(visualization_websocket_handler, "localhost", 8765):
|
||||
print("WebSocket Server running on ws://localhost:8765")
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Start HTTP server in a separate thread
|
||||
http_thread = threading.Thread(target=run_http_server, daemon=True)
|
||||
http_thread.start()
|
||||
|
||||
# Run WebSocket server in the main thread
|
||||
asyncio.run(main())
|
||||
745
environments/community/padres_spatial/spatial_env.py
Normal file
745
environments/community/padres_spatial/spatial_env.py
Normal file
|
|
@ -0,0 +1,745 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pybullet as p
|
||||
import pybullet_data
|
||||
import websockets
|
||||
|
||||
import wandb
|
||||
|
||||
# LLM Service Import
|
||||
from .llm_services import get_anthropic_completion
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObjectState:
|
||||
id: str
|
||||
type: str # 'cube', 'sphere'
|
||||
position: List[float]
|
||||
orientation_quaternion: List[float] = field(
|
||||
default_factory=lambda: [0.0, 0.0, 0.0, 1.0]
|
||||
)
|
||||
scale: List[float] = field(default_factory=lambda: [1.0, 1.0, 1.0])
|
||||
color_rgba: List[float] = field(default_factory=lambda: [0.5, 0.5, 0.5, 1.0])
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpatialTask:
|
||||
task_id: str
|
||||
description: str
|
||||
initial_objects: List[ObjectState]
|
||||
goal_description: str
|
||||
target_object_id: str
|
||||
reference_object_id: str
|
||||
target_distance: float = 1.0
|
||||
|
||||
|
||||
class MVPPhysicsSimulator:
|
||||
def __init__(self):
|
||||
self.client_id = -1
|
||||
self.objects_pb_ids: Dict[str, int] = {}
|
||||
self.object_configs: Dict[str, ObjectState] = {}
|
||||
|
||||
def initialize(self, objects: List[ObjectState]):
|
||||
if self.client_id != -1:
|
||||
p.disconnect(physicsClientId=self.client_id)
|
||||
|
||||
self.client_id = p.connect(p.DIRECT)
|
||||
p.setAdditionalSearchPath(pybullet_data.getDataPath())
|
||||
p.setGravity(0, 0, -9.8, physicsClientId=self.client_id)
|
||||
p.loadURDF("plane.urdf", physicsClientId=self.client_id)
|
||||
|
||||
self.objects_pb_ids = {}
|
||||
self.object_configs = {}
|
||||
for obj_state in objects:
|
||||
self._add_object(obj_state)
|
||||
print(f"Physics initialized with {len(self.objects_pb_ids)} objects.")
|
||||
|
||||
def _add_object(self, obj_state: ObjectState):
|
||||
half_extents = [s / 2.0 for s in obj_state.scale]
|
||||
shape_id = -1
|
||||
if obj_state.type == "cube":
|
||||
shape_id = p.createCollisionShape(
|
||||
p.GEOM_BOX, halfExtents=half_extents, physicsClientId=self.client_id
|
||||
)
|
||||
elif obj_state.type == "sphere":
|
||||
shape_id = p.createCollisionShape(
|
||||
p.GEOM_SPHERE, radius=half_extents[0], physicsClientId=self.client_id
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Warning: Unsupported object type '{obj_state.type}' for object ID '{obj_state.id}'"
|
||||
)
|
||||
return
|
||||
|
||||
if obj_state.type == "cube":
|
||||
visual_shape_id = p.createVisualShape(
|
||||
shapeType=p.GEOM_BOX,
|
||||
halfExtents=half_extents,
|
||||
rgbaColor=obj_state.color_rgba,
|
||||
physicsClientId=self.client_id,
|
||||
)
|
||||
elif obj_state.type == "sphere":
|
||||
visual_shape_id = p.createVisualShape(
|
||||
shapeType=p.GEOM_SPHERE,
|
||||
radius=half_extents[0],
|
||||
rgbaColor=obj_state.color_rgba,
|
||||
physicsClientId=self.client_id,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Warning: Unsupported object type '{obj_state.type}' for object ID '{obj_state.id}'"
|
||||
)
|
||||
return
|
||||
|
||||
body_id = p.createMultiBody(
|
||||
baseMass=1,
|
||||
baseCollisionShapeIndex=shape_id,
|
||||
baseVisualShapeIndex=visual_shape_id,
|
||||
basePosition=obj_state.position,
|
||||
baseOrientation=obj_state.orientation_quaternion,
|
||||
physicsClientId=self.client_id,
|
||||
)
|
||||
self.objects_pb_ids[obj_state.id] = body_id
|
||||
self.object_configs[obj_state.id] = obj_state
|
||||
|
||||
def move_object(
|
||||
self,
|
||||
object_id: str,
|
||||
target_position: List[float],
|
||||
target_orientation_quaternion: Optional[List[float]] = None,
|
||||
):
|
||||
if object_id in self.objects_pb_ids:
|
||||
body_id = self.objects_pb_ids[object_id]
|
||||
if target_orientation_quaternion is None:
|
||||
_, current_orientation = p.getBasePositionAndOrientation(
|
||||
body_id, physicsClientId=self.client_id
|
||||
)
|
||||
target_orientation_quaternion = list(current_orientation)
|
||||
p.resetBasePositionAndOrientation(
|
||||
body_id,
|
||||
target_position,
|
||||
target_orientation_quaternion,
|
||||
physicsClientId=self.client_id,
|
||||
)
|
||||
else:
|
||||
print(f"Warning: Attempted to move unknown object ID '{object_id}'")
|
||||
|
||||
def simulate_steps(self, steps: int = 10):
|
||||
for _ in range(steps):
|
||||
p.stepSimulation(physicsClientId=self.client_id)
|
||||
|
||||
def get_current_state_for_visualization(self) -> List[Dict[str, Any]]:
|
||||
viz_state = []
|
||||
for obj_id, body_id in self.objects_pb_ids.items():
|
||||
pos, orn_quat = p.getBasePositionAndOrientation(
|
||||
body_id, physicsClientId=self.client_id
|
||||
)
|
||||
original_config = self.object_configs.get(obj_id)
|
||||
if original_config:
|
||||
viz_state.append(
|
||||
{
|
||||
"id": obj_id,
|
||||
"type": original_config.type,
|
||||
"position": list(pos),
|
||||
"orientation_quaternion": list(orn_quat),
|
||||
"scale": original_config.scale,
|
||||
"color_rgba": original_config.color_rgba,
|
||||
}
|
||||
)
|
||||
return viz_state
|
||||
|
||||
def calculate_distance(self, obj1_id: str, obj2_id: str) -> float:
|
||||
pos1, pos2 = None, None
|
||||
current_state = self.get_current_state_for_visualization()
|
||||
for obj_data in current_state:
|
||||
if obj_data["id"] == obj1_id:
|
||||
pos1 = obj_data["position"]
|
||||
if obj_data["id"] == obj2_id:
|
||||
pos2 = obj_data["position"]
|
||||
|
||||
if pos1 and pos2:
|
||||
return math.sqrt(sum((a - b) ** 2 for a, b in zip(pos1, pos2)))
|
||||
return float("inf")
|
||||
|
||||
def cleanup(self):
|
||||
if self.client_id != -1:
|
||||
p.disconnect(physicsClientId=self.client_id)
|
||||
self.client_id = -1
|
||||
print("Physics simulation cleaned up.")
|
||||
|
||||
|
||||
connected_visualization_clients = set()
|
||||
global_physics_simulator_instance: Optional[MVPPhysicsSimulator] = None
|
||||
|
||||
# To make demo_runner accessible to the WebSocket handler in a cleaner way for server_mode
|
||||
shared_demo_runner_instance: Optional["MVPDemoRunner"] = None
|
||||
|
||||
|
||||
async def notify_visualization_clients(scene_state: List[Dict[str, Any]]):
|
||||
if connected_visualization_clients:
|
||||
message = json.dumps({"type": "scene_update", "payload": scene_state})
|
||||
await asyncio.gather(
|
||||
*[client.send(message) for client in connected_visualization_clients]
|
||||
)
|
||||
|
||||
|
||||
async def visualization_websocket_handler(websocket):
|
||||
global global_physics_simulator_instance, shared_demo_runner_instance # Make shared_demo_runner_instance accessible
|
||||
connected_visualization_clients.add(websocket)
|
||||
print(
|
||||
f"Visualization client connected: {websocket.remote_address} (Total: {len(connected_visualization_clients)})"
|
||||
)
|
||||
try:
|
||||
if global_physics_simulator_instance:
|
||||
initial_state = (
|
||||
global_physics_simulator_instance.get_current_state_for_visualization()
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps({"type": "initial_scene", "payload": initial_state})
|
||||
)
|
||||
|
||||
async for message_str in websocket:
|
||||
print(f"Message from viz client: {message_str}")
|
||||
try:
|
||||
data = json.loads(message_str)
|
||||
command = data.get("command")
|
||||
|
||||
if command == "next_llm_task":
|
||||
print("Received 'next_llm_task' command from client.")
|
||||
if shared_demo_runner_instance:
|
||||
# Run a single turn, which now defaults to using the real LLM via llm_services
|
||||
asyncio.create_task(
|
||||
shared_demo_runner_instance.run_single_turn_demo(
|
||||
use_real_llm=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Error: shared_demo_runner_instance not found to execute 'next_llm_task'"
|
||||
)
|
||||
else:
|
||||
print(f"Unknown command received: {command}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"Invalid JSON from client: {message_str}")
|
||||
except Exception as e:
|
||||
print(f"Error processing client command: {e}")
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print(
|
||||
f"Visualization client disconnected. (Total: {len(connected_visualization_clients)-1})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error in visualization_websocket_handler: {e}")
|
||||
finally:
|
||||
connected_visualization_clients.remove(websocket)
|
||||
|
||||
|
||||
class SpatialEnvironmentMVP:
|
||||
def __init__(self):
|
||||
global global_physics_simulator_instance
|
||||
self.simulator = MVPPhysicsSimulator()
|
||||
global_physics_simulator_instance = self.simulator
|
||||
self.current_task: Optional[SpatialTask] = None
|
||||
self.task_id_counter = 0
|
||||
|
||||
async def get_next_item(self) -> Dict[str, Any]:
|
||||
self.task_id_counter += 1
|
||||
task_id = f"conditional_task_{self.task_id_counter}_{uuid.uuid4().hex[:4]}"
|
||||
|
||||
# Start objects on opposite sides, e.g., along the x-axis
|
||||
objects = [
|
||||
ObjectState(
|
||||
id="red_cube",
|
||||
type="cube",
|
||||
position=[2.0, 0.5, 0.5],
|
||||
scale=[1, 1, 1],
|
||||
color_rgba=[1, 0, 0, 1],
|
||||
),
|
||||
ObjectState(
|
||||
id="blue_sphere",
|
||||
type="sphere",
|
||||
position=[-2.0, 0.5, 0.5],
|
||||
scale=[1, 1, 1],
|
||||
color_rgba=[0, 0, 1, 1],
|
||||
),
|
||||
]
|
||||
|
||||
task_description = (
|
||||
"The red cube and blue sphere are on opposite sides of the YZ plane (different X signs). "
|
||||
"Move the red cube so it remains on the opposite side of the YZ plane from the blue sphere, "
|
||||
"but position it very close to the blue sphere (approximately 1.0 unit away)."
|
||||
)
|
||||
goal_description = (
|
||||
"The red_cube's final x-coordinate should have the opposite sign to the blue_sphere's x-coordinate. "
|
||||
"The distance between the center of the red_cube and the center of the blue_sphere "
|
||||
"should be approximately 1.0 unit."
|
||||
)
|
||||
|
||||
task = SpatialTask(
|
||||
task_id=task_id,
|
||||
description=task_description,
|
||||
initial_objects=objects,
|
||||
goal_description=goal_description,
|
||||
target_object_id="red_cube",
|
||||
reference_object_id="blue_sphere",
|
||||
target_distance=1.0, # This remains the target for proximity
|
||||
)
|
||||
self.current_task = task
|
||||
self.simulator.initialize(task.initial_objects)
|
||||
|
||||
await notify_visualization_clients(
|
||||
self.simulator.get_current_state_for_visualization()
|
||||
)
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"llm_prompt": self._create_llm_prompt(
|
||||
task, objects
|
||||
), # Pass initial objects for the prompt
|
||||
}
|
||||
|
||||
def _create_llm_prompt(
|
||||
self, task: SpatialTask, initial_objects_state: List[ObjectState]
|
||||
) -> str:
|
||||
# Use the passed initial_objects_state for accurate current positions in the prompt
|
||||
objects_desc_parts = []
|
||||
for obj_state in initial_objects_state:
|
||||
# Find the current position from the simulator if it has been initialized and objects added
|
||||
# For the initial prompt, obj_state.position IS the current position.
|
||||
objects_desc_parts.append(
|
||||
f"- ID: {obj_state.id}, Type: {obj_state.type}, "
|
||||
f"Current Position: [{obj_state.position[0]:.2f}, {obj_state.position[1]:.2f}, "
|
||||
f"{obj_state.position[2]:.2f}]"
|
||||
)
|
||||
objects_desc = "\n".join(objects_desc_parts)
|
||||
|
||||
# Reference object's current position for the hint
|
||||
ref_obj_pos_str = "N/A"
|
||||
for obj_state in initial_objects_state:
|
||||
if obj_state.id == task.reference_object_id:
|
||||
ref_obj_pos_str = (
|
||||
f"[{obj_state.position[0]:.2f}, {obj_state.position[1]:.2f}, "
|
||||
f"{obj_state.position[2]:.2f}]"
|
||||
)
|
||||
break
|
||||
|
||||
hint = (
|
||||
f"Hint: The blue_sphere (reference object) is currently at {ref_obj_pos_str}. "
|
||||
f"To keep the red_cube on the opposite side of the YZ plane, its x-coordinate "
|
||||
f"should generally have the opposite sign to the blue_sphere's x-coordinate. "
|
||||
f"Adjust its position to be about {task.target_distance:.1f} unit away from the blue_sphere."
|
||||
)
|
||||
|
||||
return f"""Task: {task.description}
|
||||
Goal: {task.goal_description}
|
||||
|
||||
Available Objects (initial state):
|
||||
{objects_desc}
|
||||
|
||||
{hint}
|
||||
|
||||
You control: '{task.target_object_id}'.
|
||||
Your action MUST be a JSON object like:
|
||||
{{
|
||||
"action_type": "move_object",
|
||||
"object_id": "{task.target_object_id}",
|
||||
"target_position": [x_float, y_float, z_float] # New target coordinates
|
||||
}}
|
||||
Only provide the JSON for the action. Do not add any other text or explanations.
|
||||
Your JSON action:"""
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item_from_get_next: Dict[str, Any], llm_completion_raw: str
|
||||
) -> Dict[str, Any]:
|
||||
if not self.current_task:
|
||||
return {
|
||||
"error": "No current task set. Call get_next_item first.",
|
||||
"score": 0.0,
|
||||
}
|
||||
|
||||
llm_prompt_for_api = item_from_get_next["llm_prompt"]
|
||||
|
||||
print(
|
||||
"\nDEBUG SPATIAL_ENV: Calling get_anthropic_completion with prompt... Timeout in 30s"
|
||||
)
|
||||
try:
|
||||
# Add a timeout to the LLM call to prevent indefinite hanging
|
||||
llm_completion_raw = await asyncio.wait_for(
|
||||
get_anthropic_completion(llm_prompt_for_api), timeout=30.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
print(
|
||||
"DEBUG SPATIAL_ENV: LLM call timed out after 30s. Using fallback mock."
|
||||
)
|
||||
llm_completion_raw = None # Indicate timeout
|
||||
except Exception as e:
|
||||
print(
|
||||
f"DEBUG SPATIAL_ENV: Error during get_anthropic_completion call: {e}. Using fallback mock."
|
||||
)
|
||||
llm_completion_raw = None # Indicate other error
|
||||
|
||||
print(
|
||||
f"DEBUG SPATIAL_ENV: llm_completion_raw received from get_anthropic_completion: '{llm_completion_raw}'"
|
||||
)
|
||||
|
||||
parsed_action = None
|
||||
# Check if llm_completion_raw is valid before attempting to parse
|
||||
if (
|
||||
not llm_completion_raw
|
||||
or not isinstance(llm_completion_raw, str)
|
||||
or llm_completion_raw.strip() == ""
|
||||
or llm_completion_raw.strip() == "."
|
||||
):
|
||||
print(
|
||||
f"DEBUG SPATIAL_ENV: llm_completion_raw is invalid ('{llm_completion_raw}'). "
|
||||
f"Using internal mock action."
|
||||
)
|
||||
# Define a valid mock action string here
|
||||
internal_mock_action_dict = {
|
||||
"action_type": "move_object",
|
||||
"object_id": self.current_task.target_object_id,
|
||||
"target_position": [1.0, 0.5, 0.5],
|
||||
} # Example mock
|
||||
llm_completion_raw = json.dumps(
|
||||
internal_mock_action_dict
|
||||
) # Ensure it's a JSON string for parsing
|
||||
print(f"DEBUG SPATIAL_ENV: Substituted internal mock: {llm_completion_raw}")
|
||||
|
||||
try:
|
||||
json_str = llm_completion_raw.strip()
|
||||
if json_str.startswith("```json"):
|
||||
json_str = json_str[7:]
|
||||
if json_str.startswith("```"):
|
||||
json_str = json_str[3:]
|
||||
if json_str.endswith("```"):
|
||||
json_str = json_str[:-3]
|
||||
json_str = json_str.strip()
|
||||
|
||||
action_data = json.loads(json_str)
|
||||
if (
|
||||
action_data.get("action_type") == "move_object"
|
||||
and action_data.get("object_id") == self.current_task.target_object_id
|
||||
and isinstance(action_data.get("target_position"), list)
|
||||
and len(action_data.get("target_position")) == 3
|
||||
):
|
||||
parsed_action = action_data
|
||||
else:
|
||||
print(
|
||||
f"Warning: LLM action malformed or targets wrong object: {action_data}"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
f"Warning: LLM response not valid JSON: {llm_completion_raw}. Error: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Unexpected error parsing LLM response: {e}. Response: {llm_completion_raw}"
|
||||
)
|
||||
|
||||
if parsed_action:
|
||||
self.simulator.move_object(
|
||||
object_id=parsed_action["object_id"],
|
||||
target_position=parsed_action["target_position"],
|
||||
)
|
||||
self.simulator.simulate_steps(20)
|
||||
await notify_visualization_clients(
|
||||
self.simulator.get_current_state_for_visualization()
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
else:
|
||||
print("No valid action parsed or executed. Scoring based on current state.")
|
||||
self.simulator.simulate_steps(5)
|
||||
await notify_visualization_clients(
|
||||
self.simulator.get_current_state_for_visualization()
|
||||
)
|
||||
|
||||
distance = self.simulator.calculate_distance(
|
||||
self.current_task.target_object_id, self.current_task.reference_object_id
|
||||
)
|
||||
|
||||
initial_ref_pos = None
|
||||
# Find the initial position of the reference object (blue_sphere) from the stored task data
|
||||
for obj_state in self.current_task.initial_objects:
|
||||
if obj_state.id == self.current_task.reference_object_id:
|
||||
initial_ref_pos = obj_state.position
|
||||
break
|
||||
|
||||
final_target_pos = None
|
||||
# final_sim_state_viz is needed for metadata anyway, and has current positions
|
||||
final_sim_state_viz = self.simulator.get_current_state_for_visualization()
|
||||
for obj_data in final_sim_state_viz:
|
||||
if obj_data["id"] == self.current_task.target_object_id:
|
||||
final_target_pos = obj_data["position"]
|
||||
break
|
||||
|
||||
side_condition_met = False
|
||||
if initial_ref_pos and final_target_pos:
|
||||
# Task: red cube (target) starts at x=2, blue sphere (ref) at x=-2.
|
||||
# Goal: move red cube near blue sphere. It should end up with x < 0 (same side as blue sphere).
|
||||
initial_ref_x_sign = (
|
||||
math.copysign(1.0, initial_ref_pos[0])
|
||||
if initial_ref_pos[0] != 0
|
||||
else 0.0
|
||||
)
|
||||
final_target_x_sign = (
|
||||
math.copysign(1.0, final_target_pos[0])
|
||||
if final_target_pos[0] != 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
if initial_ref_x_sign != 0: # Avoid issues if ref_obj starts at x=0
|
||||
side_condition_met = final_target_x_sign == initial_ref_x_sign
|
||||
else:
|
||||
side_condition_met = (
|
||||
abs(final_target_pos[0]) < 0.5
|
||||
) # If ref is at x=0, target should also be near x=0
|
||||
print(
|
||||
f"DEBUG SCORING: InitialRefXSign: {initial_ref_x_sign}, "
|
||||
f"FinalTargetXSign: {final_target_x_sign}, SideConditionMet: {side_condition_met}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"DEBUG SCORING: Could not determine initial/final positions for side condition check."
|
||||
)
|
||||
|
||||
score = 0.0
|
||||
# Max 0.8 points for distance
|
||||
if distance <= self.current_task.target_distance: # e.g., <= 1.0
|
||||
score = 0.8
|
||||
elif (
|
||||
distance <= self.current_task.target_distance * 1.25
|
||||
): # More lenient threshold for high score band
|
||||
score = 0.6
|
||||
elif distance <= self.current_task.target_distance * 1.75:
|
||||
score = 0.4
|
||||
elif distance <= self.current_task.target_distance * 2.5:
|
||||
score = 0.2
|
||||
|
||||
# Bonus 0.2 points for correct side condition
|
||||
if (
|
||||
side_condition_met
|
||||
): # Give bonus if side condition is met, regardless of exact distance score (as long as it tried)
|
||||
score += 0.2
|
||||
|
||||
score = round(min(score, 1.0), 2) # Cap score at 1.0 and round
|
||||
|
||||
return {
|
||||
"request_id": self.current_task.task_id,
|
||||
"prompt_used": item_from_get_next["llm_prompt"],
|
||||
"llm_completion_raw": llm_completion_raw,
|
||||
"parsed_action": parsed_action,
|
||||
"score": score,
|
||||
"metadata": {
|
||||
"task_description": self.current_task.description,
|
||||
"final_distance": round(distance, 2),
|
||||
"target_distance": self.current_task.target_distance,
|
||||
"side_condition_met": side_condition_met,
|
||||
"final_sim_state_viz": final_sim_state_viz,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MVPDemoRunner:
|
||||
def __init__(self):
|
||||
self.env = SpatialEnvironmentMVP()
|
||||
|
||||
async def run_single_turn_demo(self, use_real_llm: bool = True):
|
||||
print("\n--- Running MVP Demo Turn ---")
|
||||
next_item_data = await self.env.get_next_item()
|
||||
task_id = next_item_data["task_id"]
|
||||
print(f"Task ID: {task_id}")
|
||||
# LLM Prompt is now printed by the llm_service if use_real_llm is True, or before collect_trajectories if not.
|
||||
# print(f"LLM Prompt:\n{llm_prompt}")
|
||||
|
||||
# The llm_completion argument to collect_trajectories is now effectively ignored if use_real_llm is true,
|
||||
# as collect_trajectories will call the LLM service itself.
|
||||
# For mock behavior when use_real_llm is False, we might need to adjust.
|
||||
# However, our llm_service has its own mock, so we can rely on that for now if API key is missing.
|
||||
|
||||
# For clarity, if we are NOT using real LLM (e.g. for process mode without API key),
|
||||
# we should generate a mock completion here and pass it.
|
||||
# But since llm_services.py has a fallback, we might not need a separate mock here IF
|
||||
# the intention is for collect_trajectories to ALWAYS try the LLM service path.
|
||||
# Let's assume collect_trajectories now always drives the LLM call.
|
||||
|
||||
# The llm_completion parameter for collect_trajectories is now mostly for the original mock structure.
|
||||
# We can pass an empty string or None, as it will be replaced by the real LLM call internally.
|
||||
result = await self.env.collect_trajectories(
|
||||
next_item_data, ""
|
||||
) # Pass dummy llm_completion
|
||||
|
||||
print(f"\n--- Result for Task {task_id} ---")
|
||||
print(f"Final Score: {result['score']:.2f}")
|
||||
print(
|
||||
f"Achieved Distance: {result['metadata']['final_distance']:.2f} "
|
||||
f"(Target: {result['metadata']['target_distance']:.2f})"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def process_mode(args):
|
||||
print(
|
||||
f"Running in 'process' mode: generating {args.num_turns} trajectories to {args.output_file}"
|
||||
)
|
||||
|
||||
run_name = f"padres_process_{args.num_turns}turns_{uuid.uuid4().hex[:4]}"
|
||||
wandb_is_initialized = False
|
||||
try:
|
||||
wandb.init(
|
||||
project="nous_hackathon_padres", # Project name for W&B
|
||||
name=run_name,
|
||||
config=vars(args), # Log command line arguments
|
||||
)
|
||||
print(f"W&B Run initialized: {run_name}. View at: {wandb.run.get_url()}")
|
||||
wandb_is_initialized = True
|
||||
except Exception as e:
|
||||
print(f"W&B initialization failed: {e}. Proceeding without W&B logging.")
|
||||
# Optionally, initialize in disabled mode: wandb.init(mode="disabled")
|
||||
|
||||
demo_runner = MVPDemoRunner()
|
||||
results_to_write = []
|
||||
|
||||
try:
|
||||
for i in range(args.num_turns):
|
||||
turn_num = i + 1
|
||||
print(f"\n--- Generating Trajectory Turn {turn_num}/{args.num_turns} ---")
|
||||
turn_result = (
|
||||
await demo_runner.run_single_turn_demo()
|
||||
) # Assumes run_single_turn_demo uses real LLM by default now
|
||||
results_to_write.append(turn_result)
|
||||
|
||||
if wandb_is_initialized and wandb.run:
|
||||
wandb_log_data = {
|
||||
"turn": turn_num,
|
||||
"task_id": turn_result.get("request_id", "N/A"),
|
||||
"score": turn_result.get("score", 0.0),
|
||||
"final_distance": turn_result.get("metadata", {}).get(
|
||||
"final_distance", float("inf")
|
||||
),
|
||||
"target_distance": turn_result.get("metadata", {}).get(
|
||||
"target_distance", 0.0
|
||||
),
|
||||
"side_condition_met": int(
|
||||
turn_result.get("metadata", {}).get("side_condition_met", False)
|
||||
),
|
||||
}
|
||||
if turn_result.get("parsed_action"):
|
||||
parsed_action = turn_result["parsed_action"]
|
||||
wandb_log_data["action_object_id"] = parsed_action.get("object_id")
|
||||
target_pos = parsed_action.get(
|
||||
"target_position", [None, None, None]
|
||||
)
|
||||
wandb_log_data["action_target_x"] = (
|
||||
target_pos[0] if target_pos and len(target_pos) > 0 else None
|
||||
)
|
||||
wandb_log_data["action_target_y"] = (
|
||||
target_pos[1] if target_pos and len(target_pos) > 1 else None
|
||||
)
|
||||
wandb_log_data["action_target_z"] = (
|
||||
target_pos[2] if target_pos and len(target_pos) > 2 else None
|
||||
)
|
||||
|
||||
wandb.log(wandb_log_data)
|
||||
print(
|
||||
f"Logged to W&B: Turn {turn_num}, Score: {turn_result.get('score')}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
with open(args.output_file, "w") as f:
|
||||
for result_item in results_to_write:
|
||||
f.write(json.dumps(result_item) + "\n")
|
||||
print(
|
||||
f"\nSuccessfully wrote {len(results_to_write)} trajectories to {args.output_file}"
|
||||
)
|
||||
|
||||
finally:
|
||||
if demo_runner.env.simulator:
|
||||
demo_runner.env.simulator.cleanup()
|
||||
if wandb_is_initialized and wandb.run:
|
||||
wandb.finish()
|
||||
print("Processing complete. W&B run finished.")
|
||||
else:
|
||||
print(
|
||||
"Processing complete. (W&B was not fully initialized or did not start a run)"
|
||||
)
|
||||
|
||||
|
||||
async def server_mode():
|
||||
global shared_demo_runner_instance # Make demo_runner available to the handler via this global
|
||||
shared_demo_runner_instance = MVPDemoRunner()
|
||||
|
||||
websocket_server = await websockets.serve(
|
||||
visualization_websocket_handler, # The handler will use shared_demo_runner_instance
|
||||
"localhost",
|
||||
8765,
|
||||
)
|
||||
print("Visualization WebSocket Server started on ws://localhost:8765")
|
||||
print("Open visualization/index.html in your browser.")
|
||||
print("You can run multiple demo turns. Press Ctrl+C to stop everything.")
|
||||
|
||||
try:
|
||||
# Default 5 auto turns in server mode, now using LLM by default
|
||||
for i in range(5): # Changed from 3 to 5
|
||||
print(f"\n--- Auto Demo Turn {i+1} ---")
|
||||
# use_real_llm=True by default
|
||||
await shared_demo_runner_instance.run_single_turn_demo(use_real_llm=True)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
print(
|
||||
"\nAutomatic demo turns complete. Server is still running for manual interaction or further tests."
|
||||
)
|
||||
await websocket_server.wait_closed()
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutting down servers...")
|
||||
finally:
|
||||
websocket_server.close()
|
||||
await websocket_server.wait_closed()
|
||||
if shared_demo_runner_instance.env.simulator:
|
||||
shared_demo_runner_instance.env.simulator.cleanup()
|
||||
print("Servers and physics simulation stopped.")
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="Spatial RL Environment MVP")
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
process_parser = subparsers.add_parser("process", help="Generate trajectory data")
|
||||
process_parser.add_argument(
|
||||
"--num_turns",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of trajectory turns to generate",
|
||||
)
|
||||
process_parser.add_argument(
|
||||
"--output_file",
|
||||
type=str,
|
||||
default="trajectories.jsonl",
|
||||
help="File to save trajectory data",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
if args.command == "process":
|
||||
await process_mode(args)
|
||||
elif args.command is None and not unknown:
|
||||
print("No command specified, running in default server mode.")
|
||||
await server_mode()
|
||||
elif unknown:
|
||||
print(f"Unknown arguments or command: {unknown}")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>MVP Spatial RL Visualization</title>
|
||||
<link rel="stylesheet" href="style.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="info">
|
||||
<h1>MVP Spatial RL Visualization</h1>
|
||||
<p>Status: <span id="status">Connecting...</span></p>
|
||||
<p>Task: <span id="taskDescription">N/A</span></p>
|
||||
</div>
|
||||
<div id="visualizationContainer">
|
||||
<!-- Canvas will be inserted here by Three.js -->
|
||||
</div>
|
||||
|
||||
<!-- Three.js Library -->
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
|
||||
<!-- Optional: OrbitControls for camera interaction -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/controls/OrbitControls.js"></script>
|
||||
|
||||
<script type="module" src="main.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
186
environments/community/padres_spatial/visualization/main.js
Normal file
186
environments/community/padres_spatial/visualization/main.js
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
// Make sure OrbitControls is loaded from CDN or included if you use it
|
||||
// If OrbitControls is loaded globally via script tag: const OrbitControls = window.OrbitControls;
|
||||
// const TWEEN = window.TWEEN; // If using TWEEN CDN
|
||||
|
||||
// --- Global Variables ---
|
||||
let scene, camera, renderer, controls;
|
||||
const objectsInScene = new Map(); // Stores THREE.Mesh objects by their ID
|
||||
const websocketUrl = "ws://localhost:8765";
|
||||
let socket;
|
||||
|
||||
// --- DOM Elements ---
|
||||
const statusElement = document.getElementById('status');
|
||||
const taskDescriptionElement = document.getElementById('taskDescription');
|
||||
const visualizationContainer = document.getElementById('visualizationContainer');
|
||||
|
||||
// --- Initialization ---
|
||||
function init() {
|
||||
// Scene
|
||||
scene = new THREE.Scene();
|
||||
scene.background = new THREE.Color(0xdddddd);
|
||||
|
||||
// Camera
|
||||
camera = new THREE.PerspectiveCamera(75, visualizationContainer.clientWidth / visualizationContainer.clientHeight, 0.1, 1000);
|
||||
camera.position.set(3, 4, 5); // Adjusted camera position
|
||||
camera.lookAt(0, 0, 0);
|
||||
|
||||
// Renderer
|
||||
renderer = new THREE.WebGLRenderer({ antialias: true });
|
||||
renderer.setSize(visualizationContainer.clientWidth, visualizationContainer.clientHeight);
|
||||
renderer.setPixelRatio(window.devicePixelRatio);
|
||||
renderer.shadowMap.enabled = true; // Enable shadows
|
||||
visualizationContainer.appendChild(renderer.domElement);
|
||||
|
||||
// Lights
|
||||
const ambientLight = new THREE.AmbientLight(0xffffff, 0.6);
|
||||
scene.add(ambientLight);
|
||||
const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
|
||||
directionalLight.position.set(5, 10, 7);
|
||||
directionalLight.castShadow = true; // Enable shadow casting for this light
|
||||
// Configure shadow properties for better quality (optional)
|
||||
directionalLight.shadow.mapSize.width = 1024;
|
||||
directionalLight.shadow.mapSize.height = 1024;
|
||||
directionalLight.shadow.camera.near = 0.5;
|
||||
directionalLight.shadow.camera.far = 50;
|
||||
scene.add(directionalLight);
|
||||
|
||||
// Ground Plane
|
||||
const planeGeometry = new THREE.PlaneGeometry(20, 20);
|
||||
const planeMaterial = new THREE.MeshStandardMaterial({ color: 0xaaaaaa, roughness: 0.8 });
|
||||
const groundPlane = new THREE.Mesh(planeGeometry, planeMaterial);
|
||||
groundPlane.rotation.x = -Math.PI / 2;
|
||||
groundPlane.receiveShadow = true; // Allow plane to receive shadows
|
||||
scene.add(groundPlane);
|
||||
|
||||
// Controls (Optional, if OrbitControls is loaded)
|
||||
if (typeof OrbitControls !== 'undefined') {
|
||||
controls = new OrbitControls(camera, renderer.domElement);
|
||||
controls.enableDamping = true;
|
||||
controls.dampingFactor = 0.05;
|
||||
controls.screenSpacePanning = false;
|
||||
controls.minDistance = 2;
|
||||
controls.maxDistance = 20;
|
||||
controls.maxPolarAngle = Math.PI / 2 - 0.05; // Prevent camera from going below ground
|
||||
}
|
||||
|
||||
// Handle window resize
|
||||
window.addEventListener('resize', onWindowResize, false);
|
||||
|
||||
// Start animation loop
|
||||
animate();
|
||||
|
||||
// Connect to WebSocket
|
||||
connectWebSocket();
|
||||
}
|
||||
|
||||
// --- WebSocket Handling ---
|
||||
function connectWebSocket() {
|
||||
socket = new WebSocket(websocketUrl);
|
||||
statusElement.textContent = "Connecting to WebSocket...";
|
||||
|
||||
socket.onopen = () => {
|
||||
statusElement.textContent = "Connected to Physics Server!";
|
||||
console.log("WebSocket connected.");
|
||||
// You could send a "client_ready" message if needed
|
||||
};
|
||||
|
||||
socket.onmessage = (event) => {
|
||||
try {
|
||||
const message = JSON.parse(event.data);
|
||||
// console.log("Message from server:", message);
|
||||
if (message.type === "scene_update" || message.type === "initial_scene") {
|
||||
updateScene(message.payload); // payload is List[ObjectData]
|
||||
if (message.type === "initial_scene" && message.task_description) {
|
||||
taskDescriptionElement.textContent = message.task_description;
|
||||
}
|
||||
} else if (message.type === "task_info") { // Example for updating task description
|
||||
taskDescriptionElement.textContent = message.description || "N/A";
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Error processing message from server:", e, event.data);
|
||||
}
|
||||
};
|
||||
|
||||
socket.onerror = (error) => {
|
||||
statusElement.textContent = "WebSocket Error!";
|
||||
console.error("WebSocket Error:", error);
|
||||
};
|
||||
|
||||
socket.onclose = () => {
|
||||
statusElement.textContent = "Disconnected. Attempting to reconnect in 3s...";
|
||||
console.log("WebSocket disconnected. Reconnecting in 3 seconds...");
|
||||
setTimeout(connectWebSocket, 3000); // Simple reconnect logic
|
||||
};
|
||||
}
|
||||
|
||||
// --- Three.js Scene Updates ---
|
||||
function updateScene(objectStates) { // objectStates is List of Dicts from server
|
||||
const receivedIds = new Set();
|
||||
|
||||
objectStates.forEach(objState => {
|
||||
receivedIds.add(objState.id);
|
||||
let threeObject = objectsInScene.get(objState.id);
|
||||
|
||||
if (!threeObject) { // Object doesn't exist, create it
|
||||
let geometry;
|
||||
const scale = objState.scale || [1,1,1];
|
||||
if (objState.type === "cube") {
|
||||
geometry = new THREE.BoxGeometry(scale[0], scale[1], scale[2]);
|
||||
} else if (objState.type === "sphere") {
|
||||
geometry = new THREE.SphereGeometry(scale[0] / 2, 32, 16); // Assume scale[0] is diameter
|
||||
} else {
|
||||
console.warn("Unsupported object type for visualization:", objState.type);
|
||||
geometry = new THREE.BoxGeometry(1, 1, 1); // Default placeholder
|
||||
}
|
||||
|
||||
const color = new THREE.Color(...(objState.color_rgba ? objState.color_rgba.slice(0,3) : [0.5, 0.5, 0.5]));
|
||||
const material = new THREE.MeshStandardMaterial({ color: color, roughness: 0.5, metalness: 0.1 });
|
||||
|
||||
threeObject = new THREE.Mesh(geometry, material);
|
||||
threeObject.name = objState.id; // Useful for debugging
|
||||
threeObject.castShadow = true; // Object casts shadows
|
||||
threeObject.receiveShadow = false; // Usually objects don't receive shadows on themselves unless complex
|
||||
|
||||
scene.add(threeObject);
|
||||
objectsInScene.set(objState.id, threeObject);
|
||||
}
|
||||
|
||||
// Update position and orientation
|
||||
if (objState.position) {
|
||||
threeObject.position.set(...objState.position);
|
||||
}
|
||||
if (objState.orientation_quaternion) {
|
||||
threeObject.quaternion.set(...objState.orientation_quaternion);
|
||||
}
|
||||
// TODO: Update color or other properties if they can change dynamically
|
||||
});
|
||||
|
||||
// Remove objects that are in Three.js scene but not in the new state
|
||||
objectsInScene.forEach((obj, id) => {
|
||||
if (!receivedIds.has(id)) {
|
||||
scene.remove(obj);
|
||||
obj.geometry.dispose(); // Dispose geometry
|
||||
obj.material.dispose(); // Dispose material
|
||||
objectsInScene.delete(id);
|
||||
console.log(`Removed object ${id} from scene.`);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// --- Animation Loop & Resize ---
|
||||
function animate() {
|
||||
requestAnimationFrame(animate);
|
||||
if (controls) {
|
||||
controls.update(); // Only if OrbitControls is used
|
||||
}
|
||||
renderer.render(scene, camera);
|
||||
}
|
||||
|
||||
function onWindowResize() {
|
||||
camera.aspect = visualizationContainer.clientWidth / visualizationContainer.clientHeight;
|
||||
camera.updateProjectionMatrix();
|
||||
renderer.setSize(visualizationContainer.clientWidth / visualizationContainer.clientHeight);
|
||||
}
|
||||
|
||||
// --- Start Everything ---
|
||||
init();
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: Arial, sans-serif;
|
||||
overflow: hidden;
|
||||
background-color: #f0f0f0;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
#container {
|
||||
display: flex;
|
||||
width: 100vw;
|
||||
height: 100vh;
|
||||
}
|
||||
|
||||
#canvas-container {
|
||||
flex: 0.7;
|
||||
height: 100vh;
|
||||
}
|
||||
|
||||
#info-panel {
|
||||
flex: 0.3;
|
||||
padding: 20px;
|
||||
background-color: #f5f5f5;
|
||||
border-left: 1px solid #ddd;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
#score {
|
||||
font-size: 1.2em;
|
||||
font-weight: bold;
|
||||
margin: 10px 0;
|
||||
padding: 10px;
|
||||
background-color: #fff;
|
||||
border-radius: 4px;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
#objects-info {
|
||||
margin-top: 20px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
#info {
|
||||
padding: 10px;
|
||||
background-color: #fff;
|
||||
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
#info h1 {
|
||||
margin-top: 0;
|
||||
font-size: 1.5em;
|
||||
}
|
||||
|
||||
#visualizationContainer {
|
||||
width: 100vw;
|
||||
height: 80vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
canvas {
|
||||
display: block;
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue