diff --git a/environments/fundamental_prediction_environment.py b/environments/fundamental_prediction_environment.py index 7eda4110..14e42880 100644 --- a/environments/fundamental_prediction_environment.py +++ b/environments/fundamental_prediction_environment.py @@ -14,7 +14,6 @@ from atroposlib.envs.base import ( Item, ScoredDataGroup, ) -from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer # System prompt only contains thinking instructions system_prompt = """You are a deep thinking AI financial analyst. @@ -174,17 +173,21 @@ class FundamentalPredictionEnv(BaseEnv): messages, add_generation_prompt=True, tokenize=False ) - # Get completions from the model - completions = await self.server.completion( - prompt=prompt, - n=self.config.group_size, - max_tokens=1024 * 15, - temperature=0.8, # Using higher temperature for diverse responses - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + # Get completions from the model + completions = await managed.completion( + prompt=prompt, + n=self.config.group_size, + max_tokens=1024 * 15, + temperature=0.8, # Using higher temperature for diverse responses + ) + + state = managed.get_state() + nodes = state["nodes"] to_score = list() - for _, completion_choice in enumerate(completions.choices): + for i, completion_choice in enumerate(completions.choices): # Create a copy of the prompt messages trajectory_messages = [] for role_dict in item[0]: @@ -197,12 +200,15 @@ class FundamentalPredictionEnv(BaseEnv): # Add to scoring queue with expected answer, magnitude, and fundamental metric to_score.append( - ( - tuple(trajectory_messages), - item[1], # answer (maintained/raised/reduced) - item[2], # magnitude - item[3], # fundamental_metric - ) + { + "messages": tuple(trajectory_messages), + "answer": item[1], # answer (maintained/raised/reduced) + "magnitude": item[2], # magnitude + "fundamental_metric": item[3], # fundamental_metric + "tokens": nodes[i].tokens, + "masks": nodes[i].masked_tokens, + "logprobs": nodes[i].logprobs, + } ) # Call score to get the scored data @@ -321,20 +327,21 @@ class FundamentalPredictionEnv(BaseEnv): scores["tokens"] = list() scores["masks"] = list() scores["scores"] = list() + scores["inference_logprobs"] = list() # Get the expected answer, magnitude, and fundamental metric expected_answer = rollout_group_data[0][ - 1 + "answer" ] # "maintained", "raised", or "reduced" - expected_magnitude = rollout_group_data[0][2] # Expected percentage change - fundamental_metric = rollout_group_data[0][3] # Type of fundamental metric + expected_magnitude = rollout_group_data[0]["magnitude"] # Expected percentage change + fundamental_metric = rollout_group_data[0]["fundamental_metric"] # Type of fundamental metric # Shuffle to avoid bias in selection random.shuffle(rollout_group_data) for item in rollout_group_data: # Extract the model's response - model_response = item[0][-1]["content"] + model_response = item["messages"][-1]["content"] # Extract the prediction and magnitude from the model's response prediction, magnitude = self._extract_prediction( @@ -364,10 +371,9 @@ class FundamentalPredictionEnv(BaseEnv): # For binary reward signal, any positive score gets +1, otherwise -1 binary_reward = 1.0 if final_score > 0 else -1.0 - # Tokenize the conversation for learning - out_dict = tokenize_for_trainer(self.tokenizer, item[0]) - tokens = out_dict["tokens"] - masks = out_dict["masks"] + tokens = item["tokens"] + masks = item["masks"] + logprobs = item["logprobs"] # Remove examples with insufficient context if len([1 for i in masks if i != -100]) < 10: @@ -375,6 +381,7 @@ class FundamentalPredictionEnv(BaseEnv): scores["tokens"].append(tokens) scores["masks"].append(masks) + scores["inference_logprobs"].append(logprobs) scores["scores"].append(binary_reward) # For tracking metrics @@ -429,14 +436,15 @@ class FundamentalPredictionEnv(BaseEnv): messages, add_generation_prompt=True, tokenize=False ) - # Get model completion - completion = await self.server.completion( - prompt=prompt, - n=1, - max_tokens=1024 * 16, - temperature=0.2, # Lower for eval - split="eval", - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + # Get model completion + completion = await managed.completion( + prompt=prompt, + n=1, + max_tokens=1024 * 16, + temperature=0.2, # Lower for eval + split="eval", + ) # Extract the model's response model_response = completion.choices[0].text