@@ -205,6 +205,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
205
205
return bp::object ();
206
206
}
207
207
208
+ template <typename Dtype>
209
+ class PythonCallback : public Solver <Dtype>::Callback {
210
+ protected:
211
+ bp::object on_start_, on_gradients_ready_;
212
+
213
+ public:
214
+ PythonCallback (bp::object on_start, bp::object on_gradients_ready)
215
+ : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
216
+ virtual void on_gradients_ready () {
217
+ on_gradients_ready_ ();
218
+ }
219
+ virtual void on_start () {
220
+ on_start_ ();
221
+ }
222
+ };
223
+ template <typename Dtype>
224
+ void Solver_addCallback (Solver<Dtype> * solver, bp::object on_start,
225
+ bp::object on_gradients_ready) {
226
+ solver->add_callback (new PythonCallback<Dtype>(on_start, on_gradients_ready));
227
+ }
228
+
208
229
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS (SolveOverloads, Solve, 0 , 1 );
209
230
210
231
BOOST_PYTHON_MODULE (_caffe) {
@@ -283,6 +304,7 @@ BOOST_PYTHON_MODULE(_caffe) {
283
304
.add_property (" test_nets" , bp::make_function (&Solver<Dtype>::test_nets,
284
305
bp::return_internal_reference<>()))
285
306
.add_property (" iter" , &Solver<Dtype>::iter)
307
+ .def (" add_callback" , &Solver_addCallback<Dtype>)
286
308
.def (" solve" , static_cast <void (Solver<Dtype>::*)(const char *)>(
287
309
&Solver<Dtype>::Solve), SolveOverloads ())
288
310
.def (" step" , &Solver<Dtype>::Step)
0 commit comments