-
Notifications
You must be signed in to change notification settings - Fork 802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix and refine graph save runtime_state_dict #10016
Changes from 3 commits
d33d5ce
aa2dc41
e951762
04b79d3
d4161c4
5999342
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,48 +26,52 @@ | |
class LRUCache(object): | ||
_cnt: int = 0 | ||
|
||
def __init__(self, cache_size): | ||
def __init__(self, cache_size, keep_the_1st=True): | ||
assert cache_size >= 2 | ||
self.cache_size = cache_size | ||
self.queue = deque() | ||
self.hash_map = dict() | ||
|
||
def front(self): | ||
if self.is_empty(): | ||
return None | ||
|
||
key = self.queue[0] | ||
return self.hash_map[key] | ||
self.keep_the_1st = keep_the_1st | ||
self.queue = deque() | ||
|
||
def is_empty(self): | ||
return len(self.queue) == 0 | ||
return len(self.hash_map) == 0 | ||
|
||
def is_queue_full(self): | ||
return len(self.queue) >= self.cache_size | ||
def is_full(self): | ||
return len(self.hash_map) >= self.cache_size | ||
|
||
def pop(self): | ||
if len(self.queue) == 0: | ||
return None | ||
pop_key = self.queue.pop() | ||
value = self.hash_map.pop(pop_key) | ||
del value | ||
return pop_key | ||
|
||
def set(self, key, value): | ||
new_key = None | ||
old_key = None | ||
if key in self.hash_map: | ||
return None | ||
return new_key, old_key | ||
|
||
pop_key = None | ||
while self.is_queue_full(): | ||
pop_key = self.pop() | ||
if self.is_full(): | ||
old_key = self.pop() | ||
assert old_key is not None, f"Cache size is {self.cache_size}, at least 2." | ||
assert not self.is_full() | ||
|
||
if not (self.keep_the_1st and self.is_empty()): | ||
self.queue.appendleft(key) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 避免第一个 base 的图被淘汰 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是如何避免第一个 base 的图被淘汰的? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
self.queue 存的可以淘汰的cache key,LRU的策略,top是最老的,end是最新的,然后根据使用情况调整顺序,维护 LRU 需要的 cache key 排序。 这里避免 base 被淘汰的方法,是因为 base graph 一定是第一个加入 cache 的,不把它加入 self.queue 他就一定不会被淘汰了。 |
||
|
||
self.queue.appendleft(key) | ||
value._oneflow_graph_cache_order = LRUCache._cnt | ||
LRUCache._cnt += 1 | ||
self.hash_map[key] = value | ||
return pop_key if pop_key is not None else key | ||
new_key = key | ||
return new_key, old_key | ||
|
||
def get(self, key): | ||
if key in self.hash_map: | ||
self.queue.remove(key) | ||
self.queue.appendleft(key) | ||
if key in self.queue: | ||
self.queue.remove(key) | ||
self.queue.appendleft(key) | ||
return self.hash_map[key] | ||
|
||
return None | ||
|
@@ -111,21 +115,27 @@ def __call__(self, *args, **kwargs): | |
return graph(*args, **kwargs) | ||
|
||
def runtime_state_dict( | ||
self, destination=None | ||
self, destination=None, with_eager=False, | ||
) -> Dict[str, Dict[str, Union[Dict[str, Tensor], str]]]: | ||
if destination is None: | ||
destination = OrderedDict() | ||
destination._metadata = OrderedDict() | ||
|
||
for (key, graph) in self._cache.items(): | ||
with AvoidRecursiveCacheCall(graph): | ||
state_dict = graph.runtime_state_dict() | ||
state_dict = graph.runtime_state_dict(with_eager=with_eager) | ||
state_dict["cache_order"] = graph._oneflow_graph_cache_order | ||
state_dict["cache_key"] = key | ||
destination[state_dict["graph_name"]] = state_dict | ||
return destination | ||
|
||
def _init_and_get_a_graph_in_cache(self, cache_key): | ||
self._base_graph._print( | ||
0, | ||
0, | ||
self._base_graph._shallow_repr() | ||
+ f" is creating a graph cache with key {cache_key}.", | ||
) | ||
cur_is_base = False | ||
if self._cache.is_empty(): | ||
# Has no graph yet | ||
|
@@ -135,7 +145,7 @@ def _init_and_get_a_graph_in_cache(self, cache_key): | |
# Create new graph from base | ||
graph = self._base_graph.__class__( | ||
*self._base_graph._cached_init_args, | ||
**self._base_graph._cached_init_kwargs | ||
**self._base_graph._cached_init_kwargs, | ||
) | ||
graph._run_with_cache = False | ||
graph._dynamic_input_graph_cache = None | ||
|
@@ -147,8 +157,16 @@ def _init_and_get_a_graph_in_cache(self, cache_key): | |
graph.enable_shared() | ||
else: | ||
graph.share_from(self._base_graph) | ||
ret = self._cache.set(cache_key, graph) | ||
assert ret is not None | ||
new_key, old_key = self._cache.set(cache_key, graph) | ||
if old_key is not None: | ||
self._base_graph._print( | ||
0, | ||
0, | ||
self._base_graph._shallow_repr() | ||
+ f" cache is full(cache size {self._cache_size}), has deleted an old graph cache with key {old_key}.", | ||
) | ||
assert new_key is not None | ||
|
||
return graph | ||
|
||
def load_runtime_state_dict( | ||
|
@@ -159,9 +177,12 @@ def load_runtime_state_dict( | |
cache_order = sub_state_dict["cache_order"] | ||
graph_dict[cache_order] = sub_state_dict | ||
|
||
self._cache = LRUCache(self._cache_size) | ||
if self._cache is None: | ||
self._cache = LRUCache(self._cache_size) | ||
for _, sub_state_dict in sorted(graph_dict.items()): | ||
cache_key = sub_state_dict["cache_key"] | ||
graph = self._cache.get(cache_key) | ||
assert graph is None | ||
graph = self._init_and_get_a_graph_in_cache(cache_key) | ||
with AvoidRecursiveCacheCall(graph): | ||
graph.load_runtime_state_dict(sub_state_dict) | ||
|
@@ -183,12 +204,12 @@ def get_graph(self, *args, **kwargs): | |
|
||
# Create graph | ||
if graph is None: | ||
graph = self._init_and_get_a_graph_in_cache(cache_key) | ||
self._base_graph._print( | ||
0, | ||
0, | ||
self._base_graph._shallow_repr() | ||
+ " got a new input shape, is compiling a new graph.", | ||
) | ||
graph = self._init_and_get_a_graph_in_cache(cache_key) | ||
|
||
return graph |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -730,6 +730,7 @@ def _filter_states(self): | |
self._variables_conf[state_tensor] = VariableConfig(op_name) | ||
|
||
self._state_tensor_tuple = convert_to_tensor_tuple(state_tensors) | ||
self._eager_state_op_names = deepcopy(state_op_names) | ||
return state_op_names | ||
|
||
def _generate_config_proto(self): | ||
|
@@ -1011,13 +1012,15 @@ def enable_save_runtime_state_dict(self, mode: bool = True): | |
self._enable_save_runtime_state_dict = False | ||
|
||
def runtime_state_dict( | ||
self, destination=None | ||
self, destination=None, with_eager=False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 默认不保存 eager module 上的 tensor |
||
) -> Union[ | ||
Dict[str, Union[Dict[str, Tensor], str]], | ||
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]], | ||
]: | ||
if self._run_with_cache == True: | ||
return self._dynamic_input_graph_cache.runtime_state_dict() | ||
return self._dynamic_input_graph_cache.runtime_state_dict( | ||
with_eager=with_eager | ||
) | ||
|
||
assert ( | ||
self._enable_save_runtime_state_dict | ||
|
@@ -1067,10 +1070,28 @@ def gen_index_in_tuple(eager_out): | |
) | ||
destination["outputs"] = outputs_sub_destination | ||
|
||
destination["oneflow_with_eager_tensor"] = with_eager | ||
if not self._build_with_shared_graph: | ||
_state_tensor_tuple4save = [] | ||
if with_eager: | ||
_state_tensor_tuple4save = self._state_tensor_tuple | ||
else: | ||
assert len(self._state_tensor_tuple) == len(self._state_op_names) | ||
for state_idx in range(len(self._state_tensor_tuple)): | ||
if self._state_op_names[state_idx] in self._eager_state_op_names: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 什么情况下存在 eager module 持有,graph state 不持有的 tensor。 是被 constant folding 的 tensor 吗? eager free tensor 是 graph state tensor 吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
是的
嗯,也是。但是不是 eager module 上的。 |
||
# This state tensor is from eager module. Just save a dummy tensor here. | ||
_state_tensor_tuple4save.append( | ||
oneflow.Tensor().to( | ||
self._state_tensor_tuple[state_idx].device | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果不需要保存 eager module 上的 tensor【通常 eager module 上的 tensor 会通过 eager 做保存和加载】,则可以只存一个空 tensor,减小 runtime_state_dict 的大小 |
||
) | ||
) | ||
else: | ||
_state_tensor_tuple4save.append( | ||
self._state_tensor_tuple[state_idx] | ||
) | ||
states_sub_destination = OrderedDict() | ||
_fill_sub_destination( | ||
states_sub_destination, self._state_op_names, self._state_tensor_tuple | ||
states_sub_destination, self._state_op_names, _state_tensor_tuple4save | ||
) | ||
destination["states"] = states_sub_destination | ||
|
||
|
@@ -1140,6 +1161,13 @@ def get_tensor_in_tuple(map_item): | |
get_tensor_in_tuple, *_eager_outputs_index | ||
) | ||
self._eager_outputs = _eager_outputs | ||
|
||
# Load state tensor of modules | ||
if "oneflow_with_eager_tensor" in state_dict: | ||
with_eager = state_dict["oneflow_with_eager_tensor"] | ||
else: | ||
with_eager = True | ||
|
||
if self._build_with_shared_graph: | ||
self._state_op_names = self._shared_graph._state_op_names | ||
self._state_tensor_tuple = self._shared_graph._state_tensor_tuple | ||
|
@@ -1160,10 +1188,14 @@ def get_tensor_in_tuple(map_item): | |
for s_idx, s_name in enumerate(self._state_op_names): | ||
if s_name in states_from_eager: | ||
state_tensor_from_eager = states_from_eager[s_name] | ||
# Note: compare value has extra cost. | ||
assert oneflow.allclose( | ||
state_tensor_from_eager, self._state_tensor_tuple[s_idx] | ||
assert ( | ||
state_tensor_from_eager.device | ||
== self._state_tensor_tuple[s_idx].device | ||
) | ||
if with_eager: | ||
assert oneflow.allclose( | ||
state_tensor_from_eager, self._state_tensor_tuple[s_idx] | ||
) | ||
self._state_tensor_tuple[s_idx] = state_tensor_from_eager | ||
|
||
self.__build_outputs_buffer() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 pop 是不是会把最久远 base 的 key 给 pop 出来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的