Agent Conversation: google/gemini-3.1-pro-preview

Model: google/gemini-3.1-pro-preview
Total steps: 7
Messages: 15
UserStep 0

Task: cv-diffusion-architecture

Diffusion Model Architecture Design

Background

The UNet backbone is the standard architecture for denoising diffusion models. Key design choices include:

  • Block types: Whether to use pure convolutional blocks (DownBlock2D / UpBlock2D) or blocks with self-attention (AttnDownBlock2D / AttnUpBlock2D), and at which resolution levels.
  • Attention placement: Self-attention is expensive at high resolutions (32x32) but may improve global coherence. The original DDPM places attention only at 16x16.
  • Depth and normalization: layers_per_block, norm_num_groups, attention_head_dim, and other structural hyperparameters.
  • Custom modules: Entirely new backbone designs (e.g., hybrid convolution-transformer, gated blocks, multi-scale fusion) as long as they satisfy the input/output interface.

Research Question

What UNet architecture achieves the best FID on unconditional CIFAR-10 diffusion, given a fixed training procedure (epsilon prediction, DDIM sampling, same optimizer and schedule)?

Task

You are given custom_train.py, a self-contained unconditional DDPM training script on CIFAR-10. Everything is fixed except the build_model(device) function.

Your goal is to design a model architecture that achieves lower FID than the baselines. The model must satisfy:

  • Input: (x, timestep) where x is [B, 3, 32, 32], timestep is [B]
  • Output: an object with .sample attribute of shape [B, 3, 32, 32]
  • UNet2DModel from diffusers satisfies this interface, but you may also build a fully custom nn.Module.

Channel widths are provided via the BLOCK_OUT_CHANNELS environment variable (e.g. "128,256,256,256") so the same architecture scales across evaluation tiers. LAYERS_PER_BLOCK (default 2) is also available.

Evaluation

  • Dataset: CIFAR-10 (32x32, unconditional)
  • Training: fixed epsilon prediction, MSE loss, AdamW lr=2e-4, EMA
  • Model scales:
    • Small: block_out_channels=(64,128,128,128), ~9M params, batch 128
    • Medium: block_out_channels=(128,256,256,256), ~36M params, batch 128
    • Large: block_out_channels=(256,512,512,512), ~140M params, batch 64
  • Training: 35,000 steps per scale, EMA rate 0.9995
  • Metric: FID (lower is better), computed with clean-fid against CIFAR-10 train set (50,000 samples)
  • Inference: 50-step DDIM sampling

Baselines

All three baselines use modern UNet improvements: scale-shift FiLM time conditioning (resnet_time_scale_shift="scale_shift"), flip_sin_to_cos=True, freq_shift=0, and explicit multi-head attention. They differ in attention placement and head capacity:

  1. iddpm (Improved DDPM, Nichol & Dhariwal 2021): self-attention at the 16x16 and 8x8 resolutions (2 levels), attention_head_dim=8. Modest compute, focused attention on the mid-deep features.
  2. adm (ADM "Diffusion Models Beat GANs", Dhariwal & Nichol 2021): self-attention at 16x16, 8x8, and 4x4 (
AssistantStep 1edit
Tool Input
1{
2  "end_line": 58,
3  "content": "def build_model(device):\n    import diffusers\n    import os\n    import torch\n    import torch.nn as nn\n    import torch.nn.functional as F\n\n    class TGCBlock(nn.Module):\n        \"\"\"Time-conditioned Global Context Block.\"\"\"\n        def __init__(self, in_channels, time_emb_dim, ratio=4):\n            super().__init__()\n            mid_channels = max(16, in_channels // ratio)\n            self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)\n            self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)\n            self.norm = nn.GroupNorm(1, mid_channels)\n            self.time_proj = nn.Linear(time_emb_dim, mid_channels)\n            self.conv2 = nn.Conv2d(mid_channels, in_channels, kernel_size=1)\n            \n            nn.init.constant_(self.conv2.weight, 0)\n            nn.init.constant_(self.conv2.bias, 0)\n\n        def forward(self, x, temb):\n            B, C, H, W = x.size()\n            attn = self.spatial_attn(x).view(B, 1, H * W)\n            attn = F.softmax(attn, dim=-1)\n            \n            context = torch.bmm(x.view(B, C, H * W), attn.transpose(1, 2))\n            context = context.view(B, C, 1, 1)\n            \n            hidden = self.norm(self.conv1(context))\n            \n            if temb is not None:\n                time_bias = self.time_proj(F.mish(temb)).view(B, -1, 1, 1)\n                hidden = hidden + time_bias\n                \n            hidden = F.mish(hidden)\n            out = self.conv2(hidden)\n            return x + out\n\n    class ResnetWithTGC(nn.Module):\n        def __init__(self, resnet, tgc_block):\n            super().__init__()\n            self.resnet = resnet\n            self.tgc_block = tgc_block\n            \n        def forward(self, input_tensor, temb, *args, **kwargs):\n            x = self.resnet(input_tensor, temb, *args, **kwargs)\n            return self.tgc_block(x, temb)\n\n    channels = (128, 256, 256, 256)\n    if os.environ.get('BLOCK_OUT_CHANNELS'):\n        channels = tuple(int(x) for x in os.environ['BLOCK_OUT_CHANNELS'].split(','))\n    layers = int(os.environ.get('LAYERS_PER_BLOCK', 2))\n    time_emb_dim = channels[0] * 4\n\n    model = diffusers.UNet2DModel(\n        sample_size=32,\n        in_channels=3,\n        out_channels=3,\n        block_out_channels=channels,\n        down_block_types=(\"AttnDownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\"),\n        up_block_types=(\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\"),\n        layers_per_block=layers,\n        norm_num_groups=32,\n        norm_eps=1e-6,\n        act_fn=\"mish\",\n        time_embedding_type=\"positional\",\n        flip_sin_to_cos=True,\n        freq_shift=0,\n        downsample_padding=1,\n        attention_head_dim=32,\n        resnet_time_scale_shift=\"scale_shift\",\n    )\n\n    def inject_tgc_blocks(mod):\n        for name, child in mod.named_children():\n            if child.__class__.__name__ == \"ResnetBlock2D\":\n                out_channels = getattr(child, \"out_channels\", None)\n                if out_channels is None:\n                    out_channels = child.conv1.out_channels\n                tgc = TGCBlock(out_channels, time_emb_dim)\n                wrapped = ResnetWithTGC(child, tgc)\n                if isinstance(mod, nn.ModuleList):\n                    mod[int(name)] = wrapped\n                else:\n                    setattr(mod, name, wrapped)\n            else:\n                inject_tgc_blocks(child)\n                \n    inject_tgc_blocks(model)\n    return model.to(device)\n",
4  "filename": "custom_train.py",
5  "op": "replace",
6  "start_line": 31
7}
Tool ResultStep 1
ERROR: Package 'custom_train.py' is not in allowed packages
AssistantStep 2edit
Tool Input
1{
2  "content": "def build_model(device):\n    import diffusers\n    import os\n    import torch\n    import torch.nn as nn\n    import torch.nn.functional as F\n\n    class TGCBlock(nn.Module):\n        \"\"\"Time-conditioned Global Context Block.\"\"\"\n        def __init__(self, in_channels, time_emb_dim, ratio=4):\n            super().__init__()\n            mid_channels = max(16, in_channels // ratio)\n            self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)\n            self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)\n            self.norm = nn.GroupNorm(1, mid_channels)\n            self.time_proj = nn.Linear(time_emb_dim, mid_channels)\n            self.conv2 = nn.Conv2d(mid_channels, in_channels, kernel_size=1)\n            \n            nn.init.constant_(self.conv2.weight, 0)\n            nn.init.constant_(self.conv2.bias, 0)\n\n        def forward(self, x, temb):\n            B, C, H, W = x.size()\n            attn = self.spatial_attn(x).view(B, 1, H * W)\n            attn = F.softmax(attn, dim=-1)\n            \n            context = torch.bmm(x.view(B, C, H * W), attn.transpose(1, 2))\n            context = context.view(B, C, 1, 1)\n            \n            hidden = self.norm(self.conv1(context))\n            \n            if temb is not None:\n                time_bias = self.time_proj(F.mish(temb)).view(B, -1, 1, 1)\n                hidden = hidden + time_bias\n                \n            hidden = F.mish(hidden)\n            out = self.conv2(hidden)\n            return x + out\n\n    class ResnetWithTGC(nn.Module):\n        def __init__(self, resnet, tgc_block):\n            super().__init__()\n            self.resnet = resnet\n            self.tgc_block = tgc_block\n            \n        def forward(self, input_tensor, temb, *args, **kwargs):\n            x = self.resnet(input_tensor, temb, *args, **kwargs)\n            return self.tgc_block(x, temb)\n\n    channels = (128, 256, 256, 256)\n    if os.environ.get('BLOCK_OUT_CHANNELS'):\n        channels = tuple(int(x) for x in os.environ['BLOCK_OUT_CHANNELS'].split(','))\n    layers = int(os.environ.get('LAYERS_PER_BLOCK', 2))\n    time_emb_dim = channels[0] * 4\n\n    model = diffusers.UNet2DModel(\n        sample_size=32,\n        in_channels=3,\n        out_channels=3,\n        block_out_channels=channels,\n        down_block_types=(\"AttnDownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\"),\n        up_block_types=(\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\"),\n        layers_per_block=layers,\n        norm_num_groups=32,\n        norm_eps=1e-6,\n        act_fn=\"mish\",\n        time_embedding_type=\"positional\",\n        flip_sin_to_cos=True,\n        freq_shift=0,\n        downsample_padding=1,\n        attention_head_dim=32,\n        resnet_time_scale_shift=\"scale_shift\",\n    )\n\n    def inject_tgc_blocks(mod):\n        for name, child in mod.named_children():\n            if child.__class__.__name__ == \"ResnetBlock2D\":\n                out_channels = getattr(child, \"out_channels\", None)\n                if out_channels is None:\n                    out_channels = child.conv1.out_channels\n                tgc = TGCBlock(out_channels, time_emb_dim)\n                wrapped = ResnetWithTGC(child, tgc)\n                if isinstance(mod, nn.ModuleList):\n                    mod[int(name)] = wrapped\n                else:\n                    setattr(mod, name, wrapped)\n            else:\n                inject_tgc_blocks(child)\n                \n    inject_tgc_blocks(model)\n    return model.to(device)\n",
3  "end_line": 58,
4  "start_line": 31,
5  "filename": "diffusers-main/custom_train.py",
6  "op": "replace"
7}
Tool ResultStep 2
OK: Replaced lines 31..58 with 91 line(s) in diffusers-main/custom_train.py. Editable range: 31–121.

[Current file: diffusers-main/custom_train.py | editable: 31–121 | total: 474 lines]
    28| # Model Architecture (EDITABLE REGION)
    29| # ============================================================================
    30| 
    31  def build_model(device):
    32      import diffusers
    33      import os
    34      import torch
    35      import torch.nn as nn
    36      import torch.nn.functional as F
    37  
    38      class TGCBlock(nn.Module):
    39          """Time-conditioned Global Context Block."""
    40          def __init__(self, in_channels, time_emb_dim, ratio=4):
    41              super().__init__()
    42              mid_channels = max(16, in_channels // ratio)
    43              self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)
    44              self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
    45              self.norm = 
AssistantStep 3test
Tool ResultStep 3
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)

