atropos/environments/community/padres_spatial/spatial_env.py
Dakota e13526d308 Fix API to accept messages without reward field + comprehensive tests
- Made reward field truly optional in messages (no auto-addition)
- Accept custom roles (dog, cat, etc.) beyond standard ones
- Added 24 new tests for edge cases (tuples, unicode, large content)
- Reorganized test structure: moved from testing/ to atroposlib/tests/
- Fixed legacy API tests and removed tests requiring missing data files

All 43 tests pass\! Fixes message handling for SFT use cases.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-09 14:03:08 -05:00

745 lines
29 KiB
Python

import argparse
import asyncio
import json
import math
import sys
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import pybullet as p
import pybullet_data
import wandb
import websockets
# LLM Service Import
from .llm_services import get_anthropic_completion
@dataclass
class ObjectState:
id: str
type: str # 'cube', 'sphere'
position: List[float]
orientation_quaternion: List[float] = field(
default_factory=lambda: [0.0, 0.0, 0.0, 1.0]
)
scale: List[float] = field(default_factory=lambda: [1.0, 1.0, 1.0])
color_rgba: List[float] = field(default_factory=lambda: [0.5, 0.5, 0.5, 1.0])
@dataclass
class SpatialTask:
task_id: str
description: str
initial_objects: List[ObjectState]
goal_description: str
target_object_id: str
reference_object_id: str
target_distance: float = 1.0
class MVPPhysicsSimulator:
def __init__(self):
self.client_id = -1
self.objects_pb_ids: Dict[str, int] = {}
self.object_configs: Dict[str, ObjectState] = {}
def initialize(self, objects: List[ObjectState]):
if self.client_id != -1:
p.disconnect(physicsClientId=self.client_id)
self.client_id = p.connect(p.DIRECT)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
p.setGravity(0, 0, -9.8, physicsClientId=self.client_id)
p.loadURDF("plane.urdf", physicsClientId=self.client_id)
self.objects_pb_ids = {}
self.object_configs = {}
for obj_state in objects:
self._add_object(obj_state)
print(f"Physics initialized with {len(self.objects_pb_ids)} objects.")
def _add_object(self, obj_state: ObjectState):
half_extents = [s / 2.0 for s in obj_state.scale]
shape_id = -1
if obj_state.type == "cube":
shape_id = p.createCollisionShape(
p.GEOM_BOX, halfExtents=half_extents, physicsClientId=self.client_id
)
elif obj_state.type == "sphere":
shape_id = p.createCollisionShape(
p.GEOM_SPHERE, radius=half_extents[0], physicsClientId=self.client_id
)
else:
print(
f"Warning: Unsupported object type '{obj_state.type}' for object ID '{obj_state.id}'"
)
return
if obj_state.type == "cube":
visual_shape_id = p.createVisualShape(
shapeType=p.GEOM_BOX,
halfExtents=half_extents,
rgbaColor=obj_state.color_rgba,
physicsClientId=self.client_id,
)
elif obj_state.type == "sphere":
visual_shape_id = p.createVisualShape(
shapeType=p.GEOM_SPHERE,
radius=half_extents[0],
rgbaColor=obj_state.color_rgba,
physicsClientId=self.client_id,
)
else:
print(
f"Warning: Unsupported object type '{obj_state.type}' for object ID '{obj_state.id}'"
)
return
body_id = p.createMultiBody(
baseMass=1,
baseCollisionShapeIndex=shape_id,
baseVisualShapeIndex=visual_shape_id,
basePosition=obj_state.position,
baseOrientation=obj_state.orientation_quaternion,
physicsClientId=self.client_id,
)
self.objects_pb_ids[obj_state.id] = body_id
self.object_configs[obj_state.id] = obj_state
def move_object(
self,
object_id: str,
target_position: List[float],
target_orientation_quaternion: Optional[List[float]] = None,
):
if object_id in self.objects_pb_ids:
body_id = self.objects_pb_ids[object_id]
if target_orientation_quaternion is None:
_, current_orientation = p.getBasePositionAndOrientation(
body_id, physicsClientId=self.client_id
)
target_orientation_quaternion = list(current_orientation)
p.resetBasePositionAndOrientation(
body_id,
target_position,
target_orientation_quaternion,
physicsClientId=self.client_id,
)
else:
print(f"Warning: Attempted to move unknown object ID '{object_id}'")
def simulate_steps(self, steps: int = 10):
for _ in range(steps):
p.stepSimulation(physicsClientId=self.client_id)
def get_current_state_for_visualization(self) -> List[Dict[str, Any]]:
viz_state = []
for obj_id, body_id in self.objects_pb_ids.items():
pos, orn_quat = p.getBasePositionAndOrientation(
body_id, physicsClientId=self.client_id
)
original_config = self.object_configs.get(obj_id)
if original_config:
viz_state.append(
{
"id": obj_id,
"type": original_config.type,
"position": list(pos),
"orientation_quaternion": list(orn_quat),
"scale": original_config.scale,
"color_rgba": original_config.color_rgba,
}
)
return viz_state
def calculate_distance(self, obj1_id: str, obj2_id: str) -> float:
pos1, pos2 = None, None
current_state = self.get_current_state_for_visualization()
for obj_data in current_state:
if obj_data["id"] == obj1_id:
pos1 = obj_data["position"]
if obj_data["id"] == obj2_id:
pos2 = obj_data["position"]
if pos1 and pos2:
return math.sqrt(sum((a - b) ** 2 for a, b in zip(pos1, pos2)))
return float("inf")
def cleanup(self):
if self.client_id != -1:
p.disconnect(physicsClientId=self.client_id)
self.client_id = -1
print("Physics simulation cleaned up.")
connected_visualization_clients = set()
global_physics_simulator_instance: Optional[MVPPhysicsSimulator] = None
# To make demo_runner accessible to the WebSocket handler in a cleaner way for server_mode
shared_demo_runner_instance: Optional["MVPDemoRunner"] = None
async def notify_visualization_clients(scene_state: List[Dict[str, Any]]):
if connected_visualization_clients:
message = json.dumps({"type": "scene_update", "payload": scene_state})
await asyncio.gather(
*[client.send(message) for client in connected_visualization_clients]
)
async def visualization_websocket_handler(websocket):
# Make shared_demo_runner_instance accessible
global global_physics_simulator_instance, shared_demo_runner_instance # noqa
connected_visualization_clients.add(websocket)
print(
f"Visualization client connected: {websocket.remote_address} (Total: {len(connected_visualization_clients)})"
)
try:
if global_physics_simulator_instance:
initial_state = (
global_physics_simulator_instance.get_current_state_for_visualization()
)
await websocket.send(
json.dumps({"type": "initial_scene", "payload": initial_state})
)
async for message_str in websocket:
print(f"Message from viz client: {message_str}")
try:
data = json.loads(message_str)
command = data.get("command")
if command == "next_llm_task":
print("Received 'next_llm_task' command from client.")
if shared_demo_runner_instance:
# Run a single turn, which now defaults to using the real LLM via llm_services
asyncio.create_task(
shared_demo_runner_instance.run_single_turn_demo(
use_real_llm=True
)
)
else:
print(
"Error: shared_demo_runner_instance not found to execute 'next_llm_task'"
)
else:
print(f"Unknown command received: {command}")
except json.JSONDecodeError:
print(f"Invalid JSON from client: {message_str}")
except Exception as e:
print(f"Error processing client command: {e}")
except websockets.exceptions.ConnectionClosed:
print(
f"Visualization client disconnected. (Total: {len(connected_visualization_clients)-1})"
)
except Exception as e:
print(f"Error in visualization_websocket_handler: {e}")
finally:
connected_visualization_clients.remove(websocket)
class SpatialEnvironmentMVP:
def __init__(self):
global global_physics_simulator_instance
self.simulator = MVPPhysicsSimulator()
global_physics_simulator_instance = self.simulator
self.current_task: Optional[SpatialTask] = None
self.task_id_counter = 0
async def get_next_item(self) -> Dict[str, Any]:
self.task_id_counter += 1
task_id = f"conditional_task_{self.task_id_counter}_{uuid.uuid4().hex[:4]}"
# Start objects on opposite sides, e.g., along the x-axis
objects = [
ObjectState(
id="red_cube",
type="cube",
position=[2.0, 0.5, 0.5],
scale=[1, 1, 1],
color_rgba=[1, 0, 0, 1],
),
ObjectState(
id="blue_sphere",
type="sphere",
position=[-2.0, 0.5, 0.5],
scale=[1, 1, 1],
color_rgba=[0, 0, 1, 1],
),
]
task_description = (
"The red cube and blue sphere are on opposite sides of the YZ plane (different X signs). "
"Move the red cube so it remains on the opposite side of the YZ plane from the blue sphere, "
"but position it very close to the blue sphere (approximately 1.0 unit away)."
)
goal_description = (
"The red_cube's final x-coordinate should have the opposite sign to the blue_sphere's x-coordinate. "
"The distance between the center of the red_cube and the center of the blue_sphere "
"should be approximately 1.0 unit."
)
task = SpatialTask(
task_id=task_id,
description=task_description,
initial_objects=objects,
goal_description=goal_description,
target_object_id="red_cube",
reference_object_id="blue_sphere",
target_distance=1.0, # This remains the target for proximity
)
self.current_task = task
self.simulator.initialize(task.initial_objects)
await notify_visualization_clients(
self.simulator.get_current_state_for_visualization()
)
return {
"task_id": task.task_id,
"llm_prompt": self._create_llm_prompt(
task, objects
), # Pass initial objects for the prompt
}
def _create_llm_prompt(
self, task: SpatialTask, initial_objects_state: List[ObjectState]
) -> str:
# Use the passed initial_objects_state for accurate current positions in the prompt
objects_desc_parts = []
for obj_state in initial_objects_state:
# Find the current position from the simulator if it has been initialized and objects added
# For the initial prompt, obj_state.position IS the current position.
objects_desc_parts.append(
f"- ID: {obj_state.id}, Type: {obj_state.type}, "
f"Current Position: [{obj_state.position[0]:.2f}, {obj_state.position[1]:.2f}, "
f"{obj_state.position[2]:.2f}]"
)
objects_desc = "\n".join(objects_desc_parts)
# Reference object's current position for the hint
ref_obj_pos_str = "N/A"
for obj_state in initial_objects_state:
if obj_state.id == task.reference_object_id:
ref_obj_pos_str = (
f"[{obj_state.position[0]:.2f}, {obj_state.position[1]:.2f}, "
f"{obj_state.position[2]:.2f}]"
)
break
hint = (
f"Hint: The blue_sphere (reference object) is currently at {ref_obj_pos_str}. "
f"To keep the red_cube on the opposite side of the YZ plane, its x-coordinate "
f"should generally have the opposite sign to the blue_sphere's x-coordinate. "
f"Adjust its position to be about {task.target_distance:.1f} unit away from the blue_sphere."
)
return f"""Task: {task.description}
Goal: {task.goal_description}
Available Objects (initial state):
{objects_desc}
{hint}
You control: '{task.target_object_id}'.
Your action MUST be a JSON object like:
{{
"action_type": "move_object",
"object_id": "{task.target_object_id}",
"target_position": [x_float, y_float, z_float] # New target coordinates
}}
Only provide the JSON for the action. Do not add any other text or explanations.
Your JSON action:"""
async def collect_trajectories(
self, item_from_get_next: Dict[str, Any], llm_completion_raw: str
) -> Dict[str, Any]:
if not self.current_task:
return {
"error": "No current task set. Call get_next_item first.",
"score": 0.0,
}
llm_prompt_for_api = item_from_get_next["llm_prompt"]
print(
"\nDEBUG SPATIAL_ENV: Calling get_anthropic_completion with prompt... Timeout in 30s"
)
try:
# Add a timeout to the LLM call to prevent indefinite hanging
llm_completion_raw = await asyncio.wait_for(
get_anthropic_completion(llm_prompt_for_api), timeout=30.0
)
except asyncio.TimeoutError:
print(
"DEBUG SPATIAL_ENV: LLM call timed out after 30s. Using fallback mock."
)
llm_completion_raw = None # Indicate timeout
except Exception as e:
print(
f"DEBUG SPATIAL_ENV: Error during get_anthropic_completion call: {e}. Using fallback mock."
)
llm_completion_raw = None # Indicate other error
print(
f"DEBUG SPATIAL_ENV: llm_completion_raw received from get_anthropic_completion: '{llm_completion_raw}'"
)
parsed_action = None
# Check if llm_completion_raw is valid before attempting to parse
if (
not llm_completion_raw
or not isinstance(llm_completion_raw, str)
or llm_completion_raw.strip() == ""
or llm_completion_raw.strip() == "."
):
print(
f"DEBUG SPATIAL_ENV: llm_completion_raw is invalid ('{llm_completion_raw}'). "
f"Using internal mock action."
)
# Define a valid mock action string here
internal_mock_action_dict = {
"action_type": "move_object",
"object_id": self.current_task.target_object_id,
"target_position": [1.0, 0.5, 0.5],
} # Example mock
llm_completion_raw = json.dumps(
internal_mock_action_dict
) # Ensure it's a JSON string for parsing
print(f"DEBUG SPATIAL_ENV: Substituted internal mock: {llm_completion_raw}")
try:
json_str = llm_completion_raw.strip()
if json_str.startswith("```json"):
json_str = json_str[7:]
if json_str.startswith("```"):
json_str = json_str[3:]
if json_str.endswith("```"):
json_str = json_str[:-3]
json_str = json_str.strip()
action_data = json.loads(json_str)
if (
action_data.get("action_type") == "move_object"
and action_data.get("object_id") == self.current_task.target_object_id
and isinstance(action_data.get("target_position"), list)
and len(action_data.get("target_position")) == 3
):
parsed_action = action_data
else:
print(
f"Warning: LLM action malformed or targets wrong object: {action_data}"
)
except json.JSONDecodeError as e:
print(
f"Warning: LLM response not valid JSON: {llm_completion_raw}. Error: {e}"
)
except Exception as e:
print(
f"Warning: Unexpected error parsing LLM response: {e}. Response: {llm_completion_raw}"
)
if parsed_action:
self.simulator.move_object(
object_id=parsed_action["object_id"],
target_position=parsed_action["target_position"],
)
self.simulator.simulate_steps(20)
await notify_visualization_clients(
self.simulator.get_current_state_for_visualization()
)
await asyncio.sleep(0.05)
else:
print("No valid action parsed or executed. Scoring based on current state.")
self.simulator.simulate_steps(5)
await notify_visualization_clients(
self.simulator.get_current_state_for_visualization()
)
distance = self.simulator.calculate_distance(
self.current_task.target_object_id, self.current_task.reference_object_id
)
initial_ref_pos = None
# Find the initial position of the reference object (blue_sphere) from the stored task data
for obj_state in self.current_task.initial_objects:
if obj_state.id == self.current_task.reference_object_id:
initial_ref_pos = obj_state.position
break
final_target_pos = None
# final_sim_state_viz is needed for metadata anyway, and has current positions
final_sim_state_viz = self.simulator.get_current_state_for_visualization()
for obj_data in final_sim_state_viz:
if obj_data["id"] == self.current_task.target_object_id:
final_target_pos = obj_data["position"]
break
side_condition_met = False
if initial_ref_pos and final_target_pos:
# Task: red cube (target) starts at x=2, blue sphere (ref) at x=-2.
# Goal: move red cube near blue sphere. It should end up with x < 0 (same side as blue sphere).
initial_ref_x_sign = (
math.copysign(1.0, initial_ref_pos[0])
if initial_ref_pos[0] != 0
else 0.0
)
final_target_x_sign = (
math.copysign(1.0, final_target_pos[0])
if final_target_pos[0] != 0
else 0.0
)
if initial_ref_x_sign != 0: # Avoid issues if ref_obj starts at x=0
side_condition_met = final_target_x_sign == initial_ref_x_sign
else:
side_condition_met = (
abs(final_target_pos[0]) < 0.5
) # If ref is at x=0, target should also be near x=0
print(
f"DEBUG SCORING: InitialRefXSign: {initial_ref_x_sign}, "
f"FinalTargetXSign: {final_target_x_sign}, SideConditionMet: {side_condition_met}"
)
else:
print(
"DEBUG SCORING: Could not determine initial/final positions for side condition check."
)
score = 0.0
# Max 0.8 points for distance
if distance <= self.current_task.target_distance: # e.g., <= 1.0
score = 0.8
elif (
distance <= self.current_task.target_distance * 1.25
): # More lenient threshold for high score band
score = 0.6
elif distance <= self.current_task.target_distance * 1.75:
score = 0.4
elif distance <= self.current_task.target_distance * 2.5:
score = 0.2
# Bonus 0.2 points for correct side condition
if (
side_condition_met
): # Give bonus if side condition is met, regardless of exact distance score (as long as it tried)
score += 0.2
score = round(min(score, 1.0), 2) # Cap score at 1.0 and round
return {
"request_id": self.current_task.task_id,
"prompt_used": item_from_get_next["llm_prompt"],
"llm_completion_raw": llm_completion_raw,
"parsed_action": parsed_action,
"score": score,
"metadata": {
"task_description": self.current_task.description,
"final_distance": round(distance, 2),
"target_distance": self.current_task.target_distance,
"side_condition_met": side_condition_met,
"final_sim_state_viz": final_sim_state_viz,
},
}
class MVPDemoRunner:
def __init__(self):
self.env = SpatialEnvironmentMVP()
async def run_single_turn_demo(self, use_real_llm: bool = True):
print("\n--- Running MVP Demo Turn ---")
next_item_data = await self.env.get_next_item()
task_id = next_item_data["task_id"]
print(f"Task ID: {task_id}")
# LLM Prompt is now printed by the llm_service if use_real_llm is True, or before collect_trajectories if not.
# print(f"LLM Prompt:\n{llm_prompt}")
# The llm_completion argument to collect_trajectories is now effectively ignored if use_real_llm is true,
# as collect_trajectories will call the LLM service itself.
# For mock behavior when use_real_llm is False, we might need to adjust.
# However, our llm_service has its own mock, so we can rely on that for now if API key is missing.
# For clarity, if we are NOT using real LLM (e.g. for process mode without API key),
# we should generate a mock completion here and pass it.
# But since llm_services.py has a fallback, we might not need a separate mock here IF
# the intention is for collect_trajectories to ALWAYS try the LLM service path.
# Let's assume collect_trajectories now always drives the LLM call.
# The llm_completion parameter for collect_trajectories is now mostly for the original mock structure.
# We can pass an empty string or None, as it will be replaced by the real LLM call internally.
result = await self.env.collect_trajectories(
next_item_data, ""
) # Pass dummy llm_completion
print(f"\n--- Result for Task {task_id} ---")
print(f"Final Score: {result['score']:.2f}")
print(
f"Achieved Distance: {result['metadata']['final_distance']:.2f} "
f"(Target: {result['metadata']['target_distance']:.2f})"
)
return result
async def process_mode(args):
print(
f"Running in 'process' mode: generating {args.num_turns} trajectories to {args.output_file}"
)
run_name = f"padres_process_{args.num_turns}turns_{uuid.uuid4().hex[:4]}"
wandb_is_initialized = False
try:
wandb.init(
project="nous_hackathon_padres", # Project name for W&B
name=run_name,
config=vars(args), # Log command line arguments
)
print(f"W&B Run initialized: {run_name}. View at: {wandb.run.get_url()}")
wandb_is_initialized = True
except Exception as e:
print(f"W&B initialization failed: {e}. Proceeding without W&B logging.")
# Optionally, initialize in disabled mode: wandb.init(mode="disabled")
demo_runner = MVPDemoRunner()
results_to_write = []
try:
for i in range(args.num_turns):
turn_num = i + 1
print(f"\n--- Generating Trajectory Turn {turn_num}/{args.num_turns} ---")
turn_result = (
await demo_runner.run_single_turn_demo()
) # Assumes run_single_turn_demo uses real LLM by default now
results_to_write.append(turn_result)
if wandb_is_initialized and wandb.run:
wandb_log_data = {
"turn": turn_num,
"task_id": turn_result.get("request_id", "N/A"),
"score": turn_result.get("score", 0.0),
"final_distance": turn_result.get("metadata", {}).get(
"final_distance", float("inf")
),
"target_distance": turn_result.get("metadata", {}).get(
"target_distance", 0.0
),
"side_condition_met": int(
turn_result.get("metadata", {}).get("side_condition_met", False)
),
}
if turn_result.get("parsed_action"):
parsed_action = turn_result["parsed_action"]
wandb_log_data["action_object_id"] = parsed_action.get("object_id")
target_pos = parsed_action.get(
"target_position", [None, None, None]
)
wandb_log_data["action_target_x"] = (
target_pos[0] if target_pos and len(target_pos) > 0 else None
)
wandb_log_data["action_target_y"] = (
target_pos[1] if target_pos and len(target_pos) > 1 else None
)
wandb_log_data["action_target_z"] = (
target_pos[2] if target_pos and len(target_pos) > 2 else None
)
wandb.log(wandb_log_data)
print(
f"Logged to W&B: Turn {turn_num}, Score: {turn_result.get('score')}"
)
await asyncio.sleep(0.1)
with open(args.output_file, "w") as f:
for result_item in results_to_write:
f.write(json.dumps(result_item) + "\n")
print(
f"\nSuccessfully wrote {len(results_to_write)} trajectories to {args.output_file}"
)
finally:
if demo_runner.env.simulator:
demo_runner.env.simulator.cleanup()
if wandb_is_initialized and wandb.run:
wandb.finish()
print("Processing complete. W&B run finished.")
else:
print(
"Processing complete. (W&B was not fully initialized or did not start a run)"
)
async def server_mode():
global shared_demo_runner_instance # Make demo_runner available to the handler via this global
shared_demo_runner_instance = MVPDemoRunner()
websocket_server = await websockets.serve(
visualization_websocket_handler, # The handler will use shared_demo_runner_instance
"localhost",
8765,
)
print("Visualization WebSocket Server started on ws://localhost:8765")
print("Open visualization/index.html in your browser.")
print("You can run multiple demo turns. Press Ctrl+C to stop everything.")
try:
# Default 5 auto turns in server mode, now using LLM by default
for i in range(5): # Changed from 3 to 5
print(f"\n--- Auto Demo Turn {i+1} ---")
# use_real_llm=True by default
await shared_demo_runner_instance.run_single_turn_demo(use_real_llm=True)
await asyncio.sleep(2)
print(
"\nAutomatic demo turns complete. Server is still running for manual interaction or further tests."
)
await websocket_server.wait_closed()
except KeyboardInterrupt:
print("\nShutting down servers...")
finally:
websocket_server.close()
await websocket_server.wait_closed()
if shared_demo_runner_instance.env.simulator:
shared_demo_runner_instance.env.simulator.cleanup()
print("Servers and physics simulation stopped.")
async def main():
parser = argparse.ArgumentParser(description="Spatial RL Environment MVP")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
process_parser = subparsers.add_parser("process", help="Generate trajectory data")
process_parser.add_argument(
"--num_turns",
type=int,
default=5,
help="Number of trajectory turns to generate",
)
process_parser.add_argument(
"--output_file",
type=str,
default="trajectories.jsonl",
help="File to save trajectory data",
)
args, unknown = parser.parse_known_args()
if args.command == "process":
await process_mode(args)
elif args.command is None and not unknown:
print("No command specified, running in default server mode.")
await server_mode()
elif unknown:
print(f"Unknown arguments or command: {unknown}")
parser.print_help()
sys.exit(1)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())