Documentation Index
Fetch the complete documentation index at: https://fireworks.ai/docs/llms.txt
Use this file to discover all available pages before exploring further.
What this is
This guide walks through DPO (Direct Preference Optimization) training using the cookbook. DPO learns from preference pairs (chosen vs. rejected responses) without a separate reward model.
How DPO differs from GRPO
| DPO | GRPO |
|---|
| Trainer jobs | 2 (policy + frozen reference) | 2 (policy + frozen reference) |
| Data | Preference pairs (chosen/rejected) | Prompts + reward function |
| Reference logprobs | Cached once at initialization | Computed every step |
| Loss | -log(sigmoid(beta * margin)) | Advantage-weighted policy gradient + KL |
Architecture
Using the recipe
from training.recipes.dpo_loop import Config, main
from training.utils import DeployConfig, InfraConfig, WandBConfig
cfg = Config(
log_path="./dpo_logs",
base_model="accounts/fireworks/models/qwen3-8b",
dataset="/path/to/preference_data.jsonl",
tokenizer_model="Qwen/Qwen3-8B",
beta=0.1,
epochs=1,
batch_size=4,
max_seq_len=4096,
infra=InfraConfig(
training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200",
ref_training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200-forward",
),
deployment=DeployConfig(
deployment_id="dpo-serving",
),
wandb=WandBConfig(entity="my-team", project="dpo-experiment"),
)
main(cfg)
DPO expects preference pairs. Supported formats:
Format 1 — chosen/rejected messages:
{
"chosen": {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "good response"}]},
"rejected": {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "bad response"}]}
}
Format 2 — input/output split:
{
"input": {"messages": [{"role": "user", "content": "..."}]},
"preferred_output": [{"role": "assistant", "content": "good"}],
"non_preferred_output": [{"role": "assistant", "content": "bad"}]
}
Step-by-step (API-level)
Provision trainers with setup_infra
DPO needs both a policy trainer and a forward-only reference trainer.
training.utils.rl.setup_infra handles shape resolution, parallel
provisioning of both trainers, and the LoRA shared-reference shortcut
(when lora_rank > 0, no separate reference trainer is needed — the
reference comes from the policy session’s base handle).
import os
from fireworks.training.sdk import TrainerJobManager, DeploymentManager
from training.utils import InfraConfig, ResourceCleanup
from training.utils.rl import setup_infra
api_key = os.environ["FIREWORKS_API_KEY"]
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")
rlor_mgr = TrainerJobManager(api_key=api_key, base_url=base_url)
deploy_mgr = DeploymentManager(api_key=api_key, base_url=base_url)
base_model = "accounts/fireworks/models/qwen3-8b"
infra_cfg = InfraConfig(
training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200",
ref_training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200-forward",
)
with ResourceCleanup(rlor_mgr) as cleanup:
infra = setup_infra(
rlor_mgr=rlor_mgr,
deploy_mgr=None,
base_model=base_model,
infra_cfg=infra_cfg,
lora_rank=0,
needs_reference=True, # DPO always needs reference
needs_inference=False, # no rollouts in DPO
role_prefix="dpo",
api_key=api_key,
cleanup=cleanup,
)
policy_client = infra.policy
reference_client = infra.reference
Cache reference logprobs
Reference logprobs are computed once at initialization and reused throughout training:
ref_cache = {}
for i, (chosen_tokens, rejected_tokens, prompt_len) in enumerate(dataset):
chosen_datum, rejected_datum = build_dpo_datums(
chosen_tokens, rejected_tokens, prompt_len, max_seq_len=4096,
)
fwd = reference_client.forward([chosen_datum, rejected_datum], "cross_entropy")
ref_cache[i] = {
"ref_chosen": fwd.loss_fn_outputs[0]["logprobs"].data,
"ref_rejected": fwd.loss_fn_outputs[1]["logprobs"].data,
"chosen_tokens": chosen_tokens,
"rejected_tokens": rejected_tokens,
"prompt_len": prompt_len,
}
DPO loss function
import torch
import torch.nn.functional as F
def make_dpo_loss_fn(ref_chosen_logprobs, ref_rejected_logprobs, beta=0.1):
ref_chosen_t = torch.tensor(ref_chosen_logprobs, dtype=torch.float32)
ref_rejected_t = torch.tensor(ref_rejected_logprobs, dtype=torch.float32)
def loss_fn(data, logprobs_list):
pi_chosen, pi_rejected = logprobs_list[0], logprobs_list[1]
chosen_weights = torch.tensor(data[0].loss_fn_inputs["weights"].data, dtype=torch.float32)
rejected_weights = torch.tensor(data[1].loss_fn_inputs["weights"].data, dtype=torch.float32)
pi_chosen_sum = torch.dot(pi_chosen.float(), chosen_weights)
pi_rejected_sum = torch.dot(pi_rejected.float(), rejected_weights)
ref_chosen_sum = torch.dot(ref_chosen_t.float(), chosen_weights)
ref_rejected_sum = torch.dot(ref_rejected_t.float(), rejected_weights)
margin = (pi_chosen_sum - ref_chosen_sum) - (pi_rejected_sum - ref_rejected_sum)
dpo_loss = -F.logsigmoid(beta * margin)
with torch.no_grad():
accuracy = 1.0 if margin.item() > 0 else 0.0
return dpo_loss, {"dpo_loss": dpo_loss.item(), "margin": margin.item(), "accuracy": accuracy}
return loss_fn
Training loop
step = 0
accum_count = 0
grad_accum = 4
for idx in ref_cache:
cached = ref_cache[idx]
chosen_datum, rejected_datum = build_dpo_datums(
cached["chosen_tokens"], cached["rejected_tokens"],
cached["prompt_len"], max_seq_len=4096,
)
loss_fn = make_dpo_loss_fn(
ref_chosen_logprobs=cached["ref_chosen"],
ref_rejected_logprobs=cached["ref_rejected"],
beta=0.1,
)
result = policy_client.forward_backward_custom([chosen_datum, rejected_datum], loss_fn)
accum_count += 1
if accum_count >= grad_accum:
policy_client.optim_step(
tinker.AdamParams(learning_rate=1e-5, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01)
)
step += 1
accum_count = 0
print(f"Step {step}: {result.metrics}")
Operational guidance
- Set
infra.training_shape_id and infra.ref_training_shape_id — DPO launches both a policy trainer and a reference trainer.
- DPO uses 2 RLOR jobs — policy trainer + frozen reference trainer.
- DPO defaults
weight_sync_interval=0 (no weight sync by default), unlike GRPO.
- Keep a versioned reference cache tied to tokenizer + base model revision. If the base model changes, recompute reference logprobs.
- Monitor margin statistics: increasing margins indicate the policy is learning preferences.
- DCP checkpoints are disabled by default (
dcp_save_interval=0). If you need to resume training from a checkpoint, explicitly set dcp_save_interval to a positive value in your WeightSyncConfig.
Common pitfalls
- Mismatched formatting between chosen/rejected sequences corrupts preference signals — ensure identical prompt prefixes.
- Stale reference cache: If you warm-start from a different model, cached reference logprobs are invalid.
- ORPO (
training.recipes.orpo_loop) — Odds Ratio Preference Optimization. Combines an SFT-style negative-log-likelihood term on the chosen response with a margin term on the odds ratio between chosen and rejected. Unlike DPO, ORPO does not require a reference trainer (no cached reference logprobs), so the recipe runs with a single trainer + dataset of preference pairs. See training.recipes.orpo_loop in the public cookbook repo for the full configuration.