add tasks_per_step arg to multiply by group_size for bs calculation

This commit is contained in:
teknium1 2025-06-10 01:54:52 -07:00
parent a26794afd2
commit 6d9523fe0b

View file

@ -33,7 +33,9 @@ def find_common_prefix(strings):
return prefix
async def register_to_api(group_size, max_token_len, api_url, num_steps):
async def register_to_api(
group_size, max_token_len, api_url, num_steps, tasks_per_step
):
"""
Registers this data grabber instance with the Atropos API.
@ -45,6 +47,7 @@ async def register_to_api(group_size, max_token_len, api_url, num_steps):
max_token_len: The maximum token length for sequences.
api_url: The base URL of the Atropos API server.
num_steps: The number of steps to run the API for.
tasks_per_step: The number of tasks per step for batch size calculation.
"""
async with aiohttp.ClientSession() as session:
# Reset data on the API server before registering
@ -56,7 +59,7 @@ async def register_to_api(group_size, max_token_len, api_url, num_steps):
json={
"wandb_group": "test",
"wandb_project": "test",
"batch_size": group_size * 8,
"batch_size": group_size * tasks_per_step,
"max_token_len": max_token_len,
"checkpoint_dir": "checkpoints",
"save_checkpoint_interval": 10,
@ -166,6 +169,7 @@ async def sft_data_grabber(
allow_negative_scores,
minimum_score_diff_max_min,
append_to_previous,
tasks_per_step,
):
"""
Main asynchronous function to grab SFT data from the Atropos API.
@ -186,6 +190,7 @@ async def sft_data_grabber(
allow_negative_scores: Whether to allow negative scores.
minimum_score_diff_max_min: Min score difference from group minimum.
append_to_previous: Whether to append to an existing file or overwrite.
tasks_per_step: Number of tasks per step for batch size calculation.
"""
tok = AutoTokenizer.from_pretrained(tokenizer)
total_count = 0
@ -208,7 +213,13 @@ async def sft_data_grabber(
return count
# Register with the API first
await register_to_api(group_size, max_token_len, api_url, num_steps=total_count)
await register_to_api(
group_size,
max_token_len,
api_url,
num_steps=total_count,
tasks_per_step=tasks_per_step,
)
# Check for file existence before opening
if os.path.exists(filepath) and not append_to_previous:
@ -296,6 +307,12 @@ def main():
action="store_true",
help="Append to the previous file instead of overwriting it.",
)
parser.add_argument(
"--tasks-per-step",
type=int,
default=64,
help="Number of tasks per step for batch size calculation (batch_size = group_size * tasks_per_step).",
)
args = parser.parse_args()
# Run the main async function
@ -312,6 +329,7 @@ def main():
args.allow_negative_scores,
args.minimum_score_diff_max_min,
args.append_to_previous,
args.tasks_per_step,
)
)