mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
feat: Add epoch-based deterministic shuffling to generate_batch endpoint
This commit is contained in:
parent
1864a54d53
commit
acd078b448
1 changed files with 24 additions and 2 deletions
|
|
@ -73,9 +73,25 @@ def create_app(config: ServerConfig) -> FastAPI:
|
|||
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
||||
return {"status": "deleted"}
|
||||
|
||||
def permute_index(idx: int, epoch_seed: int, dataset_size: int) -> int:
|
||||
"""Generate a deterministic permuted index without materializing full permutation.
|
||||
|
||||
Args:
|
||||
idx: Original index to permute
|
||||
epoch_seed: Seed for this epoch's permutation
|
||||
dataset_size: Size of the dataset
|
||||
|
||||
Returns:
|
||||
Permuted index in range [0, dataset_size)
|
||||
"""
|
||||
# Combine index with epoch seed for unique mapping
|
||||
rng = Random(epoch_seed + idx)
|
||||
# Generate a random number and map it to dataset size range
|
||||
return rng.randrange(0, dataset_size)
|
||||
|
||||
@app.get("/experiments/{name}/batch", response_model=BatchResponse)
|
||||
async def generate_batch(name: str, base_index: int, batch_size: int, epoch: int = 0):
|
||||
"""Generate a batch of raw entries"""
|
||||
"""Generate a batch of raw entries with epoch-based shuffling"""
|
||||
# Validate parameters
|
||||
if base_index < 0:
|
||||
raise HTTPException(status_code=400, detail="base_index must be non-negative")
|
||||
|
|
@ -87,9 +103,15 @@ def create_app(config: ServerConfig) -> FastAPI:
|
|||
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
||||
|
||||
try:
|
||||
dataset_size = len(experiment.dataset)
|
||||
base_seed = experiment.config.seed if experiment.config.seed is not None else 0
|
||||
epoch_seed = base_seed + (epoch * dataset_size)
|
||||
|
||||
entries = []
|
||||
for i in range(base_index, base_index + batch_size):
|
||||
entry = experiment.dataset[i]
|
||||
# Get permuted index for this position
|
||||
shuffled_idx = permute_index(i, epoch_seed, dataset_size)
|
||||
entry = experiment.dataset[shuffled_idx]
|
||||
|
||||
# Create BatchEntry with minimal required data
|
||||
batch_entry = BatchEntry(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue