initial commit

This commit is contained in:
Jai Suphavadeeprasit 2026-02-16 11:46:20 -05:00
parent 81b2d4daab
commit cc9b891eba
2 changed files with 234 additions and 0 deletions

View file

@ -66,6 +66,9 @@ class ScoredDataGroup(TypedDict):
group_overrides: Optional[Dict]
overrides: Optional[List[Dict]]
images: Optional[Any]
# On-policy distillation: top-K logprobs from teacher model
# Structure: List[List[List[List[Union[int, float]]]]] = [sequence][position][top_k] = [token_id, logprob]
onpolicydistill_logprobs: Optional[List[List[List[List]]]]
class ScoredDataItem(TypedDict):
@ -78,6 +81,8 @@ class ScoredDataItem(TypedDict):
group_overrides: Optional[Dict]
overrides: Optional[Dict]
images: Optional[Any]
# On-policy distillation: top-K logprobs from teacher model per position
onpolicydistill_logprobs: Optional[List[List[List]]]
class EvalHandlingEnum(Enum):
@ -205,6 +210,31 @@ class BaseEnvConfig(BaseModel):
"eval_helpers for the standard Hermes reasoning prompt.",
)
# On-policy distillation settings
teacher_base_url: Optional[str] = Field(
default=None,
description="Base URL of teacher model for distillation. Supports any OpenAI-compatible API "
"(vLLM, OpenAI, Together, etc.). Examples: 'http://localhost:8001/v1', 'https://api.openai.com/v1'",
)
teacher_model_name: Optional[str] = Field(
default=None,
description="Model name for teacher API calls (e.g., 'gpt-4o', 'meta-llama/Llama-3-70b'). "
"If None, uses 'default' which works for single-model vLLM servers.",
)
teacher_api_key: Optional[str] = Field(
default=None,
description="API key for teacher model. Can also be set via TEACHER_API_KEY env var.",
)
teacher_top_k: int = Field(
default=10,
description="Number of top logprobs to fetch from teacher model per position.",
)
distillation_enabled: bool = Field(
default=False,
description="Enable on-policy distillation. When True, automatically fetches teacher logprobs "
"after scoring and includes them in data sent to trainer.",
)
class BaseEnv(ABC):
name: Optional[str] = None
@ -309,6 +339,190 @@ class BaseEnv(ABC):
# Calculate derived batch size
return int(self.config.batch_size * effective_fraction)
async def get_teacher_logprobs(
self,
token_sequences: List[List[int]],
messages_list: Optional[List[List[Dict]]] = None,
top_k: Optional[int] = None,
) -> List[List[List[List]]]:
"""
Fetch top-K logprobs from teacher model for given sequences.
Supports any OpenAI-compatible API (vLLM, OpenAI, Together, etc.).
Args:
token_sequences: List of token ID sequences to get logprobs for
messages_list: Optional list of message histories (for chat APIs).
If provided, uses chat/completions with logprobs.
top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k)
Returns:
List of top-K logprobs per position per sequence
Structure: [batch][position][top_k] = [token_id, logprob]
Returns empty list if teacher_base_url is not configured.
"""
if not self.config.teacher_base_url:
return []
if top_k is None:
top_k = self.config.teacher_top_k
# Get API key from config or environment
api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "")
model_name = self.config.teacher_model_name or "default"
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
results = []
try:
async with aiohttp.ClientSession() as session:
for i, tokens in enumerate(token_sequences):
# Decode tokens to text
full_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
# Try vLLM-style completions first (supports prompt_logprobs)
# This is most efficient as it doesn't generate new tokens
request_data = {
"model": model_name,
"prompt": full_text,
"max_tokens": 1,
"temperature": 1.0,
"logprobs": top_k,
"echo": True, # Include prompt in response with logprobs
}
try:
async with session.post(
f"{self.config.teacher_base_url}/completions",
json=request_data,
headers=headers,
timeout=aiohttp.ClientTimeout(total=120),
) as response:
if response.status == 200:
data = await response.json()
seq_result = self._parse_completion_logprobs(data, top_k)
if seq_result:
results.append(seq_result)
continue
except Exception:
pass # Fall through to chat completions
# Fallback: Use chat/completions with logprobs (OpenAI style)
# This requires messages format
if messages_list and i < len(messages_list):
messages = messages_list[i]
else:
# Convert text to simple message format
messages = [{"role": "user", "content": full_text}]
chat_request = {
"model": model_name,
"messages": messages,
"max_tokens": 1,
"temperature": 1.0,
"logprobs": True,
"top_logprobs": top_k,
}
try:
async with session.post(
f"{self.config.teacher_base_url}/chat/completions",
json=chat_request,
headers=headers,
timeout=aiohttp.ClientTimeout(total=120),
) as response:
if response.status == 200:
data = await response.json()
seq_result = self._parse_chat_logprobs(data, top_k)
results.append(seq_result)
else:
logger.warning(f"Teacher API returned {response.status}")
results.append([])
except Exception as e:
logger.warning(f"Teacher chat request failed: {e}")
results.append([])
return results
except Exception as e:
logger.error(f"Error fetching teacher logprobs: {e}")
return []
def _parse_completion_logprobs(
self, data: Dict, top_k: int
) -> List[List[List]]:
"""Parse logprobs from vLLM-style completion response."""
try:
choice = data.get("choices", [{}])[0]
logprobs_data = choice.get("logprobs", {})
# vLLM returns top_logprobs as list of dicts
top_logprobs = logprobs_data.get("top_logprobs", [])
tokens = logprobs_data.get("tokens", [])
if not top_logprobs:
return []
seq_result = []
for pos_logprobs in top_logprobs:
if pos_logprobs is None:
seq_result.append([])
elif isinstance(pos_logprobs, dict):
# Format: {token_str: logprob, ...}
sorted_items = sorted(
pos_logprobs.items(),
key=lambda x: x[1],
reverse=True
)[:top_k]
pos_result = []
for token_str, logprob in sorted_items:
# Convert token string to ID
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
if token_ids:
pos_result.append([token_ids[0], float(logprob)])
seq_result.append(pos_result)
else:
seq_result.append([])
return seq_result
except Exception as e:
logger.warning(f"Error parsing completion logprobs: {e}")
return []
def _parse_chat_logprobs(
self, data: Dict, top_k: int
) -> List[List[List]]:
"""Parse logprobs from OpenAI-style chat completion response."""
try:
choice = data.get("choices", [{}])[0]
logprobs_data = choice.get("logprobs", {})
if not logprobs_data:
return []
content = logprobs_data.get("content", [])
seq_result = []
for token_data in content:
top_logprobs = token_data.get("top_logprobs", [])
pos_result = []
for item in top_logprobs[:top_k]:
token_str = item.get("token", "")
logprob = item.get("logprob", 0.0)
# Convert token string to ID
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
if token_ids:
pos_result.append([token_ids[0], float(logprob)])
seq_result.append(pos_result)
return seq_result
except Exception as e:
logger.warning(f"Error parsing chat logprobs: {e}")
return []
@classmethod
def config_init(
cls,
@ -909,6 +1123,22 @@ class BaseEnv(ABC):
for i in range(len(group["tokens"]))
]
# Automatic on-policy distillation: fetch teacher logprobs if enabled
if self.config.distillation_enabled and self.config.teacher_base_url:
if group.get("onpolicydistill_logprobs") is None:
try:
teacher_logprobs = await self.get_teacher_logprobs(
token_sequences=group["tokens"],
messages_list=group.get("messages"),
)
if teacher_logprobs:
group["onpolicydistill_logprobs"] = teacher_logprobs
logger.debug(
f"Added teacher logprobs for {len(teacher_logprobs)} sequences"
)
except Exception as e:
logger.warning(f"Failed to fetch teacher logprobs: {e}")
await self.add_rollouts_for_wandb(group, item)
if self.jsonl_writer is not None: