Skip to content

Commit dec2381

Browse files
committed
Exposing solver callbacks to python
1 parent df412ac commit dec2381

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

python/caffe/_caffe.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
228228
return bp::object();
229229
}
230230

231+
template<typename Dtype>
232+
class PythonCallback: public Solver<Dtype>::Callback {
233+
protected:
234+
bp::object on_start_, on_gradients_ready_;
235+
236+
public:
237+
PythonCallback(bp::object on_start, bp::object on_gradients_ready)
238+
: on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
239+
virtual void on_gradients_ready() {
240+
on_gradients_ready_();
241+
}
242+
virtual void on_start() {
243+
on_start_();
244+
}
245+
};
246+
template<typename Dtype>
247+
void Solver_add_callback(Solver<Dtype> * solver, bp::object on_start,
248+
bp::object on_gradients_ready) {
249+
solver->add_callback(new PythonCallback<Dtype>(on_start, on_gradients_ready));
250+
}
251+
231252
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
232253

233254
BOOST_PYTHON_MODULE(_caffe) {
@@ -317,6 +338,7 @@ BOOST_PYTHON_MODULE(_caffe) {
317338
.add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
318339
bp::return_internal_reference<>()))
319340
.add_property("iter", &Solver<Dtype>::iter)
341+
.def("add_callback", &Solver_add_callback<Dtype>)
320342
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
321343
&Solver<Dtype>::Solve), SolveOverloads())
322344
.def("step", &Solver<Dtype>::Step)

0 commit comments

Comments
 (0)