Skip to content

Commit d656757

Browse files
committed
Improve code quality
1 parent 78924a0 commit d656757

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

depthwise_conv.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def __init__(self, filters,
156156
self.depthwise_constraint = constraints.get(depthwise_constraint)
157157
self.bias_initializer = initializers.get(bias_initializer)
158158

159+
self._padding = _preprocess_padding(self.padding)
160+
self._strides = (1,) + self.strides + (1,)
161+
self._data_format = "NHWC" if self.data_format == 'channels_last' else "NCHW"
162+
159163
def build(self, input_shape):
160164
if len(input_shape) < 4:
161165
raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. '
@@ -194,15 +198,11 @@ def build(self, input_shape):
194198
self.built = True
195199

196200
def call(self, inputs, training=None):
197-
padding = _preprocess_padding(self.padding)
198-
strides = (1,) + self.strides + (1,)
199-
data_format = "NHWC" if self.data_format == 'channels_last' else "NCHW"
200-
201201
outputs = tf.nn.depthwise_conv2d(inputs, self.depthwise_kernel,
202-
strides=strides,
203-
padding=padding,
202+
strides=self._strides,
203+
padding=self._padding,
204204
rate=self.dilation_rate,
205-
data_format=data_format)
205+
data_format=self._data_format)
206206

207207
if self.bias:
208208
outputs = K.bias_add(

predict_imagenet.py

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def preprocess_input(x):
1616

1717
if __name__ == '__main__':
1818
model = MobileNets()
19-
model.load_weights('weights/mobilenet_imagenet_tf.h5')
2019

2120
img_path = 'elephant.jpg'
2221
img = image.load_img(img_path, target_size=(224, 224))

0 commit comments

Comments
 (0)