Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yolov5 train sample #13510

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,74 @@
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
GIT_INFO = check_git_info()

import torch._dynamo
from compile.FxGraphConvertor import fx2mlir
from torch._functorch.aot_autograd import aot_export_module
from tpu_mlir.python.tools.train.tpu_mlir_jit import aot_backend


class JitNet(nn.Module):
def __init__(self, net, loss_fn):
super().__init__()
self.net = net
self.loss_fn = loss_fn

def forward(self, x, y):
self.net(x)
loss, loss_item = self.loss_fn(self.net(x), y)
return loss, loss_item.detach()


def _get_disc_decomp():
from torch._decomp import get_decompositions

aten = torch.ops.aten
decompositions_dict = get_decompositions(
[
aten.gelu,
aten.gelu_backward,
aten.native_group_norm_backward,
# aten.native_layer_norm,
aten.native_layer_norm_backward,
# aten.std_mean.correction,
# aten._softmax,
aten._softmax_backward_data,
aten.tanh_backward,
aten.slice_backward,
aten.select_backward,
aten.embedding_dense_backward,
aten.sigmoid_backward,
aten.nll_loss_backward,
aten._log_softmax_backward_data,
aten.nll_loss_forward,
]
)
return decompositions_dict


def convert_module_fx(
submodule_name: str,
module: torch.fx.GraphModule,
args={},
bwd_graph: bool = False,
para_shape: list = [],
):
c = fx2mlir(submodule_name, args, bwd_graph, para_shape)
return c.convert(module)


class SophonJointCompile:
def __init__(self, model, example_inputs, trace_joint=True, output_loss_index=0, args=None):
fx_g, signature = aot_export_module(
model, example_inputs, trace_joint=trace_joint, output_loss_index=0, decompositions=_get_disc_decomp()
)
fx_g.to_folder("yolov5sc", "joint")
breakpoint()

def fx_convert_bmodel(self):
name = f"test_{args.model}_joint_{args.batch}"
convert_module_fx(name, self.fx_g, self.args, False)


def train(hyp, opt, device, callbacks):
"""
Expand Down Expand Up @@ -384,6 +452,8 @@ def lf(x):
if RANK in {-1, 0}:
pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
optimizer.zero_grad()
model_opt = torch.compile(model, backend=aot_backend)
# zwyjit = JitNet(model, compute_loss)
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
callbacks.run("on_train_batch_start")
ni = i + nb * epoch # number integrated batches (since train start)
Expand All @@ -410,8 +480,13 @@ def lf(x):

# Forward
with torch.cuda.amp.autocast(amp):
pred = model(imgs) # forward
pred = model_opt(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
# fx_g, signature = aot_export_module(
# model, [imgs], trace_joint=False, output_loss_index=0, decompositions=_get_disc_decomp()
# )
# print('fx_g:', fx_g)

if RANK != -1:
loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
if opt.quad:
Expand Down Expand Up @@ -983,4 +1058,19 @@ def run(**kwargs):

if __name__ == "__main__":
opt = parse_opt()

parser = argparse.ArgumentParser()
parser.add_argument("--chip", default="bm1690", choices=["bm1684x", "bm1690", "sg2260"], help="chip name")
parser.add_argument("--debug", default="print_ori_fx_graph", help="debug")
parser.add_argument("--cmp", action="store_true", help="enable cmp")
parser.add_argument("--fast_test", action="store_true", help="fast_test")
parser.add_argument("--skip_module_num", default=0, type=int, help="skip_module_num")
parser.add_argument("--exit_at", default=-1, type=int, help="exit_at")
parser.add_argument("--num_core", default=1, type=int, help="The number of TPU cores used for parallel computation")
parser.add_argument("--opt", default=2, type=int, help="layer group opt")
parser.add_argument("--fp", default="", help="fp")
import tpu_mlir.python.tools.train.tpu_mlir_jit as tpu_mlir_jit

tpu_mlir_jit.args = parser.parse_known_args()[0]

main(opt)
Loading