more fixes

This commit is contained in:
FIRST_NAME LAST_NAME 2025-05-19 00:22:06 +00:00
parent ebed149548
commit 8deeb3a339

View file

@ -40,32 +40,32 @@ client = OpenAI(
)
final_message = "The diagnosis is:"
final_message_prompt = final_message + "<diagnosis>headache</diagnosis>"
final_message_prompt = final_message + " headache"
doctor_system_prompt = """
You are a doctor. You are interacting with a patient.
doctor_system_prompt = """You are a doctor. You are interacting with a patient.
You need to diagnose the patient based on the symptoms.
You will need to ask the patient follow up questions to diagnose them.
Once you are confident in your diagnosis, provide it in the format:
The patient is diagnosed with <diagnosis>{possible_illness}.</diagnosis>
For example,
user: I have a headache.
assistant: What is the severity of your headache?
user: It's a 3/10.
assistant: What is the location of your headache?
user: It's in the front of my head.
assistant: What is the duration of your headache?
user: It's been going on for 2 days.
assistant: The patient is diagnosed with <diagnosis>headache</diagnosis>
The diagnosis is: {possible_illness}
"""
# ## For example,
# user: I have a headache.
# assistant: What is the severity of your headache?
# user: It's a 3/10.
# assistant: What is the location of your headache?
# user: It's in the front of my head.
# assistant: What is the duration of your headache?
# user: It's been going on for 2 days.
# assistant: The patient is diagnosed with <diagnosis>headache</diagnosis>
# """
doctor_model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
wandb_name = "doctor"
USER_TAG = "user"
ASSISTANT_TAG = "assistant"
class DoctorEnv(BaseEnv):
@ -93,7 +93,7 @@ class DoctorEnv(BaseEnv):
group_size=32,
use_wandb=True,
rollout_server_url="http://localhost:8000",
wandb_name=wandb_name,
wandb_name="doctor",
max_num_workers=128,
total_steps=100,
batch_size=1024,
@ -182,17 +182,18 @@ class DoctorEnv(BaseEnv):
patient_system_prompt = patient_profile.format(symptoms = symptoms)
patient_messages = [{"role": "system", "content": patient_system_prompt}]
print("before xai message")
# print("before xai message")
completion = client.chat.completions.create(
model="grok-3-latest",
messages=patient_messages,
)
patient_msg = completion.choices[0].message
patient_msg = completion.choices[0].message.content
# print("patient message", patient_msg)
doctor_messages.append({"role": "user", "content": patient_msg})
patient_messages.append({"role": "assistant", "content": patient_msg})
print("after xai message")
doctor_messages.append({"role": USER_TAG, "content": patient_msg})
patient_messages.append({"role": ASSISTANT_TAG, "content": patient_msg})
# print("after xai message")
score = -1
while True:
if (
@ -206,24 +207,26 @@ class DoctorEnv(BaseEnv):
doctor_messages, add_generation_prompt=True
)
)
print("before doctor response")
# print("before doctor response")
# print("messages", doctor_messages)
doctor_completions = await server.chat_completion(
messages=doctor_messages,
messages=[{"role" : USER_TAG, "content": "test"}],
n=1,
max_tokens=max_tokens,
)
doctor_msg = doctor_completions.choices[0].message.content
# print("doctor message", doctor_msg)
doctor_messages.append({"role": "assistant", "content": doctor_msg})
patient_messages.append({"role": "user", "content": doctor_msg})
print("after doctor response")
doctor_messages.append({"role": ASSISTANT_TAG, "content": doctor_msg})
patient_messages.append({"role": USER_TAG, "content": doctor_msg})
# print("after doctor response")
# check output
if doctor_msg.startwith(final_message):
if doctor_msg.startswith(final_message):
diagnosis = doctor_msg.strip(final_message)
diagnosis = diagnosis.strip()
diagnosis = diagnosis.strip().lower()
if diagnosis.contains(item["answer"]):
if diagnosis.contains(item["answer"].lower()):
score = 1
else:
score = 0
@ -234,10 +237,10 @@ class DoctorEnv(BaseEnv):
messages=patient_messages,
)
patient_msg = completion.choices[0].message
patient_msg = completion.choices[0].message.content
doctor_messages.append({"role": "user", "content": patient_msg})
patient_messages.append({"role": "assistant", "content": patient_msg})
doctor_messages.append({"role": USER_TAG, "content": patient_msg})
patient_messages.append({"role": ASSISTANT_TAG, "content": patient_msg})
self.percent_correct_buffer.append(max(score, 0))
tokens = self.tokenizer.apply_chat_template(doctor_messages)
@ -250,9 +253,9 @@ class DoctorEnv(BaseEnv):
curr_tokens = self.tokenizer.apply_chat_template(
doctor_messages[: i + 1],
add_generation_prompt=doctor_messages[i + 1]["role"]
== "assistant",
== ASSISTANT_TAG,
)
if doctor_messages[i]["role"] == "user":
if doctor_messages[i]["role"] == USER_TAG:
masks.extend([-100] * (len(curr_tokens) - len(masks)))
else:
masks.extend(curr_tokens[len(masks) :])