mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
[aiw] remove output_formats style and change return type to a standard format
This commit is contained in:
parent
57a1b5c353
commit
3d42e84807
2 changed files with 50 additions and 65 deletions
|
|
@ -29,7 +29,6 @@ class AliceInWonderlandConfig:
|
|||
male_names (List[str]): List of male names to use in questions.
|
||||
female_names (List[str]): List of female names to use in questions. Must include 'Alice'.
|
||||
task_types (List[TaskType]): List of task types to include in dataset.
|
||||
output_formats (List[OutputFormat]): List of output formats to include in dataset.
|
||||
seed (Optional[int]): Seed for random number generation.
|
||||
size (int): Number of samples in the dataset.
|
||||
max_entities (int): Max number of siblings/friends/colleagues in questions.
|
||||
|
|
@ -47,14 +46,8 @@ class AliceInWonderlandConfig:
|
|||
]
|
||||
)
|
||||
task_types: List[TaskType] = field(
|
||||
default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
|
||||
)
|
||||
output_formats: List[OutputFormat] = field(
|
||||
default_factory=lambda: [
|
||||
OutputFormat.PLAIN,
|
||||
OutputFormat.RESTRICTED,
|
||||
OutputFormat.THINKING,
|
||||
]
|
||||
TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
size: int = 10
|
||||
|
|
@ -66,8 +59,6 @@ class AliceInWonderlandConfig:
|
|||
assert len(self.female_names) > 0, "must provide female names"
|
||||
assert "Alice" in self.female_names, "'Alice' must be in female names"
|
||||
assert len(self.task_types) > 0, "must provide at least one task type"
|
||||
assert len(
|
||||
self.output_formats) > 0, "must provide at least one output format"
|
||||
assert self.max_entities > 0, "max_entities must be positive"
|
||||
|
||||
|
||||
|
|
@ -88,6 +79,7 @@ class AliceInWonderlandDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: AliceInWonderlandConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.templates = {
|
||||
|
|
@ -118,12 +110,12 @@ class AliceInWonderlandDataset(ProceduralDataset):
|
|||
"$num_female_colleagues_alice_circle female colleagues. These are all colleagues that $female_name has. "
|
||||
"All these mentioned persons around $female_name are colleagues of each other. "
|
||||
"$male_name has $num_male_colleagues_bob_circle male colleagues "
|
||||
"and $num_female_colleagues_bob_circle female colleagues in total. "
|
||||
"and $num_female_colleagues_bob_circle female colleagues in total. "
|
||||
"All these mentioned persons around $male_name are colleagues of each other. "
|
||||
"The people in the circle around $male_name do not have "
|
||||
"other colleagues aside - with the only exception of Matilda. "
|
||||
"The people in the circle around $male_name do not have "
|
||||
"other colleagues aside - with the only exception of Matilda. "
|
||||
"She is colleague of $male_name and she is also colleague of $female_name, "
|
||||
"being part of $female_name's circle. How many female colleagues does Matilda have?"
|
||||
"being part of $female_name's circle. How many female colleagues does Matilda have?"
|
||||
),
|
||||
],
|
||||
}
|
||||
|
|
@ -153,7 +145,6 @@ class AliceInWonderlandDataset(ProceduralDataset):
|
|||
and a description of the example.
|
||||
"""
|
||||
task_type = rng.choice(self.config.task_types)
|
||||
output_format = rng.choice(self.config.output_formats)
|
||||
female_name = rng.choice(self.config.female_names)
|
||||
male_name = rng.choice(self.config.male_names)
|
||||
|
||||
|
|
@ -182,10 +173,14 @@ class AliceInWonderlandDataset(ProceduralDataset):
|
|||
num_female=num_female,
|
||||
)
|
||||
elif task_type == TaskType.COLLEAGUES:
|
||||
num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
|
||||
num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
|
||||
num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
|
||||
num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
|
||||
num_male_colleagues_alice_circle = rng.randint(
|
||||
1, self.config.max_entities)
|
||||
num_female_colleagues_alice_circle = rng.randint(
|
||||
1, self.config.max_entities)
|
||||
num_male_colleagues_bob_circle = rng.randint(
|
||||
1, self.config.max_entities)
|
||||
num_female_colleagues_bob_circle = rng.randint(
|
||||
1, self.config.max_entities)
|
||||
|
||||
answer = num_female_colleagues_alice_circle + 1
|
||||
template = rng.choice(self.templates[TaskType.COLLEAGUES])
|
||||
|
|
@ -198,14 +193,12 @@ class AliceInWonderlandDataset(ProceduralDataset):
|
|||
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle
|
||||
)
|
||||
|
||||
formatted_question = self.format_templates[output_format].substitute(
|
||||
question=question
|
||||
)
|
||||
|
||||
return {
|
||||
"prompt": formatted_question,
|
||||
"right_answer": str(answer),
|
||||
"description": f"{task_type.value} variation, {output_format.value} format",
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"task_type": task_type.value
|
||||
}
|
||||
}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
@ -213,4 +206,4 @@ class AliceInWonderlandDataset(ProceduralDataset):
|
|||
return self._get_aiw(rng)
|
||||
|
||||
|
||||
register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig)
|
||||
register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
|
||||
from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType, OutputFormat
|
||||
|
||||
|
||||
def test_aiw_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
@ -11,18 +12,16 @@ def test_aiw_config_validation():
|
|||
with pytest.raises(AssertionError):
|
||||
config = AliceInWonderlandConfig(female_names=[]) # Empty female names
|
||||
config.validate()
|
||||
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = AliceInWonderlandConfig(female_names=["Mary", "Jane"]) # No Alice
|
||||
config = AliceInWonderlandConfig(
|
||||
female_names=["Mary", "Jane"]) # No Alice
|
||||
config.validate()
|
||||
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = AliceInWonderlandConfig(task_types=[]) # No task types
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = AliceInWonderlandConfig(output_formats=[]) # No output formats
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_aiw_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
|
|
@ -33,6 +32,7 @@ def test_aiw_deterministic():
|
|||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_aiw_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = AliceInWonderlandConfig(size=50, seed=42)
|
||||
|
|
@ -41,29 +41,28 @@ def test_aiw_items():
|
|||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "prompt" in item
|
||||
assert "right_answer" in item
|
||||
assert "description" in item
|
||||
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Verify answer is numeric and positive
|
||||
answer = int(item["right_answer"])
|
||||
answer = int(item["answer"])
|
||||
assert answer > 0
|
||||
|
||||
|
||||
# Verify question contains at least one female name
|
||||
female_names = config.female_names
|
||||
assert any(name in item["prompt"] for name in female_names)
|
||||
assert any(name in item["question"] for name in female_names)
|
||||
|
||||
# Verify question task type characteristics
|
||||
task_type = item["metadata"]["task_type"]
|
||||
if task_type == TaskType.SIBLINGS.value:
|
||||
assert any(phrase in item["question"]
|
||||
for phrase in ["brothers", "sisters"])
|
||||
elif task_type == TaskType.FRIENDS.value:
|
||||
assert "friends" in item["question"]
|
||||
elif task_type == TaskType.COLLEAGUES:
|
||||
assert "colleagues" in item["question"]
|
||||
|
||||
# Verify question format
|
||||
if TaskType.SIBLINGS.value in item["description"]:
|
||||
assert any(phrase in item["prompt"] for phrase in ["brothers", "sisters"])
|
||||
elif TaskType.FRIENDS.value in item["description"]:
|
||||
assert "friends" in item["prompt"]
|
||||
|
||||
# Verify output format
|
||||
if OutputFormat.RESTRICTED.value in item["description"]:
|
||||
assert "DO NOT OUTPUT ANY TEXT EXCEPT" in item["prompt"]
|
||||
elif OutputFormat.THINKING.value in item["description"]:
|
||||
assert "think carefully step by step" in item["prompt"]
|
||||
|
||||
def test_aiw_iteration():
|
||||
"""Test that iteration works correctly"""
|
||||
|
|
@ -85,23 +84,16 @@ def test_aiw_iteration():
|
|||
second_items = list(dataset)
|
||||
assert first_items == second_items
|
||||
|
||||
|
||||
def test_aiw_random_ranges():
|
||||
"""Test that generated numbers stay within expected ranges"""
|
||||
config = AliceInWonderlandConfig(size=30, seed=42, max_entities=12)
|
||||
dataset = AliceInWonderlandDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
prompt = item["prompt"]
|
||||
numbers = [int(n) for n in prompt.split() if n.isdigit()]
|
||||
|
||||
question = item["question"]
|
||||
numbers = [int(n) for n in question.split() if n.isdigit()]
|
||||
|
||||
# Check all numbers are in reasonable range (1-6 as per implementation)
|
||||
assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"
|
||||
|
||||
def test_output_format_is_correct():
|
||||
"""Test that the output format adheres to the user input"""
|
||||
config = AliceInWonderlandConfig(size=30, seed=42, output_formats=[OutputFormat.THINKING])
|
||||
dataset = AliceInWonderlandDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
prompt = item["prompt"]
|
||||
assert "think carefully step by step" in item["prompt"]
|
||||
assert all(
|
||||
1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue