mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix bugs in cat_server
This commit is contained in:
parent
6c18d04bce
commit
128dce55bc
1 changed files with 57 additions and 6 deletions
|
|
@ -75,7 +75,7 @@ class GSM8kEnv(BaseEnv):
|
|||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=61,
|
||||
batch_size=12,
|
||||
batch_size=1,
|
||||
steps_per_eval=60,
|
||||
max_token_length=2048,
|
||||
wandb_name="gsm8k",
|
||||
|
|
@ -218,7 +218,7 @@ class GSM8kEnv(BaseEnv):
|
|||
messages = (
|
||||
{"role": "system", "content": cat_system_prompt},
|
||||
user_message,
|
||||
cat_message,
|
||||
{"role": "system", "content": cat_message},
|
||||
{"role": "assistant", "content": caretaker_completion.message.content},
|
||||
)
|
||||
to_score.append(
|
||||
|
|
@ -227,15 +227,66 @@ class GSM8kEnv(BaseEnv):
|
|||
}
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
# import pdb; pdb.set_trace()
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
# scores = ScoredDataGroup()
|
||||
# scores["tokens"] = list()
|
||||
# scores["masks"] = list()
|
||||
# scores["scores"] = list()
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
# random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, item["messages"]
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0)
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
for score in scores["scores"]:
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
# check if all the same
|
||||
# print(scores['scores'])
|
||||
if all([score == 1 for score in scores["scores"]]):
|
||||
# Do length penalty :)
|
||||
token_lengths = [len(token) for token in scores["tokens"]]
|
||||
if max(token_lengths) == 0:
|
||||
# What? But don't want to crash a run so just in case...
|
||||
return None
|
||||
|
||||
# Get max allowed token length from config
|
||||
max_allowed_length = self.config.max_token_length
|
||||
# Set threshold at 50% of max_token_length - no penalty below this
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
|
||||
# Apply modified length penalty with threshold
|
||||
scores["scores"] = []
|
||||
for length in token_lengths:
|
||||
if length <= length_threshold:
|
||||
# No penalty for responses under threshold
|
||||
scores["scores"].append(1.0)
|
||||
else:
|
||||
# Calculate how far we are between threshold and max as a percentage
|
||||
percentage_of_range = (length - length_threshold) / (
|
||||
max_allowed_length - length_threshold
|
||||
)
|
||||
# Cap at 1.0 in case length exceeds max_allowed_length
|
||||
percentage_of_range = min(percentage_of_range, 1.0)
|
||||
# Apply linear penalty scaling from 1.0 down to 0.0
|
||||
scores["scores"].append(1.0 - percentage_of_range)
|
||||
return scores
|
||||
|
||||
|
||||
|
||||
# gold_parsed = parse(
|
||||
# rollout_group_data[0]["gold_answer"],
|
||||
# extraction_mode="first_match",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue