mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
convert fundamentals prediction env to use managed server
This commit is contained in:
parent
ff46cfff44
commit
4738fabd57
1 changed files with 39 additions and 31 deletions
|
|
@ -14,7 +14,6 @@ from atroposlib.envs.base import (
|
||||||
Item,
|
Item,
|
||||||
ScoredDataGroup,
|
ScoredDataGroup,
|
||||||
)
|
)
|
||||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
||||||
|
|
||||||
# System prompt only contains thinking instructions
|
# System prompt only contains thinking instructions
|
||||||
system_prompt = """You are a deep thinking AI financial analyst.
|
system_prompt = """You are a deep thinking AI financial analyst.
|
||||||
|
|
@ -174,17 +173,21 @@ class FundamentalPredictionEnv(BaseEnv):
|
||||||
messages, add_generation_prompt=True, tokenize=False
|
messages, add_generation_prompt=True, tokenize=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get completions from the model
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||||
completions = await self.server.completion(
|
# Get completions from the model
|
||||||
prompt=prompt,
|
completions = await managed.completion(
|
||||||
n=self.config.group_size,
|
prompt=prompt,
|
||||||
max_tokens=1024 * 15,
|
n=self.config.group_size,
|
||||||
temperature=0.8, # Using higher temperature for diverse responses
|
max_tokens=1024 * 15,
|
||||||
)
|
temperature=0.8, # Using higher temperature for diverse responses
|
||||||
|
)
|
||||||
|
|
||||||
|
state = managed.get_state()
|
||||||
|
nodes = state["nodes"]
|
||||||
|
|
||||||
to_score = list()
|
to_score = list()
|
||||||
|
|
||||||
for _, completion_choice in enumerate(completions.choices):
|
for i, completion_choice in enumerate(completions.choices):
|
||||||
# Create a copy of the prompt messages
|
# Create a copy of the prompt messages
|
||||||
trajectory_messages = []
|
trajectory_messages = []
|
||||||
for role_dict in item[0]:
|
for role_dict in item[0]:
|
||||||
|
|
@ -197,12 +200,15 @@ class FundamentalPredictionEnv(BaseEnv):
|
||||||
|
|
||||||
# Add to scoring queue with expected answer, magnitude, and fundamental metric
|
# Add to scoring queue with expected answer, magnitude, and fundamental metric
|
||||||
to_score.append(
|
to_score.append(
|
||||||
(
|
{
|
||||||
tuple(trajectory_messages),
|
"messages": tuple(trajectory_messages),
|
||||||
item[1], # answer (maintained/raised/reduced)
|
"answer": item[1], # answer (maintained/raised/reduced)
|
||||||
item[2], # magnitude
|
"magnitude": item[2], # magnitude
|
||||||
item[3], # fundamental_metric
|
"fundamental_metric": item[3], # fundamental_metric
|
||||||
)
|
"tokens": nodes[i].tokens,
|
||||||
|
"masks": nodes[i].masked_tokens,
|
||||||
|
"logprobs": nodes[i].logprobs,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call score to get the scored data
|
# Call score to get the scored data
|
||||||
|
|
@ -321,20 +327,21 @@ class FundamentalPredictionEnv(BaseEnv):
|
||||||
scores["tokens"] = list()
|
scores["tokens"] = list()
|
||||||
scores["masks"] = list()
|
scores["masks"] = list()
|
||||||
scores["scores"] = list()
|
scores["scores"] = list()
|
||||||
|
scores["inference_logprobs"] = list()
|
||||||
|
|
||||||
# Get the expected answer, magnitude, and fundamental metric
|
# Get the expected answer, magnitude, and fundamental metric
|
||||||
expected_answer = rollout_group_data[0][
|
expected_answer = rollout_group_data[0][
|
||||||
1
|
"answer"
|
||||||
] # "maintained", "raised", or "reduced"
|
] # "maintained", "raised", or "reduced"
|
||||||
expected_magnitude = rollout_group_data[0][2] # Expected percentage change
|
expected_magnitude = rollout_group_data[0]["magnitude"] # Expected percentage change
|
||||||
fundamental_metric = rollout_group_data[0][3] # Type of fundamental metric
|
fundamental_metric = rollout_group_data[0]["fundamental_metric"] # Type of fundamental metric
|
||||||
|
|
||||||
# Shuffle to avoid bias in selection
|
# Shuffle to avoid bias in selection
|
||||||
random.shuffle(rollout_group_data)
|
random.shuffle(rollout_group_data)
|
||||||
|
|
||||||
for item in rollout_group_data:
|
for item in rollout_group_data:
|
||||||
# Extract the model's response
|
# Extract the model's response
|
||||||
model_response = item[0][-1]["content"]
|
model_response = item["messages"][-1]["content"]
|
||||||
|
|
||||||
# Extract the prediction and magnitude from the model's response
|
# Extract the prediction and magnitude from the model's response
|
||||||
prediction, magnitude = self._extract_prediction(
|
prediction, magnitude = self._extract_prediction(
|
||||||
|
|
@ -364,10 +371,9 @@ class FundamentalPredictionEnv(BaseEnv):
|
||||||
# For binary reward signal, any positive score gets +1, otherwise -1
|
# For binary reward signal, any positive score gets +1, otherwise -1
|
||||||
binary_reward = 1.0 if final_score > 0 else -1.0
|
binary_reward = 1.0 if final_score > 0 else -1.0
|
||||||
|
|
||||||
# Tokenize the conversation for learning
|
tokens = item["tokens"]
|
||||||
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
masks = item["masks"]
|
||||||
tokens = out_dict["tokens"]
|
logprobs = item["logprobs"]
|
||||||
masks = out_dict["masks"]
|
|
||||||
|
|
||||||
# Remove examples with insufficient context
|
# Remove examples with insufficient context
|
||||||
if len([1 for i in masks if i != -100]) < 10:
|
if len([1 for i in masks if i != -100]) < 10:
|
||||||
|
|
@ -375,6 +381,7 @@ class FundamentalPredictionEnv(BaseEnv):
|
||||||
|
|
||||||
scores["tokens"].append(tokens)
|
scores["tokens"].append(tokens)
|
||||||
scores["masks"].append(masks)
|
scores["masks"].append(masks)
|
||||||
|
scores["inference_logprobs"].append(logprobs)
|
||||||
scores["scores"].append(binary_reward)
|
scores["scores"].append(binary_reward)
|
||||||
|
|
||||||
# For tracking metrics
|
# For tracking metrics
|
||||||
|
|
@ -429,14 +436,15 @@ class FundamentalPredictionEnv(BaseEnv):
|
||||||
messages, add_generation_prompt=True, tokenize=False
|
messages, add_generation_prompt=True, tokenize=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get model completion
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||||
completion = await self.server.completion(
|
# Get model completion
|
||||||
prompt=prompt,
|
completion = await managed.completion(
|
||||||
n=1,
|
prompt=prompt,
|
||||||
max_tokens=1024 * 16,
|
n=1,
|
||||||
temperature=0.2, # Lower for eval
|
max_tokens=1024 * 16,
|
||||||
split="eval",
|
temperature=0.2, # Lower for eval
|
||||||
)
|
split="eval",
|
||||||
|
)
|
||||||
|
|
||||||
# Extract the model's response
|
# Extract the model's response
|
||||||
model_response = completion.choices[0].text
|
model_response = completion.choices[0].text
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue