diff --git a/environments/hack0/README.md b/environments/hack0/README.md new file mode 100644 index 00000000..82abd9ff --- /dev/null +++ b/environments/hack0/README.md @@ -0,0 +1,14 @@ +# Physical Environment + +Our project is a physical environment to train LLMs to generate STL files, the same files used in physical CAD designs. + +## Setup + +```sh +$ pip install pyrender trimesh pyglet matplotlib torch transformers pydantic vllm numpy requests tenacity wandb +``` + +Shared libraries for Ubuntu GL rendering. +```sh +$ sudo apt-get install libglfw3-dev libgles2-mesa-dev libnvidia-gl-570-server +``` \ No newline at end of file diff --git a/environments/hack0/__init__.py b/environments/hack0/__init__.py new file mode 100644 index 00000000..7ddc4136 --- /dev/null +++ b/environments/hack0/__init__.py @@ -0,0 +1 @@ +# for packages diff --git a/environments/hack0/grpo.py b/environments/hack0/grpo.py new file mode 100644 index 00000000..4721d61d --- /dev/null +++ b/environments/hack0/grpo.py @@ -0,0 +1,544 @@ +# ADAPTED FROM THE SAMPLE TRAINER + +import atexit +import json +import math +import os +import random +import shutil +import string +import subprocess +import time +from typing import List, Optional, Tuple + +import numpy as np +import requests +import torch +import torch.nn.functional as F +import wandb # Added for logging +from pydantic import BaseModel, Field +from tenacity import retry, stop_after_attempt, wait_exponential +from torch.optim import AdamW +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Global variable to keep track of the vLLM process +vllm_process = None + + +def cleanup_vllm(): + global vllm_process + if vllm_process: + print("\nTerminating vLLM process...") + vllm_process.terminate() + try: + vllm_process.wait(timeout=5) # Wait a bit for graceful shutdown + print("vLLM process terminated.") + except subprocess.TimeoutExpired: + print("vLLM process did not terminate gracefully, killing.") + vllm_process.kill() + vllm_process.wait() + print("vLLM process killed.") + vllm_process = None + + +# Register the cleanup function to be called on script exit +atexit.register(cleanup_vllm) + + +class TrainingConfig(BaseModel): + """ + Training details, model, etc + """ + + model_name: str = Field(..., description="Name of the base model to train") + lr: float = Field(1e-5, description="Learning rate for the optimizer") + training_steps: int = Field( + 10, description="Number of training steps" + ) # Renamed from epochs + batch_size: int = Field( + 2, description="Batch size for training (will be handled by get_data)" + ) + seq_len: int = Field(2048, description="Sequence length for training") + gradient_accumulation_steps: int = Field( + 32, description="Number of gradient accumulation steps" + ) + device: str = Field( + "cuda" if torch.cuda.is_available() else "cpu", description="Device to train on" + ) + save_path: str = Field( + "trained_model_checkpoints", description="Base path to save model checkpoints" + ) + vllm_restart_interval: int = Field( + 3, description="Restart vLLM every N training steps" + ) + vllm_port: int = Field(9001, description="Port for the vLLM server") + + # Wandb configuration + use_wandb: bool = Field( + False, description="Whether to use Weights & Biases for logging" + ) + wandb_project: Optional[str] = Field(None, description="Wandb project name") + wandb_group: Optional[str] = Field(None, description="Wandb group name") + + +@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15)) +def register_trainer(config: TrainingConfig): + """ + Register the trainer with the Atropos API + """ + requests.post( + "http://localhost:8000/register", + json={ + "wandb_group": config.wandb_group, + "wandb_project": config.wandb_project, + "batch_size": config.batch_size * config.gradient_accumulation_steps, + "max_token_len": config.seq_len, + "starting_step": 0, + "checkpoint_dir": config.save_path, + "save_checkpoint_interval": config.training_steps, + "num_steps": config.training_steps, + }, + timeout=10, + ) + + +@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15)) +def get_batch(): + data = requests.get("http://localhost:8000/batch", timeout=10).json() + return data + + +def pad_data_to_good_offset(data, batch_size: int): + max_token_len = max( + [max([len(x) for x in item["tokens"]]) for item in data["batch"]] + ) + # usually 64 is a good choice to ensure nonweird scaling behavior on GPUS + # so we pad to the nearest multiple of 64 + good_multiple = 64 + if (max_token_len - 1) % (good_multiple) != 0: + max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple + token_setup_len = ( + max_token_len + 1 + ) # add 1 so we can make it causal at the proper length + else: + token_setup_len = max_token_len + max_token_len = ( + max_token_len - 1 + ) # since it's causal we need to remove the last bit... + # pad all tokens to max_token_len and add to lists + input_ids = list() + labels = list() + advantages = list() + lengths = list() + for item in data["batch"]: + scores = item["scores"] + scores = np.array(scores) + # check if we have more than 1 score... + if len(scores) > 1: + scores = scores - scores.mean() + scores = scores / max(scores.std(), 1e-8) + item["scores"] = scores + if item["overrides"] is not None: + for i in range(len(item["overrides"])): + if item["overrides"][i].get("set_advantage_to_zero", False): + item["scores"][i] = 0 + for i in range(len(item["tokens"])): + lengths.append( + math.ceil((len(item["tokens"][i]) - 1) / (good_multiple)) + * good_multiple + ) + label_item = np.concatenate( + [ + np.array(item["masks"][i]), + np.full( + max(0, token_setup_len - len(item["tokens"][i])), + -100, + dtype=np.int32, + ), + ] + ) + item["tokens"][i] = np.concatenate( + [ + np.array(item["tokens"][i]), + np.zeros( + max(0, token_setup_len - len(item["tokens"][i])), dtype=np.int32 + ), + ] + ) + input_ids.append(item["tokens"][i][:-1]) + labels.append(label_item[1:]) + advantages.append(item["scores"][i]) + # combine all lists into tensors + token_batches = [] + label_batches = [] + advantage_batches = [] + for i in range(len(input_ids) // batch_size): + token_batches.append( + torch.tensor( + np.stack(input_ids[i * batch_size : (i + 1) * batch_size], axis=0) + ) + ) + label_batches.append( + torch.tensor( + np.stack(labels[i * batch_size : (i + 1) * batch_size], axis=0) + ) + ) + advantage_batches.append( + torch.tensor( + np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0) + ).view(-1, 1) + ) + return token_batches, label_batches, advantage_batches + + +def get_data( + batch_size: int, seq_len: int +) -> List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + """ + getting data from the api + """ + batches = [] + while True: + data = get_batch() + if data["batch"] is not None: + # Save the batch + with open("temp.json", "w", encoding="utf-8") as f: + json.dump(data, f) + # In case the inference runs ahead of the training, we loop until we don't have any more data + batches.append(pad_data_to_good_offset(data, batch_size)) + elif len(batches) > 0: + # Return the batches + return batches + else: + time.sleep(1) + + +def train(config: TrainingConfig): + """ + Setups and runs GRPO training, restarting vLLM periodically, with wandb logging. + """ + global vllm_process # Declare intention to modify the global variable + + # --- Wandb Setup --- + if config.use_wandb: + if not config.wandb_project: + print("Warning: wandb_project not set, disabling wandb.") + config.use_wandb = False + else: + if not config.wandb_group: + # Set group to random 8 character string + config.wandb_group = "".join( + random.choices(string.ascii_letters + string.digits, k=8) + ) + try: + wandb.init( + project=config.wandb_project, + group=config.wandb_group, + config=config.dict(), # Log config parameters + ) + print( + f"Wandb logging enabled. Run: {wandb.run.name} (Project: {config.wandb_project}) " + ) + except Exception as e: + print(f"Error initializing wandb: {e}. Disabling wandb.") + config.use_wandb = False + # --- End Wandb Setup --- + + # Initialize model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(config.model_name) + model = AutoModelForCausalLM.from_pretrained( + config.model_name, torch_dtype=torch.bfloat16 + ) + + model.to(config.device) + model.gradient_checkpointing_enable() + model.train() + + # Setup optimizer + optimizer = AdamW(model.parameters(), lr=config.lr) + + print( + f"Starting training for {config.training_steps} steps on device: {config.device}" + ) + print( + f"vLLM will be restarted every {config.vllm_restart_interval} steps on port {config.vllm_port}" + ) + + os.makedirs(config.save_path, exist_ok=True) # Ensure base save directory exists + register_trainer(config) + + # Init vllm + vllm_command = [ + "python", + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + config.model_name, + "--port", + str(config.vllm_port), + "--dtype", + "auto", + "--gpu-memory-utilization", + "0.45", + "--disable-log-requests", + ] + print(f" Launching vLLM server: {' '.join(vllm_command)}") + try: + vllm_process = subprocess.Popen(vllm_command) + print(f" vLLM server launched with PID: {vllm_process.pid}") + # Check immediate errors + try: + stdout, stderr = vllm_process.communicate(timeout=2) + if vllm_process.returncode is not None and vllm_process.returncode != 0: + print(f" Error starting vLLM: {stderr.decode()}") + vllm_process = None + # Maybe raise error or just warn? + print(" WARNING: Failed to start vLLM server after checkpoint.") + except subprocess.TimeoutExpired: + print(" vLLM process started (check logs for details).") + except FileNotFoundError: + print( + "\n *** ERROR: 'python -m vllm...' command not found. Make sure vLLM is installed and accessible. ***\n" + ) + # Potentially stop training or just disable further vLLM restarts + print(" Disabling further vLLM restarts.") + config.vllm_restart_interval = ( + config.training_steps + 1 + ) # Prevent further restarts + except Exception as e: + print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n") + print(" Disabling further vLLM restarts.") + config.vllm_restart_interval = ( + config.training_steps + 1 + ) # Prevent further restarts + + batches = list() + for step in range(config.training_steps): + total_loss = 0 + print(f"Step {step+1}/{config.training_steps}") + total_pos_logp = 0 + total_neg_logp = 0 + total_logp = 0 + total_pos = 0 + total_neg = 0 + if len(batches) == 0: + batches = get_data(config.batch_size, config.seq_len) + token_batches, label_batches, advantage_batches = batches.pop(0) + # Terminate existing vLLM process if running + if ( + step + 1 + ) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step + # Terminate existing vLLM process if running + if vllm_process: + print(" Terminating existing vLLM process...") + vllm_process.terminate() + try: + vllm_process.wait(timeout=5) + except subprocess.TimeoutExpired: + print( + " Existing vLLM process did not terminate gracefully, killing." + ) + vllm_process.kill() + vllm_process.wait() + vllm_process = None + for tokens, labels, advantages in zip( + token_batches, label_batches, advantage_batches + ): + + tokens, labels, advantages = ( + tokens.to(config.device), + labels.to(config.device), + advantages.to(config.device), + ) + + # Forward pass + # User specified that tokens/labels are already prepared by get_data + outputs = model(tokens) # Assuming model just needs tokens + logits = outputs.logits # Assuming this is the structure + + # Calculate GRPO loss (reverting to user's previous logic) + # User stated ignore_index is -100 and tokens/labels are aligned by get_data + # Assuming logits correspond directly to labels indices (no shift needed here) + logp_per_token = -F.cross_entropy( + logits.view(-1, logits.size(-1)), # Flatten logits + labels.view(-1), # Flatten labels + reduction="none", + ignore_index=-100, # User specified ignore index + ).view( + labels.shape + ) # Reshape back to (batch, seq_len) + + # Masking based on labels != -100 + mask = (labels != -100).float() + with torch.no_grad(): + pos = (advantages > 0).float() + neg = (advantages <= 0).float() + avg_logp = (logp_per_token * mask).sum(-1) / mask.sum(-1) + pos_logp = (logp_per_token * pos).mean().item() + neg_logp = (logp_per_token * neg).mean().item() + total_pos_logp += pos_logp + total_neg_logp += neg_logp + total_logp += avg_logp + total_pos += pos.sum().item() + total_neg += neg.sum().item() + + grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach()) + grpo_loss = ( + ((-grpo_loss_term * mask).sum(-1) / mask.sum(-1)) + * advantages.to(logp_per_token.device) + ).mean() / config.gradient_accumulation_steps + grpo_loss.backward() + total_loss += grpo_loss.item() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + optimizer.zero_grad() + if total_pos > 0: + total_pos_logp /= total_pos + if total_neg > 0: + total_neg_logp /= total_neg + # --- Wandb Logging --- + if config.use_wandb: + wandb.log( + { + "train/loss": total_loss, + "train/learning_rate": optimizer.param_groups[0]["lr"], + "train/grad_norm": grad_norm.item(), + "train/pos_logp": total_pos_logp, + "train/neg_logp": total_neg_logp, + "train/logp": total_logp, + }, + step=step + 1, + ) + # --- End Wandb Logging --- + + print(f" Step Loss: {grpo_loss.item():.4f}") + + # --- vLLM Restart Logic (Moved AFTER optimizer step) --- + # Note: There are much better ways of updating the policy, this is just a very simple example + if ( + step + 1 + ) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step + checkpoint_path = os.path.join( + config.save_path, f"step_{step+1}" + ) # Save as step+1 since it's after step completion + print(f" Saving checkpoint to {checkpoint_path}...") + # Ensure fresh directory for saving + if os.path.exists(checkpoint_path): + shutil.rmtree(checkpoint_path) # Remove old checkpoint if it exists + os.makedirs(checkpoint_path, exist_ok=True) + model.save_pretrained(checkpoint_path) + tokenizer.save_pretrained(checkpoint_path) + print(" Checkpoint saved.") + + # Terminate existing vLLM process if running + if vllm_process: + print(" Terminating existing vLLM process...") + vllm_process.terminate() + try: + vllm_process.wait(timeout=5) + except subprocess.TimeoutExpired: + print( + " Existing vLLM process did not terminate gracefully, killing." + ) + vllm_process.kill() + vllm_process.wait() + vllm_process = None + + # Launch new vLLM process (only if not the very last step, maybe? depends on use case) + # Let's still launch it on the last step for consistency, cleanup will handle it. + vllm_command = [ + "python", + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + os.path.join(config.save_path, f"step_{step+1}"), + "--port", + str(config.vllm_port), + "--dtype", + "auto", + "--gpu-memory-utilization", + "0.45", + "--disable-log-requests", + "--served-model-name", + config.model_name, + ] + print(f" Launching vLLM server: {' '.join(vllm_command)}") + torch.cuda.empty_cache() + try: + vllm_process = subprocess.Popen(vllm_command) + print(f" vLLM server launched with PID: {vllm_process.pid}") + # Check immediate errors + try: + stdout, stderr = vllm_process.communicate(timeout=2) + if ( + vllm_process.returncode is not None + and vllm_process.returncode != 0 + ): + print(f" Error starting vLLM: {stderr.decode()}") + vllm_process = None + # Maybe raise error or just warn? + print( + " WARNING: Failed to start vLLM server after checkpoint." + ) + except subprocess.TimeoutExpired: + print(" vLLM process started (check logs for details).") + except FileNotFoundError: + print( + "\n *** ERROR: 'python -m vllm...' command not found. ", + "Make sure vLLM is installed and accessible. ***\n", + ) + # Potentially stop training or just disable further vLLM restarts + print(" Disabling further vLLM restarts.") + config.vllm_restart_interval = ( + config.training_steps + 1 + ) # Prevent further restarts + except Exception as e: + print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n") + print(" Disabling further vLLM restarts.") + config.vllm_restart_interval = ( + config.training_steps + 1 + ) # Prevent further restarts + # --- End vLLM Restart Logic --- + + # Basic check if vLLM process terminated unexpectedly (outside interval check) + if vllm_process and vllm_process.poll() is not None: + print( + f"\n *** WARNING: vLLM process terminated unexpectedly (return code: {vllm_process.returncode}). ", + "Check vLLM logs. ***\n", + ) + stderr_output = ( + vllm_process.stderr.read().decode() + if vllm_process.stderr + else "No stderr" + ) + print(f"vLLM stderr: {stderr_output}") + vllm_process = None # Reset so it relaunches next interval + + print("Training finished.") + # --- Wandb Finish --- + if config.use_wandb: + wandb.finish() + # --- End Wandb Finish --- + # Final cleanup (vLLM termination) is handled by atexit + + # --- Placeholder for final model save --- + final_save_path = os.path.join(config.save_path, "final_model") + print(f"Saving final model to {final_save_path}") + if os.path.exists(final_save_path): + shutil.rmtree(final_save_path) + os.makedirs(final_save_path, exist_ok=True) + model.save_pretrained(final_save_path) + tokenizer.save_pretrained(final_save_path) + print("Final model saved.") + + +if __name__ == "__main__": + training_config = TrainingConfig( + model_name="google/gemma-3-27b-it", + training_steps=20, + vllm_restart_interval=3, + use_wandb=True, + wandb_project="grpo-physical-trainer" + ) + + train(training_config) diff --git a/environments/hack0/pyrender_utils.py b/environments/hack0/pyrender_utils.py new file mode 100644 index 00000000..aed6e5e9 --- /dev/null +++ b/environments/hack0/pyrender_utils.py @@ -0,0 +1,225 @@ +import os +import numpy as np +import trimesh +import pyrender + +# Headless rendering with GPU acceleration (egl), for non-GPU environments omesa will be faster +os.environ['PYOPENGL_PLATFORM'] = 'egl' + +def create_look_at_matrix(eye: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray: + """ + Create a look-at transformation matrix for a camera. + + eye - position of the camera + target - position the camera is looking at + up - up direction for the camera + + returns the 4x4 transformation matrix + """ + eye = np.asarray(eye) + target = np.asarray(target) + up = np.asarray(up) + + # Forward vector (from eye to target) + forward = target - eye + forward = forward / np.linalg.norm(forward) + + # Side vector (right vector) + side = np.cross(forward, up) + side = side / np.linalg.norm(side) + + # Recompute up vector to ensure orthogonality + up = np.cross(side, forward) + up = up / np.linalg.norm(up) + + # Create the rotation matrix | NOTE: PyRender uses OpenGL convention + # where camera looks down negative z-axis + R = np.eye(4) + R[0, :3] = side + R[1, :3] = up + R[2, :3] = -forward + + T = np.eye(4) # translation matrix + T[:3, 3] = eye + + # The camera pose matrix is the inverse of the view matrix + # but we need to return the pose directly for pyrender + return T @ R + +class PyRenderOffline: + def __init__(self, width=224, height=224): # Standard CLIP image size + self.width = width + self.height = height + try: + self.renderer = pyrender.OffscreenRenderer(viewport_width=self.width, viewport_height=self.height, point_size=1.0) + except Exception as e: + print(f"Failed to initialize OffscreenRenderer (is a display server/EGL/OSMesa available?): {e}") + print("Try: pip install pyglet; or for headless: export PYOPENGL_PLATFORM=osmesa (or egl)") + self.renderer = None # Fallback or raise error + raise + + # Create camera poses using explicit transformation matrices + + # Front view (looking along the -Z axis) + front_pose = np.array([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 5], + [0, 0, 0, 1] + ]) + + # Top view (looking along the -Y axis) + top_pose = np.array([ + [1, 0, 0, 0], + [0, 0, 1, 5], + [0, -1, 0, 0], + [0, 0, 0, 1] + ]) + + # Diagonal view (from upper corner) + side_pose = np.array([ + [0.866, -0.25, 0.433, 3], # Camera right vector + [0.0, 0.866, 0.5, 3], # Camera up vector + [-0.5, -0.433, 0.75, 5], # Camera forward vector (pointing at origin) + [0, 0, 0, 1] + ]) + + # Store camera poses + self.camera_poses = [front_pose, top_pose, side_pose] + + # Debug print for the camera poses cause it doesn't look right + # and I'm not a game developer lol + print("Camera poses:") + for i, pose in enumerate(self.camera_poses): + print(f"Camera {i}:\n{pose}") + + # slightly wider field of view to ensure objects are visible + self.camera = pyrender.PerspectiveCamera(yfov=np.pi / 2.5, aspectRatio=1.0) + + # Bright point light + self.light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=10.0) + + def render_mesh_to_images(self, mesh_obj: trimesh.Trimesh): + if not self.renderer: + print("Renderer not initialized, cannot render.") + return [np.zeros((self.height, self.width, 3), dtype=np.uint8) for _ in range(3)] + + print(f"Rendering mesh with {len(mesh_obj.vertices)} vertices and {len(mesh_obj.faces)} faces") + + images = [] + + # Make a copy to avoid modifying original mesh + render_mesh = mesh_obj.copy() + + # Center and scale the mesh for visibility + render_mesh.apply_translation(-render_mesh.centroid) + scale_factor = 0.8 / np.max(render_mesh.extents) # Scale to unit size but slightly smaller + render_mesh.apply_scale(scale_factor) + + # Create a ground plane + ground_plane = trimesh.creation.box([4.0, 4.0, 0.01]) + ground_plane.apply_translation([0, 0, -0.5]) + + # Color scheme for wireframe blueprint + blueprint_bg = [0.98, 0.98, 1.0, 1.0] # Light blue background + wireframe_color = [0.0, 0.4, 0.8, 1.0] # Medium bright blue for wireframes + grid_color = [0.0, 0.2, 0.4, 1.0] # Darker blue for grid + + # Wireframe material for the edges - completely opaque + wireframe_material = pyrender.MetallicRoughnessMaterial( + baseColorFactor=wireframe_color, + metallicFactor=0.0, + roughnessFactor=1.0, + wireframe=True + ) + + # Grid material for the ground plane + grid_material = pyrender.MetallicRoughnessMaterial( + baseColorFactor=grid_color, + metallicFactor=0.0, + roughnessFactor=1.0, + wireframe=True + ) + + for i, pose in enumerate(self.camera_poses): + # Create a fresh scene with blueprint background + scene = pyrender.Scene(ambient_light=[0.8, 0.8, 0.95], bg_color=blueprint_bg) + + # Add ground plane as grid + plane_mesh = pyrender.Mesh.from_trimesh(ground_plane, material=grid_material) + scene.add(plane_mesh) + + # For wireframe rendering, we'll create line segments directly + edges = render_mesh.edges_unique + edge_vertices = [] + edge_indices = [] + + # Extract all edges for wireframe rendering + for _, edge in enumerate(edges): + v0_idx = len(edge_vertices) + edge_vertices.append(render_mesh.vertices[edge[0]]) + edge_vertices.append(render_mesh.vertices[edge[1]]) + edge_indices.append([v0_idx, v0_idx+1]) + + # Create lines primitive for edges + if len(edge_vertices) > 0: + edge_verts = np.array(edge_vertices, dtype=np.float32) + edge_indices = np.array(edge_indices, dtype=np.uint32) + + # Create a primitive for the lines + primitive = pyrender.Primitive( + positions=edge_verts, + indices=edge_indices, + mode=pyrender.constants.GLTF.LINES, + material=wireframe_material + ) + + # Create a mesh with just the line primitive + edge_mesh = pyrender.Mesh(primitives=[primitive]) + scene.add(edge_mesh) + + # Add camera + scene.add(self.camera, pose=pose) + + # Add light from camera direction (key light) + scene.add(self.light, pose=pose) + + # Add second light from above for better visibility + top_light_pose = np.eye(4) + top_light_pose[1, 3] = 3.0 + scene.add(self.light, pose=top_light_pose) + + try: + color, _ = self.renderer.render(scene) + + # Post-process to enhance blueprint effect + # Convert to float for processing + img_float = color.astype(np.float32) / 255.0 + + # Add a subtle grid pattern to the background + grid_size = 40 + grid_intensity = 0.02 + + # Draw faint horizontal grid lines + for y in range(0, color.shape[0], grid_size): + img_float[y:y+1, :, :] = np.minimum(img_float[y:y+1, :, :] + grid_intensity, 1.0) + + # Draw faint vertical grid lines + for x in range(0, color.shape[1], grid_size): + img_float[:, x:x+1, :] = np.minimum(img_float[:, x:x+1, :] + grid_intensity, 1.0) + + # Convert back to uint8 + processed_img = (img_float * 255).astype(np.uint8) + + images.append(processed_img) + print(f"Rendered wireframe view {i}") + + except Exception as e: + print(f"Error during rendering view {i}: {e}") + images.append(np.zeros((self.height, self.width, 3), dtype=np.uint8)) + + return images + + def __del__(self): + if self.renderer: + self.renderer.delete() diff --git a/environments/hack0/test_rendered_sphere_views.png b/environments/hack0/test_rendered_sphere_views.png new file mode 100644 index 00000000..cd2cc519 Binary files /dev/null and b/environments/hack0/test_rendered_sphere_views.png differ diff --git a/environments/hack0/test_renderer_example.py b/environments/hack0/test_renderer_example.py new file mode 100644 index 00000000..4290ea40 --- /dev/null +++ b/environments/hack0/test_renderer_example.py @@ -0,0 +1,46 @@ +import os +import numpy as np +import trimesh +import matplotlib.pyplot as plt +from pyrender_utils import PyRenderOffline + +def test_render_example(): + """ + Example script to test the PyRenderOffline class with a real mesh. + Only runs if a GPU/rendering environment is available. See README.md for more details. + """ + try: + # Create a high-quality sphere for wireframe rendering + # Using more subdivisions for smoother appearance + sphere = trimesh.creation.icosphere(subdivisions=4, radius=1.0) + + # Create a perfect sphere instead of a noisy one + # This will look better as a wireframe/blueprint for our test ;) + print(f"Created sphere with {len(sphere.vertices)} vertices and {len(sphere.faces)} faces") + print(f"Sphere has {len(sphere.edges_unique)} unique edges for wireframe rendering") + + renderer = PyRenderOffline(width=512, height=512) # larger dimensions than the CLIP size for better detail + + images = renderer.render_mesh_to_images(sphere) + + # Display the results + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + view_names = ['Front', 'Top', 'Diagonal'] + + for i, (img, name) in enumerate(zip(images, view_names)): + axes[i].imshow(img) + axes[i].set_title(f"{name} View") + axes[i].axis('off') + + plt.savefig('test_rendered_sphere_views.png') + plt.close() + + print(f"Successfully rendered sphere from 3 viewpoints") + print(f"Images saved to rendered_sphere_views.png") + return True + except Exception as e: + print(f"Failed to run renderer example: {e}") + return False + +if __name__ == "__main__": + test_render_example() \ No newline at end of file