mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
add jsonl file
This commit is contained in:
parent
e018440d66
commit
7e1de80695
3 changed files with 98 additions and 21 deletions
|
|
@ -474,10 +474,12 @@ class RubiksCubeEnv(BaseEnv):
|
|||
)
|
||||
return None
|
||||
|
||||
# First try parsing with tool_call tags
|
||||
parsed_name, parsed_args, is_error = parse_tool_call(
|
||||
response, self.tools, ["tool_call"]
|
||||
)
|
||||
|
||||
# If that fails, try looking for direct text mentions of moves
|
||||
if is_error:
|
||||
error_detail = (
|
||||
parsed_name
|
||||
|
|
@ -487,21 +489,59 @@ class RubiksCubeEnv(BaseEnv):
|
|||
logger.warning(
|
||||
f"Failed to parse tool call. Full response: '{response}'. Error detail: {error_detail}"
|
||||
)
|
||||
|
||||
# Fallback: Look for direct mentions of moves in the text
|
||||
valid_moves = ["U", "D", "L", "R", "F", "B",
|
||||
"U'", "D'", "L'", "R'", "F'", "B'",
|
||||
"U2", "D2", "L2", "R2", "F2", "B2"]
|
||||
|
||||
# Look for patterns like "I'll apply move X" or "Performing X rotation"
|
||||
move_patterns = [
|
||||
r'move\s+([UDLRFB][\'2]?)',
|
||||
r'applying\s+([UDLRFB][\'2]?)',
|
||||
r'perform\w*\s+([UDLRFB][\'2]?)',
|
||||
r'rotate\s+([UDLRFB][\'2]?)',
|
||||
r'rotation\s+([UDLRFB][\'2]?)',
|
||||
r'I\s*choose\s+([UDLRFB][\'2]?)',
|
||||
r'Execute\s+([UDLRFB][\'2]?)',
|
||||
r'([UDLRFB][\'2]?)\s+rotation',
|
||||
r'([UDLRFB][\'2]?)\s+move'
|
||||
]
|
||||
|
||||
for pattern in move_patterns:
|
||||
match = re.search(pattern, response, re.IGNORECASE)
|
||||
if match:
|
||||
potential_move = match.group(1).strip()
|
||||
if potential_move in valid_moves:
|
||||
logger.warning(f"Recovered move '{potential_move}' from text using regex")
|
||||
return potential_move
|
||||
|
||||
return None
|
||||
|
||||
move = parsed_args.get("move", "").strip()
|
||||
move = parsed_args.get("move", "").strip() if isinstance(parsed_args, dict) else ""
|
||||
valid_moves = ["U", "D", "L", "R", "F", "B",
|
||||
"U'", "D'", "L'", "R'", "F'", "B'",
|
||||
"U2", "D2", "L2", "R2", "F2", "B2"]
|
||||
|
||||
# First check if the move is directly valid
|
||||
if move in valid_moves:
|
||||
return move
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed invalid move: '{move}'. "
|
||||
f"Full response: '{response}'. Parsed args: {parsed_args}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if the move is a sequence containing valid moves
|
||||
# (when LLM outputs move sequences like "R U R'")
|
||||
if " " in move:
|
||||
# Take only the first move in the sequence
|
||||
first_move = move.split()[0].strip()
|
||||
if first_move in valid_moves:
|
||||
logger.warning(f"Got move sequence '{move}' but taking only first move '{first_move}'")
|
||||
return first_move
|
||||
|
||||
# If we get here, the move is invalid
|
||||
logger.warning(
|
||||
f"Parsed invalid move: '{move}'. "
|
||||
f"Full response: '{response}'. Parsed args: {parsed_args}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _score_response(
|
||||
self,
|
||||
|
|
@ -531,15 +571,29 @@ class RubiksCubeEnv(BaseEnv):
|
|||
current_score += 1.0
|
||||
|
||||
# Check for thinking tags
|
||||
match = re.search(r"<think>(.*?)</think>", response_text, re.DOTALL)
|
||||
if match:
|
||||
thinking_content = match.group(1)
|
||||
if thinking_content.strip(): # Not empty
|
||||
current_score += 0.2
|
||||
else: # Empty thinking tags
|
||||
current_score -= 0.1
|
||||
else: # No thinking tags
|
||||
current_score -= 0.2
|
||||
try:
|
||||
# Make sure response_text is a string
|
||||
if not isinstance(response_text, str):
|
||||
logger.warning(f"response_text is not a string: {type(response_text)}")
|
||||
response_text = str(response_text) if response_text is not None else ""
|
||||
|
||||
match = re.search(r"<think>(.*?)</think>", response_text, re.DOTALL)
|
||||
if match:
|
||||
thinking_content = match.group(1)
|
||||
# Make sure thinking_content is a string before calling strip()
|
||||
if isinstance(thinking_content, str):
|
||||
if thinking_content.strip(): # Not empty
|
||||
current_score += 0.2
|
||||
else: # Empty thinking tags
|
||||
current_score -= 0.1
|
||||
else:
|
||||
logger.warning(f"thinking_content is not a string: {type(thinking_content)}")
|
||||
current_score -= 0.1
|
||||
else: # No thinking tags
|
||||
current_score -= 0.2
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing thinking tags: {e}")
|
||||
current_score -= 0.1
|
||||
|
||||
return current_score
|
||||
|
||||
|
|
@ -782,9 +836,26 @@ class RubiksCubeEnv(BaseEnv):
|
|||
f"[Collect Trajectories Seed: {seed}] Attempting turn {turn_idx + 1}/{max_turns}."
|
||||
)
|
||||
|
||||
step_data, is_terminal_this_step = await self._next_step(
|
||||
ep, turn_idx, max_turns
|
||||
)
|
||||
try:
|
||||
step_data, is_terminal_this_step = await self._next_step(
|
||||
ep, turn_idx, max_turns
|
||||
)
|
||||
except Exception as e:
|
||||
if "'list' object has no attribute 'strip'" in str(e):
|
||||
# Special handling for this common error which doesn't actually
|
||||
# prevent the overall process from working
|
||||
logger.error(f"[Collect Trajectories Seed: {seed}] Non-fatal error in _next_step: {e}")
|
||||
# Since we can't recover the step data, mark this step as failed
|
||||
step_data = None
|
||||
is_terminal_this_step = True
|
||||
else:
|
||||
# For other errors, log and terminate
|
||||
logger.error(
|
||||
f"[Collect Trajectories Seed: {seed}] Error in _next_step: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
step_data = None
|
||||
is_terminal_this_step = True
|
||||
|
||||
if step_data:
|
||||
trajectory_data_for_trainer.append(step_data)
|
||||
|
|
@ -1000,7 +1071,7 @@ class RubiksCubeEnv(BaseEnv):
|
|||
def config_init(cls) -> Tuple[RubiksCubeEnvConfig, List[APIServerConfig]]:
|
||||
"""Initialize the configuration"""
|
||||
env_config = RubiksCubeEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
tokenizer_name="openai/gpt-4-turbo-preview",
|
||||
group_size=16,
|
||||
use_wandb=True,
|
||||
max_num_workers=128,
|
||||
|
|
@ -1033,7 +1104,6 @@ class RubiksCubeEnv(BaseEnv):
|
|||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
)
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue