diff --git a/environments/sft_loader_server.py b/environments/sft_loader_server.py index 55dfaeae..af0811d0 100644 --- a/environments/sft_loader_server.py +++ b/environments/sft_loader_server.py @@ -45,6 +45,10 @@ class SFTConfig(BaseEnvConfig): default=-1, description="The maximum number of SFTs to do per step, if -1 just sends all the data", ) + add_every_n_steps: int = Field( + default=1, + description="Only add SFT data every n steps", + ) class SFTEnv(BaseEnv): @@ -199,24 +203,26 @@ class SFTEnv(BaseEnv): async def env_step_checks(self): # Check if we need to run an eval or log... if self.curr_step != self.status_dict["current_step"]: - to_send = (self.config.max_batches_offpolicy * self.config.batch_size) - ( - self.status_dict["queue_size"] - ) - if self.config.max_sft_per_step != -1: - to_send = min(to_send, self.config.max_sft_per_step) - self.items_sent_this_step = to_send - if to_send > 0: - formatted_items = list() - for _ in range(to_send): - item = await self.get_next_item() - self.idx += 1 - formatted_items.append(await self.format_item(item)) - await asyncio.gather( - *[ - self.handle_send_to_api(formatted_item, None) - for formatted_item in formatted_items - ] - ) + next_step = self.status_dict["current_step"] + if next_step % self.config.add_every_n_steps == 0: + to_send = ( + self.config.max_batches_offpolicy * self.config.batch_size + ) - (self.status_dict["queue_size"]) + if self.config.max_sft_per_step != -1: + to_send = min(to_send, self.config.max_sft_per_step) + self.items_sent_this_step = to_send + if to_send > 0: + formatted_items = list() + for _ in range(to_send): + item = await self.get_next_item() + self.idx += 1 + formatted_items.append(await self.format_item(item)) + await asyncio.gather( + *[ + self.handle_send_to_api(formatted_item, None) + for formatted_item in formatted_items + ] + ) await super().env_step_checks()