mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
base changes
This commit is contained in:
parent
5351bcccea
commit
4e7583be43
3 changed files with 53 additions and 81 deletions
|
|
@ -0,0 +1 @@
|
|||
dataset = []
|
||||
|
|
@ -1,79 +1,15 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
|
||||
from atroposlib.type_definitions import Item
|
||||
from atropos.environments.hack0.doctor_agent.patient import patient_profiles
|
||||
from atropos.environments.hack0.doctor_agent.datasets import dataset
|
||||
|
||||
start_msg = """### Description
|
||||
There are four designated locations in the grid world indicated by R(ed),
|
||||
G(reen), Y(ellow), and B(lue). When the episode starts, the taxi starts off
|
||||
at a random square and the passenger is at a random location. The taxi
|
||||
drives to the passenger's location, picks up the passenger, drives to the
|
||||
passenger's destination (another one of the four specified locations), and
|
||||
then drops off the passenger. Once the passenger is dropped off, the episode ends.
|
||||
You are a doctor tasked to diagnose a patient symptoms. Your task is to ask the patient enough questions until you are confident about your answer
|
||||
|
||||
Map:
|
||||
|
||||
+---------+
|
||||
|R: | : :G|
|
||||
| : | : : |
|
||||
| : : : : |
|
||||
| | : | : |
|
||||
|Y| : |B: |
|
||||
+---------+
|
||||
|
||||
### Actions
|
||||
There are 6 discrete deterministic actions:
|
||||
- 0: move south (increases row index)
|
||||
- 1: move north (decreases row index)
|
||||
- 2: move east (increases column index)
|
||||
- 3: move west (decreases column index)
|
||||
- 4: pickup passenger (IF on a letter location, AND passenger is located at the same location, pickup passenger)
|
||||
- 5: drop off passenger
|
||||
|
||||
### Observations
|
||||
|
||||
Passenger locations:
|
||||
- 0: R(ed)
|
||||
- 1: G(reen)
|
||||
- 2: Y(ellow)
|
||||
- 3: B(lue)
|
||||
- 4: in taxi
|
||||
|
||||
Destinations:
|
||||
- 0: R(ed) (Row 0, Col 0)
|
||||
- 1: G(reen) (Row 4, Col 4)
|
||||
- 2: Y(ellow) (Row 0, Col 4)
|
||||
- 3: B(lue) (Row 3, Col 3)
|
||||
|
||||
### Instructions
|
||||
Please perform the actions that will let you pick up and/or drop off the passenger.
|
||||
Please respond with the action number only.
|
||||
You cannot move the taxi into walls, which are displayed as | in the map. : means you are free to move through that column.
|
||||
|
||||
|
||||
For an example, if the passenger is at R, and the destination is G, and the taxi is at (2, 2), then here are the following actions to solve this in the correct order:
|
||||
|
||||
3 (move west)
|
||||
3 (move west)
|
||||
1 (move north)
|
||||
1 (move north)
|
||||
4 (pickup passenger)
|
||||
0 (move south)
|
||||
0 (move south)
|
||||
2 (move east)
|
||||
2 (move east)
|
||||
2 (move east)
|
||||
2 (move east)
|
||||
0 (move south)
|
||||
0 (move south)
|
||||
5 (drop off passenger)
|
||||
|
||||
If you are stuck, try moving to row idx 2, as there are no walls there.
|
||||
|
||||
Submit your response as a number between 0 and 5 only to perform the discrete action.
|
||||
Each turn we will give you the current state of the environment, and you will need to respond with the action number only from the available actions.""" # noqa: E501
|
||||
When you are confident about the illness/disease the patient has respond with. The diagnosis is {illness}
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
def decode(i):
|
||||
|
|
@ -156,10 +92,12 @@ def state_render_to_user_msg(last_state, state, action_mask, render):
|
|||
ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[pass_idx]} to pick up the passenger."
|
||||
return ret_str
|
||||
|
||||
model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
|
||||
name = "gym_doctor"
|
||||
|
||||
class GymTaxiEnv(BaseEnv):
|
||||
class GymDoctorEnv(BaseEnv):
|
||||
|
||||
name = "gym_taxi"
|
||||
name = "gym_doctor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -180,16 +118,16 @@ class GymTaxiEnv(BaseEnv):
|
|||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
tokenizer_name=model,
|
||||
group_size=32,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
max_token_length=8192,
|
||||
wandb_name="gym_taxi",
|
||||
wandb_name=name,
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
model_name=model,
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
|
|
@ -237,13 +175,32 @@ class GymTaxiEnv(BaseEnv):
|
|||
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
|
||||
# Grab a dedicated llm server to take advantage of caching
|
||||
async with self.server.dedicated_server() as server:
|
||||
env = gym.make("Taxi-v3", render_mode="ansi")
|
||||
state, info = env.reset(seed=item["seed"])
|
||||
# env = gym.make(name, render_mode="ansi")
|
||||
# state, info = env.reset(seed=item["seed"]) #FIXME:
|
||||
last_state = None
|
||||
taxi_row, taxi_col, pass_idx, dest_idx = decode(state)
|
||||
init_msg = f"{start_msg}\n\n" + state_render_to_user_msg(
|
||||
last_state, state, info["action_mask"], env.render()
|
||||
)
|
||||
patient_state = []
|
||||
|
||||
patient_profile = random.choice(patient_profiles)
|
||||
|
||||
symptoms = dataset[0]
|
||||
|
||||
|
||||
|
||||
|
||||
doctor_state = []
|
||||
|
||||
|
||||
# taxi_row, taxi_col, pass_idx, dest_idx = decode(state)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
init_msg
|
||||
|
||||
# init_msg = f"{start_msg}\n\n" + state_render_to_user_msg(
|
||||
# last_state, state, info["action_mask"], env.render()
|
||||
# )
|
||||
messages = [{"role": "user", "content": init_msg}]
|
||||
score = -1
|
||||
while True:
|
||||
|
|
@ -328,4 +285,4 @@ class GymTaxiEnv(BaseEnv):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GymTaxiEnv.cli()
|
||||
GymDoctorEnv.cli()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
|
||||
|
||||
patient_profiles = [
|
||||
|
||||
]
|
||||
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key=XAI_API_KEY,
|
||||
base_url="https://api.x.ai/v1",
|
||||
)
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue