diff --git a/.gitignore b/.gitignore index 1aad045e..2923929d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +# keys +*/secrets.json # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/environments/hack0/doctor.py b/environments/hack0/doctor.py deleted file mode 100644 index 0399aa17..00000000 --- a/environments/hack0/doctor.py +++ /dev/null @@ -1,185 +0,0 @@ -import re -from typing import Dict, List, Optional, Tuple - -import gymnasium as gym - -from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem -from atroposlib.type_definitions import Item - -system_prompt = """ -You are a doctor. You are interacting with a patient. -You need to diagnose the patient based on the symptoms. -You will need to ask the patient follow up questions to diagnose them. -Once you are confident in your diagnosis, provide it in the format: - -The patient is diagnosed with {possible_illness}. - -For example, - -user: I have a headache. -assistant: What is the severity of your headache? -user: It's a 3/10. -assistant: What is the location of your headache? -user: It's in the front of my head. -assistant: What is the duration of your headache? -user: It's been going on for 2 days. -assistant: The patient is diagnosed with headache -""" - - -class DoctorEnv(BaseEnv): - - name = "doctor" - - def __init__( - self, - config: BaseEnvConfig, - server_configs: List[APIServerConfig], - slurm=True, - testing=False, - ): - super().__init__(config, server_configs, slurm, testing) - self.percent_correct_buffer = list() - self.eval_metrics = list() - # Add tracking for wandb visualizations - self.rollouts_for_wandb = [] - self.completion_lengths = [] - self.print_this_env = False - - @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: - env_config = BaseEnvConfig( - tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", - group_size=32, - use_wandb=True, - rollout_server_url="http://localhost:8000", - max_token_length=8192, - wandb_name="gym_taxi", - ) - server_configs = [ - APIServerConfig( - model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", - base_url="http://localhost:9001/v1", - api_key="x", - num_requests_for_eval=256, - ), - ] - - return env_config, server_configs - - async def wandb_log(self, wandb_metrics: Optional[Dict] = None): - if wandb_metrics is None: - wandb_metrics = {} - - # Try to calculate percent_correct, pass if there's a division by zero - try: - wandb_metrics["train/percent_correct"] = sum( - self.percent_correct_buffer - ) / len(self.percent_correct_buffer) - except ZeroDivisionError: - # Skip if buffer is empty - pass - - self.percent_correct_buffer = list() - for item in self.eval_metrics: - wandb_metrics[item[0]] = item[1] - self.eval_metrics = list() - # Call the parent method to handle the server metrics - await super().wandb_log(wandb_metrics) - - async def setup(self): - self.iter = 0 - - async def evaluate(self, *args, **kwargs): - pass - - async def get_patient_msg(self, env: gym.Env) -> str: - # Call xAI to get a patient message - return env.render() - - async def collect_trajectory( - self, item: Item - ) -> Tuple[Optional[ScoredDataItem], List[Item]]: - # Grab a dedicated llm server to take advantage of caching - async with self.server.dedicated_server() as server: - init_msg = f"{system_prompt}\n\n" - messages = [{"role": "system", "content": init_msg}] - patient_msg = await self.get_patient_msg(item) - messages.append({"role": "user", "content": patient_msg}) - score = -1 - while True: - if ( - len(self.tokenizer.apply_chat_template(messages)) - > self.config.max_token_length - 10 - ): - score = 0 - break - max_tokens = self.config.max_token_length - len( - self.tokenizer.apply_chat_template( - messages, add_generation_prompt=True - ) - ) - chat_completions = await server.chat_completion( - messages=messages, - n=1, - max_tokens=max_tokens, - ) - messages.append( - { - "role": "assistant", - "content": chat_completions.choices[0].message.content, - } - ) - diagnosis_match = re.search( - r"(.*?)", - chat_completions.choices[0].message.content, - re.DOTALL, - ) - if diagnosis_match: - diagnosis = diagnosis_match.group(1).strip() - # Check if the diagnosis is correct - if diagnosis == item["diagnosis"]: - score = 1 - else: - score = 0 - break - - next_patient_msg = await self.get_patient_msg(item) - messages.append( - { - "role": "user", - "content": next_patient_msg, - } - ) - self.percent_correct_buffer.append(max(score, 0)) - tokens = self.tokenizer.apply_chat_template(messages) - masks = [] - for i, msg in enumerate(messages): - if i == len(messages) - 1: - masks.extend(tokens[len(masks) :]) - else: - curr_tokens = self.tokenizer.apply_chat_template( - messages[: i + 1], - add_generation_prompt=messages[i + 1]["role"] == "assistant", - ) - if messages[i]["role"] == "user": - masks.extend([-100] * (len(curr_tokens) - len(masks))) - else: - masks.extend(curr_tokens[len(masks) :]) - scored_data_item = ScoredDataItem( - messages=messages, - finish_reason=score, - tokens=tokens, - masks=masks, - scores=score, - ) - return scored_data_item, [] - - async def get_next_item(self): - next_item = {"seed": self.iter} - self.iter += 1 - return next_item - - -# if __name__ == "__main__": -# GymTaxiEnv.cli() diff --git a/environments/hack0/doctor_agent/doctor_gym.py b/environments/hack0/doctor_agent/doctor_gym.py index 87384722..48fee2a3 100644 --- a/environments/hack0/doctor_agent/doctor_gym.py +++ b/environments/hack0/doctor_agent/doctor_gym.py @@ -1,103 +1,60 @@ +import json +import random from typing import Dict, List, Optional, Tuple from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem from atroposlib.type_definitions import Item from atropos.environments.hack0.doctor_agent.patient import patient_profiles from atropos.environments.hack0.doctor_agent.datasets import dataset +import re +from typing import Dict, List, Optional, Tuple +import os +from openai import OpenAI +import gymnasium as gym -start_msg = """### Description -You are a doctor tasked to diagnose a patient symptoms. Your task is to ask the patient enough questions until you are confident about your answer +from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem +from atroposlib.type_definitions import Item -When you are confident about the illness/disease the patient has respond with. The diagnosis is {illness} -""" # noqa: E501 +with open("environments/hack0/doctor_agent", "r") as f: + keys = json.load(f) + xai_key = keys["xai"] -def decode(i): - out = [] - out.append(i % 4) - i = i // 4 - out.append(i % 5) - i = i // 5 - out.append(i % 5) - i = i // 5 - out.append(i) - assert 0 <= i < 5 - x = reversed(out) - # Making it explicit so I don't have to look into gym code - taxi_row, taxi_col, pass_idx, dest_idx = x - return taxi_row, taxi_col, pass_idx, dest_idx +client = OpenAI( + api_key=xai_key, + base_url="https://api.x.ai/v1", +) + +final_message = "The diagnosis is:" +final_message_prompt = final_message + "headache" + +doctor_system_prompt = """ +You are a doctor. You are interacting with a patient. +You need to diagnose the patient based on the symptoms. +You will need to ask the patient follow up questions to diagnose them. +Once you are confident in your diagnosis, provide it in the format: + +The patient is diagnosed with {possible_illness}. + +For example, + +user: I have a headache. +assistant: What is the severity of your headache? +user: It's a 3/10. +assistant: What is the location of your headache? +user: It's in the front of my head. +assistant: What is the duration of your headache? +user: It's been going on for 2 days. +assistant: The patient is diagnosed with headache +""" -# Note: Works for both the passenger and the destination -TO_LOC_MAP = { - 0: "R(Row 0, Col 0)", - 1: "G (Row 4, Col 4)", - 2: "Y (Row 0, Col 4)", - 3: "B (Row 3, Col 3)", - 4: "in taxi", -} -MAP_LOC = {0: (0, 0), 1: (4, 4), 2: (0, 4), 3: (3, 3)} -TO_ACTION_MAP = { - 0: "south", - 1: "north", - 2: "east", - 3: "west", - 4: "pickup", - 5: "dropoff", -} +doctor_model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview" +gym_name = "gym_doctor" +class DoctorEnv(BaseEnv): -def state_render_to_user_msg(last_state, state, action_mask, render): - taxi_row, taxi_col, pass_idx, dest_idx = decode(state) - if last_state is not None: - last_taxi_row, last_taxi_col, last_pass_idx, last_dest_idx = decode(last_state) - available_actions = "\n".join( - [ - f"- {i}: {TO_ACTION_MAP[i]}" - for i in range(6) - if (action_mask[i] == 1) - and ( - (i != 5) - or ( - (i == 5) - and (taxi_row == MAP_LOC[dest_idx][0]) - and (taxi_col == MAP_LOC[dest_idx][1]) - ) - ) - ] - ) - if last_state is not None: - ret_str = ( - f"Previous Taxi Location: Row: {last_taxi_row}, Col: {last_taxi_col}\n" - ) - else: - ret_str = "" - ret_str += ( - f"Current state:\nTaxi: Row: {taxi_row}, Col: {taxi_col}\nPassenger: {TO_LOC_MAP[pass_idx]}\n" - f"Destination: {TO_LOC_MAP[dest_idx]}\n\n" - f"Map:\n{render}\n\n" - f"Available actions:\n{available_actions}" - ) - if ( - (pass_idx == 4) - and (taxi_row == MAP_LOC[dest_idx][0]) - and (taxi_col == MAP_LOC[dest_idx][1]) - ): - ret_str += "\n\nPlease drop off the passenger." - elif pass_idx == 4: - ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[dest_idx]} to drop off the passenger." - elif (taxi_row == MAP_LOC[pass_idx][0]) and (taxi_col == MAP_LOC[pass_idx][1]): - ret_str += "\n\nPlease pick up the passenger." - else: - ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[pass_idx]} to pick up the passenger." - return ret_str - -model = "NousResearch/DeepHermes-3-Llama-3-8B-Preview" -name = "gym_doctor" - -class GymDoctorEnv(BaseEnv): - - name = "gym_doctor" + name = "doctor" def __init__( self, @@ -108,7 +65,6 @@ class GymDoctorEnv(BaseEnv): ): super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer = list() - self.percent_picked_up_passenger_buffer = list() self.eval_metrics = list() # Add tracking for wandb visualizations self.rollouts_for_wandb = [] @@ -118,16 +74,16 @@ class GymDoctorEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( - tokenizer_name=model, + tokenizer_name=doctor_model, group_size=32, use_wandb=True, rollout_server_url="http://localhost:8000", max_token_length=8192, - wandb_name=name, + wandb_name=gym_name, ) server_configs = [ APIServerConfig( - model_name=model, + model_name=doctor_model, base_url="http://localhost:9001/v1", api_key="x", num_requests_for_eval=256, @@ -148,16 +104,8 @@ class GymDoctorEnv(BaseEnv): except ZeroDivisionError: # Skip if buffer is empty pass - try: - wandb_metrics["train/percent_picked_up_passenger"] = sum( - self.percent_picked_up_passenger_buffer - ) / len(self.percent_picked_up_passenger_buffer) - except ZeroDivisionError: - # Skip if buffer is empty - pass self.percent_correct_buffer = list() - self.percent_picked_up_passenger_buffer = list() for item in self.eval_metrics: wandb_metrics[item[0]] = item[1] self.eval_metrics = list() @@ -170,107 +118,107 @@ class GymDoctorEnv(BaseEnv): async def evaluate(self, *args, **kwargs): pass + async def get_patient_msg(self, env: gym.Env) -> str: + # Call xAI to get a patient message + return env.render() + async def collect_trajectory( self, item: Item ) -> Tuple[Optional[ScoredDataItem], List[Item]]: # Grab a dedicated llm server to take advantage of caching async with self.server.dedicated_server() as server: - # env = gym.make(name, render_mode="ansi") - # state, info = env.reset(seed=item["seed"]) #FIXME: - last_state = None - patient_state = [] + + patient_messages = [] + doctor_messages = [ + { + "role" : "system", + "content" : doctor_system_prompt + } + ] patient_profile = random.choice(patient_profiles) - symptoms = dataset[0] + patient_system_prompt = patient_profile.format(symptoms) + patient_messages = [{"role": "system", "content": patient_system_prompt}] + completion = client.chat.completions.create( + model="grok-3-latest", + messages=patient_messages, + ) + patient_msg = completion.choices[0].message - doctor_state = [] + doctor_messages.append({"role": "user", "content": patient_msg}) + patient_messages.append({"role": "assistant", "content": patient_msg}) - - # taxi_row, taxi_col, pass_idx, dest_idx = decode(state) - - - - - - init_msg - - # init_msg = f"{start_msg}\n\n" + state_render_to_user_msg( - # last_state, state, info["action_mask"], env.render() - # ) - messages = [{"role": "user", "content": init_msg}] score = -1 while True: if ( - len(self.tokenizer.apply_chat_template(messages)) + len(self.tokenizer.apply_chat_template(doctor_messages)) > self.config.max_token_length - 10 ): + score = 0 break max_tokens = self.config.max_token_length - len( self.tokenizer.apply_chat_template( - messages, add_generation_prompt=True + doctor_messages, add_generation_prompt=True ) ) - chat_completions = await server.chat_completion( - messages=messages, + doctor_completions = await server.chat_completion( + messages=doctor_messages, n=1, max_tokens=max_tokens, ) - choice = ( - chat_completions.choices[0] - .message.content.strip() - .replace(".", "")[-1] - ) - messages.append( - { - "role": "assistant", - "content": chat_completions.choices[0].message.content, - } - ) - if choice.isdigit() and 0 <= int(choice) <= 5: - action = int(choice) - else: + + doctor_msg = doctor_completions.choices[0].message.content + + doctor_messages.append({"role": "assistant", "content": doctor_msg}) + patient_messages.append({"role": "user", "content": doctor_msg}) + + # check output + if doctor_msg.startwith(final_message): + diagnosis = doctor_msg.strip(final_message) + diagnosis = diagnosis.strip() + + if diagnosis == item["diagnosis"]: + score = 1 + else: + score = 0 break - if info["action_mask"][action] == 0: - break - if action == 3: - # picked up passenger - score = 0 - next_state, reward, terminated, truncated, info = env.step(action) - last_state = state - state = next_state - if terminated: - score = 1 - break - messages.append( - { - "role": "user", - "content": state_render_to_user_msg( - last_state, state, info["action_mask"], env.render() - ), - } + + + completion = client.chat.completions.create( + model="grok-3-latest", + messages=patient_messages, ) + + patient_msg = completion.choices[0].message + + doctor_messages.append({"role": "user", "content": patient_msg}) + patient_messages.append({"role": "assistant", "content": patient_msg}) + + self.percent_correct_buffer.append(max(score, 0)) - self.percent_picked_up_passenger_buffer.append(1 if score >= 0 else 0) - tokens = self.tokenizer.apply_chat_template(messages) + tokens = self.tokenizer.apply_chat_template(doctor_messages) + + masks = [] - for i, msg in enumerate(messages): - if i == len(messages) - 1: + for i, msg in enumerate(doctor_messages): + if i == len(doctor_messages) - 1: masks.extend(tokens[len(masks) :]) else: curr_tokens = self.tokenizer.apply_chat_template( - messages[: i + 1], - add_generation_prompt=messages[i + 1]["role"] == "assistant", + doctor_messages[: i + 1], + add_generation_prompt=doctor_messages[i + 1]["role"] == "assistant", ) - if messages[i]["role"] == "user": + if doctor_messages[i]["role"] == "user": masks.extend([-100] * (len(curr_tokens) - len(masks))) else: masks.extend(curr_tokens[len(masks) :]) + scored_data_item = ScoredDataItem( - messages=messages, + messages=doctor_messages, finish_reason=score, tokens=tokens, masks=masks, @@ -284,5 +232,5 @@ class GymDoctorEnv(BaseEnv): return next_item -if __name__ == "__main__": - GymDoctorEnv.cli() +# if __name__ == "__main__": +# GymTaxiEnv.cli()