mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
fixes
This commit is contained in:
parent
1489b02bbb
commit
ebed149548
1 changed files with 16 additions and 10 deletions
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import random
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, TypedDict
|
||||
|
||||
import os
|
||||
from datasets import load_dataset
|
||||
from openai import OpenAI
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ from atroposlib.envs.base import (
|
|||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from .patient import patient_profiles
|
||||
from environments.hack0.doctor_agent.patient import patient_profiles
|
||||
|
||||
DatasetItem = TypedDict(
|
||||
"DatasetItem",
|
||||
|
|
@ -29,7 +29,7 @@ DatasetItem = TypedDict(
|
|||
},
|
||||
)
|
||||
|
||||
with open("environments/hack0/doctor_agent", "r") as f:
|
||||
with open("environments/hack0/doctor_agent/secrets.json", "r") as f:
|
||||
keys = json.load(f)
|
||||
xai_key = keys["xai"]
|
||||
|
||||
|
|
@ -103,12 +103,17 @@ class DoctorEnv(BaseEnv):
|
|||
data_path_to_save_groups=None,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
debug_mode=True
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=doctor_model,
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
# model_name=doctor_model,
|
||||
# base_url="http://localhost:9001/v1",
|
||||
# api_key="x",
|
||||
# num_requests_for_eval=256,
|
||||
model_name="grok-3-latest",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
|
@ -174,10 +179,10 @@ class DoctorEnv(BaseEnv):
|
|||
|
||||
patient_profile = random.choice(patient_profiles)
|
||||
symptoms = item["question"]
|
||||
patient_system_prompt = patient_profile.format(symptoms)
|
||||
patient_system_prompt = patient_profile.format(symptoms = symptoms)
|
||||
|
||||
patient_messages = [{"role": "system", "content": patient_system_prompt}]
|
||||
|
||||
print("before xai message")
|
||||
completion = client.chat.completions.create(
|
||||
model="grok-3-latest",
|
||||
messages=patient_messages,
|
||||
|
|
@ -187,7 +192,7 @@ class DoctorEnv(BaseEnv):
|
|||
|
||||
doctor_messages.append({"role": "user", "content": patient_msg})
|
||||
patient_messages.append({"role": "assistant", "content": patient_msg})
|
||||
|
||||
print("after xai message")
|
||||
score = -1
|
||||
while True:
|
||||
if (
|
||||
|
|
@ -201,6 +206,7 @@ class DoctorEnv(BaseEnv):
|
|||
doctor_messages, add_generation_prompt=True
|
||||
)
|
||||
)
|
||||
print("before doctor response")
|
||||
doctor_completions = await server.chat_completion(
|
||||
messages=doctor_messages,
|
||||
n=1,
|
||||
|
|
@ -211,7 +217,7 @@ class DoctorEnv(BaseEnv):
|
|||
|
||||
doctor_messages.append({"role": "assistant", "content": doctor_msg})
|
||||
patient_messages.append({"role": "user", "content": doctor_msg})
|
||||
|
||||
print("after doctor response")
|
||||
# check output
|
||||
if doctor_msg.startwith(final_message):
|
||||
diagnosis = doctor_msg.strip(final_message)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue