Learning stackable and skippable LEGO bricks for efficient, reconfigurable, and variable-resolution diffusion modeling.

1The University of Texas at Austin, 2ByteDance Inc., 3Microsoft Azure AI
2d-view-interpolation-panda.

Top Row: 2048 x 600 panorama image sample from the LEGO model trained on ImageNet 256 x 256.
Middle Row: 512 x 512 image samples from LEGO model trained on ImageNet 512 x 512.
Bottom Row: 256 x 256 image samples from LEGO model trained on ImageNet 256 x 256.

Abstract

Diffusion models excel at generating photo-realistic images but come with significant computational costs in both training and sampling. While various techniques address these computational challenges, a less-explored issue is designing an efficient and adaptable network backbone for iterative refinement. Current options like U-Net and Vision Transformer often rely on resource-intensive deep networks and lack the flexibility needed for generating images at variable resolutions or with a smaller network than used in training. This study introduces LEGO bricks, which seamlessly integrate Local-feature Enrichment and Global-content Orchestration for hierarchical patch-wise diffusion modeling. These bricks can be stacked to create a test-time reconfigurable diffusion backbone, allowing selective skipping of bricks to reduce sampling costs and generate higher-resolution images than the training data. LEGO bricks enrich local regions with an MLP and transform them using a Transformer block while maintaining a consistent full-resolution image across all bricks. Experimental results demonstrate that LEGO bricks enhance training efficiency, expedite convergence, and facilitate variable-resolution image generation while maintaining strong generative performance. Moreover, LEGO significantly reduces sampling time compared to other methods, establishing it as a valuable enhancement for diffusion models.

Model Overview

Our envisioned LEGO bricks are intended to possess several advantageous properties:

  1. Spatial Efficiency in Training: Each LEGO brick conducts a patch-wise diffusion modeling. Within the ensemble, they jointly form a hierarchical patch-level diffusion modeling. The majority of LEGO bricks are dedicated to producing local patches using computation-light MLP mixing and attention modules. This design choice leads to a significant reduction in computational Floating-Point Operations (FLOPs) and substantially shortens the overall training duration.
  2. Efficiency in Sampling: During sampling, the LEGO bricks can be selectively skipped at each time step without a discernible decline in generation performance. Specifically, when t is large, indicating greater uncertainty in the global spatial structure, more patch-level LEGO bricks can be safely skipped. Conversely, when t is small, signifying a more stable global spatial structure, more full-resolution LEGO bricks can be bypassed.
  3. Versatility: LEGO bricks showcase remarkable versatility, accommodating both end-to-end training and sequential training from lower to upper bricks, all while enabling generation at resolutions significantly higher than those employed during training. Furthermore, they readily support the integration of existing pre-trained models as LEGO bricks, enhancing the model's adaptability and ease of use.

Model illustration.

Illustration of LEGO-PG model. Each brick conducts a patch-level diffusion training with the output of the previous stage as condition.

Decompose images for patch-wise training

Denote the original image with spatial dimensions \(H \times W\) as \(\mathbf{x}\). For the \(k^{\text{th}}\) LEGO brick, which operates on patches of size \(r_h(k) \times r_w(k)\), where \(r_h(k) \le H\) and \(r_w(k) \le W\), we extract a set of patches of that size from \(\mathbf{x}\). To simplify, we assume the brick size \(r_h(k) = r_w(k) = r_{k}\), both \(\frac{H}{r_{k}}\) and \(\frac{W}{r_{k}}\) are integers, and the image is divided into non-overlapping patches, represented as:

\[ \mathbf{x}^{(k)}_{(i,j)} = \mathbf{x}[(i-1)r_{k}+1:i r_{k}, (j-1) r_{k}+1:jr_{k}]; \quad i \in \{1, \ldots, \frac{H}{r_{k}}\}, \quad j \in \{1, \ldots, \frac{W}{r_{k}}\}. \]

We also denote \(\mathbf{m} \in [-1,1]^{H \times W}\) as the normalized coordinates of the image pixels, and similarly, \(\mathbf{m}_{(i,j)}^{(k)}\) as the coordinate matrix of the \((i,j)^{\text{th}}\) patch at the \(k^{\text{th}}\) LEGO brick.

Patch-wise diffusion training

Denoting a noise-corrupted patch at time \(t\) as \(\mathbf{x}_t^{(k)}\), we have the diffusion chains as

\[ \text{Forward: } \quad q(\mathbf{x}^{(k)}_{0:T}) = q(\mathbf{x}^{(k)}_0) \prod_{t=1}^T \mathcal{N} \left(\mathbf{x}_t^{(k)}; \frac{\sqrt{\alpha_t}}{\sqrt{\alpha_{t-1}}}\mathbf{x}_{t-1}^{(k)}, 1 - \frac{\alpha_t}{\alpha_{t-1}}\right), \]

\[ \text{Reverse: } \quad p_{\boldsymbol{\theta}}(\mathbf{x}_{0:T}) = p(\mathbf{x}_T) \prod_{t=1}^T q(\mathbf{x}_{t-1}^{(k)} \mid \mathbf{x}_t^{(k)}, \hat{\mathbf{x}}_0^{(k)} = f_{\theta}(\mathbf{x}^{(k)}_t, \mathbf{x}_t^{(k-1)}, t)), \]

Denote \(\lambda_t^{(k)}\) as time- and brick-dependent weight coefficients, whose settings are described in Appendix. With the refined image patches \(\hat{\mathbf{x}}_{0,(i,j)}^{(k)}\), we express the training loss over the \(K\) LEGO bricks as

