Linting, move env to community

This commit is contained in:
Shannon Sands 2025-05-27 08:53:06 +10:00
parent 67e057b13c
commit 8b09ace467
18 changed files with 945 additions and 646 deletions

View file

@ -1807,6 +1807,239 @@ Access the game at `http://localhost:3000` when running the server.
---
### 22. Physical Space STL CAD RL Environment (`physical_space_stl/`)
**Contributors**: ecsbeats, venkatacrc
**PR**: [#76](https://github.com/NousResearch/atropos/pull/76)
**Integration Status**: ✅ Integrated
**Description**: A reinforcement learning environment for training language models to generate STL (stereolithography) files from 3D wireframe views and technical drawings. This environment bridges the gap between visual 3D understanding and CAD file generation, enabling AI systems to learn computer-aided design skills.
**Core Features**:
**3D Rendering Pipeline**:
- **PyRender Integration**: Offline 3D rendering with GPU acceleration (EGL) or CPU fallback (OSMesa)
- **Multi-View Generation**: Automatic generation of front, top, and diagonal wireframe views
- **Blueprint Styling**: Technical drawing aesthetics with blue wireframes on light backgrounds
- **Mesh Processing**: Support for complex STL files with automatic centering and scaling
**STL Generation Training**:
- **ASCII STL Format**: Focus on human-readable STL file generation
- **Template Variety**: Multiple query templates to encourage diverse reasoning approaches
- **Geometric Understanding**: Training on shape analysis, dimensions, and 3D spatial relationships
- **Quality Assessment**: Multi-metric evaluation of generated vs. original meshes
**Evaluation System**:
- **CLIP-Based Scoring**: Visual similarity assessment between rendered views
- **Geometric Metrics**: Comparison of vertices, faces, volume, and surface area
- **Mesh Validation**: Automatic validation of generated STL file structure
- **Progressive Difficulty**: Adaptive training with increasing geometric complexity
**Technical Architecture**:
**Environment Interface**:
```python
from environments.community.physical_space_stl.physical_env import PhysicalEnv
# Initialize environment
env_config, server_configs = PhysicalEnv.config_init()
env = PhysicalEnv(env_config, server_configs)
# Training loop
await env.setup()
item = await env.get_next_item()
# item contains: prompt, image (rendered views), stl_path
```
**Data Pipeline**:
- **STL File Loading**: Automatic discovery and loading of STL files from sample_data directory
- **Train/Test Split**: 80/20 split with reproducible random seeding
- **Image Rendering**: Real-time generation of wireframe views for each STL file
- **Query Generation**: Dynamic prompt creation with multiple template variations
**Rendering System**:
```python
from environments.community.physical_space_stl.pyrender_utils import PyRenderOffline
# Initialize renderer
renderer = PyRenderOffline(width=224, height=224) # CLIP-compatible size
# Render mesh to multiple views
images = renderer.render_mesh_to_images(mesh)
# Returns: [front_view, top_view, diagonal_view]
```
**Camera Configuration**:
- **Front View**: Standard orthographic projection along Z-axis
- **Top View**: Overhead perspective for plan view understanding
- **Diagonal View**: 3D perspective for spatial relationship comprehension
- **Lighting Setup**: Multi-point lighting for clear wireframe visibility
**STL Processing**:
**File Format Support**:
- **ASCII STL**: Primary focus for human-readable generation
- **Binary STL**: Loading support for existing files
- **Mesh Validation**: Automatic checking of facet normals and vertex ordering
- **Error Handling**: Graceful degradation for malformed files
**Quality Metrics**:
```python
def score_meshes_similarity(original_mesh, generated_mesh):
# Multi-dimensional similarity assessment
metrics = {
"vertex_ratio": min(gen_vertices / orig_vertices, 1.0),
"face_ratio": min(gen_faces / orig_faces, 1.0),
"volume_ratio": min(gen_volume / orig_volume, 1.0),
"area_ratio": min(gen_area / orig_area, 1.0)
}
return sum(metrics.values()) / len(metrics)
```
**Training Data Generation**:
**Dataset Creation Pipeline**:
```bash
# Generate training dataset
python dataset_scr.py # Create directory structure
python render_stl.py # Generate images from STL files
python llm_label.py # Create text descriptions
python prepare_push_hf_dataset.py # Upload to Hugging Face
```
**Data Structure**:
```
dataset/
├── stls/ # Original STL files
│ ├── model_0001.stl
│ └── model_0002.stl
├── images/ # Rendered wireframe views
│ ├── model_0001.png
│ └── model_0002.png
└── labels.json # Text descriptions and metadata
```
**Hugging Face Integration**:
- **Dataset Upload**: Automatic preparation and upload to HF Hub
- **Feature Extraction**: STL geometric features (centroid, bounding box, volume)
- **Image Processing**: Standardized image formats for training
- **Metadata Storage**: JSON-serialized geometric properties
**System Prompt Design**:
**Expert Persona**: "You are an expert in 3D modeling and computer-aided design..."
**Task Specification**:
- **Input**: Wireframe views and technical drawings
- **Output**: Valid ASCII STL file content
- **Reasoning**: Encouraged use of `<think>` tags for geometric analysis
- **Format**: Strict `<stl>` tag enclosure for generated content
**Example Templates**:
- "Create a 3D model (STL file) for the object shown in these technical drawings. Be precise with the geometry."
- "Based on these wireframe views, generate the STL code for this 3D object. Pay attention to all visible features."
- "Using these blueprint images as reference, provide the STL file format data to recreate this 3D model."
**Performance Optimization**:
**Rendering Efficiency**:
- **Headless Operation**: EGL/OSMesa for server environments
- **GPU Acceleration**: Automatic detection and utilization
- **Memory Management**: Efficient mesh processing and cleanup
- **Batch Processing**: Support for multiple STL files
**Computational Requirements**:
- **Dependencies**: pyrender, trimesh, pyglet, matplotlib, torch, transformers
- **System Libraries**: libglfw3-dev, libgles2-mesa-dev (Ubuntu)
- **GPU Support**: Optional but recommended for rendering performance
- **Memory Usage**: Scales with STL file complexity and batch size
**Research Applications**:
**3D Understanding**:
- **Spatial Reasoning**: Training models to understand 3D geometry from 2D projections
- **CAD Generation**: Automated creation of manufacturable 3D models
- **Design Iteration**: Rapid prototyping through AI-assisted design
- **Geometric Constraints**: Learning physical and manufacturing constraints
**Vision-Language Integration**:
- **Multi-Modal Learning**: Combining visual and textual understanding of 3D objects
- **Technical Communication**: Bridging natural language and CAD representations
- **Design Documentation**: Automatic generation of technical specifications
- **Educational Tools**: Interactive learning for 3D modeling concepts
**Manufacturing Applications**:
- **Rapid Prototyping**: AI-assisted design for 3D printing
- **Quality Control**: Automated verification of CAD file accuracy
- **Design Optimization**: Iterative improvement of 3D models
- **Accessibility**: Democratizing CAD design through natural language interfaces
**Setup Instructions**:
**Environment Setup**:
```bash
# Install Python dependencies
pip install pyrender trimesh pyglet matplotlib torch transformers pydantic vllm numpy requests tenacity wandb
# Ubuntu system dependencies for rendering
sudo apt-get install libglfw3-dev libgles2-mesa-dev libnvidia-gl-570-server
# Set rendering backend (choose one)
export PYOPENGL_PLATFORM=egl # For GPU acceleration
export PYOPENGL_PLATFORM=osmesa # For CPU-only environments
```
**Data Preparation**:
```bash
# Create sample data directory
mkdir sample_data
# Add STL files to sample_data/ directory
# Test rendering system
python test_renderer_example.py
python test_stl_env.py
```
**Training Configuration**:
- **Model**: google/gemma-3-27b-it (configurable)
- **Batch Size**: 12 (adjustable based on memory)
- **Max Tokens**: 2048 (sufficient for complex STL files)
- **Evaluation**: Every 100 steps with 10 test files
**Demo Resources**:
- **Training Run**: [W&B Run dlexyg5r](https://wandb.ai/csxl/atropos-environments_hack0/runs/dlexyg5r)
- **GRPO Training**: [W&B Run t61am7gu](https://wandb.ai/csxl/grpo-physical-trainer/runs/t61am7gu)
- **Test Images**: Rendered sphere views demonstrating wireframe quality
- **Sample Data**: HTML visualization of training conversations
**Future Enhancements**:
**Advanced Rendering**:
- **Texture Support**: Material and surface property visualization
- **Animation**: Time-series rendering for dynamic objects
- **Cross-Sections**: Internal structure visualization
- **Assembly Views**: Multi-part object rendering
**Enhanced Evaluation**:
- **Geometric Accuracy**: More sophisticated similarity metrics
- **Manufacturing Constraints**: Validation of printability and structural integrity
- **User Studies**: Human evaluation of generated designs
- **Benchmark Datasets**: Standardized test suites for CAD generation
**Integration Opportunities**:
- **CAD Software**: Direct integration with professional design tools
- **3D Printing**: Seamless workflow to physical prototypes
- **Simulation**: Physics-based validation of generated designs
- **Collaborative Design**: Multi-agent design environments
**Research Impact**: This environment represents a significant step toward AI-assisted computer-aided design, potentially revolutionizing how 3D models are created and iterated. The combination of visual understanding and structured output generation opens new possibilities for democratizing design tools and accelerating product development cycles.
**Educational Value**: The environment serves as an excellent introduction to 3D graphics programming, mesh processing, and the intersection of AI with traditional engineering disciplines. The clear separation between rendering, evaluation, and generation components makes it suitable for educational use and research extension.
**Requirements**: pyrender, trimesh, pyglet, matplotlib, torch, transformers, pydantic, vllm, numpy, requests, tenacity, wandb, atroposlib
---
## Support
For questions or issues with community environments:

View file

@ -33,4 +33,3 @@ $ sudo apt-get install libglfw3-dev libgles2-mesa-dev libnvidia-gl-570-server
Generated run: https://wandb.ai/csxl/atropos-environments_hack0/runs/dlexyg5r
Training run (ran out of memory): https://wandb.ai/csxl/grpo-physical-trainer/runs/t61am7gu

View file

@ -1,3 +1,3 @@
from .physical_env import PhysicalEnv
__all__ = ["PhysicalEnv"]
__all__ = ["PhysicalEnv"]

View file

@ -16,7 +16,7 @@ for name in os.listdir(base_dir):
# print(f"Found directory: {name}")
stl_file_fpath = os.path.join(stl_dir, name)
stl_file_fpath += ".stl"
#print(stl_file_path)
# print(stl_file_path)
ds_stl_fpath = os.path.join(ds_stl_path, name)
ds_stl_fpath += ".stl"
shutil.copy(stl_file_fpath, ds_stl_path)
@ -24,7 +24,4 @@ for name in os.listdir(base_dir):
ds_img_fpath = os.path.join(ds_img_path, name)
ds_img_fpath += "_0001.png"
shutil.copy(base_img_fpath, ds_img_fpath)
#print(base_img_fpath, ds_img_fpath)
# print(base_img_fpath, ds_img_fpath)

View file

@ -15,12 +15,13 @@ import numpy as np
import requests
import torch
import torch.nn.functional as F
import wandb # Added for logging
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb # Added for logging
# Global variable to keep track of the vLLM process
vllm_process = None
@ -538,7 +539,7 @@ if __name__ == "__main__":
training_steps=20,
vllm_restart_interval=3,
use_wandb=True,
wandb_project="grpo-physical-trainer"
wandb_project="grpo-physical-trainer",
)
train(training_config)

View file

@ -1,7 +1,8 @@
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
class CLIPScorer:
def __init__(self, model_name="openai/clip-vit-base-patch32"):
@ -11,36 +12,42 @@ class CLIPScorer:
self.processor = CLIPProcessor.from_pretrained(model_name)
print(f"CLIPScorer initialized on {self.device} with {model_name}")
except Exception as e:
print(f"Error initializing CLIPModel: {e}. Ensure model name is correct and you have internet.")
print(
f"Error initializing CLIPModel: {e}. Ensure model name is correct and you have internet."
)
self.model = None
self.processor = None
raise
@torch.no_grad() # Ensure no gradients are computed during inference
@torch.no_grad() # Ensure no gradients are computed during inference
def score_images(self, images_np_list: list, target_text_description: str):
if not self.model or not self.processor:
print("CLIPScorer not properly initialized.")
return [0.0] * len(images_np_list) # Low score on error
return [0.0] * len(images_np_list) # Low score on error
try:
pil_images = [Image.fromarray(img_arr.astype(np.uint8)) for img_arr in images_np_list]
pil_images = [
Image.fromarray(img_arr.astype(np.uint8)) for img_arr in images_np_list
]
inputs = self.processor(
text=[target_text_description], # Single text prompt
text=[target_text_description], # Single text prompt
images=pil_images,
return_tensors="pt",
padding=True,
truncation=True
truncation=True,
).to(self.device)
outputs = self.model(**inputs)
image_text_similarity_scores = outputs.logits_per_image.squeeze().tolist() # Squeeze to remove the text dim
image_text_similarity_scores = (
outputs.logits_per_image.squeeze().tolist()
) # Squeeze to remove the text dim
if not isinstance(image_text_similarity_scores, list): # If only one image
if not isinstance(image_text_similarity_scores, list): # If only one image
image_text_similarity_scores = [image_text_similarity_scores]
return image_text_similarity_scores
except Exception as e:
print(f"Error in CLIP scoring: {e}")
return [0.0] * len(images_np_list) # Low score on error
return [0.0] * len(images_np_list) # Low score on error

View file

@ -1,10 +1,15 @@
import os
import json
import trimesh
import os
import torch
import trimesh
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BlipForConditionalGeneration,
BlipProcessor,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -14,12 +19,17 @@ LABEL_FILE = "dataset/labels.json"
# Load BLIP for image captioning
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
blip_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
).to(device)
# Load Mistral or other small LLM
llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForCausalLM.from_pretrained(llm_model_name, torch_dtype=torch.float16, device_map="auto")
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name, torch_dtype=torch.float16, device_map="auto"
)
def extract_trimesh_features(mesh):
return {
@ -32,6 +42,7 @@ def extract_trimesh_features(mesh):
"euler_number": mesh.euler_number,
}
def caption_image(image_path):
raw_image = Image.open(image_path).convert("RGB")
inputs = blip_processor(raw_image, return_tensors="pt").to(device)
@ -39,6 +50,7 @@ def caption_image(image_path):
caption = blip_processor.decode(out[0], skip_special_tokens=True)
return caption
def generate_label_with_llm(features, caption):
prompt = f"""You are a 3D object classifier.
@ -56,6 +68,7 @@ Label:"""
output_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
return output_text.split("Label:")[-1].strip()
def main():
labels = {}
@ -86,6 +99,6 @@ def main():
print(f"\nSaved labels to {LABEL_FILE}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,503 @@
import glob
import os
import random
import re
from typing import Dict, List, Optional, Tuple, TypedDict, Union
import numpy as np
import trimesh
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
ScoredDataGroup,
)
from atroposlib.type_definitions import Item, number
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# Fix the relative imports for running directly
try:
from .judgement_model import CLIPScorer
from .pyrender_utils import PyRenderOffline
except ImportError:
from judgement_model import CLIPScorer
from pyrender_utils import PyRenderOffline
system_prompt = (
"You are an expert in 3D modeling and computer-aided design. Your task is to analyze the "
"blueprints or wireframe views of objects and generate the corresponding STL file content. "
"STL (stereolithography) files represent 3D models as a collection of triangular facets.\n\n"
"You may use <think> </think> tags to work through your reasoning about the shape, "
"dimensions, and geometric features of the model. Be methodical in your approach.\n\n"
"STL files can be in ASCII or binary format. For this task, generate ASCII STL content that "
"accurately represents the 3D model shown in the provided views.\n\n"
"Your final output must be enclosed in <stl> </stl> tags, containing only the valid STL content "
"and nothing else. The STL content should begin with 'solid' and end with 'endsolid'.\n\n"
"Example of STL format:\n"
"<stl>\n"
"solid model\n"
" facet normal nx ny nz\n"
" outer loop\n"
" vertex x1 y1 z1\n"
" vertex x2 y2 z2\n"
" vertex x3 y3 z3\n"
" endloop\n"
" endfacet\n"
" ... more facets ...\n"
"endsolid model\n"
"</stl>"
)
class PhysicalRow(TypedDict):
prompt: str
image: np.ndarray
stl: str
class PhysicalEnv(BaseEnv):
name = "physical"
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
# Add tracking for wandb visualizations
self.rollouts_for_wandb = []
self.completion_lengths = []
# Initialize renderer and CLIP scorer
self.renderer = PyRenderOffline(width=224, height=224)
self.clip_scorer = CLIPScorer()
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="google/gemma-3-27b-it",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
max_token_length=2048,
wandb_name="physical",
)
server_configs = [
APIServerConfig(
model_name="google/gemma-3-27b-it",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
# Try to calculate percent_correct, pass if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
# Call the parent method to handle the server metrics
await super().wandb_log(wandb_metrics)
def load_stl_file(self, stl_path):
"""Load an STL file into a trimesh object"""
try:
mesh = trimesh.load(stl_path)
return mesh
except Exception as e:
print(f"Error loading STL file {stl_path}: {e}")
return None
def generate_query_from_images(self, images):
"""Generate a query based on the rendered images of the STL file"""
# In a real implementation, this would use a vision model to generate a description
# For this simplified version, we'll use different templates to add variety
templates = [
"Create a 3D model (STL file) for the object shown in these technical drawings. "
"Be precise with the geometry.",
"Based on these wireframe views, generate the STL code for this 3D object. "
"Pay attention to all visible features.",
"Using these blueprint images as reference, provide the STL file format data "
"to recreate this 3D model.",
"These are technical views of a 3D object. Generate the STL representation "
"that would produce this exact shape.",
"Reconstruct this 3D model from the provided wireframe views and output "
"the STL file content.",
]
return random.choice(templates)
async def setup(self):
# Load all STL files from sample_data
self.stl_files = glob.glob(os.path.join("sample_data", "*.stl"))
if not self.stl_files:
raise ValueError("No STL files found in the sample_data directory")
print(f"Found {len(self.stl_files)} STL files")
# Split files into train and test sets (80/20 split)
random.seed(42)
random.shuffle(self.stl_files)
split_idx = int(len(self.stl_files) * 0.8)
self.train_files = self.stl_files[:split_idx]
self.test_files = self.stl_files[split_idx:]
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
async def rollout_and_score_eval(self, stl_path: str) -> number:
# Load the STL file
mesh = self.load_stl_file(stl_path)
if mesh is None:
return 0
# Render the images
images = self.renderer.render_mesh_to_images(mesh)
# Generate a query from the images
query = self.generate_query_from_images(images)
# Get a completion from the model
completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
split="eval",
)
# Extract the STL content from the completion
response_content = completion.choices[0].message.content
stl_content = self.extract_stl_content(response_content)
# Load the original mesh directly
original_mesh = mesh
# Save the generated STL content to a temporary file
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
try:
with open(temp_file, "w") as f:
f.write(stl_content)
# Load the generated mesh
generated_mesh = trimesh.load(temp_file)
# Score the generated mesh against the original
score = self.score_meshes_similarity(original_mesh, generated_mesh)
# Cleanup
os.remove(temp_file)
return score
except Exception as e:
print(f"Error processing generated STL: {e}")
if os.path.exists(temp_file):
os.remove(temp_file)
return 0
def extract_stl_content(self, response_content):
"""Extract STL content from the model's response"""
# Find content between <stl> and </stl> tags
match = re.search(r"<stl>(.*?)</stl>", response_content, re.DOTALL)
if match:
return match.group(1).strip()
return ""
def score_meshes_similarity(self, original_mesh, generated_mesh):
"""Score the similarity between two meshes"""
# This is a simple implementation - in practice you'd want more sophisticated metrics
# Compare basic properties like number of vertices, faces, and volume
orig_stats = {
"vertices": len(original_mesh.vertices),
"faces": len(original_mesh.faces),
"volume": original_mesh.volume or 1.0,
"surface_area": original_mesh.area or 1.0,
}
gen_stats = {
"vertices": len(generated_mesh.vertices),
"faces": len(generated_mesh.faces),
"volume": generated_mesh.volume or 1.0,
"surface_area": generated_mesh.area or 1.0,
}
# Calculate ratios (capped at 1.0 for when generated > original)
vertex_ratio = min(gen_stats["vertices"] / max(orig_stats["vertices"], 1), 1.0)
face_ratio = min(gen_stats["faces"] / max(orig_stats["faces"], 1), 1.0)
volume_ratio = min(gen_stats["volume"] / max(orig_stats["volume"], 1), 1.0)
area_ratio = min(
gen_stats["surface_area"] / max(orig_stats["surface_area"], 1), 1.0
)
# Average the ratios for a final score
score = (vertex_ratio + face_ratio + volume_ratio + area_ratio) / 4.0
return score
async def evaluate(self, *args, **kwargs):
eval_tasks = []
for stl_file in self.test_files[
:10
]: # Limit to 10 files for evaluation to keep it manageable
eval_tasks.append(self.rollout_and_score_eval(stl_file))
scores = await tqdm_asyncio.gather(*eval_tasks)
self.eval_metrics.append(("eval/similarity_score", sum(scores) / len(scores)))
async def collect_trajectories(
self, item: PhysicalRow
) -> Tuple[ScoredDataGroup, list[Item]]:
stl_path = item["stl_path"]
mesh = self.load_stl_file(stl_path)
images = self.renderer.render_mesh_to_images(mesh)
query = self.generate_query_from_images(images)
# For original STL content, we'll just store the file path instead of the content
# as the files may be binary and can't be simply read as text
original_stl_path = stl_path
user_message = {"role": "user", "content": query}
chat_completions = await self.server.chat_completion(
messages=[{"role": "system", "content": system_prompt}, user_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
)
to_score = list()
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
messages = (
{"role": "system", "content": system_prompt},
user_message,
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append(
{
"messages": messages,
"original_stl_path": original_stl_path,
"images": images,
"finish_reason": chat_completion.finish_reason,
}
)
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog
async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
response_content = item["messages"][-1]["content"]
stl_content = self.extract_stl_content(response_content)
# Save the generated STL content to a temporary file
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
try:
with open(temp_file, "w") as f:
f.write(stl_content)
# Load the original STL directly from its path
original_stl_path = item["original_stl_path"]
original_mesh = trimesh.load(original_stl_path)
# Load the generated mesh
generated_mesh = trimesh.load(temp_file)
# Score the generated mesh against the original
mesh_similarity = self.score_meshes_similarity(
original_mesh, generated_mesh
)
# Generate rendered images of the produced STL
generated_images = self.renderer.render_mesh_to_images(generated_mesh)
# Use CLIP to score the visual similarity
images_reward = 0.0
if len(generated_images) > 0 and len(item["images"]) > 0:
# Extract query from the user message
query = item["messages"][1]["content"]
# Score the visual similarity using CLIP
clip_scores = self.clip_scorer.score_images(generated_images, query)
images_reward = (
sum(clip_scores) / len(clip_scores) / 100.0
) # Normalize to roughly 0-1
# Combine mesh similarity and image similarity for final reward
reward = 0.5 * mesh_similarity + 0.5 * images_reward
out_dict = tokenize_for_trainer(
self.tokenizer, item["messages"], item["finish_reason"]
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(reward)
self.percent_correct_buffer.append(reward)
# Clean up temporary files
os.remove(temp_file)
if len(scores["tokens"]) >= self.config.group_size:
break
except Exception as e:
print(f"Error in scoring: {e}")
if os.path.exists(temp_file):
os.remove(temp_file)
# Apply length penalty if all scores are similar
if all(abs(score - scores["scores"][0]) < 0.1 for score in scores["scores"]):
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) > 0:
max_allowed_length = self.config.max_token_length
length_threshold = max_allowed_length * 0.5
scores["scores"] = []
for length in token_lengths:
if length <= length_threshold:
scores["scores"].append(1.0)
else:
percentage_of_range = (length - length_threshold) / (
max_allowed_length - length_threshold
)
percentage_of_range = min(percentage_of_range, 1.0)
scores["scores"].append(1.0 - percentage_of_range)
if all([scores["scores"][0] == score for score in scores["scores"]]):
return None # If all the same, we return None
return scores
async def get_next_item(self) -> PhysicalRow:
stl_path = self.train_files[self.iter % len(self.train_files)]
self.iter += 1
# Load the STL file and render it
mesh = self.load_stl_file(stl_path)
if mesh is None:
# Skip this file and try the next one if there's an issue
return await self.get_next_item()
# Render the mesh to get images
images = self.renderer.render_mesh_to_images(mesh)
# Generate a query from the images
query = self.generate_query_from_images(images)
# Return a row with the prompt, image, and path to the STL file
return {
"prompt": query,
"image": images[0] if images else np.zeros((224, 224, 3), dtype=np.uint8),
"stl_path": stl_path,
}
@classmethod
async def test_sample_stl(cls):
"""Test loading and rendering a sample STL file"""
# Create temporary environment instance
env_config = BaseEnvConfig(
tokenizer_name="google/gemma-3-27b-it",
group_size=8,
use_wandb=False,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
max_token_length=2048,
wandb_name="physical_test",
)
server_configs = [
APIServerConfig(
model_name="google/gemma-3-27b-it",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
env = cls(env_config, server_configs, slurm=False, testing=True)
# Find sample STL files
stl_files = glob.glob(os.path.join("sample_data", "*.stl"))
if not stl_files:
print("No STL files found in sample_data/")
return
# Test loading and rendering the first file
print(f"Testing with STL file: {stl_files[0]}")
mesh = env.load_stl_file(stl_files[0])
if mesh is None:
print("Failed to load STL file")
return
print(
f"Loaded mesh with {len(mesh.vertices)} vertices and {len(mesh.faces)} faces"
)
# Render the mesh
try:
images = env.renderer.render_mesh_to_images(mesh)
print(f"Successfully rendered {len(images)} images")
# Save the first image for inspection
from PIL import Image
img = Image.fromarray(images[0])
img.save("test_render.png")
print("Saved test render to test_render.png")
except Exception as e:
print(f"Error rendering: {e}")
print("Test completed")
if __name__ == "__main__":
PhysicalEnv.cli()

View file

@ -1,8 +1,8 @@
import os
import json
import numpy as np
import os
import trimesh
from datasets import Dataset, Features, Value, Image
from datasets import Dataset, Features, Image, Value
from huggingface_hub import login
# Log in to HF Hub (optional if you've already done `huggingface-cli login`)
@ -23,10 +23,10 @@ for image_filename in os.listdir(image_dir):
if not image_filename.endswith(".png"):
continue
image_path = os.path.join(image_dir, image_filename)
# Extract base ID
base_id = image_filename.split("_")[0]
stl_path = os.path.join(stl_dir, f"{base_id}.stl")
label = labels.get(base_id, "unknown")
@ -42,20 +42,24 @@ for image_filename in os.listdir(image_dir):
except Exception as e:
print(f"⚠️ Failed to process {stl_path}: {e}")
data.append({
"image": image_path,
"label": label,
"stl_features": stl_features,
"id": base_id,
})
data.append(
{
"image": image_path,
"label": label,
"stl_features": stl_features,
"id": base_id,
}
)
# Define dataset schema
features = Features({
"id": Value("string"),
"image": Image(), # Load images from file paths
"label": Value("string"),
"stl_features": Value("string"), # Store as JSON string for simplicity
})
features = Features(
{
"id": Value("string"),
"image": Image(), # Load images from file paths
"label": Value("string"),
"stl_features": Value("string"), # Store as JSON string for simplicity
}
)
# Convert stl_features to JSON strings for compatibility
for item in data:
@ -66,4 +70,3 @@ dataset = Dataset.from_list(data).cast(features)
# Push to Hub
dataset.push_to_hub("venkatacrc/stl-image-dataset", private=True)

View file

@ -1,223 +1,240 @@
import os
import numpy as np
import trimesh
import pyrender
import trimesh
# Headless rendering with GPU acceleration (egl), for non-GPU environments omesa will be faster
os.environ['PYOPENGL_PLATFORM'] = 'egl'
os.environ["PYOPENGL_PLATFORM"] = "egl"
def create_look_at_matrix(eye: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray:
def create_look_at_matrix(
eye: np.ndarray, target: np.ndarray, up: np.ndarray
) -> np.ndarray:
"""
Create a look-at transformation matrix for a camera.
eye - position of the camera
target - position the camera is looking at
up - up direction for the camera
returns the 4x4 transformation matrix
"""
eye = np.asarray(eye)
target = np.asarray(target)
up = np.asarray(up)
# Forward vector (from eye to target)
forward = target - eye
forward = forward / np.linalg.norm(forward)
# Side vector (right vector)
side = np.cross(forward, up)
side = side / np.linalg.norm(side)
# Recompute up vector to ensure orthogonality
up = np.cross(side, forward)
up = up / np.linalg.norm(up)
# Create the rotation matrix | NOTE: PyRender uses OpenGL convention
# where camera looks down negative z-axis
R = np.eye(4)
R[0, :3] = side
R[1, :3] = up
R[2, :3] = -forward
T = np.eye(4) # translation matrix
T = np.eye(4) # translation matrix
T[:3, 3] = eye
# The camera pose matrix is the inverse of the view matrix
# but we need to return the pose directly for pyrender
return T @ R
class PyRenderOffline:
def __init__(self, width=224, height=224): # Standard CLIP image size
def __init__(self, width=224, height=224): # Standard CLIP image size
self.width = width
self.height = height
try:
self.renderer = pyrender.OffscreenRenderer(viewport_width=self.width, viewport_height=self.height, point_size=1.0)
self.renderer = pyrender.OffscreenRenderer(
viewport_width=self.width, viewport_height=self.height, point_size=1.0
)
except Exception as e:
print(f"Failed to initialize OffscreenRenderer (is a display server/EGL/OSMesa available?): {e}")
print("Try: pip install pyglet; or for headless: export PYOPENGL_PLATFORM=osmesa (or egl)")
self.renderer = None # Fallback or raise error
print(
f"Failed to initialize OffscreenRenderer (is a display server/EGL/OSMesa available?): {e}"
)
print(
"Try: pip install pyglet; or for headless: export PYOPENGL_PLATFORM=osmesa (or egl)"
)
self.renderer = None # Fallback or raise error
raise
# Create camera poses using explicit transformation matrices
# Front view (looking along the -Z axis)
front_pose = np.array([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 5],
[0, 0, 0, 1]
])
front_pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 5], [0, 0, 0, 1]])
# Top view (looking along the -Y axis)
top_pose = np.array([
[1, 0, 0, 0],
[0, 0, 1, 5],
[0, -1, 0, 0],
[0, 0, 0, 1]
])
top_pose = np.array([[1, 0, 0, 0], [0, 0, 1, 5], [0, -1, 0, 0], [0, 0, 0, 1]])
# Diagonal view (from upper corner)
side_pose = np.array([
[0.866, -0.25, 0.433, 3], # Camera right vector
[0.0, 0.866, 0.5, 3], # Camera up vector
[-0.5, -0.433, 0.75, 5], # Camera forward vector (pointing at origin)
[0, 0, 0, 1]
])
side_pose = np.array(
[
[0.866, -0.25, 0.433, 3], # Camera right vector
[0.0, 0.866, 0.5, 3], # Camera up vector
[-0.5, -0.433, 0.75, 5], # Camera forward vector (pointing at origin)
[0, 0, 0, 1],
]
)
# Store camera poses
self.camera_poses = [front_pose, top_pose, side_pose]
# Debug print for the camera poses cause it doesn't look right
# and I'm not a game developer lol
print("Camera poses:")
for i, pose in enumerate(self.camera_poses):
print(f"Camera {i}:\n{pose}")
# slightly wider field of view to ensure objects are visible
self.camera = pyrender.PerspectiveCamera(yfov=np.pi / 2.5, aspectRatio=1.0)
# Bright point light
self.light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=10.0)
def render_mesh_to_images(self, mesh_obj: trimesh.Trimesh):
if not self.renderer:
print("Renderer not initialized, cannot render.")
return [np.zeros((self.height, self.width, 3), dtype=np.uint8) for _ in range(3)]
return [
np.zeros((self.height, self.width, 3), dtype=np.uint8) for _ in range(3)
]
print(
f"Rendering mesh with {len(mesh_obj.vertices)} vertices and {len(mesh_obj.faces)} faces"
)
print(f"Rendering mesh with {len(mesh_obj.vertices)} vertices and {len(mesh_obj.faces)} faces")
images = []
# Make a copy to avoid modifying original mesh
render_mesh = mesh_obj.copy()
# Center and scale the mesh for visibility
render_mesh.apply_translation(-render_mesh.centroid)
scale_factor = 0.8 / np.max(render_mesh.extents) # Scale to unit size but slightly smaller
scale_factor = 0.8 / np.max(
render_mesh.extents
) # Scale to unit size but slightly smaller
render_mesh.apply_scale(scale_factor)
# Create a ground plane
ground_plane = trimesh.creation.box([4.0, 4.0, 0.01])
ground_plane.apply_translation([0, 0, -0.5])
# Color scheme for wireframe blueprint
blueprint_bg = [0.98, 0.98, 1.0, 1.0] # Light blue background
wireframe_color = [0.0, 0.4, 0.8, 1.0] # Medium bright blue for wireframes
grid_color = [0.0, 0.2, 0.4, 1.0] # Darker blue for grid
blueprint_bg = [0.98, 0.98, 1.0, 1.0] # Light blue background
wireframe_color = [0.0, 0.4, 0.8, 1.0] # Medium bright blue for wireframes
grid_color = [0.0, 0.2, 0.4, 1.0] # Darker blue for grid
# Wireframe material for the edges - completely opaque
wireframe_material = pyrender.MetallicRoughnessMaterial(
baseColorFactor=wireframe_color,
metallicFactor=0.0,
roughnessFactor=1.0,
wireframe=True
wireframe=True,
)
# Grid material for the ground plane
grid_material = pyrender.MetallicRoughnessMaterial(
baseColorFactor=grid_color,
metallicFactor=0.0,
roughnessFactor=1.0,
wireframe=True
wireframe=True,
)
for i, pose in enumerate(self.camera_poses):
# Create a fresh scene with blueprint background
scene = pyrender.Scene(ambient_light=[0.8, 0.8, 0.95], bg_color=blueprint_bg)
scene = pyrender.Scene(
ambient_light=[0.8, 0.8, 0.95], bg_color=blueprint_bg
)
# Add ground plane as grid
plane_mesh = pyrender.Mesh.from_trimesh(ground_plane, material=grid_material)
plane_mesh = pyrender.Mesh.from_trimesh(
ground_plane, material=grid_material
)
scene.add(plane_mesh)
# For wireframe rendering, we'll create line segments directly
edges = render_mesh.edges_unique
edge_vertices = []
edge_indices = []
# Extract all edges for wireframe rendering
for _, edge in enumerate(edges):
v0_idx = len(edge_vertices)
edge_vertices.append(render_mesh.vertices[edge[0]])
edge_vertices.append(render_mesh.vertices[edge[1]])
edge_indices.append([v0_idx, v0_idx+1])
edge_indices.append([v0_idx, v0_idx + 1])
# Create lines primitive for edges
if len(edge_vertices) > 0:
edge_verts = np.array(edge_vertices, dtype=np.float32)
edge_indices = np.array(edge_indices, dtype=np.uint32)
# Create a primitive for the lines
primitive = pyrender.Primitive(
positions=edge_verts,
indices=edge_indices,
mode=pyrender.constants.GLTF.LINES,
material=wireframe_material
material=wireframe_material,
)
# Create a mesh with just the line primitive
edge_mesh = pyrender.Mesh(primitives=[primitive])
scene.add(edge_mesh)
# Add camera
scene.add(self.camera, pose=pose)
# Add light from camera direction (key light)
scene.add(self.light, pose=pose)
# Add second light from above for better visibility
top_light_pose = np.eye(4)
top_light_pose[1, 3] = 3.0
scene.add(self.light, pose=top_light_pose)
try:
color, _ = self.renderer.render(scene)
# Post-process to enhance blueprint effect
# Convert to float for processing
img_float = color.astype(np.float32) / 255.0
# Add a subtle grid pattern to the background
grid_size = 40
grid_intensity = 0.02
# Draw faint horizontal grid lines
for y in range(0, color.shape[0], grid_size):
img_float[y:y+1, :, :] = np.minimum(img_float[y:y+1, :, :] + grid_intensity, 1.0)
img_float[y : y + 1, :, :] = np.minimum(
img_float[y : y + 1, :, :] + grid_intensity, 1.0
)
# Draw faint vertical grid lines
for x in range(0, color.shape[1], grid_size):
img_float[:, x:x+1, :] = np.minimum(img_float[:, x:x+1, :] + grid_intensity, 1.0)
img_float[:, x : x + 1, :] = np.minimum(
img_float[:, x : x + 1, :] + grid_intensity, 1.0
)
# Convert back to uint8
processed_img = (img_float * 255).astype(np.uint8)
images.append(processed_img)
print(f"Rendered wireframe view {i}")
except Exception as e:
print(f"Error during rendering view {i}: {e}")
images.append(np.zeros((self.height, self.width, 3), dtype=np.uint8))
return images
def __del__(self):

View file

@ -1,17 +1,18 @@
import bpy
import sys
import math
import os
import sys
import bpy
# Get args after --
argv = sys.argv
argv = argv[argv.index("--") + 1:] # args after --
argv = argv[argv.index("--") + 1 :] # args after --
input_stl = argv[0]
output_dir = argv[1]
# Clear existing objects
bpy.ops.object.select_all(action='SELECT')
bpy.ops.object.select_all(action="SELECT")
bpy.ops.object.delete(use_global=False)
# Import STL
@ -19,11 +20,11 @@ bpy.ops.import_mesh.stl(filepath=input_stl)
obj = bpy.context.selected_objects[0]
# Center the object at origin
bpy.ops.object.origin_set(type='ORIGIN_CENTER_OF_MASS', center='MEDIAN')
bpy.ops.object.origin_set(type="ORIGIN_CENTER_OF_MASS", center="MEDIAN")
obj.location = (0, 0, 0)
# Add Sun light
sun_light_data = bpy.data.lights.new(name="SunLight", type='SUN')
sun_light_data = bpy.data.lights.new(name="SunLight", type="SUN")
sun_light_object = bpy.data.objects.new(name="SunLight", object_data=sun_light_data)
sun_light_object.location = (10, 10, 10)
bpy.context.collection.objects.link(sun_light_object)
@ -52,10 +53,9 @@ for i, angle in enumerate(angles):
# Point camera to object center (0,0,0)
direction = -cam_obj.location
rot_quat = direction.to_track_quat('-Z', 'Y')
rot_quat = direction.to_track_quat("-Z", "Y")
cam_obj.rotation_euler = rot_quat.to_euler()
# Render
bpy.context.scene.render.filepath = os.path.join(output_dir, f"render_{i}.png")
bpy.ops.render.render(write_still=True)

View file

Before

Width:  |  Height:  |  Size: 47 KiB

After

Width:  |  Height:  |  Size: 47 KiB

Before After
Before After

View file

@ -1,9 +1,8 @@
import os
import numpy as np
import trimesh
import matplotlib.pyplot as plt
import trimesh
from pyrender_utils import PyRenderOffline
def test_render_example():
"""
Example script to test the PyRenderOffline class with a real mesh.
@ -13,34 +12,41 @@ def test_render_example():
# Create a high-quality sphere for wireframe rendering
# Using more subdivisions for smoother appearance
sphere = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
# Create a perfect sphere instead of a noisy one
# This will look better as a wireframe/blueprint for our test ;)
print(f"Created sphere with {len(sphere.vertices)} vertices and {len(sphere.faces)} faces")
print(f"Sphere has {len(sphere.edges_unique)} unique edges for wireframe rendering")
renderer = PyRenderOffline(width=512, height=512) # larger dimensions than the CLIP size for better detail
print(
f"Created sphere with {len(sphere.vertices)} vertices and {len(sphere.faces)} faces"
)
print(
f"Sphere has {len(sphere.edges_unique)} unique edges for wireframe rendering"
)
renderer = PyRenderOffline(
width=512, height=512
) # larger dimensions than the CLIP size for better detail
images = renderer.render_mesh_to_images(sphere)
# Display the results
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
view_names = ['Front', 'Top', 'Diagonal']
view_names = ["Front", "Top", "Diagonal"]
for i, (img, name) in enumerate(zip(images, view_names)):
axes[i].imshow(img)
axes[i].set_title(f"{name} View")
axes[i].axis('off')
axes[i].axis("off")
plt.savefig('test_rendered_sphere_views.png')
plt.savefig("test_rendered_sphere_views.png")
plt.close()
print(f"Successfully rendered sphere from 3 viewpoints")
print(f"Images saved to rendered_sphere_views.png")
print("Successfully rendered sphere from 3 viewpoints")
print("Images saved to rendered_sphere_views.png")
return True
except Exception as e:
print(f"Failed to run renderer example: {e}")
return False
if __name__ == "__main__":
test_render_example()
test_render_example()

View file

@ -2,10 +2,10 @@
"""
Test script for the PhysicalEnv that loads and processes STL files
"""
import os
import asyncio
import random
from physical_env import PhysicalEnv, BaseEnvConfig, APIServerConfig, EvalHandlingEnum
from physical_env import APIServerConfig, BaseEnvConfig, EvalHandlingEnum, PhysicalEnv
async def test_render_stl():
"""Test loading and rendering an STL file"""
@ -22,7 +22,7 @@ async def test_render_stl():
max_token_length=2048,
wandb_name="physical_test",
)
server_configs = [
APIServerConfig(
model_name="google/gemma-3-27b-it",
@ -31,13 +31,13 @@ async def test_render_stl():
num_requests_for_eval=256,
),
]
print("Creating test environment")
env = PhysicalEnv(env_config, server_configs, slurm=False, testing=True)
print("Setting up environment")
await env.setup()
# Test get_next_item
print("Testing get_next_item")
try:
@ -47,8 +47,9 @@ async def test_render_stl():
print(f"STL path: {item['stl_path']}")
except Exception as e:
print(f"Error getting next item: {e}")
print("Test completed successfully")
if __name__ == "__main__":
asyncio.run(test_render_stl())
asyncio.run(test_render_stl())

View file

@ -1,481 +0,0 @@
import numpy as np
import trimesh
import io
import re
import wandb
import os
import glob
import random
from tqdm.asyncio import tqdm_asyncio
from typing import *
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
EvalHandlingEnum
)
from atroposlib.type_definitions import Item, number
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# Fix the relative imports for running directly
try:
from .pyrender_utils import PyRenderOffline
from .judgement_model import CLIPScorer
except ImportError:
from pyrender_utils import PyRenderOffline
from judgement_model import CLIPScorer
system_prompt = (
"You are an expert in 3D modeling and computer-aided design. Your task is to analyze the "
"blueprints or wireframe views of objects and generate the corresponding STL file content. "
"STL (stereolithography) files represent 3D models as a collection of triangular facets.\n\n"
"You may use <think> </think> tags to work through your reasoning about the shape, "
"dimensions, and geometric features of the model. Be methodical in your approach.\n\n"
"STL files can be in ASCII or binary format. For this task, generate ASCII STL content that "
"accurately represents the 3D model shown in the provided views.\n\n"
"Your final output must be enclosed in <stl> </stl> tags, containing only the valid STL content "
"and nothing else. The STL content should begin with 'solid' and end with 'endsolid'.\n\n"
"Example of STL format:\n"
"<stl>\n"
"solid model\n"
" facet normal nx ny nz\n"
" outer loop\n"
" vertex x1 y1 z1\n"
" vertex x2 y2 z2\n"
" vertex x3 y3 z3\n"
" endloop\n"
" endfacet\n"
" ... more facets ...\n"
"endsolid model\n"
"</stl>"
)
class PhysicalRow(TypedDict):
prompt: str
image: np.ndarray
stl: str
class PhysicalEnv(BaseEnv):
name = "physical"
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
# Add tracking for wandb visualizations
self.rollouts_for_wandb = []
self.completion_lengths = []
# Initialize renderer and CLIP scorer
self.renderer = PyRenderOffline(width=224, height=224)
self.clip_scorer = CLIPScorer()
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="google/gemma-3-27b-it",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
max_token_length=2048,
wandb_name="physical",
)
server_configs = [
APIServerConfig(
model_name="google/gemma-3-27b-it",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
# Try to calculate percent_correct, pass if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
# Call the parent method to handle the server metrics
await super().wandb_log(wandb_metrics)
def load_stl_file(self, stl_path):
"""Load an STL file into a trimesh object"""
try:
mesh = trimesh.load(stl_path)
return mesh
except Exception as e:
print(f"Error loading STL file {stl_path}: {e}")
return None
def generate_query_from_images(self, images):
"""Generate a query based on the rendered images of the STL file"""
# In a real implementation, this would use a vision model to generate a description
# For this simplified version, we'll use different templates to add variety
templates = [
"Create a 3D model (STL file) for the object shown in these technical drawings. Be precise with the geometry.",
"Based on these wireframe views, generate the STL code for this 3D object. Pay attention to all visible features.",
"Using these blueprint images as reference, provide the STL file format data to recreate this 3D model.",
"These are technical views of a 3D object. Generate the STL representation that would produce this exact shape.",
"Reconstruct this 3D model from the provided wireframe views and output the STL file content."
]
return random.choice(templates)
async def setup(self):
# Load all STL files from sample_data
self.stl_files = glob.glob(os.path.join('sample_data', '*.stl'))
if not self.stl_files:
raise ValueError("No STL files found in the sample_data directory")
print(f"Found {len(self.stl_files)} STL files")
# Split files into train and test sets (80/20 split)
random.seed(42)
random.shuffle(self.stl_files)
split_idx = int(len(self.stl_files) * 0.8)
self.train_files = self.stl_files[:split_idx]
self.test_files = self.stl_files[split_idx:]
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
async def rollout_and_score_eval(self, stl_path: str) -> number:
# Load the STL file
mesh = self.load_stl_file(stl_path)
if mesh is None:
return 0
# Render the images
images = self.renderer.render_mesh_to_images(mesh)
# Generate a query from the images
query = self.generate_query_from_images(images)
# Get a completion from the model
completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
split="eval",
)
# Extract the STL content from the completion
response_content = completion.choices[0].message.content
stl_content = self.extract_stl_content(response_content)
# Load the original mesh directly
original_mesh = mesh
# Save the generated STL content to a temporary file
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
try:
with open(temp_file, "w") as f:
f.write(stl_content)
# Load the generated mesh
generated_mesh = trimesh.load(temp_file)
# Score the generated mesh against the original
score = self.score_meshes_similarity(original_mesh, generated_mesh)
# Cleanup
os.remove(temp_file)
return score
except Exception as e:
print(f"Error processing generated STL: {e}")
if os.path.exists(temp_file):
os.remove(temp_file)
return 0
def extract_stl_content(self, response_content):
"""Extract STL content from the model's response"""
# Find content between <stl> and </stl> tags
match = re.search(r'<stl>(.*?)</stl>', response_content, re.DOTALL)
if match:
return match.group(1).strip()
return ""
def score_meshes_similarity(self, original_mesh, generated_mesh):
"""Score the similarity between two meshes"""
# This is a simple implementation - in practice you'd want more sophisticated metrics
# Compare basic properties like number of vertices, faces, and volume
orig_stats = {
'vertices': len(original_mesh.vertices),
'faces': len(original_mesh.faces),
'volume': original_mesh.volume or 1.0,
'surface_area': original_mesh.area or 1.0
}
gen_stats = {
'vertices': len(generated_mesh.vertices),
'faces': len(generated_mesh.faces),
'volume': generated_mesh.volume or 1.0,
'surface_area': generated_mesh.area or 1.0
}
# Calculate ratios (capped at 1.0 for when generated > original)
vertex_ratio = min(gen_stats['vertices'] / max(orig_stats['vertices'], 1), 1.0)
face_ratio = min(gen_stats['faces'] / max(orig_stats['faces'], 1), 1.0)
volume_ratio = min(gen_stats['volume'] / max(orig_stats['volume'], 1), 1.0)
area_ratio = min(gen_stats['surface_area'] / max(orig_stats['surface_area'], 1), 1.0)
# Average the ratios for a final score
score = (vertex_ratio + face_ratio + volume_ratio + area_ratio) / 4.0
return score
async def evaluate(self, *args, **kwargs):
eval_tasks = []
for stl_file in self.test_files[:10]: # Limit to 10 files for evaluation to keep it manageable
eval_tasks.append(self.rollout_and_score_eval(stl_file))
scores = await tqdm_asyncio.gather(*eval_tasks)
self.eval_metrics.append(("eval/similarity_score", sum(scores) / len(scores)))
async def collect_trajectories(self, item: PhysicalRow) -> Tuple[ScoredDataGroup, list[Item]]:
stl_path = item["stl_path"]
mesh = self.load_stl_file(stl_path)
images = self.renderer.render_mesh_to_images(mesh)
query = self.generate_query_from_images(images)
# For original STL content, we'll just store the file path instead of the content
# as the files may be binary and can't be simply read as text
original_stl_path = stl_path
user_message = {"role": "user", "content": query}
chat_completions = await self.server.chat_completion(
messages=[{"role": "system", "content": system_prompt}, user_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
)
to_score = list()
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
messages = (
{"role": "system", "content": system_prompt},
user_message,
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append(
{
"messages": messages,
"original_stl_path": original_stl_path,
"images": images,
"finish_reason": chat_completion.finish_reason,
}
)
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog
async def score(self, rollout_group_data) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
response_content = item["messages"][-1]["content"]
stl_content = self.extract_stl_content(response_content)
# Save the generated STL content to a temporary file
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
try:
with open(temp_file, "w") as f:
f.write(stl_content)
# Load the original STL directly from its path
original_stl_path = item["original_stl_path"]
original_mesh = trimesh.load(original_stl_path)
# Load the generated mesh
generated_mesh = trimesh.load(temp_file)
# Score the generated mesh against the original
mesh_similarity = self.score_meshes_similarity(original_mesh, generated_mesh)
# Generate rendered images of the produced STL
generated_images = self.renderer.render_mesh_to_images(generated_mesh)
# Use CLIP to score the visual similarity
images_reward = 0.0
if len(generated_images) > 0 and len(item["images"]) > 0:
# Extract query from the user message
query = item["messages"][1]["content"]
# Score the visual similarity using CLIP
clip_scores = self.clip_scorer.score_images(generated_images, query)
images_reward = sum(clip_scores) / len(clip_scores) / 100.0 # Normalize to roughly 0-1
# Combine mesh similarity and image similarity for final reward
reward = 0.5 * mesh_similarity + 0.5 * images_reward
out_dict = tokenize_for_trainer(
self.tokenizer, item["messages"], item["finish_reason"]
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(reward)
self.percent_correct_buffer.append(reward)
# Clean up temporary files
os.remove(temp_file)
if len(scores["tokens"]) >= self.config.group_size:
break
except Exception as e:
print(f"Error in scoring: {e}")
if os.path.exists(temp_file):
os.remove(temp_file)
# Apply length penalty if all scores are similar
if all(abs(score - scores["scores"][0]) < 0.1 for score in scores["scores"]):
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) > 0:
max_allowed_length = self.config.max_token_length
length_threshold = max_allowed_length * 0.5
scores["scores"] = []
for length in token_lengths:
if length <= length_threshold:
scores["scores"].append(1.0)
else:
percentage_of_range = (length - length_threshold) / (
max_allowed_length - length_threshold
)
percentage_of_range = min(percentage_of_range, 1.0)
scores["scores"].append(1.0 - percentage_of_range)
if all([scores["scores"][0] == score for score in scores["scores"]]):
return None # If all the same, we return None
return scores
async def get_next_item(self) -> PhysicalRow:
stl_path = self.train_files[self.iter % len(self.train_files)]
self.iter += 1
# Load the STL file and render it
mesh = self.load_stl_file(stl_path)
if mesh is None:
# Skip this file and try the next one if there's an issue
return await self.get_next_item()
# Render the mesh to get images
images = self.renderer.render_mesh_to_images(mesh)
# Generate a query from the images
query = self.generate_query_from_images(images)
# Return a row with the prompt, image, and path to the STL file
return {
"prompt": query,
"image": images[0] if images else np.zeros((224, 224, 3), dtype=np.uint8),
"stl_path": stl_path
}
@classmethod
async def test_sample_stl(cls):
"""Test loading and rendering a sample STL file"""
# Create temporary environment instance
env_config = BaseEnvConfig(
tokenizer_name="google/gemma-3-27b-it",
group_size=8,
use_wandb=False,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
max_token_length=2048,
wandb_name="physical_test",
)
server_configs = [
APIServerConfig(
model_name="google/gemma-3-27b-it",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
env = cls(env_config, server_configs, slurm=False, testing=True)
# Find sample STL files
stl_files = glob.glob(os.path.join('sample_data', '*.stl'))
if not stl_files:
print("No STL files found in sample_data/")
return
# Test loading and rendering the first file
print(f"Testing with STL file: {stl_files[0]}")
mesh = env.load_stl_file(stl_files[0])
if mesh is None:
print("Failed to load STL file")
return
print(f"Loaded mesh with {len(mesh.vertices)} vertices and {len(mesh.faces)} faces")
# Render the mesh
try:
images = env.renderer.render_mesh_to_images(mesh)
print(f"Successfully rendered {len(images)} images")
# Save the first image for inspection
from PIL import Image
img = Image.fromarray(images[0])
img.save("test_render.png")
print("Saved test render to test_render.png")
except Exception as e:
print(f"Error rendering: {e}")
print("Test completed")
if __name__ == "__main__":
PhysicalEnv.cli()