atropos/environments/community/openvla_robotics/open_robot_env.py
2025-06-02 13:56:48 +02:00

245 lines
8.2 KiB
Python

from typing import Dict, List, Optional, Tuple, Union
import robosuite as suite
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import Item, number
from .action_tokenizer import ActionTokenizer
prompt = "In: What action should the robot take to pick up the cube?\nOut:"
class RobotSimEnv(BaseEnv):
name = "robot_sim"
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
# Add tracking for wandb visualizations
self.rollouts_for_wandb = []
self.completion_lengths = []
# Many methods here will have to be removed
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
# vllm support for VLA models is a bit iffy. Therefore we don't use vllm.
# Ideally no model would be used below, but empty/None values throw error
# therefore placeholder tokenizer_name and model_name is passed.
# These are not needed for the VLA to run.
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
wandb_name="gsm8k",
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
return env_config, server_configs
# Not sure what to do with this
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
# Try to calculate percent_correct, pass if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
# Call the parent method to handle the server metrics
await super().wandb_log(wandb_metrics)
# There is NO test data for this environment
async def setup(self):
self.local_model_path = "openvla/openvla-7b"
self.robosuite_env = suite.make(
"Lift", # This can be changed to any other environment like NutAssemblySquare
robots="Panda",
has_renderer=False,
has_offscreen_renderer=True,
use_camera_obs=True,
camera_names="frontview",
camera_heights=640,
camera_widths=480,
)
self.processor = AutoProcessor.from_pretrained(
self.local_model_path, trust_remote_code=True
)
self.vla = AutoModelForVision2Seq.from_pretrained(
self.local_model_path,
# attn_implementation="flash_attention_2", #Removed for now. #TODO: Add this
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
).to("cuda:0")
self.action_tokenizer = ActionTokenizer(
self.processor.tokenizer, bins=256, min_action=-1, max_action=1
)
self.max_steps = 100
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
# See where this is called from and then change. No eval data so commenting out for now.
async def rollout_and_score_eval(self, question: str, answer: str) -> number:
pass
async def evaluate(self, *args, **kwargs):
# No evaluation data available for this environment yet
# eval_tasks = []
# for item in self.test:
# eval_tasks.append(
# self.rollout_and_score_eval(item["question"], item["gold_answer"])
# )
# scores = await tqdm_asyncio.gather(*eval_tasks)
# self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
return None
async def collect_trajectories(
self, item: bool
) -> Tuple[ScoredDataGroup, list[Item]]:
obs = self.robosuite_env.reset()
to_score = list()
to_backlog = list()
print("TRYING TO COLLECT TRAJECTORIES")
print(self.max_steps)
for i in range(self.max_steps):
image = Image.fromarray(obs["frontview_image"])
inputs = self.processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
action = self.vla.predict_action(
**inputs, unnorm_key="bridge_orig", do_sample=False
)
# Remove adjustment for now
action[0:3] = action[0:3] * 100 # because the sensitivity is 0.01
action[..., 2] = action[..., 2] * -1.0
# action[..., -1] = np.sign(action[..., -1])
orig_low, orig_high = 0.0, 1.0
action[..., -1] = (
2 * (action[..., -1] - orig_low) / (orig_high - orig_low) - 1
)
# action[..., -1] = np.sign(action[..., -1])
# action[6] = action[6]*2-1 #related to the gripper openvla is [0,1],robosuite is [-1,1]
action[..., -1] = action[..., -1] * -1.0
obs, reward, done, info = self.robosuite_env.step(action)
print(action)
to_score.append({"action:": action, "reward": reward})
# user_message = {"role": "user", "content": item["question"]}
# gold_answer = (
# "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
# )
# chat_completions = await self.server.chat_completion(
# messages=[{"role": "system", "content": system_prompt}, user_message],
# n=self.config.group_size,
# max_tokens=self.config.max_token_length,
# )
# for i, chat_completion in enumerate(chat_completions.choices):
# messages = (
# {"role": "system", "content": system_prompt},
# user_message,
# {"role": "assistant", "content": chat_completion.message.content},
# )
# to_score.append(
# {
# "messages": messages,
# "gold_answer": gold_answer,
# "finish_reason": chat_completion.finish_reason,
# }
# )
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog
async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
for item in rollout_group_data:
action_tokens, mask = self.action_tokenizer.tokenize(item["action"])
scores["tokens"].append(action_tokens)
scores["masks"].append(mask)
scores["scores"].append(item["reward"])
return scores
# What should this return? A full trajectory of the robot consisting of N steps
async def get_next_item(self) -> bool:
# There is no next item, only a reset (or a random seed).
# See how we can take a random seed. For now just reset.
# Another limitation is that right now it has to be 1 by 1.
# It can't do multiple of this. It has to wait for one to finish.
self.iter += 1
return True
if __name__ == "__main__":
RobotSimEnv.cli()