base changes

This commit is contained in:
Alexander Speicher 2025-05-18 15:08:08 -07:00 committed by tsadpbb
parent 5351bcccea
commit 4e7583be43
3 changed files with 53 additions and 81 deletions

View file

@ -0,0 +1 @@
dataset = []

View file

@ -1,79 +1,15 @@
from typing import Dict, List, Optional, Tuple
import gymnasium as gym
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
from atroposlib.type_definitions import Item
from atropos.environments.hack0.doctor_agent.patient import patient_profiles
from atropos.environments.hack0.doctor_agent.datasets import dataset
start_msg = """### Description
There are four designated locations in the grid world indicated by R(ed),
G(reen), Y(ellow), and B(lue). When the episode starts, the taxi starts off
at a random square and the passenger is at a random location. The taxi
drives to the passenger's location, picks up the passenger, drives to the
passenger's destination (another one of the four specified locations), and
then drops off the passenger. Once the passenger is dropped off, the episode ends.
You are a doctor tasked to diagnose a patient symptoms. Your task is to ask the patient enough questions until you are confident about your answer
Map:
+---------+
|R: | : :G|
| : | : : |
| : : : : |
| | : | : |
|Y| : |B: |
+---------+
### Actions
There are 6 discrete deterministic actions:
- 0: move south (increases row index)
- 1: move north (decreases row index)
- 2: move east (increases column index)
- 3: move west (decreases column index)
- 4: pickup passenger (IF on a letter location, AND passenger is located at the same location, pickup passenger)
- 5: drop off passenger
### Observations
Passenger locations:
- 0: R(ed)
- 1: G(reen)
- 2: Y(ellow)
- 3: B(lue)
- 4: in taxi
Destinations:
- 0: R(ed) (Row 0, Col 0)
- 1: G(reen) (Row 4, Col 4)
- 2: Y(ellow) (Row 0, Col 4)
- 3: B(lue) (Row 3, Col 3)
### Instructions
Please perform the actions that will let you pick up and/or drop off the passenger.
Please respond with the action number only.
You cannot move the taxi into walls, which are displayed as | in the map. : means you are free to move through that column.
For an example, if the passenger is at R, and the destination is G, and the taxi is at (2, 2), then here are the following actions to solve this in the correct order:
3 (move west)
3 (move west)
1 (move north)
1 (move north)
4 (pickup passenger)
0 (move south)
0 (move south)
2 (move east)
2 (move east)
2 (move east)
2 (move east)
0 (move south)
0 (move south)
5 (drop off passenger)
If you are stuck, try moving to row idx 2, as there are no walls there.
Submit your response as a number between 0 and 5 only to perform the discrete action.
Each turn we will give you the current state of the environment, and you will need to respond with the action number only from the available actions.""" # noqa: E501
When you are confident about the illness/disease the patient has respond with. The diagnosis is {illness}
""" # noqa: E501
def decode(i):
@ -156,10 +92,12 @@ def state_render_to_user_msg(last_state, state, action_mask, render):
ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[pass_idx]} to pick up the passenger."
return ret_str
model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
name = "gym_doctor"
class GymTaxiEnv(BaseEnv):
class GymDoctorEnv(BaseEnv):
name = "gym_taxi"
name = "gym_doctor"
def __init__(
self,
@ -180,16 +118,16 @@ class GymTaxiEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
tokenizer_name=model,
group_size=32,
use_wandb=True,
rollout_server_url="http://localhost:8000",
max_token_length=8192,
wandb_name="gym_taxi",
wandb_name=name,
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
model_name=model,
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
@ -237,13 +175,32 @@ class GymTaxiEnv(BaseEnv):
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
# Grab a dedicated llm server to take advantage of caching
async with self.server.dedicated_server() as server:
env = gym.make("Taxi-v3", render_mode="ansi")
state, info = env.reset(seed=item["seed"])
# env = gym.make(name, render_mode="ansi")
# state, info = env.reset(seed=item["seed"]) #FIXME:
last_state = None
taxi_row, taxi_col, pass_idx, dest_idx = decode(state)
init_msg = f"{start_msg}\n\n" + state_render_to_user_msg(
last_state, state, info["action_mask"], env.render()
)
patient_state = []
patient_profile = random.choice(patient_profiles)
symptoms = dataset[0]
doctor_state = []
# taxi_row, taxi_col, pass_idx, dest_idx = decode(state)
init_msg
# init_msg = f"{start_msg}\n\n" + state_render_to_user_msg(
# last_state, state, info["action_mask"], env.render()
# )
messages = [{"role": "user", "content": init_msg}]
score = -1
while True:
@ -328,4 +285,4 @@ class GymTaxiEnv(BaseEnv):
if __name__ == "__main__":
GymTaxiEnv.cli()
GymDoctorEnv.cli()

View file

@ -0,0 +1,14 @@
patient_profiles = [
]
from openai import OpenAI
client = OpenAI(
api_key=XAI_API_KEY,
base_url="https://api.x.ai/v1",
)