[aiw] remove output_formats style and change return type to a standard format

This commit is contained in:
rishabhranawat 2025-02-01 16:30:05 -08:00
parent 57a1b5c353
commit 3d42e84807
2 changed files with 50 additions and 65 deletions

View file

@ -29,7 +29,6 @@ class AliceInWonderlandConfig:
male_names (List[str]): List of male names to use in questions.
female_names (List[str]): List of female names to use in questions. Must include 'Alice'.
task_types (List[TaskType]): List of task types to include in dataset.
output_formats (List[OutputFormat]): List of output formats to include in dataset.
seed (Optional[int]): Seed for random number generation.
size (int): Number of samples in the dataset.
max_entities (int): Max number of siblings/friends/colleagues in questions.
@ -47,14 +46,8 @@ class AliceInWonderlandConfig:
]
)
task_types: List[TaskType] = field(
default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
)
output_formats: List[OutputFormat] = field(
default_factory=lambda: [
OutputFormat.PLAIN,
OutputFormat.RESTRICTED,
OutputFormat.THINKING,
]
TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
)
seed: Optional[int] = None
size: int = 10
@ -66,8 +59,6 @@ class AliceInWonderlandConfig:
assert len(self.female_names) > 0, "must provide female names"
assert "Alice" in self.female_names, "'Alice' must be in female names"
assert len(self.task_types) > 0, "must provide at least one task type"
assert len(
self.output_formats) > 0, "must provide at least one output format"
assert self.max_entities > 0, "max_entities must be positive"
@ -88,6 +79,7 @@ class AliceInWonderlandDataset(ProceduralDataset):
}
"""
def __init__(self, config: AliceInWonderlandConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.templates = {
@ -118,12 +110,12 @@ class AliceInWonderlandDataset(ProceduralDataset):
"$num_female_colleagues_alice_circle female colleagues. These are all colleagues that $female_name has. "
"All these mentioned persons around $female_name are colleagues of each other. "
"$male_name has $num_male_colleagues_bob_circle male colleagues "
"and $num_female_colleagues_bob_circle female colleagues in total. "
"and $num_female_colleagues_bob_circle female colleagues in total. "
"All these mentioned persons around $male_name are colleagues of each other. "
"The people in the circle around $male_name do not have "
"other colleagues aside - with the only exception of Matilda. "
"The people in the circle around $male_name do not have "
"other colleagues aside - with the only exception of Matilda. "
"She is colleague of $male_name and she is also colleague of $female_name, "
"being part of $female_name's circle. How many female colleagues does Matilda have?"
"being part of $female_name's circle. How many female colleagues does Matilda have?"
),
],
}
@ -153,7 +145,6 @@ class AliceInWonderlandDataset(ProceduralDataset):
and a description of the example.
"""
task_type = rng.choice(self.config.task_types)
output_format = rng.choice(self.config.output_formats)
female_name = rng.choice(self.config.female_names)
male_name = rng.choice(self.config.male_names)
@ -182,10 +173,14 @@ class AliceInWonderlandDataset(ProceduralDataset):
num_female=num_female,
)
elif task_type == TaskType.COLLEAGUES:
num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
num_male_colleagues_alice_circle = rng.randint(
1, self.config.max_entities)
num_female_colleagues_alice_circle = rng.randint(
1, self.config.max_entities)
num_male_colleagues_bob_circle = rng.randint(
1, self.config.max_entities)
num_female_colleagues_bob_circle = rng.randint(
1, self.config.max_entities)
answer = num_female_colleagues_alice_circle + 1
template = rng.choice(self.templates[TaskType.COLLEAGUES])
@ -198,14 +193,12 @@ class AliceInWonderlandDataset(ProceduralDataset):
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle
)
formatted_question = self.format_templates[output_format].substitute(
question=question
)
return {
"prompt": formatted_question,
"right_answer": str(answer),
"description": f"{task_type.value} variation, {output_format.value} format",
"question": question,
"answer": answer,
"metadata": {
"task_type": task_type.value
}
}
def __getitem__(self, idx: int) -> dict:
@ -213,4 +206,4 @@ class AliceInWonderlandDataset(ProceduralDataset):
return self._get_aiw(rng)
register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig)
register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig)