diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 2740afea..6c103b36 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -10,7 +10,7 @@ This wrapper maintains a tree structure of sequences, where: import time import uuid import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from openai.types.chat.chat_completion import ( ChatCompletion, @@ -50,7 +50,7 @@ class ManagedServer: new branches. Provides proper masking for training (prompt tokens masked with -100, logprobs set to 0.0). - Uses the clean _tokens_and_logprobs_completion_wrapper interface internally. + Uses the clean tokens_and_logprobs_completion interface internally. """ def __init__( @@ -300,9 +300,7 @@ class ManagedServer: output_tokens_list, output_logprobs_list, finish_reasons, - ) = await self.server._tokens_and_logprobs_completion_wrapper( - **completion_kwargs - ) + ) = await self.server.tokens_and_logprobs_completion(**completion_kwargs) # Track each completion and build choices n = len(output_tokens_list) @@ -360,7 +358,9 @@ class ManagedServer: choice = Choice( finish_reason=finish_reason, index=i, - message=ChatCompletionMessage(content=completion_text, role="assistant"), + message=ChatCompletionMessage( + content=completion_text, role="assistant" + ), ) choices.append(choice) @@ -414,7 +414,7 @@ class ManagedServer: output_tokens_list, output_logprobs_list, finish_reasons, - ) = await self.server._tokens_and_logprobs_completion_wrapper(**kwargs) + ) = await self.server.tokens_and_logprobs_completion(**kwargs) # Track each completion and build choices n = len(output_tokens_list) diff --git a/environments/math_server.py b/environments/math_server.py index 27189c2e..c6dd18b8 100644 --- a/environments/math_server.py +++ b/environments/math_server.py @@ -20,7 +20,6 @@ from atroposlib.envs.base import ( EvalHandlingEnum, ScoredDataGroup, ) -from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer system_prompt = ( "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " @@ -186,15 +185,15 @@ class MathEnv(BaseEnv): @classmethod def config_init(self) -> Tuple[RSConfig, List[APIServerConfig]]: env_config = RSConfig( - tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + tokenizer_name="NousResearch/Hermes-4-14B", group_size=16, use_wandb=True, rollout_server_url="http://localhost:8000", total_steps=1000, batch_size=1024, - max_num_workers_per_node=24, + max_num_workers_per_node=12, steps_per_eval=25, - max_token_length=8192, # 22000 // (2 ** i), + max_token_length=16384, # 22000 // (2 ** i), wandb_name="math", eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, @@ -203,10 +202,11 @@ class MathEnv(BaseEnv): ) server_configs = [ APIServerConfig( - model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + model_name="NousResearch/Hermes-4-14B", base_url="http://localhost:9004/v1", api_key="x", num_requests_for_eval=256, # since evaling only on one... + server_type="sglang", ), ] @@ -373,17 +373,25 @@ class MathEnv(BaseEnv): if thinking_len < 1024: print("thinking_len is less than 1024, skipping", flush=True) return None, [] - chat_completions = await self.server.chat_completion( - messages=chat, - n=self.config.group_size, - max_tokens=thinking_len, - temperature=1.0, - top_p=0.95, - ) + # Use managed server for automatic token/logprob tracking + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + chat_completions = await managed.chat_completion( + messages=chat, + n=self.config.group_size, + max_tokens=thinking_len, + temperature=1.0, + top_p=0.95, + ) + # Get tracked sequences with aligned tokens and logprobs + state = managed.get_state() + nodes = state["nodes"] + print("Finished generation", flush=True) to_score = list() to_backlog = list() - for i, chat_completion in enumerate(chat_completions.choices): + for i, (chat_completion, node) in enumerate( + zip(chat_completions.choices, nodes) + ): messages = ( {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, @@ -394,6 +402,9 @@ class MathEnv(BaseEnv): messages, item[1], chat_completion.finish_reason, + node.tokens, + node.masked_tokens, + node.logprobs, ) ) print("scoring normal", flush=True) @@ -447,6 +458,14 @@ class MathEnv(BaseEnv): ] ), most_dissimilar_score, + # Pass tokens/masks/logprobs for solution 1 + to_postprocess["tokens"][most_dissimilar[0]], + to_postprocess["masks"][most_dissimilar[0]], + to_postprocess["inference_logprobs"][most_dissimilar[0]], + # Pass tokens/masks/logprobs for solution 2 + to_postprocess["tokens"][most_dissimilar[1]], + to_postprocess["masks"][most_dissimilar[1]], + to_postprocess["inference_logprobs"][most_dissimilar[1]], ) ) print( @@ -571,12 +590,16 @@ class MathEnv(BaseEnv): group = item[3] scores = item[4] finish_reasons = item[5] + tokens_list = item[6] + masks_list = item[7] + logprobs_list = item[8] to_postprocess = ScoredDataGroup() to_postprocess["tokens"] = list() to_postprocess["masks"] = list() to_postprocess["scores"] = list() to_postprocess["overrides"] = list() to_postprocess["messages"] = list() + to_postprocess["inference_logprobs"] = list() for i in range(len(group)): # convert from frozen set to dict conv = [dict(x) for x in group[i]] @@ -594,21 +617,21 @@ class MathEnv(BaseEnv): >= self.config.num_rollouts_to_keep ): self.selfcorrect_rollouts.pop(0) - out_dict = tokenize_for_trainer( - tokenizer=self.tokenizer, - chat=conv, - finish_reason=finish_reasons[i], - include_messages=True, - ) - to_postprocess["tokens"].append(out_dict["tokens"]) - to_postprocess["masks"].append(out_dict["masks"]) + # Use pre-computed tokens/masks/logprobs from managed_server + assert len(logprobs_list[i]) == len( + masks_list[i] + ), f"{len(logprobs_list[i])}, {len(masks_list[i])} mismatch" + to_postprocess["tokens"].append(tokens_list[i]) + to_postprocess["masks"].append(masks_list[i]) + to_postprocess["inference_logprobs"].append(logprobs_list[i]) to_postprocess["scores"].append(scores[i]) to_postprocess["overrides"].append(dict()) if (finish_reasons[i] == "length") and ( self.config.mask_too_long_completions ): to_postprocess["overrides"][-1]["set_advantage_to_zero"] = True - to_postprocess["messages"].append(out_dict["messages"]) + # Convert back to messages format for consistency + to_postprocess["messages"].append(conv) print("selfcorrect done, sending batch off") return to_postprocess, [] else: @@ -621,13 +644,20 @@ class MathEnv(BaseEnv): scores["scores"] = list() scores["overrides"] = list() scores["messages"] = list() + scores["inference_logprobs"] = list() gold = rollout_group_data[0][1] loop = asyncio.get_event_loop() random.shuffle(rollout_group_data) for item in rollout_group_data: resp = item[0][-1]["content"].split("")[-1] scores["overrides"].append(dict()) - if item[2] == "length": + # Extract pre-computed data from managed_server + tokens = item[3] + masks = item[4] + logprobs = item[5] + finish_reason = item[2] + + if finish_reason == "length": reward = False if self.config.mask_too_long_completions: scores["overrides"][-1]["set_advantage_to_zero"] = True @@ -636,19 +666,17 @@ class MathEnv(BaseEnv): reward = await task if reward is None: return None - out_dict = tokenize_for_trainer( - tokenizer=self.tokenizer, - chat=item[0], - finish_reason=item[2], - include_messages=True, - ) - tokens = out_dict["tokens"] - masks = out_dict["masks"] - messages = out_dict["messages"] + + assert len(logprobs) == len( + masks + ), f"{len(logprobs)}, {len(masks)} mismatch" + # Use messages from item[0] + messages = item[0] + # remove obviously bad examples if len([1 for i in masks if i != -100]) < 10: continue - if item[2] == "length": + if finish_reason == "length": # Note we set it here so we can filter out the examples that are too long # for the Judge loop. IF you set the config to not do this we fix it # in the collect_trajectories_normal function. @@ -657,6 +685,7 @@ class MathEnv(BaseEnv): scores["masks"].append(masks) scores["scores"].append(1.0 if reward else -1.0) scores["messages"].append(messages) + scores["inference_logprobs"].append(logprobs) if len(scores["tokens"]) >= self.config.group_size: break if any([score == 1.0 for score in scores["scores"]]): @@ -729,39 +758,48 @@ class MathEnv(BaseEnv): solution2=item[3][-1]["content"].split("")[-1], ) print("Sending to server") - chat = [ + chat_fwd = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt_fwd}, ] - max_token_length = self.config.max_token_length - len( - self.tokenizer.apply_chat_template(chat, add_generation_prompt=True) - ) - chat_completions_fwd = self.server.chat_completion( - messages=chat, - n=3, - max_tokens=max_token_length, - temperature=1.0, - top_p=0.95, + max_token_length_fwd = self.config.max_token_length - len( + self.tokenizer.apply_chat_template(chat_fwd, add_generation_prompt=True) ) + print("Sending to server") # Should be the same token length as the fwd but tokenizers are cursed - chat = [ + chat_bwd = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt_bwd}, ] - max_token_length = self.config.max_token_length - len( - self.tokenizer.apply_chat_template(chat, add_generation_prompt=True) - ) - chat_completions_bwd = self.server.chat_completion( - messages=chat, - n=3, - max_tokens=self.config.max_token_length, - temperature=1.0, - top_p=0.95, + max_token_length_bwd = self.config.max_token_length - len( + self.tokenizer.apply_chat_template(chat_bwd, add_generation_prompt=True) ) + + # Use managed server for both forward and backward completions + async def get_fwd_completion(): + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + return await managed.chat_completion( + messages=chat_fwd, + n=3, + max_tokens=max_token_length_fwd, + temperature=1.0, + top_p=0.95, + ) + + async def get_bwd_completion(): + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + return await managed.chat_completion( + messages=chat_bwd, + n=3, + max_tokens=max_token_length_bwd, + temperature=1.0, + top_p=0.95, + ) + print("Gathering completions") chat_completions_fwd, chat_completions_bwd = await asyncio.gather( - chat_completions_fwd, chat_completions_bwd + get_fwd_completion(), get_bwd_completion() ) print("Grabbed RLAIF completions") # Check for correct answers @@ -810,25 +848,35 @@ class MathEnv(BaseEnv): to_postprocess["scores"] = list() to_postprocess["overrides"] = list() to_postprocess["messages"] = list() + to_postprocess["inference_logprobs"] = list() + # Extract pre-computed tokens/masks/logprobs from backlog + tokens_1 = item[6] + masks_1 = item[7] + logprobs_1 = item[8] + tokens_2 = item[9] + masks_2 = item[10] + logprobs_2 = item[11] + # Add assertions to verify data integrity + assert len(logprobs_1) == len( + masks_1 + ), f"{len(logprobs_1)}, {len(masks_1)} mismatch" + assert len(logprobs_2) == len( + masks_2 + ), f"{len(logprobs_2)}, {len(masks_2)} mismatch" # add the first message in - out_dict = tokenize_for_trainer( - tokenizer=self.tokenizer, chat=item[3], include_messages=True - ) - tokens = out_dict["tokens"] - masks = out_dict["masks"] - to_postprocess["tokens"].append(tokens) - to_postprocess["masks"].append(masks) + to_postprocess["tokens"].append(tokens_1) + to_postprocess["masks"].append(masks_1) to_postprocess["scores"].append(1.0 if score_1 > score_2 else -1.0) - to_postprocess["messages"].append(out_dict["messages"]) - out_dict = tokenize_for_trainer( - tokenizer=self.tokenizer, chat=item[4], include_messages=True - ) - tokens = out_dict["tokens"] - masks = out_dict["masks"] - to_postprocess["tokens"].append(tokens) - to_postprocess["masks"].append(masks) + to_postprocess["messages"].append(item[3]) # Already converted to dicts + to_postprocess["inference_logprobs"].append(logprobs_1) + to_postprocess["overrides"].append(dict()) + # add the second message in + to_postprocess["tokens"].append(tokens_2) + to_postprocess["masks"].append(masks_2) to_postprocess["scores"].append(1.0 if score_2 > score_1 else -1.0) - to_postprocess["messages"].append(out_dict["messages"]) + to_postprocess["messages"].append(item[4]) # Already converted to dicts + to_postprocess["inference_logprobs"].append(logprobs_2) + to_postprocess["overrides"].append(dict()) to_postprocess["group_overrides"] = { "group_size": 2, } @@ -848,13 +896,19 @@ class MathEnv(BaseEnv): max_token_length = self.config.max_token_length - len( self.tokenizer.apply_chat_template(chat, add_generation_prompt=True) ) - chat_completions = await self.server.chat_completion( - messages=chat, - n=self.config.group_size, - max_tokens=max_token_length, - temperature=1.0, - top_p=0.95, - ) + # Use managed server for judge completions + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + chat_completions = await managed.chat_completion( + messages=chat, + n=self.config.group_size, + max_tokens=max_token_length, + temperature=1.0, + top_p=0.95, + ) + # Get tracked sequences with aligned tokens and logprobs + state = managed.get_state() + nodes = state["nodes"] + is_correct = [ ( chat_completion.message.content.split("")[-1] @@ -878,25 +932,29 @@ class MathEnv(BaseEnv): scores["scores"] = [] scores["overrides"] = [] scores["messages"] = [] + scores["inference_logprobs"] = [] for_table = [] - for i, chat_completion in enumerate(chat_completions.choices): - out_dict = tokenize_for_trainer( - tokenizer=self.tokenizer, - chat=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - {"role": "assistant", "content": chat_completion.message.content}, - ], - include_messages=True, - ) - tokens = out_dict["tokens"] - masks = out_dict["masks"] - messages = out_dict["messages"] + for i, (chat_completion, node) in enumerate( + zip(chat_completions.choices, nodes) + ): + # Extract pre-computed data from managed_server + tokens = node.tokens + masks = node.masked_tokens + logprobs = node.logprobs + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + {"role": "assistant", "content": chat_completion.message.content}, + ] + assert len(logprobs) == len( + masks + ), f"{len(logprobs)}, {len(masks)} mismatch" if not is_correct[i]: scores["tokens"].append(tokens) scores["masks"].append(masks) scores["scores"].append(-1.0) scores["messages"].append(messages) + scores["inference_logprobs"].append(logprobs) scores["overrides"].append(dict()) if (chat_completion.finish_reason == "length") and ( self.config.mask_too_long_completions @@ -932,20 +990,31 @@ class MathEnv(BaseEnv): retry_messages, add_generation_prompt=True ) ) - retry_chat_completions = await self.server.chat_completion( - messages=retry_messages, - n=self.config.group_size, - max_tokens=max_token_length, - temperature=1.0, - top_p=0.95, - ) + # Use managed server for retry completions + async with self.server.managed_server( + tokenizer=self.tokenizer + ) as managed: + retry_chat_completions = await managed.chat_completion( + messages=retry_messages, + n=self.config.group_size, + max_tokens=max_token_length, + temperature=1.0, + top_p=0.95, + ) + # Get tracked sequences with aligned tokens and logprobs + retry_state = managed.get_state() + retry_nodes = retry_state["nodes"] + print("Gathering completions") scoring_data = [] backlog_scores = [] backlog_reasons = [] backlog_messages = [] - for j, retry_chat_completion in enumerate( - retry_chat_completions.choices + backlog_tokens = [] + backlog_masks = [] + backlog_logprobs = [] + for j, (retry_chat_completion, retry_node) in enumerate( + zip(retry_chat_completions.choices, retry_nodes) ): print(f"Scoring generation {j} for retry...") backlog_messages.append( @@ -962,6 +1031,10 @@ class MathEnv(BaseEnv): ) ) backlog_reasons.append(retry_chat_completion.finish_reason) + # Store tokens, masks, and logprobs from managed_server + backlog_tokens.append(retry_node.tokens) + backlog_masks.append(retry_node.masked_tokens) + backlog_logprobs.append(retry_node.logprobs) if retry_chat_completion.finish_reason == "length": scoring_data.append(0) backlog_scores.append(0) @@ -998,10 +1071,18 @@ class MathEnv(BaseEnv): tuple(backlog_messages), tuple(backlog_scores), tuple(backlog_reasons), + tuple(backlog_tokens), + tuple(backlog_masks), + tuple(backlog_logprobs), ) ) print(f"Sending to selfcorrect, {len(to_backlog)} in backlog") scores["scores"].append(sum(scoring_data) / len(scoring_data)) + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["messages"].append(messages) + scores["inference_logprobs"].append(logprobs) + scores["overrides"].append(dict()) self.judge_success_rate.append( sum(scoring_data) / len(scoring_data) ) @@ -1012,6 +1093,7 @@ class MathEnv(BaseEnv): scores["tokens"].append(tokens) scores["masks"].append(masks) scores["messages"].append(messages) + scores["inference_logprobs"].append(logprobs) scores["overrides"].append(dict()) if all([score == 1.0 for score in scores["scores"]]) and ( random.random() < self.config.percent_length_penalty