mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
data generation scripts to make hugging face compatible dataset
This commit is contained in:
parent
3fde5cbda8
commit
67a49f27b9
4 changed files with 251 additions and 0 deletions
30
environments/hack0/dataset_scr.py
Normal file
30
environments/hack0/dataset_scr.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
# Path to your rendered_images directory
|
||||
base_dir = "rendered_images"
|
||||
# Path to your rendered_images directory
|
||||
stl_dir = "selected_stls"
|
||||
|
||||
ds_stl_path = "dataset/stls/"
|
||||
ds_img_path = "dataset/images/"
|
||||
|
||||
# List and loop through all subdirectories
|
||||
for name in os.listdir(base_dir):
|
||||
path = os.path.join(base_dir, name)
|
||||
if os.path.isdir(path):
|
||||
# print(f"Found directory: {name}")
|
||||
stl_file_fpath = os.path.join(stl_dir, name)
|
||||
stl_file_fpath += ".stl"
|
||||
#print(stl_file_path)
|
||||
ds_stl_fpath = os.path.join(ds_stl_path, name)
|
||||
ds_stl_fpath += ".stl"
|
||||
shutil.copy(stl_file_fpath, ds_stl_path)
|
||||
base_img_fpath = path + "/render_0.png"
|
||||
ds_img_fpath = os.path.join(ds_img_path, name)
|
||||
ds_img_fpath += "_0001.png"
|
||||
shutil.copy(base_img_fpath, ds_img_fpath)
|
||||
#print(base_img_fpath, ds_img_fpath)
|
||||
|
||||
|
||||
|
||||
91
environments/hack0/llm_label.py
Normal file
91
environments/hack0/llm_label.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
import os
|
||||
import json
|
||||
import trimesh
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import BlipProcessor, BlipForConditionalGeneration
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
STL_DIR = "dataset/stls"
|
||||
IMG_DIR = "dataset/images"
|
||||
LABEL_FILE = "dataset/labels.json"
|
||||
|
||||
# Load BLIP for image captioning
|
||||
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
|
||||
|
||||
# Load Mistral or other small LLM
|
||||
llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
|
||||
llm_model = AutoModelForCausalLM.from_pretrained(llm_model_name, torch_dtype=torch.float16, device_map="auto")
|
||||
|
||||
def extract_trimesh_features(mesh):
|
||||
return {
|
||||
"volume": mesh.volume,
|
||||
"surface_area": mesh.area,
|
||||
"bounding_box": mesh.bounding_box.extents.tolist(),
|
||||
"num_faces": len(mesh.faces),
|
||||
"num_vertices": len(mesh.vertices),
|
||||
"is_watertight": mesh.is_watertight,
|
||||
"euler_number": mesh.euler_number,
|
||||
}
|
||||
|
||||
def caption_image(image_path):
|
||||
raw_image = Image.open(image_path).convert("RGB")
|
||||
inputs = blip_processor(raw_image, return_tensors="pt").to(device)
|
||||
out = blip_model.generate(**inputs)
|
||||
caption = blip_processor.decode(out[0], skip_special_tokens=True)
|
||||
return caption
|
||||
|
||||
def generate_label_with_llm(features, caption):
|
||||
prompt = f"""You are a 3D object classifier.
|
||||
|
||||
Mesh features:
|
||||
{json.dumps(features, indent=2)}
|
||||
|
||||
Rendered Image Caption:
|
||||
"{caption}"
|
||||
|
||||
Based on this information, return a short label (1-3 words) that describes the object.
|
||||
Label:"""
|
||||
|
||||
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
|
||||
outputs = llm_model.generate(**inputs, max_new_tokens=10, do_sample=False)
|
||||
output_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
return output_text.split("Label:")[-1].strip()
|
||||
|
||||
def main():
|
||||
labels = {}
|
||||
|
||||
for filename in os.listdir(STL_DIR):
|
||||
if not filename.endswith(".stl"):
|
||||
continue
|
||||
|
||||
stem = os.path.splitext(filename)[0]
|
||||
stl_path = os.path.join(STL_DIR, filename)
|
||||
img_path = os.path.join(IMG_DIR, f"{stem}_0001.png")
|
||||
|
||||
if not os.path.exists(img_path):
|
||||
print(f"Missing image for {stem}, skipping...")
|
||||
continue
|
||||
|
||||
try:
|
||||
mesh = trimesh.load_mesh(stl_path)
|
||||
features = extract_trimesh_features(mesh)
|
||||
caption = caption_image(img_path)
|
||||
label = generate_label_with_llm(features, caption)
|
||||
labels[stem] = label
|
||||
print(f"Labeled {stem}: {label}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {stem}: {e}")
|
||||
|
||||
with open(LABEL_FILE, "w") as f:
|
||||
json.dump(labels, f, indent=2)
|
||||
|
||||
print(f"\nSaved labels to {LABEL_FILE}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
69
environments/hack0/prepare_push_hf_dataset.py
Normal file
69
environments/hack0/prepare_push_hf_dataset.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import trimesh
|
||||
from datasets import Dataset, Features, Value, Image
|
||||
from huggingface_hub import login
|
||||
|
||||
# Log in to HF Hub (optional if you've already done `huggingface-cli login`)
|
||||
login(token=os.getenv("HF_TOKEN")) # Or replace with string token
|
||||
|
||||
# Paths
|
||||
image_dir = "dataset/images"
|
||||
stl_dir = "dataset/stls"
|
||||
labels_path = "dataset/labels.json"
|
||||
|
||||
# Load labels
|
||||
with open(labels_path, "r") as f:
|
||||
labels = json.load(f)
|
||||
|
||||
# Build data entries
|
||||
data = []
|
||||
for image_filename in os.listdir(image_dir):
|
||||
if not image_filename.endswith(".png"):
|
||||
continue
|
||||
image_path = os.path.join(image_dir, image_filename)
|
||||
|
||||
# Extract base ID
|
||||
base_id = image_filename.split("_")[0]
|
||||
|
||||
stl_path = os.path.join(stl_dir, f"{base_id}.stl")
|
||||
label = labels.get(base_id, "unknown")
|
||||
|
||||
# Load STL features (e.g., centroid + bounding box + volume as 9 floats)
|
||||
stl_features = [0.0] * 9
|
||||
if os.path.exists(stl_path):
|
||||
try:
|
||||
mesh = trimesh.load(stl_path, force="mesh")
|
||||
bbox = mesh.bounding_box.extents
|
||||
centroid = mesh.centroid
|
||||
volume = mesh.volume
|
||||
stl_features = list(centroid) + list(bbox) + [volume]
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to process {stl_path}: {e}")
|
||||
|
||||
data.append({
|
||||
"image": image_path,
|
||||
"label": label,
|
||||
"stl_features": stl_features,
|
||||
"id": base_id,
|
||||
})
|
||||
|
||||
# Define dataset schema
|
||||
features = Features({
|
||||
"id": Value("string"),
|
||||
"image": Image(), # Load images from file paths
|
||||
"label": Value("string"),
|
||||
"stl_features": Value("string"), # Store as JSON string for simplicity
|
||||
})
|
||||
|
||||
# Convert stl_features to JSON strings for compatibility
|
||||
for item in data:
|
||||
item["stl_features"] = json.dumps(item["stl_features"])
|
||||
|
||||
# Create Dataset
|
||||
dataset = Dataset.from_list(data).cast(features)
|
||||
|
||||
# Push to Hub
|
||||
dataset.push_to_hub("venkatacrc/stl-image-dataset", private=True)
|
||||
|
||||
61
environments/hack0/render_stl.py
Normal file
61
environments/hack0/render_stl.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import bpy
|
||||
import sys
|
||||
import math
|
||||
import os
|
||||
|
||||
# Get args after --
|
||||
argv = sys.argv
|
||||
argv = argv[argv.index("--") + 1:] # args after --
|
||||
|
||||
input_stl = argv[0]
|
||||
output_dir = argv[1]
|
||||
|
||||
# Clear existing objects
|
||||
bpy.ops.object.select_all(action='SELECT')
|
||||
bpy.ops.object.delete(use_global=False)
|
||||
|
||||
# Import STL
|
||||
bpy.ops.import_mesh.stl(filepath=input_stl)
|
||||
obj = bpy.context.selected_objects[0]
|
||||
|
||||
# Center the object at origin
|
||||
bpy.ops.object.origin_set(type='ORIGIN_CENTER_OF_MASS', center='MEDIAN')
|
||||
obj.location = (0, 0, 0)
|
||||
|
||||
# Add Sun light
|
||||
sun_light_data = bpy.data.lights.new(name="SunLight", type='SUN')
|
||||
sun_light_object = bpy.data.objects.new(name="SunLight", object_data=sun_light_data)
|
||||
sun_light_object.location = (10, 10, 10)
|
||||
bpy.context.collection.objects.link(sun_light_object)
|
||||
|
||||
# Create camera
|
||||
cam_data = bpy.data.cameras.new("Camera")
|
||||
cam_obj = bpy.data.objects.new("Camera", cam_data)
|
||||
bpy.context.collection.objects.link(cam_obj)
|
||||
bpy.context.scene.camera = cam_obj
|
||||
|
||||
# Set render resolution
|
||||
bpy.context.scene.render.resolution_x = 512
|
||||
bpy.context.scene.render.resolution_y = 512
|
||||
|
||||
# Rendering parameters
|
||||
angles = [0, 120, 240] # degrees around Z axis
|
||||
radius = 10
|
||||
elevation = 5
|
||||
|
||||
for i, angle in enumerate(angles):
|
||||
rad = math.radians(angle)
|
||||
cam_x = radius * math.cos(rad)
|
||||
cam_y = radius * math.sin(rad)
|
||||
cam_z = elevation
|
||||
cam_obj.location = (cam_x, cam_y, cam_z)
|
||||
|
||||
# Point camera to object center (0,0,0)
|
||||
direction = -cam_obj.location
|
||||
rot_quat = direction.to_track_quat('-Z', 'Y')
|
||||
cam_obj.rotation_euler = rot_quat.to_euler()
|
||||
|
||||
# Render
|
||||
bpy.context.scene.render.filepath = os.path.join(output_dir, f"render_{i}.png")
|
||||
bpy.ops.render.render(write_still=True)
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue