add ability to only send every n steps

This commit is contained in:
dmahan93 2025-05-07 15:19:26 -05:00
parent def97fc38c
commit 46d3a6032a

View file

@ -45,6 +45,10 @@ class SFTConfig(BaseEnvConfig):
default=-1,
description="The maximum number of SFTs to do per step, if -1 just sends all the data",
)
add_every_n_steps: int = Field(
default=1,
description="Only add SFT data every n steps",
)
class SFTEnv(BaseEnv):
@ -199,24 +203,26 @@ class SFTEnv(BaseEnv):
async def env_step_checks(self):
# Check if we need to run an eval or log...
if self.curr_step != self.status_dict["current_step"]:
to_send = (self.config.max_batches_offpolicy * self.config.batch_size) - (
self.status_dict["queue_size"]
)
if self.config.max_sft_per_step != -1:
to_send = min(to_send, self.config.max_sft_per_step)
self.items_sent_this_step = to_send
if to_send > 0:
formatted_items = list()
for _ in range(to_send):
item = await self.get_next_item()
self.idx += 1
formatted_items.append(await self.format_item(item))
await asyncio.gather(
*[
self.handle_send_to_api(formatted_item, None)
for formatted_item in formatted_items
]
)
next_step = self.status_dict["current_step"]
if next_step % self.config.add_every_n_steps == 0:
to_send = (
self.config.max_batches_offpolicy * self.config.batch_size
) - (self.status_dict["queue_size"])
if self.config.max_sft_per_step != -1:
to_send = min(to_send, self.config.max_sft_per_step)
self.items_sent_this_step = to_send
if to_send > 0:
formatted_items = list()
for _ in range(to_send):
item = await self.get_next_item()
self.idx += 1
formatted_items.append(await self.format_item(item))
await asyncio.gather(
*[
self.handle_send_to_api(formatted_item, None)
for formatted_item in formatted_items
]
)
await super().env_step_checks()