data generation scripts to make hugging face compatible dataset

This commit is contained in:
venkatacrc 2025-05-18 23:54:12 +00:00
parent 3fde5cbda8
commit 67a49f27b9
4 changed files with 251 additions and 0 deletions

View file

@ -0,0 +1,30 @@
import os
import shutil
# Path to your rendered_images directory
base_dir = "rendered_images"
# Path to your rendered_images directory
stl_dir = "selected_stls"
ds_stl_path = "dataset/stls/"
ds_img_path = "dataset/images/"
# List and loop through all subdirectories
for name in os.listdir(base_dir):
path = os.path.join(base_dir, name)
if os.path.isdir(path):
# print(f"Found directory: {name}")
stl_file_fpath = os.path.join(stl_dir, name)
stl_file_fpath += ".stl"
#print(stl_file_path)
ds_stl_fpath = os.path.join(ds_stl_path, name)
ds_stl_fpath += ".stl"
shutil.copy(stl_file_fpath, ds_stl_path)
base_img_fpath = path + "/render_0.png"
ds_img_fpath = os.path.join(ds_img_path, name)
ds_img_fpath += "_0001.png"
shutil.copy(base_img_fpath, ds_img_fpath)
#print(base_img_fpath, ds_img_fpath)

View file

@ -0,0 +1,91 @@
import os
import json
import trimesh
import torch
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
STL_DIR = "dataset/stls"
IMG_DIR = "dataset/images"
LABEL_FILE = "dataset/labels.json"
# Load BLIP for image captioning
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
# Load Mistral or other small LLM
llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForCausalLM.from_pretrained(llm_model_name, torch_dtype=torch.float16, device_map="auto")
def extract_trimesh_features(mesh):
return {
"volume": mesh.volume,
"surface_area": mesh.area,
"bounding_box": mesh.bounding_box.extents.tolist(),
"num_faces": len(mesh.faces),
"num_vertices": len(mesh.vertices),
"is_watertight": mesh.is_watertight,
"euler_number": mesh.euler_number,
}
def caption_image(image_path):
raw_image = Image.open(image_path).convert("RGB")
inputs = blip_processor(raw_image, return_tensors="pt").to(device)
out = blip_model.generate(**inputs)
caption = blip_processor.decode(out[0], skip_special_tokens=True)
return caption
def generate_label_with_llm(features, caption):
prompt = f"""You are a 3D object classifier.
Mesh features:
{json.dumps(features, indent=2)}
Rendered Image Caption:
"{caption}"
Based on this information, return a short label (1-3 words) that describes the object.
Label:"""
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
outputs = llm_model.generate(**inputs, max_new_tokens=10, do_sample=False)
output_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
return output_text.split("Label:")[-1].strip()
def main():
labels = {}
for filename in os.listdir(STL_DIR):
if not filename.endswith(".stl"):
continue
stem = os.path.splitext(filename)[0]
stl_path = os.path.join(STL_DIR, filename)
img_path = os.path.join(IMG_DIR, f"{stem}_0001.png")
if not os.path.exists(img_path):
print(f"Missing image for {stem}, skipping...")
continue
try:
mesh = trimesh.load_mesh(stl_path)
features = extract_trimesh_features(mesh)
caption = caption_image(img_path)
label = generate_label_with_llm(features, caption)
labels[stem] = label
print(f"Labeled {stem}: {label}")
except Exception as e:
print(f"Error processing {stem}: {e}")
with open(LABEL_FILE, "w") as f:
json.dump(labels, f, indent=2)
print(f"\nSaved labels to {LABEL_FILE}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,69 @@
import os
import json
import numpy as np
import trimesh
from datasets import Dataset, Features, Value, Image
from huggingface_hub import login
# Log in to HF Hub (optional if you've already done `huggingface-cli login`)
login(token=os.getenv("HF_TOKEN")) # Or replace with string token
# Paths
image_dir = "dataset/images"
stl_dir = "dataset/stls"
labels_path = "dataset/labels.json"
# Load labels
with open(labels_path, "r") as f:
labels = json.load(f)
# Build data entries
data = []
for image_filename in os.listdir(image_dir):
if not image_filename.endswith(".png"):
continue
image_path = os.path.join(image_dir, image_filename)
# Extract base ID
base_id = image_filename.split("_")[0]
stl_path = os.path.join(stl_dir, f"{base_id}.stl")
label = labels.get(base_id, "unknown")
# Load STL features (e.g., centroid + bounding box + volume as 9 floats)
stl_features = [0.0] * 9
if os.path.exists(stl_path):
try:
mesh = trimesh.load(stl_path, force="mesh")
bbox = mesh.bounding_box.extents
centroid = mesh.centroid
volume = mesh.volume
stl_features = list(centroid) + list(bbox) + [volume]
except Exception as e:
print(f"⚠️ Failed to process {stl_path}: {e}")
data.append({
"image": image_path,
"label": label,
"stl_features": stl_features,
"id": base_id,
})
# Define dataset schema
features = Features({
"id": Value("string"),
"image": Image(), # Load images from file paths
"label": Value("string"),
"stl_features": Value("string"), # Store as JSON string for simplicity
})
# Convert stl_features to JSON strings for compatibility
for item in data:
item["stl_features"] = json.dumps(item["stl_features"])
# Create Dataset
dataset = Dataset.from_list(data).cast(features)
# Push to Hub
dataset.push_to_hub("venkatacrc/stl-image-dataset", private=True)

View file

@ -0,0 +1,61 @@
import bpy
import sys
import math
import os
# Get args after --
argv = sys.argv
argv = argv[argv.index("--") + 1:] # args after --
input_stl = argv[0]
output_dir = argv[1]
# Clear existing objects
bpy.ops.object.select_all(action='SELECT')
bpy.ops.object.delete(use_global=False)
# Import STL
bpy.ops.import_mesh.stl(filepath=input_stl)
obj = bpy.context.selected_objects[0]
# Center the object at origin
bpy.ops.object.origin_set(type='ORIGIN_CENTER_OF_MASS', center='MEDIAN')
obj.location = (0, 0, 0)
# Add Sun light
sun_light_data = bpy.data.lights.new(name="SunLight", type='SUN')
sun_light_object = bpy.data.objects.new(name="SunLight", object_data=sun_light_data)
sun_light_object.location = (10, 10, 10)
bpy.context.collection.objects.link(sun_light_object)
# Create camera
cam_data = bpy.data.cameras.new("Camera")
cam_obj = bpy.data.objects.new("Camera", cam_data)
bpy.context.collection.objects.link(cam_obj)
bpy.context.scene.camera = cam_obj
# Set render resolution
bpy.context.scene.render.resolution_x = 512
bpy.context.scene.render.resolution_y = 512
# Rendering parameters
angles = [0, 120, 240] # degrees around Z axis
radius = 10
elevation = 5
for i, angle in enumerate(angles):
rad = math.radians(angle)
cam_x = radius * math.cos(rad)
cam_y = radius * math.sin(rad)
cam_z = elevation
cam_obj.location = (cam_x, cam_y, cam_z)
# Point camera to object center (0,0,0)
direction = -cam_obj.location
rot_quat = direction.to_track_quat('-Z', 'Y')
cam_obj.rotation_euler = rot_quat.to_euler()
# Render
bpy.context.scene.render.filepath = os.path.join(output_dir, f"render_{i}.png")
bpy.ops.render.render(write_still=True)