mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
commit
b763b4e20d
6 changed files with 185 additions and 42 deletions
42
README.md
42
README.md
|
|
@ -256,6 +256,48 @@ Atropos repo contains an example trainer that should primarily be used as a refe
|
||||||
|
|
||||||
To use the example trainer, see this page: [training example guide](example_trainer/README.md)
|
To use the example trainer, see this page: [training example guide](example_trainer/README.md)
|
||||||
|
|
||||||
|
## On-Policy Distillation (API + ScoredDataGroup Contract)
|
||||||
|
|
||||||
|
Atropos now supports OPD at the transport layer by carrying distillation arrays
|
||||||
|
through `ScoredDataGroup` and the API queue/batch endpoints.
|
||||||
|
|
||||||
|
### Scope of this change
|
||||||
|
|
||||||
|
- No teacher fetching/orchestration in `BaseEnv`.
|
||||||
|
- Environments or external pipelines are responsible for populating distillation arrays.
|
||||||
|
- API stores and returns those arrays unchanged.
|
||||||
|
|
||||||
|
### Distillation payload fields
|
||||||
|
|
||||||
|
Each scored group may include:
|
||||||
|
|
||||||
|
- `distill_token_ids`: shape `[sequence][position][top_k]`
|
||||||
|
- `distill_logprobs`: shape `[sequence][position][top_k]`
|
||||||
|
|
||||||
|
These fields are optional, and when present are forwarded from:
|
||||||
|
|
||||||
|
- environment -> `/scored_data` or `/scored_data_list`
|
||||||
|
- API queue -> `/batch` -> trainer
|
||||||
|
|
||||||
|
### Minimal producer example (environment side)
|
||||||
|
|
||||||
|
```python
|
||||||
|
scores["distill_token_ids"] = distill_token_ids
|
||||||
|
scores["distill_logprobs"] = distill_logprobs
|
||||||
|
```
|
||||||
|
|
||||||
|
### Minimal consumer check (trainer/debug side)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!=null), has_lps:(.distill_logprobs!=null)}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Notes
|
||||||
|
|
||||||
|
- The API does not validate cross-field semantics beyond schema typing.
|
||||||
|
- Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.).
|
||||||
|
- Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Testing and Debugging Tools
|
## Testing and Debugging Tools
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import gzip
|
import gzip
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
@ -25,6 +26,7 @@ MIN_ENV_WEIGHT = (
|
||||||
# Message import removed - using Dict[str, Any] for more flexible validation
|
# Message import removed - using Dict[str, Any] for more flexible validation
|
||||||
|
|
||||||
app = FastAPI(title="AtroposLib API")
|
app = FastAPI(title="AtroposLib API")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
|
|
@ -145,6 +147,10 @@ class ScoredData(BaseModel):
|
||||||
group_overrides: Optional[dict] = None
|
group_overrides: Optional[dict] = None
|
||||||
images: Optional[Any] = None
|
images: Optional[Any] = None
|
||||||
env_id: Optional[int] = None # ID of the environment that generated this data
|
env_id: Optional[int] = None # ID of the environment that generated this data
|
||||||
|
# On-policy distillation (new format): parallel token ids + logprobs.
|
||||||
|
# Shape for both: [sequence][position][top_k]
|
||||||
|
distill_token_ids: Optional[List[List[List[int]]]] = None
|
||||||
|
distill_logprobs: Optional[List[List[List[float]]]] = None
|
||||||
|
|
||||||
@field_validator("messages", mode="before")
|
@field_validator("messages", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -182,6 +188,8 @@ def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]:
|
||||||
"group_overrides": scored_data.group_overrides,
|
"group_overrides": scored_data.group_overrides,
|
||||||
"images": scored_data.images,
|
"images": scored_data.images,
|
||||||
"env_id": scored_data.env_id,
|
"env_id": scored_data.env_id,
|
||||||
|
"distill_token_ids": scored_data.distill_token_ids,
|
||||||
|
"distill_logprobs": scored_data.distill_logprobs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -385,7 +393,10 @@ async def get_batch():
|
||||||
app.state.curr_batch.append(batch)
|
app.state.curr_batch.append(batch)
|
||||||
curr_batch = app.state.curr_batch.pop()
|
curr_batch = app.state.curr_batch.pop()
|
||||||
# check length before sending
|
# check length before sending
|
||||||
print(f"Sending batch of {sum(len(x['tokens']) for x in curr_batch)} sequences")
|
logger.info(
|
||||||
|
"Sending batch of %s sequences",
|
||||||
|
sum(len(x["tokens"]) for x in curr_batch),
|
||||||
|
)
|
||||||
return {"batch": curr_batch}
|
return {"batch": curr_batch}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,10 @@ class ScoredDataGroup(TypedDict):
|
||||||
group_overrides: Optional[Dict]
|
group_overrides: Optional[Dict]
|
||||||
overrides: Optional[List[Dict]]
|
overrides: Optional[List[Dict]]
|
||||||
images: Optional[Any]
|
images: Optional[Any]
|
||||||
|
# On-policy distillation (new format): parallel token ids + logprobs.
|
||||||
|
# distill_token_ids/distill_logprobs are [sequence][position][top_k]
|
||||||
|
distill_token_ids: Optional[List[List[List[int]]]]
|
||||||
|
distill_logprobs: Optional[List[List[List[float]]]]
|
||||||
|
|
||||||
|
|
||||||
class ScoredDataItem(TypedDict):
|
class ScoredDataItem(TypedDict):
|
||||||
|
|
@ -78,6 +82,9 @@ class ScoredDataItem(TypedDict):
|
||||||
group_overrides: Optional[Dict]
|
group_overrides: Optional[Dict]
|
||||||
overrides: Optional[Dict]
|
overrides: Optional[Dict]
|
||||||
images: Optional[Any]
|
images: Optional[Any]
|
||||||
|
# On-policy distillation (new format): parallel token ids + logprobs per position.
|
||||||
|
distill_token_ids: Optional[List[List[int]]]
|
||||||
|
distill_logprobs: Optional[List[List[float]]]
|
||||||
|
|
||||||
|
|
||||||
class EvalHandlingEnum(Enum):
|
class EvalHandlingEnum(Enum):
|
||||||
|
|
@ -279,7 +286,7 @@ class BaseEnv(ABC):
|
||||||
counter += 1
|
counter += 1
|
||||||
path_changed = True
|
path_changed = True
|
||||||
if path_changed:
|
if path_changed:
|
||||||
print(
|
logger.info(
|
||||||
f"Changed data path to {self.config.data_path_to_save_groups} because {original_path} already exists." # noqa: E501
|
f"Changed data path to {self.config.data_path_to_save_groups} because {original_path} already exists." # noqa: E501
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -357,7 +364,7 @@ class BaseEnv(ABC):
|
||||||
to_postprocess["group_overrides"] = {}
|
to_postprocess["group_overrides"] = {}
|
||||||
to_postprocess["overrides"] = []
|
to_postprocess["overrides"] = []
|
||||||
to_postprocess["images"] = []
|
to_postprocess["images"] = []
|
||||||
print("Processing results")
|
logger.debug("Processing results")
|
||||||
for result in results:
|
for result in results:
|
||||||
to_postprocess["tokens"].append(result[0]["tokens"])
|
to_postprocess["tokens"].append(result[0]["tokens"])
|
||||||
to_postprocess["masks"].append(result[0]["masks"])
|
to_postprocess["masks"].append(result[0]["masks"])
|
||||||
|
|
@ -437,7 +444,7 @@ class BaseEnv(ABC):
|
||||||
setattr(self, key, data[key])
|
setattr(self, key, data[key])
|
||||||
|
|
||||||
def save_checkpoint(self, step, data=None):
|
def save_checkpoint(self, step, data=None):
|
||||||
print(f"Saving checkpoint at step {step} with data {data}")
|
logger.info("Saving checkpoint at step %s with data %s", step, data)
|
||||||
if data is None:
|
if data is None:
|
||||||
# Don't have anything to save, abort
|
# Don't have anything to save, abort
|
||||||
return
|
return
|
||||||
|
|
@ -535,7 +542,7 @@ class BaseEnv(ABC):
|
||||||
self.config.total_steps = data["num_steps"]
|
self.config.total_steps = data["num_steps"]
|
||||||
if self.config.total_steps == -1:
|
if self.config.total_steps == -1:
|
||||||
raise ValueError("Total steps not set in config or server!")
|
raise ValueError("Total steps not set in config or server!")
|
||||||
print(
|
logger.info(
|
||||||
f"Initialized env with id {self.env_id}: "
|
f"Initialized env with id {self.env_id}: "
|
||||||
f"curr_step: {self.curr_step}, "
|
f"curr_step: {self.curr_step}, "
|
||||||
f"checkpoint_dir: {self.checkpoint_dir}, "
|
f"checkpoint_dir: {self.checkpoint_dir}, "
|
||||||
|
|
@ -772,7 +779,7 @@ class BaseEnv(ABC):
|
||||||
with open(filepath, "w") as f:
|
with open(filepath, "w") as f:
|
||||||
json.dump(eval_result, f, indent=2)
|
json.dump(eval_result, f, indent=2)
|
||||||
|
|
||||||
print(f"Evaluation results saved to {filepath}")
|
logger.info("Evaluation results saved to %s", filepath)
|
||||||
|
|
||||||
# Write samples to JSONL file if provided
|
# Write samples to JSONL file if provided
|
||||||
if samples:
|
if samples:
|
||||||
|
|
@ -782,7 +789,7 @@ class BaseEnv(ABC):
|
||||||
with jsonlines.open(samples_filepath, "w") as writer:
|
with jsonlines.open(samples_filepath, "w") as writer:
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
writer.write(sample)
|
writer.write(sample)
|
||||||
print(f"Evaluation samples saved to {samples_filepath}")
|
logger.info("Evaluation samples saved to %s", samples_filepath)
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
|
|
@ -816,7 +823,7 @@ class BaseEnv(ABC):
|
||||||
elif resp.status >= 400:
|
elif resp.status >= 400:
|
||||||
logging.error(f"Client error: {resp.status}, not retrying")
|
logging.error(f"Client error: {resp.status}, not retrying")
|
||||||
return
|
return
|
||||||
print(await resp.text())
|
logger.debug(await resp.text())
|
||||||
|
|
||||||
def _post_json_with_compression(
|
def _post_json_with_compression(
|
||||||
self,
|
self,
|
||||||
|
|
@ -888,6 +895,8 @@ class BaseEnv(ABC):
|
||||||
group.setdefault("ref_logprobs", None)
|
group.setdefault("ref_logprobs", None)
|
||||||
group.setdefault("overrides", None)
|
group.setdefault("overrides", None)
|
||||||
group.setdefault("group_overrides", None)
|
group.setdefault("group_overrides", None)
|
||||||
|
group.setdefault("distill_token_ids", None)
|
||||||
|
group.setdefault("distill_logprobs", None)
|
||||||
|
|
||||||
for mask in group["masks"]:
|
for mask in group["masks"]:
|
||||||
self.completion_lengths.append(sum(m != -100 for m in mask))
|
self.completion_lengths.append(sum(m != -100 for m in mask))
|
||||||
|
|
@ -905,7 +914,12 @@ class BaseEnv(ABC):
|
||||||
|
|
||||||
if self.config.include_messages and group.get("messages") is None:
|
if self.config.include_messages and group.get("messages") is None:
|
||||||
group["messages"] = [
|
group["messages"] = [
|
||||||
self.tokenizer.decode(group["tokens"][i])
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.tokenizer.decode(group["tokens"][i]),
|
||||||
|
}
|
||||||
|
]
|
||||||
for i in range(len(group["tokens"]))
|
for i in range(len(group["tokens"]))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -913,7 +927,9 @@ class BaseEnv(ABC):
|
||||||
|
|
||||||
if self.jsonl_writer is not None:
|
if self.jsonl_writer is not None:
|
||||||
self.jsonl_writer.write(group)
|
self.jsonl_writer.write(group)
|
||||||
print(f"Wrote scored group to {self.config.data_path_to_save_groups}")
|
logger.info(
|
||||||
|
"Wrote scored group to %s", self.config.data_path_to_save_groups
|
||||||
|
)
|
||||||
|
|
||||||
valid_groups.append(group)
|
valid_groups.append(group)
|
||||||
|
|
||||||
|
|
@ -934,7 +950,7 @@ class BaseEnv(ABC):
|
||||||
if isinstance(data_to_send_to_api, dict)
|
if isinstance(data_to_send_to_api, dict)
|
||||||
else f"{len(data_to_send_to_api)} ScoredDataGroups"
|
else f"{len(data_to_send_to_api)} ScoredDataGroups"
|
||||||
)
|
)
|
||||||
print(f"Failed to send {data_type_str} after retries: {e}")
|
logger.error("Failed to send %s after retries: %s", data_type_str, e)
|
||||||
|
|
||||||
async def handle_env(
|
async def handle_env(
|
||||||
self, item_uuid: str
|
self, item_uuid: str
|
||||||
|
|
@ -944,7 +960,7 @@ class BaseEnv(ABC):
|
||||||
"""
|
"""
|
||||||
item = self.running_items.get(item_uuid)["item"]
|
item = self.running_items.get(item_uuid)["item"]
|
||||||
if item is None:
|
if item is None:
|
||||||
print(f"item {item_uuid} not found... returning")
|
logger.warning("item %s not found... returning", item_uuid)
|
||||||
return None
|
return None
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
logger.debug(f"handle_env: Starting with item: {item}")
|
logger.debug(f"handle_env: Starting with item: {item}")
|
||||||
|
|
@ -965,7 +981,7 @@ class BaseEnv(ABC):
|
||||||
to_postprocess = await self.postprocess_histories(to_postprocess)
|
to_postprocess = await self.postprocess_histories(to_postprocess)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in scoring: {item}")
|
logger.error(f"Error in scoring: {item}")
|
||||||
print(e)
|
logger.error("Scoring exception: %s", e)
|
||||||
to_postprocess = None
|
to_postprocess = None
|
||||||
self.running_items.pop(item_uuid, None)
|
self.running_items.pop(item_uuid, None)
|
||||||
duration = max(0.0, time.time() - start_time)
|
duration = max(0.0, time.time() - start_time)
|
||||||
|
|
@ -1106,10 +1122,9 @@ class BaseEnv(ABC):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue)
|
max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue)
|
||||||
print(
|
logger.info(
|
||||||
f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, "
|
f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, "
|
||||||
f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}",
|
f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}"
|
||||||
flush=True,
|
|
||||||
)
|
)
|
||||||
if (self.curr_step == 0) and (len(self.workers) == 0):
|
if (self.curr_step == 0) and (len(self.workers) == 0):
|
||||||
# We are starting up, so we should just skip the append to the list
|
# We are starting up, so we should just skip the append to the list
|
||||||
|
|
@ -1117,10 +1132,9 @@ class BaseEnv(ABC):
|
||||||
else:
|
else:
|
||||||
self.workers_added_list.append(max_num_workers - len(self.workers))
|
self.workers_added_list.append(max_num_workers - len(self.workers))
|
||||||
if len(self.workers) > max_num_workers:
|
if len(self.workers) > max_num_workers:
|
||||||
print(
|
logger.info(
|
||||||
f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, "
|
f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, "
|
||||||
"sending workers to backlog",
|
"sending workers to backlog"
|
||||||
flush=True,
|
|
||||||
)
|
)
|
||||||
num_to_reduce = len(self.workers) - max_num_workers
|
num_to_reduce = len(self.workers) - max_num_workers
|
||||||
running_items_to_remove = list(self.running_items.keys())[:num_to_reduce]
|
running_items_to_remove = list(self.running_items.keys())[:num_to_reduce]
|
||||||
|
|
@ -1266,18 +1280,22 @@ class BaseEnv(ABC):
|
||||||
# Initialize the processing
|
# Initialize the processing
|
||||||
self.curr_step = 0
|
self.curr_step = 0
|
||||||
|
|
||||||
print(f"Starting to process {self.n_groups_to_process} groups...")
|
logger.info("Starting to process %s groups...", self.n_groups_to_process)
|
||||||
|
|
||||||
# Process the required number of groups
|
# Process the required number of groups
|
||||||
while self.curr_step < self.n_groups_to_process:
|
while self.curr_step < self.n_groups_to_process:
|
||||||
# Get an item to process
|
# Get an item to process
|
||||||
item = await self.get_next_item()
|
item = await self.get_next_item()
|
||||||
if item is None:
|
if item is None:
|
||||||
print("No more items to process")
|
logger.info("No more items to process")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Process the group
|
# Process the group
|
||||||
print(f"Processing group {self.curr_step + 1}/{self.n_groups_to_process}")
|
logger.info(
|
||||||
|
"Processing group %s/%s",
|
||||||
|
self.curr_step + 1,
|
||||||
|
self.n_groups_to_process,
|
||||||
|
)
|
||||||
|
|
||||||
# Collect trajectories with the specified group size
|
# Collect trajectories with the specified group size
|
||||||
# Override the group_size temporarily
|
# Override the group_size temporarily
|
||||||
|
|
@ -1300,13 +1318,13 @@ class BaseEnv(ABC):
|
||||||
await self.wandb_log()
|
await self.wandb_log()
|
||||||
|
|
||||||
self.curr_step += 1
|
self.curr_step += 1
|
||||||
print(
|
logger.info(
|
||||||
f"Successfully processed group {self.curr_step}/{self.n_groups_to_process}"
|
f"Successfully processed group {self.curr_step}/{self.n_groups_to_process}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("Failed to process group, retrying...")
|
logger.warning("Failed to process group, retrying...")
|
||||||
|
|
||||||
print(f"Completed processing {self.curr_step} groups")
|
logger.info("Completed processing %s groups", self.curr_step)
|
||||||
|
|
||||||
# Close the output file if it's open
|
# Close the output file if it's open
|
||||||
if self.jsonl_writer is not None:
|
if self.jsonl_writer is not None:
|
||||||
|
|
@ -1340,8 +1358,7 @@ class BaseEnv(ABC):
|
||||||
"""Handles exceptions with clean output for known error types."""
|
"""Handles exceptions with clean output for known error types."""
|
||||||
if isinstance(ex, FailedExecutionException):
|
if isinstance(ex, FailedExecutionException):
|
||||||
# Handle argparse errors (already printed by argparse)
|
# Handle argparse errors (already printed by argparse)
|
||||||
print()
|
logger.error(ex.message.split("error: ")[-1])
|
||||||
print(ex.message.split("error: ")[-1])
|
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
raise ex
|
raise ex
|
||||||
|
|
@ -1402,7 +1419,7 @@ class BaseEnv(ABC):
|
||||||
if self.config is not None:
|
if self.config is not None:
|
||||||
with open(self.config, "r") as f:
|
with open(self.config, "r") as f:
|
||||||
yaml_config = yaml.safe_load(f)
|
yaml_config = yaml.safe_load(f)
|
||||||
print(f"Loaded config from {self.config}")
|
logger.info("Loaded config from %s", self.config)
|
||||||
else:
|
else:
|
||||||
yaml_config = {}
|
yaml_config = {}
|
||||||
|
|
||||||
|
|
@ -1424,31 +1441,61 @@ class BaseEnv(ABC):
|
||||||
cli_passed_flags, openai_full_prefix
|
cli_passed_flags, openai_full_prefix
|
||||||
) # CLI args
|
) # CLI args
|
||||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||||
if isinstance(default_server_configs, ServerBaseline) and (
|
|
||||||
|
# Debug logging for CLI args
|
||||||
|
logger.debug("[CLI DEBUG] cli_passed_flags = %s", cli_passed_flags)
|
||||||
|
logger.debug("[CLI DEBUG] openai_full_prefix = %s", openai_full_prefix)
|
||||||
|
logger.debug(
|
||||||
|
"[CLI DEBUG] oai_cli_passed_args = %s", oai_cli_passed_args
|
||||||
|
)
|
||||||
|
logger.debug("[CLI DEBUG] yaml_oai_config = %s", yaml_oai_config)
|
||||||
|
|
||||||
|
# Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided
|
||||||
|
# This allows any environment to use --openai.* CLI args without modifying config_init
|
||||||
|
# Use a new variable to avoid UnboundLocalError from closure scoping
|
||||||
|
effective_server_configs = default_server_configs
|
||||||
|
if isinstance(effective_server_configs, ServerBaseline) and (
|
||||||
oai_cli_passed_args or yaml_oai_config
|
oai_cli_passed_args or yaml_oai_config
|
||||||
):
|
):
|
||||||
raise ValueError(
|
# Convert ServerBaseline to APIServerConfig, preserving common fields
|
||||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
|
baseline_dict = effective_server_configs.model_dump()
|
||||||
|
effective_server_configs = APIServerConfig(**baseline_dict)
|
||||||
|
logger.info(
|
||||||
|
"Auto-converted ServerBaseline to APIServerConfig for CLI/YAML overrides"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(default_server_configs, list)
|
isinstance(effective_server_configs, list)
|
||||||
and len(default_server_configs) == 1
|
and len(effective_server_configs) == 1
|
||||||
):
|
):
|
||||||
# can't use the same var name because it shadows the class variable and we get an error
|
# can't use the same var name because it shadows the class variable and we get an error
|
||||||
default_openai_config_ = default_server_configs[0]
|
default_openai_config_ = effective_server_configs[0]
|
||||||
else:
|
else:
|
||||||
default_openai_config_ = default_server_configs
|
default_openai_config_ = effective_server_configs
|
||||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||||
yaml_oai_config = yaml_oai_config[0]
|
yaml_oai_config = yaml_oai_config[0]
|
||||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||||
yaml_oai_config, dict
|
yaml_oai_config, dict
|
||||||
):
|
):
|
||||||
|
logger.debug(
|
||||||
|
"[CLI DEBUG] default_openai_config_.model_dump() = %s",
|
||||||
|
default_openai_config_.model_dump(),
|
||||||
|
)
|
||||||
openai_config_dict = merge_dicts(
|
openai_config_dict = merge_dicts(
|
||||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||||
yaml_oai_config,
|
yaml_oai_config,
|
||||||
oai_cli_passed_args,
|
oai_cli_passed_args,
|
||||||
)
|
)
|
||||||
|
logger.debug(
|
||||||
|
"[CLI DEBUG] openai_config_dict after merge = %s",
|
||||||
|
openai_config_dict,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"[CLI DEBUG] Not merging: default_openai_config_ "
|
||||||
|
f"type={type(default_openai_config_)}, "
|
||||||
|
f"yaml_oai_config type={type(yaml_oai_config)}"
|
||||||
|
)
|
||||||
openai_config_dict = {}
|
openai_config_dict = {}
|
||||||
|
|
||||||
# 3. Server Manager Configuration (slurm, testing - not namespaced)
|
# 3. Server Manager Configuration (slurm, testing - not namespaced)
|
||||||
|
|
@ -1487,7 +1534,7 @@ class BaseEnv(ABC):
|
||||||
# Determine the final server_configs, handling single, multiple servers, and overrides.
|
# Determine the final server_configs, handling single, multiple servers, and overrides.
|
||||||
|
|
||||||
openai_configs = resolve_openai_configs(
|
openai_configs = resolve_openai_configs(
|
||||||
default_server_configs=default_server_configs,
|
default_server_configs=effective_server_configs,
|
||||||
openai_config_dict=openai_config_dict,
|
openai_config_dict=openai_config_dict,
|
||||||
yaml_config=yaml_config,
|
yaml_config=yaml_config,
|
||||||
cli_passed_flags=cli_passed_flags,
|
cli_passed_flags=cli_passed_flags,
|
||||||
|
|
@ -1597,7 +1644,7 @@ class BaseEnv(ABC):
|
||||||
if self.config is not None:
|
if self.config is not None:
|
||||||
with open(self.config, "r") as f:
|
with open(self.config, "r") as f:
|
||||||
yaml_config = yaml.safe_load(f)
|
yaml_config = yaml.safe_load(f)
|
||||||
print(f"Loaded config from {self.config}")
|
logger.info("Loaded config from %s", self.config)
|
||||||
else:
|
else:
|
||||||
yaml_config = {}
|
yaml_config = {}
|
||||||
|
|
||||||
|
|
@ -1770,7 +1817,7 @@ class BaseEnv(ABC):
|
||||||
"data_path_to_save_groups must be set for process mode"
|
"data_path_to_save_groups must be set for process mode"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
logger.info(
|
||||||
f"Processing {env_config.total_steps} groups of "
|
f"Processing {env_config.total_steps} groups of "
|
||||||
f"{env_config.group_size} responses and "
|
f"{env_config.group_size} responses and "
|
||||||
f"writing to {env_config.data_path_to_save_groups}"
|
f"writing to {env_config.data_path_to_save_groups}"
|
||||||
|
|
@ -1866,7 +1913,7 @@ class BaseEnv(ABC):
|
||||||
if self.config is not None:
|
if self.config is not None:
|
||||||
with open(self.config, "r") as f:
|
with open(self.config, "r") as f:
|
||||||
yaml_config = yaml.safe_load(f)
|
yaml_config = yaml.safe_load(f)
|
||||||
print(f"Loaded config from {self.config}")
|
logger.info("Loaded config from %s", self.config)
|
||||||
else:
|
else:
|
||||||
yaml_config = {}
|
yaml_config = {}
|
||||||
|
|
||||||
|
|
@ -2052,7 +2099,7 @@ class BaseEnv(ABC):
|
||||||
yaml.dump(
|
yaml.dump(
|
||||||
config_dict, f, default_flow_style=False, sort_keys=False
|
config_dict, f, default_flow_style=False, sort_keys=False
|
||||||
)
|
)
|
||||||
print(f"Dumped evaluate config to {config_filepath}")
|
logger.info("Dumped evaluate config to %s", config_filepath)
|
||||||
|
|
||||||
# --- Create and Run Environment ---
|
# --- Create and Run Environment ---
|
||||||
# Create the environment instance
|
# Create the environment instance
|
||||||
|
|
@ -2063,7 +2110,7 @@ class BaseEnv(ABC):
|
||||||
testing=server_manager_config.testing,
|
testing=server_manager_config.testing,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Running evaluation...")
|
logger.info("Running evaluation...")
|
||||||
# Handle the case where we might already be in an event loop
|
# Handle the case where we might already be in an event loop
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
@ -2072,6 +2119,6 @@ class BaseEnv(ABC):
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
asyncio.run(env._run_evaluate())
|
asyncio.run(env._run_evaluate())
|
||||||
|
|
||||||
print("Evaluation completed.")
|
logger.info("Evaluation completed.")
|
||||||
|
|
||||||
return CliEvaluateConfig
|
return CliEvaluateConfig
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ This wrapper maintains a tree structure of sequences, where:
|
||||||
- Branching occurs organically from different contexts and n > 1 completions
|
- Branching occurs organically from different contexts and n > 1 completions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -22,6 +24,8 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from atroposlib.envs.server_handling.server_baseline import APIServer
|
from atroposlib.envs.server_handling.server_baseline import APIServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SequenceNode(BaseModel):
|
class SequenceNode(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
@ -131,6 +135,10 @@ class ManagedServer:
|
||||||
# Fallback for tokenizers without chat template
|
# Fallback for tokenizers without chat template
|
||||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||||
|
|
||||||
|
def _debug_requests_enabled(self) -> bool:
|
||||||
|
"""Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1."""
|
||||||
|
return os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
|
||||||
|
|
||||||
def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]:
|
def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]:
|
||||||
"""
|
"""
|
||||||
Find a node that this input extends (default mode).
|
Find a node that this input extends (default mode).
|
||||||
|
|
@ -284,6 +292,17 @@ class ManagedServer:
|
||||||
completion_kwargs = kwargs.copy()
|
completion_kwargs = kwargs.copy()
|
||||||
completion_kwargs["prompt"] = prompt
|
completion_kwargs["prompt"] = prompt
|
||||||
completion_kwargs.pop("messages", None)
|
completion_kwargs.pop("messages", None)
|
||||||
|
if self._debug_requests_enabled():
|
||||||
|
msg_count = len(messages)
|
||||||
|
prompt_preview = prompt.replace("\n", "\\n")[:600]
|
||||||
|
logger.debug(
|
||||||
|
"[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s",
|
||||||
|
msg_count,
|
||||||
|
completion_kwargs.get("n"),
|
||||||
|
completion_kwargs.get("max_tokens"),
|
||||||
|
completion_kwargs.get("temperature"),
|
||||||
|
)
|
||||||
|
logger.debug("[ATROPOS_REQ_DEBUG] prompt_preview=%r", prompt_preview)
|
||||||
|
|
||||||
# Set model name if not provided
|
# Set model name if not provided
|
||||||
if "model" not in completion_kwargs:
|
if "model" not in completion_kwargs:
|
||||||
|
|
|
||||||
|
|
@ -198,7 +198,21 @@ def resolve_openai_configs(
|
||||||
raise FailedExecutionException(
|
raise FailedExecutionException(
|
||||||
f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}"
|
f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}"
|
||||||
) from e
|
) from e
|
||||||
|
elif isinstance(default_server_configs, APIServerConfig):
|
||||||
|
# Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline
|
||||||
|
logger.info(
|
||||||
|
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
final_openai_config = APIServerConfig(**openai_config_dict)
|
||||||
|
except Exception as e:
|
||||||
|
raise FailedExecutionException(
|
||||||
|
f"Error creating final OpenAI configuration from merged settings: {e}\n"
|
||||||
|
f"Merged Dict: {openai_config_dict}"
|
||||||
|
) from e
|
||||||
|
server_configs = final_openai_config
|
||||||
elif isinstance(default_server_configs, ServerBaseline):
|
elif isinstance(default_server_configs, ServerBaseline):
|
||||||
|
# Pure ServerBaseline (not APIServerConfig) - no CLI overrides possible
|
||||||
logger.info("Using ServerBaseline configuration.")
|
logger.info("Using ServerBaseline configuration.")
|
||||||
server_configs = default_server_configs
|
server_configs = default_server_configs
|
||||||
elif is_multi_server_default:
|
elif is_multi_server_default:
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,16 @@ This example uses `vLLM` for efficient inference during the (simulated) data gen
|
||||||
|
|
||||||
**Note:** This script is intended as a *reference example* for API integration and basic training setup. It is not optimized for large-scale, efficient training.
|
**Note:** This script is intended as a *reference example* for API integration and basic training setup. It is not optimized for large-scale, efficient training.
|
||||||
|
|
||||||
|
## On-Policy Distillation Scope
|
||||||
|
|
||||||
|
The current OPD integration in Atropos is transport-only:
|
||||||
|
|
||||||
|
- `ScoredDataGroup` / API payloads support `distill_token_ids` and `distill_logprobs`.
|
||||||
|
- Atropos API stores and returns those fields through `/scored_data` and `/batch`.
|
||||||
|
- Teacher orchestration (teacher endpoint calls, prompt rendering, top-k fetching) is intentionally out of scope in this PR.
|
||||||
|
|
||||||
|
If you train with distillation, provide the two distill arrays from your environment or external data pipeline before posting to the API.
|
||||||
|
|
||||||
### Custom vLLM Server
|
### Custom vLLM Server
|
||||||
|
|
||||||
The `vllm_api_server.py` file in this directory provides a customized vLLM API server implementation based on vLLM's native API. This server exposes enhanced endpoints for token and logprob tracking. The `VLLMServer` class in `atroposlib/envs/server_handling/vllm_server.py` can connect to this server for direct access to vLLM's `/generate` endpoint with full token-level logprobs.
|
The `vllm_api_server.py` file in this directory provides a customized vLLM API server implementation based on vLLM's native API. This server exposes enhanced endpoints for token and logprob tracking. The `VLLMServer` class in `atroposlib/envs/server_handling/vllm_server.py` can connect to this server for direct access to vLLM's `/generate` endpoint with full token-level logprobs.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue