Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 57 additions & 33 deletions src/mgds/pipelineModules/AspectBucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_inputs(self) -> list[str]:
def get_outputs(self) -> list[str]:
return [self.scale_resolution_out_name, self.crop_resolution_out_name, self.possible_resolutions_out_name]

def __quantize_resolution(self, resolution: tuple[float, float], quantization: int) -> tuple[int, int]:
def __quantize_resolution(self, resolution: tuple[float|int, float|int], quantization: int) -> tuple[int, int]:
return (
round(resolution[0] / quantization) * quantization,
round(resolution[1] / quantization) * quantization,
Expand Down Expand Up @@ -123,34 +123,54 @@ def get_meta(self, variation: int, name: str) -> Any:
return None

def start(self, variation: int):
possible_target_resolutions = set()
possible_fixed_resolutions = set()
possible_target_resolutions: set[int] = set()
possible_fixed_resolutions: set[tuple[int, int]] = set()
possible_frames = {1}

resolutions_warned_about_rounding: set[tuple[int, int]] = set()
def _add_resolution_if_new(resolution: int|tuple[int, int], quantization: int):
resolution_2d = (resolution, resolution) if isinstance(resolution, int) else resolution
quantized_resolution = self.__quantize_resolution(resolution_2d, quantization)

# Warn the user if we are rounding their preferred resolution.
if quantized_resolution != resolution_2d and resolution not in resolutions_warned_about_rounding:
resolutions_warned_about_rounding.add(resolution)
print(f'Warning: Resolution {resolution_2d[1]}x{resolution_2d[0]}'
f' rounded to {quantized_resolution[1]}x{quantized_resolution[0]}'
f' because image model requires multiples of {quantization}.')

if isinstance(resolution, int):
# We are a bucketable resolution, store single dimension input value
possible_target_resolutions.add(resolution)
else:
# We are a fixed-dimension resolution, store quantized value
possible_fixed_resolutions.add(quantized_resolution)

# Default resolution(s)
for index in range(self._get_previous_length(self.target_resolutions_in_name)):
resolutions = self._get_previous_item(variation, self.target_resolutions_in_name, index)
if 'x' in resolutions and ',' not in resolutions:
res = resolutions.strip().split('x')
possible_fixed_resolutions.add(
self.__quantize_resolution(
(int(res[1]), int(res[0])), self.quantization
)
)
res = resolutions.split('x', 1)
_add_resolution_if_new((int(res[1].strip()), int(res[0].strip())), self.quantization)
else:
possible_target_resolutions |= set([int(res.strip()) for res in resolutions.split(',')])
for res in resolutions.split(','):
_add_resolution_if_new(int(res.strip()), self.quantization)

if self.target_resolutions_override_in_name is not None:
# Resolution override(s)
if (self.target_resolutions_override_in_name is not None and
self.enable_target_resolutions_override_in_name is not None
):
for index in range(self._get_previous_length(self.target_resolutions_override_in_name)):
resolutions = self._get_previous_item(variation, self.target_resolutions_override_in_name, index)
if 'x' in resolutions and ',' not in resolutions:
res = resolutions.strip().split('x')
possible_fixed_resolutions.add(
self.__quantize_resolution(
(int(res[1]), int(res[0])), self.quantization
)
)
else:
possible_target_resolutions |= set([int(res.strip()) for res in resolutions.split(',')])
enable_resolution_override = self._get_previous_item(
variation, self.enable_target_resolutions_override_in_name, index)
if enable_resolution_override:
override_resolutions = self._get_previous_item(variation, self.target_resolutions_override_in_name, index)
if 'x' in override_resolutions and ',' not in override_resolutions:
res = override_resolutions.split('x', 1)
_add_resolution_if_new((int(res[1].strip()), int(res[0].strip())), self.quantization)
else:
for res in override_resolutions.split(','):
_add_resolution_if_new(int(res.strip()), self.quantization)

for index in range(self._get_previous_length(self.target_frames_in_name)):
frames = self._get_previous_item(variation, self.target_frames_in_name, index)
Expand All @@ -171,24 +191,28 @@ def start(self, variation: int):
def get_item(self, variation: int, index: int, requested_name: str = None) -> dict:
rand = self._get_rand(variation, index)
resolution = self._get_previous_item(variation, self.resolution_in_name, index)
target_resolutions = self._get_previous_item(variation, self.target_resolutions_in_name, index)

if self.enable_target_resolutions_override_in_name is not None:
enable_resolution_override = self._get_previous_item(
variation, self.enable_target_resolutions_override_in_name, index)
if enable_resolution_override:
target_resolutions = self._get_previous_item(variation, self.target_resolutions_override_in_name, index)
if (self.enable_target_resolutions_override_in_name is not None and
self._get_previous_item(variation, self.enable_target_resolutions_override_in_name, index)
):
# Use override resolution(s)
target_resolutions = self._get_previous_item(variation, self.target_resolutions_override_in_name, index)
else:
# Use base resolution(s)
target_resolutions = self._get_previous_item(variation, self.target_resolutions_in_name, index)

if 'x' in target_resolutions and ',' not in target_resolutions:
res = target_resolutions.strip().split('x')
# Get quantized resolution from a fixed resolution
res = target_resolutions.split('x', 1)
target_resolution = self.__quantize_resolution(
(int(res[1]), int(res[0])), self.quantization
(int(res[1].strip()), int(res[0].strip())), self.quantization
)
else:
target_resolutions = [int(res.strip()) for res in target_resolutions.split(',')]

target_resolution = rand.choice(target_resolutions)
target_resolution = self.__get_bucket(rand, resolution[-2], resolution[-1], target_resolution)
# Get quantized resolution bucket from a random single-dim resolution
target_resolution_list = [int(res.strip())
for res in target_resolutions.split(',')]
random_resolution = rand.choice(target_resolution_list)
target_resolution = self.__get_bucket(rand, resolution[-2], resolution[-1], random_resolution)

aspect = resolution[-2] / resolution[-1]
target_aspect = target_resolution[-2] / target_resolution[-1]
Expand Down
106 changes: 82 additions & 24 deletions src/mgds/pipelineModules/SingleAspectCalculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __init__(
target_resolutions_override_in_name: str,
scale_resolution_out_name: str,
crop_resolution_out_name: str,
possible_resolutions_out_name: str
possible_resolutions_out_name: str,
quantization: int|None = None,
):
super(SingleAspectCalculation, self).__init__()

Expand All @@ -30,7 +31,9 @@ def __init__(
self.crop_resolution_out_name = crop_resolution_out_name
self.possible_resolutions_out_name = possible_resolutions_out_name

self.possible_target_resolutions = []
self.possible_target_resolutions: list[tuple[int, int]] = []

self.quantization = quantization if quantization is not None else 1

def length(self) -> int:
return self._get_previous_length(self.resolution_in_name)
Expand All @@ -43,45 +46,100 @@ def get_outputs(self) -> list[str]:

def get_meta(self, variation: int, name: str) -> Any:
if name == self.possible_resolutions_out_name:
return [(x, x) for x in self.possible_target_resolutions]
return self.possible_target_resolutions.copy()
else:
return None

def __quantize_resolution(self,
resolution: int|tuple[int, int],
quantization: int) -> tuple[int, int]:
if isinstance(resolution, int):
resolution = (resolution, resolution)

if quantization == 1:
return resolution

return (
round(resolution[0] / quantization) * quantization,
round(resolution[1] / quantization) * quantization,
)

def start(self, variation: int):
possible_target_resolutions = set()
possible_target_resolutions: set[tuple[int, int]] = set()
resolutions_warned_about_rounding: set[tuple[int, int]] = set()

def _add_resolution_if_new(resolution: int|tuple[int, int], quantization: int):
resolution_2d = (resolution, resolution) if isinstance(resolution, int) else resolution
quantized_resolution = self.__quantize_resolution(resolution_2d, quantization)

# Warn the user if we are rounding their preferred resolution.
if quantized_resolution != resolution_2d and resolution not in resolutions_warned_about_rounding:
resolutions_warned_about_rounding.add(resolution)
print(f'Warning: Resolution {resolution_2d[1]}x{resolution_2d[0]}'
f' rounded to {quantized_resolution[1]}x{quantized_resolution[0]}'
f' because image model requires multiples of {quantization}.')

possible_target_resolutions.add(quantized_resolution)

# Default resolution
for index in range(self._get_previous_length(self.target_resolutions_in_name)):
resolutions = self._get_previous_item(variation, self.target_resolutions_in_name, index)
if isinstance(resolutions, int):
possible_target_resolutions.add(resolutions)
_add_resolution_if_new(resolutions, self.quantization)
elif isinstance(resolutions, str):
possible_target_resolutions |= set([int(res.strip()) for res in resolutions.split(',')])

if self.target_resolutions_override_in_name is not None:
if 'x' in resolutions and ',' not in resolutions:
res = resolutions.split('x', 1)
_add_resolution_if_new((int(res[1].strip()), int(res[0].strip())), self.quantization)
else:
for res in resolutions.split(','):
_add_resolution_if_new(int(res.strip()), self.quantization)

# Resolution override(s)
if (self.target_resolutions_override_in_name is not None and
self.enable_target_resolutions_override_in_name is not None
):
for index in range(self._get_previous_length(self.target_resolutions_override_in_name)):
resolutions = self._get_previous_item(variation, self.target_resolutions_override_in_name, index)
if isinstance(resolutions, int):
possible_target_resolutions.add(resolutions)
elif isinstance(resolutions, str):
possible_target_resolutions |= set([int(res.strip()) for res in resolutions.split(',')])
enable_resolution_override = self._get_previous_item(
variation, self.enable_target_resolutions_override_in_name, index)
if enable_resolution_override:
resolutions = self._get_previous_item(variation, self.target_resolutions_override_in_name, index)
if isinstance(resolutions, int):
_add_resolution_if_new(resolutions, self.quantization)
elif isinstance(resolutions, str):
if 'x' in resolutions and ',' not in resolutions:
res = resolutions.split('x', 1)
_add_resolution_if_new((int(res[1].strip()), int(res[0].strip())), self.quantization)
else:
for res in resolutions.split(','):
_add_resolution_if_new(int(res.strip()),self.quantization)

self.possible_target_resolutions = list(possible_target_resolutions)

def get_item(self, variation: int, index: int, requested_name: str = None) -> dict:
rand = self._get_rand(variation, index)
resolution = self._get_previous_item(variation, self.resolution_in_name, index)
target_resolutions = self._get_previous_item(variation, self.target_resolutions_in_name, index)

if self.enable_target_resolutions_override_in_name is not None:
enable_resolution_override = self._get_previous_item(
variation, self.enable_target_resolutions_override_in_name, index)
if enable_resolution_override:
target_resolutions = self._get_previous_item(variation, self.target_resolutions_override_in_name, index)

target_resolutions = [int(res.strip()) for res in target_resolutions.split(',')]

target_resolution = rand.choice(target_resolutions)
target_resolution = (target_resolution, target_resolution)
if (self.enable_target_resolutions_override_in_name is not None and
self._get_previous_item(variation, self.enable_target_resolutions_override_in_name, index)
):
# Use override resolution(s)
target_resolutions = self._get_previous_item(variation, self.target_resolutions_override_in_name, index)
else:
# Use base resolution(s)
target_resolutions = self._get_previous_item(variation, self.target_resolutions_in_name, index)

if 'x' in target_resolutions and ',' not in target_resolutions:
# Get quantized resolution from a fixed resolution
res = target_resolutions.split('x', 1)
target_resolution = self.__quantize_resolution(
(int(res[1].strip()), int(res[0].strip())), self.quantization
)
else:
# Get quantized resolution from a random single-dim resolution
target_resolution_list = [int(res.strip())
for res in target_resolutions.split(',')]
random_resolution = rand.choice(target_resolution_list)
target_resolution = self.__quantize_resolution(random_resolution, self.quantization)

aspect = resolution[0] / resolution[1]
target_aspect = target_resolution[0] / target_resolution[1]
Expand Down