mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-28 17:29:33 +00:00
patch custom RL loss with key detection
This commit is contained in:
parent
9ba155a34d
commit
882fefd30e
1 changed files with 19 additions and 2 deletions
|
|
@ -263,7 +263,24 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# First do a forward pass and get logprobs
|
# First do a forward pass and get logprobs
|
||||||
forward_future = await self.forward_async(data, "cross_entropy")
|
_loss_fn_input_keys = data[0].loss_fn_inputs.keys()
|
||||||
|
if "advantages" in _loss_fn_input_keys:
|
||||||
|
# Check for RL (`importance_sampling`) loss inputs
|
||||||
|
assert "advantages" in _loss_fn_input_keys, "advantages must be in loss_fn_inputs"
|
||||||
|
assert "target_tokens" in _loss_fn_input_keys, "target_tokens must be in loss_fn_inputs"
|
||||||
|
assert "logprobs" in _loss_fn_input_keys, "logprobs must be in loss_fn_inputs"
|
||||||
|
_loss_fn = "importance_sampling"
|
||||||
|
|
||||||
|
elif "weights" in _loss_fn_input_keys:
|
||||||
|
# Check for supervised learning loss inputs
|
||||||
|
assert "weights" in _loss_fn_input_keys, "weights must be in loss_fn_inputs"
|
||||||
|
assert "target_tokens" in _loss_fn_input_keys, "target_tokens must be in loss_fn_inputs"
|
||||||
|
_loss_fn = "cross_entropy"
|
||||||
|
else:
|
||||||
|
assert False, "Invalid loss function inputs"
|
||||||
|
|
||||||
|
# Compute on-policy logprobs
|
||||||
|
forward_future = await self.forward_async(data, _loss_fn)
|
||||||
forward_result = await forward_future.result_async()
|
forward_result = await forward_future.result_async()
|
||||||
logprobs_list: List[torch.Tensor] = []
|
logprobs_list: List[torch.Tensor] = []
|
||||||
for out in forward_result.loss_fn_outputs:
|
for out in forward_result.loss_fn_outputs:
|
||||||
|
|
@ -280,7 +297,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
grads.append(logprob.grad)
|
grads.append(logprob.grad)
|
||||||
|
|
||||||
linear_loss_data = []
|
linear_loss_data = []
|
||||||
for datum, grad in zip(data, grads):
|
for datum, grad in zip(data, grads, strict=True):
|
||||||
loss_fn_inputs: Any = {
|
loss_fn_inputs: Any = {
|
||||||
"target_tokens": datum.loss_fn_inputs["target_tokens"],
|
"target_tokens": datum.loss_fn_inputs["target_tokens"],
|
||||||
"weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData)
|
"weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue