mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
rest of submission
This commit is contained in:
parent
36afc0da59
commit
012090307e
9 changed files with 3537 additions and 1770 deletions
1
environments/hack0/.gitignore
vendored
1
environments/hack0/.gitignore
vendored
|
|
@ -1,2 +1,3 @@
|
|||
sample_data/
|
||||
venv/
|
||||
.env
|
||||
|
|
|
|||
|
|
@ -29,4 +29,8 @@ $ sudo apt-get install libglfw3-dev libgles2-mesa-dev libnvidia-gl-570-server
|
|||
```
|
||||
- Use `render_stl.py` to generate images from STL files.
|
||||
- Use `llm_label.py` to label the STL and image files.
|
||||
- Use `prepare_push_hf_dataset.py` to push the dataset to Hugging Face.
|
||||
- Use `prepare_push_hf_dataset.py` to push the dataset to Hugging Face.
|
||||
|
||||
Generated run: https://wandb.ai/csxl/atropos-environments_hack0/runs/dlexyg5r
|
||||
Training run (ran out of memory): https://wandb.ai/csxl/grpo-physical-trainer/runs/t61am7gu
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1,3 @@
|
|||
# for packages
|
||||
from .physical_env import PhysicalEnv
|
||||
|
||||
__all__ = ["PhysicalEnv"]
|
||||
1125
environments/hack0/data/physical.html
Normal file
1125
environments/hack0/data/physical.html
Normal file
File diff suppressed because it is too large
Load diff
46
environments/hack0/judgement_model.py
Normal file
46
environments/hack0/judgement_model.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import torch
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
class CLIPScorer:
|
||||
def __init__(self, model_name="openai/clip-vit-base-patch32"):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
try:
|
||||
self.model = CLIPModel.from_pretrained(model_name).to(self.device)
|
||||
self.processor = CLIPProcessor.from_pretrained(model_name)
|
||||
print(f"CLIPScorer initialized on {self.device} with {model_name}")
|
||||
except Exception as e:
|
||||
print(f"Error initializing CLIPModel: {e}. Ensure model name is correct and you have internet.")
|
||||
self.model = None
|
||||
self.processor = None
|
||||
raise
|
||||
|
||||
@torch.no_grad() # Ensure no gradients are computed during inference
|
||||
def score_images(self, images_np_list: list, target_text_description: str):
|
||||
if not self.model or not self.processor:
|
||||
print("CLIPScorer not properly initialized.")
|
||||
return [0.0] * len(images_np_list) # Low score on error
|
||||
|
||||
try:
|
||||
pil_images = [Image.fromarray(img_arr.astype(np.uint8)) for img_arr in images_np_list]
|
||||
|
||||
inputs = self.processor(
|
||||
text=[target_text_description], # Single text prompt
|
||||
images=pil_images,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True
|
||||
).to(self.device)
|
||||
|
||||
outputs = self.model(**inputs)
|
||||
image_text_similarity_scores = outputs.logits_per_image.squeeze().tolist() # Squeeze to remove the text dim
|
||||
|
||||
if not isinstance(image_text_similarity_scores, list): # If only one image
|
||||
image_text_similarity_scores = [image_text_similarity_scores]
|
||||
|
||||
return image_text_similarity_scores
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in CLIP scoring: {e}")
|
||||
return [0.0] * len(images_np_list) # Low score on error
|
||||
481
environments/hack0/physical_env.py
Normal file
481
environments/hack0/physical_env.py
Normal file
|
|
@ -0,0 +1,481 @@
|
|||
import numpy as np
|
||||
import trimesh
|
||||
import io
|
||||
import re
|
||||
import wandb
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from typing import *
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
EvalHandlingEnum
|
||||
)
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# Fix the relative imports for running directly
|
||||
try:
|
||||
from .pyrender_utils import PyRenderOffline
|
||||
from .judgement_model import CLIPScorer
|
||||
except ImportError:
|
||||
from pyrender_utils import PyRenderOffline
|
||||
from judgement_model import CLIPScorer
|
||||
|
||||
system_prompt = (
|
||||
"You are an expert in 3D modeling and computer-aided design. Your task is to analyze the "
|
||||
"blueprints or wireframe views of objects and generate the corresponding STL file content. "
|
||||
"STL (stereolithography) files represent 3D models as a collection of triangular facets.\n\n"
|
||||
"You may use <think> </think> tags to work through your reasoning about the shape, "
|
||||
"dimensions, and geometric features of the model. Be methodical in your approach.\n\n"
|
||||
"STL files can be in ASCII or binary format. For this task, generate ASCII STL content that "
|
||||
"accurately represents the 3D model shown in the provided views.\n\n"
|
||||
"Your final output must be enclosed in <stl> </stl> tags, containing only the valid STL content "
|
||||
"and nothing else. The STL content should begin with 'solid' and end with 'endsolid'.\n\n"
|
||||
"Example of STL format:\n"
|
||||
"<stl>\n"
|
||||
"solid model\n"
|
||||
" facet normal nx ny nz\n"
|
||||
" outer loop\n"
|
||||
" vertex x1 y1 z1\n"
|
||||
" vertex x2 y2 z2\n"
|
||||
" vertex x3 y3 z3\n"
|
||||
" endloop\n"
|
||||
" endfacet\n"
|
||||
" ... more facets ...\n"
|
||||
"endsolid model\n"
|
||||
"</stl>"
|
||||
)
|
||||
|
||||
class PhysicalRow(TypedDict):
|
||||
prompt: str
|
||||
image: np.ndarray
|
||||
stl: str
|
||||
|
||||
class PhysicalEnv(BaseEnv):
|
||||
name = "physical"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
self.completion_lengths = []
|
||||
# Initialize renderer and CLIP scorer
|
||||
self.renderer = PyRenderOffline(width=224, height=224)
|
||||
self.clip_scorer = CLIPScorer()
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="google/gemma-3-27b-it",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
max_token_length=2048,
|
||||
wandb_name="physical",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="google/gemma-3-27b-it",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Try to calculate percent_correct, pass if there's a division by zero
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
# Call the parent method to handle the server metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
def load_stl_file(self, stl_path):
|
||||
"""Load an STL file into a trimesh object"""
|
||||
try:
|
||||
mesh = trimesh.load(stl_path)
|
||||
return mesh
|
||||
except Exception as e:
|
||||
print(f"Error loading STL file {stl_path}: {e}")
|
||||
return None
|
||||
|
||||
def generate_query_from_images(self, images):
|
||||
"""Generate a query based on the rendered images of the STL file"""
|
||||
# In a real implementation, this would use a vision model to generate a description
|
||||
# For this simplified version, we'll use different templates to add variety
|
||||
templates = [
|
||||
"Create a 3D model (STL file) for the object shown in these technical drawings. Be precise with the geometry.",
|
||||
"Based on these wireframe views, generate the STL code for this 3D object. Pay attention to all visible features.",
|
||||
"Using these blueprint images as reference, provide the STL file format data to recreate this 3D model.",
|
||||
"These are technical views of a 3D object. Generate the STL representation that would produce this exact shape.",
|
||||
"Reconstruct this 3D model from the provided wireframe views and output the STL file content."
|
||||
]
|
||||
return random.choice(templates)
|
||||
|
||||
async def setup(self):
|
||||
# Load all STL files from sample_data
|
||||
self.stl_files = glob.glob(os.path.join('sample_data', '*.stl'))
|
||||
if not self.stl_files:
|
||||
raise ValueError("No STL files found in the sample_data directory")
|
||||
|
||||
print(f"Found {len(self.stl_files)} STL files")
|
||||
|
||||
# Split files into train and test sets (80/20 split)
|
||||
random.seed(42)
|
||||
random.shuffle(self.stl_files)
|
||||
split_idx = int(len(self.stl_files) * 0.8)
|
||||
self.train_files = self.stl_files[:split_idx]
|
||||
self.test_files = self.stl_files[split_idx:]
|
||||
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def rollout_and_score_eval(self, stl_path: str) -> number:
|
||||
# Load the STL file
|
||||
mesh = self.load_stl_file(stl_path)
|
||||
if mesh is None:
|
||||
return 0
|
||||
|
||||
# Render the images
|
||||
images = self.renderer.render_mesh_to_images(mesh)
|
||||
|
||||
# Generate a query from the images
|
||||
query = self.generate_query_from_images(images)
|
||||
|
||||
# Get a completion from the model
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
# Extract the STL content from the completion
|
||||
response_content = completion.choices[0].message.content
|
||||
stl_content = self.extract_stl_content(response_content)
|
||||
|
||||
# Load the original mesh directly
|
||||
original_mesh = mesh
|
||||
|
||||
# Save the generated STL content to a temporary file
|
||||
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
|
||||
try:
|
||||
with open(temp_file, "w") as f:
|
||||
f.write(stl_content)
|
||||
|
||||
# Load the generated mesh
|
||||
generated_mesh = trimesh.load(temp_file)
|
||||
|
||||
# Score the generated mesh against the original
|
||||
score = self.score_meshes_similarity(original_mesh, generated_mesh)
|
||||
|
||||
# Cleanup
|
||||
os.remove(temp_file)
|
||||
|
||||
return score
|
||||
except Exception as e:
|
||||
print(f"Error processing generated STL: {e}")
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
return 0
|
||||
|
||||
def extract_stl_content(self, response_content):
|
||||
"""Extract STL content from the model's response"""
|
||||
# Find content between <stl> and </stl> tags
|
||||
match = re.search(r'<stl>(.*?)</stl>', response_content, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return ""
|
||||
|
||||
def score_meshes_similarity(self, original_mesh, generated_mesh):
|
||||
"""Score the similarity between two meshes"""
|
||||
# This is a simple implementation - in practice you'd want more sophisticated metrics
|
||||
# Compare basic properties like number of vertices, faces, and volume
|
||||
orig_stats = {
|
||||
'vertices': len(original_mesh.vertices),
|
||||
'faces': len(original_mesh.faces),
|
||||
'volume': original_mesh.volume or 1.0,
|
||||
'surface_area': original_mesh.area or 1.0
|
||||
}
|
||||
|
||||
gen_stats = {
|
||||
'vertices': len(generated_mesh.vertices),
|
||||
'faces': len(generated_mesh.faces),
|
||||
'volume': generated_mesh.volume or 1.0,
|
||||
'surface_area': generated_mesh.area or 1.0
|
||||
}
|
||||
|
||||
# Calculate ratios (capped at 1.0 for when generated > original)
|
||||
vertex_ratio = min(gen_stats['vertices'] / max(orig_stats['vertices'], 1), 1.0)
|
||||
face_ratio = min(gen_stats['faces'] / max(orig_stats['faces'], 1), 1.0)
|
||||
volume_ratio = min(gen_stats['volume'] / max(orig_stats['volume'], 1), 1.0)
|
||||
area_ratio = min(gen_stats['surface_area'] / max(orig_stats['surface_area'], 1), 1.0)
|
||||
|
||||
# Average the ratios for a final score
|
||||
score = (vertex_ratio + face_ratio + volume_ratio + area_ratio) / 4.0
|
||||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
eval_tasks = []
|
||||
for stl_file in self.test_files[:10]: # Limit to 10 files for evaluation to keep it manageable
|
||||
eval_tasks.append(self.rollout_and_score_eval(stl_file))
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/similarity_score", sum(scores) / len(scores)))
|
||||
|
||||
async def collect_trajectories(self, item: PhysicalRow) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
stl_path = item["stl_path"]
|
||||
mesh = self.load_stl_file(stl_path)
|
||||
images = self.renderer.render_mesh_to_images(mesh)
|
||||
query = self.generate_query_from_images(images)
|
||||
|
||||
# For original STL content, we'll just store the file path instead of the content
|
||||
# as the files may be binary and can't be simply read as text
|
||||
original_stl_path = stl_path
|
||||
|
||||
user_message = {"role": "user", "content": query}
|
||||
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": system_prompt}, user_message],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
user_message,
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"original_stl_path": original_stl_path,
|
||||
"images": images,
|
||||
"finish_reason": chat_completion.finish_reason,
|
||||
}
|
||||
)
|
||||
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(self, rollout_group_data) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
|
||||
random.shuffle(rollout_group_data)
|
||||
|
||||
for item in rollout_group_data:
|
||||
response_content = item["messages"][-1]["content"]
|
||||
stl_content = self.extract_stl_content(response_content)
|
||||
|
||||
# Save the generated STL content to a temporary file
|
||||
temp_file = f"temp_generated_{random.randint(1000, 9999)}.stl"
|
||||
try:
|
||||
with open(temp_file, "w") as f:
|
||||
f.write(stl_content)
|
||||
|
||||
# Load the original STL directly from its path
|
||||
original_stl_path = item["original_stl_path"]
|
||||
original_mesh = trimesh.load(original_stl_path)
|
||||
|
||||
# Load the generated mesh
|
||||
generated_mesh = trimesh.load(temp_file)
|
||||
|
||||
# Score the generated mesh against the original
|
||||
mesh_similarity = self.score_meshes_similarity(original_mesh, generated_mesh)
|
||||
|
||||
# Generate rendered images of the produced STL
|
||||
generated_images = self.renderer.render_mesh_to_images(generated_mesh)
|
||||
|
||||
# Use CLIP to score the visual similarity
|
||||
images_reward = 0.0
|
||||
if len(generated_images) > 0 and len(item["images"]) > 0:
|
||||
# Extract query from the user message
|
||||
query = item["messages"][1]["content"]
|
||||
|
||||
# Score the visual similarity using CLIP
|
||||
clip_scores = self.clip_scorer.score_images(generated_images, query)
|
||||
images_reward = sum(clip_scores) / len(clip_scores) / 100.0 # Normalize to roughly 0-1
|
||||
|
||||
# Combine mesh similarity and image similarity for final reward
|
||||
reward = 0.5 * mesh_similarity + 0.5 * images_reward
|
||||
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, item["messages"], item["finish_reason"]
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
|
||||
# Remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(reward)
|
||||
|
||||
self.percent_correct_buffer.append(reward)
|
||||
|
||||
# Clean up temporary files
|
||||
os.remove(temp_file)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in scoring: {e}")
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
|
||||
# Apply length penalty if all scores are similar
|
||||
if all(abs(score - scores["scores"][0]) < 0.1 for score in scores["scores"]):
|
||||
token_lengths = [len(token) for token in scores["tokens"]]
|
||||
if max(token_lengths) > 0:
|
||||
max_allowed_length = self.config.max_token_length
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
|
||||
scores["scores"] = []
|
||||
for length in token_lengths:
|
||||
if length <= length_threshold:
|
||||
scores["scores"].append(1.0)
|
||||
else:
|
||||
percentage_of_range = (length - length_threshold) / (
|
||||
max_allowed_length - length_threshold
|
||||
)
|
||||
percentage_of_range = min(percentage_of_range, 1.0)
|
||||
scores["scores"].append(1.0 - percentage_of_range)
|
||||
|
||||
if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
return None # If all the same, we return None
|
||||
|
||||
return scores
|
||||
|
||||
async def get_next_item(self) -> PhysicalRow:
|
||||
stl_path = self.train_files[self.iter % len(self.train_files)]
|
||||
self.iter += 1
|
||||
|
||||
# Load the STL file and render it
|
||||
mesh = self.load_stl_file(stl_path)
|
||||
if mesh is None:
|
||||
# Skip this file and try the next one if there's an issue
|
||||
return await self.get_next_item()
|
||||
|
||||
# Render the mesh to get images
|
||||
images = self.renderer.render_mesh_to_images(mesh)
|
||||
|
||||
# Generate a query from the images
|
||||
query = self.generate_query_from_images(images)
|
||||
|
||||
# Return a row with the prompt, image, and path to the STL file
|
||||
return {
|
||||
"prompt": query,
|
||||
"image": images[0] if images else np.zeros((224, 224, 3), dtype=np.uint8),
|
||||
"stl_path": stl_path
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def test_sample_stl(cls):
|
||||
"""Test loading and rendering a sample STL file"""
|
||||
# Create temporary environment instance
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="google/gemma-3-27b-it",
|
||||
group_size=8,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
max_token_length=2048,
|
||||
wandb_name="physical_test",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="google/gemma-3-27b-it",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
env = cls(env_config, server_configs, slurm=False, testing=True)
|
||||
|
||||
# Find sample STL files
|
||||
stl_files = glob.glob(os.path.join('sample_data', '*.stl'))
|
||||
if not stl_files:
|
||||
print("No STL files found in sample_data/")
|
||||
return
|
||||
|
||||
# Test loading and rendering the first file
|
||||
print(f"Testing with STL file: {stl_files[0]}")
|
||||
mesh = env.load_stl_file(stl_files[0])
|
||||
if mesh is None:
|
||||
print("Failed to load STL file")
|
||||
return
|
||||
|
||||
print(f"Loaded mesh with {len(mesh.vertices)} vertices and {len(mesh.faces)} faces")
|
||||
|
||||
# Render the mesh
|
||||
try:
|
||||
images = env.renderer.render_mesh_to_images(mesh)
|
||||
print(f"Successfully rendered {len(images)} images")
|
||||
|
||||
# Save the first image for inspection
|
||||
from PIL import Image
|
||||
img = Image.fromarray(images[0])
|
||||
img.save("test_render.png")
|
||||
print("Saved test render to test_render.png")
|
||||
except Exception as e:
|
||||
print(f"Error rendering: {e}")
|
||||
|
||||
print("Test completed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
PhysicalEnv.cli()
|
||||
0
environments/hack0/physical_server.py
Normal file
0
environments/hack0/physical_server.py
Normal file
54
environments/hack0/test_stl_env.py
Normal file
54
environments/hack0/test_stl_env.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for the PhysicalEnv that loads and processes STL files
|
||||
"""
|
||||
import os
|
||||
import asyncio
|
||||
import random
|
||||
from physical_env import PhysicalEnv, BaseEnvConfig, APIServerConfig, EvalHandlingEnum
|
||||
|
||||
async def test_render_stl():
|
||||
"""Test loading and rendering an STL file"""
|
||||
# Create a test environment
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="google/gemma-3-27b-it",
|
||||
group_size=8,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
max_token_length=2048,
|
||||
wandb_name="physical_test",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="google/gemma-3-27b-it",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
print("Creating test environment")
|
||||
env = PhysicalEnv(env_config, server_configs, slurm=False, testing=True)
|
||||
|
||||
print("Setting up environment")
|
||||
await env.setup()
|
||||
|
||||
# Test get_next_item
|
||||
print("Testing get_next_item")
|
||||
try:
|
||||
item = await env.get_next_item()
|
||||
print(f"Got item: {item['prompt']}")
|
||||
print(f"Image shape: {item['image'].shape}")
|
||||
print(f"STL path: {item['stl_path']}")
|
||||
except Exception as e:
|
||||
print(f"Error getting next item: {e}")
|
||||
|
||||
print("Test completed successfully")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_render_stl())
|
||||
Loading…
Add table
Add a link
Reference in a new issue