atropos/environments/community/lean_proof_env/lean_env.py
2025-10-14 12:28:13 +02:00

516 lines
23 KiB
Python

import asyncio # For async operations if PyPantograph is async
import random
from typing import Dict, List, Optional, Tuple, TypedDict
import wandb # For wandb.Table
from datasets import load_dataset
from tqdm.asyncio import tqdm_asyncio # For progress bars in evaluate
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import ( # Item might not be directly used if LeanProblemRow is self-contained
Item,
number,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# from pypantograph import PyPantograph # Assuming it's a package
# For now, mock PyPantograph:
# Mock PyPantograph - Replace with your actual PyPantograph integration
class PyPantograph:
@staticmethod
async def check_lean_code(lean_code: str) -> Tuple[bool, Optional[str]]:
"""
Checks if the Lean code compiles.
Returns a tuple (compiles_ok, error_message).
Error_message is None if compiles_ok is True.
Simulates an async check.
"""
await asyncio.sleep(0.01) # Simulate I/O or computation time
# Basic checks
if "sorry" in lean_code.lower():
return False, "Proof contains 'sorry'."
if not lean_code.strip() or len(lean_code.strip()) < 10:
return False, "Proof is empty or too short."
if "begin end" in lean_code or lean_code.strip().endswith("begin"):
return False, "Incomplete or empty proof structure."
# Simulate compilation success/failure based on some keywords or randomness
if "simple_theorem_correct" in lean_code:
return True, None
if "simple_theorem_error" in lean_code:
return False, "Mock Lean: Type mismatch on 'simple_theorem_error'."
# Default random outcome
if random.random() < 0.6: # 60% chance of mock compilation
return True, None
else:
errors = [
"Mock Lean: Unknown identifier 'xyz'.",
"Mock Lean: Type mismatch, expected 'nat', got 'Prop'.",
"Mock Lean: Tactic failed.",
"Mock Lean: Universe level constraint violation.",
]
return False, random.choice(errors)
# System prompt for Lean
lean_system_prompt = (
"You are an expert Lean mathematician. Your mission is to complete the given Lean proof.\\n"
"You will be provided with a Lean theorem statement, often prefixed by necessary import and open commands. "
"The theorem itself (e.g., starting with 'theorem ...') will typically end with ':= sorry'.\\n"
"Your task is to replace 'sorry' (and its surrounding ' := ' if necessary) with the correct proof steps "
"within the provided theorem structure.\\n"
'Provide only the completed Lean theorem block (e.g., starting from "theorem ..." '
'or "def ..." up to its final "end" or conclusion), '
"including the statement and the proof, as a single Lean code block.\\n"
"Do not repeat the import or open commands that were part of the input.\\n"
"Do not include any other explanatory text, comments, "
"or markdown code fences (```lean ... ```) around your response.\\n"
"Ensure your proof is self-contained (assuming the provided imports) and syntactically correct Lean code.\\n\\n"
"Example of input you will receive from the user (header + formal statement):\\n"
"import Mathlib.Data.Nat.Basic\\n"
"open Nat\\n\\n"
"theorem add_comm (a b : nat) : a + b = b + a := sorry\\n\\n"
"Example of a correct full response from you (just the completed theorem block):\\n"
"theorem add_comm (a b : nat) : a + b = b + a :=\\\\n"
"begin\\\\n"
" rw nat.add_comm,\\\\n"
"end"
)
class LeanProblemRow(TypedDict):
id: str # Unique identifier for the problem
header: str # The import and open statements
problem_statement: (
str # e.g., "theorem add_comm (a b : nat) : a + b = b + a := sorry"
)
# ground_truth_proof: Optional[str] # For reference or more advanced evaluation
class LeanEnv(BaseEnv):
name = "lean_proof" # Used for wandb naming unless overridden
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_compiled_buffer = [] # For training batches
self.eval_metrics = [] # For storing (metric_name, value) tuples for wandb
# self.rollouts_for_wandb is inherited. We\'ll populate it with tuples:
# (problem_statement, generated_proof, score, status_message)
# self.train_data and self.test_data will be populated in setup()
self.train_data: List[LeanProblemRow] = []
self.test_data: List[LeanProblemRow] = []
self.testing_mode = testing # Store testing flag for setup
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="Qwen/Qwen3-235B-A22B", # Choose appropriate tokenizer for your LLM
group_size=4, # Number of proofs to generate per problem for scoring diversity
use_wandb=True,
rollout_server_url="http://localhost:8000", # Default Atropos server URL
total_steps=1000, # Total training steps
batch_size=8, # Number of scored groups to send to trainer at once
steps_per_eval=50, # How often to run evaluation
max_token_length=1024, # Max token length for generated Lean proofs
wandb_name="lean_proof_rl_eval", # Custom WandB project/group name prefix
ensure_scores_are_not_same=False, # For compile/no-compile, allowing same scores (e.g., all fail) is useful
num_rollouts_to_keep=32, # Number of recent groups to show in WandB table
num_rollouts_per_group_for_logging=1, # Log 1 example from each group (-1 for all in group)
include_messages=False, # We\'ll format messages for wandb table ourselves
)
# Configuration for the LLM API server (e.g., TGI, vLLM)
server_configs = [
APIServerConfig(
model_name="Qwen/Qwen3-235B-A22B", # Ensure this matches the deployed LLM and tokenizer
base_url="http://localhost:9001/v1", # Your LLM server endpoint (OpenAI compatible)
api_key="EMPTY", # API key if required by your server
num_requests_for_eval=128, # Max concurrent requests for evaluation pass
),
]
return env_config, server_configs
async def setup(self):
"""Load Lean dataset from Hugging Face."""
hf_dataset_path = "brando/minif2f-lean4" # Correctly keep this as a string
try:
print(
f"Loading training data from Hugging Face: {hf_dataset_path} (split=train)"
)
# The dataset "brando/minif2f-lean4" doesn't have a canonical "train" split from its viewer.
# It has "test" and "validation". Let's assume for now you want to use "validation" for training
# and "test" for evaluation, or vice-versa.
# For this example, I'll use "validation" for training and "test" for eval.
# Please adjust if your intention is different.
print(
"Attempting to use 'validation' split for training and 'test' split for evaluation."
)
raw_train_data = load_dataset(
hf_dataset_path, split="validation"
) # Using validation for training
print(
f"Loading test data for evaluation from Hugging Face: {hf_dataset_path} (split=test)"
)
raw_eval_data = load_dataset(
hf_dataset_path, split="test"
) # Using test for evaluation
except Exception as e:
print(f"Error loading dataset from Hugging Face: {e}")
print(
"Please ensure the dataset path is correct, splits exist, "
"and you have internet access / the dataset is cached."
)
print(
"Falling back to empty data. The environment will likely fail to run without data."
)
self.train_data = []
self.test_data = [] # This holds the data for evaluation
self.train_iter = 0
return
# Adapt this mapping based on your dataset's column names
# Using "formal_statement" for the problem statement based on provided features.
self.train_data = [
LeanProblemRow(
id=str(item.get("id", f"train_idx_{i}")),
header=item["header"],
problem_statement=item["formal_statement"],
)
for i, item in enumerate(raw_train_data)
]
self.test_data = [
LeanProblemRow(
id=str(item.get("id", f"eval_idx_{i}")),
header=item["header"],
problem_statement=item["formal_statement"],
)
for i, item in enumerate(raw_eval_data)
]
if self.testing_mode: # If in testing mode, use a small subset
print("Testing mode active: Using a small subset of the loaded data.")
self.train_data = self.train_data[:5]
self.test_data = self.test_data[:3]
if not self.train_data:
print("Warning: Test subset for training data is empty.")
if not self.test_data:
print("Warning: Test subset for validation data is empty.")
random.shuffle(self.train_data) # Optionally shuffle
self.train_iter = 0
print(
f"LeanEnv setup: Loaded {len(self.train_data)} train problems "
"(from validation split) and {len(self.test_data)} eval problems (from test split)."
)
if not self.train_data and self.config.total_steps > 0:
print(
"CRITICAL WARNING: No training data loaded. Environment will "
"not be able to produce training trajectories."
)
if not self.test_data and self.config.steps_per_eval > 0:
print(
"WARNING: No eval data loaded. Evaluation steps will have nothing to evaluate."
)
async def get_next_item(self) -> LeanProblemRow:
"""Get the next Lean problem from the training set."""
if not self.train_data:
# This should ideally not happen if total_steps is managed correctly
# Or, you might want to loop indefinitely over the training data
print(
"Warning: Ran out of unique training data. Restarting from beginning."
)
self.train_iter = 0
item = self.train_data[self.train_iter % len(self.train_data)]
self.train_iter += 1
return item
async def _get_llm_proof_attempt(
self, item: LeanProblemRow, split: str, temperature: float = 0.7
) -> str:
"""Generates a proof attempt using the LLM server for a given LeanProblemRow item."""
# LLM sees header + formal_statement
full_problem_context = item["header"] + "\n\n" + item["problem_statement"]
completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": lean_system_prompt},
{"role": "user", "content": full_problem_context},
],
n=1, # Single completion for this helper
max_tokens=self.config.max_token_length,
temperature=temperature,
split=split, # 'train' or 'eval' for server-side request tracking
)
return completion.choices[
0
].message.content.strip() # This should be the completed theorem block
async def rollout_and_score_eval(self, problem_item: LeanProblemRow) -> number:
"""Rollout a single problem for evaluation and score it based on compilation."""
# For evaluation, use a lower temperature for more deterministic outputs
llm_generated_theorem_block = await self._get_llm_proof_attempt(
problem_item, split="eval", temperature=0.1
)
# Code to check with PyPantograph = original header + LLM's completed theorem block
code_to_check = problem_item["header"] + "\n\n" + llm_generated_theorem_block
compiles, _error_msg = await PyPantograph.check_lean_code(code_to_check)
return 1 if compiles else 0
async def evaluate(self, *args, **kwargs):
"""Evaluate the LLM on the test set of Lean problems."""
print(f"Starting evaluation on {len(self.test_data)} Lean problems...")
if not self.test_data:
print("No test data to evaluate.")
self.eval_metrics.append(("eval/percent_compiled", 0.0))
return
eval_tasks = [self.rollout_and_score_eval(item) for item in self.test_data]
scores = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating Lean Proofs")
percent_compiled = (
sum(s for s in scores if s > 0) / len(scores) if scores else 0.0
)
self.eval_metrics.append(("eval/percent_compiled", percent_compiled))
print(f"Evaluation finished. Percent compiled: {percent_compiled:.2%}")
async def collect_trajectories(
self, item: LeanProblemRow # Item is LeanProblemRow
) -> Tuple[
Optional[ScoredDataGroup], list[Item]
]: # Return type List[Item] for backlog
"""
Collect a group of proof attempts for a single Lean problem and score them.
"""
# LLM sees header + formal_statement (problem_statement in LeanProblemRow is the formal_statement)
full_problem_context = item["header"] + "\n\n" + item["problem_statement"]
chat_completions = await self.server.chat_completion(
messages=[
{"role": "system", "content": lean_system_prompt},
{"role": "user", "content": full_problem_context},
],
n=self.config.group_size, # Generate n attempts
max_tokens=self.config.max_token_length,
temperature=0.7, # Higher temperature for diversity in training
split="train",
)
# Data to be scored by `score_lean_attempts`
attempts_to_score_data = []
for choice in chat_completions.choices:
# Construct message history for tokenization
# The user message here reflects what the LLM actually saw
messages_history = (
{"role": "system", "content": lean_system_prompt},
{"role": "user", "content": full_problem_context},
# LLM's output is its attempt at completing the theorem block
{"role": "assistant", "content": choice.message.content},
)
attempts_to_score_data.append(
{
"messages_history": messages_history,
"llm_generated_theorem_block": choice.message.content.strip(),
"finish_reason": choice.finish_reason,
"problem_item": item,
}
)
scored_data_group = await self.score_lean_attempts(attempts_to_score_data)
return scored_data_group, [] # No backlog items generated in this simple setup
async def score_lean_attempts(
self, attempts_data: List[Dict]
) -> Optional[ScoredDataGroup]:
"""
Scores a group of Lean proof attempts.
`attempts_data` contains 'llm_generated_theorem_block', 'messages_history', 'finish_reason', 'problem_item'.
The 'problem_item' is the original LeanProblemRow, containing the header and the formal_statement.
"""
scored_group = ScoredDataGroup(tokens=[], masks=[], scores=[])
self._current_group_rollout_details_log: List[Tuple[str, str, float, str]] = []
for attempt_data in attempts_data:
llm_generated_theorem_block = attempt_data["llm_generated_theorem_block"]
messages_history = attempt_data["messages_history"]
finish_reason = attempt_data["finish_reason"]
problem_item: LeanProblemRow = attempt_data["problem_item"]
# Code to check with PyPantograph = original header + LLM's completed theorem block
code_to_check = (
problem_item["header"] + "\n\n" + llm_generated_theorem_block
)
compiles, error_msg = await PyPantograph.check_lean_code(code_to_check)
reward = 1.0 if compiles else -1.0
status_message = (
"Compiled Successfully"
if compiles
else f"Compilation Failed: {error_msg or 'Unknown error'}"
)
# For WandB logging, problem_item["problem_statement"] is the formal_statement (original theorem with sorry)
# llm_generated_theorem_block is what the LLM produced to complete it.
self._current_group_rollout_details_log.append(
(
problem_item["problem_statement"],
llm_generated_theorem_block,
reward,
status_message,
)
)
out_dict = tokenize_for_trainer(
self.tokenizer, list(messages_history), finish_reason
)
# Basic filter for very short/empty generations not caught by PyPantograph mock
# Typically, the prompt and problem statement are part of the input tokens.
# We check the number of generated (assistant) tokens.
assistant_token_count = sum(
1 for m_idx in out_dict["masks"] if m_idx != -100
)
if (
assistant_token_count < 3
): # Arbitrary small number of tokens for a proof part
# This attempt might be too short to be a valid proof attempt.
# In a real scenario, you might still want to penalize it.
# For now, we are not explicitly filtering but PyPantograph mock handles some.
pass # Or `continue` if you want to filter these out from training batch
scored_group["tokens"].append(out_dict["tokens"])
scored_group["masks"].append(out_dict["masks"])
scored_group["scores"].append(reward)
# Update buffer for average batch compilation rate
self.percent_compiled_buffer.append(1.0 if compiles else 0.0)
if not scored_group["tokens"]: # If all attempts were filtered or no attempts
self._current_group_rollout_details_log = [] # Clear if no valid data
return None
return scored_group
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Custom WandB logging for Lean environment."""
if wandb_metrics is None:
wandb_metrics = {}
if self.percent_compiled_buffer:
wandb_metrics["train/batch_avg_percent_compiled"] = sum(
self.percent_compiled_buffer
) / len(self.percent_compiled_buffer)
self.percent_compiled_buffer = []
for metric_name, value in self.eval_metrics:
wandb_metrics[metric_name] = value
self.eval_metrics = []
# The parent super().wandb_log will call self.create_rollout_table()
await super().wandb_log(wandb_metrics)
async def add_rollouts_for_wandb(
self,
scored_data: ScoredDataGroup, # This is the ScoredDataGroup returned by score_lean_attempts
item: Optional[
LeanProblemRow
] = None, # `item` is the original LeanProblemRow from get_next_item
):
"""
Called by BaseEnv\'s handle_send_to_api. Prepares detailed rollouts for our custom WandB table.
"""
# `_current_group_rollout_details_log` was populated in `score_lean_attempts`.
# It contains: (problem_statement, generated_proof, score, status_message)
if (
hasattr(self, "_current_group_rollout_details_log")
and self._current_group_rollout_details_log
):
num_to_log = self.config.num_rollouts_per_group_for_logging
if num_to_log == -1: # Log all from the group
rollouts_to_add = self._current_group_rollout_details_log
else:
rollouts_to_add = self._current_group_rollout_details_log[:num_to_log]
# self.rollouts_for_wandb is a list of lists/groups of these tuples
if rollouts_to_add: # Ensure there\'s something to add
self.rollouts_for_wandb.append(rollouts_to_add)
# Important: Clean up the temporary instance variable after use
if hasattr(self, "_current_group_rollout_details_log"):
del self._current_group_rollout_details_log
# Keep only the configured number of recent groups for the WandB table
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
async def create_rollout_table(self, wandb_metrics: Dict) -> Dict:
"""Overrides BaseEnv method to create a custom WandB table for Lean proofs."""
if self.rollouts_for_wandb:
table_name_key = "train/lean_proof_attempts"
# BaseEnv might prepend a name like "env_0_train/lean_proof_attempts"
# If self.wandb_prepend is set (by BaseEnv.register_env), use it.
# Otherwise, default to just "train/..."
# The super().wandb_log() handles prepending to other metrics,
# but for tables we might need to handle it here or ensure consistency.
# For now, let's use a simple name and see how BaseEnv handles it.
# If BaseEnv's wandb_log prepends, this name will also be prepended.
if self.wandb_prepend:
table_name_key = f"{self.wandb_prepend}_{table_name_key.split('/')[-1]}"
table = wandb.Table(
columns=[
"Problem Statement",
"Generated Proof",
"Score",
"Compilation Status",
]
)
# self.rollouts_for_wandb is a list of groups, and each group is a list of rollout tuples
for group_of_rollouts in self.rollouts_for_wandb:
for rollout_detail_tuple in group_of_rollouts:
# Each tuple: (problem_statement, generated_proof, score, status_message)
problem_stmt, gen_proof, score, status_msg = rollout_detail_tuple
# Truncate for display, WandB tables can be slow with very long text
problem_display = (
(problem_stmt[:250] + "...")
if len(problem_stmt) > 250
else problem_stmt
)
proof_display = (
(gen_proof[:400] + "...") if len(gen_proof) > 400 else gen_proof
)
table.add_data(
problem_display, proof_display, float(score), status_msg
)
wandb_metrics[table_name_key] = table
return wandb_metrics
if __name__ == "__main__":
# This allows running:
# python environments/lean_env.py serve (to connect to Atropos trainer)
# python environments/lean_env.py process (for local data generation/testing)
LeanEnv.cli()