\[ \mathbb{E}_k \mathbb{E}_{t, \mathbf{x}_0^{(k)}, \epsilon, (i,j)} \left[\lambda_t^{(k)} \left\| \mathbf{x}_{0,(i,j)}^{(k)} - \hat{\mathbf{x}}_{0,(i,j)}^{(k)} \right\|_2^2 \right], \quad \hat{\mathbf{x}}_{0,(i,j)}^{(k)} := f_{\theta_k}(\mathbf{x}^{(k)}_t, (i,j), \hat{\mathbf{x}}_{0,(i,j)}^{(k-1)}, t). \]

When processing a noisy image \(\mathbf{x}_t\) at time \(t\), we perform upward propagation through the stacked LEGO bricks to progressively refine the full image. The LEGO brick proceeds with refined image patches \(\hat{\mathbf{x}}_0^{(k)}\) for \(k = 1, \ldots, K\). The output from the last LEGO brick \(\hat{\mathbf{x}}_{0}^{(k-1)}\) will help the training of the k-th LEGO brick.

Recursive ensemble of LEGO bricks

Stacking LEGO bricks together, the vanilla denoising diffusion step of the full model, is decomposed into \(K\) consecutive LEGO bricks, stacked from the top to the bottom as follows:

\[ \hat{\mathbf{x}}_0(\mathbf{x}_t, t; \theta) = \mathbf{z}_t^{(K)}, \quad \text{where} \quad \mathbf{z}_t^{(k)} = f_{\theta_k}(\mathbf{x}_t, \mathbf{z}^{(k-1)}_t, t) \quad \text{for} \quad k = K, \ldots, 1, \]

with \(\mathbf{z}_t^{(0)} := \emptyset\), \(\theta := \{\theta_k\}_{1,K}\), and \(\mathbf{z}^{(k)}_t\) denoting a grid of refined patches based on the corresponding patches from the output of the lower LEGO brick at time \(t\). The full model presents a hierarchical patch diffusion model.

Stacking LEGO bricks to generate multi-scale resolutions

We can stack LEGO bricks in different ways to generate image patches at different resolutions until it reach to the training resolution. The following are three common strategies:
  • Progressive Grow (PG): 4 x 4 -> 16 x 16 -> 64 x 64
  • Progressive Refine (PR): 64 x 64 -> 16 x 16 -> 4 x 4
  • U-shape: 64 x 64 -> 16 x 16 -> 4 x 4 -> 16 x 16 -> 64 x 64

Efficiency of LEGO Diffusion

Analysis of the LEGO bricks

Progressive growth (PG) and progressive refinement (PR) offer two distinct spatial refinement strategies in the LEGO model. In PG, the model initially uses patch-bricks to generate patches and re-aggregate these patches to compose a full image. Conversely, PR begins by utilizing the image-brick to establish a global structure and then employs local feature-oriented patch-bricks to refine details on top of it.

Below is a visual comparison of PG and PR model in denoising at different timesteps:

Denoising Viaualization of LEGO-PG model. Denoising Viaualization of LEGO-PR model.

Visualization of the denoising result of PG-/PR-stacked LEGO bricks at different timesteps.

Skipping LEGO bricks in sampling

We can chooose to activate different LEGO bricks according to the noisy-level during sampling.

The design of LEGO inherently facilitates the sampling process by generating images with selected LEGO bricks. Intuitively, at low-noise timesteps (i.e., when \( t \) is small), the global structure of images is already well-defined, and hence the model is desired to prioritize local details. Conversely, when images are noisier, the global-content orchestration becomes crucial to uncover global structure under high uncertainties. Therefore, for improved efficiency, we skip LEGO bricks that emphasize local details during high-noise timesteps and those that construct global structures during low-noise timesteps.

Below is a visual comparison of FID and inference time for 50k images change as the proportion of reverse diffusion time steps, at which the top-level brick is skipped. Interestingly, when \( t_{\text{break}} \) is chosen to be close to the halfway point of sampling, performance is preserved for both models, and significant time savings can be achieved during sampling:

FID vs. Skipping LEGO bricks at different timesteps.

Visualization of how FID and inference time for 50k images change as the proportion of reverse diffusion time steps, at which the top-level brick is skipped.

Convergence in training and model efficiency

Besides the flexibility in sampling, LEGO bricks also possess faster convergence in training, managing lower computational costs in the backbone to enhance training efficiency, as evidenced by reduced FLOPs, faster convergence, and shorter training times shown below.

Convergence and model FLOPs comparison.

Left: A comparison of convergence, measured with FID versus training time. Right: A comparison of the computation cost measured with FID versus training FLOPs.

Generation results

Normal image generation

LEGO bricks can generate high-quality images same as training resolutions, as shown below.
Generation Viaualization of LEGO-PG model. Generation Viaualization of LEGO-PR model.

Normal class-conditional image generation trained on ImageNet (512 x 512 and 256 x 256 resolution).

Generation beyond training resolution

The patch-wise diffusion modeling and global content organization capacity allows LEGO bricks to generate images at resolutions significantly higher than those used during training. Below are the results of generating larger images from LEGO models trained on 256 x 256 and 512 x 512 ImageNet datasets.
Paranoma Viaualization of 512 model. Paranoma Viaualization of 256 model.

Paranoma class-conditional image generation trained on ImageNet (512 x 512 and 256 x 256 resolution).

BibTeX


      @inproceedings{
        zheng2024learning,
        title={Learning Stackable and Skippable {LEGO} Bricks for Efficient, Reconfigurable, and Variable-Resolution Diffusion Modeling},
        author={Huangjie Zheng and Zhendong Wang and Jianbo Yuan and Guanghan Ning and Pengcheng He and Quanzeng You and Hongxia Yang and Mingyuan Zhou},
        booktitle={The Twelfth International Conference on Learning Representations},
        year={2024},
        url={https://openreview.net/forum?id=qmXedvwrT1}
        }