diff --git a/environments/hack0/doctor_agent/datasets.py b/environments/hack0/doctor_agent/datasets.py deleted file mode 100644 index f8b1548c..00000000 --- a/environments/hack0/doctor_agent/datasets.py +++ /dev/null @@ -1 +0,0 @@ -dataset = [] \ No newline at end of file diff --git a/environments/hack0/doctor_agent/doctor_gym.py b/environments/hack0/doctor_agent/doctor.py similarity index 74% rename from environments/hack0/doctor_agent/doctor_gym.py rename to environments/hack0/doctor_agent/doctor.py index 48fee2a3..6cf213f7 100644 --- a/environments/hack0/doctor_agent/doctor_gym.py +++ b/environments/hack0/doctor_agent/doctor.py @@ -1,20 +1,34 @@ import json import random -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Sequence, Tuple, TypedDict -from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem -from atroposlib.type_definitions import Item -from atropos.environments.hack0.doctor_agent.patient import patient_profiles -from atropos.environments.hack0.doctor_agent.datasets import dataset -import re -from typing import Dict, List, Optional, Tuple -import os +from datasets import load_dataset from openai import OpenAI -import gymnasium as gym -from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + ScoredDataItem, +) from atroposlib.type_definitions import Item +from .patient import patient_profiles + +DatasetItem = TypedDict( + "DatasetItem", + { + "question": str, + "answer": str, + "options": Dict[str, str], + "meta_info": str, + "answer_idx": str, + "diagnosis": str, + "metamap_sequence": Sequence[str], + }, +) + with open("environments/hack0/doctor_agent", "r") as f: keys = json.load(f) xai_key = keys["xai"] @@ -50,7 +64,8 @@ assistant: The patient is diagnosed with headache doctor_model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview" -gym_name = "gym_doctor" +wandb_name = "doctor" + class DoctorEnv(BaseEnv): @@ -60,7 +75,7 @@ class DoctorEnv(BaseEnv): self, config: BaseEnvConfig, server_configs: List[APIServerConfig], - slurm=True, + slurm=False, testing=False, ): super().__init__(config, server_configs, slurm, testing) @@ -78,8 +93,16 @@ class DoctorEnv(BaseEnv): group_size=32, use_wandb=True, rollout_server_url="http://localhost:8000", - max_token_length=8192, - wandb_name=gym_name, + wandb_name=wandb_name, + max_num_workers=128, + total_steps=100, + batch_size=1024, + steps_per_eval=1, + max_token_length=1024 * 15, + inference_weight=1.0, + data_path_to_save_groups=None, + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, ) server_configs = [ APIServerConfig( @@ -113,15 +136,33 @@ class DoctorEnv(BaseEnv): await super().wandb_log(wandb_metrics) async def setup(self): + """ + Set up the environment by loading and preparing the dataset. + """ + # Load the full dataset + full_dataset = load_dataset("GBaker/MedQA-USMLE-4-options") + + full_dataset = full_dataset.shuffle(seed=42) + + # Keep the splits as is - no need to reformat + self.train = full_dataset["train"] + # Limit test set size to prevent evaluation from taking too long + self.test = full_dataset["test"].select( + range(min(128, len(full_dataset["test"]))) + ) + + # Print some dataset statistics + print( + f"Loaded dataset with {len(self.train)} training examples and {len(self.test)} test examples" + ) + print(f"Example item format: {self.train[0]}") + + # Initialize iteration counter self.iter = 0 async def evaluate(self, *args, **kwargs): pass - async def get_patient_msg(self, env: gym.Env) -> str: - # Call xAI to get a patient message - return env.render() - async def collect_trajectory( self, item: Item ) -> Tuple[Optional[ScoredDataItem], List[Item]]: @@ -129,15 +170,10 @@ class DoctorEnv(BaseEnv): async with self.server.dedicated_server() as server: patient_messages = [] - doctor_messages = [ - { - "role" : "system", - "content" : doctor_system_prompt - } - ] + doctor_messages = [{"role": "system", "content": doctor_system_prompt}] patient_profile = random.choice(patient_profiles) - symptoms = dataset[0] + symptoms = item["question"] patient_system_prompt = patient_profile.format(symptoms) patient_messages = [{"role": "system", "content": patient_system_prompt}] @@ -181,13 +217,12 @@ class DoctorEnv(BaseEnv): diagnosis = doctor_msg.strip(final_message) diagnosis = diagnosis.strip() - if diagnosis == item["diagnosis"]: + if diagnosis.contains(item["answer"]): score = 1 else: score = 0 break - completion = client.chat.completions.create( model="grok-3-latest", messages=patient_messages, @@ -198,11 +233,9 @@ class DoctorEnv(BaseEnv): doctor_messages.append({"role": "user", "content": patient_msg}) patient_messages.append({"role": "assistant", "content": patient_msg}) - self.percent_correct_buffer.append(max(score, 0)) tokens = self.tokenizer.apply_chat_template(doctor_messages) - - + masks = [] for i, msg in enumerate(doctor_messages): if i == len(doctor_messages) - 1: @@ -210,13 +243,14 @@ class DoctorEnv(BaseEnv): else: curr_tokens = self.tokenizer.apply_chat_template( doctor_messages[: i + 1], - add_generation_prompt=doctor_messages[i + 1]["role"] == "assistant", + add_generation_prompt=doctor_messages[i + 1]["role"] + == "assistant", ) if doctor_messages[i]["role"] == "user": masks.extend([-100] * (len(curr_tokens) - len(masks))) else: masks.extend(curr_tokens[len(masks) :]) - + scored_data_item = ScoredDataItem( messages=doctor_messages, finish_reason=score, @@ -227,10 +261,17 @@ class DoctorEnv(BaseEnv): return scored_data_item, [] async def get_next_item(self): - next_item = {"seed": self.iter} + """ + Get the next training item from the dataset. + + Returns: + A tuple containing prompt and expected answer + """ + next_item: DatasetItem = self.train[self.iter % len(self.train)] self.iter += 1 + return next_item -# if __name__ == "__main__": -# GymTaxiEnv.cli() +if __name__ == "__main__": + DoctorEnv.cli() diff --git a/environments/hack0/doctor_agent/patient.py b/environments/hack0/doctor_agent/patient.py index be93352d..e5fecdfb 100644 --- a/environments/hack0/doctor_agent/patient.py +++ b/environments/hack0/doctor_agent/patient.py @@ -1,14 +1 @@ - - -patient_profiles = [ - -] - - -from openai import OpenAI - -client = OpenAI( - api_key=XAI_API_KEY, - base_url="https://api.x.ai/v1", -) - +patient_profiles = []