import numpy as np import trimesh import io import re import wandb import os import glob import random from tqdm.asyncio import tqdm_asyncio from typing import * from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataGroup, EvalHandlingEnum ) from atroposlib.type_definitions import Item, number from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer # Fix the relative imports for running directly try: from .pyrender_utils import PyRenderOffline from .judgement_model import CLIPScorer except ImportError: from pyrender_utils import PyRenderOffline from judgement_model import CLIPScorer system_prompt = ( "You are an expert in 3D modeling and computer-aided design. Your task is to analyze the " "blueprints or wireframe views of objects and generate the corresponding STL file content. " "STL (stereolithography) files represent 3D models as a collection of triangular facets.\n\n" "You may use tags to work through your reasoning about the shape, " "dimensions, and geometric features of the model. Be methodical in your approach.\n\n" "STL files can be in ASCII or binary format. For this task, generate ASCII STL content that " "accurately represents the 3D model shown in the provided views.\n\n" "Your final output must be enclosed in tags, containing only the valid STL content " "and nothing else. The STL content should begin with 'solid' and end with 'endsolid'.\n\n" "Example of STL format:\n" "\n" "solid model\n" " facet normal nx ny nz\n" " outer loop\n" " vertex x1 y1 z1\n" " vertex x2 y2 z2\n" " vertex x3 y3 z3\n" " endloop\n" " endfacet\n" " ... more facets ...\n" "endsolid model\n" "" ) class PhysicalRow(TypedDict): prompt: str image: np.ndarray stl: str class PhysicalEnv(BaseEnv): name = "physical" def __init__( self, config: BaseEnvConfig, server_configs: List[APIServerConfig], slurm=True, testing=False ): super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer = list() self.eval_metrics = list() # Add tracking for wandb visualizations self.rollouts_for_wandb = [] self.completion_lengths = [] # Initialize renderer and CLIP scorer self.renderer = PyRenderOffline(width=224, height=224) self.clip_scorer = CLIPScorer() @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( tokenizer_name="google/gemma-3-27b-it", group_size=8, use_wandb=True, rollout_server_url="http://localhost:8000", total_steps=1000, batch_size=12, steps_per_eval=100, eval_handling=EvalHandlingEnum.LIMIT_TRAIN, max_token_length=2048, wandb_name="physical", ) server_configs = [ APIServerConfig( model_name="google/gemma-3-27b-it", base_url="http://localhost:9001/v1", api_key="x", num_requests_for_eval=256, ), ] return env_config, server_configs async def wandb_log(self, wandb_metrics: Optional[Dict] = None): if wandb_metrics is None: wandb_metrics = {} # Try to calculate percent_correct, pass if there's a division by zero try: wandb_metrics["train/percent_correct"] = sum( self.percent_correct_buffer ) / len(self.percent_correct_buffer) except ZeroDivisionError: # Skip if buffer is empty pass self.percent_correct_buffer = list() for item in self.eval_metrics: wandb_metrics[item[0]] = item[1] self.eval_metrics = list() # Call the parent method to handle the server metrics await super().wandb_log(wandb_metrics) def load_stl_file(self, stl_path): """Load an STL file into a trimesh object""" try: mesh = trimesh.load(stl_path) return mesh except Exception as e: print(f"Error loading STL file {stl_path}: {e}") return None def generate_query_from_images(self, images): """Generate a query based on the rendered images of the STL file""" # In a real implementation, this would use a vision model to generate a description # For this simplified version, we'll use different templates to add variety templates = [ "Create a 3D model (STL file) for the object shown in these technical drawings. Be precise with the geometry.", "Based on these wireframe views, generate the STL code for this 3D object. Pay attention to all visible features.", "Using these blueprint images as reference, provide the STL file format data to recreate this 3D model.", "These are technical views of a 3D object. Generate the STL representation that would produce this exact shape.", "Reconstruct this 3D model from the provided wireframe views and output the STL file content." ] return random.choice(templates) async def setup(self): # Load all STL files from sample_data self.stl_files = glob.glob(os.path.join('sample_data', '*.stl')) if not self.stl_files: raise ValueError("No STL files found in the sample_data directory") print(f"Found {len(self.stl_files)} STL files") # Split files into train and test sets (80/20 split) random.seed(42) random.shuffle(self.stl_files) split_idx = int(len(self.stl_files) * 0.8) self.train_files = self.stl_files[:split_idx] self.test_files = self.stl_files[split_idx:] self.iter = 0 def save_checkpoint(self, step, data=None): if data is None: data = {} data["iter"] = self.iter super().save_checkpoint(step, data) async def rollout_and_score_eval(self, stl_path: str) -> number: # Load the STL file mesh = self.load_stl_file(stl_path) if mesh is None: return 0 # Render the images images = self.renderer.render_mesh_to_images(mesh) # Generate a query from the images query = self.generate_query_from_images(images) # Get a completion from the model completion = await self.server.chat_completion( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": query}, ], n=1, max_tokens=self.config.max_token_length, temperature=0.0, split="eval", ) # Extract the STL content from the completion response_content = completion.choices[0].message.content stl_content = self.extract_stl_content(response_content) # Load the original mesh directly original_mesh = mesh # Save the generated STL content to a temporary file temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl" try: with open(temp_file, "w") as f: f.write(stl_content) # Load the generated mesh generated_mesh = trimesh.load(temp_file) # Score the generated mesh against the original score = self.score_meshes_similarity(original_mesh, generated_mesh) # Cleanup os.remove(temp_file) return score except Exception as e: print(f"Error processing generated STL: {e}") if os.path.exists(temp_file): os.remove(temp_file) return 0 def extract_stl_content(self, response_content): """Extract STL content from the model's response""" # Find content between and tags match = re.search(r'(.*?)', response_content, re.DOTALL) if match: return match.group(1).strip() return "" def score_meshes_similarity(self, original_mesh, generated_mesh): """Score the similarity between two meshes""" # This is a simple implementation - in practice you'd want more sophisticated metrics # Compare basic properties like number of vertices, faces, and volume orig_stats = { 'vertices': len(original_mesh.vertices), 'faces': len(original_mesh.faces), 'volume': original_mesh.volume or 1.0, 'surface_area': original_mesh.area or 1.0 } gen_stats = { 'vertices': len(generated_mesh.vertices), 'faces': len(generated_mesh.faces), 'volume': generated_mesh.volume or 1.0, 'surface_area': generated_mesh.area or 1.0 } # Calculate ratios (capped at 1.0 for when generated > original) vertex_ratio = min(gen_stats['vertices'] / max(orig_stats['vertices'], 1), 1.0) face_ratio = min(gen_stats['faces'] / max(orig_stats['faces'], 1), 1.0) volume_ratio = min(gen_stats['volume'] / max(orig_stats['volume'], 1), 1.0) area_ratio = min(gen_stats['surface_area'] / max(orig_stats['surface_area'], 1), 1.0) # Average the ratios for a final score score = (vertex_ratio + face_ratio + volume_ratio + area_ratio) / 4.0 return score async def evaluate(self, *args, **kwargs): eval_tasks = [] for stl_file in self.test_files[:10]: # Limit to 10 files for evaluation to keep it manageable eval_tasks.append(self.rollout_and_score_eval(stl_file)) scores = await tqdm_asyncio.gather(*eval_tasks) self.eval_metrics.append(("eval/similarity_score", sum(scores) / len(scores))) async def collect_trajectories(self, item: PhysicalRow) -> Tuple[ScoredDataGroup, list[Item]]: stl_path = item["stl_path"] mesh = self.load_stl_file(stl_path) images = self.renderer.render_mesh_to_images(mesh) query = self.generate_query_from_images(images) # For original STL content, we'll just store the file path instead of the content # as the files may be binary and can't be simply read as text original_stl_path = stl_path user_message = {"role": "user", "content": query} chat_completions = await self.server.chat_completion( messages=[{"role": "system", "content": system_prompt}, user_message], n=self.config.group_size, max_tokens=self.config.max_token_length, ) to_score = list() to_backlog = list() for i, chat_completion in enumerate(chat_completions.choices): messages = ( {"role": "system", "content": system_prompt}, user_message, {"role": "assistant", "content": chat_completion.message.content}, ) to_score.append( { "messages": messages, "original_stl_path": original_stl_path, "images": images, "finish_reason": chat_completion.finish_reason, } ) to_postprocess = await self.score(to_score) return to_postprocess, to_backlog async def score(self, rollout_group_data) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: scores = ScoredDataGroup() scores["tokens"] = list() scores["masks"] = list() scores["scores"] = list() random.shuffle(rollout_group_data) for item in rollout_group_data: response_content = item["messages"][-1]["content"] stl_content = self.extract_stl_content(response_content) # Save the generated STL content to a temporary file temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl" try: with open(temp_file, "w") as f: f.write(stl_content) # Load the original STL directly from its path original_stl_path = item["original_stl_path"] original_mesh = trimesh.load(original_stl_path) # Load the generated mesh generated_mesh = trimesh.load(temp_file) # Score the generated mesh against the original mesh_similarity = self.score_meshes_similarity(original_mesh, generated_mesh) # Generate rendered images of the produced STL generated_images = self.renderer.render_mesh_to_images(generated_mesh) # Use CLIP to score the visual similarity images_reward = 0.0 if len(generated_images) > 0 and len(item["images"]) > 0: # Extract query from the user message query = item["messages"][1]["content"] # Score the visual similarity using CLIP clip_scores = self.clip_scorer.score_images(generated_images, query) images_reward = sum(clip_scores) / len(clip_scores) / 100.0 # Normalize to roughly 0-1 # Combine mesh similarity and image similarity for final reward reward = 0.5 * mesh_similarity + 0.5 * images_reward out_dict = tokenize_for_trainer( self.tokenizer, item["messages"], item["finish_reason"] ) tokens = out_dict["tokens"] masks = out_dict["masks"] # Remove obviously bad examples if len([1 for i in masks if i != -100]) < 10: continue scores["tokens"].append(tokens) scores["masks"].append(masks) scores["scores"].append(reward) self.percent_correct_buffer.append(reward) # Clean up temporary files os.remove(temp_file) if len(scores["tokens"]) >= self.config.group_size: break except Exception as e: print(f"Error in scoring: {e}") if os.path.exists(temp_file): os.remove(temp_file) # Apply length penalty if all scores are similar if all(abs(score - scores["scores"][0]) < 0.1 for score in scores["scores"]): token_lengths = [len(token) for token in scores["tokens"]] if max(token_lengths) > 0: max_allowed_length = self.config.max_token_length length_threshold = max_allowed_length * 0.5 scores["scores"] = [] for length in token_lengths: if length <= length_threshold: scores["scores"].append(1.0) else: percentage_of_range = (length - length_threshold) / ( max_allowed_length - length_threshold ) percentage_of_range = min(percentage_of_range, 1.0) scores["scores"].append(1.0 - percentage_of_range) if all([scores["scores"][0] == score for score in scores["scores"]]): return None # If all the same, we return None return scores async def get_next_item(self) -> PhysicalRow: stl_path = self.train_files[self.iter % len(self.train_files)] self.iter += 1 # Load the STL file and render it mesh = self.load_stl_file(stl_path) if mesh is None: # Skip this file and try the next one if there's an issue return await self.get_next_item() # Render the mesh to get images images = self.renderer.render_mesh_to_images(mesh) # Generate a query from the images query = self.generate_query_from_images(images) # Return a row with the prompt, image, and path to the STL file return { "prompt": query, "image": images[0] if images else np.zeros((224, 224, 3), dtype=np.uint8), "stl_path": stl_path } @classmethod async def test_sample_stl(cls): """Test loading and rendering a sample STL file""" # Create temporary environment instance env_config = BaseEnvConfig( tokenizer_name="google/gemma-3-27b-it", group_size=8, use_wandb=False, rollout_server_url="http://localhost:8000", total_steps=1000, batch_size=12, steps_per_eval=100, eval_handling=EvalHandlingEnum.LIMIT_TRAIN, max_token_length=2048, wandb_name="physical_test", ) server_configs = [ APIServerConfig( model_name="google/gemma-3-27b-it", base_url="http://localhost:9001/v1", api_key="x", num_requests_for_eval=256, ), ] env = cls(env_config, server_configs, slurm=False, testing=True) # Find sample STL files stl_files = glob.glob(os.path.join('sample_data', '*.stl')) if not stl_files: print("No STL files found in sample_data/") return # Test loading and rendering the first file print(f"Testing with STL file: {stl_files[0]}") mesh = env.load_stl_file(stl_files[0]) if mesh is None: print("Failed to load STL file") return print(f"Loaded mesh with {len(mesh.vertices)} vertices and {len(mesh.faces)} faces") # Render the mesh try: images = env.renderer.render_mesh_to_images(mesh) print(f"Successfully rendered {len(images)} images") # Save the first image for inspection from PIL import Image img = Image.fromarray(images[0]) img.save("test_render.png") print("Saved test render to test_render.png") except Exception as e: print(f"Error rendering: {e}") print("Test completed") if __name__ == "__main__": PhysicalEnv.cli()