@@ -228,6 +228,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
228
228
return bp::object ();
229
229
}
230
230
231
+ template <typename Dtype>
232
+ class PythonCallback : public Solver <Dtype>::Callback {
233
+ protected:
234
+ bp::object on_start_, on_gradients_ready_;
235
+
236
+ public:
237
+ PythonCallback (bp::object on_start, bp::object on_gradients_ready)
238
+ : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
239
+ virtual void on_gradients_ready () {
240
+ on_gradients_ready_ ();
241
+ }
242
+ virtual void on_start () {
243
+ on_start_ ();
244
+ }
245
+ };
246
+ template <typename Dtype>
247
+ void Solver_add_callback (Solver<Dtype> * solver, bp::object on_start,
248
+ bp::object on_gradients_ready) {
249
+ solver->add_callback (new PythonCallback<Dtype>(on_start, on_gradients_ready));
250
+ }
251
+
231
252
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS (SolveOverloads, Solve, 0 , 1 );
232
253
233
254
BOOST_PYTHON_MODULE (_caffe) {
@@ -317,6 +338,7 @@ BOOST_PYTHON_MODULE(_caffe) {
317
338
.add_property (" test_nets" , bp::make_function (&Solver<Dtype>::test_nets,
318
339
bp::return_internal_reference<>()))
319
340
.add_property (" iter" , &Solver<Dtype>::iter)
341
+ .def (" add_callback" , &Solver_add_callback<Dtype>)
320
342
.def (" solve" , static_cast <void (Solver<Dtype>::*)(const char *)>(
321
343
&Solver<Dtype>::Solve), SolveOverloads ())
322
344
.def (" step" , &Solver<Dtype>::Step)
0 commit comments