allow inf (<= 0 max_token_len) generations if trainer requests it, but raise a warning so that users can check their logs and get info if their trainers are doing something weird

This commit is contained in:
Dakota 2025-07-01 09:52:10 -05:00
parent 0ff966c276
commit 683559afd2

View file

@ -675,7 +675,12 @@ class BaseEnv(ABC):
for mask in group["masks"]:
self.completion_lengths.append(len(mask))
if abort_on_any_max_length_exceeded and any(
if self.max_token_len <= 0:
warnings.warn(
f"Trainer requested to ignore max length by setting max_token_len to {self.max_token_len}, "
"ensure your trainer handles this appropriately."
)
elif abort_on_any_max_length_exceeded and any(
[len(x) >= self.max_token_len for x in group["tokens"]]
):
logger.warning("Token length is too long in a group, skipping...")