Stable Diffusion Training Notes


This is a living document which I am using to keep track of my findings during development.
It is less organized than a typical blog post, and full of TODO and FIXME notes.

Repositories referenced by this post:

  • Containers, launcher, and dataloader - https://github.com/LunNova/translunar-diffusion
  • Base stable diffusion with optimizations for training - https://github.com/LunNova/InvokeAI-nyoom

Clone "translunar-diffusion" with submodules and use start.sh in the docker folder to get a container.


Stable diffusion is an awesome txt2img model. Let's get it ready to train.

We're going to start with the InvokeAI fork, formerly known as lstein/stable-diffusion

Goals and Results Summary

  • upgrade pytorch-lightning to get support for better strategies than DistributedDataParallel
    • done, works
  • add a generic, configurable local data source
    • done, works
  • make validation optional
    • done, works
  • remove any cuda-specific code and full precision casts
    • this should give us fp16 and bf16 support to speed up training and reduce vram usage
    • fp16 works, can't test bf16
  • try out gradient accumulation to speed up training
    • this didn't work well, the resulting model produced low quality images
  • try out deepspeed to reduce vram usage even further
    • this works but made training 3-4x slower than without deepspeed
  • use fsdp_native to reduce vram requirements on each GPU when scaling
    • this didn't work
  • split things up and publish repos and post
    • you're reading it, so yes!

Hardware

You need at minimum a GPU with 16GB of VRAM to train, and that's cutting it fine and will only work when using deepspeed.

I tested using a Radeon Pro W6800 32GB GPU which allows training without deepspeed with a batch size of 3.

I haven't tested with any NVIDIA GPUs as I don't have any with sufficient VRAM yet.

AMD Specific Issues and Workarounds

If you're fortunate enough to have an NVIDIA GPU with enough VRAM, there is no equivalent section, as CUDA is more reliable.

GPU Hangs

Sometimes amdgpu crashes mixing compute and desktop graphics loads. If you're reading this section after 2022, this information is likely out of date.

Use a modern kernel (5.19+) and set the following kernel parameters:

amdgpu.gpu_recovery=2
amdgpu.reset_method=4

Together on a GPU which supports BACO (bus active, chip off), this allows amdgpu to reset without interrupting your X session!

Click for fun dmesg logs showing a crash and recovery
[56076.626692] amdgpu: qcm fence wait loop timeout expired
[56076.626695] amdgpu: The cp might be in an unrecoverable state due to an unsuccessful queues preemption
[56076.626697] amdgpu: Failed to evict process queues
[56076.626716] amdgpu 0000:0c:00.0: amdgpu: GPU reset begin!
[56076.626727] amdgpu: Failed to quiesce KFD
[56076.639222] amdgpu: Failed to suspend process 0x800a
[56076.697119] [drm] free PSP TMR buffer
[56076.734968] CPU: 14 PID: 365547 Comm: kworker/u64:1 Not tainted 5.19.1-xanmod1 #1-NixOS
[56076.734970] Hardware name: ASUS System Product Name/ROG STRIX B550-F GAMING (WI-FI), BIOS 2423 08/10/2021
[56076.734971] Workqueue: amdgpu-reset-dev amdgpu_device_queue_gpu_recover_work [amdgpu]
[56076.735057] Call Trace:
[56076.735058]  <TASK>
[56076.735060]  dump_stack_lvl+0x45/0x5e
[56076.735065]  amdgpu_do_asic_reset+0x28/0x434 [amdgpu]
[56076.735178]  amdgpu_device_gpu_recover_imp.cold+0x600/0x9de [amdgpu]
[56076.735280]  amdgpu_device_queue_gpu_recover_work+0x16/0x20 [amdgpu]
[56076.735349]  process_one_work+0x251/0x440
[56076.735352]  worker_thread+0x239/0x4c0
[56076.735353]  ? mod_delayed_work_on+0x130/0x130
[56076.735354]  kthread+0x158/0x180
[56076.735356]  ? kthread_complete_and_exit+0x20/0x20
[56076.735357]  ret_from_fork+0x1f/0x30
[56076.735359]  </TASK>
[56076.735361] amdgpu 0000:0c:00.0: amdgpu: BACO reset
[56076.904156] amdgpu 0000:0c:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0013 address=0x77db21700 flags=0x0020]
[56076.926570] amdgpu 0000:0c:00.0: amdgpu: GPU reset succeeded, trying to resume
[56076.926727] [drm] PCIE GART of 512M enabled (table at 0x0000008000000000).
[56076.926745] [drm] VRAM is lost due to GPU reset!
[56076.926751] [drm] PSP is resuming...
[56076.986352] [drm] reserve 0xa00000 from 0x85e3c00000 for PSP TMR
[56077.090556] amdgpu 0000:0c:00.0: amdgpu: GECC is enabled
[56077.112143] amdgpu 0000:0c:00.0: amdgpu: SECUREDISPLAY: securedisplay ta ucode is not available
[56077.112147] amdgpu 0000:0c:00.0: amdgpu: SMU is resuming...
[56077.112150] amdgpu 0000:0c:00.0: amdgpu: smu driver if version = 0x00000040, smu fw if version = 0x00000041, smu fw program = 0, version = 0x003a5400 (58.84.0)
[56077.112152] amdgpu 0000:0c:00.0: amdgpu: SMU driver if version not matched
[56077.112183] amdgpu 0000:0c:00.0: amdgpu: use vbios provided pptable
[56077.120431] amdgpu 0000:0c:00.0: amdgpu: SMU is resumed successfully!
[56077.121385] [drm] DMUB hardware initialized: version=0x02020013
[56077.147177] [drm] kiq ring mec 2 pipe 1 q 0
[56077.152644] [drm] VCN decode and encode initialized successfully(under DPG Mode).
[56077.152816] [drm] JPEG decode initialized successfully.
[56077.152827] amdgpu 0000:0c:00.0: amdgpu: ring gfx_0.0.0 uses VM inv eng 0 on hub 0
[56077.152828] amdgpu 0000:0c:00.0: amdgpu: ring comp_1.0.0 uses VM inv eng 1 on hub 0
[56077.152829] amdgpu 0000:0c:00.0: amdgpu: ring comp_1.1.0 uses VM inv eng 4 on hub 0
[56077.152829] amdgpu 0000:0c:00.0: amdgpu: ring comp_1.2.0 uses VM inv eng 5 on hub 0
[56077.152830] amdgpu 0000:0c:00.0: amdgpu: ring comp_1.3.0 uses VM inv eng 6 on hub 0
[56077.152830] amdgpu 0000:0c:00.0: amdgpu: ring comp_1.0.1 uses VM inv eng 7 on hub 0
[56077.152831] amdgpu 0000:0c:00.0: amdgpu: ring comp_1.1.1 uses VM inv eng 8 on hub 0
[56077.152831] amdgpu 0000:0c:00.0: amdgpu: ring comp_1.2.1 uses VM inv eng 9 on hub 0
[56077.152831] amdgpu 0000:0c:00.0: amdgpu: ring comp_1.3.1 uses VM inv eng 10 on hub 0
[56077.152832] amdgpu 0000:0c:00.0: amdgpu: ring kiq_2.1.0 uses VM inv eng 11 on hub 0
[56077.152832] amdgpu 0000:0c:00.0: amdgpu: ring sdma0 uses VM inv eng 12 on hub 0
[56077.152833] amdgpu 0000:0c:00.0: amdgpu: ring sdma1 uses VM inv eng 13 on hub 0
[56077.152833] amdgpu 0000:0c:00.0: amdgpu: ring sdma2 uses VM inv eng 14 on hub 0
[56077.152834] amdgpu 0000:0c:00.0: amdgpu: ring sdma3 uses VM inv eng 15 on hub 0
[56077.152834] amdgpu 0000:0c:00.0: amdgpu: ring vcn_dec_0 uses VM inv eng 0 on hub 1
[56077.152835] amdgpu 0000:0c:00.0: amdgpu: ring vcn_enc_0.0 uses VM inv eng 1 on hub 1
[56077.152835] amdgpu 0000:0c:00.0: amdgpu: ring vcn_enc_0.1 uses VM inv eng 4 on hub 1
[56077.152836] amdgpu 0000:0c:00.0: amdgpu: ring vcn_dec_1 uses VM inv eng 5 on hub 1
[56077.152836] amdgpu 0000:0c:00.0: amdgpu: ring vcn_enc_1.0 uses VM inv eng 6 on hub 1
[56077.152837] amdgpu 0000:0c:00.0: amdgpu: ring vcn_enc_1.1 uses VM inv eng 7 on hub 1
[56077.152837] amdgpu 0000:0c:00.0: amdgpu: ring jpeg_dec uses VM inv eng 8 on hub 1
[56077.160757] amdgpu 0000:0c:00.0: amdgpu: recover vram bo from shadow start
[56077.160760] amdgpu 0000:0c:00.0: amdgpu: recover vram bo from shadow done
[56077.160770] amdgpu 0000:0c:00.0: amdgpu: GPU reset(1) succeeded!

Power Profiles

Modern AMD GPUs have a compute power profile, which should be activated before training. On my machine this is profile 5.

I use amdgpu-power-limit.sh to set a wattage limit and power profile.

Upgrading pytorch-lightning

latent-diffusion's original release used an ancient pytorch-lightning version. Upgrading to the latest 1.7.6 isn't too hard, and adds support for new accelerators.

Most changes will be in the launcher (main.py) which constructs the lightning Trainer and its options. I can't point you to a single nice commit for the minimal changes, as I replaced it.

The only other required change was removing the deprecated dataloader_idx arg from on_train_batch_start in ddpm.py.

Hardware Agnostic Training

Pytorch Lightning supports "hardware agnostic training"

Implementing this is mostly just removing casts and .to calls. My changes are in this commit.

fp16/bf16

On a modern pytorch lightning version with hardware agnostic training fp16 just works. Ensure precision: 16 is set in the lightning section of the training config.

bf16 likely also works but I don't have an NVIDIA card to test on.

deepspeed

Deepspeed's most important feature for us is called ZeRO-Offload. This offloads the optimizer state and compute to the CPU/system RAM.

  trainer:
    strategy: "deepspeed"
    # Does NOT work in combination with deepspeed currently
    # FIXME: debug why
    # precision: 16
  deepspeed:
    # Setting this uses a deepspeed config file and ignores other flags
    # config: ./configs/stable-diffusion/deepspeed.json/
    stage: 1
    offload_optimizer: True

Add these options to your training yml to enable deepspeed. VRAM usage will go down, CPU usage and PCIE bandwidth usage will go up. Training speed will slow to a crawl, partly due to fp16 not working. :(

In my experience this isn't really practical, as the reduction in training speed outweighs any increase in batch size.

FSDP / Fully Sharded Data Parallel

Fully sharded data parallel training allows splitting the model across GPUs without storing the full model weights and optimizer state on each GPU, reducing VRAM usage.

fsdp

fsdp uses facebook's fairscale FSDP implementation.

Training seems to work, but occasionally hangs making it unusable for a full dataset.

TODO: Look into this hang

fsdp_native

fsdp_native is Pytorch Lightning's native fsdp implementation.

In theory, setting the strategy parameter to the pytorch lightning Trainer to fsdp_native should be the only necessary change. In practice, this doesn't work as some weights for the autoencoder are left on the CPU.

There's an open bug for this which seems to match my issue.

Optimized EMA

See Stable diffusion optimization: EMA weights on CPU.

TL;DR: EMA weights can be stored in system RAM and updated only ever N batches to reduce compute and VRAM requirements. I also fixed a memory leak that was doubling the required VRAM.

Autoencoder training

Training a custom autoencoder on your dataset may be worth doing if it is mostly art, as the autoencoder in the base stable diffusion model seems to be pretty poor at anime or furry style eyes.

FIXME: This only works at fp32, fp16 gives NaN loss

See training-encoder.yml for a configuration for training the encoder.

Setup

Clone github:LunNova/lunar-diffusion, including submodules.

Checkpoints and Logging

Example `lightning` yml file
lightning:
  logger:
    tensorboard:
      target: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
      params:
        flush_secs: 300
        name: tensorboard
  callbacks:
    progress:
      target: lun.callbacks.SmoothedProgressBar
    # These look messy but afaict you need all these settings for correct functioning
    monitored_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        auto_insert_metric_name: false
        monitor: *monitor
        save_top_k: 1
        save_last: false # don't do last.cklpt
        filename: monitor/loss={val/loss_simple_ema:.3f} e={epoch:04d} gs={step:06d}
    periodic_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        auto_insert_metric_name: false
        every_n_train_steps: 20000
        monitor: null
        save_top_k: -1 # keep unlimited
        save_last: false # don't do last.cklpt
        save_on_train_epoch_end: True # val may be off
        filename: periodic/e={epoch:04d} gs={step:06d}
    periodic_checkpoint_overwrite:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        auto_insert_metric_name: false
        every_n_train_steps: 2000
        monitor: null
        save_top_k: 1 # needs this so it will actually overwrite
        save_last: false # don't do last.cklpt
        save_on_train_epoch_end: True
        filename: every-2k-steps
    image_logger:
      target: main.ImageLogger
      params:
        batch_frequency: 500
        max_images: 4
        increase_log_steps: False
        log_first_step: False
        # ignore global_step and keep track internally
        # works around bug when gradient accumulation is on
        check_custom_step: True
        log_images_kwargs:
          use_ema_scope: False
          inpaint: False
          plot_progressive_rows: False
          plot_diffusion_rows: False
          N: 4
          ddim_steps: 50

I recommend setting up the above callbacks and logger. Tweak fequencies as desired.

This will keep one checkpoint every-2k-steps.ckpt that gets overwritten, a periodic folder of checkpoints every 20k steps that are kept, and a monitor folder of "best" checkpoints based on the monitored metric.

This balance avoids using up too much disk space for periodic checkpoints, but keeps a frequent overwritten checkpoint that you can resume from in case of system crashes. This happened frequently to me earlier due to amdgpu/ROCM issues.

main.py also creates on_exception.ckpt when training is interrupted due to an exception, or ctrl-c.

Dataset

You'll need to prepare a dataset. There are two options ready in the code, flat files and metadata.

Flat files

dataset_dir
  images
    1.png
    2.jpg
  captions
    1.txt
    2.txt

The text files should contain the caption for each image. No filtering is available. Minimal options are available.

    train:
      target: lun.data.local.Local
      params:
        size: 512
        flip_p: 0.333
        mode: "train"
        # no metadata_params set

Metadata

The metadata approach requires this structure:

dataset_dir
  metadata
    <any number of folders deep>
      1.json
      2.json
      ...
  any arbitrary structure with images

The json metadata files should look like this:

Example JSON metadata file
{
   # list of tags for the image
   "tags":[
      "oc:kindle",
      "oc",
      "kirin",
      "artist:lulubell",
      "oc only",
      "solo",
      "safe"
   ],
   # score eg from an image board
   "score":55,
   # path relative to dataset root
   "path":"images/287/2877648.png"
}

Extra keys in the JSON file are fine and will be ignored.

If you're curious, that's for this cute picture!

Rather long YAML config with metadata.

This is a cut down example of a real dataloader I used to train recently.

    train:
      target: lun.data.local.Local
      params:
        size: 512
        flip_p: 0.333
        metadata_params: &train_metadata_params # <&so can merge this into validation below!
          # only check every nth .json in the metadata dir
          # good for validation set
          #consider_every_nth: 8
          # shuffle tags randomly in half the dataset
          shuffle_tag_p: 0.5
          # ignoring images with these tags
          blacklist_tags:
            - animated
            - machine learning generated
          # removing some tags which have no impact on the image contents from the caption
          non_caption_tags:
            - high res
            - alternate version
            - derpibooru exclusive
            - color edit
          # shortening some common multi word tags to reduce
          # how many tokens are used
          replacements:
            princess cadance: cadance
            princess luna: luna
            'artist:': 'by='
            # recommended if your dataset uses : in any tags
            # as by convention : is used for weighted prompts in most stable diffusion
            # txt2img frontend
            ':': '='
            ' ': '_'
          tag_separator: ' '
          # if score's below this it gets filtered out immediately
          abs_min_score: 50
          # if score's below this after applying tag_bonus_scores gets filtered out
          min_score: 200
          # add a tag based on the imageboard score
          # >3200 = scr3200, >1600 = scr1600, and so on
          score_tags: [ 3200, 1600, 800, 600, 400, 300, 200, 150, 100, 50, 25, 5 ]
          # increase or decrease the score used to check min_score
          # used to increase or decrease prevalence of particular tags
          tag_bonus_scores:
            # want to include more images with 1/2 chars exactly in dataset
            solo: 100
            duo: 100
            pride flag: 50
            # preferring full color art
            monochrome: -75
            grayscale: -75
            sketch: -75

Once you've prepared your dataset and config yaml, save it in the configs directory and add it to the --base option when launching the trainer.


tagged machine learning image generation stable diffusion cuda rocm