diff --git a/environments/.DS_Store b/environments/.DS_Store index d0cc1c07..3bb66c0f 100644 Binary files a/environments/.DS_Store and b/environments/.DS_Store differ diff --git a/environments/community/README.md b/environments/community/README.md index 51716168..614fcfec 100644 --- a/environments/community/README.md +++ b/environments/community/README.md @@ -1373,6 +1373,171 @@ environments/community/pytorch_optimizer_coding/ **Requirements**: modal, verdict, torch, lightning, transformers, datasets, pydantic, atroposlib +### 20. Helpful Doctors - Persona-Aware Medical QA Environment (`helpful_doctors/`) +**Author**: [tsadpbb](https://github.com/tsadpbb) with [AlxSp](https://github.com/AlxSp) +**Purpose**: Train LLMs to diagnose patients through multi-turn conversations while adapting to different patient communication styles and personalities + +A sophisticated medical AI environment that simulates realistic doctor-patient interactions by introducing persona-based variability in patient communication styles. The environment tests whether medical reasoning systems can consistently arrive at correct diagnoses regardless of how patients present their symptoms, addressing the real-world challenge of variable patient communication. + +**Features**: +- **Three Patient Personas**: Cooperative (verbose, informative), Reluctant (terse, evasive), and Neutral (brief, factual) +- **Multi-Turn Conversations**: Up to 10 follow-up questions before diagnosis requirement +- **MedQA Dataset Integration**: Uses `GBaker/MedQA-USMLE-4-options` for medical scenarios +- **Dual LLM Architecture**: Grok-3 for patient simulation, configurable model for doctor agent +- **Persona-Filtered Evaluation**: Tests diagnostic robustness across communication styles +- **Real-World Simulation**: Mirrors actual clinical variability in patient presentations + +**Patient Persona Characteristics**: + +**1. Cooperative Patient**: +- Open, verbose, and highly informative responses +- Provides detailed symptom descriptions +- Offers suggestions about potential diagnoses +- Answers with comprehensive information + +**2. Reluctant Patient**: +- Terse, vague, and occasionally evasive responses +- Avoids direct answers to medical questions +- Provides minimal information initially +- Requires skillful questioning to extract details + +**3. Neutral Patient**: +- Brief but factually consistent responses +- Provides accurate information without elaboration +- Straightforward communication style +- Balanced between cooperative and reluctant + +**Training Process**: +1. **Scenario Setup**: MedQA question loaded with correct answer +2. **Persona Assignment**: Random selection of patient communication style +3. **Patient Simulation**: Grok-3 generates initial symptom presentation +4. **Doctor Interaction**: Agent asks follow-up questions (max 10) +5. **Diagnosis Requirement**: Agent must provide final diagnosis in specified format +6. **Evaluation**: Reward based on diagnostic accuracy + +**Diagnostic Format**: +``` +The diagnosis is: {medical_condition} +``` + +**Reward System**: +- **Correct Diagnosis**: +1.0 reward when diagnosis contains the correct answer +- **Incorrect Diagnosis**: 0.0 reward for wrong or missing diagnosis +- **Timeout Penalty**: 0.0 reward if conversation exceeds token limits +- **Accuracy Tracking**: Percentage correct maintained across training batches + +**Example Interaction Flow**: +``` +Patient (Cooperative): "I've been experiencing severe headaches for the past three days, +particularly in the morning. The pain is throbbing and located primarily in my forehead. +I've also noticed some sensitivity to light and mild nausea." + +Doctor: "Can you describe the intensity of the pain on a scale of 1-10?" + +Patient (Cooperative): "I'd say it's about an 8 out of 10 when it's at its worst. +The pain seems to worsen when I move my head quickly or bend over." + +Doctor: "Have you experienced any visual disturbances or aura before the headaches?" + +Patient (Cooperative): "Yes, actually I sometimes see flashing lights about 20 minutes +before the headache starts. There's also a strange zigzag pattern in my vision." + +Doctor: "The diagnosis is: migraine with aura" +``` + +**Technical Implementation**: +- **Async Architecture**: Non-blocking patient simulation and doctor responses +- **Token Management**: Conversation length limits to prevent infinite loops +- **Dual API Integration**: Grok-3 for patients, configurable server for doctors +- **Message Threading**: Proper conversation state management +- **Evaluation Metrics**: WandB integration for training progress tracking + +**Medical Applications**: +- **Clinical Training**: Teaching AI systems to handle diverse patient communication styles +- **Diagnostic Robustness**: Testing medical reasoning under realistic variability +- **Bedside Manner**: Training empathetic and adaptive questioning strategies +- **Triage Systems**: Developing AI for patient intake and initial assessment +- **Medical Education**: Simulating patient encounters for training purposes + +**Research Contributions**: +- **Persona-Aware Benchmarking**: Novel approach to medical QA evaluation +- **Communication Variability**: Addresses gap between perfect narrators and real patients +- **Multi-Turn Reasoning**: Tests sustained diagnostic reasoning over conversations +- **Adaptive Questioning**: Encourages development of better follow-up strategies +- **Real-World Relevance**: Bridges gap between academic benchmarks and clinical practice + +**Dataset Features**: +- **MedQA Foundation**: USMLE-style medical questions with multiple choice answers +- **Synthetic Patient Interactions**: AI-generated persona-based symptom presentations +- **Diagnostic Diversity**: Wide range of medical conditions and scenarios +- **Training/Test Splits**: 128 test examples for efficient evaluation +- **Shuffled Dataset**: Randomized order for robust training + +**Configuration Options**: +- **Model Selection**: Configurable doctor model (default: NousResearch/DeepHermes-3-Llama-3-8B-Preview) +- **Conversation Limits**: Maximum 10 follow-up questions before diagnosis requirement +- **Token Management**: 15K token limit for conversation length +- **Batch Processing**: Group size 32 with 1024 batch size for training +- **Evaluation Frequency**: Steps per evaluation and limit ratios + +**WandB Integration**: +- **Accuracy Tracking**: `train/percent_correct` for diagnostic success rate +- **Conversation Logging**: Complete doctor-patient interaction histories +- **Performance Metrics**: Training progress and evaluation results +- **Persona Analysis**: Success rates across different patient communication styles + +**Future Enhancements**: + +**Extended Persona Development**: +- **Cultural Variations**: Different cultural approaches to medical communication +- **Age-Specific Patterns**: Pediatric, adult, and geriatric communication styles +- **Emotional States**: Anxious, confused, or distressed patient personas +- **Language Barriers**: Non-native speaker communication patterns + +**Advanced Medical Scenarios**: +- **Emergency Triage**: Time-critical diagnostic scenarios +- **Chronic Conditions**: Long-term patient management conversations +- **Mental Health**: Psychiatric evaluation and counseling scenarios +- **Specialist Consultations**: Domain-specific medical interactions + +**Evaluation Improvements**: +- **Diagnostic Confidence**: Uncertainty quantification in medical decisions +- **Question Quality**: Assessment of follow-up question effectiveness +- **Empathy Scoring**: Evaluation of bedside manner and patient rapport +- **Efficiency Metrics**: Diagnostic accuracy per question asked + +**Setup Requirements**: +1. **XAI API Key**: For Grok-3 patient simulation (`XAI_API_KEY` environment variable) +2. **Medical Dataset**: Automatic download of `GBaker/MedQA-USMLE-4-options` +3. **LLM Server**: Local or remote server for doctor agent inference +4. **WandB Account**: For training monitoring and experiment tracking + +**Command Line Usage**: +```bash +# Set up API key +export XAI_API_KEY="your_xai_api_key" + +# Start local LLM server for doctor agent +python -m vllm.entrypoints.openai.api_server \ + --model NousResearch/DeepHermes-3-Llama-3-8B-Preview \ + --port 9001 + +# Run helpful doctors environment +python environments/community/helpful_doctors/doctor.py process \ + --env.use_wandb true \ + --env.wandb_name helpful_doctors_training +``` + +**Demo Resources**: +- **YouTube Demo**: [Persona-Aware MedQA Benchmarking](https://youtube.com/shorts/02GEURik0PQ) +- **WandB Dashboard**: [Training Results](https://wandb.ai/nous-hackathon-2/atropos-environments_hack0_doctor_agent?nw=nwusertsadpbb) + +**Research Impact**: This environment addresses a critical gap in medical AI evaluation by introducing realistic patient communication variability. Unlike traditional QA benchmarks that assume perfect narrators, this approach tests diagnostic robustness under human-like communication patterns, leading to more reliable and empathetic medical AI systems. + +**Clinical Relevance**: Real doctors must interpret patient symptoms that are often incomplete, emotionally colored, or presented in various communication styles. This environment trains AI systems for these real-world challenges, potentially improving safety and effectiveness in clinical deployment. + +**Requirements**: datasets, openai, atroposlib, wandb + --- ## Support diff --git a/environments/hack0/doctor_agent/doctor.py b/environments/community/helpful_doctors/doctor.py similarity index 91% rename from environments/hack0/doctor_agent/doctor.py rename to environments/community/helpful_doctors/doctor.py index de99abde..30403461 100644 --- a/environments/hack0/doctor_agent/doctor.py +++ b/environments/community/helpful_doctors/doctor.py @@ -1,29 +1,21 @@ -import json +import os import random from typing import Dict, List, Optional, Sequence, Tuple, TypedDict -import os + from datasets import load_dataset from openai import OpenAI +from patient import patient_profiles from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, - ScoredDataItem, ScoredDataGroup, + ScoredDataItem, ) -# from atroposlib.envs.base import ( -# BaseEnv, -# BaseEnvConfig, -# EvalHandlingEnum, -# Item, -# APIServerConfig, -# ) from atroposlib.type_definitions import Item -from environments.hack0.doctor_agent.patient import patient_profiles - DatasetItem = TypedDict( "DatasetItem", { @@ -37,18 +29,12 @@ DatasetItem = TypedDict( }, ) -with open("environments/hack0/doctor_agent/secrets.json", "r") as f: - keys = json.load(f) - xai_key = keys["xai"] - - client = OpenAI( - api_key=xai_key, + api_key=os.environ.get("XAI_API_KEY"), base_url="https://api.x.ai/v1", ) final_message = "The diagnosis is:" -final_message_prompt = final_message + " headache" doctor_system_prompt = """You are a doctor. You are interacting with a patient. You need to diagnose the patient based on the symptoms. @@ -76,6 +62,7 @@ doctor_model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview" USER_TAG = "user" ASSISTANT_TAG = "assistant" + class DoctorEnv(BaseEnv): name = "doctor" @@ -112,17 +99,13 @@ class DoctorEnv(BaseEnv): data_path_to_save_groups=None, eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, - debug_mode=True + debug_mode=True, ) server_configs = [ APIServerConfig( - # model_name=doctor_model, - # base_url="http://localhost:9001/v1", - # api_key="x", - # num_requests_for_eval=256, - model_name="grok-3-latest", - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), + model_name=doctor_model, + base_url="http://localhost:9001/v1", + api_key="EMPTY", num_requests_for_eval=256, ), ] @@ -191,7 +174,7 @@ class DoctorEnv(BaseEnv): patient_profile = random.choice(patient_profiles) symptoms = item["question"] - patient_system_prompt = patient_profile.format(symptoms = symptoms) + patient_system_prompt = patient_profile.format(symptoms=symptoms) patient_messages = [{"role": "system", "content": patient_system_prompt}] # print("before xai message") @@ -242,7 +225,7 @@ class DoctorEnv(BaseEnv): diagnosis = doctor_msg.strip(final_message) diagnosis = diagnosis.strip() - if diagnosis.contains(item["answer"]): + if item["answer"] in diagnosis: score = 1 else: score = 0 @@ -285,7 +268,7 @@ class DoctorEnv(BaseEnv): masks=masks, scores=score, ) - + for score in scores["scores"]: self.percent_correct_buffer.append(max(score, 0)) diff --git a/environments/hack0/doctor_agent/patient.py b/environments/community/helpful_doctors/patient.py similarity index 100% rename from environments/hack0/doctor_agent/patient.py rename to environments/community/helpful_doctors/patient.py diff --git a/environments/hack0/doctor_agent/readme.md b/environments/community/helpful_doctors/readme.md similarity index 100% rename from environments/hack0/doctor_agent/readme.md rename to environments/community/helpful_doctors/readme.md diff --git a/environments/hack0/.DS_Store b/environments/hack0/.DS_Store new file mode 100644 index 00000000..6cfe8d0f Binary files /dev/null and b/environments/hack0/.DS_Store differ diff --git a/environments/hack0/doctor_agent/doctor_22.jsonl.zip b/environments/hack0/doctor_agent/doctor_22.jsonl.zip deleted file mode 100644 index 20b1d440..00000000 Binary files a/environments/hack0/doctor_agent/doctor_22.jsonl.zip and /dev/null differ