diff --git a/modules/dataLoader/ChromaBaseDataLoader.py b/modules/dataLoader/ChromaBaseDataLoader.py index 580152ce5..df72920d6 100644 --- a/modules/dataLoader/ChromaBaseDataLoader.py +++ b/modules/dataLoader/ChromaBaseDataLoader.py @@ -139,7 +139,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=64, + resolution_quantization=64, ) factory.register(BaseDataLoader, ChromaBaseDataLoader, ModelType.CHROMA_1) diff --git a/modules/dataLoader/Flux2BaseDataLoader.py b/modules/dataLoader/Flux2BaseDataLoader.py index be02f24a4..d1f5a332d 100644 --- a/modules/dataLoader/Flux2BaseDataLoader.py +++ b/modules/dataLoader/Flux2BaseDataLoader.py @@ -157,7 +157,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=64, + resolution_quantization=64, ) diff --git a/modules/dataLoader/FluxBaseDataLoader.py b/modules/dataLoader/FluxBaseDataLoader.py index 93cc44aa6..9c53f49b1 100644 --- a/modules/dataLoader/FluxBaseDataLoader.py +++ b/modules/dataLoader/FluxBaseDataLoader.py @@ -179,7 +179,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=64, + resolution_quantization=64, ) factory.register(BaseDataLoader, FluxBaseDataLoader, ModelType.FLUX_DEV_1) diff --git a/modules/dataLoader/HiDreamBaseDataLoader.py b/modules/dataLoader/HiDreamBaseDataLoader.py index dedea17be..71114fd9c 100644 --- a/modules/dataLoader/HiDreamBaseDataLoader.py +++ b/modules/dataLoader/HiDreamBaseDataLoader.py @@ -217,7 +217,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=64, + resolution_quantization=64, ) factory.register(BaseDataLoader, HiDreamBaseDataLoader, ModelType.HI_DREAM_FULL) diff --git a/modules/dataLoader/HunyuanVideoBaseDataLoader.py b/modules/dataLoader/HunyuanVideoBaseDataLoader.py index d587a5fd9..e6f617358 100644 --- a/modules/dataLoader/HunyuanVideoBaseDataLoader.py +++ b/modules/dataLoader/HunyuanVideoBaseDataLoader.py @@ -178,7 +178,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=64, + resolution_quantization=64, frame_dim_enabled=True, allow_video_files=True, vae_frame_dim=True, diff --git a/modules/dataLoader/PixArtAlphaBaseDataLoader.py b/modules/dataLoader/PixArtAlphaBaseDataLoader.py index f2dc37857..656933ee6 100644 --- a/modules/dataLoader/PixArtAlphaBaseDataLoader.py +++ b/modules/dataLoader/PixArtAlphaBaseDataLoader.py @@ -157,7 +157,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=16, + resolution_quantization=16, ) factory.register(BaseDataLoader, PixArtAlphaBaseDataLoader, ModelType.PIXART_ALPHA) diff --git a/modules/dataLoader/QwenBaseDataLoader.py b/modules/dataLoader/QwenBaseDataLoader.py index a38987289..3572f0efc 100644 --- a/modules/dataLoader/QwenBaseDataLoader.py +++ b/modules/dataLoader/QwenBaseDataLoader.py @@ -149,7 +149,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=64, + resolution_quantization=64, allow_video_files=False, #don't allow video files, but... vae_frame_dim=True, #...Qwen has a video-capable VAE. convert images to video dimensions ) diff --git a/modules/dataLoader/SanaBaseDataLoader.py b/modules/dataLoader/SanaBaseDataLoader.py index 38d5c31b0..6062ce9b7 100644 --- a/modules/dataLoader/SanaBaseDataLoader.py +++ b/modules/dataLoader/SanaBaseDataLoader.py @@ -150,7 +150,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=32, + resolution_quantization=32, ) factory.register(BaseDataLoader, SanaBaseDataLoader, ModelType.SANA) diff --git a/modules/dataLoader/StableDiffusion3BaseDataLoader.py b/modules/dataLoader/StableDiffusion3BaseDataLoader.py index c497a20c6..3214b4d69 100644 --- a/modules/dataLoader/StableDiffusion3BaseDataLoader.py +++ b/modules/dataLoader/StableDiffusion3BaseDataLoader.py @@ -196,7 +196,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=64, + resolution_quantization=64, ) factory.register(BaseDataLoader, StableDiffusion3BaseDataLoader, ModelType.STABLE_DIFFUSION_35) diff --git a/modules/dataLoader/StableDiffusionBaseDataLoader.py b/modules/dataLoader/StableDiffusionBaseDataLoader.py index 781f769a6..cc6932d51 100644 --- a/modules/dataLoader/StableDiffusionBaseDataLoader.py +++ b/modules/dataLoader/StableDiffusionBaseDataLoader.py @@ -162,7 +162,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=8, + resolution_quantization=8, ) factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_15) diff --git a/modules/dataLoader/StableDiffusionFineTuneVaeDataLoader.py b/modules/dataLoader/StableDiffusionFineTuneVaeDataLoader.py index ed5dd32b8..44f50f2dd 100644 --- a/modules/dataLoader/StableDiffusionFineTuneVaeDataLoader.py +++ b/modules/dataLoader/StableDiffusionFineTuneVaeDataLoader.py @@ -102,8 +102,10 @@ def __mask_augmentation_modules(self, config: TrainConfig) -> list: def __aspect_bucketing_in(self, config: TrainConfig): calc_aspect = CalcAspect(image_in_name='image', resolution_out_name='original_resolution') + quantization = 8 + aspect_bucketing = AspectBucketing( - quantization=8, + quantization=quantization, resolution_in_name='original_resolution', target_resolution_in_name='settings.target_resolution', enable_target_resolutions_override_in_name='concept.image.enable_resolution_override', @@ -122,7 +124,8 @@ def __aspect_bucketing_in(self, config: TrainConfig): target_resolutions_override_in_name='concept.image.resolution_override', scale_resolution_out_name='scale_resolution', crop_resolution_out_name='crop_resolution', - possible_resolutions_out_name='possible_resolutions' + possible_resolutions_out_name='possible_resolutions', + quantization=quantization, ) modules = [calc_aspect] diff --git a/modules/dataLoader/StableDiffusionXLBaseDataLoader.py b/modules/dataLoader/StableDiffusionXLBaseDataLoader.py index ed1ad491e..8ca405199 100644 --- a/modules/dataLoader/StableDiffusionXLBaseDataLoader.py +++ b/modules/dataLoader/StableDiffusionXLBaseDataLoader.py @@ -174,7 +174,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=64, + resolution_quantization=64, ) factory.register(BaseDataLoader, StableDiffusionXLBaseDataLoader, ModelType.STABLE_DIFFUSION_XL_10_BASE) factory.register(BaseDataLoader, StableDiffusionXLBaseDataLoader, ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING) diff --git a/modules/dataLoader/WuerstchenBaseDataLoader.py b/modules/dataLoader/WuerstchenBaseDataLoader.py index f5cf2f41a..d378abad1 100644 --- a/modules/dataLoader/WuerstchenBaseDataLoader.py +++ b/modules/dataLoader/WuerstchenBaseDataLoader.py @@ -153,7 +153,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=128, + resolution_quantization=128, supports_inpainting=False, ) diff --git a/modules/dataLoader/ZImageBaseDataLoader.py b/modules/dataLoader/ZImageBaseDataLoader.py index 3b9450024..978d232c9 100644 --- a/modules/dataLoader/ZImageBaseDataLoader.py +++ b/modules/dataLoader/ZImageBaseDataLoader.py @@ -136,7 +136,7 @@ def _create_dataset( ): return DataLoaderText2ImageMixin._create_dataset(self, config, model, model_setup, train_progress, is_validation, - aspect_bucketing_quantization=64, + resolution_quantization=64, ) factory.register(BaseDataLoader, ZImageBaseDataLoader, ModelType.Z_IMAGE) diff --git a/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py b/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py index 2654bdd19..8fe50c83a 100644 --- a/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py +++ b/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py @@ -151,11 +151,11 @@ def _mask_augmentation_modules(self, config: TrainConfig) -> list: return modules - def _aspect_bucketing_in(self, config: TrainConfig, aspect_bucketing_quantization: int, frame_dim_enabled:bool=False): + def _aspect_bucketing_in(self, config: TrainConfig, resolution_quantization: int, frame_dim_enabled:bool=False): calc_aspect = CalcAspect(image_in_name='image', resolution_out_name='original_resolution') - aspect_bucketing_quantization = AspectBucketing( - quantization=aspect_bucketing_quantization, + aspect_bucketing = AspectBucketing( + quantization=resolution_quantization, resolution_in_name='original_resolution', target_resolution_in_name='settings.target_resolution', enable_target_resolutions_override_in_name='concept.image.enable_resolution_override', @@ -174,13 +174,14 @@ def _aspect_bucketing_in(self, config: TrainConfig, aspect_bucketing_quantizatio target_resolutions_override_in_name='concept.image.resolution_override', scale_resolution_out_name='scale_resolution', crop_resolution_out_name='crop_resolution', - possible_resolutions_out_name='possible_resolutions' + possible_resolutions_out_name='possible_resolutions', + quantization=resolution_quantization, ) modules = [calc_aspect] if config.aspect_ratio_bucketing: - modules.append(aspect_bucketing_quantization) + modules.append(aspect_bucketing) else: modules.append(single_aspect_calculation) @@ -384,7 +385,7 @@ def _create_dataset( model_setup: ModelSetupText2ImageMixin, train_progress: TrainProgress, is_validation: bool, - aspect_bucketing_quantization: int, + resolution_quantization: int, frame_dim_enabled: bool=False, allow_video_files: bool=False, vae_frame_dim: bool=False, @@ -393,7 +394,7 @@ def _create_dataset( enumerate_input = self._enumerate_input_modules(config, allow_videos=allow_video_files) load_input = self._load_input_modules(config, model.train_dtype, vae_frame_dim=vae_frame_dim) mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, aspect_bucketing_quantization, frame_dim_enabled) + aspect_bucketing_in = self._aspect_bucketing_in(config, resolution_quantization, frame_dim_enabled) crop_modules = self._crop_modules(config) augmentation_modules = self._augmentation_modules(config) if supports_inpainting: diff --git a/requirements-global.txt b/requirements-global.txt index 7c3254adc..6602e93de 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -33,7 +33,7 @@ pooch==1.8.2 open-clip-torch==2.32.0 # data loader --e git+https://github.com/Nerogar/mgds.git@a0c84a3#egg=mgds +-e git+https://github.com/Nerogar/mgds.git@TODO_PLACEHOLDER_CHANGE_ME_IF_PR_LANDS#egg=mgds # optimizers dadaptation==3.2 # dadaptation optimizers