mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
set prompt logprobs to a masked value
This commit is contained in:
parent
c22f8ca81b
commit
c483840f59
2 changed files with 6 additions and 4 deletions
|
|
@ -1007,8 +1007,10 @@ class BaseEnv(ABC):
|
|||
* self.derived_batch_size
|
||||
// self.config.group_size
|
||||
)
|
||||
- (self.status_dict["queue_size"]),
|
||||
- (self.status_dict["self_queue_size"]),
|
||||
)
|
||||
# now minimum num workers based on allocation
|
||||
|
||||
# Now if we have a minimum batch allocation, we need to add workers to fill the self queue, in case of
|
||||
# overruns by other environments
|
||||
if self.config.min_batch_allocation is not None:
|
||||
|
|
@ -1124,7 +1126,7 @@ class BaseEnv(ABC):
|
|||
if (
|
||||
self.status_dict["current_step"]
|
||||
+ (
|
||||
self.status_dict["queue_size"]
|
||||
self.status_dict["self_queue_size"]
|
||||
* self.config.group_size
|
||||
// self.config.batch_size
|
||||
)
|
||||
|
|
@ -1134,7 +1136,7 @@ class BaseEnv(ABC):
|
|||
break
|
||||
if (
|
||||
(
|
||||
self.status_dict["queue_size"] * self.config.group_size
|
||||
self.status_dict["self_queue_size"] * self.config.group_size
|
||||
>= self.config.max_batches_offpolicy * self.config.batch_size
|
||||
)
|
||||
and (self.config.max_batches_offpolicy > 0)
|
||||
|
|
|
|||
|
|
@ -243,7 +243,7 @@ class ManagedServer:
|
|||
elif len(output_logprobs) > len(output_tokens):
|
||||
output_logprobs = output_logprobs[: len(output_tokens)]
|
||||
|
||||
full_logprobs = [0.0] * prompt_len + output_logprobs
|
||||
full_logprobs = [1.0] * prompt_len + output_logprobs
|
||||
|
||||
return SequenceNode(
|
||||
full_text=full_text,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue