diff --git a/.gitignore b/.gitignore index 4320e139..bae0815d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +# keys +*/secrets.json # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/environments/hack0/doctor_agent/doctor.py b/environments/hack0/doctor_agent/doctor.py new file mode 100644 index 00000000..de99abde --- /dev/null +++ b/environments/hack0/doctor_agent/doctor.py @@ -0,0 +1,308 @@ +import json +import random +from typing import Dict, List, Optional, Sequence, Tuple, TypedDict +import os +from datasets import load_dataset +from openai import OpenAI + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + ScoredDataItem, + ScoredDataGroup, +) +# from atroposlib.envs.base import ( +# BaseEnv, +# BaseEnvConfig, +# EvalHandlingEnum, +# Item, +# APIServerConfig, +# ) +from atroposlib.type_definitions import Item + +from environments.hack0.doctor_agent.patient import patient_profiles + +DatasetItem = TypedDict( + "DatasetItem", + { + "question": str, + "answer": str, + "options": Dict[str, str], + "meta_info": str, + "answer_idx": str, + "diagnosis": str, + "metamap_sequence": Sequence[str], + }, +) + +with open("environments/hack0/doctor_agent/secrets.json", "r") as f: + keys = json.load(f) + xai_key = keys["xai"] + + +client = OpenAI( + api_key=xai_key, + base_url="https://api.x.ai/v1", +) + +final_message = "The diagnosis is:" +final_message_prompt = final_message + " headache" + +doctor_system_prompt = """You are a doctor. You are interacting with a patient. +You need to diagnose the patient based on the symptoms. +You will need to ask the patient follow up questions to diagnose them. +Ask up to 10 follow up questions. After that make your diagnosis. +Once you are confident in your diagnosis, provide it in the format: + +The diagnosis is: {possible_illness} +""" +# ## For example, + +# user: I have a headache. +# assistant: What is the severity of your headache? +# user: It's a 3/10. +# assistant: What is the location of your headache? +# user: It's in the front of my head. +# assistant: What is the duration of your headache? +# user: It's been going on for 2 days. +# assistant: The patient is diagnosed with headache +# """ + + +doctor_model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview" + +USER_TAG = "user" +ASSISTANT_TAG = "assistant" + +class DoctorEnv(BaseEnv): + + name = "doctor" + + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[APIServerConfig], + slurm=False, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + # Add tracking for wandb visualizations + self.rollouts_for_wandb = [] + self.completion_lengths = [] + self.print_this_env = False + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: + env_config = BaseEnvConfig( + tokenizer_name=doctor_model, + group_size=32, + use_wandb=True, + rollout_server_url="http://localhost:8000", + wandb_name="doctor", + max_num_workers=128, + total_steps=100, + batch_size=1024, + steps_per_eval=1, + max_token_length=1024 * 15, + inference_weight=1.0, + data_path_to_save_groups=None, + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, + debug_mode=True + ) + server_configs = [ + APIServerConfig( + # model_name=doctor_model, + # base_url="http://localhost:9001/v1", + # api_key="x", + # num_requests_for_eval=256, + model_name="grok-3-latest", + base_url=None, + api_key=os.environ.get("OPENAI_API_KEY"), + num_requests_for_eval=256, + ), + ] + + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Try to calculate percent_correct, pass if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + # Call the parent method to handle the server metrics + await super().wandb_log(wandb_metrics) + + async def setup(self): + """ + Set up the environment by loading and preparing the dataset. + """ + # Load the full dataset + full_dataset = load_dataset("GBaker/MedQA-USMLE-4-options") + + full_dataset = full_dataset.shuffle(seed=42) + + # Keep the splits as is - no need to reformat + self.train = full_dataset["train"] + # Limit test set size to prevent evaluation from taking too long + self.test = full_dataset["test"].select( + range(min(128, len(full_dataset["test"]))) + ) + + # Print some dataset statistics + print( + f"Loaded dataset with {len(self.train)} training examples and {len(self.test)} test examples" + ) + print(f"Example item format: {self.train[0]}") + + # Initialize iteration counter + self.iter = 0 + + async def evaluate(self, *args, **kwargs): + pass + + async def collect_trajectory( + self, item: Item + ) -> Tuple[Optional[ScoredDataItem], List[Item]]: + # Grab a dedicated llm server to take advantage of caching + async with self.server.dedicated_server() as server: + + scores = ScoredDataGroup() + scores["scores"] = list() + + patient_messages = [] + doctor_messages = [{"role": "system", "content": doctor_system_prompt}] + + patient_profile = random.choice(patient_profiles) + symptoms = item["question"] + patient_system_prompt = patient_profile.format(symptoms = symptoms) + + patient_messages = [{"role": "system", "content": patient_system_prompt}] + # print("before xai message") + completion = client.chat.completions.create( + model="grok-3-latest", + messages=patient_messages, + ) + + patient_msg = completion.choices[0].message.content + # print("patient message", patient_msg) + + print("patient message", patient_msg) + + doctor_messages.append({"role": USER_TAG, "content": patient_msg}) + patient_messages.append({"role": ASSISTANT_TAG, "content": patient_msg}) + # print("after xai message") + score = -1 + while True: + if ( + len(self.tokenizer.apply_chat_template(doctor_messages)) + > self.config.max_token_length - 10 + ): + score = 0 + break + max_tokens = self.config.max_token_length - len( + self.tokenizer.apply_chat_template( + doctor_messages, add_generation_prompt=True + ) + ) + # print("before doctor response") + # print("messages", doctor_messages) + doctor_completions = await server.chat_completion( + messages=doctor_messages, + n=1, + max_tokens=max_tokens, + ) + + doctor_msg = doctor_completions.choices[0].message.content + print("doctor message", doctor_msg) + + # print("doctor message", doctor_msg) + + doctor_messages.append({"role": ASSISTANT_TAG, "content": doctor_msg}) + patient_messages.append({"role": USER_TAG, "content": doctor_msg}) + # print("after doctor response") + # check output + if doctor_msg.startswith(final_message): + diagnosis = doctor_msg.strip(final_message) + diagnosis = diagnosis.strip() + + if diagnosis.contains(item["answer"]): + score = 1 + else: + score = 0 + break + + completion = client.chat.completions.create( + model="grok-3-latest", + messages=patient_messages, + ) + + patient_msg = completion.choices[0].message.content + + doctor_messages.append({"role": USER_TAG, "content": patient_msg}) + patient_messages.append({"role": ASSISTANT_TAG, "content": patient_msg}) + + self.percent_correct_buffer.append(max(score, 0)) + tokens = self.tokenizer.apply_chat_template(doctor_messages) + + masks = [] + for i, msg in enumerate(doctor_messages): + if i == len(doctor_messages) - 1: + masks.extend(tokens[len(masks) :]) + else: + curr_tokens = self.tokenizer.apply_chat_template( + doctor_messages[: i + 1], + add_generation_prompt=doctor_messages[i + 1]["role"] + == ASSISTANT_TAG, + ) + if doctor_messages[i]["role"] == USER_TAG: + masks.extend([-100] * (len(curr_tokens) - len(masks))) + else: + masks.extend(curr_tokens[len(masks) :]) + + scores["scores"].append(1.0 if score else -1.0) + + scored_data_item = ScoredDataItem( + messages=doctor_messages, + finish_reason=score, + tokens=tokens, + masks=masks, + scores=score, + ) + + for score in scores["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + + return scored_data_item, [] + + async def get_next_item(self): + """ + Get the next training item from the dataset. + + Returns: + A tuple containing prompt and expected answer + """ + next_item: DatasetItem = self.train[self.iter % len(self.train)] + self.iter += 1 + + return next_item + + +if __name__ == "__main__": + DoctorEnv.cli() diff --git a/environments/hack0/doctor_agent/doctor_22.jsonl.zip b/environments/hack0/doctor_agent/doctor_22.jsonl.zip new file mode 100644 index 00000000..20b1d440 Binary files /dev/null and b/environments/hack0/doctor_agent/doctor_22.jsonl.zip differ diff --git a/environments/hack0/doctor_agent/patient.py b/environments/hack0/doctor_agent/patient.py new file mode 100644 index 00000000..beb044d1 --- /dev/null +++ b/environments/hack0/doctor_agent/patient.py @@ -0,0 +1,61 @@ +patient_profiles = [ + """ + You are an uneasy patient interacting with a doctor. + + Here are your symptoms: + {symptoms}. + + Do not give the symptoms directly to the doctor in a single answer. + + You are trying to get a diagnosis for your symptoms. + + The doctor will ask you follow up questions to diagnose you. + + You will need to answer the doctor's questions to get a diagnosis. + + Since you are uneasy, you will not answer the doctor's questions directly. + + You will answer the doctor's questions in a way that is not too direct, + but still gives the doctor enough information to diagnose you. + + You will also not answer the doctor's questions with a yes or no. + + You will answer the doctor's questions with a short answer. + """, + """ + You are a brief but factually consistent patient interacting with a doctor. + + Here are your symptoms: + {symptoms}. + + You are trying to get a diagnosis for your symptoms. + + The doctor will ask you follow up questions to diagnose you. + + You will answer the doctor's questions in a way that is not too direct, + but still gives the doctor enough information to diagnose you. + + You will also not answer the doctor's questions with a yes or no. + + You will answer the doctor's questions with a short answer. + """, + """ + You are an open, verbose, and highly informative patient interacting with a doctor. + + Here are your symptoms: + {symptoms}. + + You are trying to get a diagnosis for your symptoms. + + The doctor will ask you follow up questions to diagnose you. + + You will provide the doctor will some suggestions as to what you think the diagnosis is. + + You will answer the doctor's questions in a way that is not too direct, + but still gives the doctor enough information to diagnose you. + + You will also not answer the doctor's questions with a yes or no. + + You will answer the doctor's questions with a long answer. + """, +] diff --git a/environments/hack0/doctor_agent/readme.md b/environments/hack0/doctor_agent/readme.md new file mode 100644 index 00000000..14a37b69 --- /dev/null +++ b/environments/hack0/doctor_agent/readme.md @@ -0,0 +1,32 @@ +Persona-Aware MedQA Benchmarking +https://youtube.com/shorts/02GEURik0PQ + +Wandb: https://wandb.ai/nous-hackathon-2/atropos-environments_hack0_doctor_agent?nw=nwusertsadpbb +We intended on adding a simple percentage accurate score but couldn't get it done in time :( + +In this project, we reimagined medical QA evaluation by introducing a persona filter—a novel layer that simulates real-world variability in patient communication styles. Leveraging the MedQA dataset as our foundation, we infused each scenario with distinct personas generated via xAI’s language models: + +1. The Cooperative Patient – open, verbose, and highly informative. +2. The Reluctant Patient – terse, vague, and occasionally evasive. +3. The Neutral Patient – brief but factually consistent. + +The clinical challenge we explored is simple but critical: Can a medical reasoning system consistently arrive at the correct diagnosis or treatment recommendation regardless of how the patient presents information? + +Our pipeline works as follows: + +Each original MedQA item (stem + multiple choice answers) is enriched with a synthetic patient interaction that simulates one of the three personas. +We maintain the original clinical question and choices. +Only the narrative context—the patient's communication—changes, testing robustness against dialogue variability. +This mirrors how real doctors must interpret patient symptoms, which are often incomplete or colored by personality, emotion, or context. + +Why this matters: +Most QA benchmarks assume a perfect narrator. But in the real world, AI systems in healthcare will need to make decisions with varying degrees of input clarity. + +Our approach stress-tests reasoning models under more human-like variability, offering a path toward safer and more empathetic medical AI. + +Future Potential: + +Extendable to reinforcement learning pipelines where the agent adapts its questioning strategy based on persona. +Can be used to benchmark bedside AI assistants, triage bots, or LLMs deployed in low-resource clinics. +Encourages development of models that ask better follow-up questions, not just give answers. +By combining structured medical QA with naturalistic persona variation, our project brings a crucial human dimension to the next generation of AI-health interfaces.