Add dataset (#3)

This commit is contained in:
tsadpbb 2025-05-18 16:01:54 -07:00 committed by GitHub
parent 02ff663ebe
commit 189d7a3dd1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 77 additions and 50 deletions

View file

@ -1 +0,0 @@
dataset = []

View file

@ -1,20 +1,34 @@
import json
import random
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Sequence, Tuple, TypedDict
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
from atroposlib.type_definitions import Item
from atropos.environments.hack0.doctor_agent.patient import patient_profiles
from atropos.environments.hack0.doctor_agent.datasets import dataset
import re
from typing import Dict, List, Optional, Tuple
import os
from datasets import load_dataset
from openai import OpenAI
import gymnasium as gym
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
ScoredDataItem,
)
from atroposlib.type_definitions import Item
from .patient import patient_profiles
DatasetItem = TypedDict(
"DatasetItem",
{
"question": str,
"answer": str,
"options": Dict[str, str],
"meta_info": str,
"answer_idx": str,
"diagnosis": str,
"metamap_sequence": Sequence[str],
},
)
with open("environments/hack0/doctor_agent", "r") as f:
keys = json.load(f)
xai_key = keys["xai"]
@ -50,7 +64,8 @@ assistant: The patient is diagnosed with <diagnosis>headache</diagnosis>
doctor_model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
gym_name = "gym_doctor"
wandb_name = "doctor"
class DoctorEnv(BaseEnv):
@ -60,7 +75,7 @@ class DoctorEnv(BaseEnv):
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
slurm=False,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
@ -78,8 +93,16 @@ class DoctorEnv(BaseEnv):
group_size=32,
use_wandb=True,
rollout_server_url="http://localhost:8000",
max_token_length=8192,
wandb_name=gym_name,
wandb_name=wandb_name,
max_num_workers=128,
total_steps=100,
batch_size=1024,
steps_per_eval=1,
max_token_length=1024 * 15,
inference_weight=1.0,
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
)
server_configs = [
APIServerConfig(
@ -113,15 +136,33 @@ class DoctorEnv(BaseEnv):
await super().wandb_log(wandb_metrics)
async def setup(self):
"""
Set up the environment by loading and preparing the dataset.
"""
# Load the full dataset
full_dataset = load_dataset("GBaker/MedQA-USMLE-4-options")
full_dataset = full_dataset.shuffle(seed=42)
# Keep the splits as is - no need to reformat
self.train = full_dataset["train"]
# Limit test set size to prevent evaluation from taking too long
self.test = full_dataset["test"].select(
range(min(128, len(full_dataset["test"])))
)
# Print some dataset statistics
print(
f"Loaded dataset with {len(self.train)} training examples and {len(self.test)} test examples"
)
print(f"Example item format: {self.train[0]}")
# Initialize iteration counter
self.iter = 0
async def evaluate(self, *args, **kwargs):
pass
async def get_patient_msg(self, env: gym.Env) -> str:
# Call xAI to get a patient message
return env.render()
async def collect_trajectory(
self, item: Item
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
@ -129,15 +170,10 @@ class DoctorEnv(BaseEnv):
async with self.server.dedicated_server() as server:
patient_messages = []
doctor_messages = [
{
"role" : "system",
"content" : doctor_system_prompt
}
]
doctor_messages = [{"role": "system", "content": doctor_system_prompt}]
patient_profile = random.choice(patient_profiles)
symptoms = dataset[0]
symptoms = item["question"]
patient_system_prompt = patient_profile.format(symptoms)
patient_messages = [{"role": "system", "content": patient_system_prompt}]
@ -181,13 +217,12 @@ class DoctorEnv(BaseEnv):
diagnosis = doctor_msg.strip(final_message)
diagnosis = diagnosis.strip()
if diagnosis == item["diagnosis"]:
if diagnosis.contains(item["answer"]):
score = 1
else:
score = 0
break
completion = client.chat.completions.create(
model="grok-3-latest",
messages=patient_messages,
@ -198,11 +233,9 @@ class DoctorEnv(BaseEnv):
doctor_messages.append({"role": "user", "content": patient_msg})
patient_messages.append({"role": "assistant", "content": patient_msg})
self.percent_correct_buffer.append(max(score, 0))
tokens = self.tokenizer.apply_chat_template(doctor_messages)
masks = []
for i, msg in enumerate(doctor_messages):
if i == len(doctor_messages) - 1:
@ -210,13 +243,14 @@ class DoctorEnv(BaseEnv):
else:
curr_tokens = self.tokenizer.apply_chat_template(
doctor_messages[: i + 1],
add_generation_prompt=doctor_messages[i + 1]["role"] == "assistant",
add_generation_prompt=doctor_messages[i + 1]["role"]
== "assistant",
)
if doctor_messages[i]["role"] == "user":
masks.extend([-100] * (len(curr_tokens) - len(masks)))
else:
masks.extend(curr_tokens[len(masks) :])
scored_data_item = ScoredDataItem(
messages=doctor_messages,
finish_reason=score,
@ -227,10 +261,17 @@ class DoctorEnv(BaseEnv):
return scored_data_item, []
async def get_next_item(self):
next_item = {"seed": self.iter}
"""
Get the next training item from the dataset.
Returns:
A tuple containing prompt and expected answer
"""
next_item: DatasetItem = self.train[self.iter % len(self.train)]
self.iter += 1
return next_item
# if __name__ == "__main__":
# GymTaxiEnv.cli()
if __name__ == "__main__":
DoctorEnv.cli()

View file

@ -1,14 +1 @@
patient_profiles = [
]
from openai import OpenAI
client = OpenAI(
api_key=XAI_API_KEY,
base_url="https://api.x.ai/v1",
)
patient_profiles = []