diff --git a/lib/stan_math b/lib/stan_math index afea210fd8..ac977e4ed7 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit afea210fd8d68edcf0679a7618510e7dfe767103 +Subproject commit ac977e4ed79bd593440b9afbc25e1bba905e9276 diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp new file mode 100644 index 0000000000..f8b3f116ad --- /dev/null +++ b/src/stan/callbacks/dispatcher.hpp @@ -0,0 +1,175 @@ +#ifndef STAN_CALLBACKS_DISPATCHER_HPP +#define STAN_CALLBACKS_DISPATCHER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace callbacks { + +enum class InfoType { + CONFIG, // series of string messages + SAMPLE, // draw from posterior + SAMPLE_RAW, // draw from posterior + METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' + ALGORITHM_STATE, // sampler state for returned draw +}; + +struct InfoTypeHash { + std::size_t operator()(const InfoType& type) const { + return std::hash()(static_cast(type)); + } +}; + +// Base type for type erasure. +class Channel { + public: + virtual ~Channel() = default; +}; + +// Adapter for plain writers. +class WriterChannel : public Channel { + public: + explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) { + if (!w) { + throw std::runtime_error("config error, null writer"); + } + } + + // Handle all types that writer supports via operator() + void dispatch() { (*writer_)(); } + void dispatch(const std::string& value) { (*writer_)(value); } + void dispatch(const std::vector& value) { (*writer_)(value); } + void dispatch(const std::vector& value) { (*writer_)(value); } + + // Handle any Eigen Matrix type + template + void dispatch(const Eigen::Matrix& value) { + (*writer_)(value); + } + + // No key-value support for plain writers + template + void dispatch(const std::string&, const T&) {} + + private: + stan::callbacks::writer* writer_; +}; + +// Adapter for structured writers. +class StructuredWriterChannel : public Channel { + public: + explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw) + : writer_(sw) { + if (!sw) + throw std::runtime_error("config error, null writer"); + } + // Forward all key-value calls directly to the writer + void dispatch(const std::string& key) { writer_->write(key); } + // Perfect forwarding for any key-value pair + template + void dispatch(const std::string& key, T&& value) { + writer_->write(key, std::forward(value)); + } + void begin_record() { writer_->begin_record(); } + void begin_record(const std::string& key) { writer_->begin_record(key); } + void end_record() { writer_->end_record(); } + + private: + stan::callbacks::structured_writer* writer_; +}; + +// dispatcher class +class dispatcher { + public: + dispatcher() = default; + ~dispatcher() = default; + + void register_channel(InfoType type, std::unique_ptr channel) { + channels_[type] = std::move(channel); + } + + // no-arg call to writer operator () + void dispatch(InfoType type) { + auto it = channels_.find(type); + if (it == channels_.end()) + return; // silently do nothing + if (auto* wc = dynamic_cast(it->second.get())) + wc->dispatch(); + } + + // single non-string argument - call writer operator () + template < + typename T, + typename = std::enable_if_t< + std::is_same_v< + std::decay_t, + std:: + string> || std::is_same_v, std::vector> || std::is_same_v, std::vector>>> // NOLINT + void dispatch(InfoType type, T&& value) { + if (auto* wc = find_channel(type)) + wc->dispatch(std::forward(value)); + } + + // Eigen matrix types + template + void dispatch(InfoType type, const Eigen::Matrix& value) { + if (auto* wc = find_channel(type)) + wc->dispatch(value); + } + + // Key with no value (null) + void dispatch(InfoType type, const std::string& key) { + if (auto* sw = find_channel(type)) + sw->dispatch(key); + } + + // Key-value pairs (forward to structured writers) + template + void dispatch(InfoType type, const std::string& key, T&& value) { + if (auto* sw = find_channel(type)) + sw->dispatch(key, std::forward(value)); + } + + // Record operations + void begin_record(InfoType type) { + if (auto* sw = find_channel(type)) + sw->begin_record(); + } + + void begin_record(InfoType type, const std::string& key) { + if (auto* sw = find_channel(type)) + sw->begin_record(key); + } + + void end_record(InfoType type) { + if (auto* sw = find_channel(type)) + sw->end_record(); + } + + private: + // Helper to find and cast a channel of specific type + template + ChannelType* find_channel(InfoType type) { + auto it = channels_.find(type); + if (it == channels_.end()) + return nullptr; + return dynamic_cast(it->second.get()); + } + + std::unordered_map, InfoTypeHash> + channels_; +}; + +} // namespace callbacks +} // namespace stan + +#endif // STAN_CALLBACKS_DISPATCHER_HPP diff --git a/src/stan/callbacks/in_memory_writer.hpp b/src/stan/callbacks/in_memory_writer.hpp new file mode 100644 index 0000000000..25d38ec905 --- /dev/null +++ b/src/stan/callbacks/in_memory_writer.hpp @@ -0,0 +1,183 @@ +#ifndef STAN_CALLBACKS_IN_MEMORY_WRITER_HPP +#define STAN_CALLBACKS_IN_MEMORY_WRITER_HPP + +#include +#include +#include +#include +#include + +namespace stan { +namespace callbacks { + +/** + * A general in-memory writer that stores draws from Stan's write_array + * callback into a contiguous Eigen matrix. The matrix is allocated with + * column-major storage (Eigen's default), meaning that the element at + * row i and column j is stored at index (i + j * num_rows) in memory. + * + * The writer accepts draws row-by-row and writes them into the matrix. + * If more rows are written than initially allocated, or if the input row + * does not have the expected number of columns, an exception is thrown. + */ +class in_memory_writer : public stan::callbacks::writer { + public: + /** + * Construct an in-memory writer. + * + * @param num_rows the total number of rows (draws) expected + * @param num_cols the number of columns (parameters) per draw + * + * The underlying Eigen matrix is allocated to have dimensions + * num_rows x num_cols in column-major order. + */ + in_memory_writer(std::size_t num_rows, std::size_t num_cols) + : num_rows_(num_rows), + num_cols_(num_cols), + data_(Eigen::MatrixXd::Zero(num_rows, num_cols)), // column-major + current_row_(0), + names_() {} + + virtual ~in_memory_writer() {} + + /** + * Writes a set of names. + * + * @param[in] names Names in a std::vector + */ + void operator()(const std::vector& names) override { + names_ = names; + } + + /** + * Writes a single row (draw) to the in-memory matrix. + * + * The input vector must have exactly num_cols elements. + * The row is written into the matrix at the current row index. + * If the matrix is full, an exception is thrown. + * + * @param row A vector containing one draw from the posterior. + */ + void operator()(const std::vector& row) override { + if (row.size() != num_cols_) { + throw std::runtime_error("Row size does not match the number of columns"); + } + if (current_row_ >= num_rows_) { + throw std::runtime_error("Attempted to write more rows than allocated"); + } + // Because Eigen::MatrixXd is column-major by default, simply assigning + // row by row will place the data in memory in the order: + // index = current_row_ + (column_index * num_rows_). + for (std::size_t j = 0; j < num_cols_; ++j) { + data_(current_row_, j) = row[j]; + } + ++current_row_; + } + + /** + * Default implementation for empty call. + */ + void operator()() override {} + + /** + * Default implementation for string message. + */ + void operator()(const std::string& message) override {} + + /** + * Handles Eigen matrix input by converting to rows and inserting. + * This method treats each row of the input matrix as a separate draw. + */ + void operator()(const Eigen::Matrix& matrix) override { + for (Eigen::Index i = 0; i < matrix.rows(); ++i) { + std::vector row(matrix.cols()); + for (Eigen::Index j = 0; j < matrix.cols(); ++j) { + row[j] = matrix(i, j); + } + (*this)(row); // Use the vector operator to insert the row + } + } + + /** + * Handles Eigen vector input by inserting as a single row. + */ + void operator()(const Eigen::Matrix& vector) override { + std::vector row(vector.size()); + for (Eigen::Index i = 0; i < vector.size(); ++i) { + row[i] = vector(i); + } + (*this)(row); // Use the vector operator to insert the row + } + + /** + * Handles Eigen row vector input by inserting as a single row. + */ + void operator()(const Eigen::Matrix& vector) override { + std::vector row(vector.size()); + for (Eigen::Index i = 0; i < vector.size(); ++i) { + row[i] = vector(i); + } + (*this)(row); // Use the vector operator to insert the row + } + + /** + * Always returns true as the in-memory writer is always valid. + */ + bool is_valid() const noexcept override { return true; } + + /** + * Returns a const reference to the in-memory Eigen matrix containing all + * draws. + * + * The matrix is stored in column-major order. + * + * @return const reference to the Eigen::MatrixXd holding the draws. + */ + const Eigen::MatrixXd& get_eigen_state_values() const { return data_; } + + /** + * Returns a const reference to the column names. + * + * @return const reference to the vector of column names. + */ + const std::vector& get_names() const { return names_; } + + /** + * Returns the number of rows that have been written so far. + * + * @return The current row count. + */ + std::size_t get_row_count() const { return current_row_; } + + /** + * Resets the writer to its initial state. + * + * Clears the stored data (sets the matrix to zero) and resets the current row + * index. Column names are retained. + */ + void reset() { + current_row_ = 0; + data_.setZero(); + } + + /** + * Fully resets the writer, including clearing column names. + */ + void clear() { + current_row_ = 0; + data_.setZero(); + names_.clear(); + } + + private: + std::size_t num_rows_; // Total number of draws (rows) expected. + std::size_t num_cols_; // Number of parameters (columns) per draw. + Eigen::MatrixXd data_; // Internal storage; Eigen matrices are column-major. + std::size_t current_row_; // Next row index to be written. + std::vector names_; // Column names +}; + +} // namespace callbacks +} // namespace stan + +#endif // STAN_CALLBACKS_IN_MEMORY_WRITER_HPP diff --git a/src/stan/callbacks/matrix_writer.hpp b/src/stan/callbacks/matrix_writer.hpp new file mode 100644 index 0000000000..81f7974728 --- /dev/null +++ b/src/stan/callbacks/matrix_writer.hpp @@ -0,0 +1,50 @@ +#ifndef STAN_CALLBACKS_MATRIX_WRITER_HPP +#define STAN_CALLBACKS_MATRIX_WRITER_HPP + +#include +#include +#include + +namespace stan { +namespace callbacks { + +class matrix_writer : public writer { + public: + matrix_writer(size_t rows, size_t cols) + : rows_(rows), cols_(cols), data_(rows * cols) {} + + void operator()(const std::vector& x) override { + if (current_row_ >= rows_) { + throw std::runtime_error("Matrix writer: too many rows"); + } + if (x.size() != cols_) { + throw std::runtime_error("Matrix writer: incorrect number of columns"); + } + + // Store in column-major order + for (size_t j = 0; j < cols_; ++j) { + data_[j * rows_ + current_row_] = x[j]; + } + current_row_++; + } + + void operator()(const std::vector& x) override { + throw std::runtime_error("Matrix writer does not support string data"); + } + + double* data() { return data_.data(); } + size_t rows() const { return rows_; } + size_t cols() const { return cols_; } + size_t current_row() const { return current_row_; } + + private: + size_t rows_; + size_t cols_; + size_t current_row_{0}; + std::vector data_; +}; + +} // namespace callbacks +} // namespace stan + +#endif diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp new file mode 100644 index 0000000000..2d992d8462 --- /dev/null +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -0,0 +1,347 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using stan::callbacks::dispatcher; +using stan::callbacks::InfoType; + +struct deleter_noop { + template + constexpr void operator()(T* arg) const {} +}; + +class DispatcherTest : public ::testing::Test { + public: + DispatcherTest() + : ss_sample(), + ss_config(), + ss_metric(), + writer_sample(ss_sample), + writer_config(ss_config), + writer_metric( + std::unique_ptr(&ss_metric)), + writer_sample_in_memory(5, 3), + dispatcher() {} + + void SetUp() { + ss_sample.str(std::string()); + ss_sample.clear(); + ss_config.str(std::string()); + ss_config.clear(); + ss_metric.str(std::string()); + ss_metric.clear(); + + dispatcher.register_channel( + InfoType::CONFIG, + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_config))); + + dispatcher.register_channel( + InfoType::SAMPLE, + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_sample))); + + dispatcher.register_channel( + InfoType::METRIC, + std::unique_ptr( + new stan::callbacks::StructuredWriterChannel(&writer_metric))); + + dispatcher.register_channel( + InfoType::SAMPLE_RAW, + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_sample_in_memory))); + } + + void TearDown() {} + + std::stringstream ss_sample; + std::stringstream ss_config; + std::stringstream ss_metric; + + stan::callbacks::stream_writer writer_sample; + stan::callbacks::stream_writer writer_config; + stan::callbacks::json_writer writer_metric; + stan::callbacks::in_memory_writer writer_sample_in_memory; + stan::callbacks::dispatcher dispatcher; +}; + +// Test basic string dispatch to plain writer +TEST_F(DispatcherTest, StringDispatch) { + dispatcher.dispatch(InfoType::CONFIG, std::string("Message1")); + EXPECT_EQ(ss_config.str(), "Message1\n"); +} + +// Test multiple string dispatches +TEST_F(DispatcherTest, MultipleStringDispatch) { + dispatcher.dispatch(InfoType::CONFIG, std::string("Message1")); + dispatcher.dispatch(InfoType::CONFIG, std::string("Message2")); + dispatcher.dispatch(InfoType::CONFIG, std::string("Message3")); + EXPECT_EQ(ss_config.str(), "Message1\nMessage2\nMessage3\n"); +} + +// Test empty call dispatch +TEST_F(DispatcherTest, EmptyDispatch) { + dispatcher.dispatch(InfoType::CONFIG); + // Empty dispatch should produce just a newline in stream_writer + EXPECT_EQ(ss_config.str(), "\n"); +} + +// Test vector of doubles dispatch +TEST_F(DispatcherTest, VectorDoubleDispatch) { + std::vector values = {1.1, 2.2, 3.3}; + dispatcher.dispatch(InfoType::SAMPLE, values); + std::string output = ss_sample.str(); + EXPECT_NE(output.find("1.1"), std::string::npos); + EXPECT_NE(output.find("2.2"), std::string::npos); + EXPECT_NE(output.find("3.3"), std::string::npos); +} + +// Test vector of strings dispatch +TEST_F(DispatcherTest, VectorStringDispatch) { + std::vector names = {"alpha", "beta", "gamma"}; + dispatcher.dispatch(InfoType::SAMPLE, names); + std::string output = ss_sample.str(); + EXPECT_NE(output.find("alpha"), std::string::npos); + EXPECT_NE(output.find("beta"), std::string::npos); + EXPECT_NE(output.find("gamma"), std::string::npos); +} + +// Test Eigen matrix dispatch +TEST_F(DispatcherTest, EigenMatrixDispatch) { + Eigen::MatrixXd matrix(2, 2); + matrix << 1.0, 2.0, 3.0, 4.0; + dispatcher.dispatch(InfoType::SAMPLE, matrix); + std::string output = ss_sample.str(); + EXPECT_NE(output.find("1"), std::string::npos); + EXPECT_NE(output.find("2"), std::string::npos); + EXPECT_NE(output.find("3"), std::string::npos); + EXPECT_NE(output.find("4"), std::string::npos); +} + +// Test Eigen vector dispatch +TEST_F(DispatcherTest, EigenVectorDispatch) { + Eigen::VectorXd vector(3); + vector << 1.0, 2.0, 3.0; + dispatcher.dispatch(InfoType::SAMPLE, vector); + std::string output = ss_sample.str(); + EXPECT_NE(output.find("1"), std::string::npos); + EXPECT_NE(output.find("2"), std::string::npos); + EXPECT_NE(output.find("3"), std::string::npos); +} + +// Test Eigen row vector dispatch +TEST_F(DispatcherTest, EigenRowVectorDispatch) { + Eigen::RowVectorXd vector(3); + vector << 1.0, 2.0, 3.0; + dispatcher.dispatch(InfoType::SAMPLE, vector); + std::string output = ss_sample.str(); + EXPECT_NE(output.find("1"), std::string::npos); + EXPECT_NE(output.find("2"), std::string::npos); + EXPECT_NE(output.find("3"), std::string::npos); +} + +// Test structured writer begin/end record +TEST_F(DispatcherTest, StructuredBeginEndRecord) { + dispatcher.begin_record(InfoType::METRIC); + dispatcher.end_record(InfoType::METRIC); + std::string output = ss_metric.str(); + // JSON output should contain opening and closing braces + EXPECT_NE(output.find("{"), std::string::npos); + EXPECT_NE(output.find("}"), std::string::npos); +} + +TEST_F(DispatcherTest, MetricStructuredKeyValueRecord) { + // For METRIC (structured writer), open a record, dispatch key/value pairs, + // then close the record. + dispatcher.begin_record(InfoType::METRIC); + dispatcher.dispatch(InfoType::METRIC, "metric_type", std::string("diag")); + dispatcher.dispatch(InfoType::METRIC, "stepsize", 0.6789); + std::vector inv_metric = {0.1, 0.2, 0.3}; + dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric); + dispatcher.end_record(InfoType::METRIC); + // Expected output: + // Begin record marker, followed by key/value pairs each formatted as + // "key:value;" and then end record marker. + std::string output = ss_metric.str(); + EXPECT_NE(output.find("metric_type"), std::string::npos); + EXPECT_NE(output.find("diag"), std::string::npos); +} + +// Test structured writer with multiple key-value types +TEST_F(DispatcherTest, StructuredMultipleValueTypes) { + dispatcher.begin_record(InfoType::METRIC); + dispatcher.dispatch(InfoType::METRIC, "string_key", + std::string("string_value")); + dispatcher.dispatch(InfoType::METRIC, "int_key", 42); + dispatcher.dispatch(InfoType::METRIC, "double_key", 3.14159); + dispatcher.dispatch(InfoType::METRIC, "bool_key", true); + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("string_key"), std::string::npos); + EXPECT_NE(output.find("string_value"), std::string::npos); + EXPECT_NE(output.find("int_key"), std::string::npos); + EXPECT_NE(output.find("42"), std::string::npos); + EXPECT_NE(output.find("double_key"), std::string::npos); + EXPECT_NE(output.find("3.14159"), std::string::npos); + EXPECT_NE(output.find("bool_key"), std::string::npos); + EXPECT_NE(output.find("true"), std::string::npos); +} + +// Test structured writer with Eigen values +TEST_F(DispatcherTest, StructuredEigenValues) { + dispatcher.begin_record(InfoType::METRIC); + + Eigen::MatrixXd matrix(2, 2); + matrix << 1.0, 2.0, 3.0, 4.0; + dispatcher.dispatch(InfoType::METRIC, "matrix", matrix); + + Eigen::VectorXd vector(3); + vector << 5.0, 6.0, 7.0; + dispatcher.dispatch(InfoType::METRIC, "vector", vector); + + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("matrix"), std::string::npos); + EXPECT_NE(output.find("1"), std::string::npos); + EXPECT_NE(output.find("4"), std::string::npos); + EXPECT_NE(output.find("vector"), std::string::npos); + EXPECT_NE(output.find("5"), std::string::npos); + EXPECT_NE(output.find("7"), std::string::npos); +} + +// Test unregistered channel +TEST_F(DispatcherTest, UnregisteredChannel) { + // Dispatch to unregistered channel should silently do nothing + dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::string("Message")); + dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::vector{1.0, 2.0}); + dispatcher.begin_record(InfoType::ALGORITHM_STATE); + dispatcher.dispatch(InfoType::ALGORITHM_STATE, "key", "value"); + dispatcher.end_record(InfoType::ALGORITHM_STATE); + + // No exceptions should be thrown +} + +// Test named record +TEST_F(DispatcherTest, NamedRecord) { + dispatcher.begin_record(InfoType::METRIC, "record_name"); + dispatcher.dispatch(InfoType::METRIC, "key", "value"); + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("record_name"), std::string::npos); + EXPECT_NE(output.find("key"), std::string::npos); + EXPECT_NE(output.find("value"), std::string::npos); +} + +// Test that begin_record and end_record on a plain writer channel are +// silently ignored +TEST_F(DispatcherTest, RecordOperationsOnPlainWriter) { + dispatcher.begin_record(InfoType::CONFIG); + dispatcher.end_record(InfoType::CONFIG); + + // Should not generate any output + EXPECT_EQ(ss_config.str(), ""); +} + +// Test in_memory_writer integration with dispatcher +TEST_F(DispatcherTest, InMemoryWriterBasic) { + // Write rows of data to the in_memory_writer via the dispatcher + std::vector row1 = {1.1, 2.2, 3.3}; + std::vector row2 = {4.4, 5.5, 6.6}; + std::vector row3 = {7.7, 8.8, 9.9}; + + dispatcher.dispatch(InfoType::SAMPLE_RAW, row1); + dispatcher.dispatch(InfoType::SAMPLE_RAW, row2); + dispatcher.dispatch(InfoType::SAMPLE_RAW, row3); + + // Check that the data was stored correctly + const Eigen::MatrixXd& data + = writer_sample_in_memory.get_eigen_state_values(); + + EXPECT_EQ(data.rows(), 5); // As initialized + EXPECT_EQ(data.cols(), 3); // As initialized + + // Check the stored values (first 3 rows should have our data) + EXPECT_DOUBLE_EQ(data(0, 0), 1.1); + EXPECT_DOUBLE_EQ(data(0, 1), 2.2); + EXPECT_DOUBLE_EQ(data(0, 2), 3.3); + + EXPECT_DOUBLE_EQ(data(1, 0), 4.4); + EXPECT_DOUBLE_EQ(data(1, 1), 5.5); + EXPECT_DOUBLE_EQ(data(1, 2), 6.6); + + EXPECT_DOUBLE_EQ(data(2, 0), 7.7); + EXPECT_DOUBLE_EQ(data(2, 1), 8.8); + EXPECT_DOUBLE_EQ(data(2, 2), 9.9); +} + +// Test in_memory_writer with row overflow +TEST_F(DispatcherTest, InMemoryWriterRowOverflow) { + // Try to write more rows than allocated + std::vector row = {1.0, 2.0, 3.0}; + + // Should be able to write 5 rows (as initialized) + for (int i = 0; i < 5; i++) { + dispatcher.dispatch(InfoType::SAMPLE_RAW, row); + } + + // Sixth row should throw exception + EXPECT_THROW(dispatcher.dispatch(InfoType::SAMPLE_RAW, row), + std::runtime_error); +} + +// Test in_memory_writer with column mismatch +TEST_F(DispatcherTest, InMemoryWriterColumnMismatch) { + // Try to write a row with wrong number of columns + std::vector wrong_size_row + = {1.0, 2.0}; // Only 2 columns when we need 3 + + EXPECT_THROW(dispatcher.dispatch(InfoType::SAMPLE_RAW, wrong_size_row), + std::runtime_error); + + std::vector also_wrong_size + = {1.0, 2.0, 3.0, 4.0}; // 4 columns when we need 3 + + EXPECT_THROW(dispatcher.dispatch(InfoType::SAMPLE_RAW, also_wrong_size), + std::runtime_error); +} + +// Test in_memory_writer reset +TEST_F(DispatcherTest, InMemoryWriterReset) { + // Write some data + std::vector row = {1.1, 2.2, 3.3}; + dispatcher.dispatch(InfoType::SAMPLE_RAW, row); + + // Verify data is written + const Eigen::MatrixXd& data1 + = writer_sample_in_memory.get_eigen_state_values(); + EXPECT_DOUBLE_EQ(data1(0, 0), 1.1); + + // Reset the writer + writer_sample_in_memory.reset(); + + // Verify data is cleared + const Eigen::MatrixXd& data2 + = writer_sample_in_memory.get_eigen_state_values(); + EXPECT_DOUBLE_EQ(data2(0, 0), 0.0); + + // Write new data + std::vector new_row = {4.4, 5.5, 6.6}; + dispatcher.dispatch(InfoType::SAMPLE_RAW, new_row); + + // Verify new data is written + const Eigen::MatrixXd& data3 + = writer_sample_in_memory.get_eigen_state_values(); + EXPECT_DOUBLE_EQ(data3(0, 0), 4.4); +} diff --git a/src/test/unit/callbacks/in_memory_writer_test.cpp b/src/test/unit/callbacks/in_memory_writer_test.cpp new file mode 100644 index 0000000000..21a750c5be --- /dev/null +++ b/src/test/unit/callbacks/in_memory_writer_test.cpp @@ -0,0 +1,358 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +using stan::callbacks::dispatcher; +using stan::callbacks::in_memory_writer; +using stan::callbacks::InfoType; +using stan::callbacks::WriterChannel; + +// Helper function to split a comma-separated line into tokens +std::vector split_csv_line(const std::string& line) { + std::vector tokens; + std::stringstream ss(line); + std::string token; + + while (std::getline(ss, token, ',')) { + // Trim whitespace + token.erase(0, token.find_first_not_of(" \t")); + token.erase(token.find_last_not_of(" \t") + 1); + tokens.push_back(token); + } + + return tokens; +} + +// Helper function to convert string tokens to doubles +std::vector convert_to_doubles(const std::vector& tokens) { + std::vector values; + for (const auto& token : tokens) { + values.push_back(std::stod(token)); + } + return values; +} + +class InMemoryWriterTest : public ::testing::Test { + protected: + void SetUp() override { + // Parse the header + header_tokens = split_csv_line( + "theta, " + "mu,y_rep.1,y_rep.2,y_rep.3,y_rep.4,y_rep.5,y_rep.6,y_rep.7,y_rep.8,y_" + "rep.9,y_rep.10"); + num_cols = header_tokens.size(); + + // Parse the sample rows + std::vector sample_lines + = {"0.309252,0.309252,0,0,1,1,0,0,0,1,0,0", + "0.572524,0.572524,0,0,0,1,0,1,1,0,1,0", + "0.0795978,0.0795978,0,0,0,0,0,0,0,0,0,0"}; + + for (const auto& line : sample_lines) { + auto tokens = split_csv_line(line); + auto values = convert_to_doubles(tokens); + sample_rows.push_back(values); + } + + num_rows = sample_rows.size(); + } + + std::vector header_tokens; + std::vector> sample_rows; + size_t num_rows; + size_t num_cols; +}; + +// Test basic writing and retrieval of samples +TEST_F(InMemoryWriterTest, BasicWriteAndRetrieve) { + // Create writer with exact capacity + in_memory_writer writer(num_rows, num_cols); + + // Write header + writer(header_tokens); + + // Write sample rows + for (const auto& row : sample_rows) { + writer(row); + } + + // Verify names were stored correctly + ASSERT_EQ(writer.get_names().size(), num_cols); + for (size_t i = 0; i < num_cols; ++i) { + EXPECT_EQ(writer.get_names()[i], header_tokens[i]); + } + + // Verify data was stored correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + ASSERT_EQ(data.rows(), num_rows); + ASSERT_EQ(data.cols(), num_cols); + + for (size_t i = 0; i < num_rows; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i, j), sample_rows[i][j]); + } + } + + // Verify row count + EXPECT_EQ(writer.get_row_count(), num_rows); +} + +// Test writing more data than allocated +TEST_F(InMemoryWriterTest, WriteOverflow) { + // Create writer with less capacity than needed + in_memory_writer writer(num_rows - 1, num_cols); + + // Write header + writer(header_tokens); + + // Write sample rows until overflow + for (size_t i = 0; i < num_rows - 1; ++i) { + writer(sample_rows[i]); + } + + // The next write should throw an exception + EXPECT_THROW(writer(sample_rows.back()), std::runtime_error); +} + +// Test writing rows with incorrect column count +TEST_F(InMemoryWriterTest, ColumnMismatch) { + // Create writer + in_memory_writer writer(num_rows, num_cols); + + // Write header + writer(header_tokens); + + // Create an invalid row with too few columns + std::vector invalid_row(num_cols - 1, 0.5); + + // Writing should throw an exception + EXPECT_THROW(writer(invalid_row), std::runtime_error); + + // Create an invalid row with too many columns + invalid_row.resize(num_cols + 1, 0.5); + + // Writing should throw an exception + EXPECT_THROW(writer(invalid_row), std::runtime_error); +} + +// Test reset functionality +TEST_F(InMemoryWriterTest, Reset) { + // Create writer + in_memory_writer writer(num_rows, num_cols); + + // Write header and first row + writer(header_tokens); + writer(sample_rows[0]); + + // Verify row was written + EXPECT_EQ(writer.get_row_count(), 1); + EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), sample_rows[0][0]); + + // Reset the writer + writer.reset(); + + // Verify row count is reset + EXPECT_EQ(writer.get_row_count(), 0); + + // Verify data is cleared + EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), 0.0); + + // Verify names are preserved + ASSERT_EQ(writer.get_names().size(), num_cols); + EXPECT_EQ(writer.get_names()[0], header_tokens[0]); + + // Write a different row + writer(sample_rows[1]); + + // Verify new data is written + EXPECT_EQ(writer.get_row_count(), 1); + EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), sample_rows[1][0]); +} + +// Test clear functionality +TEST_F(InMemoryWriterTest, Clear) { + // Create writer + in_memory_writer writer(num_rows, num_cols); + + // Write header and first row + writer(header_tokens); + writer(sample_rows[0]); + + // Clear the writer + writer.clear(); + + // Verify row count is reset + EXPECT_EQ(writer.get_row_count(), 0); + + // Verify data is cleared + EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), 0.0); + + // Verify names are also cleared + EXPECT_EQ(writer.get_names().size(), 0); +} + +// Test with different size allocations +TEST_F(InMemoryWriterTest, DifferentSizeAllocations) { + // Create writer with more capacity than needed + in_memory_writer writer(num_rows * 2, num_cols); + + // Write header + writer(header_tokens); + + // Write sample rows + for (const auto& row : sample_rows) { + writer(row); + } + + // Verify data was stored correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + ASSERT_EQ(data.rows(), num_rows * 2); // Total allocation + ASSERT_EQ(data.cols(), num_cols); + + // Verify only written rows have data + for (size_t i = 0; i < num_rows; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i, j), sample_rows[i][j]); + } + } + + // Remaining rows should be zero + for (size_t i = num_rows; i < num_rows * 2; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i, j), 0.0); + } + } +} + +// Test integration with dispatcher +TEST_F(InMemoryWriterTest, DispatcherIntegration) { + // Create writer and dispatcher + in_memory_writer writer(num_rows, num_cols); + dispatcher disp; + + // Register channel + disp.register_channel( + InfoType::SAMPLE_RAW, + std::unique_ptr(new WriterChannel(&writer))); + + // Send header and data through dispatcher + disp.dispatch(InfoType::SAMPLE_RAW, header_tokens); + for (const auto& row : sample_rows) { + disp.dispatch(InfoType::SAMPLE_RAW, row); + } + + // Verify data was received correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + for (size_t i = 0; i < num_rows; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i, j), sample_rows[i][j]); + } + } +} + +// Test handling of string messages +TEST_F(InMemoryWriterTest, StringMessages) { + in_memory_writer writer(num_rows, num_cols); + + // Write header + writer(header_tokens); + + // Write a string message (should be ignored) + writer("This is a message that should be ignored"); + + // Write sample row + writer(sample_rows[0]); + + // Verify row count is correct (only the data row should be counted) + EXPECT_EQ(writer.get_row_count(), 1); + + // Verify data was stored correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(0, j), sample_rows[0][j]); + } +} + +// Test handling of empty calls +TEST_F(InMemoryWriterTest, EmptyCalls) { + in_memory_writer writer(num_rows, num_cols); + + // Write header + writer(header_tokens); + + // Write an empty call (should be ignored) + writer(); + + // Write sample row + writer(sample_rows[0]); + + // Verify row count is correct (only the data row should be counted) + EXPECT_EQ(writer.get_row_count(), 1); + + // Verify data was stored correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(0, j), sample_rows[0][j]); + } +} + +// Test handling of Eigen input types +TEST_F(InMemoryWriterTest, EigenInputs) { + in_memory_writer writer(num_rows + 3, num_cols); + + // Write header + writer(header_tokens); + + // Write sample row as vector + writer(sample_rows[0]); + + // Write sample row as Eigen::VectorXd + Eigen::VectorXd vec(num_cols); + for (size_t j = 0; j < num_cols; ++j) { + vec(j) = sample_rows[1][j]; + } + writer(vec); + + // Write sample row as Eigen::RowVectorXd + Eigen::RowVectorXd row_vec(num_cols); + for (size_t j = 0; j < num_cols; ++j) { + row_vec(j) = sample_rows[2][j]; + } + writer(row_vec); + + // Create a matrix with all three rows + Eigen::MatrixXd matrix(3, num_cols); + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + matrix(i, j) = sample_rows[i][j]; + } + } + + // Write all rows at once as a matrix + writer(matrix); + + // Verify row count (should be 1 + 1 + 1 + 3 = 6) + EXPECT_EQ(writer.get_row_count(), 6); + + // Verify data was stored correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + + // First three rows should match the original inputs + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i, j), sample_rows[i][j]); + } + } + + // Next three rows should match the matrix input + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i + 3, j), sample_rows[i][j]); + } + } +}