28
28
from oslo .torch .distributed .parallel_mode import ParallelMode
29
29
from oslo .torch .nn .parallel .data_parallel ._utils import is_ddp_ignored
30
30
from oslo .torch .nn .parallel .data_parallel .data_parallel import _DistributedDataParallel
31
- from oslo .torch .nn .parallel .data_parallel .zero .chunk import (
31
+ from oslo .torch .nn .parallel .data_parallel .zero .hetero . chunk import (
32
32
Chunk ,
33
33
ChunkManager ,
34
34
TensorState ,
35
35
)
36
- from oslo .torch .nn .parallel .data_parallel .zero .heterogeneous_manager import (
37
- HeterogeneousMemoryManager ,
36
+ from oslo .torch .nn .parallel .data_parallel .zero .hetero . memory_manager import (
37
+ HeteroMemoryManager ,
38
38
)
39
- from oslo .torch .nn .parallel .data_parallel .zero .memory_tracer .param_runtime_order import (
39
+ from oslo .torch .nn .parallel .data_parallel .zero .hetero . memory_tracer .param_runtime_order import (
40
40
OrderedParamGenerator ,
41
41
)
42
- from oslo .torch .nn .parallel .data_parallel .zero .utils import (
42
+ from oslo .torch .nn .parallel .data_parallel .zero .hetero . utils import (
43
43
get_current_device ,
44
44
get_temp_total_chunk_on_cuda ,
45
45
)
46
- from oslo .torch .nn .parallel .data_parallel .zero .heterogeneous_hook import (
47
- HeterogeneousZeROHook ,
46
+ from oslo .torch .nn .parallel .data_parallel .zero .hetero . hook import (
47
+ HeteroHook ,
48
48
)
49
49
50
- from oslo .torch .nn .parallel .data_parallel .zero .memory_tracer import (
50
+ from oslo .torch .nn .parallel .data_parallel .zero .hetero . memory_tracer import (
51
51
MemStats ,
52
52
)
53
- from oslo .torch .nn .parallel .data_parallel .zero .chunk import (
53
+ from oslo .torch .nn .parallel .data_parallel .zero .hetero . chunk import (
54
54
init_chunk_manager ,
55
55
)
56
56
@@ -70,22 +70,38 @@ def _cast_float(args, dtype: torch.dtype):
70
70
return args
71
71
72
72
73
- class _FullyShardedDataParallel (_DistributedDataParallel ):
74
- """Fully sharded data parallel.
75
- Warning: Nested FullyShardedDataParallel is not supported now.
76
- It is designed to be used with ChunkManager and HeterogeneousMemoryManager.
77
- For more details, see the API reference of ``ChunkManager`` and ``HeterogeneousMemoryManager``.
73
+ class _HeteroDataParallel (_DistributedDataParallel ):
74
+ """Heterogeneous sharded data parallel.
75
+
76
+ Inspired by the PatrickStar system introduced in "PatrickStar: Parallel
77
+ Training of Pre-trained Models via Chunk-based Dynamic Memory Management"
78
+ by Jiarui Fang, Zilin Zhu, et al. from Tencent Inc:
79
+
80
+ - PatrickStar uses a CPU-GPU heterogeneous memory space to store model data,
81
+ organized in memory chunks.
82
+ - Chunks are dynamically distributed across the heterogeneous memory,
83
+ guided by runtime memory statistics from a warm-up iteration.
84
+ - This approach reduces CPU-GPU data transmission volume and optimizes
85
+ bandwidth utilization.
86
+ - In tandem with the Zero Redundancy Optimizer, PatrickStar can efficiently
87
+ scale to multiple GPUs across multiple nodes.
88
+
89
+ Note:
90
+ Nested HeteroDataParallel is not supported now. It is designed to be
91
+ used with ChunkManager and HeterogeneousMemoryManager. For more details,
92
+ see the API reference of ``ChunkManager`` and ``HeteroMemoryManager``.
78
93
79
94
Args:
80
95
module (torch.nn.Module): Module to apply ZeRO-DP.
81
96
device (torch.device): Device to place the module.
82
97
parallel_context (ParallelContext): process group object.
83
98
placement_policy (str): Placement policy for the chunks.
84
99
pin_memory (bool): Chunks on CPU Memory use pin-memory.
85
- force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
86
- Defaults to False.
100
+ force_outputs_fp32 (bool): If set to True, outputs will be fp32.
101
+ Otherwise, outputs will be fp16. Defaults to False.
87
102
search_range_mb (int): Search range for the chunk size. Defaults to 32.
88
- hidden_dim (int): Hidden dimension for the chunk size search. Defaults to None.
103
+ hidden_dim (int): Hidden dimension for the chunk size search.
104
+ Defaults to None.
89
105
min_chunk_size_mb (int): Minimum chunk size in MB. Defaults to 32.
90
106
memstats (MemStats): Memory statistics. Defaults to None.
91
107
"""
@@ -111,11 +127,11 @@ def __init__(
111
127
search_range_mb = search_range_mb ,
112
128
min_chunk_size_mb = min_chunk_size_mb ,
113
129
)
114
- self .heterogeneous_manager = HeterogeneousMemoryManager (
130
+ self .hetero_memory_manager = HeteroMemoryManager (
115
131
placement_policy , self .chunk_manager , memstats
116
132
)
117
133
self .force_outputs_fp32 = force_outputs_fp32
118
- self .param_op_hook = HeterogeneousZeROHook (self .heterogeneous_manager )
134
+ self .param_op_hook = HeteroHook (self .hetero_memory_manager )
119
135
self .fp32_params : List [torch .Tensor ] = list ()
120
136
self .fp16_params : List [torch .Tensor ] = list ()
121
137
self .overflow_counter = 0
@@ -126,9 +142,9 @@ def __init__(
126
142
self ._cast_buffers ()
127
143
self ._logger = DistributedLogger .get_instance (__name__ )
128
144
129
- if self .heterogeneous_manager ._premade_memstats_ :
145
+ if self .hetero_memory_manager ._premade_memstats_ :
130
146
# build chunk in param runtime visited order.
131
- param_order = self .heterogeneous_manager .memstats ()._param_runtime_order
147
+ param_order = self .hetero_memory_manager .memstats ()._param_runtime_order
132
148
else :
133
149
# build chunk in param initialized order.
134
150
# Note: in this way, it can not get filter unused params during runtime.
@@ -138,7 +154,7 @@ def __init__(
138
154
139
155
self ._init_chunks (
140
156
param_order = param_order ,
141
- cpu_offload = self .heterogeneous_manager .policy_name != "cuda" ,
157
+ cpu_offload = self .hetero_memory_manager .policy_name != "cuda" ,
142
158
pin_memory = pin_memory ,
143
159
)
144
160
@@ -163,20 +179,20 @@ def _post_forward(self):
163
179
self .chunk_manager .move_chunk (chunk , self .grads_device [first_param ])
164
180
assert self .chunk_manager .accessed_mem == 0
165
181
# reset all recorded attributes
166
- self .heterogeneous_manager .reset_attributes ()
182
+ self .hetero_memory_manager .reset_attributes ()
167
183
168
184
def forward (self , * args , ** kwargs ):
169
185
# check whether we are in a inference mode
170
186
grad_flag = torch .is_grad_enabled ()
171
187
if not grad_flag :
172
188
assert (
173
- not self .heterogeneous_manager .need_warmup
174
- or not self .heterogeneous_manager .is_warmup ()
189
+ not self .hetero_memory_manager .need_warmup
190
+ or not self .hetero_memory_manager .is_warmup ()
175
191
), "You should run a completed iteration as your warmup iter"
176
192
177
193
args , kwargs = _cast_float (args , torch .half ), _cast_float (kwargs , torch .half )
178
194
179
- self .heterogeneous_manager .pre_iter (* args )
195
+ self .hetero_memory_manager .pre_iter (* args )
180
196
self .param_op_hook .pre_forward (self .fp16_params )
181
197
outputs = super ().forward (* args , ** kwargs )
182
198
self .param_op_hook .post_forward (self .fp16_params )
@@ -225,9 +241,9 @@ def _post_backward(self):
225
241
)
226
242
self ._setup_grads_ptr ()
227
243
self ._logger .debug (
228
- f"comp cuda demand time: { self .heterogeneous_manager ._comp_cuda_demand_time } , layout time: { self .heterogeneous_manager ._layout_time } , evict time: { self .heterogeneous_manager ._evict_time } , CPU->CUDA vol: { self .heterogeneous_manager ._h2d_volume } B, CUDA->CPU vol: { self .heterogeneous_manager ._d2h_volume } "
244
+ f"comp cuda demand time: { self .hetero_memory_manager ._comp_cuda_demand_time } , layout time: { self .hetero_memory_manager ._layout_time } , evict time: { self .hetero_memory_manager ._evict_time } , CPU->CUDA vol: { self .hetero_memory_manager ._h2d_volume } B, CUDA->CPU vol: { self .hetero_memory_manager ._d2h_volume } "
229
245
)
230
- self .heterogeneous_manager .post_iter ()
246
+ self .hetero_memory_manager .post_iter ()
231
247
232
248
def grad_handle (self , p , grad ):
233
249
self .param_op_hook .post_backward ([p ])
@@ -645,7 +661,7 @@ def _init_chunks(self, param_order, cpu_offload: bool, pin_memory: bool):
645
661
646
662
self .fp16_params .append (p )
647
663
self .fp32_params .append (fp32_p )
648
- self .grads_device [p ] = self .heterogeneous_manager .default_device
664
+ self .grads_device [p ] = self .hetero_memory_manager .default_device
649
665
650
666
self .chunk_manager .close_all_groups ()
651
667
0 commit comments