robomimic-bc-loss
Otherrobomimicrigorous codebase
Description
Behavioral Cloning: Loss Function Design for Robot Imitation Learning
Research Question
Design an improved loss function for GMM-based behavioral cloning (BC) in robot manipulation. The policy outputs a Gaussian Mixture Model (GMM) distribution over actions, and the loss function receives this distribution along with expert demonstration actions. Your goal is to design a loss that maximizes imitation learning quality as measured by rollout success rate.
What You Can Modify
The CustomBCLoss class in custom_bc_loss.py. This class receives a GMM distribution and target action tensors and must return a scalar loss.
Interface:
- Input:
dist(atorch.distributions.MixtureSameFamilyGMM distribution with 5 modes) andtarget_actions: [B, 7]-- 7-dim robot actions (6D end-effector delta + 1D gripper) - Output: scalar loss tensor
- The default implementation is negative log-likelihood (NLL):
-dist.log_prob(target_actions).mean()
You may add parameters to __init__, define helper methods, and use any PyTorch operations. The dist object supports .log_prob(), .sample(), .component_distribution, and .mixture_distribution.
Evaluation
- Metric:
success_rate-- rollout success rate on the task (higher is better) - Tasks: Lift (pick up cube), Can (pick-and-place can), Square (nut assembly)
- Dataset: 200 proficient human demonstrations, low-dimensional observations
- Policy: GMM with 5 modes, 2-layer MLP backbone (1024, 1024) with ReLU, tanh-squashed means
- Training: 2000 epochs, Adam optimizer (lr=1e-4), batch size 100
- Rollout evaluation: 50 episodes per task, horizon 400 steps, every 50 epochs
Code
custom_bc_loss.py
EditableRead-only
1"""2Custom BC Loss Function for GMM-based Behavioral Cloning.34This module defines the loss function used by BC-GMM training in robomimic.5The loss receives the GMM distribution produced by the policy network and6the expert demonstration actions, and returns a scalar loss.78The custom loss is imported and used by the patched BC_GMM._compute_losses9method during training.10"""1112import torch13import torch.nn as nn14import torch.nn.functional as F15import torch.distributions as D
Results
| Model | Type | success rate tool hang ph ↑ | success rate can ph ↑ | success rate square ph ↑ |
|---|---|---|---|---|
| default | baseline | 0.147 | 0.853 | 0.733 |
| nll_entropy | baseline | 0.073 | 0.847 | 0.740 |
| weighted_nll | baseline | 0.027 | 0.413 | 0.480 |