Custom Losses: A 1-Hour Interview Learning Session
Custom Losses: A 1-Hour Interview Learning Session
Section titled “Custom Losses: A 1-Hour Interview Learning Session”Companion notebook: custom_losses_colab.ipynb
This lesson is for a Waymo-style ML Modeling & Fundamentals interview where you may need to explain a loss from first principles and implement a small PyTorch snippet in CoderPad.
You already know cross entropy and KL. The missing piece is not the formula. The missing piece is judgment:
What behavior do I want the model to learn, and what does my loss accidentally reward?
For autonomous driving simulation, that question matters because logs are imbalanced, futures are multi-modal, and the rare cases are often the cases we care about most.
0. One-hour plan
Section titled “0. One-hour plan”0-10 min What a loss really does: incentives, not just math10-20 min Imbalance: weighted cross entropy20-30 min Hard examples: focal loss30-45 min Multi-modal futures: why MSE fails and what replaces it45-55 min Implementation patterns and debugging55-60 min Interview answers and practice drillsBy the end, you should be able to:
- Explain why ordinary losses fail under imbalance and multi-modality.
- Derive weighted cross entropy and focal loss from cross entropy.
- Explain why MSE predicts the conditional mean.
- Implement focal loss and a multi-modal trajectory loss in PyTorch.
- Discuss tradeoffs: recall, calibration, realism, diversity, controllability, and latency.
1. The core idea: loss functions create incentives
Section titled “1. The core idea: loss functions create incentives”Suppose a model predicts whether a scene contains a dangerous cut-in.
Dataset:
normal lane-following: 99,000 examplesdangerous cut-in: 1,000 examplesA model that predicts “normal” for everything gets 99% accuracy. That is useless for safety.
The problem is not that cross entropy is mathematically wrong. The problem is that the empirical training distribution says normal examples dominate. If every example has equal weight, the optimizer spends most of its effort improving common cases.
In autonomous driving and robotics, the important mistakes are often:
- Rare pedestrian crossing.
- Aggressive merge.
- Cut-in.
- Red-light runner.
- Cyclist swerving.
- Vehicle reversing.
- Unprotected left turn.
- Construction-zone behavior.
Custom losses are how we say:
This type of mistake should count more, or this kind of output structure should be rewarded.
2. Weighted cross entropy
Section titled “2. Weighted cross entropy”Interview-level intuition
Section titled “Interview-level intuition”Weighted cross entropy is ordinary cross entropy where some classes matter more.
For one example with true class :
Weighted cross entropy:
where:
- is predicted probability for the true class.
- is the weight for the true class.
If cut-in examples are rare, assign them a larger weight. The gradient from a cut-in example becomes larger, so the optimizer pays more attention.
Why you need to care
Section titled “Why you need to care”In driving simulation, if your model underpredicts rare events, your simulator becomes too easy. The autonomy stack passes simulated tests but fails on long-tail real-world situations.
Weighted CE is useful when:
- The label space is discrete.
- Rare classes are important.
- You can tolerate some calibration distortion.
- You care about recall for minority classes.
How to choose weights
Section titled “How to choose weights”Naive inverse frequency:
where is class frequency.
But this can explode for rare noisy classes. Safer variants:
1. normalized inverse frequency2. sqrt inverse frequency3. clipped weights4. manually chosen business/safety weightsExample:
normal: 99%cut-in: 1%
raw inverse ratio: cut-in gets 99x weightmaybe too high
safer: cut-in gets 5x or 10x, then validate recall/calibrationMathematical effect
Section titled “Mathematical effect”For logits and softmax probabilities , cross entropy gradient is:
Weighted CE scales it:
So class weighting directly scales the update for examples of class .
Tradeoffs
Section titled “Tradeoffs”Pros:
- Simple.
- Easy to implement.
- Strong baseline for imbalance.
- Good interview answer.
Cons:
- Can hurt probability calibration.
- Can overfit rare noisy labels.
- Does not distinguish easy vs hard examples.
- Requires tuning weights.
3. Focal loss
Section titled “3. Focal loss”Interview-level intuition
Section titled “Interview-level intuition”Weighted CE says: “some classes matter more.”
Focal loss says: “hard examples matter more.”
In a huge driving dataset, many examples are easy:
- Empty lane.
- Stationary parked vehicle.
- Normal car following.
Once the model already predicts those correctly, continuing to spend lots of gradient budget on them is wasteful. Focal loss downweights easy examples and focuses training on examples the model currently struggles with.
Formula
Section titled “Formula”Cross entropy:
Focal loss:
where:
- is probability assigned to the true class.
- is optional class weight.
- controls how aggressively easy examples are downweighted.
If and :
The example contributes almost nothing.
If :
The hard example still matters.
Why you need to care
Section titled “Why you need to care”In perception, prediction, and simulation classifiers, easy background examples can dominate. Focal loss was popularized for dense object detection, where most anchors are background. The same idea appears in autonomous driving whenever easy negatives overwhelm rare positives.
Examples:
- “Is this actor likely to cut in?”
- “Is this pedestrian about to cross?”
- “Is this generated scenario invalid?”
- “Is this object a rare class?”
- “Is this trajectory mode safety-critical?”
Focal loss vs weighted CE
Section titled “Focal loss vs weighted CE”Weighted CE: class-level importance rare class gets larger gradient
Focal loss: example-level difficulty easy examples get smaller gradient
Weighted focal: both class importance and difficultyTradeoffs
Section titled “Tradeoffs”Pros:
- Useful when easy examples dominate.
- Improves attention to hard positives/negatives.
- Often improves minority recall.
Cons:
- Adds hyperparameter .
- Can under-train easy-but-important examples if too aggressive.
- Can hurt calibration.
- Hard examples may include mislabeled data.
Interview warning:
Focal loss focuses on hard examples, but hard examples are not always valuable. Some are just bad labels.
4. MSE and the multi-modal future problem
Section titled “4. MSE and the multi-modal future problem”Interview-level intuition
Section titled “Interview-level intuition”MSE is fine when the target is roughly unimodal. It is bad when many futures are plausible.
At an intersection, a car may:
- go straight,
- turn left,
- slow down,
- yield.
If the model predicts one trajectory and trains with MSE, the optimal prediction is the average future.
For driving, the average of valid futures can be invalid.
Future A: turn leftFuture B: go straightMSE average: cuts diagonally through the intersectionMathematical reason
Section titled “Mathematical reason”MSE loss:
The minimizer is:
So MSE learns the conditional mean.
That is perfect if the conditional distribution is one blob. It is bad if the distribution has multiple modes.
Toy example
Section titled “Toy example”Suppose:
The MSE-optimal prediction is:
But is never observed. In trajectory terms, this is the impossible average trajectory.
Why you need to care
Section titled “Why you need to care”Simulation needs plausible futures, not average futures. A simulator that averages away rare maneuvers will:
- under-test planners,
- reduce long-tail coverage,
- generate boring scenes,
- miss interaction diversity,
- produce physically invalid trajectories.
5. Multi-modal trajectory losses
Section titled “5. Multi-modal trajectory losses”Instead of predicting one trajectory, predict modes:
and probabilities:
Each trajectory:
where is future timesteps.
Winner-takes-all loss
Section titled “Winner-takes-all loss”Find the mode closest to the logged future:
Regression loss:
Mode classification loss:
Combined:
What this teaches
Section titled “What this teaches”This loss says:
- At least one mode should match the logged future.
- The model should assign high probability to that matching mode.
- Other modes are free to cover other plausible futures.
Why this is not perfect
Section titled “Why this is not perfect”The logged future is only one sample from the real future distribution. If the car went straight in the log, a left turn might still have been plausible. Winner-takes-all loss may not reward that left-turn mode unless the dataset contains similar scenes where the car turned left.
This is why trajectory prediction evaluation often uses:
- minADE: did any mode match?
- minFDE: did any final point match?
- miss rate: did all modes miss?
- probability-aware metrics: did the model rank the right mode high?
- realism metrics: are other modes valid?
6. Mixture density loss
Section titled “6. Mixture density loss”A more probabilistic approach is to predict a mixture distribution:
where:
- is probability of mode .
- is mean trajectory for mode .
- is uncertainty.
Negative log likelihood:
Why use it:
- More probabilistic.
- Can model uncertainty.
- Encourages calibrated probabilities.
Why it is harder:
- Numerical stability.
- Covariance parameterization.
- Mode collapse.
- Overconfident tiny variances.
In interviews, winner-takes-all multi-modal loss is usually easier to explain and implement. Mixture density loss is a good extension if asked.
7. Custom losses beyond driving
Section titled “7. Custom losses beyond driving”The same principles apply broadly.
Robotics:
- Grasp success is imbalanced.
- Contact-rich behavior is multi-modal.
- Imitation learning with MSE can average actions and fail.
Medical ML:
- Rare disease detection needs class weighting or focal loss.
- False negatives may be much worse than false positives.
Fraud:
- Rare positives dominate business value.
- Focal loss can help with many easy negatives.
Recommendation:
- Weighted losses encode business value.
- Pairwise/listwise losses may better match ranking objectives.
General lesson:
A good custom loss aligns the training signal with the real decision cost and output structure.
8. Implementation patterns
Section titled “8. Implementation patterns”Weighted cross entropy
Section titled “Weighted cross entropy”import torchimport torch.nn.functional as F
def weighted_cross_entropy(logits, targets, class_weights): return F.cross_entropy(logits, targets, weight=class_weights)Focal loss
Section titled “Focal loss”def focal_loss(logits, targets, alpha=None, gamma=2.0): log_probs = F.log_softmax(logits, dim=-1) probs = log_probs.exp()
row = torch.arange(targets.numel(), device=targets.device) pt = probs[row, targets] log_pt = log_probs[row, targets]
if alpha is None: alpha_t = 1.0 else: alpha_t = alpha[targets]
loss = -alpha_t * (1.0 - pt).pow(gamma) * log_pt return loss.mean()Multi-modal trajectory loss
Section titled “Multi-modal trajectory loss”def multimodal_trajectory_loss(pred_trajs, mode_logits, target, cls_weight=1.0): """ pred_trajs: [B, K, T, 2] mode_logits: [B, K] target: [B, T, 2] """ diff = pred_trajs - target[:, None, :, :] sq_error = (diff ** 2).sum(dim=(-1, -2)) # [B, K] best_mode = sq_error.argmin(dim=1) # [B]
reg_loss = sq_error[torch.arange(target.size(0), device=target.device), best_mode].mean() cls_loss = F.cross_entropy(mode_logits, best_mode) return reg_loss + cls_weight * cls_lossCoderPad tip
Section titled “CoderPad tip”In an interview, first print shapes:
print(pred_trajs.shape) # [B, K, T, 2]print(target.shape) # [B, T, 2]Most loss bugs are shape bugs.
9. Debugging checklist
Section titled “9. Debugging checklist”Weighted CE
Section titled “Weighted CE”- Are class weights on the same device as logits?
- Are weights too extreme?
- Did minority recall improve?
- Did precision collapse?
- Did calibration get worse?
Focal loss
Section titled “Focal loss”- Is equivalent to weighted CE?
- Are hard examples actually valid labels?
- Is the loss ignoring too many easy examples?
- Did rare recall improve?
- Did probability calibration degrade?
Multi-modal trajectory loss
Section titled “Multi-modal trajectory loss”- Are all modes collapsing to the same trajectory?
- Is one mode always winning?
- Are mode probabilities calibrated?
- Does minADE improve while probability-aware metrics get worse?
- Are trajectories physically valid?
- Do they stay on map?
- Do they collide unrealistically?
General custom loss
Section titled “General custom loss”- Can the model overfit one small batch?
- Does each term have comparable scale?
- Does the loss decrease while the real metric does not?
- Are units consistent: meters, seconds, radians?
- Are labels noisy or multi-modal?
10. Common interview questions and strong answers
Section titled “10. Common interview questions and strong answers”Q: Why use weighted cross entropy?
A: Because the empirical distribution may not match the cost distribution. In driving, rare cut-ins or pedestrian events may be more important than common normal driving. Weighting scales gradients for those examples.
Q: Why not always use huge rare-class weights?
A: Rare labels may be noisy, and large weights can overfit noise or destroy calibration. I would tune weights and inspect precision/recall by class.
Q: What does focal loss add beyond class weighting?
A: It downweights easy examples based on the model’s current confidence. Class weighting is class-level; focal loss is example-difficulty-level.
Q: Why is MSE bad for future prediction?
A: MSE learns the conditional mean. If futures are multi-modal, the mean can be an invalid trajectory that no agent would actually take.
Q: How would you train a model to predict multiple futures?
A: Predict trajectories and mode probabilities. Use a min-over-modes regression loss to train the closest mode, plus cross entropy to assign probability to that mode.
Q: What metrics would you check?
A: For classification, per-class precision/recall and calibration. For trajectories, minADE/minFDE, miss rate, mode entropy, probability calibration, collision, offroad, and realism.
11. A 60-second explanation you can say out loud
Section titled “11. A 60-second explanation you can say out loud”Custom losses are about aligning optimization with the real cost and structure of the problem. In autonomous driving, data is imbalanced and futures are multi-modal. Weighted cross entropy increases the gradient contribution of rare important classes like cut-ins. Focal loss goes further by downweighting easy examples, so training focuses on examples the model currently gets wrong. For trajectory prediction, plain MSE is dangerous because it predicts the conditional mean; the average of left-turn and straight futures may be physically invalid. A better approach predicts multiple trajectory modes and probabilities, trains the closest mode to the logged future, and also trains the model to assign that mode high probability. Then I would debug with per-class metrics, mode collapse checks, calibration, and physical validity metrics.
12. Practice exercises with answers
Section titled “12. Practice exercises with answers”Exercise 1
Section titled “Exercise 1”You have 99% normal driving and 1% cut-in. Why can unweighted CE fail?
Answer: The optimizer sees far more normal examples, so a model can get low loss and high accuracy while missing cut-ins. The training distribution underweights the safety-critical class.
Exercise 2
Section titled “Exercise 2”For focal loss with and , what is the modulating factor?
Answer:
The easy example is downweighted by 100x.
Exercise 3
Section titled “Exercise 3”Why does MSE predict an impossible average trajectory?
Answer: MSE minimizes squared error, whose optimum is the conditional mean. If valid futures are left and straight, their mean may cut through a lane boundary or curb.
Exercise 4
Section titled “Exercise 4”In a 6-mode trajectory model, mode 0 wins for 95% of examples. What do you inspect?
Answer: Mode collapse, initialization, diversity regularization, mode classification weight, whether other modes are identical, and whether the data actually contains diverse futures.
Exercise 5
Section titled “Exercise 5”Weighted CE improves rare-class recall but precision collapses. What happened?
Answer: The model may be overpredicting the rare class. The weight may be too high, labels may be noisy, or the decision threshold needs tuning.