convert fundamentals prediction env to use managed server

This commit is contained in:
teknium 2025-11-14 09:48:56 +00:00
parent ff46cfff44
commit 4738fabd57

View file

@ -14,7 +14,6 @@ from atroposlib.envs.base import (
Item, Item,
ScoredDataGroup, ScoredDataGroup,
) )
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# System prompt only contains thinking instructions # System prompt only contains thinking instructions
system_prompt = """You are a deep thinking AI financial analyst. system_prompt = """You are a deep thinking AI financial analyst.
@ -174,17 +173,21 @@ class FundamentalPredictionEnv(BaseEnv):
messages, add_generation_prompt=True, tokenize=False messages, add_generation_prompt=True, tokenize=False
) )
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Get completions from the model # Get completions from the model
completions = await self.server.completion( completions = await managed.completion(
prompt=prompt, prompt=prompt,
n=self.config.group_size, n=self.config.group_size,
max_tokens=1024 * 15, max_tokens=1024 * 15,
temperature=0.8, # Using higher temperature for diverse responses temperature=0.8, # Using higher temperature for diverse responses
) )
state = managed.get_state()
nodes = state["nodes"]
to_score = list() to_score = list()
for _, completion_choice in enumerate(completions.choices): for i, completion_choice in enumerate(completions.choices):
# Create a copy of the prompt messages # Create a copy of the prompt messages
trajectory_messages = [] trajectory_messages = []
for role_dict in item[0]: for role_dict in item[0]:
@ -197,12 +200,15 @@ class FundamentalPredictionEnv(BaseEnv):
# Add to scoring queue with expected answer, magnitude, and fundamental metric # Add to scoring queue with expected answer, magnitude, and fundamental metric
to_score.append( to_score.append(
( {
tuple(trajectory_messages), "messages": tuple(trajectory_messages),
item[1], # answer (maintained/raised/reduced) "answer": item[1], # answer (maintained/raised/reduced)
item[2], # magnitude "magnitude": item[2], # magnitude
item[3], # fundamental_metric "fundamental_metric": item[3], # fundamental_metric
) "tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
) )
# Call score to get the scored data # Call score to get the scored data
@ -321,20 +327,21 @@ class FundamentalPredictionEnv(BaseEnv):
scores["tokens"] = list() scores["tokens"] = list()
scores["masks"] = list() scores["masks"] = list()
scores["scores"] = list() scores["scores"] = list()
scores["inference_logprobs"] = list()
# Get the expected answer, magnitude, and fundamental metric # Get the expected answer, magnitude, and fundamental metric
expected_answer = rollout_group_data[0][ expected_answer = rollout_group_data[0][
1 "answer"
] # "maintained", "raised", or "reduced" ] # "maintained", "raised", or "reduced"
expected_magnitude = rollout_group_data[0][2] # Expected percentage change expected_magnitude = rollout_group_data[0]["magnitude"] # Expected percentage change
fundamental_metric = rollout_group_data[0][3] # Type of fundamental metric fundamental_metric = rollout_group_data[0]["fundamental_metric"] # Type of fundamental metric
# Shuffle to avoid bias in selection # Shuffle to avoid bias in selection
random.shuffle(rollout_group_data) random.shuffle(rollout_group_data)
for item in rollout_group_data: for item in rollout_group_data:
# Extract the model's response # Extract the model's response
model_response = item[0][-1]["content"] model_response = item["messages"][-1]["content"]
# Extract the prediction and magnitude from the model's response # Extract the prediction and magnitude from the model's response
prediction, magnitude = self._extract_prediction( prediction, magnitude = self._extract_prediction(
@ -364,10 +371,9 @@ class FundamentalPredictionEnv(BaseEnv):
# For binary reward signal, any positive score gets +1, otherwise -1 # For binary reward signal, any positive score gets +1, otherwise -1
binary_reward = 1.0 if final_score > 0 else -1.0 binary_reward = 1.0 if final_score > 0 else -1.0
# Tokenize the conversation for learning tokens = item["tokens"]
out_dict = tokenize_for_trainer(self.tokenizer, item[0]) masks = item["masks"]
tokens = out_dict["tokens"] logprobs = item["logprobs"]
masks = out_dict["masks"]
# Remove examples with insufficient context # Remove examples with insufficient context
if len([1 for i in masks if i != -100]) < 10: if len([1 for i in masks if i != -100]) < 10:
@ -375,6 +381,7 @@ class FundamentalPredictionEnv(BaseEnv):
scores["tokens"].append(tokens) scores["tokens"].append(tokens)
scores["masks"].append(masks) scores["masks"].append(masks)
scores["inference_logprobs"].append(logprobs)
scores["scores"].append(binary_reward) scores["scores"].append(binary_reward)
# For tracking metrics # For tracking metrics
@ -429,8 +436,9 @@ class FundamentalPredictionEnv(BaseEnv):
messages, add_generation_prompt=True, tokenize=False messages, add_generation_prompt=True, tokenize=False
) )
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Get model completion # Get model completion
completion = await self.server.completion( completion = await managed.completion(
prompt=prompt, prompt=prompt,
n=1, n=1,
max_tokens=1024 * 16, max_tokens=1024 * 16,