fix bugs in cat_server

This commit is contained in:
97hongjun 2025-05-18 16:20:42 -07:00
parent 6c18d04bce
commit 128dce55bc

View file

@ -75,7 +75,7 @@ class GSM8kEnv(BaseEnv):
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=61,
batch_size=12,
batch_size=1,
steps_per_eval=60,
max_token_length=2048,
wandb_name="gsm8k",
@ -218,7 +218,7 @@ class GSM8kEnv(BaseEnv):
messages = (
{"role": "system", "content": cat_system_prompt},
user_message,
cat_message,
{"role": "system", "content": cat_message},
{"role": "assistant", "content": caretaker_completion.message.content},
)
to_score.append(
@ -227,15 +227,66 @@ class GSM8kEnv(BaseEnv):
}
)
to_postprocess = await self.score(to_score)
# import pdb; pdb.set_trace()
return to_postprocess, to_backlog
async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
# scores = ScoredDataGroup()
# scores["tokens"] = list()
# scores["masks"] = list()
# scores["scores"] = list()
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
# random.shuffle(rollout_group_data)
for item in rollout_group_data:
out_dict = tokenize_for_trainer(
self.tokenizer, item["messages"]
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0)
if len(scores["tokens"]) >= self.config.group_size:
break
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
# check if all the same
# print(scores['scores'])
if all([score == 1 for score in scores["scores"]]):
# Do length penalty :)
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) == 0:
# What? But don't want to crash a run so just in case...
return None
# Get max allowed token length from config
max_allowed_length = self.config.max_token_length
# Set threshold at 50% of max_token_length - no penalty below this
length_threshold = max_allowed_length * 0.5
# Apply modified length penalty with threshold
scores["scores"] = []
for length in token_lengths:
if length <= length_threshold:
# No penalty for responses under threshold
scores["scores"].append(1.0)
else:
# Calculate how far we are between threshold and max as a percentage
percentage_of_range = (length - length_threshold) / (
max_allowed_length - length_threshold
)
# Cap at 1.0 in case length exceeds max_allowed_length
percentage_of_range = min(percentage_of_range, 1.0)
# Apply linear penalty scaling from 1.0 down to 0.0
scores["scores"].append(1.0 - percentage_of_range)
return scores
# gold_parsed = parse(
# rollout_group_data[0]["gold_answer"],
# extraction_mode="first_match",