diff --git a/reasoning_gym/algorithmic/letter_counting.py b/reasoning_gym/algorithmic/letter_counting.py index 5d95851c..509a6171 100644 --- a/reasoning_gym/algorithmic/letter_counting.py +++ b/reasoning_gym/algorithmic/letter_counting.py @@ -86,7 +86,7 @@ class LetterCountingCurriculum(BaseCurriculum): self._define_attributes( RangeAttributeDefinition( name="words", - levels=[10, 50, 100, 1000], + levels=list(range(5, 20, 2)), description="Number of words in the span", lower_field_name="min_words", upper_field_name="max_words", diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index cc3e06d2..cf572668 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -177,7 +177,7 @@ class NumberSortingCurriculum(BaseCurriculum): self._define_attributes( RangeAttributeDefinition( name="numbers", - levels=[10, 100, 500, 1000], + levels=list(range(5, 20, 2)), description="How many numbers to sort", lower_field_name="min_numbers", upper_field_name="max_numbers", @@ -185,7 +185,7 @@ class NumberSortingCurriculum(BaseCurriculum): ), RangeAttributeDefinition( name="decimals", - levels=[0, 2, 4, 6], + levels=list(range(0, 8)), description="Number of decimal places", lower_field_name="min_decimals", upper_field_name="max_decimals", diff --git a/reasoning_gym/algorithmic/spell_backward.py b/reasoning_gym/algorithmic/spell_backward.py index 2fed1d22..f7f4d843 100644 --- a/reasoning_gym/algorithmic/spell_backward.py +++ b/reasoning_gym/algorithmic/spell_backward.py @@ -17,7 +17,7 @@ class SpellBackwardConfig: """Configuration for spelling words backward task generation""" min_word_len: int = 3 # Minimum word length - max_word_len: int = 20 # Maximum word length + max_word_len: int = 10 # Maximum word length seed: Optional[int] = None size: int = 500 # Virtual dataset size @@ -34,12 +34,11 @@ class SpellBackwardDataset(ProceduralDataset): super().__init__(config=config, seed=config.seed, size=config.size) # Load and preprocess text - text = read_data_file("in_the_year_2889.txt") - # Extract words and clean them to contain only alphanumeric characters + text = read_data_file("words3to10.txt") self.words = [ - word - for word in re.findall(r"\b\w+\b", text) - if word.isalnum() and config.min_word_len <= len(word) <= config.max_word_len + word.strip() + for word in text.splitlines() + if word.strip().isalnum() and config.min_word_len <= len(word.strip()) <= config.max_word_len ] def __getitem__(self, idx: int) -> dict: @@ -71,6 +70,8 @@ class SpellBackwardDataset(ProceduralDataset): try: if expected_answer.lower() == answer.lower(): reward = 1.0 + elif sorted(expected_answer.lower()) == sorted(answer.lower()): + reward = 0.2 else: reward = 0.05 except: @@ -86,11 +87,11 @@ class SpellBackwardCurriculum(BaseCurriculum): self._define_attributes( RangeAttributeDefinition( name="word_len", - levels=[5, 10, 20, 30], + levels=list(range(3, 11)), description="Word length", lower_field_name="min_word_len", upper_field_name="max_word_len", - ensure_interval=True, + ensure_interval=False, ), ) diff --git a/reasoning_gym/coaching/base_curriculum.py b/reasoning_gym/coaching/base_curriculum.py index 15b725cb..f64f1234 100644 --- a/reasoning_gym/coaching/base_curriculum.py +++ b/reasoning_gym/coaching/base_curriculum.py @@ -239,3 +239,16 @@ class BaseCurriculum: self.set_attr_level(attr_name, target_level) return True return False + + def get_global_level(self) -> Optional[int]: + """Get the global level of the curriculum.""" + attr_dict = {} + if not self._attributes: + return 0 + for attr_name in self._attributes: + attr = self.get_attribute(attr_name) + if isinstance(attr, RangeAttributeDefinition): + attr_dict[attr.upper_field_name] = self.get_attr_value(attr_name) + elif isinstance(attr, ScalarAttributeDefinition): + attr_dict[attr.field_name] = self.get_attr_value(attr_name) + return attr_dict diff --git a/reasoning_gym/coaching/coach.py b/reasoning_gym/coaching/coach.py index f1bf39b9..553e980f 100644 --- a/reasoning_gym/coaching/coach.py +++ b/reasoning_gym/coaching/coach.py @@ -114,11 +114,13 @@ class GroupedScores: class ScoreBoard: """Tracks scores and metadata for coaching sessions""" - scores: list[float] = field(default_factory=list) - metadata: list[dict[str, Any]] = field(default_factory=list) - conversations: list[Optional[list[dict]]] = field(default_factory=list) + scores: dict[str, list[float]] = field(default_factory=dict) + metadata: dict[str, list[dict[str, Any]]] = field(default_factory=dict) + conversations: dict[str, list[Optional[list[dict]]]] = field(default_factory=dict) - def add_score(self, score: float, metadata: dict[str, Any], conversation: Optional[list[dict]] = None) -> None: + def add_score( + self, dataset_name: str, score: float, metadata: dict[str, Any], conversation: Optional[list[dict]] = None + ) -> None: """Add a new score entry with associated metadata and optional conversation Args: @@ -126,9 +128,13 @@ class ScoreBoard: metadata: Dictionary of metadata about the task/attempt conversation: Optional list of conversation turns as dicts """ - self.scores.append(score) - self.metadata.append(metadata) - self.conversations.append(conversation) + if dataset_name not in self.scores: + self.scores[dataset_name] = [] + self.metadata[dataset_name] = [] + self.conversations[dataset_name] = [] + self.scores[dataset_name].append(score) + self.metadata[dataset_name].append(metadata) + self.conversations[dataset_name].append(conversation) def clear(self) -> None: """Clear all stored scores, metadata and conversations""" @@ -162,35 +168,48 @@ class ScoreBoard: return tuple(key_items) - def aggregate(self, last_n: Optional[int] = None) -> GroupedScores: - """Aggregate scores by difficulty parameters or full metadata if no difficulty present + def aggregate(self, last_n: Optional[int] = None) -> dict[str, GroupedScores]: + """Aggregate scores by dataset name and then by difficulty parameters Args: last_n: Optional number of most recent entries to consider - If None, use all entries + If None, use all entries Returns: - OrderedDict mapping difficulty parameter combinations to lists of scores - Keys are tuples of (param_name, value) pairs, sorted by param_name + Dictionary mapping dataset names to their respective GroupedScores objects + Each GroupedScores contains scores grouped by difficulty parameters for that dataset """ if not self.scores: - return GroupedScores(scores=OrderedDict(), total_scores=0) + return {} - # Determine start index for iteration - start_idx = max(0, len(self.scores) - last_n) if last_n is not None else 0 + # Create a nested structure: dataset -> parameter groups -> scores + result = {} - # Group scores by difficulty parameters without creating intermediate lists - result = OrderedDict() - for i in range(start_idx, len(self.scores)): - key = self._metadata_to_key(self.metadata[i]) - if key not in result: - result[key] = [] - result[key].append(self.scores[i]) + # Process each dataset + for dataset_name, dataset_scores in self.scores.items(): + # Determine start index for this dataset + dataset_len = len(dataset_scores) + start_idx = max(0, dataset_len - last_n) if last_n is not None else 0 - # Count total scores - total_scores = sum(len(scores) for scores in result.values()) + # Create OrderedDict for this dataset's parameter groupings + dataset_groups = OrderedDict() - return GroupedScores(scores=result, total_scores=total_scores) + # Process scores for this dataset + for i in range(start_idx, dataset_len): + # Get metadata for this score + metadata = self.metadata[dataset_name][i] + params = self._metadata_to_key(metadata) + + if params not in dataset_groups: + dataset_groups[params] = [] + + dataset_groups[params].append(dataset_scores[i]) + + # Create a GroupedScores object for this dataset + total_scores = sum(len(scores) for scores in dataset_groups.values()) + result[dataset_name] = GroupedScores(scores=dataset_groups, total_scores=total_scores) + + return result class Coach(ProceduralDataset): diff --git a/reasoning_gym/composite.py b/reasoning_gym/composite.py index ff6a14ce..a96af343 100644 --- a/reasoning_gym/composite.py +++ b/reasoning_gym/composite.py @@ -47,6 +47,14 @@ class CompositeConfig: for ds in self.datasets: ds.validate() + def get_dataset_weight(self, dataset_name: str) -> float: + """Get the weight for a specific dataset by name.""" + for ds in self.datasets: + if ds.name == dataset_name: + return ds.weight + + raise ValueError(f"Dataset '{dataset_name}' not found in composite configuration") + @classmethod def from_yaml_stream(cls, stream) -> "CompositeConfig": """Load configuration from a YAML stream diff --git a/tests/test_coaching.py b/tests/test_coaching.py deleted file mode 100644 index 1741e87a..00000000 --- a/tests/test_coaching.py +++ /dev/null @@ -1,238 +0,0 @@ -import json -import math -from collections import OrderedDict - -import pytest - -from reasoning_gym.arithmetic.chain_sum import ChainSumConfig, ChainSumDataset -from reasoning_gym.arithmetic.leg_counting import LegCountingConfig -from reasoning_gym.coaching import Coach, GroupedScores -from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec - - -def test_coach_with_chain_sum(): - # Create a small ChainSum dataset - config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) - dataset = ChainSumDataset(config) - coach = Coach(dataset) - - # Simulate an agent working on tasks - for i in range(5): - item = coach[i] - - # Simulate some correct and incorrect answers - if i % 2 == 0: - # Correct answer - score = coach.score_answer( - answer=item["answer"], - entry=item, - conversation=[ - {"role": "user", "content": item["question"]}, - {"role": "assistant", "content": item["answer"]}, - ], - ) - assert score == 1.0 - else: - # Incorrect answer (None) - score = coach.score_answer( - answer=None, - entry=item, - conversation=[ - {"role": "user", "content": item["question"]}, - {"role": "assistant", "content": "I don't know"}, - ], - ) - assert score == 0.0 - - # Test score aggregation - aggregated = coach.score_board.aggregate() - - # Verify we have scores grouped by difficulty parameters - assert len(aggregated.scores) > 0 - - # Each key should be a tuple of tuples containing difficulty parameters - for key in aggregated.scores: - assert isinstance(key, tuple) - # Each inner tuple should be (param_name, value) or (param_name, (min_value, max_value)) - for param in key: - assert isinstance(param, tuple) - assert param[0] in ("source", "idx", "num_terms", "num_digits") - - # Test aggregation with last_n - last_3 = coach.score_board.aggregate(last_n=3) - assert len(last_3.scores) > 0 - - # Verify total scores count - assert last_3.total_scores == 3 - - # Verify conversation tracking - assert len(coach.score_board.conversations) == 5 - for conv in coach.score_board.conversations: - assert len(conv) == 2 # user question and assistant response - assert conv[0]["role"] == "user" - assert conv[1]["role"] == "assistant" - - # Test stats calculation - stats = aggregated.stats() - - for key, values in stats.scores.items(): - assert isinstance(values, tuple) - assert len(values) == 5 # (count, mean, std, min, max) - assert isinstance(values[0], int) # count should be int - assert all(isinstance(v, float) for v in values[1:]) # stats should be floats - - # Test stats with empty scores - empty_stats = GroupedScores(scores=OrderedDict(), total_scores=0).stats() - assert len(empty_stats.scores) == 0 - - # Test stats with ignore_empty=False - empty_group = OrderedDict({(("test", 1),): []}) - non_ignoring_stats = GroupedScores(scores=empty_group, total_scores=0).stats(ignore_empty=False) - assert len(non_ignoring_stats.scores) == 1 - stats_tuple = next(iter(non_ignoring_stats.scores.values())) - assert stats_tuple[0] == 0 # count should be 0 for empty list - assert all(math.isnan(v) for v in stats_tuple[1:]) # stats should be NaN - - print(aggregated) - print(stats) - - # Test clear functionality - coach.score_board.clear() - assert len(coach.score_board.scores) == 0 - assert len(coach.score_board.metadata) == 0 - assert len(coach.score_board.conversations) == 0 - assert len(coach.score_board.aggregate().scores) == 0 - - -def test_coach_with_composite(): - # Create configs for both datasets - chain_sum_config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10) - leg_counting_config = LegCountingConfig(min_animals=2, max_animals=3, size=10) - - # Create composite config - composite_config = CompositeConfig( - size=20, - seed=42, - datasets=[ - DatasetSpec(name="chain_sum", weight=1.0, config=chain_sum_config.__dict__), - DatasetSpec(name="leg_counting", weight=1.0, config=leg_counting_config.__dict__), - ], - ) - - # Create composite dataset and coach - dataset = CompositeDataset(composite_config) - coach = Coach(dataset) - - # Score some answers - for i in range(5): - item = coach[i] - # Correct answers for even indices - score = coach.score_answer( - answer=item["answer"] if i % 2 == 0 else None, - entry=item, - conversation=[ - {"role": "user", "content": item["question"]}, - {"role": "assistant", "content": item["answer"] if i % 2 == 0 else "I don't know"}, - ], - ) - assert score in (0.0, 1.0) - - # Test aggregation - aggregated = coach.score_board.aggregate() - assert len(aggregated.scores) > 0 - - # Verify source dataset info is first in keys - for key in aggregated.scores: - assert key[0][0] == "source" # First tuple should be ("source", dataset_name) - assert key[1][0] == "idx" # Second tuple should be ("idx", index) - - # Test stats - stats = aggregated.stats() - for key, values in stats.scores.items(): - assert isinstance(values, tuple) - assert len(values) == 5 # (count, mean, std, min, max) - assert isinstance(values[0], int) - assert all(isinstance(v, float) for v in values[1:]) - - print("\nComposite Dataset Stats:") - print(stats) - - # Test config update - coach.dataset.update_dataset_config("chain_sum", {"min_terms": 4, "max_terms": 5}) - - # Verify the config was updated - chain_sum_dataset = coach.dataset.datasets["chain_sum"] - assert chain_sum_dataset.config.min_terms == 4 - assert chain_sum_dataset.config.max_terms == 5 - - # Score some more items to verify new config is in effect - for i in range(3): - item = coach[i + 5] # Use different indices - if "chain_sum" in item["metadata"]["source_dataset"]: - metadata = item["metadata"] - assert metadata["num_terms"] >= 4 - - -def test_grouped_scores_str(): - # Test raw scores string representation - scores = OrderedDict() - scores[(("num_terms", 2), ("num_digits", 1))] = [1.0, 0.0, 1.0] - scores[(("num_terms", 3), ("num_digits", 2))] = [0.5, 0.5] - grouped = GroupedScores(scores=scores, total_scores=5) - - report = str(grouped) - assert "Total scores: 5" in report - assert "(num_terms=2, num_digits=1): n=3" in report - assert "(num_terms=3, num_digits=2): n=2" in report - assert "Values: 1.00, 0.00, 1.00" in report - assert "Values: 0.50, 0.50" in report - - # Test stats string representation - stats = grouped.stats() - stats_report = str(stats) - assert "μ=" in stats_report - assert "σ=" in stats_report - assert "min=" in stats_report - assert "max=" in stats_report - - # Test empty scores - empty = GroupedScores(scores=OrderedDict(), total_scores=0) - assert str(empty) == "No scores recorded" - - -def test_coach_score_logging(tmp_path): - # Create a log file in the temporary directory - log_file = tmp_path / "scores.jsonl" - - # Create dataset and coach with logging - config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) - dataset = ChainSumDataset(config) - coach = Coach(dataset, score_log=log_file) - - # Score a few answers - for i in range(3): - item = coach[i] - coach.score_answer( - answer=item["answer"] if i % 2 == 0 else None, - entry=item, - conversation=[ - {"role": "user", "content": item["question"]}, - {"role": "assistant", "content": item["answer"] if i % 2 == 0 else "I don't know"}, - ], - ) - - # Verify log file contents - assert log_file.exists() - - # Read and parse log entries - log_entries = [json.loads(line) for line in log_file.open()] - assert len(log_entries) == 3 - - # Verify log entry structure - for i, entry in enumerate(log_entries): - assert "score" in entry - assert "entry" in entry - assert "metadata" in entry["entry"] - assert "conversation" in entry - assert entry["score"] == (1.0 if i % 2 == 0 else 0.0) - assert len(entry["conversation"]) == 2 diff --git a/tests/test_letter_counting.py b/tests/test_letter_counting.py index 6484d553..832e5fed 100644 --- a/tests/test_letter_counting.py +++ b/tests/test_letter_counting.py @@ -90,14 +90,14 @@ def test_letter_counting_curriculum(): base_cfg: LetterCountingConfig = curriculum.generate_configuration(base_value) assert base_cfg.seed == 1 assert base_cfg.size == 150 - assert base_cfg.min_words == 10 and base_cfg.max_words == 50 + assert base_cfg.min_words == 5 and base_cfg.max_words == 7 # test incrementing attribute levels curriculum.increment_attr_level("words") increased_cfg = curriculum.generate_configuration(base_value) - assert increased_cfg.min_words == 10 and increased_cfg.max_words == 100 + assert increased_cfg.min_words == 5 and increased_cfg.max_words == 9 # test decrementing attribute level for words again curriculum.decrement_attr_level("words") partially_decreased_cfg = curriculum.generate_configuration(base_value) - assert partially_decreased_cfg.min_words == 10 and partially_decreased_cfg.max_words == 50 + assert partially_decreased_cfg.min_words == 5 and partially_decreased_cfg.max_words == 7 diff --git a/tests/test_number_sorting.py b/tests/test_number_sorting.py index bd88345f..7da35c96 100644 --- a/tests/test_number_sorting.py +++ b/tests/test_number_sorting.py @@ -99,23 +99,23 @@ def test_number_sorting_curriculum(): base_cfg: NumberSortingConfig = curriculum.generate_configuration(base_value) assert base_cfg.seed == 1 assert base_cfg.size == 150 - assert base_cfg.min_numbers == 10 and base_cfg.max_numbers == 100 - assert base_cfg.min_decimals == 0 and base_cfg.max_decimals == 2 + assert base_cfg.min_numbers == 5 and base_cfg.max_numbers == 7 + assert base_cfg.min_decimals == 0 and base_cfg.max_decimals == 1 assert base_cfg.min_value == -10_000 and base_cfg.max_value == 10_000 # test incrementing some attribute levels curriculum.increment_attr_level("numbers") curriculum.increment_attr_level("decimals") increased_cfg = curriculum.generate_configuration(base_value) - assert increased_cfg.min_numbers == 10 and increased_cfg.max_numbers == 500 - assert increased_cfg.min_decimals == 0 and increased_cfg.max_decimals == 4 + assert increased_cfg.min_numbers == 5 and increased_cfg.max_numbers == 9 + assert increased_cfg.min_decimals == 0 and increased_cfg.max_decimals == 2 assert increased_cfg.min_value == -10_000 and increased_cfg.max_value == 10_000 # test decrementing attribute level for numbers again curriculum.decrement_attr_level("numbers") partially_decreased_cfg = curriculum.generate_configuration(base_value) - assert partially_decreased_cfg.min_numbers == 10 and partially_decreased_cfg.max_numbers == 100 - assert partially_decreased_cfg.min_decimals == 0 and partially_decreased_cfg.max_decimals == 4 + assert partially_decreased_cfg.min_numbers == 5 and partially_decreased_cfg.max_numbers == 7 + assert partially_decreased_cfg.min_decimals == 0 and partially_decreased_cfg.max_decimals == 2 assert partially_decreased_cfg.min_value == -10_000 and partially_decreased_cfg.max_value == 10_000 diff --git a/tests/test_spell_backward.py b/tests/test_spell_backward.py index 64d091a7..022b8228 100644 --- a/tests/test_spell_backward.py +++ b/tests/test_spell_backward.py @@ -71,14 +71,14 @@ def test_spell_backward_curriculum(): base_cfg: SpellBackwardConfig = curriculum.generate_configuration(base_value) assert base_cfg.seed == 1 assert base_cfg.size == 150 - assert base_cfg.min_word_len == 5 and base_cfg.max_word_len == 10 + assert base_cfg.min_word_len == 3 and base_cfg.max_word_len == 3 # test incrementing attribute levels curriculum.increment_attr_level("word_len") increased_cfg = curriculum.generate_configuration(base_value) - assert increased_cfg.min_word_len == 5 and increased_cfg.max_word_len == 20 + assert increased_cfg.min_word_len == 3 and increased_cfg.max_word_len == 4 # test decrementing attribute levels curriculum.decrement_attr_level("word_len") partially_decreased_cfg = curriculum.generate_configuration(base_value) - assert partially_decreased_cfg.min_word_len == 5 and partially_decreased_cfg.max_word_len == 10 + assert partially_decreased_cfg.min_word_len == 3 and partially_decreased_cfg.max_word_len == 3 diff --git a/training/configs/llama3.1_1b_grpo.yaml b/training/configs/qwen2.5_1.5b_grpo_curr.yaml similarity index 88% rename from training/configs/llama3.1_1b_grpo.yaml rename to training/configs/qwen2.5_1.5b_grpo_curr.yaml index 74200cad..31c70423 100644 --- a/training/configs/llama3.1_1b_grpo.yaml +++ b/training/configs/qwen2.5_1.5b_grpo_curr.yaml @@ -1,43 +1,20 @@ reasoning_gym: dataset_size: 10000 + enable_curriculum_learning: True developer_prompt: DeepSeekZero - enable_curriculum_learning: False - datasets: # Used if enable_curriculum_learning is False - mini_sudoku: - weight: 0.33 - config: - min_empty: 6 - futoshiki: - weight: 0.33 - config: - max_board_size: 5 - sudoku: - weight: 0.34 - config: - min_empty: 20 - curricula: - leg_counting: - attribute_levels: - num_animals: 2 - weight: 1.0 - products: - attribute_levels: - num_terms: 4 - num_digits: 4 - weight: 1.0 - chain_sum: - attribute_levels: - num_terms: 4 - num_digits: 4 - weight: 1.0 - reward: - format_reward: - enable: True - scaling_factor: 0.2 - length_reward: - enable: True - scaling_factor: 0.2 + secondary_rewards: + - name: format + scaling_factor: 0.5 +curriculum: + enabled: True + last_k: 30 + success_threshold: 0.7 + failure_threshold: 0.1 + curricula: + spell_backward: + attribute_levels: + word_len: 0 data: tokenizer: null @@ -54,7 +31,7 @@ data: actor_rollout_ref: hybrid_engine: True model: - path: meta-llama/Llama-3.2-1B-Instruct + path: Qwen/Qwen2.5-1.5B-Instruct external_lib: null override_config: { } enable_gradient_checkpointing: True @@ -101,6 +78,7 @@ actor_rollout_ref: ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size rollout: name: vllm + max_model_len: 512 temperature: 1.0 top_k: -1 # 0 for hf rollout, -1 for vllm rollout top_p: 1 @@ -138,13 +116,13 @@ algorithm: kl_ctrl: type: fixed kl_coef: 0.001 - +verbose: True trainer: balance_batch: True total_epochs: 10 total_training_steps: null project_name: rg-test - experiment_name: verl_grpo_llama3.1_1b + experiment_name: verl_grpo_qwen_curr logger: [ 'console', 'wandb' ] val_generations_to_log_to_wandb: 0 nnodes: 1 diff --git a/training/configs/qwen2.5_1.5b_grpo.yaml b/training/configs/qwen2.5_3b_grpo.yaml similarity index 78% rename from training/configs/qwen2.5_1.5b_grpo.yaml rename to training/configs/qwen2.5_3b_grpo.yaml index 3ad49d60..077d6d1f 100644 --- a/training/configs/qwen2.5_1.5b_grpo.yaml +++ b/training/configs/qwen2.5_3b_grpo.yaml @@ -1,43 +1,31 @@ reasoning_gym: dataset_size: 10000 developer_prompt: DeepSeekZero - enable_curriculum_learning: False - datasets: # Used if enable_curriculum_learning is False - mini_sudoku: - weight: 0.33 + datasets: + spell_backward: + weight: 1 config: - min_empty: 6 - futoshiki: - weight: 0.33 - config: - max_board_size: 5 - sudoku: - weight: 0.34 - config: - min_empty: 20 - curricula: - leg_counting: - attribute_levels: - num_animals: 2 - weight: 1.0 - products: - attribute_levels: - num_terms: 4 - num_digits: 4 - weight: 1.0 - chain_sum: - attribute_levels: - num_terms: 4 - num_digits: 4 - weight: 1.0 - + min_word_len: 3 + max_word_len: 10 +curriculum: + enabled: False + schedule: + automatic: True + update_steps: 30 # automatic curriculum updating after 50 steps + last_k: 20 + success_threshold: 0.7 + failure_threshold: 0.1 + curricula: + spell_backward: + attribute_levels: + word_len: 0 reward: - format_reward: - enable: True - scaling_factor: 0.2 - length_reward: - enable: True - scaling_factor: 0.2 + use_accuracy: false + secondary_rewards: + - name: cosine + scaling_factor: 2 + - name: format + scaling_factor: 0.5 data: tokenizer: null @@ -46,22 +34,22 @@ data: prompt_key: prompt max_prompt_length: 512 max_response_length: 1024 - train_batch_size: 16 - val_batch_size: 16 - return_raw_input_ids: True # This should be set to true when the tokenizer between policy and rm differs + train_batch_size: 128 + val_batch_size: 128 return_raw_chat: True + return_raw_input_ids: True actor_rollout_ref: hybrid_engine: True model: - path: Qwen/Qwen2.5-1.5B-Instruct + path: Qwen/Qwen2.5-3B-Instruct external_lib: null override_config: { } enable_gradient_checkpointing: True use_remove_padding: True actor: strategy: fsdp # This is for backward-compatibility - ppo_mini_batch_size: 16 + ppo_mini_batch_size: 32 ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu ppo_micro_batch_size_per_gpu: 8 use_dynamic_bsz: False @@ -77,9 +65,9 @@ actor_rollout_ref: ulysses_sequence_parallel_size: 1 # sp size optim: lr: 1e-6 - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine + lr_warmup_steps_ratio: 0.1 # the total steps will be injected during runtime + min_lr_ratio: 0.1 # only useful for warmup with cosine + warmup_style: cosine # select from constant/cosine total_training_steps: -1 # must be override by program fsdp_config: wrap_policy: @@ -95,13 +83,14 @@ actor_rollout_ref: # transformer_layer_cls_to_wrap: None min_num_params: 0 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: 160 + log_prob_micro_batch_size_per_gpu: 16 log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size rollout: name: vllm - temperature: 1.0 + max_model_len: 1024 + temperature: 0.7 top_k: -1 # 0 for hf rollout, -1 for vllm rollout top_p: 1 prompt_length: ${data.max_prompt_length} # not use for opensource @@ -117,7 +106,7 @@ actor_rollout_ref: max_num_batched_tokens: 8192 max_num_seqs: 1024 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: 160 + log_prob_micro_batch_size_per_gpu: 16 log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} disable_log_stats: True @@ -138,22 +127,22 @@ algorithm: kl_ctrl: type: fixed kl_coef: 0.001 - +verbose: True trainer: balance_batch: True - total_epochs: 10 + total_epochs: 5 total_training_steps: null project_name: rg-test - experiment_name: verl_grpo_llama3.1_1b + experiment_name: verl_grpo_qwen_curr logger: [ 'console', 'wandb' ] val_generations_to_log_to_wandb: 0 nnodes: 1 n_gpus_per_node: 2 - save_freq: 100 + save_freq: 50 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if resume_from_path: False - test_freq: 100 + test_freq: 300 critic_warmup: 0 default_hdfs_dir: null remove_previous_ckpt_in_save: False @@ -163,10 +152,10 @@ trainer: critic: strategy: fsdp optim: - lr: 1e-5 - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + lr: 1e-6 + lr_warmup_steps_ratio: 0.1 # the total steps will be injected during runtime min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine + warmup_style: cosine # select from constant/cosine total_training_steps: -1 # must be override by program model: path: ~/models/deepseek-llm-7b-chat diff --git a/training/configs/qwen2.5_3b_grpo_composite.yaml b/training/configs/qwen2.5_3b_grpo_composite.yaml new file mode 100644 index 00000000..d29b0a5a --- /dev/null +++ b/training/configs/qwen2.5_3b_grpo_composite.yaml @@ -0,0 +1,221 @@ +reasoning_gym: + dataset_size: 20000 + developer_prompt: DeepSeekZero + datasets: + spell_backward: + weight: 0.33 + config: + min_word_len: 3 + max_word_len: 10 + letter_counting: + weight: 0.33 + config: + min_words: 5 + max_words: 20 + number_sorting: + weight: 0.33 + config: + min_numbers: 5 + max_numbers: 10 + min_decimals: 0 + max_decimals: 8 + min_value: -10000 + max_value: 10000 + +curriculum: + enabled: False + schedule: + automatic: True + update_steps: 30 # automatic curriculum updating after 50 steps + last_k: 20 + success_threshold: 0.7 + failure_threshold: 0.1 + curricula: + spell_backward: + attribute_levels: + word_len: 0 +reward: + use_accuracy: false + secondary_rewards: + - name: cosine + scaling_factor: 2 + - name: format + scaling_factor: 0.5 + +data: + tokenizer: null + train_files: train.parquet + val_files: test.parquet + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 1024 + train_batch_size: 128 + val_batch_size: 128 + return_raw_chat: True + return_raw_input_ids: True + +actor_rollout_ref: + hybrid_engine: True + model: + path: Qwen/Qwen2.5-3B-Instruct + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 32 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 8 + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 12288 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: True # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0.1 # the total steps will be injected during runtime + min_lr_ratio: 0.1 # only useful for warmup with cosine + warmup_style: cosine # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: True + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + max_model_len: 1024 + temperature: 0.7 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.6 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 4 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + use_fire_sampling: False + # number of responses (i.e. num sample times) + n: 8 # > 1 for grpo + val_kwargs: + do_sample: True + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: grpo + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 +verbose: True +trainer: + balance_batch: True + total_epochs: 5 + total_training_steps: null + project_name: rg-test + experiment_name: verl_grpo_qwen_composite + logger: [ 'console', 'wandb' ] + val_generations_to_log_to_wandb: 0 + nnodes: 1 + n_gpus_per_node: 4 + save_freq: 50 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + resume_from_path: False + test_freq: 300 + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + +critic: + strategy: fsdp + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0.1 # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: cosine # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +# Reward model not used for GRPO +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} diff --git a/training/configs/qwen2.5_3b_grpo_curr.yaml b/training/configs/qwen2.5_3b_grpo_curr.yaml new file mode 100644 index 00000000..405b9170 --- /dev/null +++ b/training/configs/qwen2.5_3b_grpo_curr.yaml @@ -0,0 +1,201 @@ +reasoning_gym: + dataset_size: 10000 + enable_curriculum_learning: True + developer_prompt: DeepSeekZero +curriculum: + enabled: True + schedule: + automatic: True + update_steps: 30 # automatic curriculum updating after 50 steps + last_k: 20 + success_threshold: 0.7 + failure_threshold: 0.1 + curricula: + spell_backward: + attribute_levels: + word_len: 0 +reward: + use_accuracy: false + secondary_rewards: + - name: cosine + scaling_factor: 2 + - name: format + scaling_factor: 0.5 + +data: + tokenizer: null + train_files: train.parquet + val_files: test.parquet + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 1024 + train_batch_size: 128 + val_batch_size: 128 + return_raw_chat: True + return_raw_input_ids: True + +actor_rollout_ref: + hybrid_engine: True + model: + path: Qwen/Qwen2.5-3B-Instruct + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 32 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 8 + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 12288 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: True # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0.1 # the total steps will be injected during runtime + min_lr_ratio: 0.1 # only useful for warmup with cosine + warmup_style: cosine # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: True + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + max_model_len: 1024 + temperature: 0.7 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.6 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 4 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + use_fire_sampling: False + # number of responses (i.e. num sample times) + n: 8 # > 1 for grpo + val_kwargs: + do_sample: True + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: grpo + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 +verbose: True +trainer: + balance_batch: True + total_epochs: 5 + total_training_steps: null + project_name: rg-test + experiment_name: verl_grpo_qwen_curr + logger: [ 'console', 'wandb' ] + val_generations_to_log_to_wandb: 0 + nnodes: 1 + n_gpus_per_node: 4 + save_freq: 50 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + resume_from_path: False + test_freq: 300 + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + +critic: + strategy: fsdp + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0.1 # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: cosine # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +# Reward model not used for GRPO +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} diff --git a/training/rewards/__init__.py b/training/rewards/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/training/rewards/reward.py b/training/rewards/reward.py new file mode 100644 index 00000000..1e115b10 --- /dev/null +++ b/training/rewards/reward.py @@ -0,0 +1,94 @@ +import math +import re +from typing import Any, Callable, Dict + + +class RewardRegistry: + """Simple registry for secondary reward functions.""" + + def __init__(self): + self.reward_functions = {} + + def register(self, name: str): + """Register a reward function.""" + + def decorator(func): + self.reward_functions[name] = func + return func + + return decorator + + def get(self, name: str): + """Get a reward function by name.""" + return self.reward_functions.get(name) + + def list_functions(self): + """List available reward function names.""" + return list(self.reward_functions.keys()) + + +reward_registry = RewardRegistry() + + +@reward_registry.register("cosine") +def cosine_scaled_reward(solution_str, scaling_factor, **kwargs): + """Reward function that scales based on completion length using a cosine schedule.""" + min_value_wrong = -1.0 + max_value_wrong = -0.5 + min_value_correct = 0.5 + max_value_correct = 1.0 + max_len = 1000 + + is_correct = kwargs.get("is_correct", False) + gen_len = len(solution_str) + + # Apply cosine scaling based on length + progress = gen_len / max_len + cosine = math.cos(progress * math.pi) + + if is_correct: + min_value = min_value_correct + max_value = max_value_correct + else: + min_value = max_value_wrong + max_value = min_value_wrong + + cosine_scaled_reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine) + return cosine_scaled_reward * scaling_factor + + +@reward_registry.register("format") +def compute_format_reward(solution_str: str, scaling_factor: float = 0.2, **kwargs) -> float: + """Reward use of exactly one correctly structured and block.""" + pattern = r"\s*.*?\s*.*?" + if not re.match(pattern, solution_str, re.DOTALL): + return 0.0 + think_matches = list(re.finditer(r"(.*?)", solution_str, re.DOTALL)) + answer_matches = list(re.finditer(r"(.*?)", solution_str, re.DOTALL)) + if len(think_matches) != 1 or len(answer_matches) != 1: + return 0.0 + think_content = think_matches[0].group(1) + if "" in think_content or "" in think_content: + return 0.0 + answer_content = answer_matches[0].group(1) + if "" in answer_content or "" in answer_content: + return 0.0 + return 1.0 * scaling_factor + + +@reward_registry.register("length") +def length_reward(solution_str, correctness_score, scaling_factor, **kwargs): + """Reward length appropriately based on correctness.""" + epsilon = 1e-6 + max_score = kwargs.get("max_score", 1.0) + max_output_length = kwargs.get("max_output_length", 1024) + + generation_len = len(solution_str) + progress = min(generation_len / max_output_length, 1.0) + + if correctness_score < max_score - epsilon: + length_reward = (max_score - correctness_score) * progress + else: + length_reward = -progress + + return length_reward * scaling_factor diff --git a/training/train_grpo.py b/training/train_grpo.py index 9eec6ff7..9da5d16d 100644 --- a/training/train_grpo.py +++ b/training/train_grpo.py @@ -21,15 +21,14 @@ def prepare_datasets(config, tokenizer) -> tuple[ReasoningGymDataset, ReasoningG developer_prompt_setting = config.reasoning_gym.developer_prompt developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS[developer_prompt_setting] - if config.reasoning_gym.enable_curriculum_learning: - curricula = config.reasoning_gym.curricula + if config.curriculum.enabled: + curricula = config.curriculum.curricula curriculum_config = CurriculumExperimentConfig( curricula={ curriculum_name: CurriculumAttributeConfig(**curriculum_config) for curriculum_name, curriculum_config in curricula.items() } ) - curriculum_config.validate() train_data_source = CurriculumExperiment( name=config.trainer.experiment_name, config=curriculum_config, size=dataset_size, seed=1 @@ -42,9 +41,8 @@ def prepare_datasets(config, tokenizer) -> tuple[ReasoningGymDataset, ReasoningG ] train_data_source = reasoning_gym.create_dataset("composite", seed=1, size=dataset_size, datasets=dataset_specs) val_data_source = reasoning_gym.create_dataset("composite", seed=2, size=dataset_size, datasets=dataset_specs) - - train_dataset = make_dataset(tokenizer, train_data_source, developer_prompt) - val_dataset = make_dataset(tokenizer, val_data_source, developer_prompt) + train_dataset = make_dataset(tokenizer, train_data_source, "composite", developer_prompt) + val_dataset = make_dataset(tokenizer, val_data_source, "composite", developer_prompt) return train_dataset, val_dataset diff --git a/training/trainers/ray_grpo_trainer.py b/training/trainers/ray_grpo_trainer.py index a1d17824..6dbe8b0b 100644 --- a/training/trainers/ray_grpo_trainer.py +++ b/training/trainers/ray_grpo_trainer.py @@ -1,14 +1,23 @@ # Adapted version of Bytedance code: # https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py -import re +import uuid +import numpy as np import torch from omegaconf import OmegaConf, open_dict +from reward import reward_registry from torchdata.stateful_dataloader import StatefulDataLoader from utils import ReasoningGymDataset from verl import DataProto -from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.trainer.ppo.ray_trainer import ( + RayPPOTrainer, + _timer, + apply_kl_penalty, + compute_advantage, + compute_data_metrics, + compute_timing_metrics, +) from verl.utils.dataset.rl_dataset import collate_fn from reasoning_gym.utils import extract_answer @@ -30,8 +39,49 @@ class RayGRPOTrainer(RayPPOTrainer): self.val_dataset = val_dataset self.max_output_length = max_output_length - self.format_reward_scaling_factor = config.reward.format_reward.scaling_factor - self.length_reward_scaling_factor = config.reward.length_reward.scaling_factor + if config.curriculum.enabled: + self.last_k = config.curriculum.last_k + else: + self.last_k = None + + self.reward_functions = [] + if hasattr(config, "reward") and hasattr(config.reward, "secondary_rewards"): + for func_config in config.reward.secondary_rewards: + func_name = func_config.name + scaling_factor = func_config.get("scaling_factor", 1.0) + func = reward_registry.get(func_name) + if func: + # Store both function and its arguments + self.reward_functions.append( + { + "function": func, + "name": func_name, + "scaling_factor": scaling_factor, + "kwargs": func_config.get("kwargs", {}), + } + ) + + if config.curriculum.enabled: + self.last_k = config.curriculum.last_k + else: + self.last_k = None + + self.reward_functions = [] + if hasattr(config, "reward") and hasattr(config.reward, "secondary_rewards"): + for func_config in config.reward.secondary_rewards: + func_name = func_config.name + scaling_factor = func_config.get("scaling_factor", 1.0) + func = reward_registry.get(func_name) + if func: + # Store both function and its arguments + self.reward_functions.append( + { + "function": func, + "name": func_name, + "scaling_factor": scaling_factor, + "kwargs": func_config.get("kwargs", {}), + } + ) train_reward_fn = lambda data: self._score_output(data, num_examine=0) val_reward_fn = lambda data: self._score_output(data, num_examine=1) @@ -69,81 +119,44 @@ class RayGRPOTrainer(RayPPOTrainer): sequences_str = prompt_str + response_str index = data_item.non_tensor_batch["index"] - - reward = score = self._compute_correctness_score( + correctness_score = self._compute_correctness_score( solution_str=response_str, index=index, ) - - if self.config.reward.format_reward.enable: - format_reward = self._compute_format_reward(response_str) - reward += format_reward + if self.config.reward.use_accuracy: + reward_components = {"correctness": correctness_score} + total_reward = correctness_score else: - format_reward = 0.0 + reward_components = {} + total_reward = 0 - if self.config.reward.length_reward.enable: - length_reward = self._compute_length_reward(response_str, score) - reward += length_reward - else: - length_reward = 0.0 + for reward_fn in self.reward_functions: + func = reward_fn["function"] + name = reward_fn["name"] + scaling_factor = reward_fn["scaling_factor"] + kwargs = reward_fn["kwargs"] + if name == "cosine": + is_correct = correctness_score == 1.0 + reward = func(response_str, scaling_factor, is_correct=is_correct, **kwargs) + else: + reward = func(response_str, scaling_factor, **kwargs) + reward_components[name] = reward + total_reward += reward - reward_tensor[i, valid_response_length - 1] = reward + reward_tensor[i, valid_response_length - 1] = total_reward if num_printed < num_examine: - print( - f"reward={reward} (score={score}, format={format_reward}, length={length_reward}), seq={sequences_str}" - ) + components = ", ".join([f"{k}={v:.2f}" for k, v in reward_components.items()]) + print(f"(score={total_reward}, seq={sequences_str}, response={response_str})") + print(f"reward={total_reward:.2f} ({components})") num_printed += 1 return reward_tensor - def _compute_format_reward(self, solution_str: str) -> float: - """Reward use of exactly one correctly structured and block.""" - scaling_factor = self.format_reward_scaling_factor - # check and blocks are present - pattern = r"\s*.*?\s*.*?" - if not re.match(pattern, solution_str, re.DOTALL): - return 0.0 - # check exactly one properly structured block and one block - think_matches = list(re.finditer(r"(.*?)", solution_str, re.DOTALL)) - answer_matches = list(re.finditer(r"(.*?)", solution_str, re.DOTALL)) - if len(think_matches) != 1 or len(answer_matches) != 1: - return 0.0 - # check for or inside - think_content = think_matches[0].group(1) - if "" in think_content or "" in think_content: - return 0.0 - # check for nested or inside - answer_content = answer_matches[0].group(1) - if "" in answer_content or "" in answer_content: - return 0.0 - return 1.0 * scaling_factor - - def _compute_length_reward( - self, - solution_str: str, - correctness_score: float, - max_score: float = 1.0, - ) -> float: - """ - Reward shorter solutions for perfect answers, longer solutions for imperfect answers. - The scaling factor for this should be set far below 1.0, to avoid dominating the reward signal over correctness. - """ - epsilon = 1e-6 - scaling_factor = self.length_reward_scaling_factor - generation_len = len(solution_str) - progress = min(generation_len / self.max_output_length, 1.0) - if correctness_score < max_score - epsilon: - # for imperfect answers, incentivise longer ones - length_reward = (max_score - correctness_score) * progress - else: - # for perfect answers, penalise longer ones - length_reward = -progress - return length_reward * scaling_factor - def _compute_correctness_score(self, solution_str: str, index: int) -> float: found_answer = extract_answer(solution_str, tag_name="answer") data = self.train_dataset.data + entry = data[index] if self.train_dataset.experiment: experiment = self.train_dataset.experiment @@ -187,3 +200,323 @@ class RayGRPOTrainer(RayPPOTrainer): with open_dict(self.config): self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps self.config.critic.optim.total_training_steps = total_training_steps + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + print(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + batch: DataProto = DataProto.from_single_dict(batch_dict) + + gen_batch = batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids"], + ) + + is_last_step = self.global_steps >= self.total_training_steps + + with _timer("step", timing_raw): + # generate a batch + with _timer("gen", timing_raw): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + # balance the number of valid tokens on each dp rank. + # Note that this breaks the order of data inside the batch. + # Please take care when you implement group based adv computation such as GRPO and rloo + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + # recompute old_log_probs + with _timer("old_log_prob", timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with _timer("ref", timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + with _timer("adv", timing_raw): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + # we combine with rule-based rm + reward_tensor = self.reward_fn(batch) + batch.batch["token_level_scores"] = reward_tensor + + # compute rewards. apply_kl_penalty if available + if not self.config.actor_rollout_ref.actor.get("use_kl_loss", False): + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # compute advantages, executed on the driver process + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + ) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with _timer("testing", timing_raw): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with _timer("save_checkpoint", timing_raw): + self._save_checkpoint() + + # collect metrics + if self.config.curriculum.enabled: + grouped_scores = self.train_dataset.aggregate(last_n=self.config.curriculum.last_k) + if self.config.curriculum.schedule.automatic: + for dataset_name in grouped_scores.keys(): + if self.global_steps % self.config.curriculum.schedule.update_steps == 0: + self.train_dataset.experiment.update_difficulty(dataset_name, method="increment") + else: + for dataset_name in grouped_scores.keys(): + if ( + grouped_scores[dataset_name]["results"] > self.config.curriculum.success_threshold + ) and (grouped_scores[dataset_name]["total_samples"] > self.config.curriculum.last_k): + self.train_dataset.experiment.update_difficulty(dataset_name, method="increment") + elif ( + grouped_scores[dataset_name]["results"] < self.config.curriculum.failure_threshold + ) and (grouped_scores[dataset_name]["total_samples"] > self.config.curriculum.last_k): + self.train_dataset.update_difficulty(dataset_name, method="decrement") + + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + print(f"Final validation metrics: {last_val_metrics}") + return + + self.global_steps += 1 + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + print(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + batch: DataProto = DataProto.from_single_dict(batch_dict) + + gen_batch = batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids"], + ) + + is_last_step = self.global_steps >= self.total_training_steps + + with _timer("step", timing_raw): + # generate a batch + with _timer("gen", timing_raw): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + # balance the number of valid tokens on each dp rank. + # Note that this breaks the order of data inside the batch. + # Please take care when you implement group based adv computation such as GRPO and rloo + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + # recompute old_log_probs + with _timer("old_log_prob", timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with _timer("ref", timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + with _timer("adv", timing_raw): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + # we combine with rule-based rm + reward_tensor = self.reward_fn(batch) + batch.batch["token_level_scores"] = reward_tensor + + # compute rewards. apply_kl_penalty if available + if not self.config.actor_rollout_ref.actor.get("use_kl_loss", False): + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # compute advantages, executed on the driver process + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + ) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with _timer("testing", timing_raw): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with _timer("save_checkpoint", timing_raw): + self._save_checkpoint() + + # collect metrics + if self.config.curriculum.enabled: + grouped_scores = self.train_dataset.aggregate(last_n=self.config.curriculum.last_k) + if self.config.curriculum.schedule.automatic: + for dataset_name in grouped_scores.keys(): + if self.global_steps % self.config.curriculum.schedule.update_steps == 0: + self.train_dataset.experiment.update_difficulty(dataset_name, method="increment") + else: + for dataset_name in grouped_scores.keys(): + if ( + grouped_scores[dataset_name]["results"] > self.config.curriculum.success_threshold + ) and (grouped_scores[dataset_name]["total_samples"] > self.config.curriculum.last_k): + self.train_dataset.experiment.update_difficulty(dataset_name, method="increment") + elif ( + grouped_scores[dataset_name]["results"] < self.config.curriculum.failure_threshold + ) and (grouped_scores[dataset_name]["total_samples"] > self.config.curriculum.last_k): + self.train_dataset.update_difficulty(dataset_name, method="decrement") + + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + print(f"Final validation metrics: {last_val_metrics}") + return + + self.global_steps += 1 diff --git a/training/utils/datasets.py b/training/utils/datasets.py index 80da6b96..973ae2f0 100644 --- a/training/utils/datasets.py +++ b/training/utils/datasets.py @@ -1,5 +1,6 @@ -from typing import Optional +from typing import Literal, Optional +import numpy as np import verl.utils.torch_functional as verl_F from torch.utils.data import Dataset from transformers import PreTrainedTokenizer @@ -25,6 +26,11 @@ class ReasoningGymDataset(Dataset): procedural_dataset is None or experiment is None ), "Only one of `procedural_dataset` or `experiment` may be provided" + assert procedural_dataset or experiment, "One of `procedural_dataset` or `experiment` must be provided" + assert ( + procedural_dataset is None or experiment is None + ), "Only one of `procedural_dataset` or `experiment` may be provided" + self.tokenizer = tokenizer self.data = procedural_dataset or experiment.composite self.experiment = experiment @@ -67,10 +73,39 @@ class ReasoningGymDataset(Dataset): row_dict["index"] = index return row_dict + def update_experiment_difficulty(self, dataset_name: str, method: Literal["increment", "decrement"]): + """Update the difficulty of the underlying dataset.""" + if self.experiment is None: + raise ValueError("Cannot update difficulty: dataset is not a CurriculumExperiment") + if method not in ["increment", "decrement"]: + raise ValueError("Invalid method: must be 'increment' or 'decrement'") + self.experiment.score_board.clear() + self.experiment.update_difficulty(dataset_name, method) + self.data = self.experiment.composite + return True + + def aggregate(self, last_n: Optional[int] = None): + """Aggregate scores from the underlying experiment""" + if self.experiment is None: + raise ValueError("Cannot aggregate scores: dataset is not a CurriculumExperiment") + + results = self.experiment.score_board.aggregate(last_n=last_n) + output_results = {} + + for key, value in results.items(): + output_results[key] = {} + scores = value.scores + first_key = list(scores.keys())[0] + output_results[key]["results"] = np.mean(scores[first_key]) + output_results[key]["total_samples"] = value.total_scores + + return output_results + def make_dataset( tokenizer, data_source: Experiment | ProceduralDataset, + dataset_name: str, developer_prompt: str, ) -> ReasoningGymDataset: """ @@ -78,10 +113,12 @@ def make_dataset( """ kwargs = { "tokenizer": tokenizer, + # "dataset_name": dataset_name, "developer_prompt": developer_prompt, } if isinstance(data_source, Experiment): kwargs["experiment"] = data_source else: kwargs["procedural_dataset"] = data_source + print(type(data_source)) return ReasoningGymDataset(**kwargs)