import glob
import os
import random
import re
from typing import Dict, List, Optional, Tuple, TypedDict, Union
import numpy as np
import trimesh
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
ScoredDataGroup,
)
from atroposlib.type_definitions import Item, number
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# Fix the relative imports for running directly
try:
from .judgement_model import CLIPScorer
from .pyrender_utils import PyRenderOffline
except ImportError:
from judgement_model import CLIPScorer
from pyrender_utils import PyRenderOffline
system_prompt = (
"You are an expert in 3D modeling and computer-aided design. Your task is to analyze the "
"blueprints or wireframe views of objects and generate the corresponding STL file content. "
"STL (stereolithography) files represent 3D models as a collection of triangular facets.\n\n"
"You may use tags to work through your reasoning about the shape, "
"dimensions, and geometric features of the model. Be methodical in your approach.\n\n"
"STL files can be in ASCII or binary format. For this task, generate ASCII STL content that "
"accurately represents the 3D model shown in the provided views.\n\n"
"Your final output must be enclosed in tags, containing only the valid STL content "
"and nothing else. The STL content should begin with 'solid' and end with 'endsolid'.\n\n"
"Example of STL format:\n"
"\n"
"solid model\n"
" facet normal nx ny nz\n"
" outer loop\n"
" vertex x1 y1 z1\n"
" vertex x2 y2 z2\n"
" vertex x3 y3 z3\n"
" endloop\n"
" endfacet\n"
" ... more facets ...\n"
"endsolid model\n"
""
)
class PhysicalRow(TypedDict):
prompt: str
image: np.ndarray
stl: str
class PhysicalEnv(BaseEnv):
name = "physical"
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
# Add tracking for wandb visualizations
self.rollouts_for_wandb = []
self.completion_lengths = []
# Initialize renderer and CLIP scorer
self.renderer = PyRenderOffline(width=224, height=224)
self.clip_scorer = CLIPScorer()
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="google/gemma-3-27b-it",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
max_token_length=2048,
wandb_name="physical",
)
server_configs = [
APIServerConfig(
model_name="google/gemma-3-27b-it",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
# Try to calculate percent_correct, pass if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
# Call the parent method to handle the server metrics
await super().wandb_log(wandb_metrics)
def load_stl_file(self, stl_path):
"""Load an STL file into a trimesh object"""
try:
mesh = trimesh.load(stl_path)
return mesh
except Exception as e:
print(f"Error loading STL file {stl_path}: {e}")
return None
def generate_query_from_images(self, images):
"""Generate a query based on the rendered images of the STL file"""
# In a real implementation, this would use a vision model to generate a description
# For this simplified version, we'll use different templates to add variety
templates = [
"Create a 3D model (STL file) for the object shown in these technical drawings. "
"Be precise with the geometry.",
"Based on these wireframe views, generate the STL code for this 3D object. "
"Pay attention to all visible features.",
"Using these blueprint images as reference, provide the STL file format data "
"to recreate this 3D model.",
"These are technical views of a 3D object. Generate the STL representation "
"that would produce this exact shape.",
"Reconstruct this 3D model from the provided wireframe views and output "
"the STL file content.",
]
return random.choice(templates)
async def setup(self):
# Load all STL files from sample_data
self.stl_files = glob.glob(os.path.join("sample_data", "*.stl"))
if not self.stl_files:
raise ValueError("No STL files found in the sample_data directory")
print(f"Found {len(self.stl_files)} STL files")
# Split files into train and test sets (80/20 split)
random.seed(42)
random.shuffle(self.stl_files)
split_idx = int(len(self.stl_files) * 0.8)
self.train_files = self.stl_files[:split_idx]
self.test_files = self.stl_files[split_idx:]
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
async def rollout_and_score_eval(self, stl_path: str) -> number:
# Load the STL file
mesh = self.load_stl_file(stl_path)
if mesh is None:
return 0
# Render the images
images = self.renderer.render_mesh_to_images(mesh)
# Generate a query from the images
query = self.generate_query_from_images(images)
# Get a completion from the model
completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
split="eval",
)
# Extract the STL content from the completion
response_content = completion.choices[0].message.content
stl_content = self.extract_stl_content(response_content)
# Load the original mesh directly
original_mesh = mesh
# Save the generated STL content to a temporary file
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
try:
with open(temp_file, "w") as f:
f.write(stl_content)
# Load the generated mesh
generated_mesh = trimesh.load(temp_file)
# Score the generated mesh against the original
score = self.score_meshes_similarity(original_mesh, generated_mesh)
# Cleanup
os.remove(temp_file)
return score
except Exception as e:
print(f"Error processing generated STL: {e}")
if os.path.exists(temp_file):
os.remove(temp_file)
return 0
def extract_stl_content(self, response_content):
"""Extract STL content from the model's response"""
# Find content between and tags
match = re.search(r"(.*?)", response_content, re.DOTALL)
if match:
return match.group(1).strip()
return ""
def score_meshes_similarity(self, original_mesh, generated_mesh):
"""Score the similarity between two meshes"""
# This is a simple implementation - in practice you'd want more sophisticated metrics
# Compare basic properties like number of vertices, faces, and volume
orig_stats = {
"vertices": len(original_mesh.vertices),
"faces": len(original_mesh.faces),
"volume": original_mesh.volume or 1.0,
"surface_area": original_mesh.area or 1.0,
}
gen_stats = {
"vertices": len(generated_mesh.vertices),
"faces": len(generated_mesh.faces),
"volume": generated_mesh.volume or 1.0,
"surface_area": generated_mesh.area or 1.0,
}
# Calculate ratios (capped at 1.0 for when generated > original)
vertex_ratio = min(gen_stats["vertices"] / max(orig_stats["vertices"], 1), 1.0)
face_ratio = min(gen_stats["faces"] / max(orig_stats["faces"], 1), 1.0)
volume_ratio = min(gen_stats["volume"] / max(orig_stats["volume"], 1), 1.0)
area_ratio = min(
gen_stats["surface_area"] / max(orig_stats["surface_area"], 1), 1.0
)
# Average the ratios for a final score
score = (vertex_ratio + face_ratio + volume_ratio + area_ratio) / 4.0
return score
async def evaluate(self, *args, **kwargs):
eval_tasks = []
for stl_file in self.test_files[
:10
]: # Limit to 10 files for evaluation to keep it manageable
eval_tasks.append(self.rollout_and_score_eval(stl_file))
scores = await tqdm_asyncio.gather(*eval_tasks)
self.eval_metrics.append(("eval/similarity_score", sum(scores) / len(scores)))
async def collect_trajectories(
self, item: PhysicalRow
) -> Tuple[ScoredDataGroup, list[Item]]:
stl_path = item["stl_path"]
mesh = self.load_stl_file(stl_path)
images = self.renderer.render_mesh_to_images(mesh)
query = self.generate_query_from_images(images)
# For original STL content, we'll just store the file path instead of the content
# as the files may be binary and can't be simply read as text
original_stl_path = stl_path
user_message = {"role": "user", "content": query}
chat_completions = await self.server.chat_completion(
messages=[{"role": "system", "content": system_prompt}, user_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
)
to_score = list()
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
messages = (
{"role": "system", "content": system_prompt},
user_message,
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append(
{
"messages": messages,
"original_stl_path": original_stl_path,
"images": images,
"finish_reason": chat_completion.finish_reason,
}
)
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog
async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
response_content = item["messages"][-1]["content"]
stl_content = self.extract_stl_content(response_content)
# Save the generated STL content to a temporary file
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
try:
with open(temp_file, "w") as f:
f.write(stl_content)
# Load the original STL directly from its path
original_stl_path = item["original_stl_path"]
original_mesh = trimesh.load(original_stl_path)
# Load the generated mesh
generated_mesh = trimesh.load(temp_file)
# Score the generated mesh against the original
mesh_similarity = self.score_meshes_similarity(
original_mesh, generated_mesh
)
# Generate rendered images of the produced STL
generated_images = self.renderer.render_mesh_to_images(generated_mesh)
# Use CLIP to score the visual similarity
images_reward = 0.0
if len(generated_images) > 0 and len(item["images"]) > 0:
# Extract query from the user message
query = item["messages"][1]["content"]
# Score the visual similarity using CLIP
clip_scores = self.clip_scorer.score_images(generated_images, query)
images_reward = (
sum(clip_scores) / len(clip_scores) / 100.0
) # Normalize to roughly 0-1
# Combine mesh similarity and image similarity for final reward
reward = 0.5 * mesh_similarity + 0.5 * images_reward
out_dict = tokenize_for_trainer(
self.tokenizer, item["messages"], item["finish_reason"]
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(reward)
self.percent_correct_buffer.append(reward)
# Clean up temporary files
os.remove(temp_file)
if len(scores["tokens"]) >= self.config.group_size:
break
except Exception as e:
print(f"Error in scoring: {e}")
if os.path.exists(temp_file):
os.remove(temp_file)
# Apply length penalty if all scores are similar
if all(abs(score - scores["scores"][0]) < 0.1 for score in scores["scores"]):
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) > 0:
max_allowed_length = self.config.max_token_length
length_threshold = max_allowed_length * 0.5
scores["scores"] = []
for length in token_lengths:
if length <= length_threshold:
scores["scores"].append(1.0)
else:
percentage_of_range = (length - length_threshold) / (
max_allowed_length - length_threshold
)
percentage_of_range = min(percentage_of_range, 1.0)
scores["scores"].append(1.0 - percentage_of_range)
if all([scores["scores"][0] == score for score in scores["scores"]]):
return None # If all the same, we return None
return scores
async def get_next_item(self) -> PhysicalRow:
stl_path = self.train_files[self.iter % len(self.train_files)]
self.iter += 1
# Load the STL file and render it
mesh = self.load_stl_file(stl_path)
if mesh is None:
# Skip this file and try the next one if there's an issue
return await self.get_next_item()
# Render the mesh to get images
images = self.renderer.render_mesh_to_images(mesh)
# Generate a query from the images
query = self.generate_query_from_images(images)
# Return a row with the prompt, image, and path to the STL file
return {
"prompt": query,
"image": images[0] if images else np.zeros((224, 224, 3), dtype=np.uint8),
"stl_path": stl_path,
}
@classmethod
async def test_sample_stl(cls):
"""Test loading and rendering a sample STL file"""
# Create temporary environment instance
env_config = BaseEnvConfig(
tokenizer_name="google/gemma-3-27b-it",
group_size=8,
use_wandb=False,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
max_token_length=2048,
wandb_name="physical_test",
)
server_configs = [
APIServerConfig(
model_name="google/gemma-3-27b-it",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
env = cls(env_config, server_configs, slurm=False, testing=True)
# Find sample STL files
stl_files = glob.glob(os.path.join("sample_data", "*.stl"))
if not stl_files:
print("No STL files found in sample_data/")
return
# Test loading and rendering the first file
print(f"Testing with STL file: {stl_files[0]}")
mesh = env.load_stl_file(stl_files[0])
if mesh is None:
print("Failed to load STL file")
return
print(
f"Loaded mesh with {len(mesh.vertices)} vertices and {len(mesh.faces)} faces"
)
# Render the mesh
try:
images = env.renderer.render_mesh_to_images(mesh)
print(f"Successfully rendered {len(images)} images")
# Save the first image for inspection
from PIL import Image
img = Image.fromarray(images[0])
img.save("test_render.png")
print("Saved test render to test_render.png")
except Exception as e:
print(f"Error rendering: {e}")
print("Test completed")
if __name__ == "__main__":
PhysicalEnv.cli()