mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
Add dataset (#3)
This commit is contained in:
parent
02ff663ebe
commit
189d7a3dd1
3 changed files with 77 additions and 50 deletions
|
|
@ -1 +0,0 @@
|
|||
dataset = []
|
||||
|
|
@ -1,20 +1,34 @@
|
|||
import json
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, TypedDict
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
|
||||
from atroposlib.type_definitions import Item
|
||||
from atropos.environments.hack0.doctor_agent.patient import patient_profiles
|
||||
from atropos.environments.hack0.doctor_agent.datasets import dataset
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import os
|
||||
from datasets import load_dataset
|
||||
from openai import OpenAI
|
||||
import gymnasium as gym
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
ScoredDataItem,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from .patient import patient_profiles
|
||||
|
||||
DatasetItem = TypedDict(
|
||||
"DatasetItem",
|
||||
{
|
||||
"question": str,
|
||||
"answer": str,
|
||||
"options": Dict[str, str],
|
||||
"meta_info": str,
|
||||
"answer_idx": str,
|
||||
"diagnosis": str,
|
||||
"metamap_sequence": Sequence[str],
|
||||
},
|
||||
)
|
||||
|
||||
with open("environments/hack0/doctor_agent", "r") as f:
|
||||
keys = json.load(f)
|
||||
xai_key = keys["xai"]
|
||||
|
|
@ -50,7 +64,8 @@ assistant: The patient is diagnosed with <diagnosis>headache</diagnosis>
|
|||
|
||||
|
||||
doctor_model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
|
||||
gym_name = "gym_doctor"
|
||||
wandb_name = "doctor"
|
||||
|
||||
|
||||
class DoctorEnv(BaseEnv):
|
||||
|
||||
|
|
@ -60,7 +75,7 @@ class DoctorEnv(BaseEnv):
|
|||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
|
|
@ -78,8 +93,16 @@ class DoctorEnv(BaseEnv):
|
|||
group_size=32,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
max_token_length=8192,
|
||||
wandb_name=gym_name,
|
||||
wandb_name=wandb_name,
|
||||
max_num_workers=128,
|
||||
total_steps=100,
|
||||
batch_size=1024,
|
||||
steps_per_eval=1,
|
||||
max_token_length=1024 * 15,
|
||||
inference_weight=1.0,
|
||||
data_path_to_save_groups=None,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
|
|
@ -113,15 +136,33 @@ class DoctorEnv(BaseEnv):
|
|||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Set up the environment by loading and preparing the dataset.
|
||||
"""
|
||||
# Load the full dataset
|
||||
full_dataset = load_dataset("GBaker/MedQA-USMLE-4-options")
|
||||
|
||||
full_dataset = full_dataset.shuffle(seed=42)
|
||||
|
||||
# Keep the splits as is - no need to reformat
|
||||
self.train = full_dataset["train"]
|
||||
# Limit test set size to prevent evaluation from taking too long
|
||||
self.test = full_dataset["test"].select(
|
||||
range(min(128, len(full_dataset["test"])))
|
||||
)
|
||||
|
||||
# Print some dataset statistics
|
||||
print(
|
||||
f"Loaded dataset with {len(self.train)} training examples and {len(self.test)} test examples"
|
||||
)
|
||||
print(f"Example item format: {self.train[0]}")
|
||||
|
||||
# Initialize iteration counter
|
||||
self.iter = 0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def get_patient_msg(self, env: gym.Env) -> str:
|
||||
# Call xAI to get a patient message
|
||||
return env.render()
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
|
||||
|
|
@ -129,15 +170,10 @@ class DoctorEnv(BaseEnv):
|
|||
async with self.server.dedicated_server() as server:
|
||||
|
||||
patient_messages = []
|
||||
doctor_messages = [
|
||||
{
|
||||
"role" : "system",
|
||||
"content" : doctor_system_prompt
|
||||
}
|
||||
]
|
||||
doctor_messages = [{"role": "system", "content": doctor_system_prompt}]
|
||||
|
||||
patient_profile = random.choice(patient_profiles)
|
||||
symptoms = dataset[0]
|
||||
symptoms = item["question"]
|
||||
patient_system_prompt = patient_profile.format(symptoms)
|
||||
|
||||
patient_messages = [{"role": "system", "content": patient_system_prompt}]
|
||||
|
|
@ -181,13 +217,12 @@ class DoctorEnv(BaseEnv):
|
|||
diagnosis = doctor_msg.strip(final_message)
|
||||
diagnosis = diagnosis.strip()
|
||||
|
||||
if diagnosis == item["diagnosis"]:
|
||||
if diagnosis.contains(item["answer"]):
|
||||
score = 1
|
||||
else:
|
||||
score = 0
|
||||
break
|
||||
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="grok-3-latest",
|
||||
messages=patient_messages,
|
||||
|
|
@ -198,11 +233,9 @@ class DoctorEnv(BaseEnv):
|
|||
doctor_messages.append({"role": "user", "content": patient_msg})
|
||||
patient_messages.append({"role": "assistant", "content": patient_msg})
|
||||
|
||||
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
tokens = self.tokenizer.apply_chat_template(doctor_messages)
|
||||
|
||||
|
||||
|
||||
masks = []
|
||||
for i, msg in enumerate(doctor_messages):
|
||||
if i == len(doctor_messages) - 1:
|
||||
|
|
@ -210,13 +243,14 @@ class DoctorEnv(BaseEnv):
|
|||
else:
|
||||
curr_tokens = self.tokenizer.apply_chat_template(
|
||||
doctor_messages[: i + 1],
|
||||
add_generation_prompt=doctor_messages[i + 1]["role"] == "assistant",
|
||||
add_generation_prompt=doctor_messages[i + 1]["role"]
|
||||
== "assistant",
|
||||
)
|
||||
if doctor_messages[i]["role"] == "user":
|
||||
masks.extend([-100] * (len(curr_tokens) - len(masks)))
|
||||
else:
|
||||
masks.extend(curr_tokens[len(masks) :])
|
||||
|
||||
|
||||
scored_data_item = ScoredDataItem(
|
||||
messages=doctor_messages,
|
||||
finish_reason=score,
|
||||
|
|
@ -227,10 +261,17 @@ class DoctorEnv(BaseEnv):
|
|||
return scored_data_item, []
|
||||
|
||||
async def get_next_item(self):
|
||||
next_item = {"seed": self.iter}
|
||||
"""
|
||||
Get the next training item from the dataset.
|
||||
|
||||
Returns:
|
||||
A tuple containing prompt and expected answer
|
||||
"""
|
||||
next_item: DatasetItem = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
|
||||
return next_item
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# GymTaxiEnv.cli()
|
||||
if __name__ == "__main__":
|
||||
DoctorEnv.cli()
|
||||
|
|
@ -1,14 +1 @@
|
|||
|
||||
|
||||
patient_profiles = [
|
||||
|
||||
]
|
||||
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key=XAI_API_KEY,
|
||||
base_url="https://api.x.ai/v1",
|
||||
)
|
||||
|
||||
patient_profiles = []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue