-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[WIP] Refactor Model Tests #12822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP] Refactor Model Tests #12822
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent stuff! S
ome general comments:
- Normalize the model outputs to a common format before they go to
torch.allclose(). - Initialize the input dict newly before passing to a new initialization of the model with
torch.manual_seed(0). This is because some autoencoder models take ageneratorinput. - Use fixtures wherever possible to reduce boilerplate and take advantage of
pytestfeatures.- One particular session-level fixture could be
base_output. It should help reduce test time quite a bit.
- One particular session-level fixture could be
- Use
pytest.mark.parametrizewhere possible.
Okay for me to do in a future PR but:
- Should also account for the attention backends.
- Should we also do a cross between CP and attention backends?
- How about the caching mixins?
Some nits:
- Use
torch.no_grad()as an entire decorator as opposed to using it inside the functions.
| - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass | ||
| Pytest mark: attention | ||
| Use `pytest -m "not attention"` to skip these tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we implement it in an individual model testing class? For example, say we want to skip it for model X where its attention class doesn't inherit from AttentionModuleMixin?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, any model using attention also uses AttentionModuleMixin. The options here
- Do not add tests from attention Mixin to a module file.
- Add a decorator that makes AttentionModuleMixin a requirement for running attention tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But there are important classes like Autoencoders that don't use the Attention mixins.
Let's do this?
Add a decorator that makes AttentionModuleMixin a requirement for running attention tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But there are important classes like Autoencoders that don't use the Attention mixins.
The check is forAttentionModuleMixinnotAttentionMixinand Autoencoders do use it
ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see your point!
So, as long as there is an Attention module this class should apply.
So, maybe for each of the tests, at the beginning, we could check if their attention classes inherit from
| if isinstance(module, AttentionModuleMixin): |
and if that's not the case, we skip.
Otherwise, I think it could be cumbersome to check which model tests should and shouldn't use this class because attention is a common component.
| Expected class attributes to be set by subclasses: | ||
| - model_class: The model class to test | ||
| - base_precision: Tolerance for floating point comparisons (default: 1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels unnecessarily restrictive. Would rather rely on method-specific rtol and atol values because we have seen that they vary from method to method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Holdover from initial draft. All precisions will be configurable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 this seems to be the case, still. We should make precision-related arguments atol and rtol configurable instead of relying on a base_precision.
| def test_gguf_quantized_layers(self): | ||
| self._test_quantized_layers({"compute_dtype": torch.bfloat16}) | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do we include:
| class QuantCompileTests: |
| @@ -0,0 +1,489 @@ | |||
| #!/usr/bin/env python | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it cannot currently generate the dummy input and init dicts? I would understand if so because inferring those is quite non-trivial.
|
One thing I think we should do is get a coverage report for If we are, then we should likely be able to explain why that's the case. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking much better. Thanks a TON for working out all my comments.
I have left a couple of comments which are mostly minor.
My major feedback now is regarding the rewrite of the quantization tests. I think we should do a before and after coverage and see if we're missing something important.
LMK if anything is unclear.
| from ...testing_utils import ( | ||
| assert_tensors_close, | ||
| is_attention, | ||
| torch_device, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| from ...testing_utils import ( | |
| assert_tensors_close, | |
| is_attention, | |
| torch_device, | |
| ) | |
| from ...testing_utils import assert_tensors_close, is_attention, torch_device |
| Expected class attributes to be set by subclasses: | ||
| - model_class: The model class to test | ||
| - base_precision: Tolerance for floating point comparisons (default: 1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 this seems to be the case, still. We should make precision-related arguments atol and rtol configurable instead of relying on a base_precision.
| # Create modified inputs for second pass (vary hidden_states to simulate denoising) | ||
| inputs_dict_step2 = inputs_dict.copy() | ||
| if "hidden_states" in inputs_dict_step2: | ||
| inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some models might not have "hidden_states" in their forward():
| x: Union[List[torch.Tensor], List[List[torch.Tensor]]], |
I agree that this is a special case but while we're at it, would it make sense to:
- Make
"hidden_states"configurable at the class-level. - 0.1 might not make as big a difference as expected. Maybe
inputs_dict_step2["hidden_states"] + torch.randn_like(inputs_dict_step2["hidden_states"])?
| # Test cache_context works without error | ||
| with model.cache_context("test_context"): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we test for anything else here as well? For example if _set_context() is working as expected or not?
|
|
||
| model.disable_cache() | ||
|
|
||
| def _test_reset_stateful_cache(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency:
| def _test_reset_stateful_cache(self): | |
| @torch.no_grad() | |
| def _test_reset_stateful_cache(self): |
|
|
||
|
|
||
| @is_group_offload | ||
| class GroupOffloadTesterMixin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
tests/models/testing_utils/memory.py
Outdated
| """ | ||
|
|
||
| @require_group_offload_support | ||
| def test_group_offloading(self, record_stream=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 this seems to be unresolved:
diffusers/tests/models/test_modeling_common.py
Line 1766 in 1cdb872
| @parameterized.expand([False, True]) |
| from ...testing_utils import ( | ||
| is_context_parallel, | ||
| require_torch_multi_accelerator, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| from ...testing_utils import ( | |
| is_context_parallel, | |
| require_torch_multi_accelerator, | |
| ) | |
| from ...testing_utils import is_context_parallel, require_torch_multi_accelerator |
| @is_context_parallel | ||
| @require_torch_multi_accelerator | ||
| class ContextParallelTesterMixin: | ||
| base_precision = 1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove base_precision and give the underlying methods atol and rtol for finer control.
|
|
||
|
|
||
| @require_accelerator | ||
| class QuantizationTesterMixin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sanity checking, I think it would be better to obtain the coverage of current https://github.com/huggingface/diffusers/blob/main/tests/quantization/ and the coverage with the changes from this PR.
This will give us a good idea of what we might be missing and if that needs further fixing.
What does this PR do?
Following the plan outlined for Diffusers 1.0.0, this PR introduces changes to our model testing approach in order to reduce the overhead involved in adding comprehensive tests for new models and standardize tests across all models.
Changes include
generate_model_tests.py) to automatically generate tests based on the model file. Also provide a flag that allows us to include any optional features to test e.g. (we can turn this into a bot down the line)I've only made changes to Flux to make this PR easy to review. I'll open follow ups in phases for the other models once this is approved.
Will now generate a template test file that can be populated with the necessary config information
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.