mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
more fixes
This commit is contained in:
parent
ebed149548
commit
8deeb3a339
1 changed files with 38 additions and 35 deletions
|
|
@ -40,32 +40,32 @@ client = OpenAI(
|
|||
)
|
||||
|
||||
final_message = "The diagnosis is:"
|
||||
final_message_prompt = final_message + "<diagnosis>headache</diagnosis>"
|
||||
final_message_prompt = final_message + " headache"
|
||||
|
||||
doctor_system_prompt = """
|
||||
You are a doctor. You are interacting with a patient.
|
||||
doctor_system_prompt = """You are a doctor. You are interacting with a patient.
|
||||
You need to diagnose the patient based on the symptoms.
|
||||
You will need to ask the patient follow up questions to diagnose them.
|
||||
Once you are confident in your diagnosis, provide it in the format:
|
||||
|
||||
The patient is diagnosed with <diagnosis>{possible_illness}.</diagnosis>
|
||||
|
||||
For example,
|
||||
|
||||
user: I have a headache.
|
||||
assistant: What is the severity of your headache?
|
||||
user: It's a 3/10.
|
||||
assistant: What is the location of your headache?
|
||||
user: It's in the front of my head.
|
||||
assistant: What is the duration of your headache?
|
||||
user: It's been going on for 2 days.
|
||||
assistant: The patient is diagnosed with <diagnosis>headache</diagnosis>
|
||||
The diagnosis is: {possible_illness}
|
||||
"""
|
||||
# ## For example,
|
||||
|
||||
# user: I have a headache.
|
||||
# assistant: What is the severity of your headache?
|
||||
# user: It's a 3/10.
|
||||
# assistant: What is the location of your headache?
|
||||
# user: It's in the front of my head.
|
||||
# assistant: What is the duration of your headache?
|
||||
# user: It's been going on for 2 days.
|
||||
# assistant: The patient is diagnosed with <diagnosis>headache</diagnosis>
|
||||
# """
|
||||
|
||||
|
||||
doctor_model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
|
||||
wandb_name = "doctor"
|
||||
|
||||
USER_TAG = "user"
|
||||
ASSISTANT_TAG = "assistant"
|
||||
|
||||
class DoctorEnv(BaseEnv):
|
||||
|
||||
|
|
@ -93,7 +93,7 @@ class DoctorEnv(BaseEnv):
|
|||
group_size=32,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
wandb_name=wandb_name,
|
||||
wandb_name="doctor",
|
||||
max_num_workers=128,
|
||||
total_steps=100,
|
||||
batch_size=1024,
|
||||
|
|
@ -182,17 +182,18 @@ class DoctorEnv(BaseEnv):
|
|||
patient_system_prompt = patient_profile.format(symptoms = symptoms)
|
||||
|
||||
patient_messages = [{"role": "system", "content": patient_system_prompt}]
|
||||
print("before xai message")
|
||||
# print("before xai message")
|
||||
completion = client.chat.completions.create(
|
||||
model="grok-3-latest",
|
||||
messages=patient_messages,
|
||||
)
|
||||
|
||||
patient_msg = completion.choices[0].message
|
||||
patient_msg = completion.choices[0].message.content
|
||||
# print("patient message", patient_msg)
|
||||
|
||||
doctor_messages.append({"role": "user", "content": patient_msg})
|
||||
patient_messages.append({"role": "assistant", "content": patient_msg})
|
||||
print("after xai message")
|
||||
doctor_messages.append({"role": USER_TAG, "content": patient_msg})
|
||||
patient_messages.append({"role": ASSISTANT_TAG, "content": patient_msg})
|
||||
# print("after xai message")
|
||||
score = -1
|
||||
while True:
|
||||
if (
|
||||
|
|
@ -206,24 +207,26 @@ class DoctorEnv(BaseEnv):
|
|||
doctor_messages, add_generation_prompt=True
|
||||
)
|
||||
)
|
||||
print("before doctor response")
|
||||
# print("before doctor response")
|
||||
# print("messages", doctor_messages)
|
||||
doctor_completions = await server.chat_completion(
|
||||
messages=doctor_messages,
|
||||
messages=[{"role" : USER_TAG, "content": "test"}],
|
||||
n=1,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
doctor_msg = doctor_completions.choices[0].message.content
|
||||
# print("doctor message", doctor_msg)
|
||||
|
||||
doctor_messages.append({"role": "assistant", "content": doctor_msg})
|
||||
patient_messages.append({"role": "user", "content": doctor_msg})
|
||||
print("after doctor response")
|
||||
doctor_messages.append({"role": ASSISTANT_TAG, "content": doctor_msg})
|
||||
patient_messages.append({"role": USER_TAG, "content": doctor_msg})
|
||||
# print("after doctor response")
|
||||
# check output
|
||||
if doctor_msg.startwith(final_message):
|
||||
if doctor_msg.startswith(final_message):
|
||||
diagnosis = doctor_msg.strip(final_message)
|
||||
diagnosis = diagnosis.strip()
|
||||
diagnosis = diagnosis.strip().lower()
|
||||
|
||||
if diagnosis.contains(item["answer"]):
|
||||
if diagnosis.contains(item["answer"].lower()):
|
||||
score = 1
|
||||
else:
|
||||
score = 0
|
||||
|
|
@ -234,10 +237,10 @@ class DoctorEnv(BaseEnv):
|
|||
messages=patient_messages,
|
||||
)
|
||||
|
||||
patient_msg = completion.choices[0].message
|
||||
patient_msg = completion.choices[0].message.content
|
||||
|
||||
doctor_messages.append({"role": "user", "content": patient_msg})
|
||||
patient_messages.append({"role": "assistant", "content": patient_msg})
|
||||
doctor_messages.append({"role": USER_TAG, "content": patient_msg})
|
||||
patient_messages.append({"role": ASSISTANT_TAG, "content": patient_msg})
|
||||
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
tokens = self.tokenizer.apply_chat_template(doctor_messages)
|
||||
|
|
@ -250,9 +253,9 @@ class DoctorEnv(BaseEnv):
|
|||
curr_tokens = self.tokenizer.apply_chat_template(
|
||||
doctor_messages[: i + 1],
|
||||
add_generation_prompt=doctor_messages[i + 1]["role"]
|
||||
== "assistant",
|
||||
== ASSISTANT_TAG,
|
||||
)
|
||||
if doctor_messages[i]["role"] == "user":
|
||||
if doctor_messages[i]["role"] == USER_TAG:
|
||||
masks.extend([-100] * (len(curr_tokens) - len(masks)))
|
||||
else:
|
||||
masks.extend(curr_tokens[len(masks) :])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue