mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
patch custom RL loss with key detection
This commit is contained in:
parent
9ba155a34d
commit
882fefd30e
1 changed files with 19 additions and 2 deletions
|
|
@ -263,7 +263,24 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
import torch
|
||||
|
||||
# First do a forward pass and get logprobs
|
||||
forward_future = await self.forward_async(data, "cross_entropy")
|
||||
_loss_fn_input_keys = data[0].loss_fn_inputs.keys()
|
||||
if "advantages" in _loss_fn_input_keys:
|
||||
# Check for RL (`importance_sampling`) loss inputs
|
||||
assert "advantages" in _loss_fn_input_keys, "advantages must be in loss_fn_inputs"
|
||||
assert "target_tokens" in _loss_fn_input_keys, "target_tokens must be in loss_fn_inputs"
|
||||
assert "logprobs" in _loss_fn_input_keys, "logprobs must be in loss_fn_inputs"
|
||||
_loss_fn = "importance_sampling"
|
||||
|
||||
elif "weights" in _loss_fn_input_keys:
|
||||
# Check for supervised learning loss inputs
|
||||
assert "weights" in _loss_fn_input_keys, "weights must be in loss_fn_inputs"
|
||||
assert "target_tokens" in _loss_fn_input_keys, "target_tokens must be in loss_fn_inputs"
|
||||
_loss_fn = "cross_entropy"
|
||||
else:
|
||||
assert False, "Invalid loss function inputs"
|
||||
|
||||
# Compute on-policy logprobs
|
||||
forward_future = await self.forward_async(data, _loss_fn)
|
||||
forward_result = await forward_future.result_async()
|
||||
logprobs_list: List[torch.Tensor] = []
|
||||
for out in forward_result.loss_fn_outputs:
|
||||
|
|
@ -280,7 +297,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
grads.append(logprob.grad)
|
||||
|
||||
linear_loss_data = []
|
||||
for datum, grad in zip(data, grads):
|
||||
for datum, grad in zip(data, grads, strict=True):
|
||||
loss_fn_inputs: Any = {
|
||||
"target_tokens": datum.loss_fn_inputs["target_tokens"],
|
||||
"weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue