diff --git a/tools/server/server.py b/tools/server/server.py index 849db89c..c6748d0a 100644 --- a/tools/server/server.py +++ b/tools/server/server.py @@ -73,9 +73,25 @@ def create_app(config: ServerConfig) -> FastAPI: raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") return {"status": "deleted"} + def permute_index(idx: int, epoch_seed: int, dataset_size: int) -> int: + """Generate a deterministic permuted index without materializing full permutation. + + Args: + idx: Original index to permute + epoch_seed: Seed for this epoch's permutation + dataset_size: Size of the dataset + + Returns: + Permuted index in range [0, dataset_size) + """ + # Combine index with epoch seed for unique mapping + rng = Random(epoch_seed + idx) + # Generate a random number and map it to dataset size range + return rng.randrange(0, dataset_size) + @app.get("/experiments/{name}/batch", response_model=BatchResponse) async def generate_batch(name: str, base_index: int, batch_size: int, epoch: int = 0): - """Generate a batch of raw entries""" + """Generate a batch of raw entries with epoch-based shuffling""" # Validate parameters if base_index < 0: raise HTTPException(status_code=400, detail="base_index must be non-negative") @@ -87,9 +103,15 @@ def create_app(config: ServerConfig) -> FastAPI: raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") try: + dataset_size = len(experiment.dataset) + base_seed = experiment.config.seed if experiment.config.seed is not None else 0 + epoch_seed = base_seed + (epoch * dataset_size) + entries = [] for i in range(base_index, base_index + batch_size): - entry = experiment.dataset[i] + # Get permuted index for this position + shuffled_idx = permute_index(i, epoch_seed, dataset_size) + entry = experiment.dataset[shuffled_idx] # Create BatchEntry with minimal required data batch_entry = BatchEntry(