@@ -207,6 +207,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
207
207
return bp::object ();
208
208
}
209
209
210
+ template <typename Dtype>
211
+ class PythonCallback : public Solver <Dtype>::Callback {
212
+ protected:
213
+ bp::object on_start_, on_gradients_ready_;
214
+
215
+ public:
216
+ PythonCallback (bp::object on_start, bp::object on_gradients_ready)
217
+ : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
218
+ virtual void on_gradients_ready () {
219
+ on_gradients_ready_ ();
220
+ }
221
+ virtual void on_start () {
222
+ on_start_ ();
223
+ }
224
+ };
225
+ template <typename Dtype>
226
+ void Solver_addCallback (Solver<Dtype> * solver, bp::object on_start,
227
+ bp::object on_gradients_ready) {
228
+ solver->add_callback (new PythonCallback<Dtype>(on_start, on_gradients_ready));
229
+ }
230
+
210
231
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS (SolveOverloads, Solve, 0 , 1 );
211
232
212
233
BOOST_PYTHON_MODULE (_caffe) {
@@ -289,6 +310,7 @@ BOOST_PYTHON_MODULE(_caffe) {
289
310
.add_property (" test_nets" , bp::make_function (&Solver<Dtype>::test_nets,
290
311
bp::return_internal_reference<>()))
291
312
.add_property (" iter" , &Solver<Dtype>::iter)
313
+ .def (" add_callback" , &Solver_addCallback<Dtype>)
292
314
.def (" solve" , static_cast <void (Solver<Dtype>::*)(const char *)>(
293
315
&Solver<Dtype>::Solve), SolveOverloads ())
294
316
.def (" step" , &Solver<Dtype>::Step)
0 commit comments