set prompt logprobs to a masked value

This commit is contained in:
dmahan93 2025-10-26 11:58:55 -07:00
parent c22f8ca81b
commit c483840f59
2 changed files with 6 additions and 4 deletions

View file

@ -1007,8 +1007,10 @@ class BaseEnv(ABC):
* self.derived_batch_size
// self.config.group_size
)
- (self.status_dict["queue_size"]),
- (self.status_dict["self_queue_size"]),
)
# now minimum num workers based on allocation
# Now if we have a minimum batch allocation, we need to add workers to fill the self queue, in case of
# overruns by other environments
if self.config.min_batch_allocation is not None:
@ -1124,7 +1126,7 @@ class BaseEnv(ABC):
if (
self.status_dict["current_step"]
+ (
self.status_dict["queue_size"]
self.status_dict["self_queue_size"]
* self.config.group_size
// self.config.batch_size
)
@ -1134,7 +1136,7 @@ class BaseEnv(ABC):
break
if (
(
self.status_dict["queue_size"] * self.config.group_size
self.status_dict["self_queue_size"] * self.config.group_size
>= self.config.max_batches_offpolicy * self.config.batch_size
)
and (self.config.max_batches_offpolicy > 0)

View file

@ -243,7 +243,7 @@ class ManagedServer:
elif len(output_logprobs) > len(output_tokens):
output_logprobs = output_logprobs[: len(output_tokens)]
full_logprobs = [0.0] * prompt_len + output_logprobs
full_logprobs = [1.0] * prompt_len + output_logprobs
return SequenceNode(
full_text=full_text,