Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix channel mismatch in new_model.py when i==1 #47

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ __pycache__/new_model.cpython-37.pyc
__pycache__/cell_level_search.cpython-37.pyc
__pycache__/auto_deeplab.cpython-37.pyc
__pycache__/decoding_formulas.cpython-37.pyc
utils/__pycache__/copy_state_dict.cpython-37.pyc
utils/__pycache__/copy_state_dict.cpython-37.pyc
1 change: 0 additions & 1 deletion decode_autodeeplab.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(self, args):
name = k[7:] # remove 'module.' of dataparallel
new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

else:
if (torch.cuda.device_count() > 1 or args.load_parallel):
self.model.module.load_state_dict(checkpoint['state_dict'])
Expand Down
Empty file modified mypath.py
100755 → 100644
Empty file.
35 changes: 14 additions & 21 deletions new_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,12 @@


class Cell(nn.Module):

def __init__(self, steps, block_multiplier, prev_prev_fmultiplier,
prev_filter_multiplier,
cell_arch, network_arch,
filter_multiplier, downup_sample):

super(Cell, self).__init__()
self.cell_arch = cell_arch

self.C_in = block_multiplier * filter_multiplier
self.C_out = filter_multiplier
self.C_prev = int(block_multiplier * prev_filter_multiplier)
Expand Down Expand Up @@ -90,7 +87,7 @@ def forward(self, prev_prev_input, prev_input):


class newModel (nn.Module):
def __init__(self, network_arch, cell_arch, num_classes, num_layers, criterion=None, filter_multiplier=20, block_multiplier=5, step=5, cell=Cell, full_net='deeplab_v3+'):
def __init__(self, network_arch, cell_arch, num_classes, num_layers, filter_multiplier=20, block_multiplier=5, step=5, cell=Cell, full_net='deeplab_v3+'):
super(newModel, self).__init__()