### train_small (scripts/train_small.sh)
Training results:
Model parameters: 10.3M | GPUs: 8
step 200/35000 | loss 0.0348 | 32.3s
step 400/35000 | loss 0.0319 | 31.5s
step 600/35000 | loss 0.0378 | 31.3s
step 800/35000 | loss 0.0268 | 31.1s
step 1000/35000 | loss 0.0402 | 30.6s
step 1200/35000 | loss 0.0327 | 31.1s
step 1400/35000 | loss 0.0311 | 31.0s
step 1600/35000 | loss 0.0226 | 32.0s
step 1800/35000 | loss 0.0318 | 31.9s
step 2000/35000 | loss 0.0426 | 31.2s
step 2200/35000 | loss 0.0426 | 31.3s
step 2400/35000 | loss 0.0275 | 32.0s
step 2600/35000 | loss 0.0464 | 31.2s
step 2800/35000 | loss 0.0280 | 30.8s
step 3000/35000 | loss 0.0342 | 31.4s
step 3200/35000 | loss 0.0343 | 31.2s
step 3400/35000 | loss 0.0325 | 31.3s
step 3600/35000 | loss 0.0297 | 32.0s
step 3800/35000 | loss 0.0351 | 31.7s
step 4000/35000 | loss 0.0227 | 31.5s
step 4200/35000 | loss 0.0345 | 31.5s
step 4400/35000 | loss 0.0201 | 31.9s
step
AssistantStep 4edit
Tool Input
1{
2  "op": "replace",
3  "filename": "diffusers-main/custom_train.py",
4  "content": "    channels = (128, 256, 256, 256)\n    if os.environ.get('BLOCK_OUT_CHANNELS'):\n        channels = tuple(int(x) for x in os.environ['BLOCK_OUT_CHANNELS'].split(','))\n    layers = int(os.environ.get('LAYERS_PER_BLOCK', 2))\n    time_emb_dim = channels[0] * 4\n\n    if channels[0] <= 64:\n        down_blocks = (\"AttnDownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\")\n        up_blocks = (\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\")\n    else:\n        down_blocks = (\"DownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\")\n        up_blocks = (\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"UpBlock2D\")\n\n    model = diffusers.UNet2DModel(\n        sample_size=32,\n        in_channels=3,\n        out_channels=3,\n        block_out_channels=channels,\n        down_block_types=down_blocks,\n        up_block_types=up_blocks,\n        layers_per_block=layers,\n        norm_num_groups=32,\n        norm_eps=1e-6,\n        act_fn=\"mish\",\n        time_embedding_type=\"positional\",\n        flip_sin_to_cos=True,\n        freq_shift=0,\n        downsample_padding=1,\n        attention_head_dim=32,\n        resnet_time_scale_shift=\"scale_shift\",\n    )",
5  "end_line": 103,
6  "start_line": 80
7}
Tool ResultStep 4
OK: Replaced lines 80..103 with 31 line(s) in diffusers-main/custom_train.py. Editable range: 31–128.

