diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index af74fa55..a8e97077 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -460,7 +460,9 @@ class ManagedServer: ) # Call the tokens and logprobs wrapper directly - logger.warning("managed_server chat_completion calling backend completion wrapper") + logger.warning( + "managed_server chat_completion calling backend completion wrapper" + ) ( prompt_tokens, output_tokens_list, diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 12c16079..1c88ab62 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -151,7 +151,11 @@ class TeacherDistillationEnv(BaseEnv, ABC): # and Qwen3-30B) share the same vocabulary, so even though the # name_or_path strings differ they should use the fast path. student_tok_name = getattr(self.tokenizer, "name_or_path", None) or "" - if student_tok_name and teacher_tok_name and student_tok_name != teacher_tok_name: + if ( + student_tok_name + and teacher_tok_name + and student_tok_name != teacher_tok_name + ): try: from transformers import AutoTokenizer @@ -367,7 +371,9 @@ class TeacherDistillationEnv(BaseEnv, ABC): teacher_topk_lps = payload["prompt_topk_logprobs"] alignment = self._build_student_teacher_alignment(text, token_ids, teacher_ids) - return self._align_and_remap(token_ids, teacher_topk_ids, teacher_topk_lps, alignment) + return self._align_and_remap( + token_ids, teacher_topk_ids, teacher_topk_lps, alignment + ) # ------------------------------------------------------------------ # Group enrichment diff --git a/atroposlib/tests/test_server_logprobs.py b/atroposlib/tests/test_server_logprobs.py index 2da50b42..8cbd84ad 100644 --- a/atroposlib/tests/test_server_logprobs.py +++ b/atroposlib/tests/test_server_logprobs.py @@ -41,7 +41,9 @@ class _FakeAPIServer(APIServer): class _FakeRoutedServer: - def __init__(self, name: str, train_slots: int, eval_slots: int, healthy: bool = True): + def __init__( + self, name: str, train_slots: int, eval_slots: int, healthy: bool = True + ): self.name = name self.server_healthy = healthy self.sem = AsyncSemWithAdaptiveWeight(4) diff --git a/environments/gsm8k_server_teacher_distill.py b/environments/gsm8k_server_teacher_distill.py index 159fa4d3..8276436b 100644 --- a/environments/gsm8k_server_teacher_distill.py +++ b/environments/gsm8k_server_teacher_distill.py @@ -47,5 +47,6 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): ) return env_config, server_config + if __name__ == "__main__": GSM8kTeacherDistillEnv.cli() diff --git a/example_trainer/data.py b/example_trainer/data.py index bf7f5b19..4823eb64 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -189,23 +189,27 @@ def pad_data_to_good_offset( rows = max(0, token_setup_len - 1) token_mat = np.full((rows, local_k), -1, dtype=np.int64) - logprob_mat = np.full( - (rows, local_k), -1e9, dtype=np.float32 - ) + logprob_mat = np.full((rows, local_k), -1e9, dtype=np.float32) # Shift by one to align with causal labels like inference_logprobs. copy_positions = min( - len(per_pos_token_ids), len(per_pos_logprobs), token_setup_len + len(per_pos_token_ids), + len(per_pos_logprobs), + token_setup_len, ) for pos in range(1, copy_positions): src_ids = per_pos_token_ids[pos] src_lps = per_pos_logprobs[pos] - if not isinstance(src_ids, list) or not isinstance(src_lps, list): + if not isinstance(src_ids, list) or not isinstance( + src_lps, list + ): continue topk = min(local_k, len(src_ids), len(src_lps)) if topk <= 0: continue - token_mat[pos - 1, :topk] = np.array(src_ids[:topk], dtype=np.int64) + token_mat[pos - 1, :topk] = np.array( + src_ids[:topk], dtype=np.int64 + ) logprob_mat[pos - 1, :topk] = np.array( src_lps[:topk], dtype=np.float32 ) @@ -222,14 +226,18 @@ def pad_data_to_good_offset( ) else: rows = max(0, token_setup_len - 1) - distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64)) + distill_token_ids_padded.append( + np.full((rows, 1), -1, dtype=np.int64) + ) distill_logprobs_padded.append( np.full((rows, 1), -1e9, dtype=np.float32) ) else: rows = max(0, token_setup_len - 1) distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64)) - distill_logprobs_padded.append(np.full((rows, 1), -1e9, dtype=np.float32)) + distill_logprobs_padded.append( + np.full((rows, 1), -1e9, dtype=np.float32) + ) # Extract temperature (priority: override > generation_params > group_overrides > 1.0) t = 1.0 @@ -310,10 +318,14 @@ def pad_data_to_good_offset( else None ) final_distill_token_id_batches = ( - distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None + distill_token_id_batches + if (has_any_distill and distill_token_id_batches) + else None ) final_distill_logprob_batches = ( - distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None + distill_logprob_batches + if (has_any_distill and distill_logprob_batches) + else None ) return ( diff --git a/example_trainer/training.py b/example_trainer/training.py index 673ed795..b7cab944 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -311,7 +311,10 @@ def compute_distillation_loss( if distill_token_ids.dim() != 3 or distill_logprobs.dim() != 3: return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 - if distill_token_ids.shape[:2] != labels.shape or distill_logprobs.shape != distill_token_ids.shape: + if ( + distill_token_ids.shape[:2] != labels.shape + or distill_logprobs.shape != distill_token_ids.shape + ): return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 temp = max(1e-6, float(temperature)) @@ -330,7 +333,9 @@ def compute_distillation_loss( # Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab]. scaled_logits = logits / temp log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1] - student_logp_topk = torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer + student_logp_topk = ( + torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer + ) masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9) teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1) diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 131861d5..b20704e2 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -360,10 +360,16 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: ret["prompt_token_ids"] = final_output.prompt_token_ids ret["token_ids"] = [x.token_ids for x in final_output.outputs] - if sampling_params.prompt_logprobs is not None and final_output.prompt_logprobs is not None: + if ( + sampling_params.prompt_logprobs is not None + and final_output.prompt_logprobs is not None + ): ret["prompt_logprobs"] = [ - {int(tok_id): lp.logprob for tok_id, lp in pos.items()} - if pos is not None else None + ( + {int(tok_id): lp.logprob for tok_id, lp in pos.items()} + if pos is not None + else None + ) for pos in final_output.prompt_logprobs ]