Skip to content

Commit 53f37c0

Browse files
committed
Exposing solver callbacks to python
1 parent ca4e342 commit 53f37c0

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

python/caffe/_caffe.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
206206
return bp::object();
207207
}
208208

209+
template<typename Dtype>
210+
class PythonCallback: public Solver<Dtype>::Callback {
211+
protected:
212+
bp::object on_start_, on_gradients_ready_;
213+
214+
public:
215+
PythonCallback(bp::object on_start, bp::object on_gradients_ready)
216+
: on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
217+
virtual void on_gradients_ready() {
218+
on_gradients_ready_();
219+
}
220+
virtual void on_start() {
221+
on_start_();
222+
}
223+
};
224+
template<typename Dtype>
225+
void Solver_addCallback(Solver<Dtype> * solver, bp::object on_start,
226+
bp::object on_gradients_ready) {
227+
solver->add_callback(new PythonCallback<Dtype>(on_start, on_gradients_ready));
228+
}
229+
209230
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
210231

211232
BOOST_PYTHON_MODULE(_caffe) {
@@ -284,6 +305,7 @@ BOOST_PYTHON_MODULE(_caffe) {
284305
.add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
285306
bp::return_internal_reference<>()))
286307
.add_property("iter", &Solver<Dtype>::iter)
308+
.def("add_callback", &Solver_addCallback<Dtype>)
287309
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
288310
&Solver<Dtype>::Solve), SolveOverloads())
289311
.def("step", &Solver<Dtype>::Step)

0 commit comments

Comments
 (0)