convert fundamentals prediction env to use managed server

This commit is contained in:
teknium 2025-11-14 09:48:56 +00:00
parent ff46cfff44
commit 4738fabd57

View file

@ -14,7 +14,6 @@ from atroposlib.envs.base import (
Item,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# System prompt only contains thinking instructions
system_prompt = """You are a deep thinking AI financial analyst.
@ -174,17 +173,21 @@ class FundamentalPredictionEnv(BaseEnv):
messages, add_generation_prompt=True, tokenize=False
)
# Get completions from the model
completions = await self.server.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=1024 * 15,
temperature=0.8, # Using higher temperature for diverse responses
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Get completions from the model
completions = await managed.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=1024 * 15,
temperature=0.8, # Using higher temperature for diverse responses
)
state = managed.get_state()
nodes = state["nodes"]
to_score = list()
for _, completion_choice in enumerate(completions.choices):
for i, completion_choice in enumerate(completions.choices):
# Create a copy of the prompt messages
trajectory_messages = []
for role_dict in item[0]:
@ -197,12 +200,15 @@ class FundamentalPredictionEnv(BaseEnv):
# Add to scoring queue with expected answer, magnitude, and fundamental metric
to_score.append(
(
tuple(trajectory_messages),
item[1], # answer (maintained/raised/reduced)
item[2], # magnitude
item[3], # fundamental_metric
)
{
"messages": tuple(trajectory_messages),
"answer": item[1], # answer (maintained/raised/reduced)
"magnitude": item[2], # magnitude
"fundamental_metric": item[3], # fundamental_metric
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
# Call score to get the scored data
@ -321,20 +327,21 @@ class FundamentalPredictionEnv(BaseEnv):
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
scores["inference_logprobs"] = list()
# Get the expected answer, magnitude, and fundamental metric
expected_answer = rollout_group_data[0][
1
"answer"
] # "maintained", "raised", or "reduced"
expected_magnitude = rollout_group_data[0][2] # Expected percentage change
fundamental_metric = rollout_group_data[0][3] # Type of fundamental metric
expected_magnitude = rollout_group_data[0]["magnitude"] # Expected percentage change
fundamental_metric = rollout_group_data[0]["fundamental_metric"] # Type of fundamental metric
# Shuffle to avoid bias in selection
random.shuffle(rollout_group_data)
for item in rollout_group_data:
# Extract the model's response
model_response = item[0][-1]["content"]
model_response = item["messages"][-1]["content"]
# Extract the prediction and magnitude from the model's response
prediction, magnitude = self._extract_prediction(
@ -364,10 +371,9 @@ class FundamentalPredictionEnv(BaseEnv):
# For binary reward signal, any positive score gets +1, otherwise -1
binary_reward = 1.0 if final_score > 0 else -1.0
# Tokenize the conversation for learning
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
tokens = item["tokens"]
masks = item["masks"]
logprobs = item["logprobs"]
# Remove examples with insufficient context
if len([1 for i in masks if i != -100]) < 10:
@ -375,6 +381,7 @@ class FundamentalPredictionEnv(BaseEnv):
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["inference_logprobs"].append(logprobs)
scores["scores"].append(binary_reward)
# For tracking metrics
@ -429,14 +436,15 @@ class FundamentalPredictionEnv(BaseEnv):
messages, add_generation_prompt=True, tokenize=False
)
# Get model completion
completion = await self.server.completion(
prompt=prompt,
n=1,
max_tokens=1024 * 16,
temperature=0.2, # Lower for eval
split="eval",
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Get model completion
completion = await managed.completion(
prompt=prompt,
n=1,
max_tokens=1024 * 16,
temperature=0.2, # Lower for eval
split="eval",
)
# Extract the model's response
model_response = completion.choices[0].text