mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
convert fundamentals prediction env to use managed server
This commit is contained in:
parent
ff46cfff44
commit
4738fabd57
1 changed files with 39 additions and 31 deletions
|
|
@ -14,7 +14,6 @@ from atroposlib.envs.base import (
|
|||
Item,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# System prompt only contains thinking instructions
|
||||
system_prompt = """You are a deep thinking AI financial analyst.
|
||||
|
|
@ -174,17 +173,21 @@ class FundamentalPredictionEnv(BaseEnv):
|
|||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
# Get completions from the model
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=1024 * 15,
|
||||
temperature=0.8, # Using higher temperature for diverse responses
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
# Get completions from the model
|
||||
completions = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=1024 * 15,
|
||||
temperature=0.8, # Using higher temperature for diverse responses
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
to_score = list()
|
||||
|
||||
for _, completion_choice in enumerate(completions.choices):
|
||||
for i, completion_choice in enumerate(completions.choices):
|
||||
# Create a copy of the prompt messages
|
||||
trajectory_messages = []
|
||||
for role_dict in item[0]:
|
||||
|
|
@ -197,12 +200,15 @@ class FundamentalPredictionEnv(BaseEnv):
|
|||
|
||||
# Add to scoring queue with expected answer, magnitude, and fundamental metric
|
||||
to_score.append(
|
||||
(
|
||||
tuple(trajectory_messages),
|
||||
item[1], # answer (maintained/raised/reduced)
|
||||
item[2], # magnitude
|
||||
item[3], # fundamental_metric
|
||||
)
|
||||
{
|
||||
"messages": tuple(trajectory_messages),
|
||||
"answer": item[1], # answer (maintained/raised/reduced)
|
||||
"magnitude": item[2], # magnitude
|
||||
"fundamental_metric": item[3], # fundamental_metric
|
||||
"tokens": nodes[i].tokens,
|
||||
"masks": nodes[i].masked_tokens,
|
||||
"logprobs": nodes[i].logprobs,
|
||||
}
|
||||
)
|
||||
|
||||
# Call score to get the scored data
|
||||
|
|
@ -321,20 +327,21 @@ class FundamentalPredictionEnv(BaseEnv):
|
|||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
scores["inference_logprobs"] = list()
|
||||
|
||||
# Get the expected answer, magnitude, and fundamental metric
|
||||
expected_answer = rollout_group_data[0][
|
||||
1
|
||||
"answer"
|
||||
] # "maintained", "raised", or "reduced"
|
||||
expected_magnitude = rollout_group_data[0][2] # Expected percentage change
|
||||
fundamental_metric = rollout_group_data[0][3] # Type of fundamental metric
|
||||
expected_magnitude = rollout_group_data[0]["magnitude"] # Expected percentage change
|
||||
fundamental_metric = rollout_group_data[0]["fundamental_metric"] # Type of fundamental metric
|
||||
|
||||
# Shuffle to avoid bias in selection
|
||||
random.shuffle(rollout_group_data)
|
||||
|
||||
for item in rollout_group_data:
|
||||
# Extract the model's response
|
||||
model_response = item[0][-1]["content"]
|
||||
model_response = item["messages"][-1]["content"]
|
||||
|
||||
# Extract the prediction and magnitude from the model's response
|
||||
prediction, magnitude = self._extract_prediction(
|
||||
|
|
@ -364,10 +371,9 @@ class FundamentalPredictionEnv(BaseEnv):
|
|||
# For binary reward signal, any positive score gets +1, otherwise -1
|
||||
binary_reward = 1.0 if final_score > 0 else -1.0
|
||||
|
||||
# Tokenize the conversation for learning
|
||||
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
tokens = item["tokens"]
|
||||
masks = item["masks"]
|
||||
logprobs = item["logprobs"]
|
||||
|
||||
# Remove examples with insufficient context
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
|
|
@ -375,6 +381,7 @@ class FundamentalPredictionEnv(BaseEnv):
|
|||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["inference_logprobs"].append(logprobs)
|
||||
scores["scores"].append(binary_reward)
|
||||
|
||||
# For tracking metrics
|
||||
|
|
@ -429,14 +436,15 @@ class FundamentalPredictionEnv(BaseEnv):
|
|||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
# Get model completion
|
||||
completion = await self.server.completion(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
max_tokens=1024 * 16,
|
||||
temperature=0.2, # Lower for eval
|
||||
split="eval",
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
# Get model completion
|
||||
completion = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
max_tokens=1024 * 16,
|
||||
temperature=0.2, # Lower for eval
|
||||
split="eval",
|
||||
)
|
||||
|
||||
# Extract the model's response
|
||||
model_response = completion.choices[0].text
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue