Skip to content

Commit

Permalink
Merge pull request #3045 from reyoung/feature/make_golang_client_lazy…
Browse files Browse the repository at this point in the history
…_load

Make C lib in `paddle.v2.master.client` lazy load
  • Loading branch information
reyoung authored Jul 25, 2017
2 parents 1ab2e44 + 61cd828 commit 8273dd7
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions python/paddle/v2/master/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import ctypes
import os

path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
lib = ctypes.cdll.LoadLibrary(path)
__lib__ = None


def get_c_lib():
global __lib__
if __lib__ is None:
path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
__lib__ = ctypes.cdll.LoadLibrary(path)
return __lib__


class client(object):
Expand All @@ -11,8 +18,8 @@ class client(object):
"""

def __init__(self, etcd_endpoints, timeout_sec, buf_size=0):
self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout_sec,
buf_size)
self.c = get_c_lib().paddle_new_etcd_master_client(
etcd_endpoints, timeout_sec, buf_size)

def request_save_model(self, trainer_id, block_ms):
"""request to save model
Expand All @@ -32,10 +39,11 @@ def request_save_model(self, trainer_id, block_ms):
saving the model, -1 if error happened.
"""
return lib.paddle_request_save_model(self.c, trainer_id, block_ms)
return get_c_lib().paddle_request_save_model(self.c, trainer_id,
block_ms)

def release(self):
lib.paddle_release_master_client(self.c)
get_c_lib().paddle_release_master_client(self.c)
self.c = None

def set_dataset(self, paths):
Expand All @@ -45,7 +53,7 @@ def set_dataset(self, paths):
for idx, path in enumerate(paths):
c_ptr = ctypes.c_char_p(path)
holder[idx] = c_ptr
lib.paddle_set_dataset(self.c, holder, len(paths))
get_c_lib().paddle_set_dataset(self.c, holder, len(paths))

def next_record(self):
"""gets next record for training
Expand All @@ -56,7 +64,7 @@ def next_record(self):
"""
p = ctypes.c_char_p()
ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret)
size = get_c_lib().paddle_next_record(self.c, ret)
if size < 0:
# Error
return None, size
Expand All @@ -67,5 +75,5 @@ def next_record(self):

record = ret.contents.value[:size]
# Memory created from C should be freed.
lib.mem_free(ret.contents)
get_c_lib().mem_free(ret.contents)
return record, 0

0 comments on commit 8273dd7

Please sign in to comment.