import tinker
def sft_loss(data, logprobs_list):
total_loss = torch.tensor(0.0)
n_tokens = 0
for i, logprobs in enumerate(logprobs_list):
weights = torch.tensor(
data[i].loss_fn_inputs["weights"].data, dtype=torch.float32,
)
min_len = min(len(logprobs), len(weights))
total_loss = total_loss - torch.dot(
logprobs[:min_len].float(), weights[:min_len],
)
n_tokens += weights[:min_len].sum().item()
loss = total_loss / max(n_tokens, 1)
return loss, {"sft_loss": loss.item(), "n_tokens": n_tokens}
batch = [datum]
for step in range(10):
result = training_client.forward_backward_custom(batch, sft_loss).result()
training_client.optim_step(
tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
).result()
print(f"Step {step}: {result.metrics}")