Skip to content

Conversation

@DN6
Copy link
Collaborator

@DN6 DN6 commented Dec 11, 2025

What does this PR do?

Following the plan outlined for Diffusers 1.0.0, this PR introduces changes to our model testing approach in order to reduce the overhead involved in adding comprehensive tests for new models and standardize tests across all models.

Changes include

  1. Introducing feature specific tester Mixins and marks for models (breaking up the very large ModelTesterMixin class)
  2. Introduce new test file structure using Config + Mixin pattern
  3. New markers for selective test execution
  4. Adds a ulility script (generate_model_tests.py) to automatically generate tests based on the model file. Also provide a flag that allows us to include any optional features to test e.g. (we can turn this into a bot down the line)

I've only made changes to Flux to make this PR easy to review. I'll open follow ups in phases for the other models once this is approved.

python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_qwenimage.py 

Will now generate a template test file that can be populated with the necessary config information

# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from diffusers import QwenImageTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
from ..testing_utils import (
    AttentionTesterMixin,
    ContextParallelTesterMixin,
    LoraTesterMixin,
    MemoryTesterMixin,
    ModelTesterMixin,
    TorchCompileTesterMixin,
    TrainingTesterMixin,
)


enable_full_determinism()


class QwenImageTransformerTesterConfig:
    model_class = QwenImageTransformer2DModel
    pretrained_model_name_or_path = ""
    pretrained_model_kwargs = {"subfolder": "transformer"}

    @property
    def generator(self):
        return torch.Generator("cpu").manual_seed(0)

    def get_init_dict(self) -> dict[str, int | list[int]]:
        # __init__ parameters:
        #   patch_size: int = 2
        #   in_channels: int = 64
        #   out_channels: Optional[int] = 16
        #   num_layers: int = 60
        #   attention_head_dim: int = 128
        #   num_attention_heads: int = 24
        #   joint_attention_dim: int = 3584
        #   guidance_embeds: bool = False
        #   axes_dims_rope: Tuple[int, int, int] = <complex>
        return {}

    def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
        # forward() parameters:
        #   hidden_states: torch.Tensor
        #   encoder_hidden_states: torch.Tensor
        #   encoder_hidden_states_mask: torch.Tensor
        #   timestep: torch.LongTensor
        #   img_shapes: Optional[List[Tuple[int, int, int]]]
        #   txt_seq_lens: Optional[List[int]]
        #   guidance: torch.Tensor
        #   attention_kwargs: Optional[Dict[str, Any]]
        #   controlnet_block_samples
        #   return_dict: bool = True
        # TODO: Fill in dummy inputs
        return {}

    @property
    def input_shape(self) -> tuple[int, ...]:
        return (1, 1)

    @property
    def output_shape(self) -> tuple[int, ...]:
        return (1, 1)


class TestQwenImageTransformerModel(QwenImageTransformerTesterConfig, ModelTesterMixin):
    pass


class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
    pass


class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
    pass


class TestQwenImageTransformerTorchCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
    different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]

    def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
        # TODO: Implement dynamic input generation
        return {}


class TestQwenImageTransformerLora(QwenImageTransformerTesterConfig, LoraTesterMixin):
    pass


class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
    pass


class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
    pass


class TestQwenImageTransformerLoraHotSwappingForModel(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
    different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]

    def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
        # TODO: Implement dynamic input generation
        return {}

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 DN6 requested review from dg845, sayakpaul and yiyixuxu December 11, 2025 06:16
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent stuff! S

ome general comments:

  • Normalize the model outputs to a common format before they go to torch.allclose().
  • Initialize the input dict newly before passing to a new initialization of the model with torch.manual_seed(0). This is because some autoencoder models take a generator input.
  • Use fixtures wherever possible to reduce boilerplate and take advantage of pytest features.
    • One particular session-level fixture could be base_output. It should help reduce test time quite a bit.
  • Use pytest.mark.parametrize where possible.

Okay for me to do in a future PR but:

  • Should also account for the attention backends.
  • Should we also do a cross between CP and attention backends?
  • How about the caching mixins?

Some nits:

  • Use torch.no_grad() as an entire decorator as opposed to using it inside the functions.

- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we implement it in an individual model testing class? For example, say we want to skip it for model X where its attention class doesn't inherit from AttentionModuleMixin?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, any model using attention also uses AttentionModuleMixin. The options here

  1. Do not add tests from attention Mixin to a module file.
  2. Add a decorator that makes AttentionModuleMixin a requirement for running attention tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there are important classes like Autoencoders that don't use the Attention mixins.

Let's do this?

Add a decorator that makes AttentionModuleMixin a requirement for running attention tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there are important classes like Autoencoders that don't use the Attention mixins.
The check is for AttentionModuleMixin not AttentionMixin and Autoencoders do use it

ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin

Copy link
Member

@sayakpaul sayakpaul Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see your point!

So, as long as there is an Attention module this class should apply.

So, maybe for each of the tests, at the beginning, we could check if their attention classes inherit from

if isinstance(module, AttentionModuleMixin):

and if that's not the case, we skip.

Otherwise, I think it could be cumbersome to check which model tests should and shouldn't use this class because attention is a common component.

Expected class attributes to be set by subclasses:
- model_class: The model class to test
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels unnecessarily restrictive. Would rather rely on method-specific rtol and atol values because we have seen that they vary from method to method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Holdover from initial draft. All precisions will be configurable

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 this seems to be the case, still. We should make precision-related arguments atol and rtol configurable instead of relying on a base_precision.

def test_gguf_quantized_layers(self):
self._test_quantized_layers({"compute_dtype": torch.bfloat16})


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we include:

@@ -0,0 +1,489 @@
#!/usr/bin/env python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it cannot currently generate the dummy input and init dicts? I would understand if so because inferring those is quite non-trivial.

@sayakpaul
Copy link
Member

One thing I think we should do is get a coverage report for tests/models with `main and with this PR and confirm we are not skipping anything truly critical.

If we are, then we should likely be able to explain why that's the case.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking much better. Thanks a TON for working out all my comments.

I have left a couple of comments which are mostly minor.

My major feedback now is regarding the rewrite of the quantization tests. I think we should do a before and after coverage and see if we're missing something important.

LMK if anything is unclear.

Comment on lines +24 to +28
from ...testing_utils import (
assert_tensors_close,
is_attention,
torch_device,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ...testing_utils import (
assert_tensors_close,
is_attention,
torch_device,
)
from ...testing_utils import assert_tensors_close, is_attention, torch_device

Expected class attributes to be set by subclasses:
- model_class: The model class to test
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 this seems to be the case, still. We should make precision-related arguments atol and rtol configurable instead of relying on a base_precision.

Comment on lines +164 to +167
# Create modified inputs for second pass (vary hidden_states to simulate denoising)
inputs_dict_step2 = inputs_dict.copy()
if "hidden_states" in inputs_dict_step2:
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some models might not have "hidden_states" in their forward():

x: Union[List[torch.Tensor], List[List[torch.Tensor]]],

I agree that this is a special case but while we're at it, would it make sense to:

  • Make "hidden_states" configurable at the class-level.
  • 0.1 might not make as big a difference as expected. Maybe inputs_dict_step2["hidden_states"] + torch.randn_like(inputs_dict_step2["hidden_states"])?

Comment on lines +193 to +195
# Test cache_context works without error
with model.cache_context("test_context"):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test for anything else here as well? For example if _set_context() is working as expected or not?


model.disable_cache()

def _test_reset_stateful_cache(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency:

Suggested change
def _test_reset_stateful_cache(self):
@torch.no_grad()
def _test_reset_stateful_cache(self):



@is_group_offload
class GroupOffloadTesterMixin:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

"""

@require_group_offload_support
def test_group_offloading(self, record_stream=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 this seems to be unresolved:

@parameterized.expand([False, True])

Comment on lines +24 to +27
from ...testing_utils import (
is_context_parallel,
require_torch_multi_accelerator,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ...testing_utils import (
is_context_parallel,
require_torch_multi_accelerator,
)
from ...testing_utils import is_context_parallel, require_torch_multi_accelerator

@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
base_precision = 1e-3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove base_precision and give the underlying methods atol and rtol for finer control.



@require_accelerator
class QuantizationTesterMixin:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sanity checking, I think it would be better to obtain the coverage of current https://github.com/huggingface/diffusers/blob/main/tests/quantization/ and the coverage with the changes from this PR.

This will give us a good idea of what we might be missing and if that needs further fixing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants