Fixed typo, revising reward function

This commit is contained in:
Shannon Sands 2025-05-10 19:45:06 +10:00
parent 7fe1a40368
commit 840ff20921

View file

@ -159,6 +159,12 @@ class BlackjackEnv(BaseEnv):
f"EnvReward(raw): {env_reward:.4f}, EnvReward(adj for invalid): {current_env_reward:.4f} "
f"==> Final Score (from env): {final_score:.4f}"
)
# Try to get a valid tool call from the response
tool_call = self._parse_tool_call(response_text)
if tool_call == -1:
final_score -= 0.5
else:
final_score += 0.5
return final_score
def _parse_tool_call(self, response: str) -> int:
@ -521,12 +527,12 @@ class BlackjackEnv(BaseEnv):
seed=ep.seed,
tokens=alt_tokens,
masks=alt_masks,
scores=alt_advantages,
scores=alt_combined_rewards,
messages=alt_next_state_msgs,
parsed_actions=alt_parsed_actions,
)
)
# Get the best advantage index to use as the chosen action for the next ste
best_advantage = -float("inf")
best_advantage_idx = -1
valid_indices_for_tiebreak = []