add jsonl file

This commit is contained in:
Joshua Jerin 2025-05-18 15:27:37 -07:00
parent e018440d66
commit 7e1de80695
3 changed files with 98 additions and 21 deletions

View file

@ -474,10 +474,12 @@ class RubiksCubeEnv(BaseEnv):
)
return None
# First try parsing with tool_call tags
parsed_name, parsed_args, is_error = parse_tool_call(
response, self.tools, ["tool_call"]
)
# If that fails, try looking for direct text mentions of moves
if is_error:
error_detail = (
parsed_name
@ -487,21 +489,59 @@ class RubiksCubeEnv(BaseEnv):
logger.warning(
f"Failed to parse tool call. Full response: '{response}'. Error detail: {error_detail}"
)
# Fallback: Look for direct mentions of moves in the text
valid_moves = ["U", "D", "L", "R", "F", "B",
"U'", "D'", "L'", "R'", "F'", "B'",
"U2", "D2", "L2", "R2", "F2", "B2"]
# Look for patterns like "I'll apply move X" or "Performing X rotation"
move_patterns = [
r'move\s+([UDLRFB][\'2]?)',
r'applying\s+([UDLRFB][\'2]?)',
r'perform\w*\s+([UDLRFB][\'2]?)',
r'rotate\s+([UDLRFB][\'2]?)',
r'rotation\s+([UDLRFB][\'2]?)',
r'I\s*choose\s+([UDLRFB][\'2]?)',
r'Execute\s+([UDLRFB][\'2]?)',
r'([UDLRFB][\'2]?)\s+rotation',
r'([UDLRFB][\'2]?)\s+move'
]
for pattern in move_patterns:
match = re.search(pattern, response, re.IGNORECASE)
if match:
potential_move = match.group(1).strip()
if potential_move in valid_moves:
logger.warning(f"Recovered move '{potential_move}' from text using regex")
return potential_move
return None
move = parsed_args.get("move", "").strip()
move = parsed_args.get("move", "").strip() if isinstance(parsed_args, dict) else ""
valid_moves = ["U", "D", "L", "R", "F", "B",
"U'", "D'", "L'", "R'", "F'", "B'",
"U2", "D2", "L2", "R2", "F2", "B2"]
# First check if the move is directly valid
if move in valid_moves:
return move
else:
logger.warning(
f"Parsed invalid move: '{move}'. "
f"Full response: '{response}'. Parsed args: {parsed_args}"
)
return None
# Check if the move is a sequence containing valid moves
# (when LLM outputs move sequences like "R U R'")
if " " in move:
# Take only the first move in the sequence
first_move = move.split()[0].strip()
if first_move in valid_moves:
logger.warning(f"Got move sequence '{move}' but taking only first move '{first_move}'")
return first_move
# If we get here, the move is invalid
logger.warning(
f"Parsed invalid move: '{move}'. "
f"Full response: '{response}'. Parsed args: {parsed_args}"
)
return None
def _score_response(
self,
@ -531,15 +571,29 @@ class RubiksCubeEnv(BaseEnv):
current_score += 1.0
# Check for thinking tags
match = re.search(r"<think>(.*?)</think>", response_text, re.DOTALL)
if match:
thinking_content = match.group(1)
if thinking_content.strip(): # Not empty
current_score += 0.2
else: # Empty thinking tags
current_score -= 0.1
else: # No thinking tags
current_score -= 0.2
try:
# Make sure response_text is a string
if not isinstance(response_text, str):
logger.warning(f"response_text is not a string: {type(response_text)}")
response_text = str(response_text) if response_text is not None else ""
match = re.search(r"<think>(.*?)</think>", response_text, re.DOTALL)
if match:
thinking_content = match.group(1)
# Make sure thinking_content is a string before calling strip()
if isinstance(thinking_content, str):
if thinking_content.strip(): # Not empty
current_score += 0.2
else: # Empty thinking tags
current_score -= 0.1
else:
logger.warning(f"thinking_content is not a string: {type(thinking_content)}")
current_score -= 0.1
else: # No thinking tags
current_score -= 0.2
except Exception as e:
logger.warning(f"Error processing thinking tags: {e}")
current_score -= 0.1
return current_score
@ -782,9 +836,26 @@ class RubiksCubeEnv(BaseEnv):
f"[Collect Trajectories Seed: {seed}] Attempting turn {turn_idx + 1}/{max_turns}."
)
step_data, is_terminal_this_step = await self._next_step(
ep, turn_idx, max_turns
)
try:
step_data, is_terminal_this_step = await self._next_step(
ep, turn_idx, max_turns
)
except Exception as e:
if "'list' object has no attribute 'strip'" in str(e):
# Special handling for this common error which doesn't actually
# prevent the overall process from working
logger.error(f"[Collect Trajectories Seed: {seed}] Non-fatal error in _next_step: {e}")
# Since we can't recover the step data, mark this step as failed
step_data = None
is_terminal_this_step = True
else:
# For other errors, log and terminate
logger.error(
f"[Collect Trajectories Seed: {seed}] Error in _next_step: {e}",
exc_info=True
)
step_data = None
is_terminal_this_step = True
if step_data:
trajectory_data_for_trainer.append(step_data)
@ -1000,7 +1071,7 @@ class RubiksCubeEnv(BaseEnv):
def config_init(cls) -> Tuple[RubiksCubeEnvConfig, List[APIServerConfig]]:
"""Initialize the configuration"""
env_config = RubiksCubeEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
tokenizer_name="openai/gpt-4-turbo-preview",
group_size=16,
use_wandb=True,
max_num_workers=128,
@ -1033,7 +1104,6 @@ class RubiksCubeEnv(BaseEnv):
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=256,
)
]