tinker/tests/test_chunked_fwdbwd_helpers.py
2026-01-16 05:18:34 +00:00

118 lines
4.5 KiB
Python

import pytest
from tinker.lib.chunked_fwdbwd_helpers import (
_metrics_reduction,
)
from tinker.types import ForwardBackwardOutput, TensorData
class TestMetricsReductionWithUnique:
"""Test the _metrics_reduction function with unique metrics."""
def create_forward_backward_output(
self, metrics: dict, num_loss_fn_outputs: int = 1
) -> ForwardBackwardOutput:
"""Helper to create ForwardBackwardOutput for testing."""
# LossFnOutput is Dict[str, TensorData], so create sample data
tensor_data = TensorData(data=[0.0], dtype="float32", shape=[1])
loss_fn_outputs = [{"loss": tensor_data} for _ in range(num_loss_fn_outputs)]
return ForwardBackwardOutput(
loss_fn_output_type="test",
metrics=metrics,
loss_fn_outputs=loss_fn_outputs,
)
def test_unique_reduction_single_result(self):
"""Test unique reduction with single result."""
results = [self.create_forward_backward_output({"clock_cycle:unique": 12345})]
reduced = _metrics_reduction(results)
assert "clock_cycle:unique" in reduced
assert reduced["clock_cycle:unique"] == 12345
def test_unique_reduction_multiple_results(self):
"""Test unique reduction with multiple results creates additional keys."""
results = [
self.create_forward_backward_output({"clock_cycle:unique": 12345}),
self.create_forward_backward_output({"clock_cycle:unique": 12346}),
self.create_forward_backward_output({"clock_cycle:unique": 12347}),
]
reduced = _metrics_reduction(results)
# First value should be the main key
assert reduced["clock_cycle:unique"] == 12345
# Additional values should get suffixed keys
assert reduced["clock_cycle:unique_1"] == 12346
assert reduced["clock_cycle:unique_2"] == 12347
def test_unique_reduction_with_other_metrics(self):
"""Test unique reduction alongside other metric types."""
results = [
self.create_forward_backward_output(
{
"clock_cycle:unique": 100,
"loss:mean": 0.5,
"accuracy:max": 0.8,
}
),
self.create_forward_backward_output(
{
"clock_cycle:unique": 101,
"loss:mean": 0.6,
"accuracy:max": 0.9,
}
),
]
reduced = _metrics_reduction(results)
# Unique metric behavior
assert reduced["clock_cycle:unique"] == 100
assert reduced["clock_cycle:unique_1"] == 101
# Other metrics should work as normal
assert reduced["loss:mean"] == 0.55
assert reduced["accuracy:max"] == 0.9
def test_unique_reduction_empty_results(self):
"""Test unique reduction with empty results."""
results = []
reduced = _metrics_reduction(results)
assert reduced == {}
def test_unique_reduction_missing_metric(self):
"""Test unique reduction when some results don't have the metric."""
results = [
self.create_forward_backward_output({"clock_cycle:unique": 100}),
self.create_forward_backward_output({"other_metric:mean": 0.5}), # Missing clock_cycle
]
reduced = _metrics_reduction(results)
# Neither metric should be present since they're not in all results
# This matches the actual behavior of _metrics_reduction which requires
# all results to have a metric for it to be processed
assert "clock_cycle:unique" not in reduced
assert "other_metric:mean" not in reduced
assert len(reduced) == 0 # Should be empty
def test_unique_reduction_with_float_values(self):
"""Test unique reduction with float values."""
results = [
self.create_forward_backward_output({"timestamp:unique": 1234567890.123}),
self.create_forward_backward_output({"timestamp:unique": 1234567890.456}),
]
reduced = _metrics_reduction(results)
assert reduced["timestamp:unique"] == 1234567890.123
assert reduced["timestamp:unique_1"] == 1234567890.456
def test_invalid_reduction_type_raises_assertion(self):
"""Test that invalid reduction types raise AssertionError."""
results = [self.create_forward_backward_output({"invalid:nonexistent": 100})]
with pytest.raises(AssertionError, match="Invalid reduction nonexistent"):
_metrics_reduction(results)