mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
Added openVLA environment for robotics
This commit is contained in:
parent
c189fc3351
commit
aa110fcbbe
2 changed files with 309 additions and 0 deletions
72
environments/hack0/action_tokenizer.py
Normal file
72
environments/hack0/action_tokenizer.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""
|
||||
action_tokenizer.py
|
||||
|
||||
Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions.
|
||||
"""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
|
||||
class ActionTokenizer:
|
||||
def __init__(
|
||||
self, tokenizer: PreTrainedTokenizerBase, bins: int = 256, min_action: int = -1, max_action: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Discretizes continuous robot actions into N bins per dimension and maps to the least used tokens.
|
||||
|
||||
NOTE =>> by default, assumes a BPE-style tokenizer akin to the LlamaTokenizer, where *the least used tokens*
|
||||
appear at the end of the vocabulary!
|
||||
|
||||
:param tokenizer: Base LLM/VLM tokenizer to extend.
|
||||
:param bins: Number of bins for each continuous value; we'll adopt a uniform binning strategy.
|
||||
:param min_action: Minimum action value (for clipping, setting lower bound on bin interval).
|
||||
:param max_action: Maximum action value (for clipping, setting upper bound on bin interval).
|
||||
"""
|
||||
self.tokenizer, self.n_bins, self.min_action, self.max_action = tokenizer, bins, min_action, max_action
|
||||
|
||||
# Create Uniform Bins + Compute Bin Centers
|
||||
self.bins = np.linspace(min_action, max_action, self.n_bins)
|
||||
self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
|
||||
|
||||
# [Contract] Set "action_token_begin_idx" based on `self.tokenizer.vocab_size - (self.n_bins + 1)`
|
||||
# =>> Assumes we're always overwriting the final `n_bins` tokens of the vocabulary!
|
||||
self.action_token_begin_idx: int = int(self.tokenizer.vocab_size - (self.n_bins + 1))
|
||||
|
||||
def __call__(self, action: np.ndarray) -> Union[str, List[str]]:
|
||||
"""Clip & bin actions to *the last `n_bins` tokens* of the vocabulary (e.g., tokenizer.vocab[-256:])."""
|
||||
action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action))
|
||||
discretized_action = np.digitize(action, self.bins)
|
||||
|
||||
# Handle single element vs. batch
|
||||
if len(discretized_action.shape) == 1:
|
||||
return self.tokenizer.decode(list(self.tokenizer.vocab_size - discretized_action))
|
||||
else:
|
||||
return self.tokenizer.batch_decode((self.tokenizer.vocab_size - discretized_action).tolist())
|
||||
|
||||
def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Returns continuous actions for discrete action token IDs.
|
||||
|
||||
NOTE =>> Because of the way the actions are discretized w.r.t. the bins (and not the bin centers), the
|
||||
digitization returns bin indices between [1, # bins], inclusive, when there are actually only
|
||||
(# bins - 1) bin intervals.
|
||||
|
||||
Therefore, if the digitization returns the last possible index, we map this to the last bin interval.
|
||||
|
||||
EXAMPLE =>> Let's say self._bins has 256 values. Then self._bin_centers has 255 values. Digitization returns
|
||||
indices between [1, 256]. We subtract 1 from all indices so that they are between [0, 255]. There
|
||||
is still one index (i==255) that would cause an out-of-bounds error if used to index into
|
||||
self._bin_centers. Therefore, if i==255, we subtract 1 from it so that it just becomes the index of
|
||||
the last bin center. We implement this simply via clipping between [0, 255 - 1].
|
||||
"""
|
||||
discretized_actions = self.tokenizer.vocab_size - action_token_ids
|
||||
discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
|
||||
|
||||
return self.bin_centers[discretized_actions]
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.n_bins
|
||||
237
environments/hack0/open_robot_env.py
Normal file
237
environments/hack0/open_robot_env.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
import random
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
import robosuite as suite
|
||||
from robosuite.controllers import load_part_controller_config
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from action_tokenizer import ActionTokenizer
|
||||
|
||||
prompt = "In: What action should the robot take to pick up the cube?\nOut:"
|
||||
|
||||
class RobotSimEnv(BaseEnv):
|
||||
|
||||
name = "robot_sim"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
self.completion_lengths = []
|
||||
|
||||
#! Many methods here will have to be removed
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
|
||||
#! vllm support for VLA models is a bit iffy. Therefore we don't use vllm. Ideally no model would be used below, but empty/None values throw error therefore placeholder tokenizer_name and model_name is passed. These are not needed for the VLA to run.
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="gsm8k",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
#! Not sure what to do with this
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Try to calculate percent_correct, pass if there's a division by zero
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
# Call the parent method to handle the server metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
#! There is NO test data for this environment
|
||||
async def setup(self):
|
||||
|
||||
self.local_model_path = "openvla/openvla-7b"
|
||||
|
||||
self.robosuite_env = robosuite_env = suite.make(
|
||||
"Lift", #This can be changed to any other envirionment like NutAssemblySquare
|
||||
robots="Panda",
|
||||
has_renderer=False,
|
||||
has_offscreen_renderer=True,
|
||||
use_camera_obs=True,
|
||||
camera_names="frontview",
|
||||
camera_heights=640,
|
||||
camera_widths=480
|
||||
)
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(self.local_model_path, trust_remote_code=True)
|
||||
|
||||
self.vla = AutoModelForVision2Seq.from_pretrained(
|
||||
self.local_model_path,
|
||||
# attn_implementation="flash_attention_2", #Removed for now. #TODO: Add this
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True
|
||||
).to("cuda:0")
|
||||
|
||||
self.action_tokenizer = ActionTokenizer(self.processor.tokenizer, bins=256, min_action=-1, max_action=1)
|
||||
|
||||
self.max_steps = 100
|
||||
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
#! See where this is called from and then change. No eval data so commenting out for now.
|
||||
async def rollout_and_score_eval(self, question: str, answer: str) -> number:
|
||||
|
||||
pass
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
eval_tasks = []
|
||||
# for item in self.test:
|
||||
# eval_tasks.append(
|
||||
# self.rollout_and_score_eval(item["question"], item["gold_answer"])
|
||||
# )
|
||||
# scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
# self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
|
||||
return None
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: bool
|
||||
) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
|
||||
obs = self.robosuite_env.reset()
|
||||
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
|
||||
print("TRYING TO COLLECT TRAJECTORIES")
|
||||
|
||||
print(self.max_steps)
|
||||
|
||||
for i in range(self.max_steps):
|
||||
|
||||
|
||||
|
||||
image = Image.fromarray(obs['frontview_image'])
|
||||
|
||||
inputs = self.processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
|
||||
action = self.vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
|
||||
|
||||
#! Remove adjustment for now
|
||||
action[0:3] = action[0:3]*100 #because the sensitivity is 0.01
|
||||
action[..., 2] = action[..., 2] * -1.0
|
||||
# action[..., -1] = np.sign(action[..., -1])
|
||||
|
||||
orig_low, orig_high = 0.0, 1.0
|
||||
action[..., -1] = 2 * (action[..., -1] - orig_low) / (orig_high - orig_low) - 1
|
||||
# action[..., -1] = np.sign(action[..., -1])
|
||||
# action[6] = action[6]*2-1 #related to the gripper openvla is [0,1],robosuite is [-1,1]
|
||||
action[..., -1] = action[..., -1] * -1.0
|
||||
|
||||
obs, reward, done, info = self.robosuite_env.step(action)
|
||||
|
||||
print(action)
|
||||
|
||||
to_score.append({"action:": action, "reward": reward})
|
||||
|
||||
# user_message = {"role": "user", "content": item["question"]}
|
||||
# gold_answer = (
|
||||
# "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
|
||||
# )
|
||||
|
||||
# chat_completions = await self.server.chat_completion(
|
||||
# messages=[{"role": "system", "content": system_prompt}, user_message],
|
||||
# n=self.config.group_size,
|
||||
# max_tokens=self.config.max_token_length,
|
||||
# )
|
||||
|
||||
# for i, chat_completion in enumerate(chat_completions.choices):
|
||||
# messages = (
|
||||
# {"role": "system", "content": system_prompt},
|
||||
# user_message,
|
||||
# {"role": "assistant", "content": chat_completion.message.content},
|
||||
# )
|
||||
# to_score.append(
|
||||
# {
|
||||
# "messages": messages,
|
||||
# "gold_answer": gold_answer,
|
||||
# "finish_reason": chat_completion.finish_reason,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
|
||||
for item in rollout_group_data:
|
||||
action_tokens, mask = self.action_tokenizer.tokenize(item["action"])
|
||||
scores["tokens"].append(action_tokens)
|
||||
scores["masks"].append(mask)
|
||||
scores["scores"].append(item["reward"])
|
||||
return scores
|
||||
|
||||
|
||||
#! What should this return? A full trajectory of the robot consiting of N steps
|
||||
async def get_next_item(self) -> bool:
|
||||
#! There is no next item, only a reset (or a random seed). See how we can take a random seed. For now just reset.
|
||||
|
||||
#! Another limitation is that right now it has to be 1 by 1. It can't do multiple of this. It has to wait for one to finish.
|
||||
|
||||
self.iter += 1
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
RobotSimEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue