diff --git a/environments/hack0/doctor_agent/doctor.py b/environments/hack0/doctor_agent/doctor.py index 6cf213f7..6f7b3d06 100644 --- a/environments/hack0/doctor_agent/doctor.py +++ b/environments/hack0/doctor_agent/doctor.py @@ -1,7 +1,7 @@ import json import random from typing import Dict, List, Optional, Sequence, Tuple, TypedDict - +import os from datasets import load_dataset from openai import OpenAI @@ -14,7 +14,7 @@ from atroposlib.envs.base import ( ) from atroposlib.type_definitions import Item -from .patient import patient_profiles +from environments.hack0.doctor_agent.patient import patient_profiles DatasetItem = TypedDict( "DatasetItem", @@ -29,7 +29,7 @@ DatasetItem = TypedDict( }, ) -with open("environments/hack0/doctor_agent", "r") as f: +with open("environments/hack0/doctor_agent/secrets.json", "r") as f: keys = json.load(f) xai_key = keys["xai"] @@ -103,12 +103,17 @@ class DoctorEnv(BaseEnv): data_path_to_save_groups=None, eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, + debug_mode=True ) server_configs = [ APIServerConfig( - model_name=doctor_model, - base_url="http://localhost:9001/v1", - api_key="x", + # model_name=doctor_model, + # base_url="http://localhost:9001/v1", + # api_key="x", + # num_requests_for_eval=256, + model_name="grok-3-latest", + base_url=None, + api_key=os.environ.get("OPENAI_API_KEY"), num_requests_for_eval=256, ), ] @@ -174,10 +179,10 @@ class DoctorEnv(BaseEnv): patient_profile = random.choice(patient_profiles) symptoms = item["question"] - patient_system_prompt = patient_profile.format(symptoms) + patient_system_prompt = patient_profile.format(symptoms = symptoms) patient_messages = [{"role": "system", "content": patient_system_prompt}] - + print("before xai message") completion = client.chat.completions.create( model="grok-3-latest", messages=patient_messages, @@ -187,7 +192,7 @@ class DoctorEnv(BaseEnv): doctor_messages.append({"role": "user", "content": patient_msg}) patient_messages.append({"role": "assistant", "content": patient_msg}) - + print("after xai message") score = -1 while True: if ( @@ -201,6 +206,7 @@ class DoctorEnv(BaseEnv): doctor_messages, add_generation_prompt=True ) ) + print("before doctor response") doctor_completions = await server.chat_completion( messages=doctor_messages, n=1, @@ -211,7 +217,7 @@ class DoctorEnv(BaseEnv): doctor_messages.append({"role": "assistant", "content": doctor_msg}) patient_messages.append({"role": "user", "content": doctor_msg}) - + print("after doctor response") # check output if doctor_msg.startwith(final_message): diagnosis = doctor_msg.strip(final_message)