Skip to content

Custom Losses: A 1-Hour Interview Learning Session

Custom Losses: A 1-Hour Interview Learning Session

Section titled “Custom Losses: A 1-Hour Interview Learning Session”

Companion notebook: custom_losses_colab.ipynb

This lesson is for a Waymo-style ML Modeling & Fundamentals interview where you may need to explain a loss from first principles and implement a small PyTorch snippet in CoderPad.

You already know cross entropy and KL. The missing piece is not the formula. The missing piece is judgment:

What behavior do I want the model to learn, and what does my loss accidentally reward?

For autonomous driving simulation, that question matters because logs are imbalanced, futures are multi-modal, and the rare cases are often the cases we care about most.

0-10 min What a loss really does: incentives, not just math
10-20 min Imbalance: weighted cross entropy
20-30 min Hard examples: focal loss
30-45 min Multi-modal futures: why MSE fails and what replaces it
45-55 min Implementation patterns and debugging
55-60 min Interview answers and practice drills

By the end, you should be able to:

  • Explain why ordinary losses fail under imbalance and multi-modality.
  • Derive weighted cross entropy and focal loss from cross entropy.
  • Explain why MSE predicts the conditional mean.
  • Implement focal loss and a multi-modal trajectory loss in PyTorch.
  • Discuss tradeoffs: recall, calibration, realism, diversity, controllability, and latency.

1. The core idea: loss functions create incentives

Section titled “1. The core idea: loss functions create incentives”

Suppose a model predicts whether a scene contains a dangerous cut-in.

Dataset:

normal lane-following: 99,000 examples
dangerous cut-in: 1,000 examples

A model that predicts “normal” for everything gets 99% accuracy. That is useless for safety.

The problem is not that cross entropy is mathematically wrong. The problem is that the empirical training distribution says normal examples dominate. If every example has equal weight, the optimizer spends most of its effort improving common cases.

In autonomous driving and robotics, the important mistakes are often:

  • Rare pedestrian crossing.
  • Aggressive merge.
  • Cut-in.
  • Red-light runner.
  • Cyclist swerving.
  • Vehicle reversing.
  • Unprotected left turn.
  • Construction-zone behavior.

Custom losses are how we say:

This type of mistake should count more, or this kind of output structure should be rewarded.


Weighted cross entropy is ordinary cross entropy where some classes matter more.

For one example with true class yy:

LCE=logpy\mathcal{L}_{CE} = -\log p_y

Weighted cross entropy:

LWCE=wylogpy\mathcal{L}_{WCE} = -w_y \log p_y

where:

  • pyp_y is predicted probability for the true class.
  • wyw_y is the weight for the true class.

If cut-in examples are rare, assign them a larger weight. The gradient from a cut-in example becomes larger, so the optimizer pays more attention.

In driving simulation, if your model underpredicts rare events, your simulator becomes too easy. The autonomy stack passes simulated tests but fails on long-tail real-world situations.

Weighted CE is useful when:

  • The label space is discrete.
  • Rare classes are important.
  • You can tolerate some calibration distortion.
  • You care about recall for minority classes.

Naive inverse frequency:

wc=1fcw_c = \frac{1}{f_c}

where fcf_c is class frequency.

But this can explode for rare noisy classes. Safer variants:

1. normalized inverse frequency
2. sqrt inverse frequency
3. clipped weights
4. manually chosen business/safety weights

Example:

normal: 99%
cut-in: 1%
raw inverse ratio: cut-in gets 99x weight
maybe too high
safer: cut-in gets 5x or 10x, then validate recall/calibration

For logits zz and softmax probabilities pp, cross entropy gradient is:

Lzj=pj1[j=y]\frac{\partial \mathcal{L}}{\partial z_j} = p_j - \mathbf{1}[j=y]

Weighted CE scales it:

LWCEzj=wy(pj1[j=y])\frac{\partial \mathcal{L}_{WCE}}{\partial z_j} = w_y(p_j - \mathbf{1}[j=y])

So class weighting directly scales the update for examples of class yy.

Pros:

  • Simple.
  • Easy to implement.
  • Strong baseline for imbalance.
  • Good interview answer.

Cons:

  • Can hurt probability calibration.
  • Can overfit rare noisy labels.
  • Does not distinguish easy vs hard examples.
  • Requires tuning weights.

Weighted CE says: “some classes matter more.”

Focal loss says: “hard examples matter more.”

In a huge driving dataset, many examples are easy:

  • Empty lane.
  • Stationary parked vehicle.
  • Normal car following.

Once the model already predicts those correctly, continuing to spend lots of gradient budget on them is wasteful. Focal loss downweights easy examples and focuses training on examples the model currently struggles with.

Cross entropy:

LCE=logpy\mathcal{L}_{CE} = -\log p_y

Focal loss:

Lfocal=αy(1py)γlogpy\mathcal{L}_{focal} = -\alpha_y(1-p_y)^\gamma \log p_y

where:

  • pyp_y is probability assigned to the true class.
  • αy\alpha_y is optional class weight.
  • γ\gamma controls how aggressively easy examples are downweighted.

If py=0.95p_y=0.95 and γ=2\gamma=2:

(1py)γ=0.052=0.0025(1-p_y)^\gamma = 0.05^2 = 0.0025

The example contributes almost nothing.

If py=0.2p_y=0.2:

(1py)2=0.64(1-p_y)^2 = 0.64

The hard example still matters.

In perception, prediction, and simulation classifiers, easy background examples can dominate. Focal loss was popularized for dense object detection, where most anchors are background. The same idea appears in autonomous driving whenever easy negatives overwhelm rare positives.

Examples:

  • “Is this actor likely to cut in?”
  • “Is this pedestrian about to cross?”
  • “Is this generated scenario invalid?”
  • “Is this object a rare class?”
  • “Is this trajectory mode safety-critical?”
Weighted CE:
class-level importance
rare class gets larger gradient
Focal loss:
example-level difficulty
easy examples get smaller gradient
Weighted focal:
both class importance and difficulty

Pros:

  • Useful when easy examples dominate.
  • Improves attention to hard positives/negatives.
  • Often improves minority recall.

Cons:

  • Adds hyperparameter γ\gamma.
  • Can under-train easy-but-important examples if too aggressive.
  • Can hurt calibration.
  • Hard examples may include mislabeled data.

Interview warning:

Focal loss focuses on hard examples, but hard examples are not always valuable. Some are just bad labels.


MSE is fine when the target is roughly unimodal. It is bad when many futures are plausible.

At an intersection, a car may:

  1. go straight,
  2. turn left,
  3. slow down,
  4. yield.

If the model predicts one trajectory and trains with MSE, the optimal prediction is the average future.

For driving, the average of valid futures can be invalid.

Future A: turn left
Future B: go straight
MSE average: cuts diagonally through the intersection

MSE loss:

L(y^)=E[(Yy^)2X=x]\mathcal{L}(\hat{y}) = \mathbb{E}[(Y-\hat{y})^2|X=x]

The minimizer is:

y^=E[YX=x]\hat{y}^* = \mathbb{E}[Y|X=x]

So MSE learns the conditional mean.

That is perfect if the conditional distribution is one blob. It is bad if the distribution has multiple modes.

Suppose:

