Skip to content

Commit 86fb358

Browse files
authored
Merge pull request #3325 from stan-dev/feature/reduce-mem-pathfinder
Feature/reduce mem pathfinder
2 parents 30098c0 + 52ab8a3 commit 86fb358

14 files changed

+1471
-605
lines changed
+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#ifndef STAN_CALLBACKS_CONCURRENT_WRITER_HPP
2+
#define STAN_CALLBACKS_CONCURRENT_WRITER_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
6+
#include <tbb/concurrent_queue.h>
7+
#include <condition_variable>
8+
#include <functional>
9+
#include <string>
10+
#include <thread>
11+
#include <vector>
12+
13+
namespace stan::callbacks {
14+
#ifdef STAN_THREADS
15+
/**
16+
* Enables thread-safe writing of numeric values to a writer.
17+
* On construction, a thread is spawned to write to the writer.
18+
* This class uses an `std::thread` instead of a tbb task graph because
19+
* of deadlocking issues. A deadlock can happen in two major cases.
20+
* 1. If TBB gives all threads a task, and all threads hit an instance of max
21+
* capacity. TBB can choose to wait for a thread to finish instead of spinning
22+
* up the write thread. So to circumvent that issue, we use an std::thread.
23+
* 2. If the bounded queues are full but the consumer thread is not scheduled
24+
* because there are more busy threads than the number of threads available.
25+
* The producer threads are blocked because the queues are full. The consumer
26+
* thread is blocked because the producer thread is spinning. Then we have a
27+
* deadlock because the consumer thread is blocked because the producer threads
28+
* are blocked.
29+
* i.e. queue(full)->producer(blocked)->consumer(blocked)->producer(blocked)
30+
* To circumvent this issue, we check in the producer threads if
31+
* the queues are almost* full and if they are, we make a lock and wait for
32+
* the consumer thread to signal it's queues are not longer at capacity. This
33+
* frees a thread for the consumer thread to write to the writer. Once the
34+
* consumer thread is finished writing, it will notify all the producer threads
35+
* to continue sending data. The check for the queues being almost full is
36+
* done by checking if the size of the queue is greater than the max capacity
37+
* minus the number of threads times 2. This is a heuristic to make sure that,
38+
* if some threads slip past the check and write to the queue, the bounded queue
39+
* will still not be full.
40+
*
41+
* @tparam Writer A type that inherits from `writer`
42+
*/
43+
template <typename Writer>
44+
struct concurrent_writer {
45+
// A reference to the writer to write to
46+
std::reference_wrapper<Writer> writer;
47+
// Queue for Eigen vector messages
48+
tbb::concurrent_bounded_queue<Eigen::RowVectorXd> eigen_messages_{};
49+
// Block threads from writing to queues if the queues are full
50+
std::mutex block_{};
51+
// The writing thread
52+
std::thread thread_;
53+
// Condition variable to signal the writing thread to continue
54+
std::condition_variable cv;
55+
// Maximum number of threads that can be in use
56+
std::size_t max_threads{tbb::global_control::max_allowed_parallelism};
57+
// Max capacity of queue
58+
std::size_t max_capacity{1000 + max_threads};
59+
// Threshold where the writing threads will wait for the queues to empty
60+
std::size_t wait_threshold{max_capacity - max_threads - 1};
61+
// Flag to stop the writing thread once all queues are empty
62+
bool continue_writing_{true};
63+
64+
/**
65+
* Constructs a concurrent writer from a writer.
66+
* @note This will start a thread to write to the writer.
67+
* @param writer A writer to write to
68+
*/
69+
explicit concurrent_writer(Writer& writer) : writer(writer) {
70+
eigen_messages_.set_capacity(max_capacity);
71+
thread_ = std::thread([&]() {
72+
Eigen::RowVectorXd eigen;
73+
while (continue_writing_ || !eigen_messages_.empty()) {
74+
while (eigen_messages_.try_pop(eigen)) {
75+
writer(eigen);
76+
}
77+
if (this->empty()) {
78+
cv.notify_all();
79+
std::this_thread::yield();
80+
}
81+
}
82+
});
83+
}
84+
85+
/**
86+
* Checks if all queues are empty
87+
*/
88+
inline bool empty() { return eigen_messages_.empty(); }
89+
90+
/**
91+
* Check if any of the queues are at capacity
92+
*/
93+
inline bool hit_capacity() {
94+
return eigen_messages_.size() >= wait_threshold;
95+
}
96+
97+
/**
98+
* Place a value in a queue for writing.
99+
* @note If any of the queues are at capacity, the thread yields itself until
100+
* the the queues empty. In the case of spurious startups the wait just checks
101+
* that the queues are not full.
102+
* @tparam T An Eigen vector
103+
* @param t A value to put on a queue
104+
*/
105+
template <typename T>
106+
void operator()(T&& t) {
107+
bool pushed = false;
108+
if (this->hit_capacity()) {
109+
std::unique_lock lk(block_);
110+
cv.wait(lk, [this_ = this] { return !(this_->hit_capacity()); });
111+
}
112+
while (!pushed) {
113+
if constexpr (stan::is_std_vector<T>::value) {
114+
pushed = eigen_messages_.try_push(
115+
Eigen::RowVectorXd::Map(t.data(), t.size()));
116+
} else if constexpr (stan::is_eigen_vector<T>::value) {
117+
pushed = eigen_messages_.try_push(std::forward<T>(t));
118+
} else {
119+
constexpr bool is_numeric_std_vector
120+
= stan::is_std_vector<T>::value
121+
&& std::is_arithmetic_v<stan::value_type_t<T>>;
122+
static_assert(
123+
(!is_numeric_std_vector && !stan::is_eigen_vector<T>::value),
124+
"Unsupported type passed to concurrent_writer. This is an "
125+
"internal error. Please file an issue on the stan github "
126+
"repository with the error log from the compiler.\n"
127+
"https://github.com/stan-dev/stan/issues/new?template=Blank+issue");
128+
}
129+
if (!pushed) {
130+
std::this_thread::yield();
131+
}
132+
}
133+
}
134+
135+
/**
136+
* Waits till all writes are finished on the thread
137+
*/
138+
void wait() {
139+
continue_writing_ = false;
140+
if (thread_.joinable()) {
141+
// If any threads are waiting for the queues to empty, notify them
142+
cv.notify_all();
143+
thread_.join();
144+
}
145+
}
146+
/**
147+
* Destructor makes sure the thread is joined before destruction
148+
*/
149+
~concurrent_writer() { wait(); }
150+
};
151+
#else
152+
/**
153+
* When STAN_THREADS is not defined, the concurrent writer is just a wrapper
154+
*/
155+
template <typename Writer>
156+
struct concurrent_writer {
157+
std::reference_wrapper<Writer> writer;
158+
explicit concurrent_writer(Writer& writer) : writer(writer) {}
159+
template <typename T>
160+
void operator()(T&& t) {
161+
writer(std::forward<T>(t));
162+
}
163+
inline static constexpr void wait() {}
164+
};
165+
#endif
166+
} // namespace stan::callbacks
167+
#endif

