mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add tasks_per_step arg to multiply by group_size for bs calculation
This commit is contained in:
parent
a26794afd2
commit
6d9523fe0b
1 changed files with 21 additions and 3 deletions
|
|
@ -33,7 +33,9 @@ def find_common_prefix(strings):
|
|||
return prefix
|
||||
|
||||
|
||||
async def register_to_api(group_size, max_token_len, api_url, num_steps):
|
||||
async def register_to_api(
|
||||
group_size, max_token_len, api_url, num_steps, tasks_per_step
|
||||
):
|
||||
"""
|
||||
Registers this data grabber instance with the Atropos API.
|
||||
|
||||
|
|
@ -45,6 +47,7 @@ async def register_to_api(group_size, max_token_len, api_url, num_steps):
|
|||
max_token_len: The maximum token length for sequences.
|
||||
api_url: The base URL of the Atropos API server.
|
||||
num_steps: The number of steps to run the API for.
|
||||
tasks_per_step: The number of tasks per step for batch size calculation.
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Reset data on the API server before registering
|
||||
|
|
@ -56,7 +59,7 @@ async def register_to_api(group_size, max_token_len, api_url, num_steps):
|
|||
json={
|
||||
"wandb_group": "test",
|
||||
"wandb_project": "test",
|
||||
"batch_size": group_size * 8,
|
||||
"batch_size": group_size * tasks_per_step,
|
||||
"max_token_len": max_token_len,
|
||||
"checkpoint_dir": "checkpoints",
|
||||
"save_checkpoint_interval": 10,
|
||||
|
|
@ -166,6 +169,7 @@ async def sft_data_grabber(
|
|||
allow_negative_scores,
|
||||
minimum_score_diff_max_min,
|
||||
append_to_previous,
|
||||
tasks_per_step,
|
||||
):
|
||||
"""
|
||||
Main asynchronous function to grab SFT data from the Atropos API.
|
||||
|
|
@ -186,6 +190,7 @@ async def sft_data_grabber(
|
|||
allow_negative_scores: Whether to allow negative scores.
|
||||
minimum_score_diff_max_min: Min score difference from group minimum.
|
||||
append_to_previous: Whether to append to an existing file or overwrite.
|
||||
tasks_per_step: Number of tasks per step for batch size calculation.
|
||||
"""
|
||||
tok = AutoTokenizer.from_pretrained(tokenizer)
|
||||
total_count = 0
|
||||
|
|
@ -208,7 +213,13 @@ async def sft_data_grabber(
|
|||
return count
|
||||
|
||||
# Register with the API first
|
||||
await register_to_api(group_size, max_token_len, api_url, num_steps=total_count)
|
||||
await register_to_api(
|
||||
group_size,
|
||||
max_token_len,
|
||||
api_url,
|
||||
num_steps=total_count,
|
||||
tasks_per_step=tasks_per_step,
|
||||
)
|
||||
|
||||
# Check for file existence before opening
|
||||
if os.path.exists(filepath) and not append_to_previous:
|
||||
|
|
@ -296,6 +307,12 @@ def main():
|
|||
action="store_true",
|
||||
help="Append to the previous file instead of overwriting it.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks-per-step",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of tasks per step for batch size calculation (batch_size = group_size * tasks_per_step).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run the main async function
|
||||
|
|
@ -312,6 +329,7 @@ def main():
|
|||
args.allow_negative_scores,
|
||||
args.minimum_score_diff_max_min,
|
||||
args.append_to_previous,
|
||||
args.tasks_per_step,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue