diff --git a/.gitignore b/.gitignore
index 1aad045e..2923929d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
+# keys
+*/secrets.json
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
diff --git a/environments/hack0/doctor.py b/environments/hack0/doctor.py
deleted file mode 100644
index 0399aa17..00000000
--- a/environments/hack0/doctor.py
+++ /dev/null
@@ -1,185 +0,0 @@
-import re
-from typing import Dict, List, Optional, Tuple
-
-import gymnasium as gym
-
-from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
-from atroposlib.type_definitions import Item
-
-system_prompt = """
-You are a doctor. You are interacting with a patient.
-You need to diagnose the patient based on the symptoms.
-You will need to ask the patient follow up questions to diagnose them.
-Once you are confident in your diagnosis, provide it in the format:
-
-The patient is diagnosed with {possible_illness}.
-
-For example,
-
-user: I have a headache.
-assistant: What is the severity of your headache?
-user: It's a 3/10.
-assistant: What is the location of your headache?
-user: It's in the front of my head.
-assistant: What is the duration of your headache?
-user: It's been going on for 2 days.
-assistant: The patient is diagnosed with headache
-"""
-
-
-class DoctorEnv(BaseEnv):
-
- name = "doctor"
-
- def __init__(
- self,
- config: BaseEnvConfig,
- server_configs: List[APIServerConfig],
- slurm=True,
- testing=False,
- ):
- super().__init__(config, server_configs, slurm, testing)
- self.percent_correct_buffer = list()
- self.eval_metrics = list()
- # Add tracking for wandb visualizations
- self.rollouts_for_wandb = []
- self.completion_lengths = []
- self.print_this_env = False
-
- @classmethod
- def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
- env_config = BaseEnvConfig(
- tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
- group_size=32,
- use_wandb=True,
- rollout_server_url="http://localhost:8000",
- max_token_length=8192,
- wandb_name="gym_taxi",
- )
- server_configs = [
- APIServerConfig(
- model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
- base_url="http://localhost:9001/v1",
- api_key="x",
- num_requests_for_eval=256,
- ),
- ]
-
- return env_config, server_configs
-
- async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
- if wandb_metrics is None:
- wandb_metrics = {}
-
- # Try to calculate percent_correct, pass if there's a division by zero
- try:
- wandb_metrics["train/percent_correct"] = sum(
- self.percent_correct_buffer
- ) / len(self.percent_correct_buffer)
- except ZeroDivisionError:
- # Skip if buffer is empty
- pass
-
- self.percent_correct_buffer = list()
- for item in self.eval_metrics:
- wandb_metrics[item[0]] = item[1]
- self.eval_metrics = list()
- # Call the parent method to handle the server metrics
- await super().wandb_log(wandb_metrics)
-
- async def setup(self):
- self.iter = 0
-
- async def evaluate(self, *args, **kwargs):
- pass
-
- async def get_patient_msg(self, env: gym.Env) -> str:
- # Call xAI to get a patient message
- return env.render()
-
- async def collect_trajectory(
- self, item: Item
- ) -> Tuple[Optional[ScoredDataItem], List[Item]]:
- # Grab a dedicated llm server to take advantage of caching
- async with self.server.dedicated_server() as server:
- init_msg = f"{system_prompt}\n\n"
- messages = [{"role": "system", "content": init_msg}]
- patient_msg = await self.get_patient_msg(item)
- messages.append({"role": "user", "content": patient_msg})
- score = -1
- while True:
- if (
- len(self.tokenizer.apply_chat_template(messages))
- > self.config.max_token_length - 10
- ):
- score = 0
- break
- max_tokens = self.config.max_token_length - len(
- self.tokenizer.apply_chat_template(
- messages, add_generation_prompt=True
- )
- )
- chat_completions = await server.chat_completion(
- messages=messages,
- n=1,
- max_tokens=max_tokens,
- )
- messages.append(
- {
- "role": "assistant",
- "content": chat_completions.choices[0].message.content,
- }
- )
- diagnosis_match = re.search(
- r"(.*?)",
- chat_completions.choices[0].message.content,
- re.DOTALL,
- )
- if diagnosis_match:
- diagnosis = diagnosis_match.group(1).strip()
- # Check if the diagnosis is correct
- if diagnosis == item["diagnosis"]:
- score = 1
- else:
- score = 0
- break
-
- next_patient_msg = await self.get_patient_msg(item)
- messages.append(
- {
- "role": "user",
- "content": next_patient_msg,
- }
- )
- self.percent_correct_buffer.append(max(score, 0))
- tokens = self.tokenizer.apply_chat_template(messages)
- masks = []
- for i, msg in enumerate(messages):
- if i == len(messages) - 1:
- masks.extend(tokens[len(masks) :])
- else:
- curr_tokens = self.tokenizer.apply_chat_template(
- messages[: i + 1],
- add_generation_prompt=messages[i + 1]["role"] == "assistant",
- )
- if messages[i]["role"] == "user":
- masks.extend([-100] * (len(curr_tokens) - len(masks)))
- else:
- masks.extend(curr_tokens[len(masks) :])
- scored_data_item = ScoredDataItem(
- messages=messages,
- finish_reason=score,
- tokens=tokens,
- masks=masks,
- scores=score,
- )
- return scored_data_item, []
-
- async def get_next_item(self):
- next_item = {"seed": self.iter}
- self.iter += 1
- return next_item
-
-
-# if __name__ == "__main__":
-# GymTaxiEnv.cli()
diff --git a/environments/hack0/doctor_agent/doctor_gym.py b/environments/hack0/doctor_agent/doctor_gym.py
index 87384722..48fee2a3 100644
--- a/environments/hack0/doctor_agent/doctor_gym.py
+++ b/environments/hack0/doctor_agent/doctor_gym.py
@@ -1,103 +1,60 @@
+import json
+import random
from typing import Dict, List, Optional, Tuple
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
from atroposlib.type_definitions import Item
from atropos.environments.hack0.doctor_agent.patient import patient_profiles
from atropos.environments.hack0.doctor_agent.datasets import dataset
+import re
+from typing import Dict, List, Optional, Tuple
+import os
+from openai import OpenAI
+import gymnasium as gym
-start_msg = """### Description
-You are a doctor tasked to diagnose a patient symptoms. Your task is to ask the patient enough questions until you are confident about your answer
+from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
+from atroposlib.type_definitions import Item
-When you are confident about the illness/disease the patient has respond with. The diagnosis is {illness}
-""" # noqa: E501
+with open("environments/hack0/doctor_agent", "r") as f:
+ keys = json.load(f)
+ xai_key = keys["xai"]
-def decode(i):
- out = []
- out.append(i % 4)
- i = i // 4
- out.append(i % 5)
- i = i // 5
- out.append(i % 5)
- i = i // 5
- out.append(i)
- assert 0 <= i < 5
- x = reversed(out)
- # Making it explicit so I don't have to look into gym code
- taxi_row, taxi_col, pass_idx, dest_idx = x
- return taxi_row, taxi_col, pass_idx, dest_idx
+client = OpenAI(
+ api_key=xai_key,
+ base_url="https://api.x.ai/v1",
+)
+
+final_message = "The diagnosis is:"
+final_message_prompt = final_message + "headache"
+
+doctor_system_prompt = """
+You are a doctor. You are interacting with a patient.
+You need to diagnose the patient based on the symptoms.
+You will need to ask the patient follow up questions to diagnose them.
+Once you are confident in your diagnosis, provide it in the format:
+
+The patient is diagnosed with {possible_illness}.
+
+For example,
+
+user: I have a headache.
+assistant: What is the severity of your headache?
+user: It's a 3/10.
+assistant: What is the location of your headache?
+user: It's in the front of my head.
+assistant: What is the duration of your headache?
+user: It's been going on for 2 days.
+assistant: The patient is diagnosed with headache
+"""
-# Note: Works for both the passenger and the destination
-TO_LOC_MAP = {
- 0: "R(Row 0, Col 0)",
- 1: "G (Row 4, Col 4)",
- 2: "Y (Row 0, Col 4)",
- 3: "B (Row 3, Col 3)",
- 4: "in taxi",
-}
-MAP_LOC = {0: (0, 0), 1: (4, 4), 2: (0, 4), 3: (3, 3)}
-TO_ACTION_MAP = {
- 0: "south",
- 1: "north",
- 2: "east",
- 3: "west",
- 4: "pickup",
- 5: "dropoff",
-}
+doctor_model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
+gym_name = "gym_doctor"
+class DoctorEnv(BaseEnv):
-def state_render_to_user_msg(last_state, state, action_mask, render):
- taxi_row, taxi_col, pass_idx, dest_idx = decode(state)
- if last_state is not None:
- last_taxi_row, last_taxi_col, last_pass_idx, last_dest_idx = decode(last_state)
- available_actions = "\n".join(
- [
- f"- {i}: {TO_ACTION_MAP[i]}"
- for i in range(6)
- if (action_mask[i] == 1)
- and (
- (i != 5)
- or (
- (i == 5)
- and (taxi_row == MAP_LOC[dest_idx][0])
- and (taxi_col == MAP_LOC[dest_idx][1])
- )
- )
- ]
- )
- if last_state is not None:
- ret_str = (
- f"Previous Taxi Location: Row: {last_taxi_row}, Col: {last_taxi_col}\n"
- )
- else:
- ret_str = ""
- ret_str += (
- f"Current state:\nTaxi: Row: {taxi_row}, Col: {taxi_col}\nPassenger: {TO_LOC_MAP[pass_idx]}\n"
- f"Destination: {TO_LOC_MAP[dest_idx]}\n\n"
- f"Map:\n{render}\n\n"
- f"Available actions:\n{available_actions}"
- )
- if (
- (pass_idx == 4)
- and (taxi_row == MAP_LOC[dest_idx][0])
- and (taxi_col == MAP_LOC[dest_idx][1])
- ):
- ret_str += "\n\nPlease drop off the passenger."
- elif pass_idx == 4:
- ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[dest_idx]} to drop off the passenger."
- elif (taxi_row == MAP_LOC[pass_idx][0]) and (taxi_col == MAP_LOC[pass_idx][1]):
- ret_str += "\n\nPlease pick up the passenger."
- else:
- ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[pass_idx]} to pick up the passenger."
- return ret_str
-
-model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
-name = "gym_doctor"
-
-class GymDoctorEnv(BaseEnv):
-
- name = "gym_doctor"
+ name = "doctor"
def __init__(
self,
@@ -108,7 +65,6 @@ class GymDoctorEnv(BaseEnv):
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
- self.percent_picked_up_passenger_buffer = list()
self.eval_metrics = list()
# Add tracking for wandb visualizations
self.rollouts_for_wandb = []
@@ -118,16 +74,16 @@ class GymDoctorEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_config = BaseEnvConfig(
- tokenizer_name=model,
+ tokenizer_name=doctor_model,
group_size=32,
use_wandb=True,
rollout_server_url="http://localhost:8000",
max_token_length=8192,
- wandb_name=name,
+ wandb_name=gym_name,
)
server_configs = [
APIServerConfig(
- model_name=model,
+ model_name=doctor_model,
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
@@ -148,16 +104,8 @@ class GymDoctorEnv(BaseEnv):
except ZeroDivisionError:
# Skip if buffer is empty
pass
- try:
- wandb_metrics["train/percent_picked_up_passenger"] = sum(
- self.percent_picked_up_passenger_buffer
- ) / len(self.percent_picked_up_passenger_buffer)
- except ZeroDivisionError:
- # Skip if buffer is empty
- pass
self.percent_correct_buffer = list()
- self.percent_picked_up_passenger_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
@@ -170,107 +118,107 @@ class GymDoctorEnv(BaseEnv):
async def evaluate(self, *args, **kwargs):
pass
+ async def get_patient_msg(self, env: gym.Env) -> str:
+ # Call xAI to get a patient message
+ return env.render()
+
async def collect_trajectory(
self, item: Item
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
# Grab a dedicated llm server to take advantage of caching
async with self.server.dedicated_server() as server:
- # env = gym.make(name, render_mode="ansi")
- # state, info = env.reset(seed=item["seed"]) #FIXME:
- last_state = None
- patient_state = []
+
+ patient_messages = []
+ doctor_messages = [
+ {
+ "role" : "system",
+ "content" : doctor_system_prompt
+ }
+ ]
patient_profile = random.choice(patient_profiles)
-
symptoms = dataset[0]
+ patient_system_prompt = patient_profile.format(symptoms)
+ patient_messages = [{"role": "system", "content": patient_system_prompt}]
+ completion = client.chat.completions.create(
+ model="grok-3-latest",
+ messages=patient_messages,
+ )
+ patient_msg = completion.choices[0].message
- doctor_state = []
+ doctor_messages.append({"role": "user", "content": patient_msg})
+ patient_messages.append({"role": "assistant", "content": patient_msg})
-
- # taxi_row, taxi_col, pass_idx, dest_idx = decode(state)
-
-
-
-
-
- init_msg
-
- # init_msg = f"{start_msg}\n\n" + state_render_to_user_msg(
- # last_state, state, info["action_mask"], env.render()
- # )
- messages = [{"role": "user", "content": init_msg}]
score = -1
while True:
if (
- len(self.tokenizer.apply_chat_template(messages))
+ len(self.tokenizer.apply_chat_template(doctor_messages))
> self.config.max_token_length - 10
):
+ score = 0
break
max_tokens = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(
- messages, add_generation_prompt=True
+ doctor_messages, add_generation_prompt=True
)
)
- chat_completions = await server.chat_completion(
- messages=messages,
+ doctor_completions = await server.chat_completion(
+ messages=doctor_messages,
n=1,
max_tokens=max_tokens,
)
- choice = (
- chat_completions.choices[0]
- .message.content.strip()
- .replace(".", "")[-1]
- )
- messages.append(
- {
- "role": "assistant",
- "content": chat_completions.choices[0].message.content,
- }
- )
- if choice.isdigit() and 0 <= int(choice) <= 5:
- action = int(choice)
- else:
+
+ doctor_msg = doctor_completions.choices[0].message.content
+
+ doctor_messages.append({"role": "assistant", "content": doctor_msg})
+ patient_messages.append({"role": "user", "content": doctor_msg})
+
+ # check output
+ if doctor_msg.startwith(final_message):
+ diagnosis = doctor_msg.strip(final_message)
+ diagnosis = diagnosis.strip()
+
+ if diagnosis == item["diagnosis"]:
+ score = 1
+ else:
+ score = 0
break
- if info["action_mask"][action] == 0:
- break
- if action == 3:
- # picked up passenger
- score = 0
- next_state, reward, terminated, truncated, info = env.step(action)
- last_state = state
- state = next_state
- if terminated:
- score = 1
- break
- messages.append(
- {
- "role": "user",
- "content": state_render_to_user_msg(
- last_state, state, info["action_mask"], env.render()
- ),
- }
+
+
+ completion = client.chat.completions.create(
+ model="grok-3-latest",
+ messages=patient_messages,
)
+
+ patient_msg = completion.choices[0].message
+
+ doctor_messages.append({"role": "user", "content": patient_msg})
+ patient_messages.append({"role": "assistant", "content": patient_msg})
+
+
self.percent_correct_buffer.append(max(score, 0))
- self.percent_picked_up_passenger_buffer.append(1 if score >= 0 else 0)
- tokens = self.tokenizer.apply_chat_template(messages)
+ tokens = self.tokenizer.apply_chat_template(doctor_messages)
+
+
masks = []
- for i, msg in enumerate(messages):
- if i == len(messages) - 1:
+ for i, msg in enumerate(doctor_messages):
+ if i == len(doctor_messages) - 1:
masks.extend(tokens[len(masks) :])
else:
curr_tokens = self.tokenizer.apply_chat_template(
- messages[: i + 1],
- add_generation_prompt=messages[i + 1]["role"] == "assistant",
+ doctor_messages[: i + 1],
+ add_generation_prompt=doctor_messages[i + 1]["role"] == "assistant",
)
- if messages[i]["role"] == "user":
+ if doctor_messages[i]["role"] == "user":
masks.extend([-100] * (len(curr_tokens) - len(masks)))
else:
masks.extend(curr_tokens[len(masks) :])
+
scored_data_item = ScoredDataItem(
- messages=messages,
+ messages=doctor_messages,
finish_reason=score,
tokens=tokens,
masks=masks,
@@ -284,5 +232,5 @@ class GymDoctorEnv(BaseEnv):
return next_item
-if __name__ == "__main__":
- GymDoctorEnv.cli()
+# if __name__ == "__main__":
+# GymTaxiEnv.cli()