diff --git a/environments/mcqa_thinking_env.py b/environments/mcqa_thinking_env.py index f584f715..4304de4b 100644 --- a/environments/mcqa_thinking_env.py +++ b/environments/mcqa_thinking_env.py @@ -14,7 +14,6 @@ from atroposlib.envs.base import ( Item, ScoredDataGroup, ) -from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer system_prompt = ( "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " @@ -168,13 +167,17 @@ class MCQAThinkingEnv(BaseEnv): messages, add_generation_prompt=True, tokenize=False ) - # Get completions from the model using completion() instead of chat_completion() - completions = await self.server.completion( - prompt=prompt, - n=self.config.group_size, - max_tokens=1024 * 15, - temperature=1.0, # Using temperature to get diverse responses - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + # Get completions from the model using completion() instead of chat_completion() + completions = await managed.completion( + prompt=prompt, + n=self.config.group_size, + max_tokens=1024 * 15, + temperature=1.0, # Using temperature to get diverse responses + ) + + state = managed.get_state() + nodes = state["nodes"] to_score = list() @@ -191,12 +194,15 @@ class MCQAThinkingEnv(BaseEnv): # Add to scoring queue with expected answer, ground truth text, and stop reason to_score.append( - ( - tuple(trajectory_messages), - item[1], # Letter (A, B, C, D) - item[2], # Include the answer_string/ground_truth_text - completion_choice.finish_reason, # Add the stop reason - ) + { + "messages": tuple(trajectory_messages), + "expected_answer": item[1], # Letter (A, B, C, D) + "ground_truth_text": item[2], # Include the answer_string/ground_truth_text + "finish_reason": completion_choice.finish_reason, # Add the stop reason + "tokens": nodes[i].tokens, + "masks": nodes[i].masked_tokens, + "logprobs": nodes[i].logprobs, + } ) # Call score to get the scored data @@ -313,18 +319,19 @@ class MCQAThinkingEnv(BaseEnv): scores["tokens"] = list() scores["masks"] = list() scores["scores"] = list() + scores["inference_logprobs"] = list() # Get the expected answer letter - expected_answer = rollout_group_data[0][1] # Letter A, B, C, D - ground_truth_text = rollout_group_data[0][2] + expected_answer = rollout_group_data[0]["expected_answer"] # Letter A, B, C, D + ground_truth_text = rollout_group_data[0]["ground_truth_text"] # Shuffle to avoid bias in selection random.shuffle(rollout_group_data) for item in rollout_group_data: # Extract the model's response - model_response = item[0][-1]["content"] - stop_reason = item[3] # Get the stop reason + model_response = item["messages"][-1]["content"] + stop_reason = item["finish_reason"] # Get the stop reason # If the response was cut off due to length, give it a score of 0 if stop_reason == "length": @@ -343,10 +350,9 @@ class MCQAThinkingEnv(BaseEnv): else: reward = 0 # Wrong answer gets 0 reward - # Tokenize the conversation for learning - out_dict = tokenize_for_trainer(self.tokenizer, item[0]) - tokens = out_dict["tokens"] - masks = out_dict["masks"] + tokens = item["tokens"] + masks = item["masks"] + logprobs = item["logprobs"] # Remove examples with insufficient context if len([1 for i in masks if i != -100]) < 10: @@ -354,6 +360,7 @@ class MCQAThinkingEnv(BaseEnv): scores["tokens"].append(tokens) scores["masks"].append(masks) + scores["inference_logprobs"].append(logprobs) scores["scores"].append(1.0 if reward else -1.0) # Break once we have enough examples @@ -448,8 +455,8 @@ class MCQAThinkingEnv(BaseEnv): ( self.tokenizer.decode(scored_data["tokens"][i]), scored_data["scores"][i], - item[1], - item[2], + item[1] if isinstance(item, tuple) else item["expected_answer"], + item[2] if isinstance(item, tuple) else item["ground_truth_text"], ) for i in range(num_keep) ]