mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
118 lines
4.5 KiB
Python
118 lines
4.5 KiB
Python
import pytest
|
|
|
|
from tinker.lib.chunked_fwdbwd_helpers import (
|
|
_metrics_reduction,
|
|
)
|
|
from tinker.types import ForwardBackwardOutput, TensorData
|
|
|
|
|
|
class TestMetricsReductionWithUnique:
|
|
"""Test the _metrics_reduction function with unique metrics."""
|
|
|
|
def create_forward_backward_output(
|
|
self, metrics: dict, num_loss_fn_outputs: int = 1
|
|
) -> ForwardBackwardOutput:
|
|
"""Helper to create ForwardBackwardOutput for testing."""
|
|
# LossFnOutput is Dict[str, TensorData], so create sample data
|
|
tensor_data = TensorData(data=[0.0], dtype="float32", shape=[1])
|
|
loss_fn_outputs = [{"loss": tensor_data} for _ in range(num_loss_fn_outputs)]
|
|
return ForwardBackwardOutput(
|
|
loss_fn_output_type="test",
|
|
metrics=metrics,
|
|
loss_fn_outputs=loss_fn_outputs,
|
|
)
|
|
|
|
def test_unique_reduction_single_result(self):
|
|
"""Test unique reduction with single result."""
|
|
results = [self.create_forward_backward_output({"clock_cycle:unique": 12345})]
|
|
|
|
reduced = _metrics_reduction(results)
|
|
|
|
assert "clock_cycle:unique" in reduced
|
|
assert reduced["clock_cycle:unique"] == 12345
|
|
|
|
def test_unique_reduction_multiple_results(self):
|
|
"""Test unique reduction with multiple results creates additional keys."""
|
|
results = [
|
|
self.create_forward_backward_output({"clock_cycle:unique": 12345}),
|
|
self.create_forward_backward_output({"clock_cycle:unique": 12346}),
|
|
self.create_forward_backward_output({"clock_cycle:unique": 12347}),
|
|
]
|
|
|
|
reduced = _metrics_reduction(results)
|
|
|
|
# First value should be the main key
|
|
assert reduced["clock_cycle:unique"] == 12345
|
|
# Additional values should get suffixed keys
|
|
assert reduced["clock_cycle:unique_1"] == 12346
|
|
assert reduced["clock_cycle:unique_2"] == 12347
|
|
|
|
def test_unique_reduction_with_other_metrics(self):
|
|
"""Test unique reduction alongside other metric types."""
|
|
results = [
|
|
self.create_forward_backward_output(
|
|
{
|
|
"clock_cycle:unique": 100,
|
|
"loss:mean": 0.5,
|
|
"accuracy:max": 0.8,
|
|
}
|
|
),
|
|
self.create_forward_backward_output(
|
|
{
|
|
"clock_cycle:unique": 101,
|
|
"loss:mean": 0.6,
|
|
"accuracy:max": 0.9,
|
|
}
|
|
),
|
|
]
|
|
|
|
reduced = _metrics_reduction(results)
|
|
|
|
# Unique metric behavior
|
|
assert reduced["clock_cycle:unique"] == 100
|
|
assert reduced["clock_cycle:unique_1"] == 101
|
|
|
|
# Other metrics should work as normal
|
|
assert reduced["loss:mean"] == 0.55
|
|
assert reduced["accuracy:max"] == 0.9
|
|
|
|
def test_unique_reduction_empty_results(self):
|
|
"""Test unique reduction with empty results."""
|
|
results = []
|
|
reduced = _metrics_reduction(results)
|
|
assert reduced == {}
|
|
|
|
def test_unique_reduction_missing_metric(self):
|
|
"""Test unique reduction when some results don't have the metric."""
|
|
results = [
|
|
self.create_forward_backward_output({"clock_cycle:unique": 100}),
|
|
self.create_forward_backward_output({"other_metric:mean": 0.5}), # Missing clock_cycle
|
|
]
|
|
|
|
reduced = _metrics_reduction(results)
|
|
|
|
# Neither metric should be present since they're not in all results
|
|
# This matches the actual behavior of _metrics_reduction which requires
|
|
# all results to have a metric for it to be processed
|
|
assert "clock_cycle:unique" not in reduced
|
|
assert "other_metric:mean" not in reduced
|
|
assert len(reduced) == 0 # Should be empty
|
|
|
|
def test_unique_reduction_with_float_values(self):
|
|
"""Test unique reduction with float values."""
|
|
results = [
|
|
self.create_forward_backward_output({"timestamp:unique": 1234567890.123}),
|
|
self.create_forward_backward_output({"timestamp:unique": 1234567890.456}),
|
|
]
|
|
|
|
reduced = _metrics_reduction(results)
|
|
|
|
assert reduced["timestamp:unique"] == 1234567890.123
|
|
assert reduced["timestamp:unique_1"] == 1234567890.456
|
|
|
|
def test_invalid_reduction_type_raises_assertion(self):
|
|
"""Test that invalid reduction types raise AssertionError."""
|
|
results = [self.create_forward_backward_output({"invalid:nonexistent": 100})]
|
|
|
|
with pytest.raises(AssertionError, match="Invalid reduction nonexistent"):
|
|
_metrics_reduction(results)
|