mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
Merge branch 'main' of https://github.com/ecsbeats/atropos
This commit is contained in:
commit
332a1025ba
6 changed files with 830 additions and 0 deletions
14
environments/hack0/README.md
Normal file
14
environments/hack0/README.md
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# Physical Environment
|
||||
|
||||
Our project is a physical environment to train LLMs to generate STL files, the same files used in physical CAD designs.
|
||||
|
||||
## Setup
|
||||
|
||||
```sh
|
||||
$ pip install pyrender trimesh pyglet matplotlib torch transformers pydantic vllm numpy requests tenacity wandb
|
||||
```
|
||||
|
||||
Shared libraries for Ubuntu GL rendering.
|
||||
```sh
|
||||
$ sudo apt-get install libglfw3-dev libgles2-mesa-dev libnvidia-gl-570-server
|
||||
```
|
||||
1
environments/hack0/__init__.py
Normal file
1
environments/hack0/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# for packages
|
||||
544
environments/hack0/grpo.py
Normal file
544
environments/hack0/grpo.py
Normal file
|
|
@ -0,0 +1,544 @@
|
|||
# ADAPTED FROM THE SAMPLE TRAINER
|
||||
|
||||
import atexit
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import string
|
||||
import subprocess
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import wandb # Added for logging
|
||||
from pydantic import BaseModel, Field
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# Global variable to keep track of the vLLM process
|
||||
vllm_process = None
|
||||
|
||||
|
||||
def cleanup_vllm():
|
||||
global vllm_process
|
||||
if vllm_process:
|
||||
print("\nTerminating vLLM process...")
|
||||
vllm_process.terminate()
|
||||
try:
|
||||
vllm_process.wait(timeout=5) # Wait a bit for graceful shutdown
|
||||
print("vLLM process terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("vLLM process did not terminate gracefully, killing.")
|
||||
vllm_process.kill()
|
||||
vllm_process.wait()
|
||||
print("vLLM process killed.")
|
||||
vllm_process = None
|
||||
|
||||
|
||||
# Register the cleanup function to be called on script exit
|
||||
atexit.register(cleanup_vllm)
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""
|
||||
Training details, model, etc
|
||||
"""
|
||||
|
||||
model_name: str = Field(..., description="Name of the base model to train")
|
||||
lr: float = Field(1e-5, description="Learning rate for the optimizer")
|
||||
training_steps: int = Field(
|
||||
10, description="Number of training steps"
|
||||
) # Renamed from epochs
|
||||
batch_size: int = Field(
|
||||
2, description="Batch size for training (will be handled by get_data)"
|
||||
)
|
||||
seq_len: int = Field(2048, description="Sequence length for training")
|
||||
gradient_accumulation_steps: int = Field(
|
||||
32, description="Number of gradient accumulation steps"
|
||||
)
|
||||
device: str = Field(
|
||||
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
|
||||
)
|
||||
save_path: str = Field(
|
||||
"trained_model_checkpoints", description="Base path to save model checkpoints"
|
||||
)
|
||||
vllm_restart_interval: int = Field(
|
||||
3, description="Restart vLLM every N training steps"
|
||||
)
|
||||
vllm_port: int = Field(9001, description="Port for the vLLM server")
|
||||
|
||||
# Wandb configuration
|
||||
use_wandb: bool = Field(
|
||||
False, description="Whether to use Weights & Biases for logging"
|
||||
)
|
||||
wandb_project: Optional[str] = Field(None, description="Wandb project name")
|
||||
wandb_group: Optional[str] = Field(None, description="Wandb group name")
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
|
||||
def register_trainer(config: TrainingConfig):
|
||||
"""
|
||||
Register the trainer with the Atropos API
|
||||
"""
|
||||
requests.post(
|
||||
"http://localhost:8000/register",
|
||||
json={
|
||||
"wandb_group": config.wandb_group,
|
||||
"wandb_project": config.wandb_project,
|
||||
"batch_size": config.batch_size * config.gradient_accumulation_steps,
|
||||
"max_token_len": config.seq_len,
|
||||
"starting_step": 0,
|
||||
"checkpoint_dir": config.save_path,
|
||||
"save_checkpoint_interval": config.training_steps,
|
||||
"num_steps": config.training_steps,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
|
||||
def get_batch():
|
||||
data = requests.get("http://localhost:8000/batch", timeout=10).json()
|
||||
return data
|
||||
|
||||
|
||||
def pad_data_to_good_offset(data, batch_size: int):
|
||||
max_token_len = max(
|
||||
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
|
||||
)
|
||||
# usually 64 is a good choice to ensure nonweird scaling behavior on GPUS
|
||||
# so we pad to the nearest multiple of 64
|
||||
good_multiple = 64
|
||||
if (max_token_len - 1) % (good_multiple) != 0:
|
||||
max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple
|
||||
token_setup_len = (
|
||||
max_token_len + 1
|
||||
) # add 1 so we can make it causal at the proper length
|
||||
else:
|
||||
token_setup_len = max_token_len
|
||||
max_token_len = (
|
||||
max_token_len - 1
|
||||
) # since it's causal we need to remove the last bit...
|
||||
# pad all tokens to max_token_len and add to lists
|
||||
input_ids = list()
|
||||
labels = list()
|
||||
advantages = list()
|
||||
lengths = list()
|
||||
for item in data["batch"]:
|
||||
scores = item["scores"]
|
||||
scores = np.array(scores)
|
||||
# check if we have more than 1 score...
|
||||
if len(scores) > 1:
|
||||
scores = scores - scores.mean()
|
||||
scores = scores / max(scores.std(), 1e-8)
|
||||
item["scores"] = scores
|
||||
if item["overrides"] is not None:
|
||||
for i in range(len(item["overrides"])):
|
||||
if item["overrides"][i].get("set_advantage_to_zero", False):
|
||||
item["scores"][i] = 0
|
||||
for i in range(len(item["tokens"])):
|
||||
lengths.append(
|
||||
math.ceil((len(item["tokens"][i]) - 1) / (good_multiple))
|
||||
* good_multiple
|
||||
)
|
||||
label_item = np.concatenate(
|
||||
[
|
||||
np.array(item["masks"][i]),
|
||||
np.full(
|
||||
max(0, token_setup_len - len(item["tokens"][i])),
|
||||
-100,
|
||||
dtype=np.int32,
|
||||
),
|
||||
]
|
||||
)
|
||||
item["tokens"][i] = np.concatenate(
|
||||
[
|
||||
np.array(item["tokens"][i]),
|
||||
np.zeros(
|
||||
max(0, token_setup_len - len(item["tokens"][i])), dtype=np.int32
|
||||
),
|
||||
]
|
||||
)
|
||||
input_ids.append(item["tokens"][i][:-1])
|
||||
labels.append(label_item[1:])
|
||||
advantages.append(item["scores"][i])
|
||||
# combine all lists into tensors
|
||||
token_batches = []
|
||||
label_batches = []
|
||||
advantage_batches = []
|
||||
for i in range(len(input_ids) // batch_size):
|
||||
token_batches.append(
|
||||
torch.tensor(
|
||||
np.stack(input_ids[i * batch_size : (i + 1) * batch_size], axis=0)
|
||||
)
|
||||
)
|
||||
label_batches.append(
|
||||
torch.tensor(
|
||||
np.stack(labels[i * batch_size : (i + 1) * batch_size], axis=0)
|
||||
)
|
||||
)
|
||||
advantage_batches.append(
|
||||
torch.tensor(
|
||||
np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0)
|
||||
).view(-1, 1)
|
||||
)
|
||||
return token_batches, label_batches, advantage_batches
|
||||
|
||||
|
||||
def get_data(
|
||||
batch_size: int, seq_len: int
|
||||
) -> List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
|
||||
"""
|
||||
getting data from the api
|
||||
"""
|
||||
batches = []
|
||||
while True:
|
||||
data = get_batch()
|
||||
if data["batch"] is not None:
|
||||
# Save the batch
|
||||
with open("temp.json", "w", encoding="utf-8") as f:
|
||||
json.dump(data, f)
|
||||
# In case the inference runs ahead of the training, we loop until we don't have any more data
|
||||
batches.append(pad_data_to_good_offset(data, batch_size))
|
||||
elif len(batches) > 0:
|
||||
# Return the batches
|
||||
return batches
|
||||
else:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def train(config: TrainingConfig):
|
||||
"""
|
||||
Setups and runs GRPO training, restarting vLLM periodically, with wandb logging.
|
||||
"""
|
||||
global vllm_process # Declare intention to modify the global variable
|
||||
|
||||
# --- Wandb Setup ---
|
||||
if config.use_wandb:
|
||||
if not config.wandb_project:
|
||||
print("Warning: wandb_project not set, disabling wandb.")
|
||||
config.use_wandb = False
|
||||
else:
|
||||
if not config.wandb_group:
|
||||
# Set group to random 8 character string
|
||||
config.wandb_group = "".join(
|
||||
random.choices(string.ascii_letters + string.digits, k=8)
|
||||
)
|
||||
try:
|
||||
wandb.init(
|
||||
project=config.wandb_project,
|
||||
group=config.wandb_group,
|
||||
config=config.dict(), # Log config parameters
|
||||
)
|
||||
print(
|
||||
f"Wandb logging enabled. Run: {wandb.run.name} (Project: {config.wandb_project}) "
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error initializing wandb: {e}. Disabling wandb.")
|
||||
config.use_wandb = False
|
||||
# --- End Wandb Setup ---
|
||||
|
||||
# Initialize model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
model.to(config.device)
|
||||
model.gradient_checkpointing_enable()
|
||||
model.train()
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr)
|
||||
|
||||
print(
|
||||
f"Starting training for {config.training_steps} steps on device: {config.device}"
|
||||
)
|
||||
print(
|
||||
f"vLLM will be restarted every {config.vllm_restart_interval} steps on port {config.vllm_port}"
|
||||
)
|
||||
|
||||
os.makedirs(config.save_path, exist_ok=True) # Ensure base save directory exists
|
||||
register_trainer(config)
|
||||
|
||||
# Init vllm
|
||||
vllm_command = [
|
||||
"python",
|
||||
"-m",
|
||||
"vllm.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
config.model_name,
|
||||
"--port",
|
||||
str(config.vllm_port),
|
||||
"--dtype",
|
||||
"auto",
|
||||
"--gpu-memory-utilization",
|
||||
"0.45",
|
||||
"--disable-log-requests",
|
||||
]
|
||||
print(f" Launching vLLM server: {' '.join(vllm_command)}")
|
||||
try:
|
||||
vllm_process = subprocess.Popen(vllm_command)
|
||||
print(f" vLLM server launched with PID: {vllm_process.pid}")
|
||||
# Check immediate errors
|
||||
try:
|
||||
stdout, stderr = vllm_process.communicate(timeout=2)
|
||||
if vllm_process.returncode is not None and vllm_process.returncode != 0:
|
||||
print(f" Error starting vLLM: {stderr.decode()}")
|
||||
vllm_process = None
|
||||
# Maybe raise error or just warn?
|
||||
print(" WARNING: Failed to start vLLM server after checkpoint.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(" vLLM process started (check logs for details).")
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
"\n *** ERROR: 'python -m vllm...' command not found. Make sure vLLM is installed and accessible. ***\n"
|
||||
)
|
||||
# Potentially stop training or just disable further vLLM restarts
|
||||
print(" Disabling further vLLM restarts.")
|
||||
config.vllm_restart_interval = (
|
||||
config.training_steps + 1
|
||||
) # Prevent further restarts
|
||||
except Exception as e:
|
||||
print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n")
|
||||
print(" Disabling further vLLM restarts.")
|
||||
config.vllm_restart_interval = (
|
||||
config.training_steps + 1
|
||||
) # Prevent further restarts
|
||||
|
||||
batches = list()
|
||||
for step in range(config.training_steps):
|
||||
total_loss = 0
|
||||
print(f"Step {step+1}/{config.training_steps}")
|
||||
total_pos_logp = 0
|
||||
total_neg_logp = 0
|
||||
total_logp = 0
|
||||
total_pos = 0
|
||||
total_neg = 0
|
||||
if len(batches) == 0:
|
||||
batches = get_data(config.batch_size, config.seq_len)
|
||||
token_batches, label_batches, advantage_batches = batches.pop(0)
|
||||
# Terminate existing vLLM process if running
|
||||
if (
|
||||
step + 1
|
||||
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step
|
||||
# Terminate existing vLLM process if running
|
||||
if vllm_process:
|
||||
print(" Terminating existing vLLM process...")
|
||||
vllm_process.terminate()
|
||||
try:
|
||||
vllm_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
" Existing vLLM process did not terminate gracefully, killing."
|
||||
)
|
||||
vllm_process.kill()
|
||||
vllm_process.wait()
|
||||
vllm_process = None
|
||||
for tokens, labels, advantages in zip(
|
||||
token_batches, label_batches, advantage_batches
|
||||
):
|
||||
|
||||
tokens, labels, advantages = (
|
||||
tokens.to(config.device),
|
||||
labels.to(config.device),
|
||||
advantages.to(config.device),
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
# User specified that tokens/labels are already prepared by get_data
|
||||
outputs = model(tokens) # Assuming model just needs tokens
|
||||
logits = outputs.logits # Assuming this is the structure
|
||||
|
||||
# Calculate GRPO loss (reverting to user's previous logic)
|
||||
# User stated ignore_index is -100 and tokens/labels are aligned by get_data
|
||||
# Assuming logits correspond directly to labels indices (no shift needed here)
|
||||
logp_per_token = -F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)), # Flatten logits
|
||||
labels.view(-1), # Flatten labels
|
||||
reduction="none",
|
||||
ignore_index=-100, # User specified ignore index
|
||||
).view(
|
||||
labels.shape
|
||||
) # Reshape back to (batch, seq_len)
|
||||
|
||||
# Masking based on labels != -100
|
||||
mask = (labels != -100).float()
|
||||
with torch.no_grad():
|
||||
pos = (advantages > 0).float()
|
||||
neg = (advantages <= 0).float()
|
||||
avg_logp = (logp_per_token * mask).sum(-1) / mask.sum(-1)
|
||||
pos_logp = (logp_per_token * pos).mean().item()
|
||||
neg_logp = (logp_per_token * neg).mean().item()
|
||||
total_pos_logp += pos_logp
|
||||
total_neg_logp += neg_logp
|
||||
total_logp += avg_logp
|
||||
total_pos += pos.sum().item()
|
||||
total_neg += neg.sum().item()
|
||||
|
||||
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
|
||||
grpo_loss = (
|
||||
((-grpo_loss_term * mask).sum(-1) / mask.sum(-1))
|
||||
* advantages.to(logp_per_token.device)
|
||||
).mean() / config.gradient_accumulation_steps
|
||||
grpo_loss.backward()
|
||||
total_loss += grpo_loss.item()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
if total_pos > 0:
|
||||
total_pos_logp /= total_pos
|
||||
if total_neg > 0:
|
||||
total_neg_logp /= total_neg
|
||||
# --- Wandb Logging ---
|
||||
if config.use_wandb:
|
||||
wandb.log(
|
||||
{
|
||||
"train/loss": total_loss,
|
||||
"train/learning_rate": optimizer.param_groups[0]["lr"],
|
||||
"train/grad_norm": grad_norm.item(),
|
||||
"train/pos_logp": total_pos_logp,
|
||||
"train/neg_logp": total_neg_logp,
|
||||
"train/logp": total_logp,
|
||||
},
|
||||
step=step + 1,
|
||||
)
|
||||
# --- End Wandb Logging ---
|
||||
|
||||
print(f" Step Loss: {grpo_loss.item():.4f}")
|
||||
|
||||
# --- vLLM Restart Logic (Moved AFTER optimizer step) ---
|
||||
# Note: There are much better ways of updating the policy, this is just a very simple example
|
||||
if (
|
||||
step + 1
|
||||
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step
|
||||
checkpoint_path = os.path.join(
|
||||
config.save_path, f"step_{step+1}"
|
||||
) # Save as step+1 since it's after step completion
|
||||
print(f" Saving checkpoint to {checkpoint_path}...")
|
||||
# Ensure fresh directory for saving
|
||||
if os.path.exists(checkpoint_path):
|
||||
shutil.rmtree(checkpoint_path) # Remove old checkpoint if it exists
|
||||
os.makedirs(checkpoint_path, exist_ok=True)
|
||||
model.save_pretrained(checkpoint_path)
|
||||
tokenizer.save_pretrained(checkpoint_path)
|
||||
print(" Checkpoint saved.")
|
||||
|
||||
# Terminate existing vLLM process if running
|
||||
if vllm_process:
|
||||
print(" Terminating existing vLLM process...")
|
||||
vllm_process.terminate()
|
||||
try:
|
||||
vllm_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
" Existing vLLM process did not terminate gracefully, killing."
|
||||
)
|
||||
vllm_process.kill()
|
||||
vllm_process.wait()
|
||||
vllm_process = None
|
||||
|
||||
# Launch new vLLM process (only if not the very last step, maybe? depends on use case)
|
||||
# Let's still launch it on the last step for consistency, cleanup will handle it.
|
||||
vllm_command = [
|
||||
"python",
|
||||
"-m",
|
||||
"vllm.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
os.path.join(config.save_path, f"step_{step+1}"),
|
||||
"--port",
|
||||
str(config.vllm_port),
|
||||
"--dtype",
|
||||
"auto",
|
||||
"--gpu-memory-utilization",
|
||||
"0.45",
|
||||
"--disable-log-requests",
|
||||
"--served-model-name",
|
||||
config.model_name,
|
||||
]
|
||||
print(f" Launching vLLM server: {' '.join(vllm_command)}")
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
vllm_process = subprocess.Popen(vllm_command)
|
||||
print(f" vLLM server launched with PID: {vllm_process.pid}")
|
||||
# Check immediate errors
|
||||
try:
|
||||
stdout, stderr = vllm_process.communicate(timeout=2)
|
||||
if (
|
||||
vllm_process.returncode is not None
|
||||
and vllm_process.returncode != 0
|
||||
):
|
||||
print(f" Error starting vLLM: {stderr.decode()}")
|
||||
vllm_process = None
|
||||
# Maybe raise error or just warn?
|
||||
print(
|
||||
" WARNING: Failed to start vLLM server after checkpoint."
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(" vLLM process started (check logs for details).")
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
"\n *** ERROR: 'python -m vllm...' command not found. ",
|
||||
"Make sure vLLM is installed and accessible. ***\n",
|
||||
)
|
||||
# Potentially stop training or just disable further vLLM restarts
|
||||
print(" Disabling further vLLM restarts.")
|
||||
config.vllm_restart_interval = (
|
||||
config.training_steps + 1
|
||||
) # Prevent further restarts
|
||||
except Exception as e:
|
||||
print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n")
|
||||
print(" Disabling further vLLM restarts.")
|
||||
config.vllm_restart_interval = (
|
||||
config.training_steps + 1
|
||||
) # Prevent further restarts
|
||||
# --- End vLLM Restart Logic ---
|
||||
|
||||
# Basic check if vLLM process terminated unexpectedly (outside interval check)
|
||||
if vllm_process and vllm_process.poll() is not None:
|
||||
print(
|
||||
f"\n *** WARNING: vLLM process terminated unexpectedly (return code: {vllm_process.returncode}). ",
|
||||
"Check vLLM logs. ***\n",
|
||||
)
|
||||
stderr_output = (
|
||||
vllm_process.stderr.read().decode()
|
||||
if vllm_process.stderr
|
||||
else "No stderr"
|
||||
)
|
||||
print(f"vLLM stderr: {stderr_output}")
|
||||
vllm_process = None # Reset so it relaunches next interval
|
||||
|
||||
print("Training finished.")
|
||||
# --- Wandb Finish ---
|
||||
if config.use_wandb:
|
||||
wandb.finish()
|
||||
# --- End Wandb Finish ---
|
||||
# Final cleanup (vLLM termination) is handled by atexit
|
||||
|
||||
# --- Placeholder for final model save ---
|
||||
final_save_path = os.path.join(config.save_path, "final_model")
|
||||
print(f"Saving final model to {final_save_path}")
|
||||
if os.path.exists(final_save_path):
|
||||
shutil.rmtree(final_save_path)
|
||||
os.makedirs(final_save_path, exist_ok=True)
|
||||
model.save_pretrained(final_save_path)
|
||||
tokenizer.save_pretrained(final_save_path)
|
||||
print("Final model saved.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_config = TrainingConfig(
|
||||
model_name="google/gemma-3-27b-it",
|
||||
training_steps=20,
|
||||
vllm_restart_interval=3,
|
||||
use_wandb=True,
|
||||
wandb_project="grpo-physical-trainer"
|
||||
)
|
||||
|
||||
train(training_config)
|
||||
225
environments/hack0/pyrender_utils.py
Normal file
225
environments/hack0/pyrender_utils.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import pyrender
|
||||
|
||||
# Headless rendering with GPU acceleration (egl), for non-GPU environments omesa will be faster
|
||||
os.environ['PYOPENGL_PLATFORM'] = 'egl'
|
||||
|
||||
def create_look_at_matrix(eye: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Create a look-at transformation matrix for a camera.
|
||||
|
||||
eye - position of the camera
|
||||
target - position the camera is looking at
|
||||
up - up direction for the camera
|
||||
|
||||
returns the 4x4 transformation matrix
|
||||
"""
|
||||
eye = np.asarray(eye)
|
||||
target = np.asarray(target)
|
||||
up = np.asarray(up)
|
||||
|
||||
# Forward vector (from eye to target)
|
||||
forward = target - eye
|
||||
forward = forward / np.linalg.norm(forward)
|
||||
|
||||
# Side vector (right vector)
|
||||
side = np.cross(forward, up)
|
||||
side = side / np.linalg.norm(side)
|
||||
|
||||
# Recompute up vector to ensure orthogonality
|
||||
up = np.cross(side, forward)
|
||||
up = up / np.linalg.norm(up)
|
||||
|
||||
# Create the rotation matrix | NOTE: PyRender uses OpenGL convention
|
||||
# where camera looks down negative z-axis
|
||||
R = np.eye(4)
|
||||
R[0, :3] = side
|
||||
R[1, :3] = up
|
||||
R[2, :3] = -forward
|
||||
|
||||
T = np.eye(4) # translation matrix
|
||||
T[:3, 3] = eye
|
||||
|
||||
# The camera pose matrix is the inverse of the view matrix
|
||||
# but we need to return the pose directly for pyrender
|
||||
return T @ R
|
||||
|
||||
class PyRenderOffline:
|
||||
def __init__(self, width=224, height=224): # Standard CLIP image size
|
||||
self.width = width
|
||||
self.height = height
|
||||
try:
|
||||
self.renderer = pyrender.OffscreenRenderer(viewport_width=self.width, viewport_height=self.height, point_size=1.0)
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize OffscreenRenderer (is a display server/EGL/OSMesa available?): {e}")
|
||||
print("Try: pip install pyglet; or for headless: export PYOPENGL_PLATFORM=osmesa (or egl)")
|
||||
self.renderer = None # Fallback or raise error
|
||||
raise
|
||||
|
||||
# Create camera poses using explicit transformation matrices
|
||||
|
||||
# Front view (looking along the -Z axis)
|
||||
front_pose = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 0, 1, 5],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
|
||||
# Top view (looking along the -Y axis)
|
||||
top_pose = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 5],
|
||||
[0, -1, 0, 0],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
|
||||
# Diagonal view (from upper corner)
|
||||
side_pose = np.array([
|
||||
[0.866, -0.25, 0.433, 3], # Camera right vector
|
||||
[0.0, 0.866, 0.5, 3], # Camera up vector
|
||||
[-0.5, -0.433, 0.75, 5], # Camera forward vector (pointing at origin)
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
|
||||
# Store camera poses
|
||||
self.camera_poses = [front_pose, top_pose, side_pose]
|
||||
|
||||
# Debug print for the camera poses cause it doesn't look right
|
||||
# and I'm not a game developer lol
|
||||
print("Camera poses:")
|
||||
for i, pose in enumerate(self.camera_poses):
|
||||
print(f"Camera {i}:\n{pose}")
|
||||
|
||||
# slightly wider field of view to ensure objects are visible
|
||||
self.camera = pyrender.PerspectiveCamera(yfov=np.pi / 2.5, aspectRatio=1.0)
|
||||
|
||||
# Bright point light
|
||||
self.light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=10.0)
|
||||
|
||||
def render_mesh_to_images(self, mesh_obj: trimesh.Trimesh):
|
||||
if not self.renderer:
|
||||
print("Renderer not initialized, cannot render.")
|
||||
return [np.zeros((self.height, self.width, 3), dtype=np.uint8) for _ in range(3)]
|
||||
|
||||
print(f"Rendering mesh with {len(mesh_obj.vertices)} vertices and {len(mesh_obj.faces)} faces")
|
||||
|
||||
images = []
|
||||
|
||||
# Make a copy to avoid modifying original mesh
|
||||
render_mesh = mesh_obj.copy()
|
||||
|
||||
# Center and scale the mesh for visibility
|
||||
render_mesh.apply_translation(-render_mesh.centroid)
|
||||
scale_factor = 0.8 / np.max(render_mesh.extents) # Scale to unit size but slightly smaller
|
||||
render_mesh.apply_scale(scale_factor)
|
||||
|
||||
# Create a ground plane
|
||||
ground_plane = trimesh.creation.box([4.0, 4.0, 0.01])
|
||||
ground_plane.apply_translation([0, 0, -0.5])
|
||||
|
||||
# Color scheme for wireframe blueprint
|
||||
blueprint_bg = [0.98, 0.98, 1.0, 1.0] # Light blue background
|
||||
wireframe_color = [0.0, 0.4, 0.8, 1.0] # Medium bright blue for wireframes
|
||||
grid_color = [0.0, 0.2, 0.4, 1.0] # Darker blue for grid
|
||||
|
||||
# Wireframe material for the edges - completely opaque
|
||||
wireframe_material = pyrender.MetallicRoughnessMaterial(
|
||||
baseColorFactor=wireframe_color,
|
||||
metallicFactor=0.0,
|
||||
roughnessFactor=1.0,
|
||||
wireframe=True
|
||||
)
|
||||
|
||||
# Grid material for the ground plane
|
||||
grid_material = pyrender.MetallicRoughnessMaterial(
|
||||
baseColorFactor=grid_color,
|
||||
metallicFactor=0.0,
|
||||
roughnessFactor=1.0,
|
||||
wireframe=True
|
||||
)
|
||||
|
||||
for i, pose in enumerate(self.camera_poses):
|
||||
# Create a fresh scene with blueprint background
|
||||
scene = pyrender.Scene(ambient_light=[0.8, 0.8, 0.95], bg_color=blueprint_bg)
|
||||
|
||||
# Add ground plane as grid
|
||||
plane_mesh = pyrender.Mesh.from_trimesh(ground_plane, material=grid_material)
|
||||
scene.add(plane_mesh)
|
||||
|
||||
# For wireframe rendering, we'll create line segments directly
|
||||
edges = render_mesh.edges_unique
|
||||
edge_vertices = []
|
||||
edge_indices = []
|
||||
|
||||
# Extract all edges for wireframe rendering
|
||||
for _, edge in enumerate(edges):
|
||||
v0_idx = len(edge_vertices)
|
||||
edge_vertices.append(render_mesh.vertices[edge[0]])
|
||||
edge_vertices.append(render_mesh.vertices[edge[1]])
|
||||
edge_indices.append([v0_idx, v0_idx+1])
|
||||
|
||||
# Create lines primitive for edges
|
||||
if len(edge_vertices) > 0:
|
||||
edge_verts = np.array(edge_vertices, dtype=np.float32)
|
||||
edge_indices = np.array(edge_indices, dtype=np.uint32)
|
||||
|
||||
# Create a primitive for the lines
|
||||
primitive = pyrender.Primitive(
|
||||
positions=edge_verts,
|
||||
indices=edge_indices,
|
||||
mode=pyrender.constants.GLTF.LINES,
|
||||
material=wireframe_material
|
||||
)
|
||||
|
||||
# Create a mesh with just the line primitive
|
||||
edge_mesh = pyrender.Mesh(primitives=[primitive])
|
||||
scene.add(edge_mesh)
|
||||
|
||||
# Add camera
|
||||
scene.add(self.camera, pose=pose)
|
||||
|
||||
# Add light from camera direction (key light)
|
||||
scene.add(self.light, pose=pose)
|
||||
|
||||
# Add second light from above for better visibility
|
||||
top_light_pose = np.eye(4)
|
||||
top_light_pose[1, 3] = 3.0
|
||||
scene.add(self.light, pose=top_light_pose)
|
||||
|
||||
try:
|
||||
color, _ = self.renderer.render(scene)
|
||||
|
||||
# Post-process to enhance blueprint effect
|
||||
# Convert to float for processing
|
||||
img_float = color.astype(np.float32) / 255.0
|
||||
|
||||
# Add a subtle grid pattern to the background
|
||||
grid_size = 40
|
||||
grid_intensity = 0.02
|
||||
|
||||
# Draw faint horizontal grid lines
|
||||
for y in range(0, color.shape[0], grid_size):
|
||||
img_float[y:y+1, :, :] = np.minimum(img_float[y:y+1, :, :] + grid_intensity, 1.0)
|
||||
|
||||
# Draw faint vertical grid lines
|
||||
for x in range(0, color.shape[1], grid_size):
|
||||
img_float[:, x:x+1, :] = np.minimum(img_float[:, x:x+1, :] + grid_intensity, 1.0)
|
||||
|
||||
# Convert back to uint8
|
||||
processed_img = (img_float * 255).astype(np.uint8)
|
||||
|
||||
images.append(processed_img)
|
||||
print(f"Rendered wireframe view {i}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during rendering view {i}: {e}")
|
||||
images.append(np.zeros((self.height, self.width, 3), dtype=np.uint8))
|
||||
|
||||
return images
|
||||
|
||||
def __del__(self):
|
||||
if self.renderer:
|
||||
self.renderer.delete()
|
||||
BIN
environments/hack0/test_rendered_sphere_views.png
Normal file
BIN
environments/hack0/test_rendered_sphere_views.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 47 KiB |
46
environments/hack0/test_renderer_example.py
Normal file
46
environments/hack0/test_renderer_example.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import matplotlib.pyplot as plt
|
||||
from pyrender_utils import PyRenderOffline
|
||||
|
||||
def test_render_example():
|
||||
"""
|
||||
Example script to test the PyRenderOffline class with a real mesh.
|
||||
Only runs if a GPU/rendering environment is available. See README.md for more details.
|
||||
"""
|
||||
try:
|
||||
# Create a high-quality sphere for wireframe rendering
|
||||
# Using more subdivisions for smoother appearance
|
||||
sphere = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
|
||||
|
||||
# Create a perfect sphere instead of a noisy one
|
||||
# This will look better as a wireframe/blueprint for our test ;)
|
||||
print(f"Created sphere with {len(sphere.vertices)} vertices and {len(sphere.faces)} faces")
|
||||
print(f"Sphere has {len(sphere.edges_unique)} unique edges for wireframe rendering")
|
||||
|
||||
renderer = PyRenderOffline(width=512, height=512) # larger dimensions than the CLIP size for better detail
|
||||
|
||||
images = renderer.render_mesh_to_images(sphere)
|
||||
|
||||
# Display the results
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||
view_names = ['Front', 'Top', 'Diagonal']
|
||||
|
||||
for i, (img, name) in enumerate(zip(images, view_names)):
|
||||
axes[i].imshow(img)
|
||||
axes[i].set_title(f"{name} View")
|
||||
axes[i].axis('off')
|
||||
|
||||
plt.savefig('test_rendered_sphere_views.png')
|
||||
plt.close()
|
||||
|
||||
print(f"Successfully rendered sphere from 3 viewpoints")
|
||||
print(f"Images saved to rendered_sphere_views.png")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to run renderer example: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_render_example()
|
||||
Loading…
Add table
Add a link
Reference in a new issue