diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index bf6491f9..933a3ee0 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -1007,8 +1007,10 @@ class BaseEnv(ABC): * self.derived_batch_size // self.config.group_size ) - - (self.status_dict["queue_size"]), + - (self.status_dict["self_queue_size"]), ) + # now minimum num workers based on allocation + # Now if we have a minimum batch allocation, we need to add workers to fill the self queue, in case of # overruns by other environments if self.config.min_batch_allocation is not None: @@ -1124,7 +1126,7 @@ class BaseEnv(ABC): if ( self.status_dict["current_step"] + ( - self.status_dict["queue_size"] + self.status_dict["self_queue_size"] * self.config.group_size // self.config.batch_size ) @@ -1134,7 +1136,7 @@ class BaseEnv(ABC): break if ( ( - self.status_dict["queue_size"] * self.config.group_size + self.status_dict["self_queue_size"] * self.config.group_size >= self.config.max_batches_offpolicy * self.config.batch_size ) and (self.config.max_batches_offpolicy > 0) diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 6c103b36..23ef0b73 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -243,7 +243,7 @@ class ManagedServer: elif len(output_logprobs) > len(output_tokens): output_logprobs = output_logprobs[: len(output_tokens)] - full_logprobs = [0.0] * prompt_len + output_logprobs + full_logprobs = [1.0] * prompt_len + output_logprobs return SequenceNode( full_text=full_text,