Skip to content

Commit 2882cda

Browse files
committed
Merge pull request #936 from jeffdonahue/not-stage
Add "not_stage" to NetStateRule to exclude NetStates with certain stages
2 parents 02aadf0 + c6e9c59 commit 2882cda

File tree

3 files changed

+95
-18
lines changed

3 files changed

+95
-18
lines changed

src/caffe/net.cpp

+28-16
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state,
282282
if (state.level() < rule.min_level()) {
283283
LOG(INFO) << "The NetState level (" << state.level()
284284
<< ") is above the min_level (" << rule.min_level()
285-
<< " specified by a rule in layer " << layer_name;
285+
<< ") specified by a rule in layer " << layer_name;
286286
return false;
287287
}
288288
}
@@ -291,24 +291,36 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state,
291291
if (state.level() > rule.max_level()) {
292292
LOG(INFO) << "The NetState level (" << state.level()
293293
<< ") is above the max_level (" << rule.max_level()
294-
<< " specified by a rule in layer " << layer_name;
294+
<< ") specified by a rule in layer " << layer_name;
295295
return false;
296296
}
297297
}
298-
// Check whether the rule is broken due to stage. If stage is specified,
299-
// the NetState must contain ALL of the rule's stages to meet it.
300-
if (rule.stage_size()) {
301-
for (int i = 0; i < rule.stage_size(); ++i) {
302-
// Check that the NetState contains the rule's ith stage.
303-
bool has_stage = false;
304-
for (int j = 0; !has_stage && j < state.stage_size(); ++j) {
305-
if (rule.stage(i) == state.stage(j)) { has_stage = true; }
306-
}
307-
if (!has_stage) {
308-
LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i)
309-
<< "' specified by a rule in layer " << layer_name;
310-
return false;
311-
}
298+
// Check whether the rule is broken due to stage. The NetState must
299+
// contain ALL of the rule's stages to meet it.
300+
for (int i = 0; i < rule.stage_size(); ++i) {
301+
// Check that the NetState contains the rule's ith stage.
302+
bool has_stage = false;
303+
for (int j = 0; !has_stage && j < state.stage_size(); ++j) {
304+
if (rule.stage(i) == state.stage(j)) { has_stage = true; }
305+
}
306+
if (!has_stage) {
307+
LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i)
308+
<< "' specified by a rule in layer " << layer_name;
309+
return false;
310+
}
311+
}
312+
// Check whether the rule is broken due to not_stage. The NetState must
313+
// contain NONE of the rule's not_stages to meet it.
314+
for (int i = 0; i < rule.not_stage_size(); ++i) {
315+
// Check that the NetState contains the rule's ith not_stage.
316+
bool has_stage = false;
317+
for (int j = 0; !has_stage && j < state.stage_size(); ++j) {
318+
if (rule.not_stage(i) == state.stage(j)) { has_stage = true; }
319+
}
320+
if (has_stage) {
321+
LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i)
322+
<< "' specified by a rule in layer " << layer_name;
323+
return false;
312324
}
313325
}
314326
return true;

src/caffe/proto/caffe.proto

+4-2
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,12 @@ message NetStateRule {
171171
optional int32 min_level = 2;
172172
optional int32 max_level = 3;
173173

174-
// A customizable set of stages.
175-
// The net must have ALL of the specified stages to meet the rule.
174+
// Customizable sets of stages to include or exclude.
175+
// The net must have ALL of the specified stages and NONE of the specified
176+
// "not_stage"s to meet the rule.
176177
// (Use multiple NetStateRules to specify conjunctions of stages.)
177178
repeated string stage = 4;
179+
repeated string not_stage = 5;
178180
}
179181

180182
// NOTE

src/caffe/test/test_net.cpp

+63
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,69 @@ TEST_F(FilterNetTest, TestFilterInByMultipleStage2) {
16181618
this->RunFilterNetTest(input_proto, input_proto);
16191619
}
16201620

1621+
TEST_F(FilterNetTest, TestFilterInByNotStage) {
1622+
const string& input_proto =
1623+
"state: { stage: 'mystage' } "
1624+
"name: 'TestNetwork' "
1625+
"layers: { "
1626+
" name: 'data' "
1627+
" type: DATA "
1628+
" top: 'data' "
1629+
" top: 'label' "
1630+
"} "
1631+
"layers: { "
1632+
" name: 'innerprod' "
1633+
" type: INNER_PRODUCT "
1634+
" bottom: 'data' "
1635+
" top: 'innerprod' "
1636+
" include: { not_stage: 'myotherstage' } "
1637+
"} "
1638+
"layers: { "
1639+
" name: 'loss' "
1640+
" type: SOFTMAX_LOSS "
1641+
" bottom: 'innerprod' "
1642+
" bottom: 'label' "
1643+
" include: { not_stage: 'myotherstage' } "
1644+
"} ";
1645+
this->RunFilterNetTest(input_proto, input_proto);
1646+
}
1647+
1648+
TEST_F(FilterNetTest, TestFilterOutByNotStage) {
1649+
const string& input_proto =
1650+
"state: { stage: 'mystage' } "
1651+
"name: 'TestNetwork' "
1652+
"layers: { "
1653+
" name: 'data' "
1654+
" type: DATA "
1655+
" top: 'data' "
1656+
" top: 'label' "
1657+
"} "
1658+
"layers: { "
1659+
" name: 'innerprod' "
1660+
" type: INNER_PRODUCT "
1661+
" bottom: 'data' "
1662+
" top: 'innerprod' "
1663+
" include: { not_stage: 'mystage' } "
1664+
"} "
1665+
"layers: { "
1666+
" name: 'loss' "
1667+
" type: SOFTMAX_LOSS "
1668+
" bottom: 'innerprod' "
1669+
" bottom: 'label' "
1670+
" include: { not_stage: 'mystage' } "
1671+
"} ";
1672+
const string& output_proto =
1673+
"state: { stage: 'mystage' } "
1674+
"name: 'TestNetwork' "
1675+
"layers: { "
1676+
" name: 'data' "
1677+
" type: DATA "
1678+
" top: 'data' "
1679+
" top: 'label' "
1680+
"} ";
1681+
this->RunFilterNetTest(input_proto, output_proto);
1682+
}
1683+
16211684
TEST_F(FilterNetTest, TestFilterOutByMinLevel) {
16221685
const string& input_proto =
16231686
"name: 'TestNetwork' "

0 commit comments

Comments
 (0)