Skip to content

Commit b7a9e7b

Browse files
committed
Exposing solver callbacks to python
1 parent ff16f6e commit b7a9e7b

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

python/caffe/_caffe.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
205205
return bp::object();
206206
}
207207

208+
template<typename Dtype>
209+
class PythonCallback: public Solver<Dtype>::Callback {
210+
protected:
211+
bp::object on_start_, on_gradients_ready_;
212+
213+
public:
214+
PythonCallback(bp::object on_start, bp::object on_gradients_ready)
215+
: on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
216+
virtual void on_gradients_ready() {
217+
on_gradients_ready_();
218+
}
219+
virtual void on_start() {
220+
on_start_();
221+
}
222+
};
223+
template<typename Dtype>
224+
void Solver_addCallback(Solver<Dtype> * solver, bp::object on_start,
225+
bp::object on_gradients_ready) {
226+
solver->add_callback(new PythonCallback<Dtype>(on_start, on_gradients_ready));
227+
}
228+
208229
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
209230

210231
BOOST_PYTHON_MODULE(_caffe) {
@@ -283,6 +304,7 @@ BOOST_PYTHON_MODULE(_caffe) {
283304
.add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
284305
bp::return_internal_reference<>()))
285306
.add_property("iter", &Solver<Dtype>::iter)
307+
.def("add_callback", &Solver_addCallback<Dtype>)
286308
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
287309
&Solver<Dtype>::Solve), SolveOverloads())
288310
.def("step", &Solver<Dtype>::Step)

0 commit comments

Comments
 (0)