add reward

This commit is contained in:
Jonah Philion 2025-05-18 17:23:11 -07:00
parent 128dce55bc
commit 00599f2e4b

View file

@ -234,55 +234,70 @@ class GSM8kEnv(BaseEnv):
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
# random.shuffle(rollout_group_data)
# # random.shuffle(rollout_group_data)
for item in rollout_group_data:
out_dict = tokenize_for_trainer(
self.tokenizer, item["messages"]
final_question = list(item["messages"]) + [{'role': 'system', 'content': 'The conversation is over. Say meow if the caretaker did a bad job, or purr if the caretaker did a good job.'}]
caretaker_completions = await self.server.chat_completion(
messages=final_question,
n=1,
max_tokens=self.config.max_token_length,
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0)
if len(scores["tokens"]) >= self.config.group_size:
break
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
# check if all the same
# print(scores['scores'])
if all([score == 1 for score in scores["scores"]]):
# Do length penalty :)
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) == 0:
# What? But don't want to crash a run so just in case...
return None
final_out = {'role': 'system', 'content': [row.message.content for row in caretaker_completions.choices][0]}
# Get max allowed token length from config
max_allowed_length = self.config.max_token_length
# Set threshold at 50% of max_token_length - no penalty below this
length_threshold = max_allowed_length * 0.5
final_score = purrfect_eval(final_out['content'])
# Apply modified length penalty with threshold
scores["scores"] = []
for length in token_lengths:
if length <= length_threshold:
# No penalty for responses under threshold
scores["scores"].append(1.0)
else:
# Calculate how far we are between threshold and max as a percentage
percentage_of_range = (length - length_threshold) / (
max_allowed_length - length_threshold
)
# Cap at 1.0 in case length exceeds max_allowed_length
percentage_of_range = min(percentage_of_range, 1.0)
# Apply linear penalty scaling from 1.0 down to 0.0
scores["scores"].append(1.0 - percentage_of_range)
out_dict = tokenize_for_trainer(
self.tokenizer, [row for row in item["messages"]] + [final_out]
)
scores['tokens'].append(out_dict['tokens'])
scores['masks'].append(out_dict['masks'])
scores['scores'].append(final_score)
# tokens = out_dict["tokens"]
# masks = out_dict["masks"]
# # remove obviously bad examples
# if len([1 for i in masks if i != -100]) < 10:
# continue
# scores["tokens"].append(tokens)
# scores["masks"].append(masks)
# scores["scores"].append(1.0)
# if len(scores["tokens"]) >= self.config.group_size:
# break
# for score in scores["scores"]:
# self.percent_correct_buffer.append(max(score, 0))
# # check if all the same
# # print(scores['scores'])
# if all([score == 1 for score in scores["scores"]]):
# # Do length penalty :)
# token_lengths = [len(token) for token in scores["tokens"]]
# if max(token_lengths) == 0:
# # What? But don't want to crash a run so just in case...
# return None
# # Get max allowed token length from config
# max_allowed_length = self.config.max_token_length
# # Set threshold at 50% of max_token_length - no penalty below this
# length_threshold = max_allowed_length * 0.5
# # Apply modified length penalty with threshold
# scores["scores"] = []
# for length in token_lengths:
# if length <= length_threshold:
# # No penalty for responses under threshold
# scores["scores"].append(1.0)
# else:
# # Calculate how far we are between threshold and max as a percentage
# percentage_of_range = (length - length_threshold) / (
# max_allowed_length - length_threshold
# )
# # Cap at 1.0 in case length exceeds max_allowed_length
# percentage_of_range = min(percentage_of_range, 1.0)
# # Apply linear penalty scaling from 1.0 down to 0.0
# scores["scores"].append(1.0 - percentage_of_range)
return scores
@ -380,5 +395,11 @@ class GSM8kEnv(BaseEnv):
return next_item
def purrfect_eval(st: str) -> float:
if "purr" in st.lower():
return 1.0
return 0.0
if __name__ == "__main__":
GSM8kEnv.cli()