#!/usr/bin/env python3
"""
RubiksCubeTokenRewards: Token-level reward utilities for Rubik's Cube environment
This module provides functions for calculating token-level rewards, which are
important for fine-grained RL training signals that help the model understand
which tokens in its response contribute to success or failure.
"""
import numpy as np
import re
from typing import Dict, List, Optional, Tuple, Any
def calculate_token_level_rewards(
response_text: str,
is_valid_move: bool,
parsed_move: Optional[str],
reward: float,
token_ids: List[int],
scale_factor: float = 0.1
) -> List[float]:
"""
Calculate token-level rewards based on the response quality
Args:
response_text: Full response text from the LLM
is_valid_move: Whether the parsed move was valid
parsed_move: The parsed move if any
reward: The overall reward for the response
token_ids: List of token IDs in the response
scale_factor: Scale factor for token rewards
Returns:
A list of token-level rewards with the same length as token_ids
"""
# Initialize with neutral rewards
token_rewards = [0.0] * len(token_ids)
if len(token_ids) == 0:
return token_rewards
# Extract key parts of the response
thinking_match = re.search(r"(.*?)", response_text, re.DOTALL)
tool_call_match = re.search(r"(.*?)", response_text, re.DOTALL)
# Find the indices of key tokens
thinking_start_idx = response_text.find("")
thinking_end_idx = response_text.find("")
tool_call_start_idx = response_text.find("")
tool_call_end_idx = response_text.find("")
# Determine approximate character-to-token ratio
chars_per_token = len(response_text) / len(token_ids)
# Flag for quality of thinking
has_good_thinking = False
if thinking_match and len(thinking_match.group(1).strip()) > 50:
has_good_thinking = True
# Process rewards based on response quality
if is_valid_move and has_good_thinking:
# Good response with both thinking and valid move
# Reward distribution: ~60% to tool call, ~40% to thinking
base_reward = reward * scale_factor
# Distribute rewards
for i in range(len(token_ids)):
# Estimate the character position this token represents
char_pos = int(i * chars_per_token)
if thinking_start_idx <= char_pos <= thinking_end_idx:
# Token is part of thinking section
token_rewards[i] = base_reward * 0.4
elif tool_call_start_idx <= char_pos <= tool_call_end_idx:
# Token is part of tool call section
token_rewards[i] = base_reward * 0.6
else:
# Token is part of other sections
token_rewards[i] = base_reward * 0.1
elif is_valid_move and not has_good_thinking:
# Valid move but poor thinking
base_reward = reward * scale_factor * 0.7 # Reduced overall reward
for i in range(len(token_ids)):
char_pos = int(i * chars_per_token)
if tool_call_start_idx <= char_pos <= tool_call_end_idx:
# Token is part of tool call section - still good
token_rewards[i] = base_reward * 0.8
else:
# Token is part of other sections - minimal reward
token_rewards[i] = base_reward * 0.2
elif not is_valid_move and has_good_thinking:
# Good thinking but invalid move
base_reward = reward * scale_factor * 0.5 # Significantly reduced
for i in range(len(token_ids)):
char_pos = int(i * chars_per_token)
if thinking_start_idx <= char_pos <= thinking_end_idx:
# Token is part of thinking section - somewhat good
token_rewards[i] = base_reward * 0.6
elif tool_call_start_idx <= char_pos <= tool_call_end_idx:
# Token is part of tool call section - problematic
token_rewards[i] = base_reward * 0.1
else:
# Token is part of other sections
token_rewards[i] = base_reward * 0.3
else:
# Poor response overall
base_reward = reward * scale_factor * 0.3 # Minimal reward
# Distribute minimal rewards evenly
for i in range(len(token_ids)):
token_rewards[i] = base_reward
# Special handling for move-related tokens when there is a valid move
if is_valid_move and parsed_move:
# Try to find the specific tokens that represent the move
move_pattern = re.escape(parsed_move)
move_matches = list(re.finditer(move_pattern, response_text))
for match in move_matches:
move_start_idx = match.start()
move_end_idx = match.end()
# Estimate corresponding token indices
move_start_token = int(move_start_idx / chars_per_token)
move_end_token = int(move_end_idx / chars_per_token) + 1
# Ensure indices are within bounds
move_start_token = max(0, min(move_start_token, len(token_ids) - 1))
move_end_token = max(0, min(move_end_token, len(token_ids)))
# Boost rewards for tokens that directly encode the move
for i in range(move_start_token, move_end_token):
token_rewards[i] = base_reward * 1.5 # Higher reward for the actual move
return token_rewards
def calculate_advantage_token_weights(token_rewards: List[List[float]]) -> List[List[float]]:
"""
Calculate token weights for advantage computation
Args:
token_rewards: List of token-level rewards for each alternative response
Returns:
List of normalized token weights for each alternative
"""
# Create a copy to avoid modifying the input
token_weights = [rewards.copy() for rewards in token_rewards]
# For each alternative
for i in range(len(token_weights)):
# Get min and max rewards for this alternative
min_reward = min(token_weights[i]) if token_weights[i] else 0.0
max_reward = max(token_weights[i]) if token_weights[i] else 0.0
reward_range = max_reward - min_reward
# Normalize to [0.5, 1.0] range to ensure all tokens get some weight
if reward_range > 0:
for j in range(len(token_weights[i])):
normalized = 0.5 + 0.5 * (token_weights[i][j] - min_reward) / reward_range
token_weights[i][j] = normalized
else:
# If all rewards are the same, use uniform weights
for j in range(len(token_weights[i])):
token_weights[i][j] = 1.0
return token_weights