Skip to content
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@

def pytest_configure(config):
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
config.addinivalue_line("markers", "training: marks tests for training functionality")
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")


def pytest_addoption(parser):
Expand Down
79 changes: 79 additions & 0 deletions tests/models/testing_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from .attention import AttentionTesterMixin
from .cache import (
CacheTesterMixin,
FasterCacheConfigMixin,
FasterCacheTesterMixin,
FirstBlockCacheConfigMixin,
FirstBlockCacheTesterMixin,
PyramidAttentionBroadcastConfigMixin,
PyramidAttentionBroadcastTesterMixin,
)
from .common import BaseModelTesterConfig, ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .parallelism import ContextParallelTesterMixin
from .quantization import (
BitsAndBytesCompileTesterMixin,
BitsAndBytesConfigMixin,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFConfigMixin,
GGUFTesterMixin,
ModelOptCompileTesterMixin,
ModelOptConfigMixin,
ModelOptTesterMixin,
QuantizationCompileTesterMixin,
QuantizationTesterMixin,
QuantoCompileTesterMixin,
QuantoConfigMixin,
QuantoTesterMixin,
TorchAoCompileTesterMixin,
TorchAoConfigMixin,
TorchAoTesterMixin,
)
from .single_file import SingleFileTesterMixin
from .training import TrainingTesterMixin


__all__ = [
"AttentionTesterMixin",
"BaseModelTesterConfig",
"BitsAndBytesCompileTesterMixin",
"BitsAndBytesConfigMixin",
"BitsAndBytesTesterMixin",
"CacheTesterMixin",
"ContextParallelTesterMixin",
"CPUOffloadTesterMixin",
"FasterCacheConfigMixin",
"FasterCacheTesterMixin",
"FirstBlockCacheConfigMixin",
"FirstBlockCacheTesterMixin",
"GGUFCompileTesterMixin",
"GGUFConfigMixin",
"GGUFTesterMixin",
"GroupOffloadTesterMixin",
"IPAdapterTesterMixin",
"LayerwiseCastingTesterMixin",
"LoraHotSwappingForModelTesterMixin",
"LoraTesterMixin",
"MemoryTesterMixin",
"ModelOptCompileTesterMixin",
"ModelOptConfigMixin",
"ModelOptTesterMixin",
"ModelTesterMixin",
"PyramidAttentionBroadcastConfigMixin",
"PyramidAttentionBroadcastTesterMixin",
"QuantizationCompileTesterMixin",
"QuantizationTesterMixin",
"QuantoCompileTesterMixin",
"QuantoConfigMixin",
"QuantoTesterMixin",
"SingleFileTesterMixin",
"TorchAoCompileTesterMixin",
"TorchAoConfigMixin",
"TorchAoTesterMixin",
"TorchCompileTesterMixin",
"TrainingTesterMixin",
]
171 changes: 171 additions & 0 deletions tests/models/testing_utils/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
)

from ...testing_utils import (
assert_tensors_close,
is_attention,
torch_device,
)
Comment on lines +24 to +28
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ...testing_utils import (
assert_tensors_close,
is_attention,
torch_device,
)
from ...testing_utils import assert_tensors_close, is_attention, torch_device



@is_attention
class AttentionTesterMixin:
"""
Mixin class for testing attention processor and module functionality on models.

Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)

Expected class attributes to be set by subclasses:
- model_class: The model class to test
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)

Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass

Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we implement it in an individual model testing class? For example, say we want to skip it for model X where its attention class doesn't inherit from AttentionModuleMixin?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, any model using attention also uses AttentionModuleMixin. The options here

  1. Do not add tests from attention Mixin to a module file.
  2. Add a decorator that makes AttentionModuleMixin a requirement for running attention tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there are important classes like Autoencoders that don't use the Attention mixins.

Let's do this?

Add a decorator that makes AttentionModuleMixin a requirement for running attention tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there are important classes like Autoencoders that don't use the Attention mixins.
The check is for AttentionModuleMixin not AttentionMixin and Autoencoders do use it

ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin

Copy link
Member

@sayakpaul sayakpaul Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see your point!

So, as long as there is an Attention module this class should apply.

So, maybe for each of the tests, at the beginning, we could check if their attention classes inherit from

if isinstance(module, AttentionModuleMixin):

and if that's not the case, we skip.

Otherwise, I think it could be cumbersome to check which model tests should and shouldn't use this class because attention is a common component.

"""

@torch.no_grad()
def test_fuse_unfuse_qkv_projections(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

if not hasattr(model, "fuse_qkv_projections"):
pytest.skip("Model does not support QKV projection fusion.")

output_before_fusion = model(**inputs_dict, return_dict=False)[0]

model.fuse_qkv_projections()

has_fused_projections = False
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
has_fused_projections = True
assert module.fused_projections, "fused_projections flag should be True"
break

if has_fused_projections:
output_after_fusion = model(**inputs_dict, return_dict=False)[0]

assert_tensors_close(
output_before_fusion,
output_after_fusion,
atol=1e-3,
rtol=0,
msg="Output should not change after fusing projections",
)

model.unfuse_qkv_projections()

for module in model.modules():
if isinstance(module, AttentionModuleMixin):
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
assert not module.fused_projections, "fused_projections flag should be False"

output_after_unfusion = model(**inputs_dict, return_dict=False)[0]

assert_tensors_close(
output_before_fusion,
output_after_unfusion,
atol=1e-3,
rtol=0,
msg="Output should match original after unfusing projections",
)

def test_get_set_processor(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)

# Check if model has attention processors
if not hasattr(model, "attn_processors"):
pytest.skip("Model does not have attention processors.")

# Test getting processors
processors = model.attn_processors
assert isinstance(processors, dict), "attn_processors should return a dict"
assert len(processors) > 0, "Model should have at least one attention processor"

# Test that all processors can be retrieved via get_processor
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
processor = module.get_processor()
assert processor is not None, "get_processor should return a processor"

# Test setting a new processor
new_processor = AttnProcessor()
module.set_processor(new_processor)
retrieved_processor = module.get_processor()
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"

def test_attention_processor_dict(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)

if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")

# Get current processors
current_processors = model.attn_processors

# Create a dict of new processors
new_processors = {key: AttnProcessor() for key in current_processors.keys()}

# Set processors using dict
model.set_attn_processor(new_processors)

# Verify all processors were set
updated_processors = model.attn_processors
for key in current_processors.keys():
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"

def test_attention_processor_count_mismatch_raises_error(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)

if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")

# Get current processors
current_processors = model.attn_processors

# Create a dict with wrong number of processors
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}

# Verify error is raised
with pytest.raises(ValueError) as exc_info:
model.set_attn_processor(wrong_processors)

assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
Loading
Loading