Y={1with probability 0.51with probability 0.5Y = \begin{cases} -1 & \text{with probability } 0.5 \\ 1 & \text{with probability } 0.5 \end{cases}

The MSE-optimal prediction is:

E[Y]=0\mathbb{E}[Y] = 0

But 00 is never observed. In trajectory terms, this is the impossible average trajectory.

Simulation needs plausible futures, not average futures. A simulator that averages away rare maneuvers will:

  • under-test planners,
  • reduce long-tail coverage,
  • generate boring scenes,
  • miss interaction diversity,
  • produce physically invalid trajectories.

Instead of predicting one trajectory, predict KK modes:

Y^(1),,Y^(K)\hat{Y}^{(1)}, \dots, \hat{Y}^{(K)}

and probabilities:

π1,,πK\pi_1, \dots, \pi_K

Each trajectory:

Y^(k)RT×2\hat{Y}^{(k)} \in \mathbb{R}^{T \times 2}

where TT is future timesteps.

Find the mode closest to the logged future:

k=argminkt=1TY^t(k)Yt22k^* = \arg\min_k \sum_{t=1}^T \|\hat{Y}^{(k)}_t - Y_t\|_2^2

Regression loss:

Lreg=t=1TY^t(k)Yt22\mathcal{L}_{reg} = \sum_{t=1}^T \|\hat{Y}^{(k^*)}_t - Y_t\|_2^2

Mode classification loss:

Lmode=logπk\mathcal{L}_{mode} = -\log \pi_{k^*}

Combined:

L=Lreg+λLmode\mathcal{L} = \mathcal{L}_{reg} + \lambda \mathcal{L}_{mode}

This loss says:

  1. At least one mode should match the logged future.
  2. The model should assign high probability to that matching mode.
  3. Other modes are free to cover other plausible futures.

The logged future is only one sample from the real future distribution. If the car went straight in the log, a left turn might still have been plausible. Winner-takes-all loss may not reward that left-turn mode unless the dataset contains similar scenes where the car turned left.

This is why trajectory prediction evaluation often uses:

  • minADE: did any mode match?
  • minFDE: did any final point match?
  • miss rate: did all modes miss?
  • probability-aware metrics: did the model rank the right mode high?
  • realism metrics: are other modes valid?

A more probabilistic approach is to predict a mixture distribution:

p(YX)=k=1KπkN(Y;μk,Σk)p(Y|X) = \sum_{k=1}^K \pi_k \mathcal{N}(Y; \mu_k, \Sigma_k)

where:

  • πk\pi_k is probability of mode kk.
  • μk\mu_k is mean trajectory for mode kk.
  • Σk\Sigma_k is uncertainty.

Negative log likelihood:

LNLL=logk=1KπkN(Y;μk,Σk)\mathcal{L}_{NLL} = -\log \sum_{k=1}^K \pi_k \mathcal{N}(Y; \mu_k, \Sigma_k)

Why use it:

  • More probabilistic.
  • Can model uncertainty.
  • Encourages calibrated probabilities.

Why it is harder:

  • Numerical stability.
  • Covariance parameterization.
  • Mode collapse.
  • Overconfident tiny variances.

In interviews, winner-takes-all multi-modal loss is usually easier to explain and implement. Mixture density loss is a good extension if asked.


The same principles apply broadly.

Robotics:

  • Grasp success is imbalanced.
  • Contact-rich behavior is multi-modal.
  • Imitation learning with MSE can average actions and fail.

Medical ML:

  • Rare disease detection needs class weighting or focal loss.
  • False negatives may be much worse than false positives.

Fraud:

  • Rare positives dominate business value.
  • Focal loss can help with many easy negatives.

Recommendation:

  • Weighted losses encode business value.
  • Pairwise/listwise losses may better match ranking objectives.

General lesson:

A good custom loss aligns the training signal with the real decision cost and output structure.


import torch
import torch.nn.functional as F
def weighted_cross_entropy(logits, targets, class_weights):
return F.cross_entropy(logits, targets, weight=class_weights)
def focal_loss(logits, targets, alpha=None, gamma=2.0):
log_probs = F.log_softmax(logits, dim=-1)
probs = log_probs.exp()
row = torch.arange(targets.numel(), device=targets.device)
pt = probs[row, targets]
log_pt = log_probs[row, targets]
if alpha is None:
alpha_t = 1.0
else:
alpha_t = alpha[targets]
loss = -alpha_t * (1.0 - pt).pow(gamma) * log_pt
return loss.mean()
def multimodal_trajectory_loss(pred_trajs, mode_logits, target, cls_weight=1.0):
"""
pred_trajs: [B, K, T, 2]
mode_logits: [B, K]
target: [B, T, 2]
"""
diff = pred_trajs - target[:, None, :, :]
sq_error = (diff ** 2).sum(dim=(-1, -2)) # [B, K]
best_mode = sq_error.argmin(dim=1) # [B]
reg_loss = sq_error[torch.arange(target.size(0), device=target.device), best_mode].mean()
cls_loss = F.cross_entropy(mode_logits, best_mode)
return reg_loss + cls_weight * cls_loss

In an interview, first print shapes:

print(pred_trajs.shape) # [B, K, T, 2]
print(target.shape) # [B, T, 2]

Most loss bugs are shape bugs.


  • Are class weights on the same device as logits?
  • Are weights too extreme?
  • Did minority recall improve?
  • Did precision collapse?
  • Did calibration get worse?
  • Is γ=0\gamma=0 equivalent to weighted CE?
  • Are hard examples actually valid labels?
  • Is the loss ignoring too many easy examples?
  • Did rare recall improve?
  • Did probability calibration degrade?
  • Are all modes collapsing to the same trajectory?
  • Is one mode always winning?
  • Are mode probabilities calibrated?
  • Does minADE improve while probability-aware metrics get worse?
  • Are trajectories physically valid?
  • Do they stay on map?
  • Do they collide unrealistically?
  • Can the model overfit one small batch?
  • Does each term have comparable scale?
  • Does the loss decrease while the real metric does not?
  • Are units consistent: meters, seconds, radians?
  • Are labels noisy or multi-modal?

10. Common interview questions and strong answers

Section titled “10. Common interview questions and strong answers”

Q: Why use weighted cross entropy?
A: Because the empirical distribution may not match the cost distribution. In driving, rare cut-ins or pedestrian events may be more important than common normal driving. Weighting scales gradients for those examples.

Q: Why not always use huge rare-class weights?
A: Rare labels may be noisy, and large weights can overfit noise or destroy calibration. I would tune weights and inspect precision/recall by class.

Q: What does focal loss add beyond class weighting?
A: It downweights easy examples based on the model’s current confidence. Class weighting is class-level; focal loss is example-difficulty-level.

Q: Why is MSE bad for future prediction?
A: MSE learns the conditional mean. If futures are multi-modal, the mean can be an invalid trajectory that no agent would actually take.

Q: How would you train a model to predict multiple futures?
A: Predict KK trajectories and mode probabilities. Use a min-over-modes regression loss to train the closest mode, plus cross entropy to assign probability to that mode.

Q: What metrics would you check?
A: For classification, per-class precision/recall and calibration. For trajectories, minADE/minFDE, miss rate, mode entropy, probability calibration, collision, offroad, and realism.


11. A 60-second explanation you can say out loud

Section titled “11. A 60-second explanation you can say out loud”

Custom losses are about aligning optimization with the real cost and structure of the problem. In autonomous driving, data is imbalanced and futures are multi-modal. Weighted cross entropy increases the gradient contribution of rare important classes like cut-ins. Focal loss goes further by downweighting easy examples, so training focuses on examples the model currently gets wrong. For trajectory prediction, plain MSE is dangerous because it predicts the conditional mean; the average of left-turn and straight futures may be physically invalid. A better approach predicts multiple trajectory modes and probabilities, trains the closest mode to the logged future, and also trains the model to assign that mode high probability. Then I would debug with per-class metrics, mode collapse checks, calibration, and physical validity metrics.


You have 99% normal driving and 1% cut-in. Why can unweighted CE fail?

Answer: The optimizer sees far more normal examples, so a model can get low loss and high accuracy while missing cut-ins. The training distribution underweights the safety-critical class.

For focal loss with py=0.9p_y=0.9 and γ=2\gamma=2, what is the modulating factor?

Answer:

(10.9)2=0.01(1-0.9)^2 = 0.01

The easy example is downweighted by 100x.

Why does MSE predict an impossible average trajectory?

Answer: MSE minimizes squared error, whose optimum is the conditional mean. If valid futures are left and straight, their mean may cut through a lane boundary or curb.

In a 6-mode trajectory model, mode 0 wins for 95% of examples. What do you inspect?

Answer: Mode collapse, initialization, diversity regularization, mode classification weight, whether other modes are identical, and whether the data actually contains diverse futures.

Weighted CE improves rare-class recall but precision collapses. What happened?

Answer: The model may be overpredicting the rare class. The weight may be too high, labels may be noisy, or the decision threshold needs tuning.