[Current file: diffusers-main/custom_train.py | editable: 31–128 | total: 481 lines]
    28| # Model Architecture (EDITABLE REGION)
    29| # ============================================================================
    30| 
    31  def build_model(device):
    32      import diffusers
    33      import os
    34      import torch
    35      import torch.nn as nn
    36      import torch.nn.functional as F
    37  
    38      class TGCBlock(nn.Module):
    39          """Time-conditioned Global Context Block."""
    40          def __init__(self, in_channels, time_emb_dim, ratio=4):
    41              super().__init__()
    42              mid_channels = max(16, in_channels // ratio)
    43              self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)
    44              self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
    45              self.norm =
AssistantStep 5test
Tool ResultStep 5
[Test #2] (1 test remaining; call submit(n=2) to submit this result)

### train_small (scripts/train_small.sh)
Training results:
Model parameters: 10.3M | GPUs: 8
step 200/35000 | loss 0.0348 | 31.9s
step 400/35000 | loss 0.0320 | 31.1s
step 600/35000 | loss 0.0378 | 31.3s
step 800/35000 | loss 0.0268 | 31.3s
step 1000/35000 | loss 0.0404 | 31.1s
step 1200/35000 | loss 0.0328 | 31.7s
step 1400/35000 | loss 0.0311 | 31.8s
step 1600/35000 | loss 0.0226 | 31.4s
step 1800/35000 | loss 0.0317 | 30.9s
step 2000/35000 | loss 0.0426 | 31.7s
step 2200/35000 | loss 0.0425 | 31.8s
step 2400/35000 | loss 0.0274 | 31.2s
step 2600/35000 | loss 0.0464 | 31.4s
step 2800/35000 | loss 0.0280 | 31.0s
step 3000/35000 | loss 0.0343 | 31.4s
step 3200/35000 | loss 0.0344 | 31.8s
step 3400/35000 | loss 0.0324 | 31.2s
step 3600/35000 | loss 0.0297 | 31.4s
step 3800/35000 | loss 0.0351 | 31.5s
step 4000/35000 | loss 0.0227 | 31.6s
step 4200/35000 | loss 0.0345 | 31.1s
step 4400/35000 | loss 0.0201 | 31.3s
step 
AssistantStep 6edit
Tool Input
1{
2  "content": "def build_model(device):\n    import diffusers\n    import os\n    import torch\n    import torch.nn as nn\n    import torch.nn.functional as F\n\n    class TGCBlock(nn.Module):\n        \"\"\"Time-conditioned Global Context Block.\"\"\"\n        def __init__(self, in_channels, time_emb_dim, ratio=4):\n            super().__init__()\n            mid_channels = max(16, in_channels // ratio)\n            self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)\n            self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)\n            self.norm = nn.GroupNorm(1, mid_channels)\n            self.time_proj = nn.Linear(time_emb_dim, mid_channels)\n            self.conv2 = nn.Conv2d(mid_channels, in_channels, kernel_size=1)\n            \n            nn.init.constant_(self.conv2.weight, 0)\n            nn.init.constant_(self.conv2.bias, 0)\n\n        def forward(self, x, temb):\n            B, C, H, W = x.size()\n            attn = self.spatial_attn(x).view(B, 1, H * W)\n            attn = F.softmax(attn, dim=-1)\n            \n            context = torch.bmm(x.view(B, C, H * W), attn.transpose(1, 2))\n            context = context.view(B, C, 1, 1)\n            \n            hidden = self.norm(self.conv1(context))\n            \n            if temb is not None:\n                time_bias = self.time_proj(F.mish(temb)).view(B, -1, 1, 1)\n                hidden = hidden + time_bias\n                \n            hidden = F.mish(hidden)\n            out = self.conv2(hidden)\n            return x + out\n\n    class ResnetWithTGC(nn.Module):\n        def __init__(self, resnet, tgc_block):\n            super().__init__()\n            self.resnet = resnet\n            self.tgc_block = tgc_block\n            \n        def forward(self, input_tensor, temb, *args, **kwargs):\n            x = self.resnet(input_tensor, temb, *args, **kwargs)\n            return self.tgc_block(x, temb)\n\n    channels = (128, 256, 256, 256)\n    if os.environ.get('BLOCK_OUT_CHANNELS'):\n        channels = tuple(int(x) for x in os.environ['BLOCK_OUT_CHANNELS'].split(','))\n    layers = int(os.environ.get('LAYERS_PER_BLOCK', 2))\n    time_emb_dim = channels[0] * 4\n\n    is_small = (channels[0] <= 64)\n\n    if is_small:\n        down_blocks = (\"AttnDownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\")\n        up_blocks = (\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\")\n        head_dim = 32\n    else:\n        down_blocks = (\"DownBlock2D\", \"AttnDownBlock2D\", \"AttnDownBlock2D\", \"DownBlock2D\")\n        up_blocks = (\"UpBlock2D\", \"AttnUpBlock2D\", \"AttnUpBlock2D\", \"UpBlock2D\")\n        head_dim = 8\n\n    model = diffusers.UNet2DModel(\n        sample_size=32,\n        in_channels=3,\n        out_channels=3,\n        block_out_channels=channels,\n        down_block_types=down_blocks,\n        up_block_types=up_blocks,\n        layers_per_block=layers,\n        norm_num_groups=32,\n        norm_eps=1e-6,\n        act_fn=\"mish\",\n        time_embedding_type=\"positional\",\n        flip_sin_to_cos=True,\n        freq_shift=0,\n        downsample_padding=1,\n        attention_head_dim=head_dim,\n        resnet_time_scale_shift=\"scale_shift\",\n    )\n\n    def inject_tgc_blocks(mod):\n        for name, child in mod.named_children():\n            if child.__class__.__name__ == \"ResnetBlock2D\":\n                out_channels = getattr(child, \"out_channels\", None)\n                if out_channels is None:\n                    out_channels = child.conv1.out_channels\n                tgc = TGCBlock(out_channels, time_emb_dim)\n                wrapped = ResnetWithTGC(child, tgc)\n                if isinstance(mod, nn.ModuleList):\n                    mod[int(name)] = wrapped\n                else:\n                    setattr(mod, name, wrapped)\n            else:\n                inject_tgc_blocks(child)\n                \n    def zero_init_resnet(mod):\n        for name, child in mod.named_children():\n            if child.__class__.__name__ == \"ResnetBlock2D\":\n                if hasattr(child, 'conv2'):\n                    nn.init.constant_(child.conv2.weight, 0)\n                    nn.init.constant_(child.conv2.bias, 0)\n            else:\n                zero_init_resnet(child)\n\n    if is_small:\n        inject_tgc_blocks(model)\n    else:\n        zero_init_resnet(model)\n\n    return model.to(device)\n",
3  "end_line": 128,
4  "op": "replace",
5  "start_line": 31,
6  "filename": "diffusers-main/custom_train.py"
7}
Tool ResultStep 6
OK: Replaced lines 31..128 with 115 line(s) in diffusers-main/custom_train.py. Editable range: 31–145.

[Current file: diffusers-main/custom_train.py | editable: 31–145 | total: 498 lines]
    28| # Model Architecture (EDITABLE REGION)
    29| # ============================================================================
    30| 
    31  def build_model(device):
    32      import diffusers
    33      import os
    34      import torch
    35      import torch.nn as nn
    36      import torch.nn.functional as F
    37  
    38      class TGCBlock(nn.Module):
    39          """Time-conditioned Global Context Block."""
    40          def __init__(self, in_channels, time_emb_dim, ratio=4):
    41              super().__init__()
    42              mid_channels = max(16, in_channels // ratio)
    43              self.spatial_attn = nn.Conv2d(in_channels, 1, kernel_size=1)
    44              self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
    45              self.norm 
AssistantStep 7test
Tool ResultStep 7
[Test #3] (0 tests remaining)

### train_small (scripts/train_small.sh)
Training results:
Model parameters: 10.3M | GPUs: 8
step 200/35000 | loss 0.0349 | 32.1s
step 400/35000 | loss 0.0318 | 31.9s
step 600/35000 | loss 0.0377 | 31.6s
step 800/35000 | loss 0.0271 | 32.4s
step 1000/35000 | loss 0.0404 | 31.2s
step 1200/35000 | loss 0.0327 | 30.9s
step 1400/35000 | loss 0.0310 | 31.4s
step 1600/35000 | loss 0.0227 | 30.7s
step 1800/35000 | loss 0.0317 | 31.3s
step 2000/35000 | loss 0.0426 | 31.6s
step 2200/35000 | loss 0.0427 | 31.2s
step 2400/35000 | loss 0.0275 | 31.7s
step 2600/35000 | loss 0.0464 | 32.2s
step 2800/35000 | loss 0.0280 | 31.2s
step 3000/35000 | loss 0.0342 | 31.3s
step 3200/35000 | loss 0.0343 | 31.2s
step 3400/35000 | loss 0.0324 | 31.1s
step 3600/35000 | loss 0.0298 | 31.1s
step 3800/35000 | loss 0.0352 | 31.5s
step 4000/35000 | loss 0.0228 | 31.1s
step 4200/35000 | loss 0.0344 | 31.2s
step 4400/35000 | loss 0.0202 | 31.3s
step 4600/35000 | loss 0.0349 | 31.7s
step 4