mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
initial commit
This commit is contained in:
parent
81b2d4daab
commit
cc9b891eba
2 changed files with 234 additions and 0 deletions
|
|
@ -66,6 +66,9 @@ class ScoredDataGroup(TypedDict):
|
|||
group_overrides: Optional[Dict]
|
||||
overrides: Optional[List[Dict]]
|
||||
images: Optional[Any]
|
||||
# On-policy distillation: top-K logprobs from teacher model
|
||||
# Structure: List[List[List[List[Union[int, float]]]]] = [sequence][position][top_k] = [token_id, logprob]
|
||||
onpolicydistill_logprobs: Optional[List[List[List[List]]]]
|
||||
|
||||
|
||||
class ScoredDataItem(TypedDict):
|
||||
|
|
@ -78,6 +81,8 @@ class ScoredDataItem(TypedDict):
|
|||
group_overrides: Optional[Dict]
|
||||
overrides: Optional[Dict]
|
||||
images: Optional[Any]
|
||||
# On-policy distillation: top-K logprobs from teacher model per position
|
||||
onpolicydistill_logprobs: Optional[List[List[List]]]
|
||||
|
||||
|
||||
class EvalHandlingEnum(Enum):
|
||||
|
|
@ -205,6 +210,31 @@ class BaseEnvConfig(BaseModel):
|
|||
"eval_helpers for the standard Hermes reasoning prompt.",
|
||||
)
|
||||
|
||||
# On-policy distillation settings
|
||||
teacher_base_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Base URL of teacher model for distillation. Supports any OpenAI-compatible API "
|
||||
"(vLLM, OpenAI, Together, etc.). Examples: 'http://localhost:8001/v1', 'https://api.openai.com/v1'",
|
||||
)
|
||||
teacher_model_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Model name for teacher API calls (e.g., 'gpt-4o', 'meta-llama/Llama-3-70b'). "
|
||||
"If None, uses 'default' which works for single-model vLLM servers.",
|
||||
)
|
||||
teacher_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for teacher model. Can also be set via TEACHER_API_KEY env var.",
|
||||
)
|
||||
teacher_top_k: int = Field(
|
||||
default=10,
|
||||
description="Number of top logprobs to fetch from teacher model per position.",
|
||||
)
|
||||
distillation_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Enable on-policy distillation. When True, automatically fetches teacher logprobs "
|
||||
"after scoring and includes them in data sent to trainer.",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
name: Optional[str] = None
|
||||
|
|
@ -309,6 +339,190 @@ class BaseEnv(ABC):
|
|||
# Calculate derived batch size
|
||||
return int(self.config.batch_size * effective_fraction)
|
||||
|
||||
async def get_teacher_logprobs(
|
||||
self,
|
||||
token_sequences: List[List[int]],
|
||||
messages_list: Optional[List[List[Dict]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
) -> List[List[List[List]]]:
|
||||
"""
|
||||
Fetch top-K logprobs from teacher model for given sequences.
|
||||
|
||||
Supports any OpenAI-compatible API (vLLM, OpenAI, Together, etc.).
|
||||
|
||||
Args:
|
||||
token_sequences: List of token ID sequences to get logprobs for
|
||||
messages_list: Optional list of message histories (for chat APIs).
|
||||
If provided, uses chat/completions with logprobs.
|
||||
top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k)
|
||||
|
||||
Returns:
|
||||
List of top-K logprobs per position per sequence
|
||||
Structure: [batch][position][top_k] = [token_id, logprob]
|
||||
Returns empty list if teacher_base_url is not configured.
|
||||
"""
|
||||
if not self.config.teacher_base_url:
|
||||
return []
|
||||
|
||||
if top_k is None:
|
||||
top_k = self.config.teacher_top_k
|
||||
|
||||
# Get API key from config or environment
|
||||
api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "")
|
||||
model_name = self.config.teacher_model_name or "default"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
results = []
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i, tokens in enumerate(token_sequences):
|
||||
# Decode tokens to text
|
||||
full_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
|
||||
|
||||
# Try vLLM-style completions first (supports prompt_logprobs)
|
||||
# This is most efficient as it doesn't generate new tokens
|
||||
request_data = {
|
||||
"model": model_name,
|
||||
"prompt": full_text,
|
||||
"max_tokens": 1,
|
||||
"temperature": 1.0,
|
||||
"logprobs": top_k,
|
||||
"echo": True, # Include prompt in response with logprobs
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.config.teacher_base_url}/completions",
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
seq_result = self._parse_completion_logprobs(data, top_k)
|
||||
if seq_result:
|
||||
results.append(seq_result)
|
||||
continue
|
||||
except Exception:
|
||||
pass # Fall through to chat completions
|
||||
|
||||
# Fallback: Use chat/completions with logprobs (OpenAI style)
|
||||
# This requires messages format
|
||||
if messages_list and i < len(messages_list):
|
||||
messages = messages_list[i]
|
||||
else:
|
||||
# Convert text to simple message format
|
||||
messages = [{"role": "user", "content": full_text}]
|
||||
|
||||
chat_request = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"max_tokens": 1,
|
||||
"temperature": 1.0,
|
||||
"logprobs": True,
|
||||
"top_logprobs": top_k,
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.config.teacher_base_url}/chat/completions",
|
||||
json=chat_request,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
seq_result = self._parse_chat_logprobs(data, top_k)
|
||||
results.append(seq_result)
|
||||
else:
|
||||
logger.warning(f"Teacher API returned {response.status}")
|
||||
results.append([])
|
||||
except Exception as e:
|
||||
logger.warning(f"Teacher chat request failed: {e}")
|
||||
results.append([])
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching teacher logprobs: {e}")
|
||||
return []
|
||||
|
||||
def _parse_completion_logprobs(
|
||||
self, data: Dict, top_k: int
|
||||
) -> List[List[List]]:
|
||||
"""Parse logprobs from vLLM-style completion response."""
|
||||
try:
|
||||
choice = data.get("choices", [{}])[0]
|
||||
logprobs_data = choice.get("logprobs", {})
|
||||
|
||||
# vLLM returns top_logprobs as list of dicts
|
||||
top_logprobs = logprobs_data.get("top_logprobs", [])
|
||||
tokens = logprobs_data.get("tokens", [])
|
||||
|
||||
if not top_logprobs:
|
||||
return []
|
||||
|
||||
seq_result = []
|
||||
for pos_logprobs in top_logprobs:
|
||||
if pos_logprobs is None:
|
||||
seq_result.append([])
|
||||
elif isinstance(pos_logprobs, dict):
|
||||
# Format: {token_str: logprob, ...}
|
||||
sorted_items = sorted(
|
||||
pos_logprobs.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:top_k]
|
||||
pos_result = []
|
||||
for token_str, logprob in sorted_items:
|
||||
# Convert token string to ID
|
||||
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
||||
if token_ids:
|
||||
pos_result.append([token_ids[0], float(logprob)])
|
||||
seq_result.append(pos_result)
|
||||
else:
|
||||
seq_result.append([])
|
||||
|
||||
return seq_result
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing completion logprobs: {e}")
|
||||
return []
|
||||
|
||||
def _parse_chat_logprobs(
|
||||
self, data: Dict, top_k: int
|
||||
) -> List[List[List]]:
|
||||
"""Parse logprobs from OpenAI-style chat completion response."""
|
||||
try:
|
||||
choice = data.get("choices", [{}])[0]
|
||||
logprobs_data = choice.get("logprobs", {})
|
||||
|
||||
if not logprobs_data:
|
||||
return []
|
||||
|
||||
content = logprobs_data.get("content", [])
|
||||
seq_result = []
|
||||
|
||||
for token_data in content:
|
||||
top_logprobs = token_data.get("top_logprobs", [])
|
||||
pos_result = []
|
||||
for item in top_logprobs[:top_k]:
|
||||
token_str = item.get("token", "")
|
||||
logprob = item.get("logprob", 0.0)
|
||||
# Convert token string to ID
|
||||
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
||||
if token_ids:
|
||||
pos_result.append([token_ids[0], float(logprob)])
|
||||
seq_result.append(pos_result)
|
||||
|
||||
return seq_result
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing chat logprobs: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def config_init(
|
||||
cls,
|
||||
|
|
@ -909,6 +1123,22 @@ class BaseEnv(ABC):
|
|||
for i in range(len(group["tokens"]))
|
||||
]
|
||||
|
||||
# Automatic on-policy distillation: fetch teacher logprobs if enabled
|
||||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
||||
if group.get("onpolicydistill_logprobs") is None:
|
||||
try:
|
||||
teacher_logprobs = await self.get_teacher_logprobs(
|
||||
token_sequences=group["tokens"],
|
||||
messages_list=group.get("messages"),
|
||||
)
|
||||
if teacher_logprobs:
|
||||
group["onpolicydistill_logprobs"] = teacher_logprobs
|
||||
logger.debug(
|
||||
f"Added teacher logprobs for {len(teacher_logprobs)} sequences"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch teacher logprobs: {e}")
|
||||
|
||||
await self.add_rollouts_for_wandb(group, item)
|
||||
|
||||
if self.jsonl_writer is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue