This commit is contained in:
FIRST_NAME LAST_NAME 2025-05-19 00:03:25 +00:00
parent 1489b02bbb
commit ebed149548

View file

@ -1,7 +1,7 @@
import json
import random
from typing import Dict, List, Optional, Sequence, Tuple, TypedDict
import os
from datasets import load_dataset
from openai import OpenAI
@ -14,7 +14,7 @@ from atroposlib.envs.base import (
)
from atroposlib.type_definitions import Item
from .patient import patient_profiles
from environments.hack0.doctor_agent.patient import patient_profiles
DatasetItem = TypedDict(
"DatasetItem",
@ -29,7 +29,7 @@ DatasetItem = TypedDict(
},
)
with open("environments/hack0/doctor_agent", "r") as f:
with open("environments/hack0/doctor_agent/secrets.json", "r") as f:
keys = json.load(f)
xai_key = keys["xai"]
@ -103,12 +103,17 @@ class DoctorEnv(BaseEnv):
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
debug_mode=True
)
server_configs = [
APIServerConfig(
model_name=doctor_model,
base_url="http://localhost:9001/v1",
api_key="x",
# model_name=doctor_model,
# base_url="http://localhost:9001/v1",
# api_key="x",
# num_requests_for_eval=256,
model_name="grok-3-latest",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=256,
),
]
@ -174,10 +179,10 @@ class DoctorEnv(BaseEnv):
patient_profile = random.choice(patient_profiles)
symptoms = item["question"]
patient_system_prompt = patient_profile.format(symptoms)
patient_system_prompt = patient_profile.format(symptoms = symptoms)
patient_messages = [{"role": "system", "content": patient_system_prompt}]
print("before xai message")
completion = client.chat.completions.create(
model="grok-3-latest",
messages=patient_messages,
@ -187,7 +192,7 @@ class DoctorEnv(BaseEnv):
doctor_messages.append({"role": "user", "content": patient_msg})
patient_messages.append({"role": "assistant", "content": patient_msg})
print("after xai message")
score = -1
while True:
if (
@ -201,6 +206,7 @@ class DoctorEnv(BaseEnv):
doctor_messages, add_generation_prompt=True
)
)
print("before doctor response")
doctor_completions = await server.chat_completion(
messages=doctor_messages,
n=1,
@ -211,7 +217,7 @@ class DoctorEnv(BaseEnv):
doctor_messages.append({"role": "assistant", "content": doctor_msg})
patient_messages.append({"role": "user", "content": doctor_msg})
print("after doctor response")
# check output
if doctor_msg.startwith(final_message):
diagnosis = doctor_msg.strip(final_message)