Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,9 @@ def __del__(self):
self._patcher_finalizer.detach()

def is_dead(self):
return self.real_model() is not None and self.model is None
# real_model is plain None until model_load(); guard so a pre-load entry can't raise.
rm = self.real_model
return rm is not None and rm() is not None and self.model is None


def use_more_memory(extra_memory, loaded_models, device):
Expand Down Expand Up @@ -808,28 +810,42 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins

for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
# Pin the patcher; LoadedModel.model is a weakref deref a mid-loop finalizer can None.
model = shift_model.model
if model is None:
continue
if device is None or shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
# Carry the object so a reentrant cleanup_models pop can't stale the index; i is just a sort tiebreaker.
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(model), shift_model.model_memory(), i, shift_model))
shift_model.currently_used = False

can_unload_sorted = sorted(can_unload)
for x in can_unload_sorted:
i = x[-1]
shift_model = x[-1]
# Pin: shift_model.model still re-derefs the weakref a reentrant cleanup_models can None.
model = shift_model.model
if model is None:
continue
memory_to_free = 1e32
if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
if model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
memory_required -= model.loaded_size()
memory_to_free = 0
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i)

for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
if memory_to_free > 0 and shift_model.model_unload(memory_to_free):
logging.debug(f"Unloading {model.model.__class__.__name__}")
unloaded_model.append(shift_model)

for shift_model in unloaded_model:
unloaded_models.append(shift_model)
# Remove by identity (model is None post-unload, so __eq__ is unsafe); tolerate a reentrant pop.
for idx in range(len(current_loaded_models) - 1, -1, -1):
if current_loaded_models[idx] is shift_model:
current_loaded_models.pop(idx)
break

if not for_dynamic and pins_required > 0:
ensure_pin_budget(pins_required)
Expand Down Expand Up @@ -988,12 +1004,17 @@ def archive_model_dtypes(model):
def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):
if current_loaded_models[i].real_model() is None:
to_delete = [i] + to_delete

for i in to_delete:
x = current_loaded_models.pop(i)
del x
# real_model can still be the plain None from __init__ here; guard the call.
rm = current_loaded_models[i].real_model
if rm is not None and rm() is None:
to_delete.append(current_loaded_models[i])

for lm in to_delete:
# Remove by identity, mirroring free_memory; tolerate a reentrant pop.
for idx in range(len(current_loaded_models) - 1, -1, -1):
if current_loaded_models[idx] is lm:
current_loaded_models.pop(idx)
break

def dtype_size(dtype):
dtype_size = 4
Expand Down