readd multistep masking

This commit is contained in:
Shannon Sands 2025-05-10 09:24:55 +10:00
parent 4d0f919fd1
commit 7fe1a40368
3 changed files with 109 additions and 6 deletions

View file

@ -24,7 +24,7 @@ from atroposlib.envs.base import (
OpenaiConfig,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer_multistep
from atroposlib.utils.tool_call_parser import parse_tool_call
logger = logging.getLogger(__name__)
@ -446,7 +446,7 @@ class BlackjackEnv(BaseEnv):
next_state_msgs_i = current_state_plus_response
alt_next_state_msgs.append(next_state_msgs_i)
tokenized_i = tokenize_for_trainer(
tokenized_i = tokenize_for_trainer_multistep(
self.tokenizer, next_state_msgs_i
)
alt_tokens.append(tokenized_i["tokens"])
@ -1087,7 +1087,7 @@ class BlackjackEnv(BaseEnv):
break
try:
tokenized_alt = tokenize_for_trainer(
tokenized_alt = tokenize_for_trainer_multistep(
self.tokenizer, working_messages[alt_idx]
)
temp_new_alt_tokens.append(tokenized_alt["tokens"])