|
| 1 | +#ifndef STAN_CALLBACKS_CONCURRENT_WRITER_HPP |
| 2 | +#define STAN_CALLBACKS_CONCURRENT_WRITER_HPP |
| 3 | + |
| 4 | +#include <stan/math/prim/fun/Eigen.hpp> |
| 5 | +#include <stan/math/prim/meta.hpp> |
| 6 | +#include <tbb/concurrent_queue.h> |
| 7 | +#include <condition_variable> |
| 8 | +#include <functional> |
| 9 | +#include <string> |
| 10 | +#include <thread> |
| 11 | +#include <vector> |
| 12 | + |
| 13 | +namespace stan::callbacks { |
| 14 | +#ifdef STAN_THREADS |
| 15 | +/** |
| 16 | + * Enables thread-safe writing of numeric values to a writer. |
| 17 | + * On construction, a thread is spawned to write to the writer. |
| 18 | + * This class uses an `std::thread` instead of a tbb task graph because |
| 19 | + * of deadlocking issues. A deadlock can happen in two major cases. |
| 20 | + * 1. If TBB gives all threads a task, and all threads hit an instance of max |
| 21 | + * capacity. TBB can choose to wait for a thread to finish instead of spinning |
| 22 | + * up the write thread. So to circumvent that issue, we use an std::thread. |
| 23 | + * 2. If the bounded queues are full but the consumer thread is not scheduled |
| 24 | + * because there are more busy threads than the number of threads available. |
| 25 | + * The producer threads are blocked because the queues are full. The consumer |
| 26 | + * thread is blocked because the producer thread is spinning. Then we have a |
| 27 | + * deadlock because the consumer thread is blocked because the producer threads |
| 28 | + * are blocked. |
| 29 | + * i.e. queue(full)->producer(blocked)->consumer(blocked)->producer(blocked) |
| 30 | + * To circumvent this issue, we check in the producer threads if |
| 31 | + * the queues are almost* full and if they are, we make a lock and wait for |
| 32 | + * the consumer thread to signal it's queues are not longer at capacity. This |
| 33 | + * frees a thread for the consumer thread to write to the writer. Once the |
| 34 | + * consumer thread is finished writing, it will notify all the producer threads |
| 35 | + * to continue sending data. The check for the queues being almost full is |
| 36 | + * done by checking if the size of the queue is greater than the max capacity |
| 37 | + * minus the number of threads times 2. This is a heuristic to make sure that, |
| 38 | + * if some threads slip past the check and write to the queue, the bounded queue |
| 39 | + * will still not be full. |
| 40 | + * |
| 41 | + * @tparam Writer A type that inherits from `writer` |
| 42 | + */ |
| 43 | +template <typename Writer> |
| 44 | +struct concurrent_writer { |
| 45 | + // A reference to the writer to write to |
| 46 | + std::reference_wrapper<Writer> writer; |
| 47 | + // Queue for Eigen vector messages |
| 48 | + tbb::concurrent_bounded_queue<Eigen::RowVectorXd> eigen_messages_{}; |
| 49 | + // Block threads from writing to queues if the queues are full |
| 50 | + std::mutex block_{}; |
| 51 | + // The writing thread |
| 52 | + std::thread thread_; |
| 53 | + // Condition variable to signal the writing thread to continue |
| 54 | + std::condition_variable cv; |
| 55 | + // Maximum number of threads that can be in use |
| 56 | + std::size_t max_threads{tbb::global_control::max_allowed_parallelism}; |
| 57 | + // Max capacity of queue |
| 58 | + std::size_t max_capacity{1000 + max_threads}; |
| 59 | + // Threshold where the writing threads will wait for the queues to empty |
| 60 | + std::size_t wait_threshold{max_capacity - max_threads - 1}; |
| 61 | + // Flag to stop the writing thread once all queues are empty |
| 62 | + bool continue_writing_{true}; |
| 63 | + |
| 64 | + /** |
| 65 | + * Constructs a concurrent writer from a writer. |
| 66 | + * @note This will start a thread to write to the writer. |
| 67 | + * @param writer A writer to write to |
| 68 | + */ |
| 69 | + explicit concurrent_writer(Writer& writer) : writer(writer) { |
| 70 | + eigen_messages_.set_capacity(max_capacity); |
| 71 | + thread_ = std::thread([&]() { |
| 72 | + Eigen::RowVectorXd eigen; |
| 73 | + while (continue_writing_ || !eigen_messages_.empty()) { |
| 74 | + while (eigen_messages_.try_pop(eigen)) { |
| 75 | + writer(eigen); |
| 76 | + } |
| 77 | + if (this->empty()) { |
| 78 | + cv.notify_all(); |
| 79 | + std::this_thread::yield(); |
| 80 | + } |
| 81 | + } |
| 82 | + }); |
| 83 | + } |
| 84 | + |
| 85 | + /** |
| 86 | + * Checks if all queues are empty |
| 87 | + */ |
| 88 | + inline bool empty() { return eigen_messages_.empty(); } |
| 89 | + |
| 90 | + /** |
| 91 | + * Check if any of the queues are at capacity |
| 92 | + */ |
| 93 | + inline bool hit_capacity() { |
| 94 | + return eigen_messages_.size() >= wait_threshold; |
| 95 | + } |
| 96 | + |
| 97 | + /** |
| 98 | + * Place a value in a queue for writing. |
| 99 | + * @note If any of the queues are at capacity, the thread yields itself until |
| 100 | + * the the queues empty. In the case of spurious startups the wait just checks |
| 101 | + * that the queues are not full. |
| 102 | + * @tparam T An Eigen vector |
| 103 | + * @param t A value to put on a queue |
| 104 | + */ |
| 105 | + template <typename T> |
| 106 | + void operator()(T&& t) { |
| 107 | + bool pushed = false; |
| 108 | + if (this->hit_capacity()) { |
| 109 | + std::unique_lock lk(block_); |
| 110 | + cv.wait(lk, [this_ = this] { return !(this_->hit_capacity()); }); |
| 111 | + } |
| 112 | + while (!pushed) { |
| 113 | + if constexpr (stan::is_std_vector<T>::value) { |
| 114 | + pushed = eigen_messages_.try_push( |
| 115 | + Eigen::RowVectorXd::Map(t.data(), t.size())); |
| 116 | + } else if constexpr (stan::is_eigen_vector<T>::value) { |
| 117 | + pushed = eigen_messages_.try_push(std::forward<T>(t)); |
| 118 | + } else { |
| 119 | + constexpr bool is_numeric_std_vector |
| 120 | + = stan::is_std_vector<T>::value |
| 121 | + && std::is_arithmetic_v<stan::value_type_t<T>>; |
| 122 | + static_assert( |
| 123 | + (!is_numeric_std_vector && !stan::is_eigen_vector<T>::value), |
| 124 | + "Unsupported type passed to concurrent_writer. This is an " |
| 125 | + "internal error. Please file an issue on the stan github " |
| 126 | + "repository with the error log from the compiler.\n" |
| 127 | + "https://github.com/stan-dev/stan/issues/new?template=Blank+issue"); |
| 128 | + } |
| 129 | + if (!pushed) { |
| 130 | + std::this_thread::yield(); |
| 131 | + } |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + /** |
| 136 | + * Waits till all writes are finished on the thread |
| 137 | + */ |
| 138 | + void wait() { |
| 139 | + continue_writing_ = false; |
| 140 | + if (thread_.joinable()) { |
| 141 | + // If any threads are waiting for the queues to empty, notify them |
| 142 | + cv.notify_all(); |
| 143 | + thread_.join(); |
| 144 | + } |
| 145 | + } |
| 146 | + /** |
| 147 | + * Destructor makes sure the thread is joined before destruction |
| 148 | + */ |
| 149 | + ~concurrent_writer() { wait(); } |
| 150 | +}; |
| 151 | +#else |
| 152 | +/** |
| 153 | + * When STAN_THREADS is not defined, the concurrent writer is just a wrapper |
| 154 | + */ |
| 155 | +template <typename Writer> |
| 156 | +struct concurrent_writer { |
| 157 | + std::reference_wrapper<Writer> writer; |
| 158 | + explicit concurrent_writer(Writer& writer) : writer(writer) {} |
| 159 | + template <typename T> |
| 160 | + void operator()(T&& t) { |
| 161 | + writer(std::forward<T>(t)); |
| 162 | + } |
| 163 | + inline static constexpr void wait() {} |
| 164 | +}; |
| 165 | +#endif |
| 166 | +} // namespace stan::callbacks |
| 167 | +#endif |
0 commit comments