atropos/environments/community/physical_space_stl/llm_label.py
2025-05-27 08:53:06 +10:00

104 lines
2.9 KiB
Python

import json
import os
import torch
import trimesh
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BlipForConditionalGeneration,
BlipProcessor,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
STL_DIR = "dataset/stls"
IMG_DIR = "dataset/images"
LABEL_FILE = "dataset/labels.json"
# Load BLIP for image captioning
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
).to(device)
# Load Mistral or other small LLM
llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name, torch_dtype=torch.float16, device_map="auto"
)
def extract_trimesh_features(mesh):
return {
"volume": mesh.volume,
"surface_area": mesh.area,
"bounding_box": mesh.bounding_box.extents.tolist(),
"num_faces": len(mesh.faces),
"num_vertices": len(mesh.vertices),
"is_watertight": mesh.is_watertight,
"euler_number": mesh.euler_number,
}
def caption_image(image_path):
raw_image = Image.open(image_path).convert("RGB")
inputs = blip_processor(raw_image, return_tensors="pt").to(device)
out = blip_model.generate(**inputs)
caption = blip_processor.decode(out[0], skip_special_tokens=True)
return caption
def generate_label_with_llm(features, caption):
prompt = f"""You are a 3D object classifier.
Mesh features:
{json.dumps(features, indent=2)}
Rendered Image Caption:
"{caption}"
Based on this information, return a short label (1-3 words) that describes the object.
Label:"""
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
outputs = llm_model.generate(**inputs, max_new_tokens=10, do_sample=False)
output_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
return output_text.split("Label:")[-1].strip()
def main():
labels = {}
for filename in os.listdir(STL_DIR):
if not filename.endswith(".stl"):
continue
stem = os.path.splitext(filename)[0]
stl_path = os.path.join(STL_DIR, filename)
img_path = os.path.join(IMG_DIR, f"{stem}_0001.png")
if not os.path.exists(img_path):
print(f"Missing image for {stem}, skipping...")
continue
try:
mesh = trimesh.load_mesh(stl_path)
features = extract_trimesh_features(mesh)
caption = caption_image(img_path)
label = generate_label_with_llm(features, caption)
labels[stem] = label
print(f"Labeled {stem}: {label}")
except Exception as e:
print(f"Error processing {stem}: {e}")
with open(LABEL_FILE, "w") as f:
json.dump(labels, f, indent=2)
print(f"\nSaved labels to {LABEL_FILE}")
if __name__ == "__main__":
main()