mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add ability to only send every n steps
This commit is contained in:
parent
def97fc38c
commit
46d3a6032a
1 changed files with 24 additions and 18 deletions
|
|
@ -45,6 +45,10 @@ class SFTConfig(BaseEnvConfig):
|
|||
default=-1,
|
||||
description="The maximum number of SFTs to do per step, if -1 just sends all the data",
|
||||
)
|
||||
add_every_n_steps: int = Field(
|
||||
default=1,
|
||||
description="Only add SFT data every n steps",
|
||||
)
|
||||
|
||||
|
||||
class SFTEnv(BaseEnv):
|
||||
|
|
@ -199,24 +203,26 @@ class SFTEnv(BaseEnv):
|
|||
async def env_step_checks(self):
|
||||
# Check if we need to run an eval or log...
|
||||
if self.curr_step != self.status_dict["current_step"]:
|
||||
to_send = (self.config.max_batches_offpolicy * self.config.batch_size) - (
|
||||
self.status_dict["queue_size"]
|
||||
)
|
||||
if self.config.max_sft_per_step != -1:
|
||||
to_send = min(to_send, self.config.max_sft_per_step)
|
||||
self.items_sent_this_step = to_send
|
||||
if to_send > 0:
|
||||
formatted_items = list()
|
||||
for _ in range(to_send):
|
||||
item = await self.get_next_item()
|
||||
self.idx += 1
|
||||
formatted_items.append(await self.format_item(item))
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.handle_send_to_api(formatted_item, None)
|
||||
for formatted_item in formatted_items
|
||||
]
|
||||
)
|
||||
next_step = self.status_dict["current_step"]
|
||||
if next_step % self.config.add_every_n_steps == 0:
|
||||
to_send = (
|
||||
self.config.max_batches_offpolicy * self.config.batch_size
|
||||
) - (self.status_dict["queue_size"])
|
||||
if self.config.max_sft_per_step != -1:
|
||||
to_send = min(to_send, self.config.max_sft_per_step)
|
||||
self.items_sent_this_step = to_send
|
||||
if to_send > 0:
|
||||
formatted_items = list()
|
||||
for _ in range(to_send):
|
||||
item = await self.get_next_item()
|
||||
self.idx += 1
|
||||
formatted_items.append(await self.format_item(item))
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.handle_send_to_api(formatted_item, None)
|
||||
for formatted_item in formatted_items
|
||||
]
|
||||
)
|
||||
await super().env_step_checks()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue