mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Linting, move env to community
This commit is contained in:
parent
67e057b13c
commit
8b09ace467
18 changed files with 945 additions and 646 deletions
|
|
@ -1807,6 +1807,239 @@ Access the game at `http://localhost:3000` when running the server.
|
|||
|
||||
---
|
||||
|
||||
### 22. Physical Space STL CAD RL Environment (`physical_space_stl/`)
|
||||
|
||||
**Contributors**: ecsbeats, venkatacrc
|
||||
**PR**: [#76](https://github.com/NousResearch/atropos/pull/76)
|
||||
**Integration Status**: ✅ Integrated
|
||||
|
||||
**Description**: A reinforcement learning environment for training language models to generate STL (stereolithography) files from 3D wireframe views and technical drawings. This environment bridges the gap between visual 3D understanding and CAD file generation, enabling AI systems to learn computer-aided design skills.
|
||||
|
||||
**Core Features**:
|
||||
|
||||
**3D Rendering Pipeline**:
|
||||
- **PyRender Integration**: Offline 3D rendering with GPU acceleration (EGL) or CPU fallback (OSMesa)
|
||||
- **Multi-View Generation**: Automatic generation of front, top, and diagonal wireframe views
|
||||
- **Blueprint Styling**: Technical drawing aesthetics with blue wireframes on light backgrounds
|
||||
- **Mesh Processing**: Support for complex STL files with automatic centering and scaling
|
||||
|
||||
**STL Generation Training**:
|
||||
- **ASCII STL Format**: Focus on human-readable STL file generation
|
||||
- **Template Variety**: Multiple query templates to encourage diverse reasoning approaches
|
||||
- **Geometric Understanding**: Training on shape analysis, dimensions, and 3D spatial relationships
|
||||
- **Quality Assessment**: Multi-metric evaluation of generated vs. original meshes
|
||||
|
||||
**Evaluation System**:
|
||||
- **CLIP-Based Scoring**: Visual similarity assessment between rendered views
|
||||
- **Geometric Metrics**: Comparison of vertices, faces, volume, and surface area
|
||||
- **Mesh Validation**: Automatic validation of generated STL file structure
|
||||
- **Progressive Difficulty**: Adaptive training with increasing geometric complexity
|
||||
|
||||
**Technical Architecture**:
|
||||
|
||||
**Environment Interface**:
|
||||
```python
|
||||
from environments.community.physical_space_stl.physical_env import PhysicalEnv
|
||||
|
||||
# Initialize environment
|
||||
env_config, server_configs = PhysicalEnv.config_init()
|
||||
env = PhysicalEnv(env_config, server_configs)
|
||||
|
||||
# Training loop
|
||||
await env.setup()
|
||||
item = await env.get_next_item()
|
||||
# item contains: prompt, image (rendered views), stl_path
|
||||
```
|
||||
|
||||
**Data Pipeline**:
|
||||
- **STL File Loading**: Automatic discovery and loading of STL files from sample_data directory
|
||||
- **Train/Test Split**: 80/20 split with reproducible random seeding
|
||||
- **Image Rendering**: Real-time generation of wireframe views for each STL file
|
||||
- **Query Generation**: Dynamic prompt creation with multiple template variations
|
||||
|
||||
**Rendering System**:
|
||||
```python
|
||||
from environments.community.physical_space_stl.pyrender_utils import PyRenderOffline
|
||||
|
||||
# Initialize renderer
|
||||
renderer = PyRenderOffline(width=224, height=224) # CLIP-compatible size
|
||||
|
||||
# Render mesh to multiple views
|
||||
images = renderer.render_mesh_to_images(mesh)
|
||||
# Returns: [front_view, top_view, diagonal_view]
|
||||
```
|
||||
|
||||
**Camera Configuration**:
|
||||
- **Front View**: Standard orthographic projection along Z-axis
|
||||
- **Top View**: Overhead perspective for plan view understanding
|
||||
- **Diagonal View**: 3D perspective for spatial relationship comprehension
|
||||
- **Lighting Setup**: Multi-point lighting for clear wireframe visibility
|
||||
|
||||
**STL Processing**:
|
||||
|
||||
**File Format Support**:
|
||||
- **ASCII STL**: Primary focus for human-readable generation
|
||||
- **Binary STL**: Loading support for existing files
|
||||
- **Mesh Validation**: Automatic checking of facet normals and vertex ordering
|
||||
- **Error Handling**: Graceful degradation for malformed files
|
||||
|
||||
**Quality Metrics**:
|
||||
```python
|
||||
def score_meshes_similarity(original_mesh, generated_mesh):
|
||||
# Multi-dimensional similarity assessment
|
||||
metrics = {
|
||||
"vertex_ratio": min(gen_vertices / orig_vertices, 1.0),
|
||||
"face_ratio": min(gen_faces / orig_faces, 1.0),
|
||||
"volume_ratio": min(gen_volume / orig_volume, 1.0),
|
||||
"area_ratio": min(gen_area / orig_area, 1.0)
|
||||
}
|
||||
return sum(metrics.values()) / len(metrics)
|
||||
```
|
||||
|
||||
**Training Data Generation**:
|
||||
|
||||
**Dataset Creation Pipeline**:
|
||||
```bash
|
||||
# Generate training dataset
|
||||
python dataset_scr.py # Create directory structure
|
||||
python render_stl.py # Generate images from STL files
|
||||
python llm_label.py # Create text descriptions
|
||||
python prepare_push_hf_dataset.py # Upload to Hugging Face
|
||||
```
|
||||
|
||||
**Data Structure**:
|
||||
```
|
||||
dataset/
|
||||
├── stls/ # Original STL files
|
||||
│ ├── model_0001.stl
|
||||
│ └── model_0002.stl
|
||||
├── images/ # Rendered wireframe views
|
||||
│ ├── model_0001.png
|
||||
│ └── model_0002.png
|
||||
└── labels.json # Text descriptions and metadata
|
||||
```
|
||||
|
||||
**Hugging Face Integration**:
|
||||
- **Dataset Upload**: Automatic preparation and upload to HF Hub
|
||||
- **Feature Extraction**: STL geometric features (centroid, bounding box, volume)
|
||||
- **Image Processing**: Standardized image formats for training
|
||||
- **Metadata Storage**: JSON-serialized geometric properties
|
||||
|
||||
**System Prompt Design**:
|
||||
|
||||
**Expert Persona**: "You are an expert in 3D modeling and computer-aided design..."
|
||||
|
||||
**Task Specification**:
|
||||
- **Input**: Wireframe views and technical drawings
|
||||
- **Output**: Valid ASCII STL file content
|
||||
- **Reasoning**: Encouraged use of `<think>` tags for geometric analysis
|
||||
- **Format**: Strict `<stl>` tag enclosure for generated content
|
||||
|
||||
**Example Templates**:
|
||||
- "Create a 3D model (STL file) for the object shown in these technical drawings. Be precise with the geometry."
|
||||
- "Based on these wireframe views, generate the STL code for this 3D object. Pay attention to all visible features."
|
||||
- "Using these blueprint images as reference, provide the STL file format data to recreate this 3D model."
|
||||
|
||||
**Performance Optimization**:
|
||||
|
||||
**Rendering Efficiency**:
|
||||
- **Headless Operation**: EGL/OSMesa for server environments
|
||||
- **GPU Acceleration**: Automatic detection and utilization
|
||||
- **Memory Management**: Efficient mesh processing and cleanup
|
||||
- **Batch Processing**: Support for multiple STL files
|
||||
|
||||
**Computational Requirements**:
|
||||
- **Dependencies**: pyrender, trimesh, pyglet, matplotlib, torch, transformers
|
||||
- **System Libraries**: libglfw3-dev, libgles2-mesa-dev (Ubuntu)
|
||||
- **GPU Support**: Optional but recommended for rendering performance
|
||||
- **Memory Usage**: Scales with STL file complexity and batch size
|
||||
|
||||
**Research Applications**:
|
||||
|
||||
**3D Understanding**:
|
||||
- **Spatial Reasoning**: Training models to understand 3D geometry from 2D projections
|
||||
- **CAD Generation**: Automated creation of manufacturable 3D models
|
||||
- **Design Iteration**: Rapid prototyping through AI-assisted design
|
||||
- **Geometric Constraints**: Learning physical and manufacturing constraints
|
||||
|
||||
**Vision-Language Integration**:
|
||||
- **Multi-Modal Learning**: Combining visual and textual understanding of 3D objects
|
||||
- **Technical Communication**: Bridging natural language and CAD representations
|
||||
- **Design Documentation**: Automatic generation of technical specifications
|
||||
- **Educational Tools**: Interactive learning for 3D modeling concepts
|
||||
|
||||
**Manufacturing Applications**:
|
||||
- **Rapid Prototyping**: AI-assisted design for 3D printing
|
||||
- **Quality Control**: Automated verification of CAD file accuracy
|
||||
- **Design Optimization**: Iterative improvement of 3D models
|
||||
- **Accessibility**: Democratizing CAD design through natural language interfaces
|
||||
|
||||
**Setup Instructions**:
|
||||
|
||||
**Environment Setup**:
|
||||
```bash
|
||||
# Install Python dependencies
|
||||
pip install pyrender trimesh pyglet matplotlib torch transformers pydantic vllm numpy requests tenacity wandb
|
||||
|
||||
# Ubuntu system dependencies for rendering
|
||||
sudo apt-get install libglfw3-dev libgles2-mesa-dev libnvidia-gl-570-server
|
||||
|
||||
# Set rendering backend (choose one)
|
||||
export PYOPENGL_PLATFORM=egl # For GPU acceleration
|
||||
export PYOPENGL_PLATFORM=osmesa # For CPU-only environments
|
||||
```
|
||||
|
||||
**Data Preparation**:
|
||||
```bash
|
||||
# Create sample data directory
|
||||
mkdir sample_data
|
||||
# Add STL files to sample_data/ directory
|
||||
|
||||
# Test rendering system
|
||||
python test_renderer_example.py
|
||||
python test_stl_env.py
|
||||
```
|
||||
|
||||
**Training Configuration**:
|
||||
- **Model**: google/gemma-3-27b-it (configurable)
|
||||
- **Batch Size**: 12 (adjustable based on memory)
|
||||
- **Max Tokens**: 2048 (sufficient for complex STL files)
|
||||
- **Evaluation**: Every 100 steps with 10 test files
|
||||
|
||||
**Demo Resources**:
|
||||
- **Training Run**: [W&B Run dlexyg5r](https://wandb.ai/csxl/atropos-environments_hack0/runs/dlexyg5r)
|
||||
- **GRPO Training**: [W&B Run t61am7gu](https://wandb.ai/csxl/grpo-physical-trainer/runs/t61am7gu)
|
||||
- **Test Images**: Rendered sphere views demonstrating wireframe quality
|
||||
- **Sample Data**: HTML visualization of training conversations
|
||||
|
||||
**Future Enhancements**:
|
||||
|
||||
**Advanced Rendering**:
|
||||
- **Texture Support**: Material and surface property visualization
|
||||
- **Animation**: Time-series rendering for dynamic objects
|
||||
- **Cross-Sections**: Internal structure visualization
|
||||
- **Assembly Views**: Multi-part object rendering
|
||||
|
||||
**Enhanced Evaluation**:
|
||||
- **Geometric Accuracy**: More sophisticated similarity metrics
|
||||
- **Manufacturing Constraints**: Validation of printability and structural integrity
|
||||
- **User Studies**: Human evaluation of generated designs
|
||||
- **Benchmark Datasets**: Standardized test suites for CAD generation
|
||||
|
||||
**Integration Opportunities**:
|
||||
- **CAD Software**: Direct integration with professional design tools
|
||||
- **3D Printing**: Seamless workflow to physical prototypes
|
||||
- **Simulation**: Physics-based validation of generated designs
|
||||
- **Collaborative Design**: Multi-agent design environments
|
||||
|
||||
**Research Impact**: This environment represents a significant step toward AI-assisted computer-aided design, potentially revolutionizing how 3D models are created and iterated. The combination of visual understanding and structured output generation opens new possibilities for democratizing design tools and accelerating product development cycles.
|
||||
|
||||
**Educational Value**: The environment serves as an excellent introduction to 3D graphics programming, mesh processing, and the intersection of AI with traditional engineering disciplines. The clear separation between rendering, evaluation, and generation components makes it suitable for educational use and research extension.
|
||||
|
||||
**Requirements**: pyrender, trimesh, pyglet, matplotlib, torch, transformers, pydantic, vllm, numpy, requests, tenacity, wandb, atroposlib
|
||||
|
||||
---
|
||||
|
||||
## Support
|
||||
|
||||
For questions or issues with community environments:
|
||||
|
|
|
|||
|
|
@ -33,4 +33,3 @@ $ sudo apt-get install libglfw3-dev libgles2-mesa-dev libnvidia-gl-570-server
|
|||
|
||||
Generated run: https://wandb.ai/csxl/atropos-environments_hack0/runs/dlexyg5r
|
||||
Training run (ran out of memory): https://wandb.ai/csxl/grpo-physical-trainer/runs/t61am7gu
|
||||
|
||||
|
|
@ -1,3 +1,3 @@
|
|||
from .physical_env import PhysicalEnv
|
||||
|
||||
__all__ = ["PhysicalEnv"]
|
||||
__all__ = ["PhysicalEnv"]
|
||||
|
|
@ -16,7 +16,7 @@ for name in os.listdir(base_dir):
|
|||
# print(f"Found directory: {name}")
|
||||
stl_file_fpath = os.path.join(stl_dir, name)
|
||||
stl_file_fpath += ".stl"
|
||||
#print(stl_file_path)
|
||||
# print(stl_file_path)
|
||||
ds_stl_fpath = os.path.join(ds_stl_path, name)
|
||||
ds_stl_fpath += ".stl"
|
||||
shutil.copy(stl_file_fpath, ds_stl_path)
|
||||
|
|
@ -24,7 +24,4 @@ for name in os.listdir(base_dir):
|
|||
ds_img_fpath = os.path.join(ds_img_path, name)
|
||||
ds_img_fpath += "_0001.png"
|
||||
shutil.copy(base_img_fpath, ds_img_fpath)
|
||||
#print(base_img_fpath, ds_img_fpath)
|
||||
|
||||
|
||||
|
||||
# print(base_img_fpath, ds_img_fpath)
|
||||
|
|
@ -15,12 +15,13 @@ import numpy as np
|
|||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import wandb # Added for logging
|
||||
from pydantic import BaseModel, Field
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import wandb # Added for logging
|
||||
|
||||
# Global variable to keep track of the vLLM process
|
||||
vllm_process = None
|
||||
|
||||
|
|
@ -538,7 +539,7 @@ if __name__ == "__main__":
|
|||
training_steps=20,
|
||||
vllm_restart_interval=3,
|
||||
use_wandb=True,
|
||||
wandb_project="grpo-physical-trainer"
|
||||
wandb_project="grpo-physical-trainer",
|
||||
)
|
||||
|
||||
train(training_config)
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
|
||||
class CLIPScorer:
|
||||
def __init__(self, model_name="openai/clip-vit-base-patch32"):
|
||||
|
|
@ -11,36 +12,42 @@ class CLIPScorer:
|
|||
self.processor = CLIPProcessor.from_pretrained(model_name)
|
||||
print(f"CLIPScorer initialized on {self.device} with {model_name}")
|
||||
except Exception as e:
|
||||
print(f"Error initializing CLIPModel: {e}. Ensure model name is correct and you have internet.")
|
||||
print(
|
||||
f"Error initializing CLIPModel: {e}. Ensure model name is correct and you have internet."
|
||||
)
|
||||
self.model = None
|
||||
self.processor = None
|
||||
raise
|
||||
|
||||
@torch.no_grad() # Ensure no gradients are computed during inference
|
||||
@torch.no_grad() # Ensure no gradients are computed during inference
|
||||
def score_images(self, images_np_list: list, target_text_description: str):
|
||||
if not self.model or not self.processor:
|
||||
print("CLIPScorer not properly initialized.")
|
||||
return [0.0] * len(images_np_list) # Low score on error
|
||||
return [0.0] * len(images_np_list) # Low score on error
|
||||
|
||||
try:
|
||||
pil_images = [Image.fromarray(img_arr.astype(np.uint8)) for img_arr in images_np_list]
|
||||
|
||||
pil_images = [
|
||||
Image.fromarray(img_arr.astype(np.uint8)) for img_arr in images_np_list
|
||||
]
|
||||
|
||||
inputs = self.processor(
|
||||
text=[target_text_description], # Single text prompt
|
||||
text=[target_text_description], # Single text prompt
|
||||
images=pil_images,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True
|
||||
truncation=True,
|
||||
).to(self.device)
|
||||
|
||||
outputs = self.model(**inputs)
|
||||
image_text_similarity_scores = outputs.logits_per_image.squeeze().tolist() # Squeeze to remove the text dim
|
||||
image_text_similarity_scores = (
|
||||
outputs.logits_per_image.squeeze().tolist()
|
||||
) # Squeeze to remove the text dim
|
||||
|
||||
if not isinstance(image_text_similarity_scores, list): # If only one image
|
||||
if not isinstance(image_text_similarity_scores, list): # If only one image
|
||||
image_text_similarity_scores = [image_text_similarity_scores]
|
||||
|
||||
return image_text_similarity_scores
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in CLIP scoring: {e}")
|
||||
return [0.0] * len(images_np_list) # Low score on error
|
||||
return [0.0] * len(images_np_list) # Low score on error
|
||||
|
|
@ -1,10 +1,15 @@
|
|||
import os
|
||||
import json
|
||||
import trimesh
|
||||
import os
|
||||
|
||||
import torch
|
||||
import trimesh
|
||||
from PIL import Image
|
||||
from transformers import BlipProcessor, BlipForConditionalGeneration
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BlipForConditionalGeneration,
|
||||
BlipProcessor,
|
||||
)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
|
@ -14,12 +19,17 @@ LABEL_FILE = "dataset/labels.json"
|
|||
|
||||
# Load BLIP for image captioning
|
||||
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
|
||||
blip_model = BlipForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip-image-captioning-base"
|
||||
).to(device)
|
||||
|
||||
# Load Mistral or other small LLM
|
||||
llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
|
||||
llm_model = AutoModelForCausalLM.from_pretrained(llm_model_name, torch_dtype=torch.float16, device_map="auto")
|
||||
llm_model = AutoModelForCausalLM.from_pretrained(
|
||||
llm_model_name, torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
|
||||
|
||||
def extract_trimesh_features(mesh):
|
||||
return {
|
||||
|
|
@ -32,6 +42,7 @@ def extract_trimesh_features(mesh):
|
|||
"euler_number": mesh.euler_number,
|
||||
}
|
||||
|
||||
|
||||
def caption_image(image_path):
|
||||
raw_image = Image.open(image_path).convert("RGB")
|
||||
inputs = blip_processor(raw_image, return_tensors="pt").to(device)
|
||||
|
|
@ -39,6 +50,7 @@ def caption_image(image_path):
|
|||
caption = blip_processor.decode(out[0], skip_special_tokens=True)
|
||||
return caption
|
||||
|
||||
|
||||
def generate_label_with_llm(features, caption):
|
||||
prompt = f"""You are a 3D object classifier.
|
||||
|
||||
|
|
@ -56,6 +68,7 @@ Label:"""
|
|||
output_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
return output_text.split("Label:")[-1].strip()
|
||||
|
||||
|
||||
def main():
|
||||
labels = {}
|
||||
|
||||
|
|
@ -86,6 +99,6 @@ def main():
|
|||
|
||||
print(f"\nSaved labels to {LABEL_FILE}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
503
environments/community/physical_space_stl/physical_env.py
Normal file
503
environments/community/physical_space_stl/physical_env.py
Normal file
|
|
@ -0,0 +1,503 @@
|
|||
import glob
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import trimesh
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# Fix the relative imports for running directly
|
||||
try:
|
||||
from .judgement_model import CLIPScorer
|
||||
from .pyrender_utils import PyRenderOffline
|
||||
except ImportError:
|
||||
from judgement_model import CLIPScorer
|
||||
from pyrender_utils import PyRenderOffline
|
||||
|
||||
system_prompt = (
|
||||
"You are an expert in 3D modeling and computer-aided design. Your task is to analyze the "
|
||||
"blueprints or wireframe views of objects and generate the corresponding STL file content. "
|
||||
"STL (stereolithography) files represent 3D models as a collection of triangular facets.\n\n"
|
||||
"You may use <think> </think> tags to work through your reasoning about the shape, "
|
||||
"dimensions, and geometric features of the model. Be methodical in your approach.\n\n"
|
||||
"STL files can be in ASCII or binary format. For this task, generate ASCII STL content that "
|
||||
"accurately represents the 3D model shown in the provided views.\n\n"
|
||||
"Your final output must be enclosed in <stl> </stl> tags, containing only the valid STL content "
|
||||
"and nothing else. The STL content should begin with 'solid' and end with 'endsolid'.\n\n"
|
||||
"Example of STL format:\n"
|
||||
"<stl>\n"
|
||||
"solid model\n"
|
||||
" facet normal nx ny nz\n"
|
||||
" outer loop\n"
|
||||
" vertex x1 y1 z1\n"
|
||||
" vertex x2 y2 z2\n"
|
||||
" vertex x3 y3 z3\n"
|
||||
" endloop\n"
|
||||
" endfacet\n"
|
||||
" ... more facets ...\n"
|
||||
"endsolid model\n"
|
||||
"</stl>"
|
||||
)
|
||||
|
||||
|
||||
class PhysicalRow(TypedDict):
|
||||
prompt: str
|
||||
image: np.ndarray
|
||||
stl: str
|
||||
|
||||
|
||||
class PhysicalEnv(BaseEnv):
|
||||
name = "physical"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
self.completion_lengths = []
|
||||
# Initialize renderer and CLIP scorer
|
||||
self.renderer = PyRenderOffline(width=224, height=224)
|
||||
self.clip_scorer = CLIPScorer()
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="google/gemma-3-27b-it",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
max_token_length=2048,
|
||||
wandb_name="physical",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="google/gemma-3-27b-it",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Try to calculate percent_correct, pass if there's a division by zero
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
# Call the parent method to handle the server metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
def load_stl_file(self, stl_path):
|
||||
"""Load an STL file into a trimesh object"""
|
||||
try:
|
||||
mesh = trimesh.load(stl_path)
|
||||
return mesh
|
||||
except Exception as e:
|
||||
print(f"Error loading STL file {stl_path}: {e}")
|
||||
return None
|
||||
|
||||
def generate_query_from_images(self, images):
|
||||
"""Generate a query based on the rendered images of the STL file"""
|
||||
# In a real implementation, this would use a vision model to generate a description
|
||||
# For this simplified version, we'll use different templates to add variety
|
||||
templates = [
|
||||
"Create a 3D model (STL file) for the object shown in these technical drawings. "
|
||||
"Be precise with the geometry.",
|
||||
"Based on these wireframe views, generate the STL code for this 3D object. "
|
||||
"Pay attention to all visible features.",
|
||||
"Using these blueprint images as reference, provide the STL file format data "
|
||||
"to recreate this 3D model.",
|
||||
"These are technical views of a 3D object. Generate the STL representation "
|
||||
"that would produce this exact shape.",
|
||||
"Reconstruct this 3D model from the provided wireframe views and output "
|
||||
"the STL file content.",
|
||||
]
|
||||
return random.choice(templates)
|
||||
|
||||
async def setup(self):
|
||||
# Load all STL files from sample_data
|
||||
self.stl_files = glob.glob(os.path.join("sample_data", "*.stl"))
|
||||
if not self.stl_files:
|
||||
raise ValueError("No STL files found in the sample_data directory")
|
||||
|
||||
print(f"Found {len(self.stl_files)} STL files")
|
||||
|
||||
# Split files into train and test sets (80/20 split)
|
||||
random.seed(42)
|
||||
random.shuffle(self.stl_files)
|
||||
split_idx = int(len(self.stl_files) * 0.8)
|
||||
self.train_files = self.stl_files[:split_idx]
|
||||
self.test_files = self.stl_files[split_idx:]
|
||||
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def rollout_and_score_eval(self, stl_path: str) -> number:
|
||||
# Load the STL file
|
||||
mesh = self.load_stl_file(stl_path)
|
||||
if mesh is None:
|
||||
return 0
|
||||
|
||||
# Render the images
|
||||
images = self.renderer.render_mesh_to_images(mesh)
|
||||
|
||||
# Generate a query from the images
|
||||
query = self.generate_query_from_images(images)
|
||||
|
||||
# Get a completion from the model
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
# Extract the STL content from the completion
|
||||
response_content = completion.choices[0].message.content
|
||||
stl_content = self.extract_stl_content(response_content)
|
||||
|
||||
# Load the original mesh directly
|
||||
original_mesh = mesh
|
||||
|
||||
# Save the generated STL content to a temporary file
|
||||
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
|
||||
try:
|
||||
with open(temp_file, "w") as f:
|
||||
f.write(stl_content)
|
||||
|
||||
# Load the generated mesh
|
||||
generated_mesh = trimesh.load(temp_file)
|
||||
|
||||
# Score the generated mesh against the original
|
||||
score = self.score_meshes_similarity(original_mesh, generated_mesh)
|
||||
|
||||
# Cleanup
|
||||
os.remove(temp_file)
|
||||
|
||||
return score
|
||||
except Exception as e:
|
||||
print(f"Error processing generated STL: {e}")
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
return 0
|
||||
|
||||
def extract_stl_content(self, response_content):
|
||||
"""Extract STL content from the model's response"""
|
||||
# Find content between <stl> and </stl> tags
|
||||
match = re.search(r"<stl>(.*?)</stl>", response_content, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return ""
|
||||
|
||||
def score_meshes_similarity(self, original_mesh, generated_mesh):
|
||||
"""Score the similarity between two meshes"""
|
||||
# This is a simple implementation - in practice you'd want more sophisticated metrics
|
||||
# Compare basic properties like number of vertices, faces, and volume
|
||||
orig_stats = {
|
||||
"vertices": len(original_mesh.vertices),
|
||||
"faces": len(original_mesh.faces),
|
||||
"volume": original_mesh.volume or 1.0,
|
||||
"surface_area": original_mesh.area or 1.0,
|
||||
}
|
||||
|
||||
gen_stats = {
|
||||
"vertices": len(generated_mesh.vertices),
|
||||
"faces": len(generated_mesh.faces),
|
||||
"volume": generated_mesh.volume or 1.0,
|
||||
"surface_area": generated_mesh.area or 1.0,
|
||||
}
|
||||
|
||||
# Calculate ratios (capped at 1.0 for when generated > original)
|
||||
vertex_ratio = min(gen_stats["vertices"] / max(orig_stats["vertices"], 1), 1.0)
|
||||
face_ratio = min(gen_stats["faces"] / max(orig_stats["faces"], 1), 1.0)
|
||||
volume_ratio = min(gen_stats["volume"] / max(orig_stats["volume"], 1), 1.0)
|
||||
area_ratio = min(
|
||||
gen_stats["surface_area"] / max(orig_stats["surface_area"], 1), 1.0
|
||||
)
|
||||
|
||||
# Average the ratios for a final score
|
||||
score = (vertex_ratio + face_ratio + volume_ratio + area_ratio) / 4.0
|
||||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
eval_tasks = []
|
||||
for stl_file in self.test_files[
|
||||
:10
|
||||
]: # Limit to 10 files for evaluation to keep it manageable
|
||||
eval_tasks.append(self.rollout_and_score_eval(stl_file))
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/similarity_score", sum(scores) / len(scores)))
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: PhysicalRow
|
||||
) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
stl_path = item["stl_path"]
|
||||
mesh = self.load_stl_file(stl_path)
|
||||
images = self.renderer.render_mesh_to_images(mesh)
|
||||
query = self.generate_query_from_images(images)
|
||||
|
||||
# For original STL content, we'll just store the file path instead of the content
|
||||
# as the files may be binary and can't be simply read as text
|
||||
original_stl_path = stl_path
|
||||
|
||||
user_message = {"role": "user", "content": query}
|
||||
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": system_prompt}, user_message],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
user_message,
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"original_stl_path": original_stl_path,
|
||||
"images": images,
|
||||
"finish_reason": chat_completion.finish_reason,
|
||||
}
|
||||
)
|
||||
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
|
||||
random.shuffle(rollout_group_data)
|
||||
|
||||
for item in rollout_group_data:
|
||||
response_content = item["messages"][-1]["content"]
|
||||
stl_content = self.extract_stl_content(response_content)
|
||||
|
||||
# Save the generated STL content to a temporary file
|
||||
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
|
||||
try:
|
||||
with open(temp_file, "w") as f:
|
||||
f.write(stl_content)
|
||||
|
||||
# Load the original STL directly from its path
|
||||
original_stl_path = item["original_stl_path"]
|
||||
original_mesh = trimesh.load(original_stl_path)
|
||||
|
||||
# Load the generated mesh
|
||||
generated_mesh = trimesh.load(temp_file)
|
||||
|
||||
# Score the generated mesh against the original
|
||||
mesh_similarity = self.score_meshes_similarity(
|
||||
original_mesh, generated_mesh
|
||||
)
|
||||
|
||||
# Generate rendered images of the produced STL
|
||||
generated_images = self.renderer.render_mesh_to_images(generated_mesh)
|
||||
|
||||
# Use CLIP to score the visual similarity
|
||||
images_reward = 0.0
|
||||
if len(generated_images) > 0 and len(item["images"]) > 0:
|
||||
# Extract query from the user message
|
||||
query = item["messages"][1]["content"]
|
||||
|
||||
# Score the visual similarity using CLIP
|
||||
clip_scores = self.clip_scorer.score_images(generated_images, query)
|
||||
images_reward = (
|
||||
sum(clip_scores) / len(clip_scores) / 100.0
|
||||
) # Normalize to roughly 0-1
|
||||
|
||||
# Combine mesh similarity and image similarity for final reward
|
||||
reward = 0.5 * mesh_similarity + 0.5 * images_reward
|
||||
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, item["messages"], item["finish_reason"]
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
|
||||
# Remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(reward)
|
||||
|
||||
self.percent_correct_buffer.append(reward)
|
||||
|
||||
# Clean up temporary files
|
||||
os.remove(temp_file)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in scoring: {e}")
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
|
||||
# Apply length penalty if all scores are similar
|
||||
if all(abs(score - scores["scores"][0]) < 0.1 for score in scores["scores"]):
|
||||
token_lengths = [len(token) for token in scores["tokens"]]
|
||||
if max(token_lengths) > 0:
|
||||
max_allowed_length = self.config.max_token_length
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
|
||||
scores["scores"] = []
|
||||
for length in token_lengths:
|
||||
if length <= length_threshold:
|
||||
scores["scores"].append(1.0)
|
||||
else:
|
||||
percentage_of_range = (length - length_threshold) / (
|
||||
max_allowed_length - length_threshold
|
||||
)
|
||||
percentage_of_range = min(percentage_of_range, 1.0)
|
||||
scores["scores"].append(1.0 - percentage_of_range)
|
||||
|
||||
if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
return None # If all the same, we return None
|
||||
|
||||
return scores
|
||||
|
||||
async def get_next_item(self) -> PhysicalRow:
|
||||
stl_path = self.train_files[self.iter % len(self.train_files)]
|
||||
self.iter += 1
|
||||
|
||||
# Load the STL file and render it
|
||||
mesh = self.load_stl_file(stl_path)
|
||||
if mesh is None:
|
||||
# Skip this file and try the next one if there's an issue
|
||||
return await self.get_next_item()
|
||||
|
||||
# Render the mesh to get images
|
||||
images = self.renderer.render_mesh_to_images(mesh)
|
||||
|
||||
# Generate a query from the images
|
||||
query = self.generate_query_from_images(images)
|
||||
|
||||
# Return a row with the prompt, image, and path to the STL file
|
||||
return {
|
||||
"prompt": query,
|
||||
"image": images[0] if images else np.zeros((224, 224, 3), dtype=np.uint8),
|
||||
"stl_path": stl_path,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def test_sample_stl(cls):
|
||||
"""Test loading and rendering a sample STL file"""
|
||||
# Create temporary environment instance
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="google/gemma-3-27b-it",
|
||||
group_size=8,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
max_token_length=2048,
|
||||
wandb_name="physical_test",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="google/gemma-3-27b-it",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
env = cls(env_config, server_configs, slurm=False, testing=True)
|
||||
|
||||
# Find sample STL files
|
||||
stl_files = glob.glob(os.path.join("sample_data", "*.stl"))
|
||||
if not stl_files:
|
||||
print("No STL files found in sample_data/")
|
||||
return
|
||||
|
||||
# Test loading and rendering the first file
|
||||
print(f"Testing with STL file: {stl_files[0]}")
|
||||
mesh = env.load_stl_file(stl_files[0])
|
||||
if mesh is None:
|
||||
print("Failed to load STL file")
|
||||
return
|
||||
|
||||
print(
|
||||
f"Loaded mesh with {len(mesh.vertices)} vertices and {len(mesh.faces)} faces"
|
||||
)
|
||||
|
||||
# Render the mesh
|
||||
try:
|
||||
images = env.renderer.render_mesh_to_images(mesh)
|
||||
print(f"Successfully rendered {len(images)} images")
|
||||
|
||||
# Save the first image for inspection
|
||||
from PIL import Image
|
||||
|
||||
img = Image.fromarray(images[0])
|
||||
img.save("test_render.png")
|
||||
print("Saved test render to test_render.png")
|
||||
except Exception as e:
|
||||
print(f"Error rendering: {e}")
|
||||
|
||||
print("Test completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PhysicalEnv.cli()
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import trimesh
|
||||
from datasets import Dataset, Features, Value, Image
|
||||
from datasets import Dataset, Features, Image, Value
|
||||
from huggingface_hub import login
|
||||
|
||||
# Log in to HF Hub (optional if you've already done `huggingface-cli login`)
|
||||
|
|
@ -23,10 +23,10 @@ for image_filename in os.listdir(image_dir):
|
|||
if not image_filename.endswith(".png"):
|
||||
continue
|
||||
image_path = os.path.join(image_dir, image_filename)
|
||||
|
||||
|
||||
# Extract base ID
|
||||
base_id = image_filename.split("_")[0]
|
||||
|
||||
|
||||
stl_path = os.path.join(stl_dir, f"{base_id}.stl")
|
||||
label = labels.get(base_id, "unknown")
|
||||
|
||||
|
|
@ -42,20 +42,24 @@ for image_filename in os.listdir(image_dir):
|
|||
except Exception as e:
|
||||
print(f"⚠️ Failed to process {stl_path}: {e}")
|
||||
|
||||
data.append({
|
||||
"image": image_path,
|
||||
"label": label,
|
||||
"stl_features": stl_features,
|
||||
"id": base_id,
|
||||
})
|
||||
data.append(
|
||||
{
|
||||
"image": image_path,
|
||||
"label": label,
|
||||
"stl_features": stl_features,
|
||||
"id": base_id,
|
||||
}
|
||||
)
|
||||
|
||||
# Define dataset schema
|
||||
features = Features({
|
||||
"id": Value("string"),
|
||||
"image": Image(), # Load images from file paths
|
||||
"label": Value("string"),
|
||||
"stl_features": Value("string"), # Store as JSON string for simplicity
|
||||
})
|
||||
features = Features(
|
||||
{
|
||||
"id": Value("string"),
|
||||
"image": Image(), # Load images from file paths
|
||||
"label": Value("string"),
|
||||
"stl_features": Value("string"), # Store as JSON string for simplicity
|
||||
}
|
||||
)
|
||||
|
||||
# Convert stl_features to JSON strings for compatibility
|
||||
for item in data:
|
||||
|
|
@ -66,4 +70,3 @@ dataset = Dataset.from_list(data).cast(features)
|
|||
|
||||
# Push to Hub
|
||||
dataset.push_to_hub("venkatacrc/stl-image-dataset", private=True)
|
||||
|
||||
|
|
@ -1,223 +1,240 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import pyrender
|
||||
import trimesh
|
||||
|
||||
# Headless rendering with GPU acceleration (egl), for non-GPU environments omesa will be faster
|
||||
os.environ['PYOPENGL_PLATFORM'] = 'egl'
|
||||
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
||||
|
||||
def create_look_at_matrix(eye: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray:
|
||||
|
||||
def create_look_at_matrix(
|
||||
eye: np.ndarray, target: np.ndarray, up: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Create a look-at transformation matrix for a camera.
|
||||
|
||||
|
||||
eye - position of the camera
|
||||
target - position the camera is looking at
|
||||
up - up direction for the camera
|
||||
|
||||
|
||||
returns the 4x4 transformation matrix
|
||||
"""
|
||||
eye = np.asarray(eye)
|
||||
target = np.asarray(target)
|
||||
up = np.asarray(up)
|
||||
|
||||
|
||||
# Forward vector (from eye to target)
|
||||
forward = target - eye
|
||||
forward = forward / np.linalg.norm(forward)
|
||||
|
||||
|
||||
# Side vector (right vector)
|
||||
side = np.cross(forward, up)
|
||||
side = side / np.linalg.norm(side)
|
||||
|
||||
|
||||
# Recompute up vector to ensure orthogonality
|
||||
up = np.cross(side, forward)
|
||||
up = up / np.linalg.norm(up)
|
||||
|
||||
|
||||
# Create the rotation matrix | NOTE: PyRender uses OpenGL convention
|
||||
# where camera looks down negative z-axis
|
||||
R = np.eye(4)
|
||||
R[0, :3] = side
|
||||
R[1, :3] = up
|
||||
R[2, :3] = -forward
|
||||
|
||||
T = np.eye(4) # translation matrix
|
||||
|
||||
T = np.eye(4) # translation matrix
|
||||
T[:3, 3] = eye
|
||||
|
||||
|
||||
# The camera pose matrix is the inverse of the view matrix
|
||||
# but we need to return the pose directly for pyrender
|
||||
return T @ R
|
||||
|
||||
|
||||
class PyRenderOffline:
|
||||
def __init__(self, width=224, height=224): # Standard CLIP image size
|
||||
def __init__(self, width=224, height=224): # Standard CLIP image size
|
||||
self.width = width
|
||||
self.height = height
|
||||
try:
|
||||
self.renderer = pyrender.OffscreenRenderer(viewport_width=self.width, viewport_height=self.height, point_size=1.0)
|
||||
self.renderer = pyrender.OffscreenRenderer(
|
||||
viewport_width=self.width, viewport_height=self.height, point_size=1.0
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize OffscreenRenderer (is a display server/EGL/OSMesa available?): {e}")
|
||||
print("Try: pip install pyglet; or for headless: export PYOPENGL_PLATFORM=osmesa (or egl)")
|
||||
self.renderer = None # Fallback or raise error
|
||||
print(
|
||||
f"Failed to initialize OffscreenRenderer (is a display server/EGL/OSMesa available?): {e}"
|
||||
)
|
||||
print(
|
||||
"Try: pip install pyglet; or for headless: export PYOPENGL_PLATFORM=osmesa (or egl)"
|
||||
)
|
||||
self.renderer = None # Fallback or raise error
|
||||
raise
|
||||
|
||||
# Create camera poses using explicit transformation matrices
|
||||
|
||||
|
||||
# Front view (looking along the -Z axis)
|
||||
front_pose = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 0, 1, 5],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
|
||||
front_pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 5], [0, 0, 0, 1]])
|
||||
|
||||
# Top view (looking along the -Y axis)
|
||||
top_pose = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 5],
|
||||
[0, -1, 0, 0],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
|
||||
top_pose = np.array([[1, 0, 0, 0], [0, 0, 1, 5], [0, -1, 0, 0], [0, 0, 0, 1]])
|
||||
|
||||
# Diagonal view (from upper corner)
|
||||
side_pose = np.array([
|
||||
[0.866, -0.25, 0.433, 3], # Camera right vector
|
||||
[0.0, 0.866, 0.5, 3], # Camera up vector
|
||||
[-0.5, -0.433, 0.75, 5], # Camera forward vector (pointing at origin)
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
|
||||
side_pose = np.array(
|
||||
[
|
||||
[0.866, -0.25, 0.433, 3], # Camera right vector
|
||||
[0.0, 0.866, 0.5, 3], # Camera up vector
|
||||
[-0.5, -0.433, 0.75, 5], # Camera forward vector (pointing at origin)
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Store camera poses
|
||||
self.camera_poses = [front_pose, top_pose, side_pose]
|
||||
|
||||
|
||||
# Debug print for the camera poses cause it doesn't look right
|
||||
# and I'm not a game developer lol
|
||||
print("Camera poses:")
|
||||
for i, pose in enumerate(self.camera_poses):
|
||||
print(f"Camera {i}:\n{pose}")
|
||||
|
||||
|
||||
# slightly wider field of view to ensure objects are visible
|
||||
self.camera = pyrender.PerspectiveCamera(yfov=np.pi / 2.5, aspectRatio=1.0)
|
||||
|
||||
|
||||
# Bright point light
|
||||
self.light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=10.0)
|
||||
|
||||
def render_mesh_to_images(self, mesh_obj: trimesh.Trimesh):
|
||||
if not self.renderer:
|
||||
print("Renderer not initialized, cannot render.")
|
||||
return [np.zeros((self.height, self.width, 3), dtype=np.uint8) for _ in range(3)]
|
||||
return [
|
||||
np.zeros((self.height, self.width, 3), dtype=np.uint8) for _ in range(3)
|
||||
]
|
||||
|
||||
print(
|
||||
f"Rendering mesh with {len(mesh_obj.vertices)} vertices and {len(mesh_obj.faces)} faces"
|
||||
)
|
||||
|
||||
print(f"Rendering mesh with {len(mesh_obj.vertices)} vertices and {len(mesh_obj.faces)} faces")
|
||||
|
||||
images = []
|
||||
|
||||
|
||||
# Make a copy to avoid modifying original mesh
|
||||
render_mesh = mesh_obj.copy()
|
||||
|
||||
|
||||
# Center and scale the mesh for visibility
|
||||
render_mesh.apply_translation(-render_mesh.centroid)
|
||||
scale_factor = 0.8 / np.max(render_mesh.extents) # Scale to unit size but slightly smaller
|
||||
scale_factor = 0.8 / np.max(
|
||||
render_mesh.extents
|
||||
) # Scale to unit size but slightly smaller
|
||||
render_mesh.apply_scale(scale_factor)
|
||||
|
||||
|
||||
# Create a ground plane
|
||||
ground_plane = trimesh.creation.box([4.0, 4.0, 0.01])
|
||||
ground_plane.apply_translation([0, 0, -0.5])
|
||||
|
||||
|
||||
# Color scheme for wireframe blueprint
|
||||
blueprint_bg = [0.98, 0.98, 1.0, 1.0] # Light blue background
|
||||
wireframe_color = [0.0, 0.4, 0.8, 1.0] # Medium bright blue for wireframes
|
||||
grid_color = [0.0, 0.2, 0.4, 1.0] # Darker blue for grid
|
||||
|
||||
blueprint_bg = [0.98, 0.98, 1.0, 1.0] # Light blue background
|
||||
wireframe_color = [0.0, 0.4, 0.8, 1.0] # Medium bright blue for wireframes
|
||||
grid_color = [0.0, 0.2, 0.4, 1.0] # Darker blue for grid
|
||||
|
||||
# Wireframe material for the edges - completely opaque
|
||||
wireframe_material = pyrender.MetallicRoughnessMaterial(
|
||||
baseColorFactor=wireframe_color,
|
||||
metallicFactor=0.0,
|
||||
roughnessFactor=1.0,
|
||||
wireframe=True
|
||||
wireframe=True,
|
||||
)
|
||||
|
||||
|
||||
# Grid material for the ground plane
|
||||
grid_material = pyrender.MetallicRoughnessMaterial(
|
||||
baseColorFactor=grid_color,
|
||||
metallicFactor=0.0,
|
||||
roughnessFactor=1.0,
|
||||
wireframe=True
|
||||
wireframe=True,
|
||||
)
|
||||
|
||||
|
||||
for i, pose in enumerate(self.camera_poses):
|
||||
# Create a fresh scene with blueprint background
|
||||
scene = pyrender.Scene(ambient_light=[0.8, 0.8, 0.95], bg_color=blueprint_bg)
|
||||
|
||||
scene = pyrender.Scene(
|
||||
ambient_light=[0.8, 0.8, 0.95], bg_color=blueprint_bg
|
||||
)
|
||||
|
||||
# Add ground plane as grid
|
||||
plane_mesh = pyrender.Mesh.from_trimesh(ground_plane, material=grid_material)
|
||||
plane_mesh = pyrender.Mesh.from_trimesh(
|
||||
ground_plane, material=grid_material
|
||||
)
|
||||
scene.add(plane_mesh)
|
||||
|
||||
|
||||
# For wireframe rendering, we'll create line segments directly
|
||||
edges = render_mesh.edges_unique
|
||||
edge_vertices = []
|
||||
edge_indices = []
|
||||
|
||||
|
||||
# Extract all edges for wireframe rendering
|
||||
for _, edge in enumerate(edges):
|
||||
v0_idx = len(edge_vertices)
|
||||
edge_vertices.append(render_mesh.vertices[edge[0]])
|
||||
edge_vertices.append(render_mesh.vertices[edge[1]])
|
||||
edge_indices.append([v0_idx, v0_idx+1])
|
||||
|
||||
edge_indices.append([v0_idx, v0_idx + 1])
|
||||
|
||||
# Create lines primitive for edges
|
||||
if len(edge_vertices) > 0:
|
||||
edge_verts = np.array(edge_vertices, dtype=np.float32)
|
||||
edge_indices = np.array(edge_indices, dtype=np.uint32)
|
||||
|
||||
|
||||
# Create a primitive for the lines
|
||||
primitive = pyrender.Primitive(
|
||||
positions=edge_verts,
|
||||
indices=edge_indices,
|
||||
mode=pyrender.constants.GLTF.LINES,
|
||||
material=wireframe_material
|
||||
material=wireframe_material,
|
||||
)
|
||||
|
||||
|
||||
# Create a mesh with just the line primitive
|
||||
edge_mesh = pyrender.Mesh(primitives=[primitive])
|
||||
scene.add(edge_mesh)
|
||||
|
||||
|
||||
# Add camera
|
||||
scene.add(self.camera, pose=pose)
|
||||
|
||||
|
||||
# Add light from camera direction (key light)
|
||||
scene.add(self.light, pose=pose)
|
||||
|
||||
|
||||
# Add second light from above for better visibility
|
||||
top_light_pose = np.eye(4)
|
||||
top_light_pose[1, 3] = 3.0
|
||||
scene.add(self.light, pose=top_light_pose)
|
||||
|
||||
|
||||
try:
|
||||
color, _ = self.renderer.render(scene)
|
||||
|
||||
|
||||
# Post-process to enhance blueprint effect
|
||||
# Convert to float for processing
|
||||
img_float = color.astype(np.float32) / 255.0
|
||||
|
||||
|
||||
# Add a subtle grid pattern to the background
|
||||
grid_size = 40
|
||||
grid_intensity = 0.02
|
||||
|
||||
|
||||
# Draw faint horizontal grid lines
|
||||
for y in range(0, color.shape[0], grid_size):
|
||||
img_float[y:y+1, :, :] = np.minimum(img_float[y:y+1, :, :] + grid_intensity, 1.0)
|
||||
|
||||
img_float[y : y + 1, :, :] = np.minimum(
|
||||
img_float[y : y + 1, :, :] + grid_intensity, 1.0
|
||||
)
|
||||
|
||||
# Draw faint vertical grid lines
|
||||
for x in range(0, color.shape[1], grid_size):
|
||||
img_float[:, x:x+1, :] = np.minimum(img_float[:, x:x+1, :] + grid_intensity, 1.0)
|
||||
|
||||
img_float[:, x : x + 1, :] = np.minimum(
|
||||
img_float[:, x : x + 1, :] + grid_intensity, 1.0
|
||||
)
|
||||
|
||||
# Convert back to uint8
|
||||
processed_img = (img_float * 255).astype(np.uint8)
|
||||
|
||||
|
||||
images.append(processed_img)
|
||||
print(f"Rendered wireframe view {i}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during rendering view {i}: {e}")
|
||||
images.append(np.zeros((self.height, self.width, 3), dtype=np.uint8))
|
||||
|
||||
|
||||
return images
|
||||
|
||||
def __del__(self):
|
||||
|
|
@ -1,17 +1,18 @@
|
|||
import bpy
|
||||
import sys
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
|
||||
import bpy
|
||||
|
||||
# Get args after --
|
||||
argv = sys.argv
|
||||
argv = argv[argv.index("--") + 1:] # args after --
|
||||
argv = argv[argv.index("--") + 1 :] # args after --
|
||||
|
||||
input_stl = argv[0]
|
||||
output_dir = argv[1]
|
||||
|
||||
# Clear existing objects
|
||||
bpy.ops.object.select_all(action='SELECT')
|
||||
bpy.ops.object.select_all(action="SELECT")
|
||||
bpy.ops.object.delete(use_global=False)
|
||||
|
||||
# Import STL
|
||||
|
|
@ -19,11 +20,11 @@ bpy.ops.import_mesh.stl(filepath=input_stl)
|
|||
obj = bpy.context.selected_objects[0]
|
||||
|
||||
# Center the object at origin
|
||||
bpy.ops.object.origin_set(type='ORIGIN_CENTER_OF_MASS', center='MEDIAN')
|
||||
bpy.ops.object.origin_set(type="ORIGIN_CENTER_OF_MASS", center="MEDIAN")
|
||||
obj.location = (0, 0, 0)
|
||||
|
||||
# Add Sun light
|
||||
sun_light_data = bpy.data.lights.new(name="SunLight", type='SUN')
|
||||
sun_light_data = bpy.data.lights.new(name="SunLight", type="SUN")
|
||||
sun_light_object = bpy.data.objects.new(name="SunLight", object_data=sun_light_data)
|
||||
sun_light_object.location = (10, 10, 10)
|
||||
bpy.context.collection.objects.link(sun_light_object)
|
||||
|
|
@ -52,10 +53,9 @@ for i, angle in enumerate(angles):
|
|||
|
||||
# Point camera to object center (0,0,0)
|
||||
direction = -cam_obj.location
|
||||
rot_quat = direction.to_track_quat('-Z', 'Y')
|
||||
rot_quat = direction.to_track_quat("-Z", "Y")
|
||||
cam_obj.rotation_euler = rot_quat.to_euler()
|
||||
|
||||
# Render
|
||||
bpy.context.scene.render.filepath = os.path.join(output_dir, f"render_{i}.png")
|
||||
bpy.ops.render.render(write_still=True)
|
||||
|
||||
|
Before Width: | Height: | Size: 47 KiB After Width: | Height: | Size: 47 KiB |
|
|
@ -1,9 +1,8 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import matplotlib.pyplot as plt
|
||||
import trimesh
|
||||
from pyrender_utils import PyRenderOffline
|
||||
|
||||
|
||||
def test_render_example():
|
||||
"""
|
||||
Example script to test the PyRenderOffline class with a real mesh.
|
||||
|
|
@ -13,34 +12,41 @@ def test_render_example():
|
|||
# Create a high-quality sphere for wireframe rendering
|
||||
# Using more subdivisions for smoother appearance
|
||||
sphere = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
|
||||
|
||||
|
||||
# Create a perfect sphere instead of a noisy one
|
||||
# This will look better as a wireframe/blueprint for our test ;)
|
||||
print(f"Created sphere with {len(sphere.vertices)} vertices and {len(sphere.faces)} faces")
|
||||
print(f"Sphere has {len(sphere.edges_unique)} unique edges for wireframe rendering")
|
||||
|
||||
renderer = PyRenderOffline(width=512, height=512) # larger dimensions than the CLIP size for better detail
|
||||
|
||||
print(
|
||||
f"Created sphere with {len(sphere.vertices)} vertices and {len(sphere.faces)} faces"
|
||||
)
|
||||
print(
|
||||
f"Sphere has {len(sphere.edges_unique)} unique edges for wireframe rendering"
|
||||
)
|
||||
|
||||
renderer = PyRenderOffline(
|
||||
width=512, height=512
|
||||
) # larger dimensions than the CLIP size for better detail
|
||||
|
||||
images = renderer.render_mesh_to_images(sphere)
|
||||
|
||||
|
||||
# Display the results
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||
view_names = ['Front', 'Top', 'Diagonal']
|
||||
|
||||
view_names = ["Front", "Top", "Diagonal"]
|
||||
|
||||
for i, (img, name) in enumerate(zip(images, view_names)):
|
||||
axes[i].imshow(img)
|
||||
axes[i].set_title(f"{name} View")
|
||||
axes[i].axis('off')
|
||||
axes[i].axis("off")
|
||||
|
||||
plt.savefig('test_rendered_sphere_views.png')
|
||||
plt.savefig("test_rendered_sphere_views.png")
|
||||
plt.close()
|
||||
|
||||
print(f"Successfully rendered sphere from 3 viewpoints")
|
||||
print(f"Images saved to rendered_sphere_views.png")
|
||||
|
||||
print("Successfully rendered sphere from 3 viewpoints")
|
||||
print("Images saved to rendered_sphere_views.png")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to run renderer example: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_render_example()
|
||||
test_render_example()
|
||||
|
|
@ -2,10 +2,10 @@
|
|||
"""
|
||||
Test script for the PhysicalEnv that loads and processes STL files
|
||||
"""
|
||||
import os
|
||||
import asyncio
|
||||
import random
|
||||
from physical_env import PhysicalEnv, BaseEnvConfig, APIServerConfig, EvalHandlingEnum
|
||||
|
||||
from physical_env import APIServerConfig, BaseEnvConfig, EvalHandlingEnum, PhysicalEnv
|
||||
|
||||
|
||||
async def test_render_stl():
|
||||
"""Test loading and rendering an STL file"""
|
||||
|
|
@ -22,7 +22,7 @@ async def test_render_stl():
|
|||
max_token_length=2048,
|
||||
wandb_name="physical_test",
|
||||
)
|
||||
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="google/gemma-3-27b-it",
|
||||
|
|
@ -31,13 +31,13 @@ async def test_render_stl():
|
|||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
print("Creating test environment")
|
||||
env = PhysicalEnv(env_config, server_configs, slurm=False, testing=True)
|
||||
|
||||
|
||||
print("Setting up environment")
|
||||
await env.setup()
|
||||
|
||||
|
||||
# Test get_next_item
|
||||
print("Testing get_next_item")
|
||||
try:
|
||||
|
|
@ -47,8 +47,9 @@ async def test_render_stl():
|
|||
print(f"STL path: {item['stl_path']}")
|
||||
except Exception as e:
|
||||
print(f"Error getting next item: {e}")
|
||||
|
||||
|
||||
print("Test completed successfully")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_render_stl())
|
||||
asyncio.run(test_render_stl())
|
||||
|
|
@ -1,481 +0,0 @@
|
|||
import numpy as np
|
||||
import trimesh
|
||||
import io
|
||||
import re
|
||||
import wandb
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from typing import *
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
EvalHandlingEnum
|
||||
)
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# Fix the relative imports for running directly
|
||||
try:
|
||||
from .pyrender_utils import PyRenderOffline
|
||||
from .judgement_model import CLIPScorer
|
||||
except ImportError:
|
||||
from pyrender_utils import PyRenderOffline
|
||||
from judgement_model import CLIPScorer
|
||||
|
||||
system_prompt = (
|
||||
"You are an expert in 3D modeling and computer-aided design. Your task is to analyze the "
|
||||
"blueprints or wireframe views of objects and generate the corresponding STL file content. "
|
||||
"STL (stereolithography) files represent 3D models as a collection of triangular facets.\n\n"
|
||||
"You may use <think> </think> tags to work through your reasoning about the shape, "
|
||||
"dimensions, and geometric features of the model. Be methodical in your approach.\n\n"
|
||||
"STL files can be in ASCII or binary format. For this task, generate ASCII STL content that "
|
||||
"accurately represents the 3D model shown in the provided views.\n\n"
|
||||
"Your final output must be enclosed in <stl> </stl> tags, containing only the valid STL content "
|
||||
"and nothing else. The STL content should begin with 'solid' and end with 'endsolid'.\n\n"
|
||||
"Example of STL format:\n"
|
||||
"<stl>\n"
|
||||
"solid model\n"
|
||||
" facet normal nx ny nz\n"
|
||||
" outer loop\n"
|
||||
" vertex x1 y1 z1\n"
|
||||
" vertex x2 y2 z2\n"
|
||||
" vertex x3 y3 z3\n"
|
||||
" endloop\n"
|
||||
" endfacet\n"
|
||||
" ... more facets ...\n"
|
||||
"endsolid model\n"
|
||||
"</stl>"
|
||||
)
|
||||
|
||||
class PhysicalRow(TypedDict):
|
||||
prompt: str
|
||||
image: np.ndarray
|
||||
stl: str
|
||||
|
||||
class PhysicalEnv(BaseEnv):
|
||||
name = "physical"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
self.completion_lengths = []
|
||||
# Initialize renderer and CLIP scorer
|
||||
self.renderer = PyRenderOffline(width=224, height=224)
|
||||
self.clip_scorer = CLIPScorer()
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="google/gemma-3-27b-it",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
max_token_length=2048,
|
||||
wandb_name="physical",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="google/gemma-3-27b-it",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Try to calculate percent_correct, pass if there's a division by zero
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
# Call the parent method to handle the server metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
def load_stl_file(self, stl_path):
|
||||
"""Load an STL file into a trimesh object"""
|
||||
try:
|
||||
mesh = trimesh.load(stl_path)
|
||||
return mesh
|
||||
except Exception as e:
|
||||
print(f"Error loading STL file {stl_path}: {e}")
|
||||
return None
|
||||
|
||||
def generate_query_from_images(self, images):
|
||||
"""Generate a query based on the rendered images of the STL file"""
|
||||
# In a real implementation, this would use a vision model to generate a description
|
||||
# For this simplified version, we'll use different templates to add variety
|
||||
templates = [
|
||||
"Create a 3D model (STL file) for the object shown in these technical drawings. Be precise with the geometry.",
|
||||
"Based on these wireframe views, generate the STL code for this 3D object. Pay attention to all visible features.",
|
||||
"Using these blueprint images as reference, provide the STL file format data to recreate this 3D model.",
|
||||
"These are technical views of a 3D object. Generate the STL representation that would produce this exact shape.",
|
||||
"Reconstruct this 3D model from the provided wireframe views and output the STL file content."
|
||||
]
|
||||
return random.choice(templates)
|
||||
|
||||
async def setup(self):
|
||||
# Load all STL files from sample_data
|
||||
self.stl_files = glob.glob(os.path.join('sample_data', '*.stl'))
|
||||
if not self.stl_files:
|
||||
raise ValueError("No STL files found in the sample_data directory")
|
||||
|
||||
print(f"Found {len(self.stl_files)} STL files")
|
||||
|
||||
# Split files into train and test sets (80/20 split)
|
||||
random.seed(42)
|
||||
random.shuffle(self.stl_files)
|
||||
split_idx = int(len(self.stl_files) * 0.8)
|
||||
self.train_files = self.stl_files[:split_idx]
|
||||
self.test_files = self.stl_files[split_idx:]
|
||||
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def rollout_and_score_eval(self, stl_path: str) -> number:
|
||||
# Load the STL file
|
||||
mesh = self.load_stl_file(stl_path)
|
||||
if mesh is None:
|
||||
return 0
|
||||
|
||||
# Render the images
|
||||
images = self.renderer.render_mesh_to_images(mesh)
|
||||
|
||||
# Generate a query from the images
|
||||
query = self.generate_query_from_images(images)
|
||||
|
||||
# Get a completion from the model
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
# Extract the STL content from the completion
|
||||
response_content = completion.choices[0].message.content
|
||||
stl_content = self.extract_stl_content(response_content)
|
||||
|
||||
# Load the original mesh directly
|
||||
original_mesh = mesh
|
||||
|
||||
# Save the generated STL content to a temporary file
|
||||
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
|
||||
try:
|
||||
with open(temp_file, "w") as f:
|
||||
f.write(stl_content)
|
||||
|
||||
# Load the generated mesh
|
||||
generated_mesh = trimesh.load(temp_file)
|
||||
|
||||
# Score the generated mesh against the original
|
||||
score = self.score_meshes_similarity(original_mesh, generated_mesh)
|
||||
|
||||
# Cleanup
|
||||
os.remove(temp_file)
|
||||
|
||||
return score
|
||||
except Exception as e:
|
||||
print(f"Error processing generated STL: {e}")
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
return 0
|
||||
|
||||
def extract_stl_content(self, response_content):
|
||||
"""Extract STL content from the model's response"""
|
||||
# Find content between <stl> and </stl> tags
|
||||
match = re.search(r'<stl>(.*?)</stl>', response_content, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return ""
|
||||
|
||||
def score_meshes_similarity(self, original_mesh, generated_mesh):
|
||||
"""Score the similarity between two meshes"""
|
||||
# This is a simple implementation - in practice you'd want more sophisticated metrics
|
||||
# Compare basic properties like number of vertices, faces, and volume
|
||||
orig_stats = {
|
||||
'vertices': len(original_mesh.vertices),
|
||||
'faces': len(original_mesh.faces),
|
||||
'volume': original_mesh.volume or 1.0,
|
||||
'surface_area': original_mesh.area or 1.0
|
||||
}
|
||||
|
||||
gen_stats = {
|
||||
'vertices': len(generated_mesh.vertices),
|
||||
'faces': len(generated_mesh.faces),
|
||||
'volume': generated_mesh.volume or 1.0,
|
||||
'surface_area': generated_mesh.area or 1.0
|
||||
}
|
||||
|
||||
# Calculate ratios (capped at 1.0 for when generated > original)
|
||||
vertex_ratio = min(gen_stats['vertices'] / max(orig_stats['vertices'], 1), 1.0)
|
||||
face_ratio = min(gen_stats['faces'] / max(orig_stats['faces'], 1), 1.0)
|
||||
volume_ratio = min(gen_stats['volume'] / max(orig_stats['volume'], 1), 1.0)
|
||||
area_ratio = min(gen_stats['surface_area'] / max(orig_stats['surface_area'], 1), 1.0)
|
||||
|
||||
# Average the ratios for a final score
|
||||
score = (vertex_ratio + face_ratio + volume_ratio + area_ratio) / 4.0
|
||||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
eval_tasks = []
|
||||
for stl_file in self.test_files[:10]: # Limit to 10 files for evaluation to keep it manageable
|
||||
eval_tasks.append(self.rollout_and_score_eval(stl_file))
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/similarity_score", sum(scores) / len(scores)))
|
||||
|
||||
async def collect_trajectories(self, item: PhysicalRow) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
stl_path = item["stl_path"]
|
||||
mesh = self.load_stl_file(stl_path)
|
||||
images = self.renderer.render_mesh_to_images(mesh)
|
||||
query = self.generate_query_from_images(images)
|
||||
|
||||
# For original STL content, we'll just store the file path instead of the content
|
||||
# as the files may be binary and can't be simply read as text
|
||||
original_stl_path = stl_path
|
||||
|
||||
user_message = {"role": "user", "content": query}
|
||||
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": system_prompt}, user_message],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
user_message,
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"original_stl_path": original_stl_path,
|
||||
"images": images,
|
||||
"finish_reason": chat_completion.finish_reason,
|
||||
}
|
||||
)
|
||||
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(self, rollout_group_data) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
|
||||
random.shuffle(rollout_group_data)
|
||||
|
||||
for item in rollout_group_data:
|
||||
response_content = item["messages"][-1]["content"]
|
||||
stl_content = self.extract_stl_content(response_content)
|
||||
|
||||
# Save the generated STL content to a temporary file
|
||||
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
|
||||
try:
|
||||
with open(temp_file, "w") as f:
|
||||
f.write(stl_content)
|
||||
|
||||
# Load the original STL directly from its path
|
||||
original_stl_path = item["original_stl_path"]
|
||||
original_mesh = trimesh.load(original_stl_path)
|
||||
|
||||
# Load the generated mesh
|
||||
generated_mesh = trimesh.load(temp_file)
|
||||
|
||||
# Score the generated mesh against the original
|
||||
mesh_similarity = self.score_meshes_similarity(original_mesh, generated_mesh)
|
||||
|
||||
# Generate rendered images of the produced STL
|
||||
generated_images = self.renderer.render_mesh_to_images(generated_mesh)
|
||||
|
||||
# Use CLIP to score the visual similarity
|
||||
images_reward = 0.0
|
||||
if len(generated_images) > 0 and len(item["images"]) > 0:
|
||||
# Extract query from the user message
|
||||
query = item["messages"][1]["content"]
|
||||
|
||||
# Score the visual similarity using CLIP
|
||||
clip_scores = self.clip_scorer.score_images(generated_images, query)
|
||||
images_reward = sum(clip_scores) / len(clip_scores) / 100.0 # Normalize to roughly 0-1
|
||||
|
||||
# Combine mesh similarity and image similarity for final reward
|
||||
reward = 0.5 * mesh_similarity + 0.5 * images_reward
|
||||
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, item["messages"], item["finish_reason"]
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
|
||||
# Remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(reward)
|
||||
|
||||
self.percent_correct_buffer.append(reward)
|
||||
|
||||
# Clean up temporary files
|
||||
os.remove(temp_file)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in scoring: {e}")
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
|
||||
# Apply length penalty if all scores are similar
|
||||
if all(abs(score - scores["scores"][0]) < 0.1 for score in scores["scores"]):
|
||||
token_lengths = [len(token) for token in scores["tokens"]]
|
||||
if max(token_lengths) > 0:
|
||||
max_allowed_length = self.config.max_token_length
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
|
||||
scores["scores"] = []
|
||||
for length in token_lengths:
|
||||
if length <= length_threshold:
|
||||
scores["scores"].append(1.0)
|
||||
else:
|
||||
percentage_of_range = (length - length_threshold) / (
|
||||
max_allowed_length - length_threshold
|
||||
)
|
||||
percentage_of_range = min(percentage_of_range, 1.0)
|
||||
scores["scores"].append(1.0 - percentage_of_range)
|
||||
|
||||
if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
return None # If all the same, we return None
|
||||
|
||||
return scores
|
||||
|
||||
async def get_next_item(self) -> PhysicalRow:
|
||||
stl_path = self.train_files[self.iter % len(self.train_files)]
|
||||
self.iter += 1
|
||||
|
||||
# Load the STL file and render it
|
||||
mesh = self.load_stl_file(stl_path)
|
||||
if mesh is None:
|
||||
# Skip this file and try the next one if there's an issue
|
||||
return await self.get_next_item()
|
||||
|
||||
# Render the mesh to get images
|
||||
images = self.renderer.render_mesh_to_images(mesh)
|
||||
|
||||
# Generate a query from the images
|
||||
query = self.generate_query_from_images(images)
|
||||
|
||||
# Return a row with the prompt, image, and path to the STL file
|
||||
return {
|
||||
"prompt": query,
|
||||
"image": images[0] if images else np.zeros((224, 224, 3), dtype=np.uint8),
|
||||
"stl_path": stl_path
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def test_sample_stl(cls):
|
||||
"""Test loading and rendering a sample STL file"""
|
||||
# Create temporary environment instance
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="google/gemma-3-27b-it",
|
||||
group_size=8,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
max_token_length=2048,
|
||||
wandb_name="physical_test",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="google/gemma-3-27b-it",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
env = cls(env_config, server_configs, slurm=False, testing=True)
|
||||
|
||||
# Find sample STL files
|
||||
stl_files = glob.glob(os.path.join('sample_data', '*.stl'))
|
||||
if not stl_files:
|
||||
print("No STL files found in sample_data/")
|
||||
return
|
||||
|
||||
# Test loading and rendering the first file
|
||||
print(f"Testing with STL file: {stl_files[0]}")
|
||||
mesh = env.load_stl_file(stl_files[0])
|
||||
if mesh is None:
|
||||
print("Failed to load STL file")
|
||||
return
|
||||
|
||||
print(f"Loaded mesh with {len(mesh.vertices)} vertices and {len(mesh.faces)} faces")
|
||||
|
||||
# Render the mesh
|
||||
try:
|
||||
images = env.renderer.render_mesh_to_images(mesh)
|
||||
print(f"Successfully rendered {len(images)} images")
|
||||
|
||||
# Save the first image for inspection
|
||||
from PIL import Image
|
||||
img = Image.fromarray(images[0])
|
||||
img.save("test_render.png")
|
||||
print("Saved test render to test_render.png")
|
||||
except Exception as e:
|
||||
print(f"Error rendering: {e}")
|
||||
|
||||
print("Test completed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
PhysicalEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue