Skip to content

Commit d7c4e32

Browse files
authored
[Fix] Refactor ZeRO Directory Structure (#211)
## Title - [Fix] Refactor ZeRO Directory Structure ## Description - This PR restructures the zero directory under `oslo/torch/nn/parallel/data_parallel/zero` to enhance code organization and readability. The changes align the implementation with the architecture of our project, providing a more logical separation between different components and functionalities. - Organized heterogeneous components (Inspired by PatrickStar) into the `hetero` subdirectory, centralizing related code and improving maintainability. - Update to Zero Optimizer Wrapper Interface: > In the existing Zero optimizer, we were not sharding the optimizer state, so the wrapper interface has been updated accordingly. My sincere apologies for any confusion or inconvenience this change may cause, and I urge reviewers to assess this modification to ensure alignment with our project's requirements. - Renaming FULL_SHARD to PatrickStar Algorithm: > Please note that the previously termed FULL_SHARD strategy was, in fact, implementing the PatrickStar algorithm. PatrickStar is a novel approach to parallel training of pre-trained models via chunk-based memory management, leveraging CPU-GPU heterogeneous memory space. It has demonstrated significant advantages in model scaling and execution speed. > > However, I felt that the name "PatrickStar" did not adequately convey the specific characteristics of this approach. Therefore, I have taken the liberty to rename it as "hetero," reflecting the heterogeneous memory utilization. I genuinely value the reviewers' opinions on this naming choice and kindly ask for your feedback. If a more suitable name can be agreed upon, I will happily update it accordingly. ## Linked Issues - N/A
1 parent 21ef4a1 commit d7c4e32

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+200
-162
lines changed

oslo/torch/nn/parallel/data_parallel/data_parallel.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323

2424

2525
class ShardingStrategy(Enum):
26-
SHARD_OP = auto()
27-
SHARD_GRAD_OP = auto()
28-
FULL_SHARD = auto()
26+
SHARD_PARAM = auto()
27+
SHARD_GRAD_PARAM = auto()
28+
HETERO_SHARD = auto()
2929

3030

3131
def DistributedDataParallel(
@@ -45,18 +45,18 @@ def DistributedDataParallel(
4545
4646
Supported sharding strategies are:
4747
- None: No sharding is used. This is the default strategy, where each GPU maintains a full replica of the model.
48-
- SHARD_OP: The optimizer states are sharded across GPUs. Each GPU maintains only a portion of the optimizer state.
49-
- SHARD_GRAD_OP: In addition to sharding the optimizer states, the gradients are also sharded across GPUs.
50-
- FULL_SHARD: The model parameters, optimizer states, and gradients are all sharded across GPUs.
48+
- SHARD_PARAM: Shards the model parameters across GPUs.
49+
- SHARD_GRAD_PARAM: Shards the gradient as well as the model parameters across GPUs.
50+
- HETERO_SHARD: Use the CPU-GPU heterogeneous memory space to store the model data, inspired from PatrickStar.
5151
52-
For the SHARD_OP, SHARD_GRAD_OP, and FULL_SHARD strategies, it is mandatory to provide an optimizer.
52+
For the SHARD_PARAM, SHARD_GRAD_PARAM, and HETERO_SHARD strategies, it is mandatory to provide an optimizer.
5353
5454
Args:
5555
module (nn.Module): PyTorch module object to be wrapped.
5656
parallel_context (ParallelContext): Process group object for distributed training.
5757
model_wrapper_config (Optional[Dict[str, Any]]): Additional configuration parameters for the model wrapper.
5858
optimizer_wrapper_config (Optional[Dict[str, Any]]): Additional configuration parameters for the optimizer wrapper.
59-
sharding_strategy (Optional[ShardingStrategy]): The strategy for sharding. Options include None, SHARD_OP, SHARD_GRAD_OP, and FULL_SHARD.
59+
sharding_strategy (Optional[ShardingStrategy]): The strategy for sharding. Options include None, SHARD_PARAM, SHARD_GRAD_PARAM, and HETERO_SHARD.
6060
optimizer (Optional[torch.optim.Optimizer]): PyTorch optimizer object to be wrapped if a sharding strategy is specified.
6161
6262
Returns:
@@ -86,7 +86,7 @@ def default_strategy():
8686
)
8787
return module
8888

89-
def SHARD_OP_strategy():
89+
def shard_param_strategy():
9090
optimizer_wrapper_config.pop("partition_grad", None)
9191
return module, zero.ZeroRedundancyOptimizer(
9292
optimizer,
@@ -95,7 +95,7 @@ def SHARD_OP_strategy():
9595
**optimizer_wrapper_config,
9696
)
9797

98-
def shard_grad_op_strategy():
98+
def shard_grad_param_strategy():
9999
optimizer_wrapper_config.pop("partition_grad", None)
100100
return module, zero.ZeroRedundancyOptimizer(
101101
optimizer,
@@ -104,15 +104,15 @@ def shard_grad_op_strategy():
104104
**optimizer_wrapper_config,
105105
)
106106

107-
def full_shard_strategy():
108-
fsdp = zero._FullyShardedDataParallel(
107+
def hetero_shard_strategy():
108+
fsdp = zero._HeteroDataParallel(
109109
module=module,
110110
device=torch.device("cuda"),
111111
parallel_context=parallel_context,
112112
force_outputs_fp32=True,
113113
**model_wrapper_config,
114114
)
115-
opt = zero._HeterogeneousZeroOptimizer(
115+
opt = zero._HeteroOptimizer(
116116
optimizer,
117117
module=fsdp,
118118
**optimizer_wrapper_config,
@@ -127,9 +127,9 @@ def full_shard_strategy():
127127

128128
strategy_map = {
129129
None: default_strategy,
130-
ShardingStrategy.SHARD_OP: SHARD_OP_strategy,
131-
ShardingStrategy.SHARD_GRAD_OP: shard_grad_op_strategy,
132-
ShardingStrategy.FULL_SHARD: full_shard_strategy,
130+
ShardingStrategy.SHARD_PARAM: shard_param_strategy,
131+
ShardingStrategy.SHARD_GRAD_PARAM: shard_grad_param_strategy,
132+
ShardingStrategy.HETERO_SHARD: hetero_shard_strategy,
133133
}
134134

135135
strategy = strategy_map.get(sharding_strategy)
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from oslo.torch.nn.parallel.data_parallel.zero.sharded_optim.sharded_optim import (
1+
from oslo.torch.nn.parallel.data_parallel.zero.optim.optim import (
22
ZeroRedundancyOptimizer,
33
)
4-
from oslo.torch.nn.parallel.data_parallel.zero.fully_sharded_data_parallel import (
5-
_FullyShardedDataParallel,
4+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.data_parallel import (
5+
_HeteroDataParallel,
66
)
7-
from oslo.torch.nn.parallel.data_parallel.zero.sharded_optim.heterogeneous_optim import (
8-
_HeterogeneousZeroOptimizer,
7+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.optim import (
8+
_HeteroOptimizer,
99
)
1010

1111
__ALL__ = [
1212
"ZeroRedundancyOptimizer",
13-
"_FullyShardedDataParallel",
14-
"_HeterogeneousZeroOptimizer",
13+
"_HeteroDataParallel",
14+
"_HeteroOptimizer",
1515
]

oslo/torch/nn/parallel/data_parallel/zero/chunk/__init__.py

-9
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.data_parallel import (
2+
_HeteroDataParallel,
3+
)
4+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.optim import _HeteroOptimizer
5+
6+
__ALL__ = ["_HeteroDataParallel", "_HeteroOptimizer"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk.chunk import (
2+
Chunk,
3+
TensorState,
4+
ChunkFullError,
5+
)
6+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk.manager import ChunkManager
7+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk.utils import (
8+
init_chunk_manager,
9+
)
10+
11+
__ALL__ = [
12+
"Chunk",
13+
"TensorState",
14+
"ChunkFullError",
15+
"ChunkManager",
16+
"init_chunk_manager",
17+
]

oslo/torch/nn/parallel/data_parallel/zero/chunk/chunk.py oslo/torch/nn/parallel/data_parallel/zero/hetero/chunk/chunk.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from oslo.torch.distributed.parallel_mode import ParallelMode
2525
from oslo.torch.distributed.parallel_context import ParallelContext
2626

27-
from oslo.torch.nn.parallel.data_parallel.zero.utils import get_current_device
27+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.utils import get_current_device
2828

2929

3030
class TensorState(Enum):

oslo/torch/nn/parallel/data_parallel/zero/chunk/manager.py oslo/torch/nn/parallel/data_parallel/zero/hetero/chunk/manager.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121

2222
from .chunk import Chunk, ChunkFullError, TensorState
23-
from oslo.torch.nn.parallel.data_parallel.zero.utils import get_current_device
23+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.utils import get_current_device
2424

2525
from oslo.torch.distributed.parallel_context import ParallelContext
2626

oslo/torch/nn/parallel/data_parallel/zero/chunk/utils.py oslo/torch/nn/parallel/data_parallel/zero/hetero/chunk/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.distributed as dist
2121
from torch import nn
2222

23-
from oslo.torch.nn.parallel.data_parallel.zero.chunk.manager import (
23+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk.manager import (
2424
ChunkManager,
2525
)
2626

@@ -34,7 +34,7 @@
3434
from oslo.torch.nn.parallel.data_parallel._utils import (
3535
is_ddp_ignored,
3636
)
37-
from oslo.torch.nn.parallel.data_parallel.zero.memory_tracer import (
37+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.memory_tracer import (
3838
MemStats,
3939
OrderedParamGenerator,
4040
)

oslo/torch/nn/parallel/data_parallel/zero/fully_sharded_data_parallel.py oslo/torch/nn/parallel/data_parallel/zero/hetero/data_parallel.py

+45-29
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,29 @@
2828
from oslo.torch.distributed.parallel_mode import ParallelMode
2929
from oslo.torch.nn.parallel.data_parallel._utils import is_ddp_ignored
3030
from oslo.torch.nn.parallel.data_parallel.data_parallel import _DistributedDataParallel
31-
from oslo.torch.nn.parallel.data_parallel.zero.chunk import (
31+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk import (
3232
Chunk,
3333
ChunkManager,
3434
TensorState,
3535
)
36-
from oslo.torch.nn.parallel.data_parallel.zero.heterogeneous_manager import (
37-
HeterogeneousMemoryManager,
36+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.memory_manager import (
37+
HeteroMemoryManager,
3838
)
39-
from oslo.torch.nn.parallel.data_parallel.zero.memory_tracer.param_runtime_order import (
39+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.memory_tracer.param_runtime_order import (
4040
OrderedParamGenerator,
4141
)
42-
from oslo.torch.nn.parallel.data_parallel.zero.utils import (
42+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.utils import (
4343
get_current_device,
4444
get_temp_total_chunk_on_cuda,
4545
)
46-
from oslo.torch.nn.parallel.data_parallel.zero.heterogeneous_hook import (
47-
HeterogeneousZeROHook,
46+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.hook import (
47+
HeteroHook,
4848
)
4949

50-
from oslo.torch.nn.parallel.data_parallel.zero.memory_tracer import (
50+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.memory_tracer import (
5151
MemStats,
5252
)
53-
from oslo.torch.nn.parallel.data_parallel.zero.chunk import (
53+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk import (
5454
init_chunk_manager,
5555
)
5656

@@ -70,22 +70,38 @@ def _cast_float(args, dtype: torch.dtype):
7070
return args
7171

7272

73-
class _FullyShardedDataParallel(_DistributedDataParallel):
74-
"""Fully sharded data parallel.
75-
Warning: Nested FullyShardedDataParallel is not supported now.
76-
It is designed to be used with ChunkManager and HeterogeneousMemoryManager.
77-
For more details, see the API reference of ``ChunkManager`` and ``HeterogeneousMemoryManager``.
73+
class _HeteroDataParallel(_DistributedDataParallel):
74+
"""Heterogeneous sharded data parallel.
75+
76+
Inspired by the PatrickStar system introduced in "PatrickStar: Parallel
77+
Training of Pre-trained Models via Chunk-based Dynamic Memory Management"
78+
by Jiarui Fang, Zilin Zhu, et al. from Tencent Inc:
79+
80+
- PatrickStar uses a CPU-GPU heterogeneous memory space to store model data,
81+
organized in memory chunks.
82+
- Chunks are dynamically distributed across the heterogeneous memory,
83+
guided by runtime memory statistics from a warm-up iteration.
84+
- This approach reduces CPU-GPU data transmission volume and optimizes
85+
bandwidth utilization.
86+
- In tandem with the Zero Redundancy Optimizer, PatrickStar can efficiently
87+
scale to multiple GPUs across multiple nodes.
88+
89+
Note:
90+
Nested HeteroDataParallel is not supported now. It is designed to be
91+
used with ChunkManager and HeterogeneousMemoryManager. For more details,
92+
see the API reference of ``ChunkManager`` and ``HeteroMemoryManager``.
7893
7994
Args:
8095
module (torch.nn.Module): Module to apply ZeRO-DP.
8196
device (torch.device): Device to place the module.
8297
parallel_context (ParallelContext): process group object.
8398
placement_policy (str): Placement policy for the chunks.
8499
pin_memory (bool): Chunks on CPU Memory use pin-memory.
85-
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
86-
Defaults to False.
100+
force_outputs_fp32 (bool): If set to True, outputs will be fp32.
101+
Otherwise, outputs will be fp16. Defaults to False.
87102
search_range_mb (int): Search range for the chunk size. Defaults to 32.
88-
hidden_dim (int): Hidden dimension for the chunk size search. Defaults to None.
103+
hidden_dim (int): Hidden dimension for the chunk size search.
104+
Defaults to None.
89105
min_chunk_size_mb (int): Minimum chunk size in MB. Defaults to 32.
90106
memstats (MemStats): Memory statistics. Defaults to None.
91107
"""
@@ -111,11 +127,11 @@ def __init__(
111127
search_range_mb=search_range_mb,
112128
min_chunk_size_mb=min_chunk_size_mb,
113129
)
114-
self.heterogeneous_manager = HeterogeneousMemoryManager(
130+
self.hetero_memory_manager = HeteroMemoryManager(
115131
placement_policy, self.chunk_manager, memstats
116132
)
117133
self.force_outputs_fp32 = force_outputs_fp32
118-
self.param_op_hook = HeterogeneousZeROHook(self.heterogeneous_manager)
134+
self.param_op_hook = HeteroHook(self.hetero_memory_manager)
119135
self.fp32_params: List[torch.Tensor] = list()
120136
self.fp16_params: List[torch.Tensor] = list()
121137
self.overflow_counter = 0
@@ -126,9 +142,9 @@ def __init__(
126142
self._cast_buffers()
127143
self._logger = DistributedLogger.get_instance(__name__)
128144

129-
if self.heterogeneous_manager._premade_memstats_:
145+
if self.hetero_memory_manager._premade_memstats_:
130146
# build chunk in param runtime visited order.
131-
param_order = self.heterogeneous_manager.memstats()._param_runtime_order
147+
param_order = self.hetero_memory_manager.memstats()._param_runtime_order
132148
else:
133149
# build chunk in param initialized order.
134150
# Note: in this way, it can not get filter unused params during runtime.
@@ -138,7 +154,7 @@ def __init__(
138154

139155
self._init_chunks(
140156
param_order=param_order,
141-
cpu_offload=self.heterogeneous_manager.policy_name != "cuda",
157+
cpu_offload=self.hetero_memory_manager.policy_name != "cuda",
142158
pin_memory=pin_memory,
143159
)
144160

@@ -163,20 +179,20 @@ def _post_forward(self):
163179
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
164180
assert self.chunk_manager.accessed_mem == 0
165181
# reset all recorded attributes
166-
self.heterogeneous_manager.reset_attributes()
182+
self.hetero_memory_manager.reset_attributes()
167183

168184
def forward(self, *args, **kwargs):
169185
# check whether we are in a inference mode
170186
grad_flag = torch.is_grad_enabled()
171187
if not grad_flag:
172188
assert (
173-
not self.heterogeneous_manager.need_warmup
174-
or not self.heterogeneous_manager.is_warmup()
189+
not self.hetero_memory_manager.need_warmup
190+
or not self.hetero_memory_manager.is_warmup()
175191
), "You should run a completed iteration as your warmup iter"
176192

177193
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
178194

179-
self.heterogeneous_manager.pre_iter(*args)
195+
self.hetero_memory_manager.pre_iter(*args)
180196
self.param_op_hook.pre_forward(self.fp16_params)
181197
outputs = super().forward(*args, **kwargs)
182198
self.param_op_hook.post_forward(self.fp16_params)
@@ -225,9 +241,9 @@ def _post_backward(self):
225241
)
226242
self._setup_grads_ptr()
227243
self._logger.debug(
228-
f"comp cuda demand time: {self.heterogeneous_manager._comp_cuda_demand_time}, layout time: {self.heterogeneous_manager._layout_time}, evict time: {self.heterogeneous_manager._evict_time}, CPU->CUDA vol: {self.heterogeneous_manager._h2d_volume}B, CUDA->CPU vol: {self.heterogeneous_manager._d2h_volume}"
244+
f"comp cuda demand time: {self.hetero_memory_manager._comp_cuda_demand_time}, layout time: {self.hetero_memory_manager._layout_time}, evict time: {self.hetero_memory_manager._evict_time}, CPU->CUDA vol: {self.hetero_memory_manager._h2d_volume}B, CUDA->CPU vol: {self.hetero_memory_manager._d2h_volume}"
229245
)
230-
self.heterogeneous_manager.post_iter()
246+
self.hetero_memory_manager.post_iter()
231247

232248
def grad_handle(self, p, grad):
233249
self.param_op_hook.post_backward([p])
@@ -645,7 +661,7 @@ def _init_chunks(self, param_order, cpu_offload: bool, pin_memory: bool):
645661

646662
self.fp16_params.append(p)
647663
self.fp32_params.append(fp32_p)
648-
self.grads_device[p] = self.heterogeneous_manager.default_device
664+
self.grads_device[p] = self.hetero_memory_manager.default_device
649665

650666
self.chunk_manager.close_all_groups()
651667

oslo/torch/nn/parallel/data_parallel/zero/heterogeneous_hook.py oslo/torch/nn/parallel/data_parallel/zero/hetero/hook.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919

2020
import torch
2121

22-
from oslo.torch.nn.parallel.data_parallel.zero.chunk import (
22+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk import (
2323
TensorState,
2424
)
25-
from oslo.torch.nn.parallel.data_parallel.zero.heterogeneous_manager import (
26-
HeterogeneousMemoryManager,
25+
from oslo.torch.nn.parallel.data_parallel.zero.hetero.memory_manager import (
26+
HeteroMemoryManager,
2727
)
2828
from oslo.torch.nn.parallel.data_parallel._utils import is_ddp_ignored
2929

@@ -33,25 +33,25 @@ class TrainingPhase(Enum):
3333
BACKWARD = 1
3434

3535

36-
class HeterogeneousZeROHook:
37-
def __init__(self, heterogeneous_manager: HeterogeneousMemoryManager) -> None:
36+
class HeteroHook:
37+
def __init__(self, hetero_memory_manager: HeteroMemoryManager) -> None:
3838
super().__init__()
39-
self._heterogeneous_manager = heterogeneous_manager
40-
self._chunk_manager = heterogeneous_manager.chunk_manager
39+
self._hetero_memory_manager = hetero_memory_manager
40+
self._chunk_manager = hetero_memory_manager.chunk_manager
4141
self._training_phase = TrainingPhase.FORWARD
4242

4343
def pre_op(self, params):
4444
params = [p for p in params if not is_ddp_ignored(p)]
4545
chunks = self._chunk_manager.get_chunks(params)
4646
for p in params:
4747
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
48-
self._heterogeneous_manager.sample_overall_data()
49-
self._heterogeneous_manager.adjust_layout(chunks)
48+
self._hetero_memory_manager.sample_overall_data()
49+
self._hetero_memory_manager.adjust_layout(chunks)
5050
for chunk in chunks:
5151
self._chunk_manager.access_chunk(chunk)
5252

5353
# record cuda model data of the current OP
54-
self._heterogeneous_manager.record_model_data_volume()
54+
self._hetero_memory_manager.record_model_data_volume()
5555

5656
def post_op(self, params):
5757
params = [p for p in params if not is_ddp_ignored(p)]

0 commit comments

Comments
 (0)