@@ -206,6 +206,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
206
206
return bp::object ();
207
207
}
208
208
209
+ template <typename Dtype>
210
+ class PythonCallback : public Solver <Dtype>::Callback {
211
+ protected:
212
+ bp::object on_start_, on_gradients_ready_;
213
+
214
+ public:
215
+ PythonCallback (bp::object on_start, bp::object on_gradients_ready)
216
+ : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
217
+ virtual void on_gradients_ready () {
218
+ on_gradients_ready_ ();
219
+ }
220
+ virtual void on_start () {
221
+ on_start_ ();
222
+ }
223
+ };
224
+ template <typename Dtype>
225
+ void Solver_addCallback (Solver<Dtype> * solver, bp::object on_start,
226
+ bp::object on_gradients_ready) {
227
+ solver->add_callback (new PythonCallback<Dtype>(on_start, on_gradients_ready));
228
+ }
229
+
209
230
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS (SolveOverloads, Solve, 0 , 1 );
210
231
211
232
BOOST_PYTHON_MODULE (_caffe) {
@@ -284,6 +305,7 @@ BOOST_PYTHON_MODULE(_caffe) {
284
305
.add_property (" test_nets" , bp::make_function (&Solver<Dtype>::test_nets,
285
306
bp::return_internal_reference<>()))
286
307
.add_property (" iter" , &Solver<Dtype>::iter)
308
+ .def (" add_callback" , &Solver_addCallback<Dtype>)
287
309
.def (" solve" , static_cast <void (Solver<Dtype>::*)(const char *)>(
288
310
&Solver<Dtype>::Solve), SolveOverloads ())
289
311
.def (" step" , &Solver<Dtype>::Step)
0 commit comments