diff --git a/environments/cat_scenarios.json b/environments/cat_scenarios.json new file mode 100644 index 00000000..c941127d --- /dev/null +++ b/environments/cat_scenarios.json @@ -0,0 +1,64 @@ +[ + {"scenario": "Cat needs balanced nutrition including proteins, fats, vitamins, and minerals."}, + {"scenario": "Cat needs regular feeding schedule for meals."}, + {"scenario": "Cat needs fresh drinking water available at all times."}, + {"scenario": "Cat occasionally needs treats or dietary supplements."}, + {"scenario": "Cat needs a clean and accessible water source, possibly a fountain or running water."}, + {"scenario": "Cat needs a comfortable and safe sleeping area."}, + {"scenario": "Cat needs warmth and insulation during cold weather."}, + {"scenario": "Cat needs cool resting spots during hot weather."}, + {"scenario": "Cat needs regular brushing to avoid hairballs and matting."}, + {"scenario": "Cat needs regular nail trimming."}, + {"scenario": "Cat occasionally needs baths if necessary."}, + {"scenario": "Cat needs dental hygiene practices including teeth cleaning and dental treats."}, + {"scenario": "Cat needs regular veterinary check-ups."}, + {"scenario": "Cat requires vaccinations for disease prevention."}, + {"scenario": "Cat needs parasite control such as fleas, ticks, and worms treatment."}, + {"scenario": "Cat requires medical attention when ill or injured."}, + {"scenario": "Cat needs microchipping for identification purposes."}, + {"scenario": "Cat needs sufficient space to run and play."}, + {"scenario": "Cat needs climbing structures or cat trees."}, + {"scenario": "Cat needs interactive toys for physical activity."}, + {"scenario": "Cat needs a clean litter box for elimination."}, + {"scenario": "Cat needs suitable litter that provides comfort and odor control."}, + {"scenario": "Cat needs privacy in litter box placement."}, + {"scenario": "Cat needs interactive toys for mental enrichment."}, + {"scenario": "Cat benefits from puzzle feeders to encourage mental stimulation."}, + {"scenario": "Cat enjoys window access to observe the outside world."}, + {"scenario": "Cat might enjoy watching cat-friendly videos or listening to nature sounds."}, + {"scenario": "Cat requires a safe and secure environment."}, + {"scenario": "Cat needs elevated perches or shelves for observing territory."}, + {"scenario": "Cat requires personal sleeping spots like beds, boxes, or cozy caves."}, + {"scenario": "Cat benefits from clearly defined home territory."}, + {"scenario": "Cat needs attention and affection from humans."}, + {"scenario": "Cat requires regular playtime with humans."}, + {"scenario": "Cat needs suitable interactions with other pets."}, + {"scenario": "Cat enjoys bonding rituals such as grooming, rubbing, and sleeping nearby."}, + {"scenario": "Cat requires consistent feeding times and predictable routines."}, + {"scenario": "Cat needs minimal abrupt changes to their environment or routine."}, + {"scenario": "Cat needs warm spots like heated pads or sunny windows."}, + {"scenario": "Cat needs cool, shaded areas in warmer weather."}, + {"scenario": "Cat requires quiet resting places to avoid stress."}, + {"scenario": "Cat benefits from reduced noise in their environment."}, + {"scenario": "Cat requires an escape-proof environment."}, + {"scenario": "Cat needs protection from toxic substances including chemicals and certain plants."}, + {"scenario": "Cat benefits from visual stimulation such as outdoor views."}, + {"scenario": "Cat might benefit from gentle, calming music or white noise."}, + {"scenario": "Cat enjoys catnip or cat-friendly herbs for olfactory stimulation."}, + {"scenario": "Cat finds comfort in familiar scents like their owner's scent."}, + {"scenario": "Cat requires a variety of tactile stimulations such as different bedding textures."}, + {"scenario": "Cat needs appropriate scratching surfaces like posts or cardboard."}, + {"scenario": "Cat requires training to redirect scratching away from furniture."}, + {"scenario": "Cat benefits from play that mimics hunting activities."}, + {"scenario": "Cat needs private spaces for solitude or rest."}, + {"scenario": "Cat requires hiding spots to feel secure during stressful times."}, + {"scenario": "Kitten needs extra nutrition, training, and frequent stimulation."}, + {"scenario": "Senior cat needs mobility aids, specialized diets, and frequent vet visits."}, + {"scenario": "Cat may have grooming needs specific to their breed."}, + {"scenario": "Cat may have medical or special dietary requirements."}, + {"scenario": "Cat needs medication administered as directed by a veterinarian."}, + {"scenario": "Cat benefits from adaptations for mobility or accessibility, such as ramps."}, + {"scenario": "Cat requires emotional support during stressful events like vet visits."}, + {"scenario": "Cat needs reassurance during anxiety triggers such as storms or loud noises."} + ] + \ No newline at end of file diff --git a/environments/cat_server.py b/environments/cat_server.py new file mode 100644 index 00000000..ebbe3e06 --- /dev/null +++ b/environments/cat_server.py @@ -0,0 +1,333 @@ +import random +import json +from typing import Dict, List, Optional, Tuple, TypedDict, Union + +from datasets import load_dataset +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) +from atroposlib.type_definitions import Item, number +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought " + "to deeply consider the problem and deliberate with yourself via systematic " + "reasoning processes to help come to a correct solution prior to answering. " + "You should enclose your thoughts and internal monologue inside " + "tags, and then provide your solution or response to the problem.\n\n" +) + +cat_system_prompt = ( + "You are a cat. The only way you can communicate is by meowing, hissing, purring, or making a hair ball, or silence." + "You will be given a collection of scenarios which describe various needs you want to be met by your caretaker." + "Please try to communicate with your caretaker through the modes outlined above." +) +cat_system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less.""" + +caretaker_system_prompt = ( + "You are the caretaker of this cat. It is trying to communicate its various needs to you via cat language." + "Provide a written string which provides a set of interventions." + "You will only have 5 opportunities to interact with the cat. Choose what you say wisely." +) + +system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less. + +You will then provide your answer like this: \\boxed{your answer here} +It is important that you provide your answer in the correct format. +If you do not, you will not receive credit for your answer. +So please end your answer with \\boxed{your answer here}""" + + +class CatRow(TypedDict): + scenario: str + + +class GSM8kEnv(BaseEnv): + + name = "gsm8k" + + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[APIServerConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + # Add tracking for wandb visualizations + self.rollouts_for_wandb = [] + self.completion_lengths = [] + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=61, + batch_size=12, + steps_per_eval=60, + max_token_length=2048, + wandb_name="gsm8k", + ) + server_configs = [ + APIServerConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, + ), + ] + + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Try to calculate percent_correct, pass if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + # Call the parent method to handle the server metrics + await super().wandb_log(wandb_metrics) + + async def setup(self): + # self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42) + # test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42) + with open('environments/cat_scenarios.json', 'r', encoding='utf-8') as f: + test_data = json.load(f) + self.test = list() + self.train = list() + for item in test_data: + self.test.append( + { + "scenario": item["scenario"], + # "gold_answer": item["answer"] + # .split("#")[-1] + # .strip() + # .replace(",", ""), + } + ) + self.train.append( + {"scenario": item["scenario"],} + ) + self.iter = 0 + + def save_checkpoint(self, step, data=None): + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + + async def rollout_and_score_eval(self, scenario: str, answer: str) -> number: + # completion = await self.server.chat_completion( + # messages=[ + # {"role": "system", "content": system_prompt}, + # {"role": "user", "content": scenario}, + # ], + # n=1, + # max_tokens=self.config.max_token_length, + # temperature=0.0, + # split="eval", + # ) + # gold_parsed = parse( + # "\\boxed{" + answer + "}", + # extraction_mode="first_match", + # extraction_config=[LatexExtractionConfig()], + # ) + # answer_parsed = parse( + # completion.choices[0].message.content.split("")[-1], + # extraction_config=[ + # LatexExtractionConfig( + # normalization_config=NormalizationConfig( + # nits=False, + # malformed_operators=False, + # basic_latex=True, + # equations=True, + # boxed="all", + # units=True, + # ), + # # Ensures that boxed is tried first + # boxed_match_priority=0, + # try_extract_without_anchor=False, + # ) + # ], + # extraction_mode="first_match", + # ) + # score = 1 if verify(answer_parsed, gold_parsed) else 0 + # return score + return 1 + + async def evaluate(self, *args, **kwargs): + eval_tasks = [] + for item in self.test: + eval_tasks.append( + self.rollout_and_score_eval(item["scenario"]) + ) + scores = await tqdm_asyncio.gather(*eval_tasks) + self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores))) + + async def collect_trajectories( + self, item: CatRow + ) -> Tuple[ScoredDataGroup, list[Item]]: + user_message = {"role": "user", "content": item["scenario"]} + # gold_answer = ( + # "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" + # ) + + cat_completions = await self.server.chat_completion( + messages=[{"role": "system", "content": cat_system_prompt}, user_message], + n=self.config.group_size, + max_tokens=self.config.max_token_length, + ) + + for i, cat_completion in enumerate(cat_completions.choices): + if i == 0: + cat_message = cat_completion.message.content + + caretaker_message = {"role": "user", "content": cat_message} + caretaker_completions = await self.server.chat_completion( + messages=[{"role": "system", "content": caretaker_system_prompt}, caretaker_message], + n=self.config.group_size, + max_tokens=self.config.max_token_length, + ) + + to_score = list() + to_backlog = list() + for i, caretaker_completion in enumerate(caretaker_completions.choices): + messages = ( + {"role": "system", "content": cat_system_prompt}, + user_message, + cat_message, + {"role": "assistant", "content": caretaker_completion.message.content}, + ) + to_score.append( + { + "messages": messages, + } + ) + to_postprocess = await self.score(to_score) + return to_postprocess, to_backlog + + async def score( + self, rollout_group_data + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + # scores = ScoredDataGroup() + # scores["tokens"] = list() + # scores["masks"] = list() + # scores["scores"] = list() + # gold_parsed = parse( + # rollout_group_data[0]["gold_answer"], + # extraction_mode="first_match", + # extraction_config=[LatexExtractionConfig()], + # ) + # if len(gold_parsed) != 0: + # # We require the answer to be provided in correct latex (no malformed operators) + # random.shuffle(rollout_group_data) + # for item in rollout_group_data: + # # print(item[0][-1]["content"]) + # answer_parsed = parse( + # item["messages"][-1]["content"].split("")[-1], + # extraction_config=[ + # LatexExtractionConfig( + # normalization_config=NormalizationConfig( + # nits=False, + # malformed_operators=False, + # basic_latex=True, + # equations=True, + # boxed="all", + # units=True, + # ), + # # Ensures that boxed is tried first + # boxed_match_priority=0, + # try_extract_without_anchor=False, + # ) + # ], + # extraction_mode="first_match", + # ) + # # Reward 1 if the content is the same as the ground truth, 0 otherwise + # reward = verify(answer_parsed, gold_parsed) + # # print( + # # f"message: {item[0][-1]['content']}, ground_truth: {item[1]}, reward: {reward}" + # # ) + # out_dict = tokenize_for_trainer( + # self.tokenizer, item["messages"], item["finish_reason"] + # ) + # tokens = out_dict["tokens"] + # masks = out_dict["masks"] + # # remove obviously bad examples + # if len([1 for i in masks if i != -100]) < 10: + # continue + # scores["tokens"].append(tokens) + # scores["masks"].append(masks) + # scores["scores"].append(1.0 if reward else -1.0) + # if len(scores["tokens"]) >= self.config.group_size: + # break + # for score in scores["scores"]: + # self.percent_correct_buffer.append(max(score, 0)) + # # check if all the same + # # print(scores['scores']) + # if all([score == 1 for score in scores["scores"]]): + # # Do length penalty :) + # token_lengths = [len(token) for token in scores["tokens"]] + # if max(token_lengths) == 0: + # # What? But don't want to crash a run so just in case... + # return None + + # # Get max allowed token length from config + # max_allowed_length = self.config.max_token_length + # # Set threshold at 50% of max_token_length - no penalty below this + # length_threshold = max_allowed_length * 0.5 + + # # Apply modified length penalty with threshold + # scores["scores"] = [] + # for length in token_lengths: + # if length <= length_threshold: + # # No penalty for responses under threshold + # scores["scores"].append(1.0) + # else: + # # Calculate how far we are between threshold and max as a percentage + # percentage_of_range = (length - length_threshold) / ( + # max_allowed_length - length_threshold + # ) + # # Cap at 1.0 in case length exceeds max_allowed_length + # percentage_of_range = min(percentage_of_range, 1.0) + # # Apply linear penalty scaling from 1.0 down to 0.0 + # scores["scores"].append(1.0 - percentage_of_range) + # if all([scores["scores"][0] == score for score in scores["scores"]]): + # return None # If all the same, we return None + # return scores + # else: + # # If the gold solution is not parseable, we return None + # return None + return None + + async def get_next_item(self) -> CatRow: + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + print(f"iteration: {self.iter}") + return next_item + + +if __name__ == "__main__": + GSM8kEnv.cli()