formatting

This commit is contained in:
Andreas Koepf 2025-01-24 10:34:07 +01:00
parent 98988c8481
commit 20069b2a7d
37 changed files with 504 additions and 666 deletions

View file

@ -1,11 +1,8 @@
"""Tests for leg counting task generation"""
import pytest
from reasoning_gym.arithmetic.leg_counting import (
LegCountingConfig,
LegCountingDataset,
ANIMALS,
)
from reasoning_gym.arithmetic.leg_counting import ANIMALS, LegCountingConfig, LegCountingDataset
def test_leg_counting_config_validation():
@ -35,13 +32,7 @@ def test_leg_counting_dataset_deterministic():
def test_leg_counting_dataset_items():
"""Test basic properties of generated items"""
config = LegCountingConfig(
min_animals=2,
max_animals=4,
max_instances=2,
size=10,
seed=42
)
config = LegCountingConfig(min_animals=2, max_animals=4, max_instances=2, size=10, seed=42)
dataset = LegCountingDataset(config)
for i in range(len(dataset)):
@ -51,19 +42,19 @@ def test_leg_counting_dataset_items():
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert "animals" in item["metadata"]
assert "total_legs" in item["metadata"]
# Verify animal count constraints
animals = item["metadata"]["animals"]
assert len(animals) >= config.min_animals
assert len(animals) <= config.max_animals
# Verify instance count constraints
assert all(1 <= count <= config.max_instances for count in animals.values())
# Verify leg counting is correct
total_legs = sum(count * ANIMALS[animal] for animal, count in animals.items())
assert str(total_legs) == item["answer"]
@ -86,7 +77,7 @@ def test_leg_counting_animal_validation():
"""Test that all animals have valid leg counts"""
# Verify all animals have non-negative leg counts
assert all(legs >= 0 for legs in ANIMALS.values())
# Verify common animals have expected leg counts
assert ANIMALS["spider"] == 8
assert ANIMALS["insect"] == 6