Skip to content

Commit 7bc5a0c

Browse files
committed
Exposing solver callbacks to python
1 parent cff6f3d commit 7bc5a0c

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

python/caffe/_caffe.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
207207
return bp::object();
208208
}
209209

210+
template<typename Dtype>
211+
class PythonCallback: public Solver<Dtype>::Callback {
212+
protected:
213+
bp::object on_start_, on_gradients_ready_;
214+
215+
public:
216+
PythonCallback(bp::object on_start, bp::object on_gradients_ready)
217+
: on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
218+
virtual void on_gradients_ready() {
219+
on_gradients_ready_();
220+
}
221+
virtual void on_start() {
222+
on_start_();
223+
}
224+
};
225+
template<typename Dtype>
226+
void Solver_addCallback(Solver<Dtype> * solver, bp::object on_start,
227+
bp::object on_gradients_ready) {
228+
solver->add_callback(new PythonCallback<Dtype>(on_start, on_gradients_ready));
229+
}
230+
210231
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
211232

212233
BOOST_PYTHON_MODULE(_caffe) {
@@ -289,6 +310,7 @@ BOOST_PYTHON_MODULE(_caffe) {
289310
.add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
290311
bp::return_internal_reference<>()))
291312
.add_property("iter", &Solver<Dtype>::iter)
313+
.def("add_callback", &Solver_addCallback<Dtype>)
292314
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
293315
&Solver<Dtype>::Solve), SolveOverloads())
294316
.def("step", &Solver<Dtype>::Step)

0 commit comments

Comments
 (0)