self.cells = nn.ModuleList()
Expand All @@ -101,7 +98,6 @@ def __init__(self, network_arch, cell_arch, num_classes, num_layers, criterion=N
self._step = step
self._block_multiplier = block_multiplier
self._filter_multiplier = filter_multiplier
self._criterion = criterion
self._full_net = full_net
initial_fm = 128
self.stem0 = nn.Sequential(
Expand All @@ -125,38 +121,35 @@ def __init__(self, network_arch, cell_arch, num_classes, num_layers, criterion=N
filter_param_dict = {0: 1, 1: 2, 2: 4, 3: 8}
for i in range(self._num_layers):
level_option = torch.sum(self.network_arch[i], dim=1)
prev_level_option = torch.sum(self.network_arch[i-1], dim=1)
prev_prev_level_option = torch.sum(
self.network_arch[i-2], dim=1)
level = torch.argmax(level_option).item()
prev_level = torch.argmax(prev_level_option).item()
prev_prev_level = torch.argmax(prev_prev_level_option).item()
if i>=1:
prev_level_option = torch.sum(self.network_arch[i-1], dim=1)
prev_level = torch.argmax(prev_level_option).item()
if i>=2:
prev_prev_level_option = torch.sum(self.network_arch[i-2], dim=1)
prev_prev_level = torch.argmax(prev_prev_level_option).item()
if i == 0:
downup_sample = 0
_cell = cell(self._step, self._block_multiplier, ini_initial_fm / block_multiplier,
initial_fm / block_multiplier,
self.cell_arch, self.network_arch[i],
self._filter_multiplier *
filter_param_dict[level],
self._filter_multiplier * filter_param_dict[level],
downup_sample)
else:
three_branch_options = torch.sum(self.network_arch[i], dim=0)
downup_sample = torch.argmax(three_branch_options).item() - 1
if i == 1:
_cell = cell(self._step, self._block_multiplier,
initial_fm / block_multiplier,
self._filter_multiplier * 1,
_cell = cell(self._step, self._block_multiplier, initial_fm / block_multiplier,
self._filter_multiplier * filter_param_dict[prev_level],
self.cell_arch, self.network_arch[i],
self._filter_multiplier *
filter_param_dict[level],
self._filter_multiplier * filter_param_dict[level],
downup_sample)
else:
_cell = cell(self._step, self._block_multiplier, self._filter_multiplier * filter_param_dict[prev_prev_level],
self._filter_multiplier *
filter_param_dict[prev_level],
self._filter_multiplier * filter_param_dict[prev_level],
self.cell_arch, self.network_arch[i],
self._filter_multiplier *
filter_param_dict[level], downup_sample)
self._filter_multiplier * filter_param_dict[level],
downup_sample)

self.cells += [_cell]

Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def main():
if args.epochs is None:
epoches = {
'coco': 30,
'cityscapes': 200,
'cityscapes': 90,
'pascal': 50,
}
args.epochs = epoches[args.dataset.lower()]
Expand Down
26 changes: 8 additions & 18 deletions train_autodeeplab.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
APEX_AVAILABLE = True
except ModuleNotFoundError:
APEX_AVAILABLE = False


APEX_AVAILABLE = False
print('working with pytorch version {}'.format(torch.__version__))
print('with cuda version {}'.format(torch.version.cuda))
print('cudnn enabled: {}'.format(torch.backends.cudnn.enabled))
Expand Down Expand Up @@ -115,10 +114,10 @@ def __init__(self, args):


# Using data parallel
if args.cuda and len(self.args.gpu_ids) >1:
if args.cuda and torch.cuda.device_count() >1:
if self.opt_level == 'O2' or self.opt_level == 'O3':
print('currently cannot run with nn.DataParallel and optimization level', self.opt_level)
self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
self.model = torch.nn.DataParallel(self.model)
patch_replication_callback(self.model)
print('training on multiple-GPUs')

Expand Down Expand Up @@ -356,12 +355,8 @@ def main():
# cuda, seed and logging
parser.add_argument('--no-cuda', action='store_true', default=
False, help='disables CUDA training')

parser.add_argument('--use_amp', action='store_true', default=
False)
parser.add_argument('--gpu-ids', type=str, default='0',
help='use which gpu to train, must be a \
comma-separated list of integers only (default=0)')
False)
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
# checking point
Expand All @@ -380,14 +375,9 @@ def main():

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
try:
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
except ValueError:
raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')


if args.sync_bn is None:
if args.cuda and len(args.gpu_ids) > 1:
if args.cuda and torch.cuda.device_count() > 1:
args.sync_bn = True
else:
args.sync_bn = False
Expand All @@ -403,12 +393,12 @@ def main():
args.epochs = epoches[args.dataset.lower()]

if args.batch_size is None:
args.batch_size = 4 * len(args.gpu_ids)
args.batch_size = 4 * torch.cuda.device_count()

if args.test_batch_size is None:
args.test_batch_size = args.batch_size

#args.lr = args.lr / (4 * len(args.gpu_ids)) * args.batch_size
#args.lr = args.lr / (4 * torch.cuda.device_count()) * args.batch_size


if args.checkname is None:
Expand Down
23 changes: 12 additions & 11 deletions train_new_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(self, args):
model = newModel(network_arch= new_network_arch,
cell_arch = new_cell_arch,
num_classes=self.nclass,
num_layers=12)
num_layers=12,
full_net=None)
# output_stride=args.out_stride,
# sync_bn=args.sync_bn,
# freeze_bn=args.freeze_bn)
Expand Down Expand Up @@ -243,7 +244,7 @@ def main():
parser.add_argument('--use-balanced-weights', action='store_true', default=False,
help='whether to use balanced weights (default: False)')
# optimizer params
parser.add_argument('--lr', type=float, default=None, metavar='LR',
parser.add_argument('--lr', type=float, default=0.05, metavar='LR',
help='learning rate (default: auto)')
parser.add_argument('--lr-scheduler', type=str, default='poly',
choices=['poly', 'step', 'cos'],
Expand Down Expand Up @@ -277,22 +278,22 @@ def main():
help='evaluuation interval (default: 1)')
parser.add_argument('--no-val', action='store_true', default=False,
help='skip validation during training')
parser.add_argument('--filter_multiplier', type=int, default=20)
parser.add_argument('--filter_multiplier', type=int, default=4)
parser.add_argument('--autodeeplab', type=str, default='train',
choices=['search', 'train'])
parser.add_argument('--load-parallel', type=int, default=0)
parser.add_argument('--min_lr', type=float, default=0.001) #TODO: CHECK THAT THEY EVEN DO THIS FOR THE MODEL IN THE PAPER

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
try:
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
except ValueError:
raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')
# if args.cuda:
# try:
# args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
# except ValueError:
# raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')

if args.sync_bn is None:
if args.cuda and len(args.gpu_ids) > 1:
if args.cuda and torch.cuda.device_count() > 1:
args.sync_bn = True
else:
args.sync_bn = False
Expand All @@ -307,7 +308,7 @@ def main():
args.epochs = epoches[args.dataset.lower()]

if args.batch_size is None:
args.batch_size = 4 * len(args.gpu_ids)
args.batch_size = 4 * torch.cuda.device_count()

if args.test_batch_size is None:
args.test_batch_size = args.batch_size
Expand All @@ -318,7 +319,7 @@ def main():
'cityscapes': 0.01,
'pascal': 0.007,
}
args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size
args.lr = lrs[args.dataset.lower()] / (4 * torch.cuda.device_count()) * args.batch_size


if args.checkname is None:
Expand Down