diff --git a/environments/cat_server.py b/environments/cat_server.py index 989837c2..10e30ae4 100644 --- a/environments/cat_server.py +++ b/environments/cat_server.py @@ -234,55 +234,70 @@ class GSM8kEnv(BaseEnv): self, rollout_group_data ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: scores = ScoredDataGroup() + scores["tokens"] = list() scores["masks"] = list() scores["scores"] = list() - # random.shuffle(rollout_group_data) + # # random.shuffle(rollout_group_data) for item in rollout_group_data: - out_dict = tokenize_for_trainer( - self.tokenizer, item["messages"] + final_question = list(item["messages"]) + [{'role': 'system', 'content': 'The conversation is over. Say meow if the caretaker did a bad job, or purr if the caretaker did a good job.'}] + caretaker_completions = await self.server.chat_completion( + messages=final_question, + n=1, + max_tokens=self.config.max_token_length, ) - tokens = out_dict["tokens"] - masks = out_dict["masks"] - # remove obviously bad examples - if len([1 for i in masks if i != -100]) < 10: - continue - scores["tokens"].append(tokens) - scores["masks"].append(masks) - scores["scores"].append(1.0) - if len(scores["tokens"]) >= self.config.group_size: - break - for score in scores["scores"]: - self.percent_correct_buffer.append(max(score, 0)) - # check if all the same - # print(scores['scores']) - if all([score == 1 for score in scores["scores"]]): - # Do length penalty :) - token_lengths = [len(token) for token in scores["tokens"]] - if max(token_lengths) == 0: - # What? But don't want to crash a run so just in case... - return None + final_out = {'role': 'system', 'content': [row.message.content for row in caretaker_completions.choices][0]} - # Get max allowed token length from config - max_allowed_length = self.config.max_token_length - # Set threshold at 50% of max_token_length - no penalty below this - length_threshold = max_allowed_length * 0.5 + final_score = purrfect_eval(final_out['content']) - # Apply modified length penalty with threshold - scores["scores"] = [] - for length in token_lengths: - if length <= length_threshold: - # No penalty for responses under threshold - scores["scores"].append(1.0) - else: - # Calculate how far we are between threshold and max as a percentage - percentage_of_range = (length - length_threshold) / ( - max_allowed_length - length_threshold - ) - # Cap at 1.0 in case length exceeds max_allowed_length - percentage_of_range = min(percentage_of_range, 1.0) - # Apply linear penalty scaling from 1.0 down to 0.0 - scores["scores"].append(1.0 - percentage_of_range) + out_dict = tokenize_for_trainer( + self.tokenizer, [row for row in item["messages"]] + [final_out] + ) + scores['tokens'].append(out_dict['tokens']) + scores['masks'].append(out_dict['masks']) + scores['scores'].append(final_score) + + # tokens = out_dict["tokens"] + # masks = out_dict["masks"] + # # remove obviously bad examples + # if len([1 for i in masks if i != -100]) < 10: + # continue + # scores["tokens"].append(tokens) + # scores["masks"].append(masks) + # scores["scores"].append(1.0) + # if len(scores["tokens"]) >= self.config.group_size: + # break + # for score in scores["scores"]: + # self.percent_correct_buffer.append(max(score, 0)) + # # check if all the same + # # print(scores['scores']) + # if all([score == 1 for score in scores["scores"]]): + # # Do length penalty :) + # token_lengths = [len(token) for token in scores["tokens"]] + # if max(token_lengths) == 0: + # # What? But don't want to crash a run so just in case... + # return None + + # # Get max allowed token length from config + # max_allowed_length = self.config.max_token_length + # # Set threshold at 50% of max_token_length - no penalty below this + # length_threshold = max_allowed_length * 0.5 + + # # Apply modified length penalty with threshold + # scores["scores"] = [] + # for length in token_lengths: + # if length <= length_threshold: + # # No penalty for responses under threshold + # scores["scores"].append(1.0) + # else: + # # Calculate how far we are between threshold and max as a percentage + # percentage_of_range = (length - length_threshold) / ( + # max_allowed_length - length_threshold + # ) + # # Cap at 1.0 in case length exceeds max_allowed_length + # percentage_of_range = min(percentage_of_range, 1.0) + # # Apply linear penalty scaling from 1.0 down to 0.0 + # scores["scores"].append(1.0 - percentage_of_range) return scores @@ -380,5 +395,11 @@ class GSM8kEnv(BaseEnv): return next_item +def purrfect_eval(st: str) -> float: + if "purr" in st.lower(): + return 1.0 + return 0.0 + + if __name__ == "__main__": GSM8kEnv.cli()