diff --git a/CONFIG.md b/CONFIG.md index a1c502fe..5434f0e5 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -30,6 +30,8 @@ Basic environment configuration settings. | `ensure_scores_are_not_same` | `bool` | `True` | Ensure that scores within a group are not identical (usually `True`). | | `data_path_to_save_groups` | `str | None` | `None` | Path to save generated groups as a JSONL file. If set, groups will be written here. | | `min_items_sent_before_logging` | `int` | `2` | Minimum number of items sent to the API before logging metrics. `0` or less logs every time. | +| `include_messages` | `bool` | `False` | Whether to include messages in the output transmitted to the trainer. | +| `min_batch_allocation` | `float | None` | `None` | Minimum proportion of a batch this environment should be allocated (0.0-1.0). Ensures this env contributes at least this fraction to each training batch. | ## Server Manager Configuration (`atroposlib.envs.server_handling.server_manager.ServerManagerConfig`) diff --git a/atroposlib/api/README.md b/atroposlib/api/README.md index 6430ab0a..4f7de5bb 100644 --- a/atroposlib/api/README.md +++ b/atroposlib/api/README.md @@ -18,6 +18,7 @@ This service specifically handles the **experience data pathway**: * Registration endpoints for Trainers and Rollout Handlers. * Serves batches of aggregated experience data to Trainers. * Supports heterogeneous environments with weighting (via `/register-env` weight and internal batching). +* Minimum batch allocation support to guarantee certain environments contribute a minimum proportion to each batch. * Provides status endpoints for monitoring queue size and training step count. * Basic integration with Weights & Biases (W&B) project/group info. * Endpoints for Rollout Handlers to disconnect gracefully. @@ -97,6 +98,8 @@ The API documentation (Swagger UI) will be available at `http:// max_token_length: int # Max length this env produces desired_name: str # Base name for identification/logging weight: float # Weight for sampling/batching (e.g., 1.0) + group_size: int # Expected number of sequences per data submission + min_batch_allocation: Optional[float] = None # Minimum proportion of batch (0.0-1.0) ``` * **Response:** Provides assigned ID, unique W&B name, checkpoint info. ```json @@ -135,14 +138,17 @@ The API documentation (Swagger UI) will be available at `http:// overrides: Optional[List[dict]] = None # Per-item logging overrides group_overrides: Optional[dict] = None # Group logging overrides images: Optional[Any] = None # Image data (if applicable) + env_id: Optional[int] = None # ID of the environment that generated this data ``` - * **Response:** `{"status": "received"}` + * **Response:** + * Normal submission: `{"status": "received"}` + * Mixed-size group buffered: `{"status": "buffered", "buffer_size": }` * `POST /scored_data_list` * **Description:** Endpoint for Rollout Handlers to push a list of `ScoredData` chunks. * **Request Body:** `List[ScoredData]` * **Response:** `{"status": "received", "groups_processed": }` * `GET /batch` - * **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic (`grab_exact_from_heterogeneous_queue`) to form a batch of the configured size from the available data in the queue, potentially respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned. + * **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic to form a batch of the configured size from the available data in the queue. If any environments have minimum batch allocations specified, it uses `grab_batch_with_minimum_allocations` to ensure each environment gets at least its minimum proportion of the batch. Otherwise, it uses `grab_exact_from_heterogeneous_queue` to form batches respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned. * **Response:** * Success: `{"batch": [, ..., ]}` where each `data_item` matches the structure pushed via `/scored_data`. * Not enough data: `{"batch": null}` @@ -178,6 +184,57 @@ The API documentation (Swagger UI) will be available at `http:// * mermaid diagram of how a rollout handler interacts with the api is located [here](env_interaction.md). 6. **Shutdown:** Handlers may call `POST /disconnect-env`. +## Minimum Batch Allocation Feature + +The API supports ensuring minimum batch allocations for specific environments. This feature is useful when you want to guarantee that certain environments contribute at least a minimum proportion of sequences to each training batch. + +### How It Works + +1. **Environment Registration**: When registering an environment via `/register-env`, you can specify: + - `min_batch_allocation` (Optional[float]): A value between 0.0 and 1.0 representing the minimum proportion of the batch this environment should contribute + - `group_size` (int): The expected number of sequences per data submission from this environment + +2. **Batch Formation**: When the trainer requests a batch via `/batch`: + - If any environment has a `min_batch_allocation` specified, the system uses special logic to ensure minimums are met + - The system attempts to allocate at least `min_batch_allocation * batch_size` sequences from each environment with a minimum + - If the sum of all minimum allocations exceeds 1.0, they are proportionally scaled down + - If an environment with a minimum allocation has no data available, the batch formation fails (returns null) + +3. **Mixed-Size Group Handling**: When an environment submits data with a different number of sequences than its declared `group_size`: + - The data is buffered separately for that environment + - The system attempts to combine buffered groups to match the expected `group_size` + - Once combined, the data is added to the main queue + - Response includes `{"status": "buffered", "buffer_size": }` + +### Example Configuration + +```python +# Environment 1: Requires at least 30% of each batch +{ + "max_token_length": 512, + "desired_name": "critical_env", + "weight": 1.0, + "group_size": 4, + "min_batch_allocation": 0.3 # 30% minimum +} + +# Environment 2: No minimum requirement +{ + "max_token_length": 512, + "desired_name": "standard_env", + "weight": 1.0, + "group_size": 2, + "min_batch_allocation": None # No minimum +} +``` + +### Important Notes + +- Minimum allocations are enforced per batch, not globally +- If minimum allocations cannot be satisfied (e.g., not enough data from a required environment), batch formation fails +- Environments without `min_batch_allocation` fill the remaining batch space after minimums are satisfied +- The feature respects heterogeneous packing constraints when forming batches + ## Limitations & TODOs * **In-Memory State:** The primary limitation is that all queues, configurations, and states are stored in the FastAPI application's memory (`app.state`). diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 37064bc9..3a0a9b15 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -7,7 +7,11 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import PlainTextResponse from pydantic import BaseModel, field_validator -from atroposlib.api.utils import grab_exact_from_heterogeneous_queue +from atroposlib.api.utils import ( + find_groups_summing_to_target, + grab_batch_with_minimum_allocations, + grab_exact_from_heterogeneous_queue, +) # Message import removed - using Dict[str, Any] for more flexible validation @@ -42,6 +46,10 @@ class RegisterEnv(BaseModel): max_token_length: int desired_name: str weight: float + group_size: int + min_batch_allocation: Optional[float] = ( + None # Minimum proportion of a batch this env should be allocated (0.0-1.0) + ) class EnvIdentifier(BaseModel): @@ -60,6 +68,7 @@ class ScoredData(BaseModel): overrides: Optional[List[dict]] = None group_overrides: Optional[dict] = None images: Optional[Any] = None + env_id: Optional[int] = None # ID of the environment that generated this data @field_validator("messages", mode="before") @classmethod @@ -115,6 +124,7 @@ async def register(registration: Registration): app.state.curr_batch = [] app.state.started = False app.state.envs = [] + app.state.buffer = {} # Buffer for mixed-size groups per environment try: app.state.requesters.append(uuid.uuid4().int) except AttributeError: @@ -157,6 +167,8 @@ async def register_env_url(register_env: RegisterEnv): "registered_id": registered_id, "last_update": time.time(), "connected": True, + "min_batch_allocation": register_env.min_batch_allocation, + "group_size": register_env.group_size, } ) return { @@ -207,14 +219,31 @@ async def get_batch(): return {"batch": app.state.curr_batch.pop()} else: new_batches = [] - batch, app.state.queue = grab_exact_from_heterogeneous_queue( - app.state.queue, app.state.batchsize + # Check if any envs have minimum allocations + has_min_allocations = any( + env.get("min_batch_allocation") is not None + for env in getattr(app.state, "envs", []) ) - while batch is not None: - new_batches.append(batch) + + if has_min_allocations: + batch, app.state.queue = grab_batch_with_minimum_allocations( + app.state.queue, app.state.batchsize, app.state.envs + ) + else: batch, app.state.queue = grab_exact_from_heterogeneous_queue( app.state.queue, app.state.batchsize ) + + while batch is not None: + new_batches.append(batch) + if has_min_allocations: + batch, app.state.queue = grab_batch_with_minimum_allocations( + app.state.queue, app.state.batchsize, app.state.envs + ) + else: + batch, app.state.queue = grab_exact_from_heterogeneous_queue( + app.state.queue, app.state.batchsize + ) steps_to_take = len(new_batches) if steps_to_take == 0: return {"batch": None} @@ -224,7 +253,7 @@ async def get_batch(): app.state.curr_batch.append(batch) curr_batch = app.state.curr_batch.pop() # check length before sending - print(f"Sending batch of length {sum(len(x['tokens']) for x in curr_batch)}") + print(f"Sending batch of {sum(len(x['tokens']) for x in curr_batch)} sequences") return {"batch": curr_batch} @@ -246,20 +275,58 @@ async def get_latest_example(): @app.post("/scored_data") async def scored_data(scored_data: ScoredData): - app.state.queue.append( - { - "tokens": scored_data.tokens, - "masks": scored_data.masks, - "scores": scored_data.scores, - "advantages": scored_data.advantages, - "ref_logprobs": scored_data.ref_logprobs, - "messages": scored_data.messages, - "overrides": scored_data.overrides, - "group_overrides": scored_data.group_overrides, - "images": scored_data.images, - } - ) - app.state.latest = app.state.queue[-1] + data_dict = { + "tokens": scored_data.tokens, + "masks": scored_data.masks, + "scores": scored_data.scores, + "advantages": scored_data.advantages, + "ref_logprobs": scored_data.ref_logprobs, + "messages": scored_data.messages, + "overrides": scored_data.overrides, + "group_overrides": scored_data.group_overrides, + "images": scored_data.images, + "env_id": scored_data.env_id, + } + + # Check if this is a mixed-size group + env_id = scored_data.env_id + if env_id is not None and env_id < len(app.state.envs): + expected_group_size = app.state.envs[env_id].get("group_size", 1) + actual_group_size = len(scored_data.tokens) + + if actual_group_size != expected_group_size: + # Mixed size group - add to buffer + if env_id not in app.state.buffer: + app.state.buffer[env_id] = [] + + app.state.buffer[env_id].append(data_dict) + + # Try to find groups that sum to expected_group_size + indices = find_groups_summing_to_target( + app.state.buffer[env_id], expected_group_size + ) + + if indices: + # Add these groups to queue in order + groups_to_add = [] + for idx in sorted(indices, reverse=True): + groups_to_add.append(app.state.buffer[env_id].pop(idx)) + + # Add in FIFO order + for group in reversed(groups_to_add): + app.state.queue.append(group) + app.state.latest = group + + return { + "status": "buffered", + "buffer_size": sum( + len(g["tokens"]) for g in app.state.buffer.get(env_id, []) + ), + } + + # Normal path - correct size or no env info + app.state.queue.append(data_dict) + app.state.latest = data_dict return {"status": "received"} @@ -267,24 +334,57 @@ async def scored_data(scored_data: ScoredData): async def scored_data_list(scored_data_list: List[ScoredData]): """Handle a list of ScoredData objects for step-based learning""" - for idx, scored_data in enumerate(scored_data_list): + # Process each scored data item + for scored_data in scored_data_list: + data_dict = { + "tokens": scored_data.tokens, + "masks": scored_data.masks, + "scores": scored_data.scores, + "advantages": scored_data.advantages, + "ref_logprobs": scored_data.ref_logprobs, + "images": scored_data.images, + "messages": scored_data.messages, + "overrides": scored_data.overrides, + "group_overrides": scored_data.group_overrides, + "env_id": scored_data.env_id, + } - app.state.queue.append( - { - "tokens": scored_data.tokens, - "masks": scored_data.masks, - "scores": scored_data.scores, - "advantages": scored_data.advantages, - "ref_logprobs": scored_data.ref_logprobs, - "images": scored_data.images, - "messages": scored_data.messages, - "overrides": scored_data.overrides, - "group_overrides": scored_data.group_overrides, - } - ) + # Check if this is a mixed-size group + env_id = scored_data.env_id + if env_id is not None and env_id < len(app.state.envs): + expected_group_size = app.state.envs[env_id].get("group_size", 1) + actual_group_size = len(scored_data.tokens) - if scored_data_list: - app.state.latest = app.state.queue[-1] + if actual_group_size != expected_group_size: + # Mixed size group - add to buffer + if env_id not in app.state.buffer: + app.state.buffer[env_id] = [] + + app.state.buffer[env_id].append(data_dict) + + # Try to find groups that sum to expected_group_size + indices = find_groups_summing_to_target( + app.state.buffer[env_id], expected_group_size + ) + + if indices: + # Add these groups to queue in order + groups_to_add = [] + for idx in sorted(indices, reverse=True): + groups_to_add.append(app.state.buffer[env_id].pop(idx)) + + # Add in FIFO order + for group in reversed(groups_to_add): + app.state.queue.append(group) + app.state.latest = group + else: + # Normal size - add directly to queue + app.state.queue.append(data_dict) + app.state.latest = data_dict + else: + # No env info or normal path - add directly to queue + app.state.queue.append(data_dict) + app.state.latest = data_dict return {"status": "received", "groups_processed": len(scored_data_list)} @@ -309,6 +409,7 @@ async def get_status_env(env: EnvIdentifier): if x["connected"] ] ) + env_group_size = app.state.envs[env.env_id]["group_size"] env_weight = ( app.state.envs[env.env_id]["max_context_len"] * app.state.envs[env.env_id]["weight"] @@ -318,13 +419,95 @@ async def get_status_env(env: EnvIdentifier): 0.01, env_weight ) # Minimum weight of 0.01 :) TODO: try to figure out a better way to do this + # Calculate total minimum allocations + total_min_allocation = 0.0 + for env_config in app.state.envs: + if ( + env_config.get("connected", False) + and env_config.get("min_batch_allocation") is not None + ): + total_min_allocation += env_config["min_batch_allocation"] + + # Calculate unallocated fraction + unallocated_fraction = 1.0 - min(total_min_allocation, 1.0) + + # Find the maximum group size across all items in queue + queue = getattr(app.state, "queue", []) + max_group_size = 1 + num_self_sequences_in_queue = 0 + for item in queue: + group_size = len(item.get("tokens", [])) + if group_size > max_group_size: + max_group_size = group_size + if item.get("env_id") == env.env_id: + # update the group size for the requesting env, handle cases where the group size may be dynamic with max + env_group_size = max(env_group_size, group_size) + num_self_sequences_in_queue += group_size + + # update the group size for the requesting env + app.state.envs[env.env_id]["group_size"] = env_group_size + + # Calculate minimum sequences allocated to each environment + batch_size = getattr(app.state, "batchsize", 0) + min_sequences_by_env = {} + for env_config in app.state.envs: + if ( + env_config.get("connected", False) + and env_config.get("min_batch_allocation") is not None + ): + env_id = env_config["registered_id"] + min_sequences = int(batch_size * env_config["min_batch_allocation"]) + min_sequences_by_env[env_id] = min_sequences + + # Count sequences and calculate packed groups for each environment + import math + + sequences_by_env = {} + packed_groups_by_env = {} + curr_env_total_sequences = 0 + + for item in queue: + env_id = item.get("env_id") + seq_count = len(item.get("tokens", [])) + + # Special handling for the requesting environment + if env_id == env.env_id: + curr_env_total_sequences += seq_count + else: + if env_id not in sequences_by_env: + sequences_by_env[env_id] = 0 + sequences_by_env[env_id] += seq_count + + # Calculate packed groups for each environment (excluding the requesting env) + if max_group_size > 1: + for env_id, seq_count in sequences_by_env.items(): + packed_groups_by_env[env_id] = math.ceil(seq_count / max_group_size) + + # Calculate adjusted queue size + # (curr_env_total_sequences + sum of available sequences from other envs after their minimums) + available_from_others = 0 + for env_id in packed_groups_by_env: + packed_sequences = packed_groups_by_env[env_id] * max_group_size + min_sequences = min_sequences_by_env.get(env_id, 0) + available_from_others += max(0, packed_sequences - min_sequences) + + env_queue_size = curr_env_total_sequences + available_from_others + try: ret_dict = { "current_step": app.state.status_dict["step"], - "queue_size": len(app.state.queue), + "queue_size": env_queue_size // env_group_size, + "unallocated_fraction": unallocated_fraction, + "self_queue_size": num_self_sequences_in_queue // env_group_size, + "max_group_size": max_group_size, } except AttributeError: - ret_dict = {"current_step": 0, "queue_size": 0} + ret_dict = { + "current_step": 0, + "queue_size": 0, + "unallocated_fraction": 1.0, + "num_self_sequences_in_queue": 0, + } ret_dict["env_weight"] = env_weight return ret_dict @@ -342,6 +525,7 @@ async def reset_data(): app.state.started = False app.state.requesters = [] app.state.envs = [] + app.state.buffer = {} except KeyError: pass return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK) diff --git a/atroposlib/api/utils.py b/atroposlib/api/utils.py index bf89e4e6..f2d0bc50 100644 --- a/atroposlib/api/utils.py +++ b/atroposlib/api/utils.py @@ -1,47 +1,104 @@ from typing import Dict, List, Optional, Tuple +def find_groups_summing_to_target(buffer: List[Dict], target_size: int) -> List[int]: + """ + Find indices of groups in buffer that sum exactly to target_size. + Prioritizes FIFO order. + + :param buffer: Buffer of groups from same env + :param target_size: Target sum of group sizes + :return: List of indices that sum to target_size, or empty list if impossible + """ + if not buffer: + return [] + + # First try simple FIFO + current_sum = 0 + indices = [] + + for i, group in enumerate(buffer): + size = len(group["tokens"]) + if current_sum + size <= target_size: + indices.append(i) + current_sum += size + if current_sum == target_size: + return indices + + # If FIFO doesn't work exactly, try dynamic programming + # to find any valid combination (still preferring earlier indices) + n = len(buffer) + sizes = [len(g["tokens"]) for g in buffer] + + # dp[i][j] = can we make sum j using first i groups + dp = [[False] * (target_size + 1) for _ in range(n + 1)] + dp[0][0] = True + + for i in range(1, n + 1): + for j in range(target_size + 1): + # Don't take group i-1 + dp[i][j] = dp[i - 1][j] + # Take group i-1 if possible + if j >= sizes[i - 1]: + dp[i][j] = dp[i][j] or dp[i - 1][j - sizes[i - 1]] + + if not dp[n][target_size]: + return [] + + # Backtrack to find indices, preferring earlier ones + indices = [] + j = target_size + for i in range(n, 0, -1): + if j >= sizes[i - 1] and dp[i - 1][j - sizes[i - 1]]: + indices.append(i - 1) + j -= sizes[i - 1] + + return sorted(indices) # Return in FIFO order + + def grab_exact_from_heterogeneous_queue( queue: List[Dict[str, List]], batch_size: int ) -> Tuple[Optional[List], List]: """ - Grabs a batch of size batchsize from a queue of different sized items + Grabs a batch of exactly batch_size sequences from a queue of items with different group sizes. + Each item in the queue has a 'tokens' field containing a list of sequences. e.g. queue = [{"tokens": [[1, 2, 3],[4, 5, 6, 7, 8]]}, {"tokens": [[9, 10]]}] + where the first item has 2 sequences and the second has 1 sequence. - without going over the batchsize. This function will return a batch of size batchsize, and the new queue. + This function returns a batch containing exactly batch_size sequences total, and the remaining queue. Because all groups are a common denominator of the batchsize, and all groups are a power of 2, we can simplify a bit by assuming we can grab groups of groups to be equal to the maximum group size. - Note that we cannot drop items from groups, so we must grab the entire group if we grab it. + Note that we cannot split items, so we must take the entire item with all its sequences. There may be a more efficient clearing mechanism by grouping these smaller groups heterogeneously, but forcing them all into powers of two groups is a simple way to ensure we can grab a batch of the correct size. - :param queue: - :param batch_size: + :param queue: List of items, each with a 'tokens' field containing sequences + :param batch_size: Target number of sequences for the batch :return: batch, new_queue """ - # Pass 1: precompute group sizes, total tokens and early exit if not enough tokens. + # Pass 1: precompute group sizes, total sequences and early exit if not enough sequences. total_groups = len(queue) if total_groups == 0: return None, queue group_sizes = [] lengths = [] - total_tokens = 0 + total_sequences = 0 max_group_size = 0 for item in queue: - length = len(item["tokens"]) + length = len(item["tokens"]) # Number of sequences in this group lengths.append(length) group_sizes.append(length) - total_tokens += length + total_sequences += length if length > max_group_size: max_group_size = length - if total_tokens < batch_size: + if total_sequences < batch_size: return None, queue group_sizes_set = set(group_sizes) @@ -55,30 +112,168 @@ def grab_exact_from_heterogeneous_queue( potential_batch_indices.extend(group_batching_storage[group_size]) group_batching_storage[group_size].clear() # much faster than = [] - # Calculate total batch tokens only once (avoid repeated sums) - potential_batch_token_total = sum(lengths[i] for i in potential_batch_indices) - if potential_batch_token_total < batch_size: + # Calculate total sequences in potential batch only once (avoid repeated sums) + potential_batch_sequences_total = sum(lengths[i] for i in potential_batch_indices) + if potential_batch_sequences_total < batch_size: return None, queue # Batch selection batch = [] batch_indices = [] - running_tokens = 0 + running_seqs = 0 for idx in potential_batch_indices: group = queue[idx] batch.append(group) batch_indices.append(idx) - running_tokens += lengths[idx] - if running_tokens == batch_size: + running_seqs += lengths[idx] + if running_seqs == batch_size: break - elif running_tokens > batch_size: + elif running_seqs > batch_size: # Should never happen due to problem constraints, but sanity check return None, queue - if running_tokens != batch_size: + if running_seqs != batch_size: return None, queue # Construct new_queue with a single pass, using a set for O(1) lookup batch_indices_set = set(batch_indices) new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set] return batch, new_queue + + +def grab_batch_with_minimum_allocations( + queue: List[Dict[str, any]], batch_size: int, env_configs: List[Dict[str, any]] +) -> Tuple[Optional[List], List]: + """ + Grabs a batch from the queue while respecting minimum allocation requirements for environments. + This function works with groups where each group contains multiple sequences. + + :param queue: List of groups with env_id field and sequences (stored in 'tokens' field) + :param batch_size: Target batch size in sequences + :param env_configs: List of environment configs with min_batch_allocation field + :return: batch, new_queue + """ + if not queue: + return None, queue + + # Build env_id to min allocation mapping + env_min_allocations = {} + for env in env_configs: + if env.get("connected", False) and env.get("min_batch_allocation") is not None: + env_min_allocations[env["registered_id"]] = env["min_batch_allocation"] + + # If no minimum allocations, fall back to original function + if not env_min_allocations: + return grab_exact_from_heterogeneous_queue(queue, batch_size) + + # First, find the maximum group size across all items + max_group_size = 0 + for item in queue: + group_size = len(item.get("tokens", [])) + if group_size > max_group_size: + max_group_size = group_size + + # Group queue items by env_id and calculate which can form complete packs + items_by_env = {} + packable_items_by_env = {} + + for i, item in enumerate(queue): + env_id = item.get("env_id") + group_size = len(item.get("tokens", [])) + + if env_id is not None: + if env_id not in items_by_env: + items_by_env[env_id] = {} + packable_items_by_env[env_id] = [] + + if group_size not in items_by_env[env_id]: + items_by_env[env_id][group_size] = [] + + items_by_env[env_id][group_size].append((i, item, group_size)) + + # Check if we can form a complete pack + items_of_size = items_by_env[env_id][group_size] + if len(items_of_size) * group_size == max_group_size: + # We have a complete pack! + packable_items_by_env[env_id].extend(items_of_size) + items_by_env[env_id][group_size] = [] + + # Calculate minimum sequences needed per env + min_sequences_per_env = {} + total_min_sequences = 0 + for env_id, min_proportion in env_min_allocations.items(): + min_sequences = int(batch_size * min_proportion) + if min_sequences > 0: + # Check if this env has any items in the queue at all + if env_id not in items_by_env: + # This env has a minimum but no items - can't satisfy minimum + return None, queue + # Check if this env has any packable items + if env_id not in packable_items_by_env or not packable_items_by_env[env_id]: + # This env has items but no packable items - can't satisfy minimum + return None, queue + min_sequences_per_env[env_id] = min_sequences + total_min_sequences += min_sequences + + # If minimums exceed batch size, scale them down proportionally + if total_min_sequences > batch_size: + scale_factor = batch_size / total_min_sequences + for env_id in min_sequences_per_env: + # Ensure at least one pack from each env with minimum + if packable_items_by_env.get(env_id): + min_group_size = min(g[2] for g in packable_items_by_env[env_id]) + min_sequences_per_env[env_id] = max( + min_group_size, + int(min_sequences_per_env[env_id] * scale_factor), + ) + + # Build batch ensuring minimums are met + batch = [] + batch_indices = [] + sequences_taken_per_env = {env_id: 0 for env_id in packable_items_by_env} + total_sequences = 0 + + # First pass: satisfy minimum requirements using packable items + for env_id, min_sequences in min_sequences_per_env.items(): + if env_id in packable_items_by_env: + # Take packable items in order (FIFO) + for idx, item, group_size in packable_items_by_env[env_id]: + if sequences_taken_per_env[env_id] >= min_sequences: + break + if total_sequences + group_size <= batch_size: + batch.append(item) + batch_indices.append(idx) + sequences_taken_per_env[env_id] += group_size + total_sequences += group_size + + # Second pass: fill remaining slots with packable items from any env + if total_sequences < batch_size: + # Collect all remaining packable items in queue order + all_packable = [] + for i, item in enumerate(queue): + if i not in batch_indices: + # Check if this item is in any env's packable list + env_id = item.get("env_id") + if env_id in packable_items_by_env: + for idx, packable_item, size in packable_items_by_env[env_id]: + if idx == i: + all_packable.append((i, item, size)) + break + + # Take packable items in queue order + for idx, item, group_size in all_packable: + if total_sequences + group_size <= batch_size: + batch.append(item) + batch_indices.append(idx) + total_sequences += group_size + if total_sequences == batch_size: + break + + # If we couldn't form a full batch, return None + if total_sequences != batch_size: + return None, queue + + # Construct new queue + batch_indices_set = set(batch_indices) + new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set] + return batch, new_queue diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 28c043a8..99917546 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import math import os import random import string @@ -162,6 +163,14 @@ class BaseEnvConfig(BaseModel): default=False, description="Whether to include messages in the output transmitted to the trainer", ) + min_batch_allocation: Optional[float] = Field( + default=None, + description="Minimum proportion of a batch this environment should be allocated (0.0-1.0)", + ) + worker_timeout: float = Field( + default=600, + description="Timeout for a a task, in seconds, if -1, no timeout", + ) class BaseEnv(ABC): @@ -237,6 +246,26 @@ class BaseEnv(ABC): else: self.jsonl_writer = None + @property + def derived_batch_size(self): + """Calculate the effective batch size for this environment based on minimum allocations.""" + # If batch_size is not set or no status yet, return the config batch_size + if not hasattr(self, "status_dict") or self.config.batch_size == -1: + return self.config.batch_size + + # Get unallocated fraction from status + unallocated_fraction = self.status_dict.get("unallocated_fraction", 1.0) + + # If this env has a minimum allocation, add it to the unallocated portion + if self.config.min_batch_allocation is not None: + effective_fraction = unallocated_fraction + self.config.min_batch_allocation + else: + # This env competes for the unallocated portion based on its weight + effective_fraction = unallocated_fraction + + # Calculate derived batch size + return int(self.config.batch_size * effective_fraction) + @classmethod def config_init( cls, @@ -434,6 +463,8 @@ class BaseEnv(ABC): "max_token_length": self.config.max_token_length, "desired_name": self.config.wandb_name, "weight": self.config.inference_weight, + "min_batch_allocation": self.config.min_batch_allocation, + "group_size": self.config.group_size, }, ) as resp: data = await parse_http_response(resp, logger) @@ -614,6 +645,13 @@ class BaseEnv(ABC): """ Send scored data to the API with retry logic for timeouts and server errors. """ + # Add env_id to the data + if isinstance(scored_data, list): + for item in scored_data: + item["env_id"] = getattr(self, "env_id", None) + else: + scored_data["env_id"] = getattr(self, "env_id", None) + url = ( f"{self.config.rollout_server_url}/scored_data_list" if isinstance(scored_data, list) @@ -736,7 +774,7 @@ class BaseEnv(ABC): """ Handle the rollout of an item """ - item = self.running_items.get(item_uuid) + item = self.running_items.get(item_uuid)["item"] if item is None: print(f"item {item_uuid} not found... returning") return None @@ -813,7 +851,9 @@ class BaseEnv(ABC): self.eval_runner = eval_task if self.config.eval_handling == EvalHandlingEnum.STOP_TRAIN: # Stop training if eval is running - self.backlog.extend(self.running_items.values()) + self.backlog.extend( + [x["item"] for x in self.running_items.values()] + ) for worker in self.workers: worker.cancel() self.workers = set() @@ -852,16 +892,72 @@ class BaseEnv(ABC): max_num_workers, ( self.config.max_batches_offpolicy - * self.config.batch_size + * self.derived_batch_size // self.config.group_size ) - (self.status_dict["queue_size"]), ) + # Now if we have a minimum batch allocation, we need to add workers to fill the self queue, in case of + # overruns by other environments + if self.config.min_batch_allocation is not None: + min_workers_to_fill_self_queue = max( + 0, + math.ceil( + ( + ( + ( + math.ceil( + self.config.min_batch_allocation + * self.config.batch_size + * self.config.max_batches_offpolicy + / self.status_dict["max_group_size"] + ) + + ( + self.status_dict["max_group_size"] + // self.config.group_size + ) + ) + * self.status_dict["max_group_size"] + ) + - ( + ( + self.status_dict["max_group_size"] + * self.status_dict["self_queue_size"] + // ( + self.status_dict["max_group_size"] + / self.config.group_size + ) + ) + ) + ) + / self.config.group_size + ), + ) + max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue) + print( + f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, " + f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}", + flush=True, + ) if (self.curr_step == 0) and (len(self.workers) == 0): # We are starting up, so we should just skip the append to the list pass else: self.workers_added_list.append(max_num_workers - len(self.workers)) + if len(self.workers) > max_num_workers: + print( + f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, " + "sending workers to backlog", + flush=True, + ) + num_to_reduce = len(self.workers) - max_num_workers + running_items_to_remove = list(self.running_items.keys())[:num_to_reduce] + for item_uuid in running_items_to_remove: + self.backlog.append(self.running_items[item_uuid]["item"]) + self.running_items[item_uuid]["worker"].cancel() + self.workers.discard(self.running_items[item_uuid]["worker"]) + self.running_items.pop(item_uuid) + while len(self.workers) < max_num_workers: # Generate a UUID for tracking this item item_uuid = str(uuid.uuid4()) @@ -871,8 +967,12 @@ class BaseEnv(ABC): item = await self.get_next_item() if item is None: break - self.running_items[item_uuid] = item worker = asyncio.create_task(self.handle_env(item_uuid)) + self.running_items[item_uuid] = { + "item": item, + "worker": worker, + "start_time": time.time(), + } self.workers.add(worker) worker.add_done_callback( lambda fut, i=item: ( @@ -926,9 +1026,32 @@ class BaseEnv(ABC): >= self.config.max_batches_offpolicy * self.config.batch_size ) and (self.config.max_batches_offpolicy > 0) - ) or (self.config.batch_size == -1): + and ( + (self.config.min_batch_allocation is None) + or ( + ( + ( + ( + math.ceil( + self.config.min_batch_allocation + * self.config.batch_size + * self.config.max_batches_offpolicy + / self.status_dict["max_group_size"] + ) + * ( + self.status_dict["max_group_size"] + // self.config.group_size + ) + ) + ) + - (self.status_dict["self_queue_size"]) + ) + <= 0 + ) + ) + ) or (self.derived_batch_size == -1): # We have too many, lets cleanup the tasks and wait a bit - self.backlog.extend(self.running_items.values()) + self.backlog.extend([x["item"] for x in self.running_items.values()]) for worker in self.workers: worker.cancel() self.running_items = dict() @@ -937,6 +1060,18 @@ class BaseEnv(ABC): pass else: await self.add_train_workers() + # cleanup workers that have timed out + if self.config.worker_timeout > 0: + for item_uuid, item in list(self.running_items.items()): + if time.time() - item["start_time"] > self.config.worker_timeout: + logger.warning( + f"Worker {item_uuid} has timed out after {time.time() - item['start_time']} seconds" + ) + item["worker"].cancel() + self.workers.discard(item["worker"]) + self.running_items.pop(item_uuid) + # Do we want to retry? probably not... + # self.backlog.append(item["item"]) await asyncio.sleep(0.1) async def process_manager(self): diff --git a/atroposlib/tests/test_utils/test_heterogeneous_packing.py b/atroposlib/tests/test_utils/test_heterogeneous_packing.py new file mode 100644 index 00000000..395fe610 --- /dev/null +++ b/atroposlib/tests/test_utils/test_heterogeneous_packing.py @@ -0,0 +1,154 @@ +"""Tests for heterogeneous group packing utility.""" + +import pytest + +from atroposlib.api.utils import find_groups_summing_to_target + + +class TestHeterogeneousPacking: + """Test cases for finding groups that sum to target size.""" + + def test_simple_fifo_exact_match(self): + """Test when FIFO order gives exact match.""" + buffer = [ + {"tokens": [[1, 2]], "scores": [0.5]}, # size 1 + {"tokens": [[3, 4], [5, 6]], "scores": [0.6, 0.7]}, # size 2 + {"tokens": [[7, 8]], "scores": [0.8]}, # size 1 + ] + + indices = find_groups_summing_to_target(buffer, 4) + assert indices == [0, 1, 2] + + def test_fifo_partial_match(self): + """Test when FIFO can match with subset.""" + buffer = [ + {"tokens": [[1, 2], [3, 4]], "scores": [0.5, 0.6]}, # size 2 + {"tokens": [[5, 6], [7, 8]], "scores": [0.7, 0.8]}, # size 2 + { + "tokens": [[9, 10], [11, 12], [13, 14], [15, 16]], + "scores": [0.9, 1.0, 1.1, 1.2], + }, # size 4 + ] + + indices = find_groups_summing_to_target(buffer, 4) + assert indices == [0, 1] # First two groups sum to 4 + + def test_need_dynamic_programming(self): + """Test when FIFO doesn't work but other combinations do.""" + buffer = [ + {"tokens": [[1, 2], [3, 4], [5, 6]], "scores": [0.5, 0.6, 0.7]}, # size 3 + {"tokens": [[7, 8]], "scores": [0.8]}, # size 1 + { + "tokens": [[9, 10], [11, 12], [13, 14], [15, 16]], + "scores": [0.9, 1.0, 1.1, 1.2], + }, # size 4 + ] + + indices = find_groups_summing_to_target(buffer, 5) + assert indices == [1, 2] # Groups at index 1 (size 1) and 2 (size 4) + + def test_impossible_target(self): + """Test when no combination can reach target.""" + buffer = [ + {"tokens": [[1, 2], [3, 4]], "scores": [0.5, 0.6]}, # size 2 + { + "tokens": [[5, 6], [7, 8], [9, 10], [11, 12]], + "scores": [0.7, 0.8, 0.9, 1.0], + }, # size 4 + ] + + indices = find_groups_summing_to_target(buffer, 3) + assert indices == [] # Can't make 3 from groups of size 2 and 4 + + def test_empty_buffer(self): + """Test with empty buffer.""" + indices = find_groups_summing_to_target([], 4) + assert indices == [] + + def test_single_group_exact(self): + """Test when single group matches exactly.""" + buffer = [ + { + "tokens": [[1, 2], [3, 4], [5, 6], [7, 8]], + "scores": [0.5, 0.6, 0.7, 0.8], + }, # size 4 + ] + + indices = find_groups_summing_to_target(buffer, 4) + assert indices == [0] + + def test_bradley_terry_pairs(self): + """Test RLAIF use case with Bradley-Terry pairs.""" + buffer = [ + {"tokens": [[1, 2], [3, 4]], "scores": [0.7, 0.3]}, # size 2 (BT pair) + {"tokens": [[5, 6], [7, 8]], "scores": [0.6, 0.4]}, # size 2 (BT pair) + {"tokens": [[9, 10], [11, 12]], "scores": [0.8, 0.2]}, # size 2 (BT pair) + {"tokens": [[13, 14], [15, 16]], "scores": [0.5, 0.5]}, # size 2 (BT pair) + ] + + indices = find_groups_summing_to_target(buffer, 8) + assert indices == [0, 1, 2, 3] # All 4 pairs + + def test_mixed_sizes_complex(self): + """Test with various power-of-2 sizes.""" + buffer = [ + {"tokens": [[1]], "scores": [0.5]}, # size 1 + {"tokens": [[2], [3]], "scores": [0.6, 0.7]}, # size 2 + {"tokens": [[4]], "scores": [0.8]}, # size 1 + {"tokens": [[5], [6], [7], [8]], "scores": [0.9, 1.0, 1.1, 1.2]}, # size 4 + {"tokens": [[9], [10]], "scores": [1.3, 1.4]}, # size 2 + ] + + # Target 8: should find combination that sums to 8 + indices = find_groups_summing_to_target(buffer, 8) + assert len(indices) > 0 + assert sum(len(buffer[i]["tokens"]) for i in indices) == 8 + + def test_large_groups(self): + """Test with larger group sizes.""" + buffer = [ + {"tokens": [[i] for i in range(16)], "scores": [0.5] * 16}, # size 16 + {"tokens": [[i] for i in range(8)], "scores": [0.6] * 8}, # size 8 + {"tokens": [[i] for i in range(8)], "scores": [0.7] * 8}, # size 8 + ] + + indices = find_groups_summing_to_target(buffer, 32) + assert indices == [0, 1, 2] # All groups needed + + def test_prefer_earlier_indices(self): + """Test that algorithm prefers earlier indices when multiple solutions exist.""" + buffer = [ + {"tokens": [[1], [2]], "scores": [0.5, 0.6]}, # size 2 + {"tokens": [[3], [4]], "scores": [0.7, 0.8]}, # size 2 + {"tokens": [[5], [6], [7], [8]], "scores": [0.9, 1.0, 1.1, 1.2]}, # size 4 + {"tokens": [[9], [10]], "scores": [1.3, 1.4]}, # size 2 + {"tokens": [[11], [12]], "scores": [1.5, 1.6]}, # size 2 + ] + + indices = find_groups_summing_to_target(buffer, 4) + assert indices == [0, 1] # Should prefer first two groups over later ones + + def test_exact_fit_with_remainder(self): + """Test when we can form exact target but have leftover groups.""" + buffer = [ + {"tokens": [[1], [2]], "scores": [0.5, 0.6]}, # size 2 + {"tokens": [[3], [4], [5], [6]], "scores": [0.7, 0.8, 0.9, 1.0]}, # size 4 + {"tokens": [[7], [8]], "scores": [1.1, 1.2]}, # size 2 + {"tokens": [[9]], "scores": [1.3]}, # size 1 + ] + + indices = find_groups_summing_to_target(buffer, 6) + assert sorted(indices) == [0, 1] # First two groups sum to 6 + + def test_stress_many_small_groups(self): + """Test with many small groups.""" + # Create 16 groups of size 1 + buffer = [{"tokens": [[i]], "scores": [i * 0.1]} for i in range(16)] + + indices = find_groups_summing_to_target(buffer, 8) + assert len(indices) == 8 + assert indices == list(range(8)) # Should take first 8 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/atroposlib/tests/test_utils/test_min_batch_allocation.py b/atroposlib/tests/test_utils/test_min_batch_allocation.py new file mode 100644 index 00000000..43fe81ac --- /dev/null +++ b/atroposlib/tests/test_utils/test_min_batch_allocation.py @@ -0,0 +1,806 @@ +"""Tests for minimum batch allocation functionality.""" + +import random + +from atroposlib.api.utils import grab_batch_with_minimum_allocations + + +class TestMinBatchAllocation: + """Test cases for minimum batch allocation feature.""" + + def test_basic_minimum_allocation(self): + """Test that basic minimum allocations are respected.""" + # Each item represents a group with multiple token sequences + queue = [ + { + "tokens": [[1, 2], [3, 4]], + "env_id": 0, + "masks": [[1, 1], [1, 1]], + "scores": [0.5, 0.6], + }, # 2 groups + { + "tokens": [[5, 6], [7, 8]], + "env_id": 1, + "masks": [[1, 1], [1, 1]], + "scores": [0.7, 0.8], + }, # 2 groups + { + "tokens": [[9, 10], [11, 12]], + "env_id": 0, + "masks": [[1, 1], [1, 1]], + "scores": [0.5, 0.6], + }, # 2 groups + { + "tokens": [[13, 14], [15, 16]], + "env_id": 1, + "masks": [[1, 1], [1, 1]], + "scores": [0.7, 0.8], + }, # 2 groups + ] + + env_configs = [ + { + "registered_id": 0, + "connected": True, + "min_batch_allocation": 0.25, + }, # 25% = 2 groups min + { + "registered_id": 1, + "connected": True, + "min_batch_allocation": 0.5, + }, # 50% = 4 groups min + ] + + batch_size = 8 # 8 token groups total + batch, new_queue = grab_batch_with_minimum_allocations( + queue, batch_size, env_configs + ) + + assert batch is not None + + # Count groups (not items) per environment + env_groups = {} + total_groups = 0 + for item in batch: + env_id = item["env_id"] + groups = len(item["tokens"]) + env_groups[env_id] = env_groups.get(env_id, 0) + groups + total_groups += groups + + assert total_groups == batch_size + + # Env 1 should have at least 50% (4 groups) + assert env_groups.get(1, 0) >= 4 + # Env 0 should have at least 25% (2 groups) + assert env_groups.get(0, 0) >= 2 + + def test_no_minimum_allocation_fallback(self): + """Test fallback to original function when no minimums specified.""" + queue = [ + {"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]}, + {"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]}, + {"tokens": [[5, 6]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]}, + {"tokens": [[7, 8]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]}, + ] + + env_configs = [ + {"registered_id": 0, "connected": True}, + {"registered_id": 1, "connected": True}, + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs) + + # Should still form a batch using original logic + assert batch is not None + assert len(new_queue) < len(queue) + + def test_conflicting_minimums_scale_down(self): + """Test that conflicting minimums > 100% are scaled down.""" + queue = [ + { + "tokens": [[1, 2], [3, 4]], + "env_id": 0, + "masks": [[1, 1], [1, 1]], + "scores": [0.5, 0.6], + }, # 2 groups + { + "tokens": [[5, 6], [7, 8]], + "env_id": 1, + "masks": [[1, 1], [1, 1]], + "scores": [0.7, 0.8], + }, # 2 groups + ] + + env_configs = [ + {"registered_id": 0, "connected": True, "min_batch_allocation": 0.7}, # 70% + { + "registered_id": 1, + "connected": True, + "min_batch_allocation": 0.6, + }, # 60% = 130% total + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs) + + # Should still form a batch with scaled allocations + assert batch is not None + assert len(batch) == 2 # Both items needed to form batch of 4 groups + + def test_missing_env_in_queue(self): + """Test handling when an env has minimum but no items in queue.""" + queue = [ + { + "tokens": [[1, 2], [3, 4]], + "env_id": 0, + "masks": [[1, 1], [1, 1]], + "scores": [0.5, 0.6], + }, # 2 groups + { + "tokens": [[5, 6], [7, 8]], + "env_id": 0, + "masks": [[1, 1], [1, 1]], + "scores": [0.5, 0.6], + }, # 2 groups + ] + + env_configs = [ + {"registered_id": 0, "connected": True, "min_batch_allocation": 0.3}, + { + "registered_id": 1, + "connected": True, + "min_batch_allocation": 0.5, + }, # No items! + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs) + + # Should return None because env 1 has minimum allocation but no items + assert batch is None + + def test_disconnected_env_ignored(self): + """Test that disconnected environments are ignored.""" + queue = [ + {"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]}, + {"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]}, + ] + + env_configs = [ + {"registered_id": 0, "connected": True, "min_batch_allocation": 0.25}, + { + "registered_id": 1, + "connected": False, + "min_batch_allocation": 0.75, + }, # Disconnected! + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs) + + # Should only consider connected env + assert batch is not None + # May include env 1 items but won't enforce its minimum + + def test_mixed_group_sizes(self): + """Test handling of different group sizes.""" + queue = [ + {"tokens": [[1]], "env_id": 0, "masks": [[1]], "scores": [0.5]}, # size 1 + { + "tokens": [[2, 3, 4, 5]], + "env_id": 0, + "masks": [[1, 1, 1, 1]], + "scores": [0.6, 0.7, 0.8, 0.9], + }, # size 4 + { + "tokens": [[6, 7]], + "env_id": 1, + "masks": [[1, 1]], + "scores": [0.5, 0.6], + }, # size 2 + ] + + env_configs = [ + {"registered_id": 0, "connected": True, "min_batch_allocation": 0.5}, + {"registered_id": 1, "connected": True, "min_batch_allocation": 0.25}, + ] + + # Try to form batch of size 7 (which would need all items) + batch, new_queue = grab_batch_with_minimum_allocations(queue, 7, env_configs) + + if batch is not None: + total_tokens = sum(len(item["tokens"]) for item in batch) + assert total_tokens == 7 + + def test_empty_queue(self): + """Test handling of empty queue.""" + queue = [] + env_configs = [ + {"registered_id": 0, "connected": True, "min_batch_allocation": 0.5}, + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs) + + assert batch is None + assert new_queue == [] + + def test_insufficient_items_for_batch(self): + """Test when there aren't enough items to form a full batch.""" + queue = [ + {"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]}, + ] + + env_configs = [ + {"registered_id": 0, "connected": True, "min_batch_allocation": 0.5}, + ] + + # Request batch size 4 but only have 2 tokens + batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs) + + assert batch is None + assert len(new_queue) == 1 # Original queue unchanged + + def test_heterogeneous_envs(self): + """Test envs with individual group sizes.""" + # Env 0: all groups have size 2 + # Env 1: all groups have size 4 + # Env 2: all groups have size 8 + queue = [] + + # Add items for env 0 (group size 2) + for i in range(1): + queue.append( + { + "tokens": [[i * 2, i * 2 + 1] for _ in range(2)], + "env_id": 0, + "masks": [[1, 1] for _ in range(2)], + "scores": [0.5, 0.6], + } + ) + # for i in range(1): + # queue.append( + # { + # "tokens": [[i * 2, i * 2 + 1] for _ in range(2)], + # "env_id": 1, + # "masks": [[1, 1] for _ in range(2)], + # "scores": [0.5, 0.6], + # } + # ) + # Add 3 items of group size 2 to show why greedy packing doesn't work + for i in range(3): + queue.append( + { + "tokens": [[i * 2, i * 2 + 1] for _ in range(2)], + "env_id": 6, + "masks": [[1, 1] for _ in range(2)], + "scores": [0.5, 0.6], + } + ) + + # Add items for env 1 (group size 4) + for i in range(5): + queue.append( + { + "tokens": [ + [i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3] for _ in range(4) + ], + "env_id": 2, + "masks": [[1, 1, 1, 1] for _ in range(4)], + "scores": [0.7, 0.8, 0.9, 1.0], + } + ) + + # Add items for env 2 (group size 8) + for i in range(3): + queue.append( + { + "tokens": [[i * 8 + j] for j in range(8)], + "env_id": 3, + "masks": [[1] for _ in range(8)], + "scores": [0.5] * 8, + } + ) + + env_configs = [ + { + "registered_id": 0, + "connected": True, + "min_batch_allocation": 0.1, + }, # min 2 sequences + # { + # "registered_id": 1, + # "connected": True, + # "min_batch_allocation": 0.1, + # }, # min 2 sequences + # { + # "registered_id": 2, + # "connected": True, + # "min_batch_allocation": 0.25, + # }, # min 4 sequences + { + "registered_id": 3, + "connected": True, + "min_batch_allocation": 0.5, + }, # min 8 sequences + ] + + batch_size = 16 + batch, new_queue = grab_batch_with_minimum_allocations( + queue, batch_size, env_configs + ) + + # Since env 0 has min allocation of 10% but can't form any complete packs + # (has 1 item of size 2, needs 4 to make pack of 8), the function should + # return None as it cannot satisfy the minimum allocation requirement + assert batch is None + + # Queue should be unchanged + assert len(new_queue) == len(queue) + + def test_packing_constraint_enforcement(self): + """Test that packing to max group size is properly enforced.""" + # Create queue with items that can't form complete packs + queue = [ + { + "tokens": [[1, 2]], + "env_id": 0, + "masks": [[1, 1]], + "scores": [0.5], + }, # size 1 + { + "tokens": [[3, 4]], + "env_id": 0, + "masks": [[1, 1]], + "scores": [0.5], + }, # size 1 + { + "tokens": [[5, 6]], + "env_id": 0, + "masks": [[1, 1]], + "scores": [0.5], + }, # size 1 + # Need 4 items of size 1 to make a pack of 4, only have 3 + { + "tokens": [[7, 8], [9, 10], [11, 12], [13, 14]], + "env_id": 1, + "masks": [[1, 1]] * 4, + "scores": [0.7] * 4, + }, # size 4 + ] + + env_configs = [ + {"registered_id": 0, "connected": True, "min_batch_allocation": 0.25}, + {"registered_id": 1, "connected": True, "min_batch_allocation": None}, + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs) + + # Should return None because env 0 can't form complete packs + assert batch is None + + def test_fifo_order_preservation(self): + """Test that FIFO order is preserved when forming batches.""" + queue = [] + # Add items with sequential scores to track order + for i in range(8): + queue.append( + { + "tokens": [[i, i + 1]], + "env_id": 0, + "masks": [[1, 1]], + "scores": [float(i)], # Use score to track original order + } + ) + + env_configs = [ + {"registered_id": 0, "connected": True, "min_batch_allocation": None}, + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs) + + if batch is not None: + # Check that we got the first 4 items (scores 0-3) + batch_scores = [item["scores"][0] for item in batch] + assert sorted(batch_scores) == [0.0, 1.0, 2.0, 3.0] + + def test_exact_minimum_boundary(self): + """Test behavior at exact minimum allocation boundaries.""" + queue = [ + { + "tokens": [[1, 2], [3, 4]], + "env_id": 0, + "masks": [[1, 1], [1, 1]], + "scores": [0.5, 0.6], + }, + { + "tokens": [[5, 6], [7, 8]], + "env_id": 0, + "masks": [[1, 1], [1, 1]], + "scores": [0.5, 0.6], + }, + { + "tokens": [[9, 10], [11, 12]], + "env_id": 1, + "masks": [[1, 1], [1, 1]], + "scores": [0.7, 0.8], + }, + { + "tokens": [[13, 14], [15, 16]], + "env_id": 1, + "masks": [[1, 1], [1, 1]], + "scores": [0.7, 0.8], + }, + ] + + env_configs = [ + { + "registered_id": 0, + "connected": True, + "min_batch_allocation": 0.5, + }, # Exactly 50% + { + "registered_id": 1, + "connected": True, + "min_batch_allocation": 0.5, + }, # Exactly 50% + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 8, env_configs) + + assert batch is not None + env_counts = {} + for item in batch: + env_id = item["env_id"] + count = len(item["tokens"]) + env_counts[env_id] = env_counts.get(env_id, 0) + count + + # Both envs should get exactly 4 sequences (50%) + assert env_counts[0] == 4 + assert env_counts[1] == 4 + + def test_zero_minimum_allocation(self): + """Test that zero minimum allocation is handled correctly.""" + queue = [ + {"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]}, + {"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7]}, + ] + + env_configs = [ + {"registered_id": 0, "connected": True, "min_batch_allocation": 0.0}, # 0% + {"registered_id": 1, "connected": True, "min_batch_allocation": 0.5}, # 50% + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs) + + # Should work fine - env 0 has no minimum requirement + assert batch is not None + + def test_multiple_complete_packs(self): + """Test forming multiple complete packs from same environment.""" + queue = [] + # Add 16 items of size 1 from env 0 (can form 4 complete packs of 4) + for i in range(16): + queue.append( + { + "tokens": [[i]], + "env_id": 0, + "masks": [[1]], + "scores": [0.5], + } + ) + + # Add 2 items of size 4 from env 1 + for i in range(2): + queue.append( + { + "tokens": [[100 + i * 4, 101 + i * 4, 102 + i * 4, 103 + i * 4]], + "env_id": 1, + "masks": [[1, 1, 1, 1]], + "scores": [0.7], + } + ) + + env_configs = [ + { + "registered_id": 0, + "connected": True, + "min_batch_allocation": 0.75, + }, # 12 sequences + {"registered_id": 1, "connected": True, "min_batch_allocation": None}, + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 16, env_configs) + + assert batch is not None + env_counts = {} + for item in batch: + env_id = item["env_id"] + count = len(item["tokens"]) + env_counts[env_id] = env_counts.get(env_id, 0) + count + + # Env 0 should get at least 12 sequences + assert env_counts.get(0, 0) >= 12 + assert sum(env_counts.values()) == 16 + + def test_no_packable_items(self): + """Test when no items can form complete packs.""" + queue = [ + { + "tokens": [[1, 2], [3, 4]], + "env_id": 0, + "masks": [[1, 1], [1, 1]], + "scores": [0.5, 0.6], + }, # size 2 + { + "tokens": [[5, 6], [7, 8]], + "env_id": 0, + "masks": [[1, 1], [1, 1]], + "scores": [0.5, 0.6], + }, # size 2 + { + "tokens": [[9, 10], [11, 12]], + "env_id": 0, + "masks": [[1, 1], [1, 1]], + "scores": [0.5, 0.6], + }, # size 2 + # Only 3 items of size 2, need 4 to make complete pack of 8 + { + "tokens": [ + [13, 14, 15, 16], + [17, 18, 19, 20], + [21, 22, 23, 24], + [25, 26, 27, 28], + [29, 30, 31, 32], + [33, 34, 35, 36], + [37, 38, 39, 40], + [41, 42, 43, 44], + ], + "env_id": 1, + "masks": [[1, 1, 1, 1]] * 8, + "scores": [0.7] * 8, + }, # size 8 + ] + + env_configs = [ + { + "registered_id": 0, + "connected": True, + "min_batch_allocation": 0.25, + }, # Can't form complete packs + {"registered_id": 1, "connected": True, "min_batch_allocation": None}, + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 8, env_configs) + + # Env 0 can't form complete packs (has 3 items, needs 4) + assert batch is None + + def test_env_without_items(self): + """Test env config without any items in queue.""" + queue = [ + {"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]}, + {"tokens": [[3, 4]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]}, + ] + + env_configs = [ + {"registered_id": 0, "connected": True, "min_batch_allocation": 0.5}, + { + "registered_id": 1, + "connected": True, + "min_batch_allocation": None, + }, # No items + { + "registered_id": 2, + "connected": True, + "min_batch_allocation": 0.3, + }, # No items + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs) + + # Should work - env 2 has no items so its minimum is ignored + assert batch is not None + + def test_scaling_with_single_env(self): + """Test scaling behavior with only one env having minimum.""" + queue = [] + for i in range(8): + queue.append( + { + "tokens": [[i, i + 1]], + "env_id": 0, + "masks": [[1, 1]], + "scores": [0.5], + } + ) + + env_configs = [ + { + "registered_id": 0, + "connected": True, + "min_batch_allocation": 1.5, + }, # 150% - impossible + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs) + + # Should scale down to 100% and work + assert batch is not None + assert len(batch) == 4 + + def test_mixed_null_and_set_minimums(self): + """Test mix of environments with and without minimum allocations.""" + queue = [] + # Env 0: 4 items of size 2 + for i in range(4): + queue.append( + { + "tokens": [[i * 2, i * 2 + 1], [i * 2 + 10, i * 2 + 11]], + "env_id": 0, + "masks": [[1, 1], [1, 1]], + "scores": [0.5, 0.6], + } + ) + # Env 1: 2 items of size 2 + for i in range(2): + queue.append( + { + "tokens": [[i * 2 + 20, i * 2 + 21], [i * 2 + 30, i * 2 + 31]], + "env_id": 1, + "masks": [[1, 1], [1, 1]], + "scores": [0.7, 0.8], + } + ) + # Env 2: 2 items of size 2 (no minimum) + for i in range(2): + queue.append( + { + "tokens": [[i * 2 + 40, i * 2 + 41], [i * 2 + 50, i * 2 + 51]], + "env_id": 2, + "masks": [[1, 1], [1, 1]], + "scores": [0.9, 1.0], + } + ) + + env_configs = [ + { + "registered_id": 0, + "connected": True, + "min_batch_allocation": 0.4, + }, # 40% = 6.4 ≈ 6 + { + "registered_id": 1, + "connected": True, + "min_batch_allocation": 0.2, + }, # 20% = 3.2 ≈ 3 + { + "registered_id": 2, + "connected": True, + "min_batch_allocation": None, + }, # No minimum + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 16, env_configs) + + assert batch is not None + env_counts = {} + for item in batch: + env_id = item["env_id"] + count = len(item["tokens"]) + env_counts[env_id] = env_counts.get(env_id, 0) + count + + # Check minimums are satisfied + assert env_counts.get(0, 0) >= 6 # At least 40% of 16 + assert env_counts.get(1, 0) >= 3 # At least 20% of 16 + assert sum(env_counts.values()) == 16 + + def test_random_consistent_group_sizes(self): + """Random test where each env has a consistent power-of-2 group size.""" + for _ in range(100): + batch_size = 64 * random.randint(1, 4) + num_envs = random.randint(2, 4) + + # Assign each env a consistent group size + env_group_sizes = {} + for env_id in range(num_envs): + env_group_sizes[env_id] = 2 ** random.randint(0, 3) # 1, 2, 4, or 8 + + # Create queue + queue = [] + for env_id in range(num_envs): + group_size = env_group_sizes[env_id] + num_items = random.randint(5, 20) + for i in range(num_items): + queue.append( + { + "tokens": [ + [env_id * 1000 + i * 10 + j] for j in range(group_size) + ], + "env_id": env_id, + "masks": [[1] for _ in range(group_size)], + "scores": [0.5 + env_id * 0.1] * group_size, + } + ) + + # Random minimum allocations that sum to less than 1.0 + env_configs = [] + remaining = 0.9 + for env_id in range(num_envs): + if env_id == num_envs - 1: + min_alloc = remaining + else: + min_alloc = random.uniform(0.1, min(0.4, remaining)) + remaining -= min_alloc + + env_configs.append( + { + "registered_id": env_id, + "connected": True, + "min_batch_allocation": min_alloc, + } + ) + + batch, new_queue = grab_batch_with_minimum_allocations( + queue, batch_size, env_configs + ) + + if batch is not None: + # Verify batch size + total_sequences = sum(len(item["tokens"]) for item in batch) + assert total_sequences == batch_size + + # Verify all items from same env have same group size + env_group_sizes_seen = {} + for item in batch: + env_id = item["env_id"] + group_size = len(item["tokens"]) + if env_id in env_group_sizes_seen: + assert group_size == env_group_sizes_seen[env_id] + else: + env_group_sizes_seen[env_id] = group_size + + def test_queue_dominated_by_one_env(self): + """Test minimum allocation when one env dominates the queue.""" + queue = [] + + # Only env 1 items in queue + for i in range(100): + queue.append( + { + "tokens": [[1000 + i, 1001 + i]], + "env_id": 1, + "masks": [[1, 1]], + "scores": [0.7], + } + ) + + env_configs = [ + { + "registered_id": 0, + "connected": True, + "min_batch_allocation": 0.5, + }, # 50% but no items! + {"registered_id": 1, "connected": True, "min_batch_allocation": 0.3}, # 30% + ] + + batch, new_queue = grab_batch_with_minimum_allocations(queue, 10, env_configs) + + # Should return None because env 0 has minimum allocation but no items + assert batch is None + + # Test with env 0 having no minimum - should work + env_configs[0]["min_batch_allocation"] = None + batch, new_queue = grab_batch_with_minimum_allocations(queue, 10, env_configs) + + assert batch is not None + env_counts = {} + for item in batch: + env_id = item["env_id"] + count = len(item["tokens"]) + env_counts[env_id] = env_counts.get(env_id, 0) + count + + # Should all be from env 1 + assert env_counts.get(1, 0) == 10 + assert sum(env_counts.values()) == 10 + + +if __name__ == "__main__": + test = TestMinBatchAllocation() + test.test_queue_dominated_by_one_env() diff --git a/environments/intern_bootcamp/intern_bootcamp_env.py b/environments/intern_bootcamp/intern_bootcamp_env.py index 325a95d6..c3e17321 100644 --- a/environments/intern_bootcamp/intern_bootcamp_env.py +++ b/environments/intern_bootcamp/intern_bootcamp_env.py @@ -367,12 +367,12 @@ class InternBootcampEnv(BaseEnv): tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=8, use_wandb=True, - max_num_workers=64, + max_num_workers_per_node=16, rollout_server_url="http://localhost:8000", total_steps=10000, batch_size=1024, steps_per_eval=100, - max_token_length=16384, + max_token_length=8192, inference_weight=1.0, wandb_name="intern_bootcamp_random_tasks", data_path_to_save_groups="data/intern_bootcamp_random_tasks.jsonl", @@ -385,6 +385,7 @@ class InternBootcampEnv(BaseEnv): format_bonus=0.2, # Training parameters require_reasoning=True, + min_batch_allocation=0.1, min_reasoning_length=50, temperature=0.7, top_p=0.9, diff --git a/environments/math_server.py b/environments/math_server.py index 8235f0e5..27189c2e 100644 --- a/environments/math_server.py +++ b/environments/math_server.py @@ -22,6 +22,13 @@ from atroposlib.envs.base import ( ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer +system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " + "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " + "solution prior to answering. You should enclose your thoughts and internal monologue inside " + " tags, and then provide your solution or response to the problem." +) + problem_format = "{problem}" judge_format = """Here is a math problem and a proposed solution: @@ -85,7 +92,7 @@ class RSConfig(BaseEnvConfig): ) percent_to_judge: float = Field(0.3, description="The percentage of items to judge") percent_length_penalty: float = Field( - 0.0, description="The percentage of items to have length penalty" + 0.1, description="The percentage of items to have length penalty" ) @@ -179,21 +186,24 @@ class MathEnv(BaseEnv): @classmethod def config_init(self) -> Tuple[RSConfig, List[APIServerConfig]]: env_config = RSConfig( - tokenizer_name="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", - group_size=8, + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=16, use_wandb=True, rollout_server_url="http://localhost:8000", total_steps=1000, batch_size=1024, + max_num_workers_per_node=24, steps_per_eval=25, - max_token_length=31000, # 22000 // (2 ** i), + max_token_length=8192, # 22000 // (2 ** i), wandb_name="math", eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, + inference_weight=4, + min_batch_allocation=0.1, ) server_configs = [ APIServerConfig( - model_name="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", num_requests_for_eval=256, # since evaling only on one... @@ -306,6 +316,7 @@ class MathEnv(BaseEnv): completion = await self.server.chat_completion( messages=[ + {"role": "system", "content": system_prompt}, {"role": "user", "content": question}, ], n=1, @@ -352,11 +363,16 @@ class MathEnv(BaseEnv): thinking_len = self.config.max_token_length user_prompt = problem_format.format(problem=item[0]) chat = [ + {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] thinking_len = thinking_len - len( self.tokenizer.apply_chat_template(chat, add_generation_prompt=True) ) + print(f"thinking_len: {thinking_len}", flush=True) + if thinking_len < 1024: + print("thinking_len is less than 1024, skipping", flush=True) + return None, [] chat_completions = await self.server.chat_completion( messages=chat, n=self.config.group_size, @@ -369,6 +385,7 @@ class MathEnv(BaseEnv): to_backlog = list() for i, chat_completion in enumerate(chat_completions.choices): messages = ( + {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, {"role": "assistant", "content": chat_completion.message.content}, ) @@ -379,8 +396,9 @@ class MathEnv(BaseEnv): chat_completion.finish_reason, ) ) - + print("scoring normal", flush=True) to_postprocess = await self.score_normal(to_score) + print("scoring normal done", flush=True) if to_postprocess is None: return None, to_backlog if all( @@ -712,6 +730,7 @@ class MathEnv(BaseEnv): ) print("Sending to server") chat = [ + {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt_fwd}, ] max_token_length = self.config.max_token_length - len( @@ -727,6 +746,7 @@ class MathEnv(BaseEnv): print("Sending to server") # Should be the same token length as the fwd but tokenizers are cursed chat = [ + {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt_bwd}, ] max_token_length = self.config.max_token_length - len( @@ -822,6 +842,7 @@ class MathEnv(BaseEnv): ) to_backlog = list() chat = [ + {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] max_token_length = self.config.max_token_length - len( @@ -862,6 +883,7 @@ class MathEnv(BaseEnv): out_dict = tokenize_for_trainer( tokenizer=self.tokenizer, chat=[ + {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, {"role": "assistant", "content": chat_completion.message.content}, ], @@ -902,6 +924,7 @@ class MathEnv(BaseEnv): ) print("Sending to server") retry_messages = [ + {"role": "system", "content": system_prompt}, {"role": "user", "content": retry_prompt}, ] max_token_length = self.config.max_token_length - len( diff --git a/environments/reasoning_gym_environment/reasoning_gym_environment.py b/environments/reasoning_gym_environment/reasoning_gym_environment.py index 124a5ac1..95e80337 100644 --- a/environments/reasoning_gym_environment/reasoning_gym_environment.py +++ b/environments/reasoning_gym_environment/reasoning_gym_environment.py @@ -34,7 +34,8 @@ if _SUBMODULE_DIR not in sys.path: try: import reasoning_gym from reasoning_gym.utils import extract_answer -except ImportError: +except ImportError as e: + print(e) reasoning_gym = None extract_answer = None @@ -74,6 +75,10 @@ class ReasoningGymEnvConfig(BaseEnvConfig): default=True, description="Suppress verbose base environment logs (like status dict updates).", ) + mask_too_long_completions: bool = Field( + default=True, + description="Whether to mask too long completions.", + ) rollout_save_score_threshold: float = Field( default=0.7, description="Minimum score threshold for saving rollouts to data dumps. Only groups with at least one rollout above this threshold will be saved.", # noqa: E501 @@ -149,7 +154,6 @@ class ReasoningGymEnv(BaseEnv): base_logger.setLevel(logging.WARNING) # Set max_token_len for base class compatibility - self.max_token_len = self.config.max_token_length self.percent_correct_buffer = list() self.eval_metrics = list() @@ -193,17 +197,17 @@ class ReasoningGymEnv(BaseEnv): def config_init(cls) -> Tuple[ReasoningGymEnvConfig, List[APIServerConfig]]: env_config = ReasoningGymEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", - group_size=16, + group_size=8, use_wandb=True, rollout_server_url="http://localhost:8000", total_steps=250, seed=1918, batch_size=1024, steps_per_eval=25, - max_token_length=1024 * 16, - inference_weight=1.0, + max_token_length=1024 * 8, + inference_weight=4.0, wandb_name="reasoning_gym_think", # Specific name for reasoning gym - eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_handling=EvalHandlingEnum.NONE, eval_limit_ratio=0.1, num_rollouts_per_group_for_logging=4, num_rollouts_to_keep=50, @@ -216,13 +220,14 @@ class ReasoningGymEnv(BaseEnv): eval_seed=123, complexity_mode="random", # Options: None, "curriculum", "random" curriculum_target_accuracy=0.7, + min_batch_allocation=0.1, ) server_configs = [ APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", - num_max_requests_at_once=32, + num_max_requests_at_once=128, num_requests_for_eval=256, ), ] @@ -735,6 +740,7 @@ class ReasoningGymEnv(BaseEnv): scores_container["tokens"] = list() scores_container["masks"] = list() scores_container["scores"] = list() + scores_container["overrides"] = list() if not rollout_group_data: return None @@ -745,7 +751,7 @@ class ReasoningGymEnv(BaseEnv): # Shuffle to avoid bias in selection random.shuffle(rollout_group_data) - for trajectory_messages, _, _ in rollout_group_data: + for trajectory_messages, _, _, finish_reason in rollout_group_data: model_full_response = trajectory_messages[-1]["content"] # Extract the part of the response that should be the answer @@ -779,6 +785,10 @@ class ReasoningGymEnv(BaseEnv): scores_container["tokens"].append(tokens) scores_container["masks"].append(masks) scores_container["scores"].append(reward_0_to_1) + scores_container["overrides"].append(dict()) + if finish_reason == "length": + if self.config.mask_too_long_completions: + scores_container["overrides"][-1]["set_advantage_to_zero"] = True if len(scores_container["tokens"]) >= self.config.group_size: break @@ -968,13 +978,16 @@ class ReasoningGymEnv(BaseEnv): # Calculate max_tokens like tool_calling_server prompt_tokens = len(self.tokenizer.encode(prompt_str)) - max_tokens = min(1024 * 15, self.config.max_token_length - prompt_tokens) + max_tokens = self.config.max_token_length - prompt_tokens + if max_tokens <= 0: + return None, [] completions = await self.server.completion( prompt=prompt_str, n=self.config.group_size, max_tokens=max_tokens, - temperature=0.8, + temperature=1.0, + top_p=0.95, ) to_score_list = [] @@ -988,7 +1001,12 @@ class ReasoningGymEnv(BaseEnv): ) to_score_list.append( - (tuple(current_trajectory_messages), rg_item, dataset_obj) + ( + tuple(current_trajectory_messages), + rg_item, + dataset_obj, + choice.finish_reason, + ) ) scored_data_group = await self.score(to_score_list) @@ -1159,7 +1177,9 @@ class ReasoningGymEnv(BaseEnv): # Calculate max_tokens like tool_calling_server prompt_tokens = len(self.tokenizer.encode(prompt_str)) - max_tokens = min(1024 * 15, self.config.max_token_length - prompt_tokens) + max_tokens = (2 * self.config.max_token_length) - prompt_tokens + if max_tokens < 0: + return 0.0 completion = await self.server.completion( prompt=prompt_str, @@ -2061,7 +2081,7 @@ class ReasoningGymEnv(BaseEnv): if new_complexity != current_complexity: self.task_complexity_levels[task_name] = new_complexity self.logger.info( - f"↑ {task_name}: complexity {current_complexity:.2f} -> {new_complexity:.2f} " + f"↑ {task_name}: complexity {current_complexity:.2f} -> {new_complexity:.2f} " f"(accuracy: {recent_accuracy:.2f}, stability: {stability_factor:.2f}, groups: {group_count})" ) @@ -2073,7 +2093,7 @@ class ReasoningGymEnv(BaseEnv): if new_complexity != current_complexity: self.task_complexity_levels[task_name] = new_complexity self.logger.info( - f"↓ {task_name}: complexity {current_complexity:.2f} -> {new_complexity:.2f} " + f"↓ {task_name}: complexity {current_complexity:.2f} -> {new_complexity:.2f} " f"(accuracy: {recent_accuracy:.2f}, stability: {stability_factor:.2f}, groups: {group_count})" ) @@ -2089,7 +2109,7 @@ class ReasoningGymEnv(BaseEnv): if new_complexity != current_complexity: self.task_complexity_levels[task_name] = new_complexity self.logger.info( - f"⚡ {task_name}: fast complexity jump {current_complexity:.2f} -> {new_complexity:.2f} " + f"âš¡ {task_name}: fast complexity jump {current_complexity:.2f} -> {new_complexity:.2f} " f"(high accuracy: {recent_accuracy:.2f}, stable performance)" ) @@ -2105,7 +2125,7 @@ class ReasoningGymEnv(BaseEnv): if new_complexity != current_complexity: self.task_complexity_levels[task_name] = new_complexity self.logger.info( - f"🔻 {task_name}: fast complexity drop {current_complexity:.2f} -> {new_complexity:.2f} " + f"🔻 {task_name}: fast complexity drop {current_complexity:.2f} -> {new_complexity:.2f} " f"(low accuracy: {recent_accuracy:.2f}, stable performance)" ) @@ -2116,7 +2136,7 @@ class ReasoningGymEnv(BaseEnv): ): if group_count % 10 == 0: # Log every 10 groups when stable self.logger.debug( - f"🎯 {task_name}: stable at complexity {current_complexity:.2f} " + f"🎯 {task_name}: stable at complexity {current_complexity:.2f} " f"(accuracy: {recent_accuracy:.2f}, target: {target_accuracy:.2f})" ) diff --git a/environments/tool_calling_server.py b/environments/tool_calling_server.py index 7de921aa..e492413a 100644 --- a/environments/tool_calling_server.py +++ b/environments/tool_calling_server.py @@ -46,15 +46,17 @@ class SingleToolCallingEnv(BaseEnv): tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=16, use_wandb=True, + max_num_workers_per_node=16, rollout_server_url="http://localhost:8000", total_steps=2000, batch_size=1024, - steps_per_eval=20, + steps_per_eval=25, max_token_length=1024 * 16, inference_weight=1.0, wandb_name="toolcall_think", eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, + min_batch_allocation=0.1, ) server_configs = [ APIServerConfig( @@ -113,7 +115,7 @@ class SingleToolCallingEnv(BaseEnv): full_dataset = full_dataset.shuffle(seed=42) # Create train/test split on the fly (e.g., 95% train, 5% test) - split_dataset = full_dataset.train_test_split(test_size=0.02, seed=42) + split_dataset = full_dataset.train_test_split(test_size=100, seed=42) # Keep the splits as is - no need to reformat self.train = split_dataset["train"]