mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
post merge changes
This commit is contained in:
parent
c89854a350
commit
79e392c446
3 changed files with 200 additions and 88 deletions
|
|
@ -145,9 +145,10 @@ class ScoredData(BaseModel):
|
|||
group_overrides: Optional[dict] = None
|
||||
images: Optional[Any] = None
|
||||
env_id: Optional[int] = None # ID of the environment that generated this data
|
||||
# On-policy distillation: top-K logprobs from teacher model
|
||||
# Structure: [sequence][position][top_k] = [token_id, logprob]
|
||||
onpolicydistill_logprobs: Optional[List[List[List[List]]]] = None
|
||||
# On-policy distillation (new format): parallel token ids + logprobs.
|
||||
# Shape for both: [sequence][position][top_k]
|
||||
distill_token_ids: Optional[List[List[List[int]]]] = None
|
||||
distill_logprobs: Optional[List[List[List[float]]]] = None
|
||||
|
||||
@field_validator("messages", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -185,7 +186,8 @@ def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]:
|
|||
"group_overrides": scored_data.group_overrides,
|
||||
"images": scored_data.images,
|
||||
"env_id": scored_data.env_id,
|
||||
"onpolicydistill_logprobs": scored_data.onpolicydistill_logprobs,
|
||||
"distill_token_ids": scored_data.distill_token_ids,
|
||||
"distill_logprobs": scored_data.distill_logprobs,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -66,9 +66,10 @@ class ScoredDataGroup(TypedDict):
|
|||
group_overrides: Optional[Dict]
|
||||
overrides: Optional[List[Dict]]
|
||||
images: Optional[Any]
|
||||
# On-policy distillation: top-K logprobs from teacher model
|
||||
# Structure: List[List[List[List[Union[int, float]]]]] = [sequence][position][top_k] = [token_id, logprob]
|
||||
onpolicydistill_logprobs: Optional[List[List[List[List]]]]
|
||||
# On-policy distillation (new format): parallel token ids + logprobs.
|
||||
# distill_token_ids/distill_logprobs are [sequence][position][top_k]
|
||||
distill_token_ids: Optional[List[List[List[int]]]]
|
||||
distill_logprobs: Optional[List[List[List[float]]]]
|
||||
|
||||
|
||||
class ScoredDataItem(TypedDict):
|
||||
|
|
@ -81,8 +82,9 @@ class ScoredDataItem(TypedDict):
|
|||
group_overrides: Optional[Dict]
|
||||
overrides: Optional[Dict]
|
||||
images: Optional[Any]
|
||||
# On-policy distillation: top-K logprobs from teacher model per position
|
||||
onpolicydistill_logprobs: Optional[List[List[List]]]
|
||||
# On-policy distillation (new format): parallel token ids + logprobs per position.
|
||||
distill_token_ids: Optional[List[List[int]]]
|
||||
distill_logprobs: Optional[List[List[float]]]
|
||||
|
||||
|
||||
class EvalHandlingEnum(Enum):
|
||||
|
|
@ -233,6 +235,18 @@ class BaseEnvConfig(BaseModel):
|
|||
default=20,
|
||||
description="Number of top logprobs to fetch from teacher model per position.",
|
||||
)
|
||||
teacher_prefix_text: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional text prefix prepended to teacher scoring prompt. "
|
||||
"Useful for behavior steering. Prefix token positions are trimmed out "
|
||||
"before sending distillation arrays to the trainer, preserving alignment.",
|
||||
)
|
||||
teacher_system_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional teacher system prompt. For completion-style teacher APIs, "
|
||||
"this is converted to a textual prefix. For chat fallback, this is injected "
|
||||
"as a leading system message.",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
|
|
@ -343,7 +357,7 @@ class BaseEnv(ABC):
|
|||
token_sequences: List[List[int]],
|
||||
messages_list: Optional[List[List[Dict]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
) -> List[List[List[List]]]:
|
||||
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||
"""
|
||||
Fetch top-K logprobs from teacher model for given sequences.
|
||||
|
||||
|
|
@ -356,16 +370,16 @@ class BaseEnv(ABC):
|
|||
top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k)
|
||||
|
||||
Returns:
|
||||
List of top-K logprobs per position per sequence
|
||||
Structure: [batch][position][top_k] = [token_id, logprob]
|
||||
Returns empty list if teacher_base_url is not configured.
|
||||
Tuple of (distill_token_ids, distill_logprobs), both shaped as:
|
||||
[batch][position][top_k].
|
||||
Returns ([], []) if teacher_base_url is not configured.
|
||||
"""
|
||||
logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences")
|
||||
logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}")
|
||||
|
||||
if not self.config.teacher_base_url:
|
||||
logger.warning("[TEACHER] No teacher_base_url configured, returning empty")
|
||||
return []
|
||||
return [], []
|
||||
|
||||
if top_k is None:
|
||||
top_k = self.config.teacher_top_k
|
||||
|
|
@ -380,14 +394,29 @@ class BaseEnv(ABC):
|
|||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
results = []
|
||||
token_id_results: List[List[List[int]]] = []
|
||||
logprob_results: List[List[List[float]]] = []
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i, tokens in enumerate(token_sequences):
|
||||
logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens")
|
||||
# Decode tokens to text
|
||||
full_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
|
||||
# Decode original sequence and optionally prepend teacher steering text.
|
||||
base_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
|
||||
steering_prefix = ""
|
||||
if self.config.teacher_system_prompt:
|
||||
steering_prefix += (
|
||||
"System instruction:\n"
|
||||
f"{self.config.teacher_system_prompt.strip()}\n\n"
|
||||
)
|
||||
if self.config.teacher_prefix_text:
|
||||
steering_prefix += self.config.teacher_prefix_text
|
||||
full_text = steering_prefix + base_text
|
||||
prefix_token_len = (
|
||||
len(self.tokenizer.encode(steering_prefix, add_special_tokens=False))
|
||||
if steering_prefix
|
||||
else 0
|
||||
)
|
||||
|
||||
# Try vLLM-style completions first (supports prompt_logprobs)
|
||||
# This is most efficient as it doesn't generate new tokens
|
||||
|
|
@ -409,9 +438,18 @@ class BaseEnv(ABC):
|
|||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
seq_result = self._parse_completion_logprobs(data, top_k)
|
||||
if seq_result:
|
||||
results.append(seq_result)
|
||||
seq_token_ids, seq_logprobs = self._parse_completion_logprobs(
|
||||
data, top_k
|
||||
)
|
||||
if seq_token_ids and seq_logprobs:
|
||||
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens(
|
||||
seq_token_ids,
|
||||
seq_logprobs,
|
||||
target_token_len=len(tokens),
|
||||
prefix_token_len=prefix_token_len,
|
||||
)
|
||||
token_id_results.append(aligned_ids)
|
||||
logprob_results.append(aligned_lps)
|
||||
continue
|
||||
except Exception:
|
||||
pass # Fall through to chat completions
|
||||
|
|
@ -419,10 +457,25 @@ class BaseEnv(ABC):
|
|||
# Fallback: Use chat/completions with logprobs (OpenAI style)
|
||||
# This requires messages format
|
||||
if messages_list and i < len(messages_list):
|
||||
messages = messages_list[i]
|
||||
messages = list(messages_list[i])
|
||||
if self.config.teacher_system_prompt:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": self.config.teacher_system_prompt,
|
||||
}
|
||||
] + messages
|
||||
else:
|
||||
# Convert text to simple message format
|
||||
messages = [{"role": "user", "content": full_text}]
|
||||
messages = []
|
||||
if self.config.teacher_system_prompt:
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": self.config.teacher_system_prompt,
|
||||
}
|
||||
)
|
||||
messages.append({"role": "user", "content": full_text})
|
||||
|
||||
chat_request = {
|
||||
"model": model_name,
|
||||
|
|
@ -442,40 +495,88 @@ class BaseEnv(ABC):
|
|||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
seq_result = self._parse_chat_logprobs(data, top_k)
|
||||
results.append(seq_result)
|
||||
seq_token_ids, seq_logprobs = self._parse_chat_logprobs(
|
||||
data, top_k
|
||||
)
|
||||
# Chat fallback logprobs are for generated tokens, not prompt tokens.
|
||||
# To keep alignment correct for distillation, return empty per-position rows.
|
||||
if seq_token_ids and len(seq_token_ids) >= len(tokens):
|
||||
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens(
|
||||
seq_token_ids,
|
||||
seq_logprobs,
|
||||
target_token_len=len(tokens),
|
||||
prefix_token_len=0,
|
||||
)
|
||||
else:
|
||||
aligned_ids = [[] for _ in range(len(tokens))]
|
||||
aligned_lps = [[] for _ in range(len(tokens))]
|
||||
token_id_results.append(aligned_ids)
|
||||
logprob_results.append(aligned_lps)
|
||||
else:
|
||||
logger.warning(f"Teacher API returned {response.status}")
|
||||
results.append([])
|
||||
token_id_results.append([[] for _ in range(len(tokens))])
|
||||
logprob_results.append([[] for _ in range(len(tokens))])
|
||||
except Exception as e:
|
||||
logger.warning(f"Teacher chat request failed: {e}")
|
||||
results.append([])
|
||||
token_id_results.append([[] for _ in range(len(tokens))])
|
||||
logprob_results.append([[] for _ in range(len(tokens))])
|
||||
|
||||
return results
|
||||
return token_id_results, logprob_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching teacher logprobs: {e}")
|
||||
return []
|
||||
return [], []
|
||||
|
||||
def _align_teacher_topk_to_tokens(
|
||||
self,
|
||||
seq_token_ids: List[List[int]],
|
||||
seq_logprobs: List[List[float]],
|
||||
target_token_len: int,
|
||||
prefix_token_len: int = 0,
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
"""
|
||||
Trim teacher prefix positions and enforce exact length alignment with source tokens.
|
||||
"""
|
||||
n = min(len(seq_token_ids), len(seq_logprobs))
|
||||
aligned_ids = list(seq_token_ids[:n])
|
||||
aligned_lps = list(seq_logprobs[:n])
|
||||
|
||||
if prefix_token_len > 0:
|
||||
aligned_ids = aligned_ids[prefix_token_len:]
|
||||
aligned_lps = aligned_lps[prefix_token_len:]
|
||||
|
||||
# Truncate any tail token (e.g., generated token when max_tokens>0 with echo=True)
|
||||
aligned_ids = aligned_ids[:target_token_len]
|
||||
aligned_lps = aligned_lps[:target_token_len]
|
||||
|
||||
# Pad missing positions to preserve strict [seq][position][top_k] shape.
|
||||
if len(aligned_ids) < target_token_len:
|
||||
pad_count = target_token_len - len(aligned_ids)
|
||||
aligned_ids.extend([[] for _ in range(pad_count)])
|
||||
aligned_lps.extend([[] for _ in range(pad_count)])
|
||||
|
||||
return aligned_ids, aligned_lps
|
||||
|
||||
def _parse_completion_logprobs(
|
||||
self, data: Dict, top_k: int
|
||||
) -> List[List[List]]:
|
||||
"""Parse logprobs from vLLM-style completion response."""
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
"""Parse token ids + logprobs from vLLM-style completion response."""
|
||||
try:
|
||||
choice = data.get("choices", [{}])[0]
|
||||
logprobs_data = choice.get("logprobs", {})
|
||||
|
||||
# vLLM returns top_logprobs as list of dicts
|
||||
top_logprobs = logprobs_data.get("top_logprobs", [])
|
||||
tokens = logprobs_data.get("tokens", [])
|
||||
|
||||
if not top_logprobs:
|
||||
return []
|
||||
|
||||
seq_result = []
|
||||
return [], []
|
||||
|
||||
seq_token_ids: List[List[int]] = []
|
||||
seq_logprobs: List[List[float]] = []
|
||||
for pos_logprobs in top_logprobs:
|
||||
if pos_logprobs is None:
|
||||
seq_result.append([])
|
||||
seq_token_ids.append([])
|
||||
seq_logprobs.append([])
|
||||
elif isinstance(pos_logprobs, dict):
|
||||
# Format: {token_str: logprob, ...}
|
||||
sorted_items = sorted(
|
||||
|
|
@ -483,51 +584,59 @@ class BaseEnv(ABC):
|
|||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:top_k]
|
||||
pos_result = []
|
||||
pos_ids: List[int] = []
|
||||
pos_lps: List[float] = []
|
||||
for token_str, logprob in sorted_items:
|
||||
# Convert token string to ID
|
||||
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
||||
if token_ids:
|
||||
pos_result.append([token_ids[0], float(logprob)])
|
||||
seq_result.append(pos_result)
|
||||
pos_ids.append(int(token_ids[0]))
|
||||
pos_lps.append(float(logprob))
|
||||
seq_token_ids.append(pos_ids)
|
||||
seq_logprobs.append(pos_lps)
|
||||
else:
|
||||
seq_result.append([])
|
||||
|
||||
return seq_result
|
||||
seq_token_ids.append([])
|
||||
seq_logprobs.append([])
|
||||
|
||||
return seq_token_ids, seq_logprobs
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing completion logprobs: {e}")
|
||||
return []
|
||||
return [], []
|
||||
|
||||
def _parse_chat_logprobs(
|
||||
self, data: Dict, top_k: int
|
||||
) -> List[List[List]]:
|
||||
"""Parse logprobs from OpenAI-style chat completion response."""
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
"""Parse token ids + logprobs from OpenAI-style chat completion response."""
|
||||
try:
|
||||
choice = data.get("choices", [{}])[0]
|
||||
logprobs_data = choice.get("logprobs", {})
|
||||
|
||||
if not logprobs_data:
|
||||
return []
|
||||
return [], []
|
||||
|
||||
content = logprobs_data.get("content", [])
|
||||
seq_result = []
|
||||
seq_token_ids: List[List[int]] = []
|
||||
seq_logprobs: List[List[float]] = []
|
||||
|
||||
for token_data in content:
|
||||
top_logprobs = token_data.get("top_logprobs", [])
|
||||
pos_result = []
|
||||
pos_ids: List[int] = []
|
||||
pos_lps: List[float] = []
|
||||
for item in top_logprobs[:top_k]:
|
||||
token_str = item.get("token", "")
|
||||
logprob = item.get("logprob", 0.0)
|
||||
# Convert token string to ID
|
||||
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
||||
if token_ids:
|
||||
pos_result.append([token_ids[0], float(logprob)])
|
||||
seq_result.append(pos_result)
|
||||
|
||||
return seq_result
|
||||
pos_ids.append(int(token_ids[0]))
|
||||
pos_lps.append(float(logprob))
|
||||
seq_token_ids.append(pos_ids)
|
||||
seq_logprobs.append(pos_lps)
|
||||
|
||||
return seq_token_ids, seq_logprobs
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing chat logprobs: {e}")
|
||||
return []
|
||||
return [], []
|
||||
|
||||
@classmethod
|
||||
def config_init(
|
||||
|
|
@ -1108,6 +1217,8 @@ class BaseEnv(ABC):
|
|||
group.setdefault("ref_logprobs", None)
|
||||
group.setdefault("overrides", None)
|
||||
group.setdefault("group_overrides", None)
|
||||
group.setdefault("distill_token_ids", None)
|
||||
group.setdefault("distill_logprobs", None)
|
||||
|
||||
for mask in group["masks"]:
|
||||
self.completion_lengths.append(sum(m != -100 for m in mask))
|
||||
|
|
@ -1129,31 +1240,6 @@ class BaseEnv(ABC):
|
|||
for i in range(len(group["tokens"]))
|
||||
]
|
||||
|
||||
# Automatic on-policy distillation: fetch teacher logprobs if enabled
|
||||
logger.info(f"[DISTILL DEBUG] distillation_enabled={self.config.distillation_enabled}, teacher_base_url={self.config.teacher_base_url}")
|
||||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
||||
logger.info(f"[DISTILL DEBUG] Distillation is enabled! Checking for existing logprobs...")
|
||||
if group.get("onpolicydistill_logprobs") is None:
|
||||
logger.info(f"[DISTILL DEBUG] No existing logprobs, fetching from teacher...")
|
||||
try:
|
||||
teacher_logprobs = await self.get_teacher_logprobs(
|
||||
token_sequences=group["tokens"],
|
||||
messages_list=group.get("messages"),
|
||||
)
|
||||
if teacher_logprobs:
|
||||
group["onpolicydistill_logprobs"] = teacher_logprobs
|
||||
logger.info(
|
||||
f"[DISTILL DEBUG] Added teacher logprobs for {len(teacher_logprobs)} sequences"
|
||||
)
|
||||
else:
|
||||
logger.warning("[DISTILL DEBUG] get_teacher_logprobs returned empty!")
|
||||
except Exception as e:
|
||||
logger.error(f"[DISTILL DEBUG] Failed to fetch teacher logprobs: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.debug(f"[DISTILL DEBUG] Distillation skipped - not enabled or no teacher URL")
|
||||
|
||||
await self.add_rollouts_for_wandb(group, item)
|
||||
|
||||
if self.jsonl_writer is not None:
|
||||
|
|
@ -1167,15 +1253,22 @@ class BaseEnv(ABC):
|
|||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
||||
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups")
|
||||
for group in valid_groups:
|
||||
if group.get("onpolicydistill_logprobs") is None:
|
||||
has_new_format = (
|
||||
group.get("distill_token_ids") is not None
|
||||
and group.get("distill_logprobs") is not None
|
||||
)
|
||||
if not has_new_format:
|
||||
try:
|
||||
teacher_logprobs = await self.get_teacher_logprobs(
|
||||
teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs(
|
||||
token_sequences=group["tokens"],
|
||||
messages_list=group.get("messages"),
|
||||
)
|
||||
if teacher_logprobs:
|
||||
group["onpolicydistill_logprobs"] = teacher_logprobs
|
||||
logger.info(f"[DISTILL] Added teacher logprobs for {len(teacher_logprobs)} sequences")
|
||||
if teacher_token_ids and teacher_logprobs:
|
||||
group["distill_token_ids"] = teacher_token_ids
|
||||
group["distill_logprobs"] = teacher_logprobs
|
||||
logger.info(
|
||||
f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences"
|
||||
)
|
||||
else:
|
||||
logger.warning("[DISTILL] get_teacher_logprobs returned empty")
|
||||
except Exception as e:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue