diff --git a/environments/hack0/doctor_agent/doctor.py b/environments/hack0/doctor_agent/doctor.py
index 6f7b3d06..22245438 100644
--- a/environments/hack0/doctor_agent/doctor.py
+++ b/environments/hack0/doctor_agent/doctor.py
@@ -40,32 +40,32 @@ client = OpenAI(
)
final_message = "The diagnosis is:"
-final_message_prompt = final_message + "headache"
+final_message_prompt = final_message + " headache"
-doctor_system_prompt = """
-You are a doctor. You are interacting with a patient.
+doctor_system_prompt = """You are a doctor. You are interacting with a patient.
You need to diagnose the patient based on the symptoms.
You will need to ask the patient follow up questions to diagnose them.
Once you are confident in your diagnosis, provide it in the format:
-The patient is diagnosed with {possible_illness}.
-
-For example,
-
-user: I have a headache.
-assistant: What is the severity of your headache?
-user: It's a 3/10.
-assistant: What is the location of your headache?
-user: It's in the front of my head.
-assistant: What is the duration of your headache?
-user: It's been going on for 2 days.
-assistant: The patient is diagnosed with headache
+The diagnosis is: {possible_illness}
"""
+# ## For example,
+
+# user: I have a headache.
+# assistant: What is the severity of your headache?
+# user: It's a 3/10.
+# assistant: What is the location of your headache?
+# user: It's in the front of my head.
+# assistant: What is the duration of your headache?
+# user: It's been going on for 2 days.
+# assistant: The patient is diagnosed with headache
+# """
doctor_model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
-wandb_name = "doctor"
+USER_TAG = "user"
+ASSISTANT_TAG = "assistant"
class DoctorEnv(BaseEnv):
@@ -93,7 +93,7 @@ class DoctorEnv(BaseEnv):
group_size=32,
use_wandb=True,
rollout_server_url="http://localhost:8000",
- wandb_name=wandb_name,
+ wandb_name="doctor",
max_num_workers=128,
total_steps=100,
batch_size=1024,
@@ -182,17 +182,18 @@ class DoctorEnv(BaseEnv):
patient_system_prompt = patient_profile.format(symptoms = symptoms)
patient_messages = [{"role": "system", "content": patient_system_prompt}]
- print("before xai message")
+ # print("before xai message")
completion = client.chat.completions.create(
model="grok-3-latest",
messages=patient_messages,
)
- patient_msg = completion.choices[0].message
+ patient_msg = completion.choices[0].message.content
+ # print("patient message", patient_msg)
- doctor_messages.append({"role": "user", "content": patient_msg})
- patient_messages.append({"role": "assistant", "content": patient_msg})
- print("after xai message")
+ doctor_messages.append({"role": USER_TAG, "content": patient_msg})
+ patient_messages.append({"role": ASSISTANT_TAG, "content": patient_msg})
+ # print("after xai message")
score = -1
while True:
if (
@@ -206,24 +207,26 @@ class DoctorEnv(BaseEnv):
doctor_messages, add_generation_prompt=True
)
)
- print("before doctor response")
+ # print("before doctor response")
+ # print("messages", doctor_messages)
doctor_completions = await server.chat_completion(
- messages=doctor_messages,
+ messages=[{"role" : USER_TAG, "content": "test"}],
n=1,
max_tokens=max_tokens,
)
doctor_msg = doctor_completions.choices[0].message.content
+ # print("doctor message", doctor_msg)
- doctor_messages.append({"role": "assistant", "content": doctor_msg})
- patient_messages.append({"role": "user", "content": doctor_msg})
- print("after doctor response")
+ doctor_messages.append({"role": ASSISTANT_TAG, "content": doctor_msg})
+ patient_messages.append({"role": USER_TAG, "content": doctor_msg})
+ # print("after doctor response")
# check output
- if doctor_msg.startwith(final_message):
+ if doctor_msg.startswith(final_message):
diagnosis = doctor_msg.strip(final_message)
- diagnosis = diagnosis.strip()
+ diagnosis = diagnosis.strip().lower()
- if diagnosis.contains(item["answer"]):
+ if diagnosis.contains(item["answer"].lower()):
score = 1
else:
score = 0
@@ -234,10 +237,10 @@ class DoctorEnv(BaseEnv):
messages=patient_messages,
)
- patient_msg = completion.choices[0].message
+ patient_msg = completion.choices[0].message.content
- doctor_messages.append({"role": "user", "content": patient_msg})
- patient_messages.append({"role": "assistant", "content": patient_msg})
+ doctor_messages.append({"role": USER_TAG, "content": patient_msg})
+ patient_messages.append({"role": ASSISTANT_TAG, "content": patient_msg})
self.percent_correct_buffer.append(max(score, 0))
tokens = self.tokenizer.apply_chat_template(doctor_messages)
@@ -250,9 +253,9 @@ class DoctorEnv(BaseEnv):
curr_tokens = self.tokenizer.apply_chat_template(
doctor_messages[: i + 1],
add_generation_prompt=doctor_messages[i + 1]["role"]
- == "assistant",
+ == ASSISTANT_TAG,
)
- if doctor_messages[i]["role"] == "user":
+ if doctor_messages[i]["role"] == USER_TAG:
masks.extend([-100] * (len(curr_tokens) - len(masks)))
else:
masks.extend(curr_tokens[len(masks) :])