src/stan/callbacks/stream_writer.hpp

+48
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,55 @@ class stream_writer : public writer {
6868
output_ << comment_prefix_ << message << std::endl;
6969
}
7070

71+
/**
72+
* Writes multiple rows and columns of values in csv format.
73+
*
74+
* Note: the precision of the output is determined by the settings
75+
* of the stream on construction.
76+
*
77+
* @param[in] values A matrix of values. The input is expected to have
78+
* parameters in the rows and samples in the columns. The matrix is then
79+
* transposed for the output.
80+
*/
81+
void operator()(const Eigen::Matrix<double, -1, -1>& values) {
82+
output_ << values.transpose().format(CommaInitFmt);
83+
}
84+
/**
85+
* Write a row of values in csv format.
86+
*
87+
* Note: the precision of the output is determined by the settings
88+
* of the stream on construction.
89+
*
90+
* @param[in] values A column vector of values.
91+
*/
92+
void operator()(const Eigen::Matrix<double, -1, 1>& values) {
93+
output_ << values.transpose().format(CommaInitFmt);
94+
}
95+
96+
/**
97+
* Write a row of values in csv format
98+
*
99+
* Note: the precision of the output is determined by the settings
100+
* of the stream on construction.
101+
*
102+
* @param[in] values A row vector of values.
103+
*/
104+
void operator()(const Eigen::Matrix<double, 1, -1>& values) {
105+
output_ << values.format(CommaInitFmt);
106+
}
107+
108+
/**
109+
* Checks if stream is valid.
110+
*/
111+
virtual bool is_valid() const noexcept { return output_.good(); }
112+
71113
private:
114+
/**
115+
* Comma formatter for writing Eigen matrices
116+
*/
117+
Eigen::IOFormat CommaInitFmt{
118+
Eigen::StreamPrecision, Eigen::DontAlignCols, ", ", "", "", "\n", "", ""};
119+
72120
/**
73121
* Output stream
74122
*/

src/stan/callbacks/tee_writer.hpp

+59-38
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,86 @@
11
#ifndef STAN_CALLBACKS_TEE_WRITER_HPP
22
#define STAN_CALLBACKS_TEE_WRITER_HPP
33

4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/functor/apply.hpp>
6+
#include <stan/math/prim/functor/for_each.hpp>
47
#include <stan/callbacks/writer.hpp>
58
#include <ostream>
69
#include <vector>
710
#include <string>
811

9-
namespace stan {
10-
namespace callbacks {
12+
namespace stan::callbacks {
1113

1214
/**
13-
* <code>tee_writer</code> is an implementation that writes to
14-
* two writers.
15-
*
16-
* For any call to this writer, it will tee the call to both writers
17-
* provided in the constructor.
15+
* `tee_writer` is an layer on top of a writer class that
16+
* allows for multiple output streams to be written to.
17+
* @tparam Writers A parameter pack of types that inherit from `writer`
1818
*/
19-
class tee_writer final : public writer {
19+
template <typename... Writers>
20+
class tee_writer {
2021
public:
2122
/**
22-
* Constructor accepting two writers.
23-
*
24-
* @param[in, out] writer1 first writer
25-
* @param[in, out] writer2 second writer
23+
* Constructs a multi stream writer from a parameter pack of writers.
24+
* @param[in, out] args A parameter pack of writers
2625
*/
27-
tee_writer(writer& writer1, writer& writer2)
28-
: writer1_(writer1), writer2_(writer2) {}
26+
explicit tee_writer(Writers&... args) : output_(args...) {}
2927

30-
virtual ~tee_writer() {}
28+
tee_writer() = default;
3129

32-
void operator()(const std::vector<std::string>& names) {
33-
writer1_(names);
34-
writer2_(names);
35-
}
36-
37-
void operator()(const std::vector<double>& state) {
38-
writer1_(state);
39-
writer2_(state);
30+
/**
31+
* @tparam T Any type accepted by a `writer` overload
32+
* @param[in] x A value to write to the output streams
33+
*/
34+
template <typename T>
35+
void operator()(T&& x) {
36+
stan::math::for_each([&](auto&& output) { output(x); }, output_);
4037
}
41-
38+
/**
39+
* Write a comment prefix to each writer
40+
*/
4241
void operator()() {
43-
writer1_();
44-
writer2_();
45-
}
46-
47-
void operator()(const std::string& message) {
48-
writer1_(message);
49-
writer2_(message);
42+
stan::math::for_each([](auto&& output) { output(); }, output_);
5043
}
5144

52-
private:
5345
/**
54-
* The first writer
46+
* Checks if all underlying writers are nonnull.
5547
*/
56-
writer& writer1_;
48+
inline bool is_valid() const noexcept {
49+
return stan::math::apply(
50+
[](auto&&... output) { return (output.is_valid() && ...); }, output_);
51+
}
52+
5753
/**
58-
* The second writer
54+
* Get the tuple of underlying streams
5955
*/
60-
writer& writer2_;
56+
inline auto& get_stream() noexcept { return output_; }
57+
58+
private:
59+
// Output streams
60+
std::tuple<std::reference_wrapper<Writers>...> output_;
6161
};
6262

63-
} // namespace callbacks
64-
} // namespace stan
63+
namespace internal {
64+
template <typename T>
65+
struct is_tee_writer : std::false_type {};
66+
67+
template <typename... Types>
68+
struct is_tee_writer<tee_writer<Types...>> : std::true_type {};
69+
} // namespace internal
70+
71+
/**
72+
* Type trait that checks if a type is a `tee_writer`
73+
* @tparam T A type to check
74+
*/
75+
template <typename T>
76+
struct is_tee_writer : internal::is_tee_writer<std::decay_t<T>> {};
77+
78+
/**
79+
* Helper variable template to check if a type is a `tee_writer`
80+
*/
81+
template <typename T>
82+
inline constexpr bool is_tee_writer_v = is_tee_writer<T>::value;
83+
84+
} // namespace stan::callbacks
85+
6586
#endif

0 commit comments

Comments
 (0)