Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and refine graph save runtime_state_dict #10016

Merged
merged 6 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions python/oneflow/nn/graph/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,52 @@
class LRUCache(object):
_cnt: int = 0

def __init__(self, cache_size):
def __init__(self, cache_size, keep_the_1st=True):
assert cache_size >= 2
self.cache_size = cache_size
self.queue = deque()
self.hash_map = dict()

def front(self):
if self.is_empty():
return None

key = self.queue[0]
return self.hash_map[key]
self.keep_the_1st = keep_the_1st
self.queue = deque()

def is_empty(self):
return len(self.queue) == 0
return len(self.hash_map) == 0

def is_queue_full(self):
return len(self.queue) >= self.cache_size
def is_full(self):
return len(self.hash_map) >= self.cache_size

def pop(self):
if len(self.queue) == 0:
return None
pop_key = self.queue.pop()
value = self.hash_map.pop(pop_key)
del value
return pop_key

def set(self, key, value):
new_key = None
old_key = None
if key in self.hash_map:
return None
return new_key, old_key

pop_key = None
while self.is_queue_full():
pop_key = self.pop()
if self.is_full():
old_key = self.pop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 pop 是不是会把最久远 base 的 key 给 pop 出来

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 pop 是不是会把最久远 base 的 key 给 pop 出来

是的

assert old_key is not None, f"Cache size is {self.cache_size}, at least 2."
assert not self.is_full()

if not (self.keep_the_1st and self.is_empty()):
self.queue.appendleft(key)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

避免第一个 base 的图被淘汰

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是如何避免第一个 base 的图被淘汰的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是如何避免第一个 base 的图被淘汰的?

self.queue 存的可以淘汰的cache key,LRU的策略,top是最老的,end是最新的,然后根据使用情况调整顺序,维护 LRU 需要的 cache key 排序。

这里避免 base 被淘汰的方法,是因为 base graph 一定是第一个加入 cache 的,不把它加入 self.queue 他就一定不会被淘汰了。


self.queue.appendleft(key)
value._oneflow_graph_cache_order = LRUCache._cnt
LRUCache._cnt += 1
self.hash_map[key] = value
return pop_key if pop_key is not None else key
new_key = key
return new_key, old_key

def get(self, key):
if key in self.hash_map:
self.queue.remove(key)
self.queue.appendleft(key)
if key in self.queue:
self.queue.remove(key)
self.queue.appendleft(key)
return self.hash_map[key]

return None
Expand Down Expand Up @@ -111,21 +115,27 @@ def __call__(self, *args, **kwargs):
return graph(*args, **kwargs)

def runtime_state_dict(
self, destination=None
self, destination=None, with_eager=False,
) -> Dict[str, Dict[str, Union[Dict[str, Tensor], str]]]:
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()

for (key, graph) in self._cache.items():
with AvoidRecursiveCacheCall(graph):
state_dict = graph.runtime_state_dict()
state_dict = graph.runtime_state_dict(with_eager=with_eager)
state_dict["cache_order"] = graph._oneflow_graph_cache_order
state_dict["cache_key"] = key
destination[state_dict["graph_name"]] = state_dict
return destination

def _init_and_get_a_graph_in_cache(self, cache_key):
self._base_graph._print(
0,
0,
self._base_graph._shallow_repr()
+ f" is creating a graph cache with key {cache_key}.",
)
cur_is_base = False
if self._cache.is_empty():
# Has no graph yet
Expand All @@ -135,7 +145,7 @@ def _init_and_get_a_graph_in_cache(self, cache_key):
# Create new graph from base
graph = self._base_graph.__class__(
*self._base_graph._cached_init_args,
**self._base_graph._cached_init_kwargs
**self._base_graph._cached_init_kwargs,
)
graph._run_with_cache = False
graph._dynamic_input_graph_cache = None
Expand All @@ -147,8 +157,16 @@ def _init_and_get_a_graph_in_cache(self, cache_key):
graph.enable_shared()
else:
graph.share_from(self._base_graph)
ret = self._cache.set(cache_key, graph)
assert ret is not None
new_key, old_key = self._cache.set(cache_key, graph)
if old_key is not None:
self._base_graph._print(
0,
0,
self._base_graph._shallow_repr()
+ f" cache is full(cache size {self._cache_size}), has deleted an old graph cache with key {old_key}.",
)
assert new_key is not None

return graph

def load_runtime_state_dict(
Expand All @@ -159,9 +177,12 @@ def load_runtime_state_dict(
cache_order = sub_state_dict["cache_order"]
graph_dict[cache_order] = sub_state_dict

self._cache = LRUCache(self._cache_size)
if self._cache is None:
self._cache = LRUCache(self._cache_size)
for _, sub_state_dict in sorted(graph_dict.items()):
cache_key = sub_state_dict["cache_key"]
graph = self._cache.get(cache_key)
assert graph is None
graph = self._init_and_get_a_graph_in_cache(cache_key)
with AvoidRecursiveCacheCall(graph):
graph.load_runtime_state_dict(sub_state_dict)
Expand All @@ -183,12 +204,12 @@ def get_graph(self, *args, **kwargs):

# Create graph
if graph is None:
graph = self._init_and_get_a_graph_in_cache(cache_key)
self._base_graph._print(
0,
0,
self._base_graph._shallow_repr()
+ " got a new input shape, is compiling a new graph.",
)
graph = self._init_and_get_a_graph_in_cache(cache_key)

return graph
44 changes: 38 additions & 6 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@ def _filter_states(self):
self._variables_conf[state_tensor] = VariableConfig(op_name)

self._state_tensor_tuple = convert_to_tensor_tuple(state_tensors)
self._eager_state_op_names = deepcopy(state_op_names)
return state_op_names

def _generate_config_proto(self):
Expand Down Expand Up @@ -1011,13 +1012,15 @@ def enable_save_runtime_state_dict(self, mode: bool = True):
self._enable_save_runtime_state_dict = False

def runtime_state_dict(
self, destination=None
self, destination=None, with_eager=False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认不保存 eager module 上的 tensor

) -> Union[
Dict[str, Union[Dict[str, Tensor], str]],
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
]:
if self._run_with_cache == True:
return self._dynamic_input_graph_cache.runtime_state_dict()
return self._dynamic_input_graph_cache.runtime_state_dict(
with_eager=with_eager
)

assert (
self._enable_save_runtime_state_dict
Expand Down Expand Up @@ -1067,10 +1070,28 @@ def gen_index_in_tuple(eager_out):
)
destination["outputs"] = outputs_sub_destination

destination["oneflow_with_eager_tensor"] = with_eager
if not self._build_with_shared_graph:
_state_tensor_tuple4save = []
if with_eager:
_state_tensor_tuple4save = self._state_tensor_tuple
else:
assert len(self._state_tensor_tuple) == len(self._state_op_names)
for state_idx in range(len(self._state_tensor_tuple)):
if self._state_op_names[state_idx] in self._eager_state_op_names:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么情况下存在 eager module 持有,graph state 不持有的 tensor。 是被 constant folding 的 tensor 吗?

eager free tensor 是 graph state tensor 吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么情况下存在 eager module 持有,graph state 不持有的 tensor。 是被 constant folding 的 tensor 吗?

是的

eager free tensor 是 graph state tensor 吗?

嗯,也是。但是不是 eager module 上的。

# This state tensor is from eager module. Just save a dummy tensor here.
_state_tensor_tuple4save.append(
oneflow.Tensor().to(
self._state_tensor_tuple[state_idx].device
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不需要保存 eager module 上的 tensor【通常 eager module 上的 tensor 会通过 eager 做保存和加载】,则可以只存一个空 tensor,减小 runtime_state_dict 的大小

)
)
else:
_state_tensor_tuple4save.append(
self._state_tensor_tuple[state_idx]
)
states_sub_destination = OrderedDict()
_fill_sub_destination(
states_sub_destination, self._state_op_names, self._state_tensor_tuple
states_sub_destination, self._state_op_names, _state_tensor_tuple4save
)
destination["states"] = states_sub_destination

Expand Down Expand Up @@ -1140,6 +1161,13 @@ def get_tensor_in_tuple(map_item):
get_tensor_in_tuple, *_eager_outputs_index
)
self._eager_outputs = _eager_outputs

# Load state tensor of modules
if "oneflow_with_eager_tensor" in state_dict:
with_eager = state_dict["oneflow_with_eager_tensor"]
else:
with_eager = True

if self._build_with_shared_graph:
self._state_op_names = self._shared_graph._state_op_names
self._state_tensor_tuple = self._shared_graph._state_tensor_tuple
Expand All @@ -1160,10 +1188,14 @@ def get_tensor_in_tuple(map_item):
for s_idx, s_name in enumerate(self._state_op_names):
if s_name in states_from_eager:
state_tensor_from_eager = states_from_eager[s_name]
# Note: compare value has extra cost.
assert oneflow.allclose(
state_tensor_from_eager, self._state_tensor_tuple[s_idx]
assert (
state_tensor_from_eager.device
== self._state_tensor_tuple[s_idx].device
)
if with_eager:
assert oneflow.allclose(
state_tensor_from_eager, self._state_tensor_tuple[s_idx]
)
self._state_tensor_tuple[s_idx] = state_tensor_from_eager

self.__build_outputs_buffer()
Expand Down
Loading