mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
Patient Doctor Loop
This commit is contained in:
parent
4e7583be43
commit
02ff663ebe
3 changed files with 113 additions and 348 deletions
|
|
@ -1,185 +0,0 @@
|
|||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
system_prompt = """
|
||||
You are a doctor. You are interacting with a patient.
|
||||
You need to diagnose the patient based on the symptoms.
|
||||
You will need to ask the patient follow up questions to diagnose them.
|
||||
Once you are confident in your diagnosis, provide it in the format:
|
||||
|
||||
The patient is diagnosed with <diagnosis>{possible_illness}.</diagnosis>
|
||||
|
||||
For example,
|
||||
|
||||
user: I have a headache.
|
||||
assistant: What is the severity of your headache?
|
||||
user: It's a 3/10.
|
||||
assistant: What is the location of your headache?
|
||||
user: It's in the front of my head.
|
||||
assistant: What is the duration of your headache?
|
||||
user: It's been going on for 2 days.
|
||||
assistant: The patient is diagnosed with <diagnosis>headache</diagnosis>
|
||||
"""
|
||||
|
||||
|
||||
class DoctorEnv(BaseEnv):
|
||||
|
||||
name = "doctor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
self.completion_lengths = []
|
||||
self.print_this_env = False
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=32,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
max_token_length=8192,
|
||||
wandb_name="gym_taxi",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Try to calculate percent_correct, pass if there's a division by zero
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
# Call the parent method to handle the server metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
self.iter = 0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def get_patient_msg(self, env: gym.Env) -> str:
|
||||
# Call xAI to get a patient message
|
||||
return env.render()
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
|
||||
# Grab a dedicated llm server to take advantage of caching
|
||||
async with self.server.dedicated_server() as server:
|
||||
init_msg = f"{system_prompt}\n\n"
|
||||
messages = [{"role": "system", "content": init_msg}]
|
||||
patient_msg = await self.get_patient_msg(item)
|
||||
messages.append({"role": "user", "content": patient_msg})
|
||||
score = -1
|
||||
while True:
|
||||
if (
|
||||
len(self.tokenizer.apply_chat_template(messages))
|
||||
> self.config.max_token_length - 10
|
||||
):
|
||||
score = 0
|
||||
break
|
||||
max_tokens = self.config.max_token_length - len(
|
||||
self.tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
)
|
||||
chat_completions = await server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": chat_completions.choices[0].message.content,
|
||||
}
|
||||
)
|
||||
diagnosis_match = re.search(
|
||||
r"<diagnosis>(.*?)</diagnosis>",
|
||||
chat_completions.choices[0].message.content,
|
||||
re.DOTALL,
|
||||
)
|
||||
if diagnosis_match:
|
||||
diagnosis = diagnosis_match.group(1).strip()
|
||||
# Check if the diagnosis is correct
|
||||
if diagnosis == item["diagnosis"]:
|
||||
score = 1
|
||||
else:
|
||||
score = 0
|
||||
break
|
||||
|
||||
next_patient_msg = await self.get_patient_msg(item)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": next_patient_msg,
|
||||
}
|
||||
)
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
tokens = self.tokenizer.apply_chat_template(messages)
|
||||
masks = []
|
||||
for i, msg in enumerate(messages):
|
||||
if i == len(messages) - 1:
|
||||
masks.extend(tokens[len(masks) :])
|
||||
else:
|
||||
curr_tokens = self.tokenizer.apply_chat_template(
|
||||
messages[: i + 1],
|
||||
add_generation_prompt=messages[i + 1]["role"] == "assistant",
|
||||
)
|
||||
if messages[i]["role"] == "user":
|
||||
masks.extend([-100] * (len(curr_tokens) - len(masks)))
|
||||
else:
|
||||
masks.extend(curr_tokens[len(masks) :])
|
||||
scored_data_item = ScoredDataItem(
|
||||
messages=messages,
|
||||
finish_reason=score,
|
||||
tokens=tokens,
|
||||
masks=masks,
|
||||
scores=score,
|
||||
)
|
||||
return scored_data_item, []
|
||||
|
||||
async def get_next_item(self):
|
||||
next_item = {"seed": self.iter}
|
||||
self.iter += 1
|
||||
return next_item
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# GymTaxiEnv.cli()
|
||||
|
|
@ -1,103 +1,60 @@
|
|||
import json
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
|
||||
from atroposlib.type_definitions import Item
|
||||
from atropos.environments.hack0.doctor_agent.patient import patient_profiles
|
||||
from atropos.environments.hack0.doctor_agent.datasets import dataset
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import os
|
||||
from openai import OpenAI
|
||||
import gymnasium as gym
|
||||
|
||||
start_msg = """### Description
|
||||
You are a doctor tasked to diagnose a patient symptoms. Your task is to ask the patient enough questions until you are confident about your answer
|
||||
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
When you are confident about the illness/disease the patient has respond with. The diagnosis is {illness}
|
||||
""" # noqa: E501
|
||||
with open("environments/hack0/doctor_agent", "r") as f:
|
||||
keys = json.load(f)
|
||||
xai_key = keys["xai"]
|
||||
|
||||
|
||||
def decode(i):
|
||||
out = []
|
||||
out.append(i % 4)
|
||||
i = i // 4
|
||||
out.append(i % 5)
|
||||
i = i // 5
|
||||
out.append(i % 5)
|
||||
i = i // 5
|
||||
out.append(i)
|
||||
assert 0 <= i < 5
|
||||
x = reversed(out)
|
||||
# Making it explicit so I don't have to look into gym code
|
||||
taxi_row, taxi_col, pass_idx, dest_idx = x
|
||||
return taxi_row, taxi_col, pass_idx, dest_idx
|
||||
client = OpenAI(
|
||||
api_key=xai_key,
|
||||
base_url="https://api.x.ai/v1",
|
||||
)
|
||||
|
||||
final_message = "The diagnosis is:"
|
||||
final_message_prompt = final_message + "<diagnosis>headache</diagnosis>"
|
||||
|
||||
doctor_system_prompt = """
|
||||
You are a doctor. You are interacting with a patient.
|
||||
You need to diagnose the patient based on the symptoms.
|
||||
You will need to ask the patient follow up questions to diagnose them.
|
||||
Once you are confident in your diagnosis, provide it in the format:
|
||||
|
||||
The patient is diagnosed with <diagnosis>{possible_illness}.</diagnosis>
|
||||
|
||||
For example,
|
||||
|
||||
user: I have a headache.
|
||||
assistant: What is the severity of your headache?
|
||||
user: It's a 3/10.
|
||||
assistant: What is the location of your headache?
|
||||
user: It's in the front of my head.
|
||||
assistant: What is the duration of your headache?
|
||||
user: It's been going on for 2 days.
|
||||
assistant: The patient is diagnosed with <diagnosis>headache</diagnosis>
|
||||
"""
|
||||
|
||||
|
||||
# Note: Works for both the passenger and the destination
|
||||
TO_LOC_MAP = {
|
||||
0: "R(Row 0, Col 0)",
|
||||
1: "G (Row 4, Col 4)",
|
||||
2: "Y (Row 0, Col 4)",
|
||||
3: "B (Row 3, Col 3)",
|
||||
4: "in taxi",
|
||||
}
|
||||
MAP_LOC = {0: (0, 0), 1: (4, 4), 2: (0, 4), 3: (3, 3)}
|
||||
TO_ACTION_MAP = {
|
||||
0: "south",
|
||||
1: "north",
|
||||
2: "east",
|
||||
3: "west",
|
||||
4: "pickup",
|
||||
5: "dropoff",
|
||||
}
|
||||
doctor_model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
|
||||
gym_name = "gym_doctor"
|
||||
|
||||
class DoctorEnv(BaseEnv):
|
||||
|
||||
def state_render_to_user_msg(last_state, state, action_mask, render):
|
||||
taxi_row, taxi_col, pass_idx, dest_idx = decode(state)
|
||||
if last_state is not None:
|
||||
last_taxi_row, last_taxi_col, last_pass_idx, last_dest_idx = decode(last_state)
|
||||
available_actions = "\n".join(
|
||||
[
|
||||
f"- {i}: {TO_ACTION_MAP[i]}"
|
||||
for i in range(6)
|
||||
if (action_mask[i] == 1)
|
||||
and (
|
||||
(i != 5)
|
||||
or (
|
||||
(i == 5)
|
||||
and (taxi_row == MAP_LOC[dest_idx][0])
|
||||
and (taxi_col == MAP_LOC[dest_idx][1])
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
if last_state is not None:
|
||||
ret_str = (
|
||||
f"Previous Taxi Location: Row: {last_taxi_row}, Col: {last_taxi_col}\n"
|
||||
)
|
||||
else:
|
||||
ret_str = ""
|
||||
ret_str += (
|
||||
f"Current state:\nTaxi: Row: {taxi_row}, Col: {taxi_col}\nPassenger: {TO_LOC_MAP[pass_idx]}\n"
|
||||
f"Destination: {TO_LOC_MAP[dest_idx]}\n\n"
|
||||
f"Map:\n{render}\n\n"
|
||||
f"Available actions:\n{available_actions}"
|
||||
)
|
||||
if (
|
||||
(pass_idx == 4)
|
||||
and (taxi_row == MAP_LOC[dest_idx][0])
|
||||
and (taxi_col == MAP_LOC[dest_idx][1])
|
||||
):
|
||||
ret_str += "\n\nPlease drop off the passenger."
|
||||
elif pass_idx == 4:
|
||||
ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[dest_idx]} to drop off the passenger."
|
||||
elif (taxi_row == MAP_LOC[pass_idx][0]) and (taxi_col == MAP_LOC[pass_idx][1]):
|
||||
ret_str += "\n\nPlease pick up the passenger."
|
||||
else:
|
||||
ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[pass_idx]} to pick up the passenger."
|
||||
return ret_str
|
||||
|
||||
model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
|
||||
name = "gym_doctor"
|
||||
|
||||
class GymDoctorEnv(BaseEnv):
|
||||
|
||||
name = "gym_doctor"
|
||||
name = "doctor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -108,7 +65,6 @@ class GymDoctorEnv(BaseEnv):
|
|||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.percent_picked_up_passenger_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
|
|
@ -118,16 +74,16 @@ class GymDoctorEnv(BaseEnv):
|
|||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name=model,
|
||||
tokenizer_name=doctor_model,
|
||||
group_size=32,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
max_token_length=8192,
|
||||
wandb_name=name,
|
||||
wandb_name=gym_name,
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
model_name=doctor_model,
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
|
|
@ -148,16 +104,8 @@ class GymDoctorEnv(BaseEnv):
|
|||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
try:
|
||||
wandb_metrics["train/percent_picked_up_passenger"] = sum(
|
||||
self.percent_picked_up_passenger_buffer
|
||||
) / len(self.percent_picked_up_passenger_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
self.percent_picked_up_passenger_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
|
|
@ -170,107 +118,107 @@ class GymDoctorEnv(BaseEnv):
|
|||
async def evaluate(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def get_patient_msg(self, env: gym.Env) -> str:
|
||||
# Call xAI to get a patient message
|
||||
return env.render()
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
|
||||
# Grab a dedicated llm server to take advantage of caching
|
||||
async with self.server.dedicated_server() as server:
|
||||
# env = gym.make(name, render_mode="ansi")
|
||||
# state, info = env.reset(seed=item["seed"]) #FIXME:
|
||||
last_state = None
|
||||
patient_state = []
|
||||
|
||||
patient_messages = []
|
||||
doctor_messages = [
|
||||
{
|
||||
"role" : "system",
|
||||
"content" : doctor_system_prompt
|
||||
}
|
||||
]
|
||||
|
||||
patient_profile = random.choice(patient_profiles)
|
||||
|
||||
symptoms = dataset[0]
|
||||
patient_system_prompt = patient_profile.format(symptoms)
|
||||
|
||||
patient_messages = [{"role": "system", "content": patient_system_prompt}]
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="grok-3-latest",
|
||||
messages=patient_messages,
|
||||
)
|
||||
|
||||
patient_msg = completion.choices[0].message
|
||||
|
||||
doctor_state = []
|
||||
doctor_messages.append({"role": "user", "content": patient_msg})
|
||||
patient_messages.append({"role": "assistant", "content": patient_msg})
|
||||
|
||||
|
||||
# taxi_row, taxi_col, pass_idx, dest_idx = decode(state)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
init_msg
|
||||
|
||||
# init_msg = f"{start_msg}\n\n" + state_render_to_user_msg(
|
||||
# last_state, state, info["action_mask"], env.render()
|
||||
# )
|
||||
messages = [{"role": "user", "content": init_msg}]
|
||||
score = -1
|
||||
while True:
|
||||
if (
|
||||
len(self.tokenizer.apply_chat_template(messages))
|
||||
len(self.tokenizer.apply_chat_template(doctor_messages))
|
||||
> self.config.max_token_length - 10
|
||||
):
|
||||
score = 0
|
||||
break
|
||||
max_tokens = self.config.max_token_length - len(
|
||||
self.tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
doctor_messages, add_generation_prompt=True
|
||||
)
|
||||
)
|
||||
chat_completions = await server.chat_completion(
|
||||
messages=messages,
|
||||
doctor_completions = await server.chat_completion(
|
||||
messages=doctor_messages,
|
||||
n=1,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
choice = (
|
||||
chat_completions.choices[0]
|
||||
.message.content.strip()
|
||||
.replace(".", "")[-1]
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": chat_completions.choices[0].message.content,
|
||||
}
|
||||
)
|
||||
if choice.isdigit() and 0 <= int(choice) <= 5:
|
||||
action = int(choice)
|
||||
else:
|
||||
|
||||
doctor_msg = doctor_completions.choices[0].message.content
|
||||
|
||||
doctor_messages.append({"role": "assistant", "content": doctor_msg})
|
||||
patient_messages.append({"role": "user", "content": doctor_msg})
|
||||
|
||||
# check output
|
||||
if doctor_msg.startwith(final_message):
|
||||
diagnosis = doctor_msg.strip(final_message)
|
||||
diagnosis = diagnosis.strip()
|
||||
|
||||
if diagnosis == item["diagnosis"]:
|
||||
score = 1
|
||||
else:
|
||||
score = 0
|
||||
break
|
||||
if info["action_mask"][action] == 0:
|
||||
break
|
||||
if action == 3:
|
||||
# picked up passenger
|
||||
score = 0
|
||||
next_state, reward, terminated, truncated, info = env.step(action)
|
||||
last_state = state
|
||||
state = next_state
|
||||
if terminated:
|
||||
score = 1
|
||||
break
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": state_render_to_user_msg(
|
||||
last_state, state, info["action_mask"], env.render()
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="grok-3-latest",
|
||||
messages=patient_messages,
|
||||
)
|
||||
|
||||
patient_msg = completion.choices[0].message
|
||||
|
||||
doctor_messages.append({"role": "user", "content": patient_msg})
|
||||
patient_messages.append({"role": "assistant", "content": patient_msg})
|
||||
|
||||
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
self.percent_picked_up_passenger_buffer.append(1 if score >= 0 else 0)
|
||||
tokens = self.tokenizer.apply_chat_template(messages)
|
||||
tokens = self.tokenizer.apply_chat_template(doctor_messages)
|
||||
|
||||
|
||||
masks = []
|
||||
for i, msg in enumerate(messages):
|
||||
if i == len(messages) - 1:
|
||||
for i, msg in enumerate(doctor_messages):
|
||||
if i == len(doctor_messages) - 1:
|
||||
masks.extend(tokens[len(masks) :])
|
||||
else:
|
||||
curr_tokens = self.tokenizer.apply_chat_template(
|
||||
messages[: i + 1],
|
||||
add_generation_prompt=messages[i + 1]["role"] == "assistant",
|
||||
doctor_messages[: i + 1],
|
||||
add_generation_prompt=doctor_messages[i + 1]["role"] == "assistant",
|
||||
)
|
||||
if messages[i]["role"] == "user":
|
||||
if doctor_messages[i]["role"] == "user":
|
||||
masks.extend([-100] * (len(curr_tokens) - len(masks)))
|
||||
else:
|
||||
masks.extend(curr_tokens[len(masks) :])
|
||||
|
||||
scored_data_item = ScoredDataItem(
|
||||
messages=messages,
|
||||
messages=doctor_messages,
|
||||
finish_reason=score,
|
||||
tokens=tokens,
|
||||
masks=masks,
|
||||
|
|
@ -284,5 +232,5 @@ class GymDoctorEnv(BaseEnv):
|
|||
return next_item
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GymDoctorEnv.cli()
|
||||
# if __name__ == "__main__":
|
||||
# GymTaxiEnv.cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue