mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
readd multistep masking
This commit is contained in:
parent
4d0f919fd1
commit
7fe1a40368
3 changed files with 109 additions and 6 deletions
|
|
@ -24,7 +24,7 @@ from atroposlib.envs.base import (
|
|||
OpenaiConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer_multistep
|
||||
from atroposlib.utils.tool_call_parser import parse_tool_call
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -446,7 +446,7 @@ class BlackjackEnv(BaseEnv):
|
|||
next_state_msgs_i = current_state_plus_response
|
||||
alt_next_state_msgs.append(next_state_msgs_i)
|
||||
|
||||
tokenized_i = tokenize_for_trainer(
|
||||
tokenized_i = tokenize_for_trainer_multistep(
|
||||
self.tokenizer, next_state_msgs_i
|
||||
)
|
||||
alt_tokens.append(tokenized_i["tokens"])
|
||||
|
|
@ -1087,7 +1087,7 @@ class BlackjackEnv(BaseEnv):
|
|||
break
|
||||
|
||||
try:
|
||||
tokenized_alt = tokenize_for_trainer(
|
||||
tokenized_alt = tokenize_for_trainer_multistep(
|
||||
self.tokenizer, working_messages[alt_idx]
|
||||
)
|
||||
temp_new_alt_tokens.append(tokenized_alt["tokens"])
|
||||
|
|
|
|||
|
|
@ -25,7 +25,10 @@ from atroposlib.envs.base import (
|
|||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Message
|
||||
from atroposlib.utils.tokenize_for_trainer import UNMASKED_ROLES, tokenize_for_trainer
|
||||
from atroposlib.utils.tokenize_for_trainer import (
|
||||
UNMASKED_ROLES,
|
||||
tokenize_for_trainer_multistep,
|
||||
)
|
||||
from atroposlib.utils.tool_call_parser import parse_tool_call
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -491,7 +494,7 @@ class BlackjackEnv(BaseEnv):
|
|||
step_msgs.append({"role": "agent", "content": response})
|
||||
|
||||
try:
|
||||
out = tokenize_for_trainer(self.tokenizer, step_msgs)
|
||||
out = tokenize_for_trainer_multistep(self.tokenizer, step_msgs)
|
||||
alt_tokens.append(out["tokens"])
|
||||
alt_masks.append(out["masks"])
|
||||
alt_messages.append(step_msgs)
|
||||
|
|
@ -1598,7 +1601,7 @@ class BlackjackEnv(BaseEnv):
|
|||
break
|
||||
|
||||
try:
|
||||
tokenized_alt = tokenize_for_trainer(
|
||||
tokenized_alt = tokenize_for_trainer_multistep(
|
||||
self.tokenizer, working_messages[alt_idx]
|
||||
)
|
||||
temp_new_alt_tokens.append(tokenized_alt["tokens"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue