Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/mgds/MGDS.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random

import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch.utils.data import DataLoader, IterableDataset

from mgds.ConceptPipelineModule import ConceptPipelineModule
from mgds.LoadingPipeline import LoadingPipeline
Expand All @@ -19,7 +19,7 @@ def __init__(
device: torch.device,
concepts: list[dict],
settings: dict,
definition: [PipelineModule],
definition: list[PipelineModule],
batch_size: int, #local batch size
state: PipelineState,
seed: int = 42,
Expand Down
5 changes: 4 additions & 1 deletion src/mgds/PipelineModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from concurrent import futures
from contextlib import ExitStack
from random import Random
from typing import TYPE_CHECKING

import torch

from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule
from mgds.pipelineModuleTypes.SerialPipelineModule import SerialPipelineModule
from mgds.pipelineModuleTypes.SingleVariationRandomAccessPipelineModule import SingleVariationRandomAccessPipelineModule

if TYPE_CHECKING:
from mgds.LoadingPipeline import LoadingPipeline


class PipelineState:
"""Container for state shared amongst all pipeline modules in a pipeline.
Expand Down Expand Up @@ -142,7 +146,6 @@ def _get_previous_item(self, variation: int, name: str, index: int):
def _get_previous_length(self, name: str):
split_name = name.split('.')
item_name = split_name[0]
path_names = split_name[1::]

for previous_module_index in range(self.__module_index - 1, -1, -1):
module = self.pipeline.modules[previous_module_index]
Expand Down
4 changes: 2 additions & 2 deletions src/mgds/pipelineModules/DropTags.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_item(self, variation: int, index: int, requested_name: str = None) -> di
for s in dropout_tags:
if special_tag_mode == "WHITELIST" and s in special_tags_list:
pruned_tags.append(s)
elif special_tag_mode == "BLACKLIST" and not(s in special_tags_list):
elif special_tag_mode == "BLACKLIST" and s not in special_tags_list:
pruned_tags.append(s)
elif dropout_mode.startswith("RANDOM"):
#iterate through dropout_tags and add to pruned_tags if random > probability
Expand All @@ -156,7 +156,7 @@ def get_item(self, variation: int, index: int, requested_name: str = None) -> di
elif special_tag_mode == "BLACKLIST":
for i, s in enumerate(dropout_tags):
if rand.random() > self.probability_weighted(probability, dropout_mode, i, len(dropout_tags)) \
or not(s in special_tags_list):
or s not in special_tags_list:
pruned_tags.append(s)
else: #NONE or any other unexpected values
for i, s in enumerate(dropout_tags):
Expand Down
2 changes: 0 additions & 2 deletions src/mgds/pipelineModules/ImageToVideo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast, GemmaTokenizer, LlamaTokenizer

from mgds.PipelineModule import PipelineModule
from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule

Expand Down
2 changes: 1 addition & 1 deletion src/mgds/pipelineModules/LoadVideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get_item(self, variation: int, index: int, requested_name: str = None) -> di

except FileNotFoundError:
video_tensor = None
except Exception as e:
except Exception:
print("could not load video, it might be corrupted: " + path)
raise
else:
Expand Down