diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py index 30f8f98acc9dd3..0bec9f9e3b3050 100644 --- a/Lib/test/test__xxsubinterpreters.py +++ b/Lib/test/test__xxsubinterpreters.py @@ -1337,6 +1337,97 @@ def test_run_string_arg_resolved(self): self.assertEqual(obj, b'spam') self.assertEqual(out.strip(), 'send') + def test_channel_send_wait_same_interpreter(self): + cid = interpreters.channel_create() + + received = interpreters.channel_send_wait(cid, b"send", timeout=0) + self.assertFalse(received) + + obj = interpreters.channel_recv(cid) + self.assertEqual(obj, b"send") + + def test_channel_send_wait_different_interpreters(self): + cid = interpreters.channel_create() + interp = interpreters.create() + _run_output(interp, dedent(f""" + import _xxsubinterpreters as _interpreters + import time + import math + + start = time.time() + rc = _interpreters.channel_send_wait({cid}, b"send", timeout=1) + end = time.time() + + assert not rc + assert math.floor(end-start) == 1 + """)) + + obj = interpreters.channel_recv(cid) + self.assertEqual(obj, b"send") + + def test_channel_send_wait_different_threads_and_interpreters(self): + cid = interpreters.channel_create() + interp = interpreters.create() + + thread_exc = None + def run(): + try: + out = _run_output(interp, dedent(f""" + import _xxsubinterpreters as _interpreters + import time + + rc = _interpreters.channel_send_wait({cid}, b"send") + assert rc + """)) + except Exception as e: + nonlocal thread_exc + thread_exc = e + t = threading.Thread(target=run) + t.start() + time.sleep(0.5) + + obj = interpreters.channel_recv(cid) + self.assertEqual(obj, b"send") + t.join() + assert thread_exc is None, f"{thread_exc}" + + def test_channel_send_wait_no_timeout(self): + cid = interpreters.channel_create() + interp = interpreters.create() + + thread_exc = None + def run(): + try: + out = _run_output(interp, dedent(f""" + import _xxsubinterpreters as _interpreters + import time + + rc = _interpreters.channel_send_wait({cid}, b"send", timeout=10) + assert rc + """)) + except Exception as e: + nonlocal thread_exc + thread_exc = e + t = threading.Thread(target=run) + t.start() + time.sleep(0.5) + + obj = interpreters.channel_recv(cid) + self.assertEqual(obj, b"send") + t.join() + assert thread_exc is None, f"{thread_exc}" + + def test_invalid_channel_send_wait(self): + does_not_exist_cid = 1000 + closed_cid = interpreters.channel_create() + interpreters.channel_close(closed_cid) + + with self.assertRaises(interpreters.ChannelNotFoundError): + interpreters.channel_send_wait(does_not_exist_cid, b"error") + + with self.assertRaises(interpreters.ChannelClosedError): + interpreters.channel_send_wait(closed_cid, b"error") + # close def test_close_single_user(self): @@ -1519,6 +1610,30 @@ def test_close_used_multiple_times_by_single_user(self): with self.assertRaises(interpreters.ChannelClosedError): interpreters.channel_recv(cid) + def test_close_while_sender_waiting(self): + cid = interpreters.channel_create() + interp = interpreters.create() + + thread_exc = None + def run(): + try: + out = _run_output(interp, dedent(f""" + import _xxsubinterpreters as _interpreters + + rc = _interpreters.channel_send_wait({cid}, b"send") + assert not rc + """)) + except Exception as e: + nonlocal thread_exc + thread_exc = e + + t = threading.Thread(target=run) + t.start() + time.sleep(0.1) + interpreters.channel_close(cid, force=True) + t.join() + assert thread_exc is None, f"{thread_exc}" + class ChannelReleaseTests(TestBase): diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c index fa35e14c554012..5153cf27ea25d9 100644 --- a/Modules/_xxsubinterpretersmodule.c +++ b/Modules/_xxsubinterpretersmodule.c @@ -347,12 +347,165 @@ channel_exceptions_init(PyObject *ns) return 0; } +typedef struct _channelitem_wait_lock { + PyThread_type_lock lock; + PyThread_type_lock mutex; + int lock_is_allocated; + int is_sent; + int is_recv; +} _channelitem_wait_lock; + +/** + * Allocate a new wait lock for a channel item. + * + * This function allocates memory that must be freed by calling + * _channelitem_wait_lock_recv and _channelitem_wait_lock_wait. + */ +static void * +_channelitem_wait_lock_new(void) +{ + _channelitem_wait_lock *wait_lock = calloc(1, sizeof(*wait_lock)); + wait_lock->lock = PyThread_allocate_lock(); + wait_lock->mutex = PyThread_allocate_lock(); + wait_lock->lock_is_allocated = 1; + wait_lock->is_sent = 0; + wait_lock->is_recv = 0; + + return wait_lock; +} + +/** + * To update the wait lock's state when the channel item is sent by an + * interpreter. + * This function must be called by the interpeter sending the message. + * + * NOTE: This must be called before _channelitem_wait_lock_wait or + * _channelitem_wait_lock_recv can be called. + */ +static void +_channelitem_wait_lock_sent(_channelitem_wait_lock *wait_lock) +{ + assert(wait_lock != NULL); + assert(wait_lock->is_sent == 0); + PyThread_acquire_lock(wait_lock->lock, WAIT_LOCK); + wait_lock->is_sent = 1; +} + +/** + * To update the wait lock's state when the channel item is removed from the + * channel's queue. + * This function must be called by the interpeter receiving the message. + * + * If received is: + * - 1: then the channel item was received by an interpreter. + * - 0: then the channel item was removed from the queue, or the queue was + * closed, before it could be received. + * + * If this function is called after _channelitem_wait_lock_wait has timed out + * then it is responsible for freeing the wait lock. + * + * NOTE: This must be called after _channelitem_wait_lock_sent has been called. + */ +static void +_channelitem_wait_lock_recv(_channelitem_wait_lock *wait_lock, int received) +{ + assert(wait_lock != NULL); + assert(wait_lock->is_sent == 1); + PyThread_type_lock mutex = wait_lock->mutex; + + PyThread_acquire_lock(mutex, WAIT_LOCK); + /* + * If the lock is still allocated it means that _channelitem_wait_lock_wait + * has not timed out. + */ + if (wait_lock->lock_is_allocated == 1) { + PyThread_release_lock(wait_lock->lock); + wait_lock->is_recv = received; + + } else { + free(wait_lock); + wait_lock = NULL; + } + PyThread_release_lock(mutex); + + if (wait_lock == NULL) { + PyThread_free_lock(mutex); + } +} + +/** + * Will wait until the channel item with the wait_lock has been revieved or + * for timeout microseconds. + * This function must be called by the interpeter sending the message. + * + * If timeout is: + * - < 0: then wait until the channelitem has been received. + * - >= 0: then wait until the channelitem has been received or timeout + * microseconds, whichever comes first. + * + * If channel item is received before this function has timed out then it is + * responsible for freeing the wait lock. + * + * This function returns: + * - PY_LOCK_ACQUIRED: The channel item was received by an iterpreter before + * this function timed out. + * - PY_LOCK_FAILURE: The channel item was not received by an interpreter + * before this function timed out. + * + * NOTE: This must be called after _channelitem_wait_lock_sent has been called. + */ +static PyLockStatus +_channelitem_wait_lock_wait(_channelitem_wait_lock *wait_lock, int timeout) +{ + assert(wait_lock != NULL); + assert(wait_lock->is_sent == 1); + PyLockStatus lock_rc; + Py_BEGIN_ALLOW_THREADS + lock_rc = PyThread_acquire_lock_timed(wait_lock->lock, timeout, 0); + Py_END_ALLOW_THREADS + + if (lock_rc == PY_LOCK_INTR) { + lock_rc = PY_LOCK_FAILURE; + } + + PyThread_type_lock mutex = wait_lock->mutex; + PyThread_acquire_lock(mutex, WAIT_LOCK); + PyThread_free_lock(wait_lock->lock); + wait_lock->lock_is_allocated = 0; + + /* + * If lock_rc is PY_LOCK_ACQUIRED then _channelitem_wait_lock_recv must + * have already been called. + */ + if (lock_rc == PY_LOCK_ACQUIRED) { + if (wait_lock->is_recv == 0) { + /* + * The channel item was removed from the queue but not by an + * interpreter. + */ + lock_rc = PY_LOCK_FAILURE; + } + + free(wait_lock); + wait_lock = NULL; + } + PyThread_release_lock(mutex); + + if (wait_lock == NULL) { + PyThread_free_lock(mutex); + } + + return lock_rc; +} + /* the channel queue */ struct _channelitem; typedef struct _channelitem { _PyCrossInterpreterData *data; + /* The lock is owned by the sender. */ + void *wait_lock; struct _channelitem *next; } _channelitem; @@ -393,6 +546,9 @@ _channelitem_free_all(_channelitem *item) while (item != NULL) { _channelitem *last = item; item = item->next; + if (last->wait_lock != NULL) { + _channelitem_wait_lock_recv(last->wait_lock, 0); + } _channelitem_free(last); } } @@ -402,6 +558,9 @@ _channelitem_popped(_channelitem *item) { _PyCrossInterpreterData *data = item->data; item->data = NULL; + if (item->wait_lock != NULL) { + _channelitem_wait_lock_recv(item->wait_lock, 1); + } _channelitem_free(item); return data; } @@ -443,13 +602,18 @@ _channelqueue_free(_channelqueue *queue) } static int -_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data) +_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data, + void *wait_lock) { _channelitem *item = _channelitem_new(); if (item == NULL) { return -1; } item->data = data; + item->wait_lock = wait_lock; + if (wait_lock != NULL) { + _channelitem_wait_lock_sent(item->wait_lock); + } queue->count += 1; if (queue->first == NULL) { @@ -761,7 +925,7 @@ _channel_free(_PyChannelState *chan) static int _channel_add(_PyChannelState *chan, int64_t interp, - _PyCrossInterpreterData *data) + _PyCrossInterpreterData *data, void *wait_lock) { int res = -1; PyThread_acquire_lock(chan->mutex, WAIT_LOCK); @@ -774,7 +938,7 @@ _channel_add(_PyChannelState *chan, int64_t interp, goto done; } - if (_channelqueue_put(chan->queue, data) != 0) { + if (_channelqueue_put(chan->queue, data, wait_lock) != 0) { goto done; } @@ -1285,7 +1449,8 @@ _channel_destroy(_channels *channels, int64_t id) } static int -_channel_send(_channels *channels, int64_t id, PyObject *obj) +_channel_send(_channels *channels, int64_t id, PyObject *obj, + void *wait_lock) { PyInterpreterState *interp = _get_current(); if (interp == NULL) { @@ -1319,7 +1484,7 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj) } // Add the data to the channel. - int res = _channel_add(chan, PyInterpreterState_GetID(interp), data); + int res = _channel_add(chan, PyInterpreterState_GetID(interp), data, wait_lock); PyThread_release_lock(mutex); if (res != 0) { _PyCrossInterpreterData_Release(data); @@ -2337,7 +2502,7 @@ channel_send(PyObject *self, PyObject *args, PyObject *kwds) return NULL; } - if (_channel_send(&_globals.channels, cid, obj) != 0) { + if (_channel_send(&_globals.channels, cid, obj, NULL) != 0) { return NULL; } Py_RETURN_NONE; @@ -2366,6 +2531,50 @@ PyDoc_STRVAR(channel_recv_doc, \n\ Return a new object from the data at the from of the channel's queue."); +static PyObject * +channel_send_wait(PyObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"cid", "obj", "timeout", NULL}; + int64_t cid; + PyObject *obj; + int64_t timeout = -1; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|l:channel_send_wait", + kwlist, channel_id_converter, + &cid, &obj, &timeout)) { + return NULL; + } + + void *wait_lock = _channelitem_wait_lock_new(); + if (_channel_send(&_globals.channels, cid, obj, wait_lock) != 0) { + return NULL; + } + + long long microseconds; + if (timeout >= 0) { + microseconds = timeout * 1000000; + } + else { + microseconds = -1; + } + PyLockStatus lock_rc = _channelitem_wait_lock_wait(wait_lock, microseconds); + + if (lock_rc == PY_LOCK_ACQUIRED) { + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } +} + +PyDoc_STRVAR(channel_send_wait_doc, + "channel_send_wait(cid, obj, timeout)\n\ +\n\ +Add the object's data to the channel's queue and wait until it's removed.\n\ +\n\ +If the timeout is set as:\n\ + * < 0 then wait forever until the object is removed from the queue.\n\ + * >= 0 then wait until the object is removed or for timeout seconds."); + static PyObject * channel_close(PyObject *self, PyObject *args, PyObject *kwds) { @@ -2481,6 +2690,8 @@ static PyMethodDef module_functions[] = { METH_NOARGS, channel_list_all_doc}, {"channel_send", (PyCFunction)(void(*)(void))channel_send, METH_VARARGS | METH_KEYWORDS, channel_send_doc}, + {"channel_send_wait", (PyCFunction)(void (*)(void))channel_send_wait, + METH_VARARGS | METH_KEYWORDS, channel_send_wait_doc}, {"channel_recv", (PyCFunction)(void(*)(void))channel_recv, METH_VARARGS | METH_KEYWORDS, channel_recv_doc}, {"channel_close", (PyCFunction)(void(*)(void))channel_close,