feat: Add epoch-based deterministic shuffling to generate_batch endpoint

This commit is contained in:
Andreas Koepf (aider) 2025-02-22 20:24:34 +00:00
parent 1864a54d53
commit acd078b448

View file

@ -73,9 +73,25 @@ def create_app(config: ServerConfig) -> FastAPI:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
return {"status": "deleted"}
def permute_index(idx: int, epoch_seed: int, dataset_size: int) -> int:
"""Generate a deterministic permuted index without materializing full permutation.
Args:
idx: Original index to permute
epoch_seed: Seed for this epoch's permutation
dataset_size: Size of the dataset
Returns:
Permuted index in range [0, dataset_size)
"""
# Combine index with epoch seed for unique mapping
rng = Random(epoch_seed + idx)
# Generate a random number and map it to dataset size range
return rng.randrange(0, dataset_size)
@app.get("/experiments/{name}/batch", response_model=BatchResponse)
async def generate_batch(name: str, base_index: int, batch_size: int, epoch: int = 0):
"""Generate a batch of raw entries"""
"""Generate a batch of raw entries with epoch-based shuffling"""
# Validate parameters
if base_index < 0:
raise HTTPException(status_code=400, detail="base_index must be non-negative")
@ -87,9 +103,15 @@ def create_app(config: ServerConfig) -> FastAPI:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
try:
dataset_size = len(experiment.dataset)
base_seed = experiment.config.seed if experiment.config.seed is not None else 0
epoch_seed = base_seed + (epoch * dataset_size)
entries = []
for i in range(base_index, base_index + batch_size):
entry = experiment.dataset[i]
# Get permuted index for this position
shuffled_idx = permute_index(i, epoch_seed, dataset_size)
entry = experiment.dataset[shuffled_idx]
# Create BatchEntry with minimal required data
batch_entry = BatchEntry(