patch custom RL loss with key detection

This commit is contained in:
Michael Zhang 2025-10-09 22:13:09 -07:00
parent 9ba155a34d
commit 882fefd30e

View file

@ -263,7 +263,24 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
import torch
# First do a forward pass and get logprobs
forward_future = await self.forward_async(data, "cross_entropy")
_loss_fn_input_keys = data[0].loss_fn_inputs.keys()
if "advantages" in _loss_fn_input_keys:
# Check for RL (`importance_sampling`) loss inputs
assert "advantages" in _loss_fn_input_keys, "advantages must be in loss_fn_inputs"
assert "target_tokens" in _loss_fn_input_keys, "target_tokens must be in loss_fn_inputs"
assert "logprobs" in _loss_fn_input_keys, "logprobs must be in loss_fn_inputs"
_loss_fn = "importance_sampling"
elif "weights" in _loss_fn_input_keys:
# Check for supervised learning loss inputs
assert "weights" in _loss_fn_input_keys, "weights must be in loss_fn_inputs"
assert "target_tokens" in _loss_fn_input_keys, "target_tokens must be in loss_fn_inputs"
_loss_fn = "cross_entropy"
else:
assert False, "Invalid loss function inputs"
# Compute on-policy logprobs
forward_future = await self.forward_async(data, _loss_fn)
forward_result = await forward_future.result_async()
logprobs_list: List[torch.Tensor] = []
for out in forward_result.loss_fn_outputs:
@ -280,7 +297,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
grads.append(logprob.grad)
linear_loss_data = []
for datum, grad in zip(data, grads):
for datum, grad in zip(data, grads, strict=True):
loss_fn_inputs: Any = {
"target_tokens": datum.loss_fn_inputs["target_tokens"],
"weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData)