mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
add reward
This commit is contained in:
parent
128dce55bc
commit
00599f2e4b
1 changed files with 63 additions and 42 deletions
|
|
@ -234,55 +234,70 @@ class GSM8kEnv(BaseEnv):
|
|||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup()
|
||||
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
# random.shuffle(rollout_group_data)
|
||||
# # random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, item["messages"]
|
||||
final_question = list(item["messages"]) + [{'role': 'system', 'content': 'The conversation is over. Say meow if the caretaker did a bad job, or purr if the caretaker did a good job.'}]
|
||||
caretaker_completions = await self.server.chat_completion(
|
||||
messages=final_question,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0)
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
for score in scores["scores"]:
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
# check if all the same
|
||||
# print(scores['scores'])
|
||||
if all([score == 1 for score in scores["scores"]]):
|
||||
# Do length penalty :)
|
||||
token_lengths = [len(token) for token in scores["tokens"]]
|
||||
if max(token_lengths) == 0:
|
||||
# What? But don't want to crash a run so just in case...
|
||||
return None
|
||||
final_out = {'role': 'system', 'content': [row.message.content for row in caretaker_completions.choices][0]}
|
||||
|
||||
# Get max allowed token length from config
|
||||
max_allowed_length = self.config.max_token_length
|
||||
# Set threshold at 50% of max_token_length - no penalty below this
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
final_score = purrfect_eval(final_out['content'])
|
||||
|
||||
# Apply modified length penalty with threshold
|
||||
scores["scores"] = []
|
||||
for length in token_lengths:
|
||||
if length <= length_threshold:
|
||||
# No penalty for responses under threshold
|
||||
scores["scores"].append(1.0)
|
||||
else:
|
||||
# Calculate how far we are between threshold and max as a percentage
|
||||
percentage_of_range = (length - length_threshold) / (
|
||||
max_allowed_length - length_threshold
|
||||
)
|
||||
# Cap at 1.0 in case length exceeds max_allowed_length
|
||||
percentage_of_range = min(percentage_of_range, 1.0)
|
||||
# Apply linear penalty scaling from 1.0 down to 0.0
|
||||
scores["scores"].append(1.0 - percentage_of_range)
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, [row for row in item["messages"]] + [final_out]
|
||||
)
|
||||
scores['tokens'].append(out_dict['tokens'])
|
||||
scores['masks'].append(out_dict['masks'])
|
||||
scores['scores'].append(final_score)
|
||||
|
||||
# tokens = out_dict["tokens"]
|
||||
# masks = out_dict["masks"]
|
||||
# # remove obviously bad examples
|
||||
# if len([1 for i in masks if i != -100]) < 10:
|
||||
# continue
|
||||
# scores["tokens"].append(tokens)
|
||||
# scores["masks"].append(masks)
|
||||
# scores["scores"].append(1.0)
|
||||
# if len(scores["tokens"]) >= self.config.group_size:
|
||||
# break
|
||||
# for score in scores["scores"]:
|
||||
# self.percent_correct_buffer.append(max(score, 0))
|
||||
# # check if all the same
|
||||
# # print(scores['scores'])
|
||||
# if all([score == 1 for score in scores["scores"]]):
|
||||
# # Do length penalty :)
|
||||
# token_lengths = [len(token) for token in scores["tokens"]]
|
||||
# if max(token_lengths) == 0:
|
||||
# # What? But don't want to crash a run so just in case...
|
||||
# return None
|
||||
|
||||
# # Get max allowed token length from config
|
||||
# max_allowed_length = self.config.max_token_length
|
||||
# # Set threshold at 50% of max_token_length - no penalty below this
|
||||
# length_threshold = max_allowed_length * 0.5
|
||||
|
||||
# # Apply modified length penalty with threshold
|
||||
# scores["scores"] = []
|
||||
# for length in token_lengths:
|
||||
# if length <= length_threshold:
|
||||
# # No penalty for responses under threshold
|
||||
# scores["scores"].append(1.0)
|
||||
# else:
|
||||
# # Calculate how far we are between threshold and max as a percentage
|
||||
# percentage_of_range = (length - length_threshold) / (
|
||||
# max_allowed_length - length_threshold
|
||||
# )
|
||||
# # Cap at 1.0 in case length exceeds max_allowed_length
|
||||
# percentage_of_range = min(percentage_of_range, 1.0)
|
||||
# # Apply linear penalty scaling from 1.0 down to 0.0
|
||||
# scores["scores"].append(1.0 - percentage_of_range)
|
||||
return scores
|
||||
|
||||
|
||||
|
|
@ -380,5 +395,11 @@ class GSM8kEnv(BaseEnv):
|
|||
return next_item
|
||||
|
||||
|
||||
def purrfect_eval(st: str) -> float:
|
||||
if "purr" in st.lower():
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kEnv.cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue