Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execution Model Inversion #2666

Merged
merged 39 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
36b2214
Execution Model Inversion
guill Jan 29, 2024
e4e20d7
Allow `input_info` to be of type `None`
guill Feb 15, 2024
2c7145d
Handle errors (like OOM) more gracefully
guill Feb 15, 2024
12627ca
Add a command-line argument to enable variants
guill Feb 15, 2024
9c1e3f7
Fix an overly aggressive assertion.
guill Feb 18, 2024
508d286
Fix Pyright warnings
guill Feb 18, 2024
fff2283
Add execution model unit tests
guill Feb 18, 2024
e60dbe3
Fix issue with unused literals
guill Feb 22, 2024
5ab1565
Merge branch 'master' into execution_model_inversion
guill Feb 22, 2024
6d09dd7
Make custom VALIDATE_INPUTS skip normal validation
guill Feb 25, 2024
6b6a93c
Merge branch 'master' into execution_model_inversion
guill Mar 23, 2024
03394ac
Fix example in unit test
guill Mar 23, 2024
a0bf532
Use fstrings instead of '%' formatting syntax
guill Apr 21, 2024
5dc1365
Use custom exception types.
guill Apr 21, 2024
dd3bafb
Display an error for dependency cycles
guill Apr 21, 2024
7dbee88
Add docs on when ExecutionBlocker should be used
guill Apr 21, 2024
b5e4583
Remove unused functionality
guill Apr 21, 2024
75774c6
Rename ExecutionResult.SLEEPING to PENDING
guill Apr 21, 2024
ecbef30
Remove superfluous function parameter
guill Apr 21, 2024
1f06588
Pass None for uneval inputs instead of default
guill Apr 21, 2024
2dda3f2
Add a test for mixed node expansion
guill Apr 21, 2024
06f3ce9
Raise exception for bad get_node calls.
guill Apr 21, 2024
b3e547f
Merge branch 'master' into execution_model_inversion
guill Apr 22, 2024
fa48ad3
Minor refactor of IsChangedCache.get
guill Apr 22, 2024
afa4c7b
Refactor `map_node_over_list` function
guill Apr 22, 2024
8d17f3c
Fix ui output for duplicated nodes
guill Jun 17, 2024
9d62456
Merge branch 'master' into execution_model_inversion
guill Jun 17, 2024
4712df8
Add documentation on `check_lazy_status`
guill Jun 17, 2024
85d046b
Merge branch 'master' into execution_model_inversion
guill Jul 21, 2024
48d03c4
Add file for execution model unit tests
guill Jul 21, 2024
64e3a43
Clean up Javascript code as per review
guill Aug 2, 2024
bb5de4d
Improve documentation
guill Aug 2, 2024
c4666bf
Add a new unit test for mixed lazy results
guill Aug 2, 2024
887ceb3
Merge branch 'master' into execution_model_inversion
guill Aug 2, 2024
655548d
Merge branch 'master' into execution_model_inversion
guill Aug 7, 2024
36131f0
Allow kwargs in VALIDATE_INPUTS functions
guill Aug 8, 2024
fd7229e
List cached nodes in `execution_cached` message
guill Aug 8, 2024
e73c917
Merge branch 'master' into execution_model_inversion
guill Aug 10, 2024
519d08e
Merge branch 'master' into execution_model_inversion
guill Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 299 additions & 0 deletions comfy/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
import itertools
from typing import Sequence, Mapping
from comfy.graph import DynamicPrompt

import nodes

from comfy.graph_utils import is_link

class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {}
self.subcache_keys = {}

def add_keys(self, node_ids):
raise NotImplementedError()

def all_node_ids(self):
return set(self.keys.keys())

def get_used_keys(self):
return self.keys.values()

def get_used_subcache_keys(self):
return self.subcache_keys.values()

def get_data_key(self, node_id):
return self.keys.get(node_id, None)

def get_subcache_key(self, node_id):
return self.subcache_keys.get(node_id, None)

class Unhashable:
def __init__(self):
self.value = float("NaN")

def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
elif isinstance(obj, Sequence):
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
else:
# TODO - Support other objects like tensors?
mcmonkey4eva marked this conversation as resolved.
Show resolved Hide resolved
return Unhashable()

class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.add_keys(node_ids)

def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"])

class CacheKeySetInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
self.add_keys(node_ids)

def include_node_id_in_input(self) -> bool:
return False

def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])

def get_node_signature(self, dynprompt, node_id):
signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)

def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()):
if is_link(inputs[key]):
(ancestor_id, ancestor_socket) = inputs[key]
ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, inputs[key]))
return signature

# This function returns a list of all ancestors of the given node. The order of the list is
# deterministic based on which specific inputs the ancestor is connected by.
def get_ordered_ancestry(self, dynprompt, node_id):
ancestors = []
order_mapping = {}
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
return ancestors, order_mapping

def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
inputs = dynprompt.get_node(node_id)["inputs"]
input_keys = sorted(inputs.keys())
for key in input_keys:
if is_link(inputs[key]):
ancestor_id = inputs[key][0]
if ancestor_id not in order_mapping:
ancestors.append(ancestor_id)
order_mapping[ancestor_id] = len(ancestors) - 1
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)

class BasicCache:
def __init__(self, key_class):
self.key_class = key_class
self.initialized = False
self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}

def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
self.is_changed_cache = is_changed_cache
self.initialized = True

def all_node_ids(self):
assert self.initialized
node_ids = self.cache_key_set.all_node_ids()
for subcache in self.subcaches.values():
node_ids = node_ids.union(subcache.all_node_ids())
return node_ids

def _clean_cache(self):
preserve_keys = set(self.cache_key_set.get_used_keys())
to_remove = []
for key in self.cache:
if key not in preserve_keys:
to_remove.append(key)
for key in to_remove:
del self.cache[key]

def _clean_subcaches(self):
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())

to_remove = []
for key in self.subcaches:
if key not in preserve_subcaches:
to_remove.append(key)
for key in to_remove:
del self.subcaches[key]

def clean_unused(self):
assert self.initialized
self._clean_cache()
self._clean_subcaches()

def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value

def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
else:
return None

def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None)
if subcache is None:
subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache

def _get_subcache(self, node_id):
assert self.initialized
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
return self.subcaches[subcache_key]
else:
return None

def recursive_debug_dump(self):
result = []
for key in self.cache:
result.append({"key": key, "value": self.cache[key]})
for key in self.subcaches:
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
return result

class HierarchicalCache(BasicCache):
guill marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, key_class):
super().__init__(key_class)

def _get_cache_for(self, node_id):
assert self.dynprompt is not None
parent_id = self.dynprompt.get_parent_node_id(node_id)
if parent_id is None:
return self

hierarchy = []
while parent_id is not None:
hierarchy.append(parent_id)
parent_id = self.dynprompt.get_parent_node_id(parent_id)

cache = self
for parent_id in reversed(hierarchy):
cache = cache._get_subcache(parent_id)
if cache is None:
return None
return cache

def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return cache._get_immediate(node_id)

def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
cache._set_immediate(node_id, value)

def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._ensure_subcache(node_id, children_ids)

class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
self.max_size = max_size
self.min_generation = 0
self.generation = 0
self.used_generation = {}
self.children = {}

def set_prompt(self, dynprompt, node_ids, is_changed_cache):
super().set_prompt(dynprompt, node_ids, is_changed_cache)
self.generation += 1
for node_id in node_ids:
self._mark_used(node_id)

def clean_unused(self):
while len(self.cache) > self.max_size and self.min_generation < self.generation:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: what happens when self.max_size is less than the number of nodes in the workflow? AFAIU, all nodes will be cached anyway until the next call to clean_unused, which happens on the next prompt? If so, it seems like an opportunity to save some memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the LRU caching method will always keep around at least one prompt worth of cache (just like the default caching method). If we immediately cleared out the cache after the workflow, it would technically save memory while ComfyUI was sitting idle, but would entirely break caching for extremely large workflows.

Because the LRU caching method is specifically intended as a more memory-intensive (but more convenient) caching mode for people with plenty of RAM/VRAM, I don't think it makes sense for it to be more conservative with memory than the default caching method.

self.min_generation += 1
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
for key in to_remove:
del self.cache[key]
del self.used_generation[key]
if key in self.children:
del self.children[key]
self._clean_subcaches()

def get(self, node_id):
self._mark_used(node_id)
return self._get_immediate(node_id)

def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation

def set(self, node_id, value):
self._mark_used(node_id)
return self._set_immediate(node_id, value)

def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
super()._ensure_subcache(node_id, children_ids)

self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id)
self.children[cache_key] = []
for child_id in children_ids:
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self

4 changes: 4 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)

cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")

attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
Expand Down
Loading
Loading