From 3b7a3f7137e33ec799c957ab9466c3cd1b1b6408 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 18:51:43 -0500 Subject: [PATCH 01/53] output_controller and tests --- src/stan/io/output_controller.hpp | 49 +++++++++ src/test/unit/io/output_controller_test.cpp | 115 ++++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 src/stan/io/output_controller.hpp create mode 100644 src/test/unit/io/output_controller_test.cpp diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp new file mode 100644 index 0000000000..6baa641e6a --- /dev/null +++ b/src/stan/io/output_controller.hpp @@ -0,0 +1,49 @@ +#ifndef STAN_IO_OUTPUT_CONTROLLER_HPP +#define STAN_IO_OUTPUT_CONTROLLER_HPP + +#include +#include +#include +#include +#include + +namespace stan { +namespace io { + +class output_controller { + public: + output_controller() = default; + + // Add a writer for samples + void add_sample_writer(std::unique_ptr writer) { + sample_writers_.push_back(std::move(writer)); + } + + // Add a writer for diagnostics + void add_diagnostic_writer(std::unique_ptr writer) { + diagnostic_writers_.push_back(std::move(writer)); + } + + // Write sample data to all registered sample writers + void write_sample(const std::vector& sample) { + for (auto& writer : sample_writers_) { + writer->operator()(sample); + } + } + + // Write diagnostic data to all registered diagnostic writers + void write_diagnostic(const std::vector& diagnostic) { + for (auto& writer : diagnostic_writers_) { + writer->operator()(diagnostic); + } + } + + private: + std::vector> sample_writers_; + std::vector> diagnostic_writers_; +}; + +} // namespace io +} // namespace stan + +#endif \ No newline at end of file diff --git a/src/test/unit/io/output_controller_test.cpp b/src/test/unit/io/output_controller_test.cpp new file mode 100644 index 0000000000..44b669b7ca --- /dev/null +++ b/src/test/unit/io/output_controller_test.cpp @@ -0,0 +1,115 @@ +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace test { +namespace unit { +namespace io { + +// Mock writer for testing +class mock_writer : public callbacks::writer { + public: + std::vector> samples; + std::vector> diagnostics; + + void operator()(const std::vector& x) override { + samples.push_back(x); + } + + void operator()(const std::vector& x) override { + // Not used in tests + } +}; + +TEST(output_controller, add_and_write_sample) { + output_controller controller; + + // Create two mock writers + auto writer1 = std::make_unique(); + auto writer2 = std::make_unique(); + + // Store raw pointers for testing + auto* writer1_ptr = writer1.get(); + auto* writer2_ptr = writer2.get(); + + // Add writers to controller + controller.add_sample_writer(std::move(writer1)); + controller.add_sample_writer(std::move(writer2)); + + // Write sample data + std::vector sample = {1.0, 2.0, 3.0}; + controller.write_sample(sample); + + // Check both writers received the data + EXPECT_EQ(writer1_ptr->samples.size(), 1); + EXPECT_EQ(writer2_ptr->samples.size(), 1); + EXPECT_EQ(writer1_ptr->samples[0], sample); + EXPECT_EQ(writer2_ptr->samples[0], sample); +} + +TEST(output_controller, add_and_write_diagnostic) { + output_controller controller; + + // Create two mock writers + auto writer1 = std::make_unique(); + auto writer2 = std::make_unique(); + + // Store raw pointers for testing + auto* writer1_ptr = writer1.get(); + auto* writer2_ptr = writer2.get(); + + // Add writers to controller + controller.add_diagnostic_writer(std::move(writer1)); + controller.add_diagnostic_writer(std::move(writer2)); + + // Write diagnostic data + std::vector diagnostic = {4.0, 5.0, 6.0}; + controller.write_diagnostic(diagnostic); + + // Check both writers received the data + EXPECT_EQ(writer1_ptr->diagnostics.size(), 1); + EXPECT_EQ(writer2_ptr->diagnostics.size(), 1); + EXPECT_EQ(writer1_ptr->diagnostics[0], diagnostic); + EXPECT_EQ(writer2_ptr->diagnostics[0], diagnostic); +} + +TEST(output_controller, empty_writers) { + output_controller controller; + + // Writing to empty writers should not crash + std::vector data = {1.0, 2.0, 3.0}; + controller.write_sample(data); + controller.write_diagnostic(data); + + // Test passes if no crash occurs +} + +TEST(output_controller, multiple_writes) { + output_controller controller; + + auto writer = std::make_unique(); + auto* writer_ptr = writer.get(); + + controller.add_sample_writer(std::move(writer)); + + // Write multiple samples + std::vector sample1 = {1.0, 2.0, 3.0}; + std::vector sample2 = {4.0, 5.0, 6.0}; + + controller.write_sample(sample1); + controller.write_sample(sample2); + + // Check all samples were written + EXPECT_EQ(writer_ptr->samples.size(), 2); + EXPECT_EQ(writer_ptr->samples[0], sample1); + EXPECT_EQ(writer_ptr->samples[1], sample2); +} + +} // namespace io +} // namespace unit +} // namespace test +} // namespace stan \ No newline at end of file From 2cdfda86cc1990a0f0da51e8bd219b1ac4f4f01e Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 18:59:37 -0500 Subject: [PATCH 02/53] path fix --- src/stan/io/output_controller.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index 6baa641e6a..857d841841 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include From db3e8a5ac80c1aefdc111ecf91913fbc9f11c8df Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 19:07:01 -0500 Subject: [PATCH 03/53] fix namespace --- src/test/unit/io/output_controller_test.cpp | 22 ++++++--------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/test/unit/io/output_controller_test.cpp b/src/test/unit/io/output_controller_test.cpp index 44b669b7ca..9675303eb8 100644 --- a/src/test/unit/io/output_controller_test.cpp +++ b/src/test/unit/io/output_controller_test.cpp @@ -5,13 +5,8 @@ #include #include -namespace stan { -namespace test { -namespace unit { -namespace io { - // Mock writer for testing -class mock_writer : public callbacks::writer { +class mock_writer : public stan::callbacks::writer { public: std::vector> samples; std::vector> diagnostics; @@ -26,7 +21,7 @@ class mock_writer : public callbacks::writer { }; TEST(output_controller, add_and_write_sample) { - output_controller controller; + stan::io::output_controller controller; // Create two mock writers auto writer1 = std::make_unique(); @@ -52,7 +47,7 @@ TEST(output_controller, add_and_write_sample) { } TEST(output_controller, add_and_write_diagnostic) { - output_controller controller; + stan::io::output_controller controller; // Create two mock writers auto writer1 = std::make_unique(); @@ -78,7 +73,7 @@ TEST(output_controller, add_and_write_diagnostic) { } TEST(output_controller, empty_writers) { - output_controller controller; + stan::io::output_controller controller; // Writing to empty writers should not crash std::vector data = {1.0, 2.0, 3.0}; @@ -89,7 +84,7 @@ TEST(output_controller, empty_writers) { } TEST(output_controller, multiple_writes) { - output_controller controller; + stan::io::output_controller controller; auto writer = std::make_unique(); auto* writer_ptr = writer.get(); @@ -107,9 +102,4 @@ TEST(output_controller, multiple_writes) { EXPECT_EQ(writer_ptr->samples.size(), 2); EXPECT_EQ(writer_ptr->samples[0], sample1); EXPECT_EQ(writer_ptr->samples[1], sample2); -} - -} // namespace io -} // namespace unit -} // namespace test -} // namespace stan \ No newline at end of file +} \ No newline at end of file From 8473d2c59881e7ec28a3eca288d1060c4ff197ad Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 19:10:59 -0500 Subject: [PATCH 04/53] unneeded include --- src/test/unit/io/output_controller_test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/test/unit/io/output_controller_test.cpp b/src/test/unit/io/output_controller_test.cpp index 9675303eb8..9df6f0c0eb 100644 --- a/src/test/unit/io/output_controller_test.cpp +++ b/src/test/unit/io/output_controller_test.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include #include From f4acf745c197fc0adf6e6deca24a9211b6fcb1d2 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 19:50:10 -0500 Subject: [PATCH 05/53] 2nd try --- src/stan/io/output_controller.hpp | 63 ++++++++++++++------- src/test/unit/io/output_controller_test.cpp | 12 ++-- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index 857d841841..362ab7e00e 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -6,41 +6,66 @@ #include #include #include +#include +#include namespace stan { namespace io { +enum class OutputFormat { + CSV, // Plain text CSV files + MATRIX, // In-memory column-major matrix + ARROW, // Apache Arrow files + JSON // JSON files +}; + +struct OutputConfig { + OutputFormat format; + std::string path; // File path or identifier for the output +}; + class output_controller { public: output_controller() = default; - // Add a writer for samples - void add_sample_writer(std::unique_ptr writer) { - sample_writers_.push_back(std::move(writer)); - } - - // Add a writer for diagnostics - void add_diagnostic_writer(std::unique_ptr writer) { - diagnostic_writers_.push_back(std::move(writer)); + // Configure output format for a specific information type + void configure_output(const std::string& info_type, OutputConfig config) { + output_configs_[info_type] = config; } - // Write sample data to all registered sample writers - void write_sample(const std::vector& sample) { - for (auto& writer : sample_writers_) { - writer->operator()(sample); + // Get appropriate writer for information type + std::unique_ptr get_writer(const std::string& info_type) { + auto it = output_configs_.find(info_type); + if (it == output_configs_.end()) { + throw std::runtime_error("No output configuration for " + info_type); } + + return create_writer(it->second); } - // Write diagnostic data to all registered diagnostic writers - void write_diagnostic(const std::vector& diagnostic) { - for (auto& writer : diagnostic_writers_) { - writer->operator()(diagnostic); - } + // Write data using appropriate writer + void write(const std::string& info_type, const std::vector& data) { + auto writer = get_writer(info_type); + writer->operator()(data); } private: - std::vector> sample_writers_; - std::vector> diagnostic_writers_; + std::map output_configs_; + + std::unique_ptr create_writer(const OutputConfig& config) { + switch (config.format) { + case OutputFormat::CSV: + return std::make_unique(config.path); + case OutputFormat::MATRIX: + return std::make_unique(/* dimensions */); + case OutputFormat::ARROW: + return std::make_unique(config.path); + case OutputFormat::JSON: + return std::make_unique(config.path); + default: + throw std::runtime_error("Unsupported output format"); + } + } }; } // namespace io diff --git a/src/test/unit/io/output_controller_test.cpp b/src/test/unit/io/output_controller_test.cpp index 9df6f0c0eb..95b2d71deb 100644 --- a/src/test/unit/io/output_controller_test.cpp +++ b/src/test/unit/io/output_controller_test.cpp @@ -14,6 +14,10 @@ class mock_writer : public stan::callbacks::writer { samples.push_back(x); } + void write_diagnostic(const std::vector& x) { + diagnostics.push_back(x); + } + void operator()(const std::vector& x) override { // Not used in tests } @@ -65,10 +69,10 @@ TEST(output_controller, add_and_write_diagnostic) { controller.write_diagnostic(diagnostic); // Check both writers received the data - EXPECT_EQ(writer1_ptr->diagnostics.size(), 1); - EXPECT_EQ(writer2_ptr->diagnostics.size(), 1); - EXPECT_EQ(writer1_ptr->diagnostics[0], diagnostic); - EXPECT_EQ(writer2_ptr->diagnostics[0], diagnostic); + EXPECT_EQ(writer1_ptr->samples.size(), 1); + EXPECT_EQ(writer2_ptr->samples.size(), 1); + EXPECT_EQ(writer1_ptr->samples[0], diagnostic); + EXPECT_EQ(writer2_ptr->samples[0], diagnostic); } TEST(output_controller, empty_writers) { From 7e7c110122125d3517f699bab04a06654c5931c5 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 19:57:18 -0500 Subject: [PATCH 06/53] another try --- src/stan/callbacks/matrix_writer.hpp | 50 ++++++++++++++++++++++++++++ src/stan/io/output_controller.hpp | 7 +++- 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 src/stan/callbacks/matrix_writer.hpp diff --git a/src/stan/callbacks/matrix_writer.hpp b/src/stan/callbacks/matrix_writer.hpp new file mode 100644 index 0000000000..36c8c0d57a --- /dev/null +++ b/src/stan/callbacks/matrix_writer.hpp @@ -0,0 +1,50 @@ +#ifndef STAN_CALLBACKS_MATRIX_WRITER_HPP +#define STAN_CALLBACKS_MATRIX_WRITER_HPP + +#include +#include +#include + +namespace stan { +namespace callbacks { + +class matrix_writer : public writer { + public: + matrix_writer(size_t rows, size_t cols) + : rows_(rows), cols_(cols), data_(rows * cols) {} + + void operator()(const std::vector& x) override { + if (current_row_ >= rows_) { + throw std::runtime_error("Matrix writer: too many rows"); + } + if (x.size() != cols_) { + throw std::runtime_error("Matrix writer: incorrect number of columns"); + } + + // Store in column-major order + for (size_t j = 0; j < cols_; ++j) { + data_[j * rows_ + current_row_] = x[j]; + } + current_row_++; + } + + void operator()(const std::vector& x) override { + throw std::runtime_error("Matrix writer does not support string data"); + } + + double* data() { return data_.data(); } + size_t rows() const { return rows_; } + size_t cols() const { return cols_; } + size_t current_row() const { return current_row_; } + + private: + size_t rows_; + size_t cols_; + size_t current_row_{0}; + std::vector data_; +}; + +} // namespace callbacks +} // namespace stan + +#endif \ No newline at end of file diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index 362ab7e00e..4864ff911e 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -22,6 +22,8 @@ enum class OutputFormat { struct OutputConfig { OutputFormat format; std::string path; // File path or identifier for the output + size_t rows{0}; // For matrix format + size_t cols{0}; // For matrix format }; class output_controller { @@ -57,7 +59,10 @@ class output_controller { case OutputFormat::CSV: return std::make_unique(config.path); case OutputFormat::MATRIX: - return std::make_unique(/* dimensions */); + if (config.rows == 0 || config.cols == 0) { + throw std::runtime_error("Matrix dimensions must be specified"); + } + return std::make_unique(config.rows, config.cols); case OutputFormat::ARROW: return std::make_unique(config.path); case OutputFormat::JSON: From c16635e1167d1fb9b15a85e285088d8084318235 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 20:04:05 -0500 Subject: [PATCH 07/53] try try again --- src/stan/io/output_controller.hpp | 8 +- src/test/unit/io/output_controller_test.cpp | 108 ++++++++------------ 2 files changed, 45 insertions(+), 71 deletions(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index 4864ff911e..c7cd852457 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -2,8 +2,9 @@ #define STAN_IO_OUTPUT_CONTROLLER_HPP #include -#include -#include +#include +#include +#include #include #include #include @@ -15,7 +16,6 @@ namespace io { enum class OutputFormat { CSV, // Plain text CSV files MATRIX, // In-memory column-major matrix - ARROW, // Apache Arrow files JSON // JSON files }; @@ -63,8 +63,6 @@ class output_controller { throw std::runtime_error("Matrix dimensions must be specified"); } return std::make_unique(config.rows, config.cols); - case OutputFormat::ARROW: - return std::make_unique(config.path); case OutputFormat::JSON: return std::make_unique(config.path); default: diff --git a/src/test/unit/io/output_controller_test.cpp b/src/test/unit/io/output_controller_test.cpp index 95b2d71deb..455e5a6ed9 100644 --- a/src/test/unit/io/output_controller_test.cpp +++ b/src/test/unit/io/output_controller_test.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include #include #include #include @@ -7,15 +9,10 @@ // Mock writer for testing class mock_writer : public stan::callbacks::writer { public: - std::vector> samples; - std::vector> diagnostics; + std::vector> data; void operator()(const std::vector& x) override { - samples.push_back(x); - } - - void write_diagnostic(const std::vector& x) { - diagnostics.push_back(x); + data.push_back(x); } void operator()(const std::vector& x) override { @@ -23,86 +20,65 @@ class mock_writer : public stan::callbacks::writer { } }; -TEST(output_controller, add_and_write_sample) { +TEST(output_controller, configure_and_write) { stan::io::output_controller controller; - // Create two mock writers - auto writer1 = std::make_unique(); - auto writer2 = std::make_unique(); - - // Store raw pointers for testing - auto* writer1_ptr = writer1.get(); - auto* writer2_ptr = writer2.get(); - - // Add writers to controller - controller.add_sample_writer(std::move(writer1)); - controller.add_sample_writer(std::move(writer2)); + // Configure output for samples + controller.configure_output("samples", + {stan::io::OutputFormat::MATRIX, "memory", 100, 3}); // Write sample data std::vector sample = {1.0, 2.0, 3.0}; - controller.write_sample(sample); - - // Check both writers received the data - EXPECT_EQ(writer1_ptr->samples.size(), 1); - EXPECT_EQ(writer2_ptr->samples.size(), 1); - EXPECT_EQ(writer1_ptr->samples[0], sample); - EXPECT_EQ(writer2_ptr->samples[0], sample); + controller.write("samples", sample); + + // Get writer and verify data + auto writer = controller.get_writer("samples"); + auto* matrix_writer = dynamic_cast(writer.get()); + EXPECT_NE(matrix_writer, nullptr); + EXPECT_EQ(matrix_writer->rows(), 100); + EXPECT_EQ(matrix_writer->cols(), 3); + EXPECT_EQ(matrix_writer->current_row(), 1); } -TEST(output_controller, add_and_write_diagnostic) { +TEST(output_controller, multiple_formats) { stan::io::output_controller controller; - // Create two mock writers - auto writer1 = std::make_unique(); - auto writer2 = std::make_unique(); + // Configure different formats for different information types + controller.configure_output("samples", + {stan::io::OutputFormat::MATRIX, "memory", 100, 3}); + controller.configure_output("diagnostics", + {stan::io::OutputFormat::CSV, "diagnostics.csv"}); - // Store raw pointers for testing - auto* writer1_ptr = writer1.get(); - auto* writer2_ptr = writer2.get(); + // Write data + std::vector sample = {1.0, 2.0, 3.0}; + std::vector diagnostic = {4.0, 5.0, 6.0}; - // Add writers to controller - controller.add_diagnostic_writer(std::move(writer1)); - controller.add_diagnostic_writer(std::move(writer2)); + controller.write("samples", sample); + controller.write("diagnostics", diagnostic); - // Write diagnostic data - std::vector diagnostic = {4.0, 5.0, 6.0}; - controller.write_diagnostic(diagnostic); + // Verify writers + auto sample_writer = controller.get_writer("samples"); + auto diagnostic_writer = controller.get_writer("diagnostics"); - // Check both writers received the data - EXPECT_EQ(writer1_ptr->samples.size(), 1); - EXPECT_EQ(writer2_ptr->samples.size(), 1); - EXPECT_EQ(writer1_ptr->samples[0], diagnostic); - EXPECT_EQ(writer2_ptr->samples[0], diagnostic); + EXPECT_NE(dynamic_cast(sample_writer.get()), nullptr); + EXPECT_NE(dynamic_cast(diagnostic_writer.get()), nullptr); } -TEST(output_controller, empty_writers) { +TEST(output_controller, unconfigured_output) { stan::io::output_controller controller; - // Writing to empty writers should not crash + // Attempt to write without configuration std::vector data = {1.0, 2.0, 3.0}; - controller.write_sample(data); - controller.write_diagnostic(data); - - // Test passes if no crash occurs + EXPECT_THROW(controller.write("samples", data), std::runtime_error); } -TEST(output_controller, multiple_writes) { +TEST(output_controller, invalid_format) { stan::io::output_controller controller; - auto writer = std::make_unique(); - auto* writer_ptr = writer.get(); - - controller.add_sample_writer(std::move(writer)); + // Configure with invalid format + controller.configure_output("samples", + {static_cast(999), "invalid"}); - // Write multiple samples - std::vector sample1 = {1.0, 2.0, 3.0}; - std::vector sample2 = {4.0, 5.0, 6.0}; - - controller.write_sample(sample1); - controller.write_sample(sample2); - - // Check all samples were written - EXPECT_EQ(writer_ptr->samples.size(), 2); - EXPECT_EQ(writer_ptr->samples[0], sample1); - EXPECT_EQ(writer_ptr->samples[1], sample2); + std::vector data = {1.0, 2.0, 3.0}; + EXPECT_THROW(controller.write("samples", data), std::runtime_error); } \ No newline at end of file From 32acdc0f08b2eb8c65a77c9379f485c459cd33e3 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 20:10:43 -0500 Subject: [PATCH 08/53] try again --- src/stan/io/output_controller.hpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index c7cd852457..174bb485fb 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace stan { namespace io { @@ -56,15 +57,21 @@ class output_controller { std::unique_ptr create_writer(const OutputConfig& config) { switch (config.format) { - case OutputFormat::CSV: - return std::make_unique(config.path); + case OutputFormat::CSV: { + auto* file = new std::ofstream(config.path); + return std::unique_ptr( + new callbacks::stream_writer(*file)); + } case OutputFormat::MATRIX: if (config.rows == 0 || config.cols == 0) { throw std::runtime_error("Matrix dimensions must be specified"); } - return std::make_unique(config.rows, config.cols); - case OutputFormat::JSON: - return std::make_unique(config.path); + return std::make_unique(config.rows, config.cols); + case OutputFormat::JSON: { + auto* file = new std::ofstream(config.path); + return std::unique_ptr( + new callbacks::json_writer(*file)); + } default: throw std::runtime_error("Unsupported output format"); } From 63194af196a5d5dd812e15fb0d180be0d605b189 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 20:25:44 -0500 Subject: [PATCH 09/53] try again --- src/stan/io/output_controller.hpp | 30 +++++++++- src/test/unit/io/output_controller_test.cpp | 66 +++++++++++++++++++-- 2 files changed, 89 insertions(+), 7 deletions(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index 174bb485fb..f732e663af 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -15,9 +16,10 @@ namespace stan { namespace io { enum class OutputFormat { - CSV, // Plain text CSV files + CSV, // Plain text CSV files for streaming data MATRIX, // In-memory column-major matrix - JSON // JSON files + ARROW, // Apache Arrow format for streaming data + JSON // JSON format for flexible metadata }; struct OutputConfig { @@ -46,12 +48,29 @@ class output_controller { return create_writer(it->second); } - // Write data using appropriate writer + // Write streaming data (samples, diagnostics) - works for CSV, MATRIX, and ARROW void write(const std::string& info_type, const std::vector& data) { auto writer = get_writer(info_type); writer->operator()(data); } + // Write metadata (flexible JSON format) + void write_metadata(const std::string& info_type, + const std::vector& metadata) { + auto writer = get_writer(info_type); + writer->operator()(metadata); + } + + // For backward compatibility with existing JSON writer usage + callbacks::json_writer* get_json_writer(const std::string& info_type) { + auto writer = get_writer(info_type); + auto* json_writer = dynamic_cast(writer.get()); + if (!json_writer) { + throw std::runtime_error("Writer is not a JSON writer"); + } + return json_writer; + } + private: std::map output_configs_; @@ -67,6 +86,11 @@ class output_controller { throw std::runtime_error("Matrix dimensions must be specified"); } return std::make_unique(config.rows, config.cols); + case OutputFormat::ARROW: { + auto* file = new std::ofstream(config.path); + return std::unique_ptr( + new callbacks::structured_writer(*file)); + } case OutputFormat::JSON: { auto* file = new std::ofstream(config.path); return std::unique_ptr( diff --git a/src/test/unit/io/output_controller_test.cpp b/src/test/unit/io/output_controller_test.cpp index 455e5a6ed9..7fdf131083 100644 --- a/src/test/unit/io/output_controller_test.cpp +++ b/src/test/unit/io/output_controller_test.cpp @@ -2,6 +2,8 @@ #include #include #include +#include +#include #include #include #include @@ -10,17 +12,18 @@ class mock_writer : public stan::callbacks::writer { public: std::vector> data; + std::vector> metadata; void operator()(const std::vector& x) override { data.push_back(x); } void operator()(const std::vector& x) override { - // Not used in tests + metadata.push_back(x); } }; -TEST(output_controller, configure_and_write) { +TEST(output_controller, configure_and_write_streaming) { stan::io::output_controller controller; // Configure output for samples @@ -40,6 +43,50 @@ TEST(output_controller, configure_and_write) { EXPECT_EQ(matrix_writer->current_row(), 1); } +TEST(output_controller, write_metadata) { + stan::io::output_controller controller; + + // Configure output for metadata + controller.configure_output("model_info", + {stan::io::OutputFormat::JSON, "model.json"}); + + // Write metadata + std::vector metadata = {"model_name", "bernoulli", "version", "2.29.0"}; + controller.write_metadata("model_info", metadata); + + // Get writer and verify it's a JSON writer + auto writer = controller.get_writer("model_info"); + auto* json_writer = dynamic_cast(writer.get()); + EXPECT_NE(json_writer, nullptr); +} + +TEST(output_controller, get_json_writer) { + stan::io::output_controller controller; + + // Configure JSON writer for metric + controller.configure_output("metric", + {stan::io::OutputFormat::JSON, "metric.json"}); + + // Get JSON writer directly + auto* metric_writer = controller.get_json_writer("metric"); + EXPECT_NE(metric_writer, nullptr); + + // Test using JSON writer interface directly + std::vector metric_data = {"stepsize", "0.1", "metric_type", "dense"}; + metric_writer->operator()(metric_data); +} + +TEST(output_controller, get_json_writer_wrong_type) { + stan::io::output_controller controller; + + // Configure non-JSON writer + controller.configure_output("samples", + {stan::io::OutputFormat::MATRIX, "memory", 100, 3}); + + // Attempt to get JSON writer + EXPECT_THROW(controller.get_json_writer("samples"), std::runtime_error); +} + TEST(output_controller, multiple_formats) { stan::io::output_controller controller; @@ -48,20 +95,27 @@ TEST(output_controller, multiple_formats) { {stan::io::OutputFormat::MATRIX, "memory", 100, 3}); controller.configure_output("diagnostics", {stan::io::OutputFormat::CSV, "diagnostics.csv"}); + controller.configure_output("metric", + {stan::io::OutputFormat::JSON, "metric.json"}); - // Write data + // Write streaming data std::vector sample = {1.0, 2.0, 3.0}; std::vector diagnostic = {4.0, 5.0, 6.0}; - controller.write("samples", sample); controller.write("diagnostics", diagnostic); + // Write metric using JSON writer interface + auto* metric_writer = controller.get_json_writer("metric"); + std::vector metric_data = {"stepsize", "0.1"}; + metric_writer->operator()(metric_data); + // Verify writers auto sample_writer = controller.get_writer("samples"); auto diagnostic_writer = controller.get_writer("diagnostics"); EXPECT_NE(dynamic_cast(sample_writer.get()), nullptr); EXPECT_NE(dynamic_cast(diagnostic_writer.get()), nullptr); + EXPECT_NE(controller.get_json_writer("metric"), nullptr); } TEST(output_controller, unconfigured_output) { @@ -69,7 +123,11 @@ TEST(output_controller, unconfigured_output) { // Attempt to write without configuration std::vector data = {1.0, 2.0, 3.0}; + std::vector metadata = {"model_name", "bernoulli"}; + EXPECT_THROW(controller.write("samples", data), std::runtime_error); + EXPECT_THROW(controller.write_metadata("model_info", metadata), std::runtime_error); + EXPECT_THROW(controller.get_json_writer("metric"), std::runtime_error); } TEST(output_controller, invalid_format) { From c80c2ed615484d36ec57eb2284c1420ee0b8bcf5 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 21:22:41 -0500 Subject: [PATCH 10/53] try again --- src/stan/io/output_controller.hpp | 50 ++++++++++++++------- src/test/unit/io/output_controller_test.cpp | 37 ++++++++++++--- 2 files changed, 66 insertions(+), 21 deletions(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index f732e663af..87f8cb3fcf 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -61,43 +61,61 @@ class output_controller { writer->operator()(metadata); } - // For backward compatibility with existing JSON writer usage - callbacks::json_writer* get_json_writer(const std::string& info_type) { + // Get structured writer for backward compatibility with existing code + template> + callbacks::structured_writer* get_structured_writer(const std::string& info_type) { auto writer = get_writer(info_type); - auto* json_writer = dynamic_cast(writer.get()); + auto* structured_writer = dynamic_cast(writer.get()); + if (!structured_writer) { + throw std::runtime_error("Writer for " + info_type + " is not a structured writer"); + } + return structured_writer; + } + + // Get JSON writer for backward compatibility with existing code + template> + callbacks::json_writer* get_json_writer(const std::string& info_type) { + auto writer = get_writer(info_type); + auto* json_writer = dynamic_cast*>(writer.get()); if (!json_writer) { - throw std::runtime_error("Writer is not a JSON writer"); + throw std::runtime_error("Writer for " + info_type + " is not a JSON writer"); } return json_writer; } + // Get Arrow writer for backward compatibility with existing code + template> + callbacks::arrow_writer* get_arrow_writer(const std::string& info_type) { + auto writer = get_writer(info_type); + auto* arrow_writer = dynamic_cast*>(writer.get()); + if (!arrow_writer) { + throw std::runtime_error("Writer for " + info_type + " is not an Arrow writer"); + } + return arrow_writer; + } + private: std::map output_configs_; std::unique_ptr create_writer(const OutputConfig& config) { switch (config.format) { - case OutputFormat::CSV: { - auto* file = new std::ofstream(config.path); - return std::unique_ptr( - new callbacks::stream_writer(*file)); - } + case OutputFormat::CSV: + return std::make_unique(config.path); case OutputFormat::MATRIX: if (config.rows == 0 || config.cols == 0) { throw std::runtime_error("Matrix dimensions must be specified"); } return std::make_unique(config.rows, config.cols); case OutputFormat::ARROW: { - auto* file = new std::ofstream(config.path); - return std::unique_ptr( - new callbacks::structured_writer(*file)); + auto file = std::make_unique(config.path); + return std::make_unique>(std::move(file)); } case OutputFormat::JSON: { - auto* file = new std::ofstream(config.path); - return std::unique_ptr( - new callbacks::json_writer(*file)); + auto file = std::make_unique(config.path); + return std::make_unique>(std::move(file)); } default: - throw std::runtime_error("Unsupported output format"); + throw std::runtime_error("Invalid output format"); } } }; diff --git a/src/test/unit/io/output_controller_test.cpp b/src/test/unit/io/output_controller_test.cpp index 7fdf131083..400a42c052 100644 --- a/src/test/unit/io/output_controller_test.cpp +++ b/src/test/unit/io/output_controller_test.cpp @@ -87,12 +87,39 @@ TEST(output_controller, get_json_writer_wrong_type) { EXPECT_THROW(controller.get_json_writer("samples"), std::runtime_error); } -TEST(output_controller, multiple_formats) { +TEST(output_controller, get_arrow_writer) { stan::io::output_controller controller; - // Configure different formats for different information types + // Configure Arrow writer for samples + controller.configure_output("samples", + {stan::io::OutputFormat::ARROW, "samples.arrow"}); + + // Get Arrow writer directly + auto* arrow_writer = controller.get_arrow_writer("samples"); + EXPECT_NE(arrow_writer, nullptr); + + // Test using Arrow writer interface directly + std::vector sample = {1.0, 2.0, 3.0}; + arrow_writer->operator()(sample); +} + +TEST(output_controller, get_arrow_writer_wrong_type) { + stan::io::output_controller controller; + + // Configure non-Arrow writer controller.configure_output("samples", {stan::io::OutputFormat::MATRIX, "memory", 100, 3}); + + // Attempt to get Arrow writer + EXPECT_THROW(controller.get_arrow_writer("samples"), std::runtime_error); +} + +TEST(output_controller, multiple_formats_with_arrow) { + stan::io::output_controller controller; + + // Configure different formats for different information types + controller.configure_output("samples", + {stan::io::OutputFormat::ARROW, "samples.arrow"}); controller.configure_output("diagnostics", {stan::io::OutputFormat::CSV, "diagnostics.csv"}); controller.configure_output("metric", @@ -105,7 +132,7 @@ TEST(output_controller, multiple_formats) { controller.write("diagnostics", diagnostic); // Write metric using JSON writer interface - auto* metric_writer = controller.get_json_writer("metric"); + auto* metric_writer = controller.get_json_writer("metric"); std::vector metric_data = {"stepsize", "0.1"}; metric_writer->operator()(metric_data); @@ -113,9 +140,9 @@ TEST(output_controller, multiple_formats) { auto sample_writer = controller.get_writer("samples"); auto diagnostic_writer = controller.get_writer("diagnostics"); - EXPECT_NE(dynamic_cast(sample_writer.get()), nullptr); + EXPECT_NE(dynamic_cast*>(sample_writer.get()), nullptr); EXPECT_NE(dynamic_cast(diagnostic_writer.get()), nullptr); - EXPECT_NE(controller.get_json_writer("metric"), nullptr); + EXPECT_NE(controller.get_json_writer("metric"), nullptr); } TEST(output_controller, unconfigured_output) { From eb6d343a7703b9536e01493df801c4fa7eaa7d5f Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 21:27:53 -0500 Subject: [PATCH 11/53] try again --- src/stan/io/output_controller.hpp | 112 ++++++++------------ src/test/unit/io/output_controller_test.cpp | 46 ++------ 2 files changed, 52 insertions(+), 106 deletions(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index 87f8cb3fcf..55b5cce6d3 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -7,9 +7,8 @@ #include #include #include -#include #include -#include +#include #include namespace stan { @@ -22,54 +21,53 @@ enum class OutputFormat { JSON // JSON format for flexible metadata }; +struct OutputDims { + size_t rows; + size_t cols; +}; + struct OutputConfig { OutputFormat format; - std::string path; // File path or identifier for the output - size_t rows{0}; // For matrix format - size_t cols{0}; // For matrix format + std::string file_path; + OutputDims dims; }; class output_controller { - public: - output_controller() = default; - - // Configure output format for a specific information type - void configure_output(const std::string& info_type, OutputConfig config) { - output_configs_[info_type] = config; - } - - // Get appropriate writer for information type - std::unique_ptr get_writer(const std::string& info_type) { - auto it = output_configs_.find(info_type); - if (it == output_configs_.end()) { - throw std::runtime_error("No output configuration for " + info_type); + private: + std::unordered_map> writers_; + + std::shared_ptr create_writer(const OutputConfig& config) { + switch (config.format) { + case OutputFormat::CSV: { + auto file = std::make_unique(config.file_path); + return std::make_shared(*file); + } + case OutputFormat::MATRIX: + return std::make_shared( + config.dims.rows, config.dims.cols); + case OutputFormat::ARROW: + // TODO: Implement Arrow writer + throw std::runtime_error("Arrow writer not yet implemented"); + case OutputFormat::JSON: { + auto file = std::make_unique(config.file_path); + return std::make_shared>(std::move(file)); + } + default: + throw std::runtime_error("Invalid output format"); } - - return create_writer(it->second); - } - - // Write streaming data (samples, diagnostics) - works for CSV, MATRIX, and ARROW - void write(const std::string& info_type, const std::vector& data) { - auto writer = get_writer(info_type); - writer->operator()(data); } - // Write metadata (flexible JSON format) - void write_metadata(const std::string& info_type, - const std::vector& metadata) { - auto writer = get_writer(info_type); - writer->operator()(metadata); + public: + void configure_output(const std::string& info_type, const OutputConfig& config) { + writers_[info_type] = create_writer(config); } - // Get structured writer for backward compatibility with existing code - template> - callbacks::structured_writer* get_structured_writer(const std::string& info_type) { - auto writer = get_writer(info_type); - auto* structured_writer = dynamic_cast(writer.get()); - if (!structured_writer) { - throw std::runtime_error("Writer for " + info_type + " is not a structured writer"); + std::shared_ptr get_writer(const std::string& info_type) { + auto it = writers_.find(info_type); + if (it == writers_.end()) { + throw std::runtime_error("No writer configured for " + info_type); } - return structured_writer; + return it->second; } // Get JSON writer for backward compatibility with existing code @@ -83,40 +81,14 @@ class output_controller { return json_writer; } - // Get Arrow writer for backward compatibility with existing code - template> - callbacks::arrow_writer* get_arrow_writer(const std::string& info_type) { + void write(const std::string& info_type, const std::vector& data) { auto writer = get_writer(info_type); - auto* arrow_writer = dynamic_cast*>(writer.get()); - if (!arrow_writer) { - throw std::runtime_error("Writer for " + info_type + " is not an Arrow writer"); - } - return arrow_writer; + writer->operator()(data); } - private: - std::map output_configs_; - - std::unique_ptr create_writer(const OutputConfig& config) { - switch (config.format) { - case OutputFormat::CSV: - return std::make_unique(config.path); - case OutputFormat::MATRIX: - if (config.rows == 0 || config.cols == 0) { - throw std::runtime_error("Matrix dimensions must be specified"); - } - return std::make_unique(config.rows, config.cols); - case OutputFormat::ARROW: { - auto file = std::make_unique(config.path); - return std::make_unique>(std::move(file)); - } - case OutputFormat::JSON: { - auto file = std::make_unique(config.path); - return std::make_unique>(std::move(file)); - } - default: - throw std::runtime_error("Invalid output format"); - } + void write_metadata(const std::string& info_type, const std::vector& metadata) { + auto writer = get_writer(info_type); + writer->operator()(metadata); } }; diff --git a/src/test/unit/io/output_controller_test.cpp b/src/test/unit/io/output_controller_test.cpp index 400a42c052..0b6313c5cd 100644 --- a/src/test/unit/io/output_controller_test.cpp +++ b/src/test/unit/io/output_controller_test.cpp @@ -7,6 +7,7 @@ #include #include #include +#include // Mock writer for testing class mock_writer : public stan::callbacks::writer { @@ -56,7 +57,7 @@ TEST(output_controller, write_metadata) { // Get writer and verify it's a JSON writer auto writer = controller.get_writer("model_info"); - auto* json_writer = dynamic_cast(writer.get()); + auto* json_writer = dynamic_cast*>(writer.get()); EXPECT_NE(json_writer, nullptr); } @@ -68,12 +69,12 @@ TEST(output_controller, get_json_writer) { {stan::io::OutputFormat::JSON, "metric.json"}); // Get JSON writer directly - auto* metric_writer = controller.get_json_writer("metric"); + auto* metric_writer = controller.get_json_writer("metric"); EXPECT_NE(metric_writer, nullptr); // Test using JSON writer interface directly std::vector metric_data = {"stepsize", "0.1", "metric_type", "dense"}; - metric_writer->operator()(metric_data); + metric_writer->write(metric_data); } TEST(output_controller, get_json_writer_wrong_type) { @@ -84,42 +85,15 @@ TEST(output_controller, get_json_writer_wrong_type) { {stan::io::OutputFormat::MATRIX, "memory", 100, 3}); // Attempt to get JSON writer - EXPECT_THROW(controller.get_json_writer("samples"), std::runtime_error); -} - -TEST(output_controller, get_arrow_writer) { - stan::io::output_controller controller; - - // Configure Arrow writer for samples - controller.configure_output("samples", - {stan::io::OutputFormat::ARROW, "samples.arrow"}); - - // Get Arrow writer directly - auto* arrow_writer = controller.get_arrow_writer("samples"); - EXPECT_NE(arrow_writer, nullptr); - - // Test using Arrow writer interface directly - std::vector sample = {1.0, 2.0, 3.0}; - arrow_writer->operator()(sample); + EXPECT_THROW(controller.get_json_writer("samples"), std::runtime_error); } -TEST(output_controller, get_arrow_writer_wrong_type) { - stan::io::output_controller controller; - - // Configure non-Arrow writer - controller.configure_output("samples", - {stan::io::OutputFormat::MATRIX, "memory", 100, 3}); - - // Attempt to get Arrow writer - EXPECT_THROW(controller.get_arrow_writer("samples"), std::runtime_error); -} - -TEST(output_controller, multiple_formats_with_arrow) { +TEST(output_controller, multiple_formats) { stan::io::output_controller controller; // Configure different formats for different information types controller.configure_output("samples", - {stan::io::OutputFormat::ARROW, "samples.arrow"}); + {stan::io::OutputFormat::MATRIX, "memory", 100, 3}); controller.configure_output("diagnostics", {stan::io::OutputFormat::CSV, "diagnostics.csv"}); controller.configure_output("metric", @@ -134,13 +108,13 @@ TEST(output_controller, multiple_formats_with_arrow) { // Write metric using JSON writer interface auto* metric_writer = controller.get_json_writer("metric"); std::vector metric_data = {"stepsize", "0.1"}; - metric_writer->operator()(metric_data); + metric_writer->write(metric_data); // Verify writers auto sample_writer = controller.get_writer("samples"); auto diagnostic_writer = controller.get_writer("diagnostics"); - EXPECT_NE(dynamic_cast*>(sample_writer.get()), nullptr); + EXPECT_NE(dynamic_cast(sample_writer.get()), nullptr); EXPECT_NE(dynamic_cast(diagnostic_writer.get()), nullptr); EXPECT_NE(controller.get_json_writer("metric"), nullptr); } @@ -154,7 +128,7 @@ TEST(output_controller, unconfigured_output) { EXPECT_THROW(controller.write("samples", data), std::runtime_error); EXPECT_THROW(controller.write_metadata("model_info", metadata), std::runtime_error); - EXPECT_THROW(controller.get_json_writer("metric"), std::runtime_error); + EXPECT_THROW(controller.get_json_writer("metric"), std::runtime_error); } TEST(output_controller, invalid_format) { From d7579c6e62cd0c3150a5d6be70043f212311129a Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 21:32:01 -0500 Subject: [PATCH 12/53] try again --- src/stan/io/output_controller.hpp | 3 ++- src/test/unit/io/output_controller_test.cpp | 9 ++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index 55b5cce6d3..cba427e88f 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -50,7 +50,8 @@ class output_controller { throw std::runtime_error("Arrow writer not yet implemented"); case OutputFormat::JSON: { auto file = std::make_unique(config.file_path); - return std::make_shared>(std::move(file)); + return std::make_shared>>( + std::move(file)); } default: throw std::runtime_error("Invalid output format"); diff --git a/src/test/unit/io/output_controller_test.cpp b/src/test/unit/io/output_controller_test.cpp index 0b6313c5cd..89be3121c8 100644 --- a/src/test/unit/io/output_controller_test.cpp +++ b/src/test/unit/io/output_controller_test.cpp @@ -57,7 +57,7 @@ TEST(output_controller, write_metadata) { // Get writer and verify it's a JSON writer auto writer = controller.get_writer("model_info"); - auto* json_writer = dynamic_cast*>(writer.get()); + auto* json_writer = dynamic_cast>*>(writer.get()); EXPECT_NE(json_writer, nullptr); } @@ -73,8 +73,8 @@ TEST(output_controller, get_json_writer) { EXPECT_NE(metric_writer, nullptr); // Test using JSON writer interface directly - std::vector metric_data = {"stepsize", "0.1", "metric_type", "dense"}; - metric_writer->write(metric_data); + metric_writer->write("stepsize", "0.1"); + metric_writer->write("metric_type", "dense"); } TEST(output_controller, get_json_writer_wrong_type) { @@ -107,8 +107,7 @@ TEST(output_controller, multiple_formats) { // Write metric using JSON writer interface auto* metric_writer = controller.get_json_writer("metric"); - std::vector metric_data = {"stepsize", "0.1"}; - metric_writer->write(metric_data); + metric_writer->write("stepsize", "0.1"); // Verify writers auto sample_writer = controller.get_writer("samples"); From 8742b7d24a6e1a6e7400463b00458b39ecb10802 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 21:34:11 -0500 Subject: [PATCH 13/53] try again --- src/stan/io/output_controller.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index cba427e88f..4d842afb70 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -50,8 +50,9 @@ class output_controller { throw std::runtime_error("Arrow writer not yet implemented"); case OutputFormat::JSON: { auto file = std::make_unique(config.file_path); - return std::make_shared>>( + auto json_writer = std::make_shared>>( std::move(file)); + return std::static_pointer_cast(json_writer); } default: throw std::runtime_error("Invalid output format"); From b2297c2b12b24aec55c08fcd5b6f00741ffeda04 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 21:35:23 -0500 Subject: [PATCH 14/53] try again --- src/stan/io/output_controller.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index 4d842afb70..ce1ff6fa40 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -52,7 +52,7 @@ class output_controller { auto file = std::make_unique(config.file_path); auto json_writer = std::make_shared>>( std::move(file)); - return std::static_pointer_cast(json_writer); + return std::dynamic_pointer_cast(json_writer); } default: throw std::runtime_error("Invalid output format"); From 11024d9d1cc0e10a835237e7cd17dfca26745d91 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 21:37:48 -0500 Subject: [PATCH 15/53] try again --- src/stan/io/output_controller.hpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index ce1ff6fa40..dae188e82a 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -35,11 +35,13 @@ struct OutputConfig { class output_controller { private: std::unordered_map> writers_; + std::unordered_map> files_; // Keep files alive std::shared_ptr create_writer(const OutputConfig& config) { switch (config.format) { case OutputFormat::CSV: { - auto file = std::make_unique(config.file_path); + auto file = std::make_shared(config.file_path); + files_[config.file_path] = file; // Store file return std::make_shared(*file); } case OutputFormat::MATRIX: @@ -49,10 +51,11 @@ class output_controller { // TODO: Implement Arrow writer throw std::runtime_error("Arrow writer not yet implemented"); case OutputFormat::JSON: { - auto file = std::make_unique(config.file_path); + auto file = std::make_shared(config.file_path); + files_[config.file_path] = file; // Store file auto json_writer = std::make_shared>>( - std::move(file)); - return std::dynamic_pointer_cast(json_writer); + file.get()); + return json_writer; // No need for dynamic_cast since json_writer inherits from writer } default: throw std::runtime_error("Invalid output format"); From 37000c4537a434c7f7152dd44d15ec7873b370cd Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 21:40:12 -0500 Subject: [PATCH 16/53] try again --- src/stan/io/output_controller.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index dae188e82a..c8452d1d6c 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -53,9 +53,10 @@ class output_controller { case OutputFormat::JSON: { auto file = std::make_shared(config.file_path); files_[config.file_path] = file; // Store file + auto file_ptr = std::make_unique(config.file_path); auto json_writer = std::make_shared>>( - file.get()); - return json_writer; // No need for dynamic_cast since json_writer inherits from writer + std::move(file_ptr)); + return std::dynamic_pointer_cast(json_writer); } default: throw std::runtime_error("Invalid output format"); From b5c858a5c911e385762ca53b8541b48fe93bcec8 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 21:50:30 -0500 Subject: [PATCH 17/53] fix namespaces --- src/stan/{io => callbacks}/output_controller.hpp | 0 src/test/unit/{io => callbacks}/output_controller_test.cpp | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename src/stan/{io => callbacks}/output_controller.hpp (100%) rename src/test/unit/{io => callbacks}/output_controller_test.cpp (100%) diff --git a/src/stan/io/output_controller.hpp b/src/stan/callbacks/output_controller.hpp similarity index 100% rename from src/stan/io/output_controller.hpp rename to src/stan/callbacks/output_controller.hpp diff --git a/src/test/unit/io/output_controller_test.cpp b/src/test/unit/callbacks/output_controller_test.cpp similarity index 100% rename from src/test/unit/io/output_controller_test.cpp rename to src/test/unit/callbacks/output_controller_test.cpp From ff2f43e3d5117ba04a6ac9df32895f25c9ebfe5a Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 21:56:56 -0500 Subject: [PATCH 18/53] namespace fix --- src/stan/callbacks/output_controller.hpp | 24 +++++++++---------- .../unit/callbacks/output_controller_test.cpp | 20 ++++++++-------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/stan/callbacks/output_controller.hpp b/src/stan/callbacks/output_controller.hpp index c8452d1d6c..f889a75242 100644 --- a/src/stan/callbacks/output_controller.hpp +++ b/src/stan/callbacks/output_controller.hpp @@ -12,7 +12,7 @@ #include namespace stan { -namespace io { +namespace callbacks { enum class OutputFormat { CSV, // Plain text CSV files for streaming data @@ -34,18 +34,18 @@ struct OutputConfig { class output_controller { private: - std::unordered_map> writers_; + std::unordered_map> writers_; std::unordered_map> files_; // Keep files alive - std::shared_ptr create_writer(const OutputConfig& config) { + std::shared_ptr create_writer(const OutputConfig& config) { switch (config.format) { case OutputFormat::CSV: { auto file = std::make_shared(config.file_path); files_[config.file_path] = file; // Store file - return std::make_shared(*file); + return std::make_shared(*file); } case OutputFormat::MATRIX: - return std::make_shared( + return std::make_shared( config.dims.rows, config.dims.cols); case OutputFormat::ARROW: // TODO: Implement Arrow writer @@ -54,9 +54,9 @@ class output_controller { auto file = std::make_shared(config.file_path); files_[config.file_path] = file; // Store file auto file_ptr = std::make_unique(config.file_path); - auto json_writer = std::make_shared>>( + auto json_writer = std::make_shared>>( std::move(file_ptr)); - return std::dynamic_pointer_cast(json_writer); + return std::dynamic_pointer_cast(json_writer); } default: throw std::runtime_error("Invalid output format"); @@ -68,7 +68,7 @@ class output_controller { writers_[info_type] = create_writer(config); } - std::shared_ptr get_writer(const std::string& info_type) { + std::shared_ptr get_writer(const std::string& info_type) { auto it = writers_.find(info_type); if (it == writers_.end()) { throw std::runtime_error("No writer configured for " + info_type); @@ -78,9 +78,9 @@ class output_controller { // Get JSON writer for backward compatibility with existing code template> - callbacks::json_writer* get_json_writer(const std::string& info_type) { + json_writer* get_json_writer(const std::string& info_type) { auto writer = get_writer(info_type); - auto* json_writer = dynamic_cast*>(writer.get()); + auto* json_writer = dynamic_cast*>(writer.get()); if (!json_writer) { throw std::runtime_error("Writer for " + info_type + " is not a JSON writer"); } @@ -98,7 +98,7 @@ class output_controller { } }; -} // namespace io +} // namespace callbacks } // namespace stan -#endif \ No newline at end of file +#endif diff --git a/src/test/unit/callbacks/output_controller_test.cpp b/src/test/unit/callbacks/output_controller_test.cpp index 89be3121c8..f80f9fe104 100644 --- a/src/test/unit/callbacks/output_controller_test.cpp +++ b/src/test/unit/callbacks/output_controller_test.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -29,7 +29,7 @@ TEST(output_controller, configure_and_write_streaming) { // Configure output for samples controller.configure_output("samples", - {stan::io::OutputFormat::MATRIX, "memory", 100, 3}); + {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); // Write sample data std::vector sample = {1.0, 2.0, 3.0}; @@ -49,7 +49,7 @@ TEST(output_controller, write_metadata) { // Configure output for metadata controller.configure_output("model_info", - {stan::io::OutputFormat::JSON, "model.json"}); + {stan::callbacks::OutputFormat::JSON, "model.json"}); // Write metadata std::vector metadata = {"model_name", "bernoulli", "version", "2.29.0"}; @@ -66,7 +66,7 @@ TEST(output_controller, get_json_writer) { // Configure JSON writer for metric controller.configure_output("metric", - {stan::io::OutputFormat::JSON, "metric.json"}); + {stan::callbacks::OutputFormat::JSON, "metric.json"}); // Get JSON writer directly auto* metric_writer = controller.get_json_writer("metric"); @@ -82,7 +82,7 @@ TEST(output_controller, get_json_writer_wrong_type) { // Configure non-JSON writer controller.configure_output("samples", - {stan::io::OutputFormat::MATRIX, "memory", 100, 3}); + {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); // Attempt to get JSON writer EXPECT_THROW(controller.get_json_writer("samples"), std::runtime_error); @@ -93,11 +93,11 @@ TEST(output_controller, multiple_formats) { // Configure different formats for different information types controller.configure_output("samples", - {stan::io::OutputFormat::MATRIX, "memory", 100, 3}); + {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); controller.configure_output("diagnostics", - {stan::io::OutputFormat::CSV, "diagnostics.csv"}); + {stan::callbacks::OutputFormat::CSV, "diagnostics.csv"}); controller.configure_output("metric", - {stan::io::OutputFormat::JSON, "metric.json"}); + {stan::callbacks::OutputFormat::JSON, "metric.json"}); // Write streaming data std::vector sample = {1.0, 2.0, 3.0}; @@ -135,8 +135,8 @@ TEST(output_controller, invalid_format) { // Configure with invalid format controller.configure_output("samples", - {static_cast(999), "invalid"}); + {static_cast(999), "invalid"}); std::vector data = {1.0, 2.0, 3.0}; EXPECT_THROW(controller.write("samples", data), std::runtime_error); -} \ No newline at end of file +} From dc90f4f6e261bbc98bcd80cd04ecd99bec5ae102 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:07:59 -0500 Subject: [PATCH 19/53] try again --- src/stan/io/output_controller.hpp | 99 +++++++++++++++++++ .../unit/callbacks/output_controller_test.cpp | 14 +-- 2 files changed, 106 insertions(+), 7 deletions(-) create mode 100644 src/stan/io/output_controller.hpp diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp new file mode 100644 index 0000000000..edc37b1bfc --- /dev/null +++ b/src/stan/io/output_controller.hpp @@ -0,0 +1,99 @@ +#ifndef STAN_IO_OUTPUT_CONTROLLER_HPP +#define STAN_IO_OUTPUT_CONTROLLER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace callbacks { + +enum class OutputFormat { + CSV, // Plain text CSV files for streaming data + MATRIX, // In-memory column-major matrix + ARROW, // Apache Arrow format for streaming data + JSON // JSON format for flexible metadata +}; + +struct OutputDims { + size_t rows; + size_t cols; +}; + +struct OutputConfig { + OutputFormat format; + std::string file_path; + OutputDims dims; +}; + +class output_controller { + private: + std::unordered_map> writers_; + std::unordered_map> files_; // Keep files alive + + std::shared_ptr create_writer(const OutputConfig& config) { + switch (config.format) { + case OutputFormat::CSV: { + auto file = std::make_shared(config.file_path); + files_[config.file_path] = file; // Store file + return std::make_shared(*file); + } + case OutputFormat::MATRIX: + return std::make_shared( + config.dims.rows, config.dims.cols); + case OutputFormat::ARROW: + // TODO: Implement Arrow writer + throw std::runtime_error("Arrow writer not yet implemented"); + case OutputFormat::JSON: { + auto file = std::make_shared(config.file_path); + files_[config.file_path] = file; // Store file + auto file_ptr = std::make_unique(config.file_path); + auto json_writer = std::make_shared>>( + std::move(file_ptr)); + return std::dynamic_pointer_cast(json_writer); + } + default: + throw std::runtime_error("Invalid output format"); + } + } + + public: + void configure_output(const std::string& info_type, const OutputConfig& config) { + writers_[info_type] = create_writer(config); + } + + std::shared_ptr get_writer(const std::string& info_type) { + auto it = writers_.find(info_type); + if (it == writers_.end()) { + throw std::runtime_error("No writer configured for " + info_type); + } + return it->second; + } + + // Forward all JSON writer methods to the underlying writer + template> + json_writer* get_json_writer(const std::string& info_type) { + auto writer = get_writer(info_type); + auto* json_writer = dynamic_cast*>(writer.get()); + if (!json_writer) { + throw std::runtime_error("Writer for " + info_type + " is not a JSON writer"); + } + return json_writer; + } + + void write(const std::string& info_type, const std::vector& data) { + auto writer = get_writer(info_type); + writer->operator()(data); + } +}; + +} // namespace callbacks +} // namespace stan + +#endif \ No newline at end of file diff --git a/src/test/unit/callbacks/output_controller_test.cpp b/src/test/unit/callbacks/output_controller_test.cpp index f80f9fe104..33df43c7e2 100644 --- a/src/test/unit/callbacks/output_controller_test.cpp +++ b/src/test/unit/callbacks/output_controller_test.cpp @@ -25,7 +25,7 @@ class mock_writer : public stan::callbacks::writer { }; TEST(output_controller, configure_and_write_streaming) { - stan::io::output_controller controller; + stan::callbacks::output_controller controller; // Configure output for samples controller.configure_output("samples", @@ -45,7 +45,7 @@ TEST(output_controller, configure_and_write_streaming) { } TEST(output_controller, write_metadata) { - stan::io::output_controller controller; + stan::callbacks::output_controller controller; // Configure output for metadata controller.configure_output("model_info", @@ -62,7 +62,7 @@ TEST(output_controller, write_metadata) { } TEST(output_controller, get_json_writer) { - stan::io::output_controller controller; + stan::callbacks::output_controller controller; // Configure JSON writer for metric controller.configure_output("metric", @@ -78,7 +78,7 @@ TEST(output_controller, get_json_writer) { } TEST(output_controller, get_json_writer_wrong_type) { - stan::io::output_controller controller; + stan::callbacks::output_controller controller; // Configure non-JSON writer controller.configure_output("samples", @@ -89,7 +89,7 @@ TEST(output_controller, get_json_writer_wrong_type) { } TEST(output_controller, multiple_formats) { - stan::io::output_controller controller; + stan::callbacks::output_controller controller; // Configure different formats for different information types controller.configure_output("samples", @@ -119,7 +119,7 @@ TEST(output_controller, multiple_formats) { } TEST(output_controller, unconfigured_output) { - stan::io::output_controller controller; + stan::callbacks::output_controller controller; // Attempt to write without configuration std::vector data = {1.0, 2.0, 3.0}; @@ -131,7 +131,7 @@ TEST(output_controller, unconfigured_output) { } TEST(output_controller, invalid_format) { - stan::io::output_controller controller; + stan::callbacks::output_controller controller; // Configure with invalid format controller.configure_output("samples", From ff72b3b0404e78f95dd9d1c5b8a9a4b27841281d Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:11:06 -0500 Subject: [PATCH 20/53] try again --- .../unit/callbacks/output_controller_test.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/test/unit/callbacks/output_controller_test.cpp b/src/test/unit/callbacks/output_controller_test.cpp index 33df43c7e2..c85ec17309 100644 --- a/src/test/unit/callbacks/output_controller_test.cpp +++ b/src/test/unit/callbacks/output_controller_test.cpp @@ -51,14 +51,15 @@ TEST(output_controller, write_metadata) { controller.configure_output("model_info", {stan::callbacks::OutputFormat::JSON, "model.json"}); - // Write metadata - std::vector metadata = {"model_name", "bernoulli", "version", "2.29.0"}; - controller.write_metadata("model_info", metadata); - - // Get writer and verify it's a JSON writer - auto writer = controller.get_writer("model_info"); - auto* json_writer = dynamic_cast>*>(writer.get()); + // Get JSON writer + auto* json_writer = controller.get_json_writer("model_info"); EXPECT_NE(json_writer, nullptr); + + // Write metadata as a properly structured JSON record + json_writer->begin_record(); + json_writer->write("model_name", "bernoulli"); + json_writer->write("version", "2.29.0"); + json_writer->end_record(); } TEST(output_controller, get_json_writer) { From 1dfa9143ba3544a4d7219cc36d07e3e2cf73fa47 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:14:49 -0500 Subject: [PATCH 21/53] TRY AGAIN --- src/stan/callbacks/output_controller.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/stan/callbacks/output_controller.hpp b/src/stan/callbacks/output_controller.hpp index f889a75242..2324348143 100644 --- a/src/stan/callbacks/output_controller.hpp +++ b/src/stan/callbacks/output_controller.hpp @@ -1,5 +1,5 @@ -#ifndef STAN_IO_OUTPUT_CONTROLLER_HPP -#define STAN_IO_OUTPUT_CONTROLLER_HPP +#ifndef STAN_CALLBACKS_OUTPUT_CONTROLLER_HPP +#define STAN_CALLBACKS_OUTPUT_CONTROLLER_HPP #include #include @@ -54,9 +54,9 @@ class output_controller { auto file = std::make_shared(config.file_path); files_[config.file_path] = file; // Store file auto file_ptr = std::make_unique(config.file_path); - auto json_writer = std::make_shared>>( + auto writer = std::make_shared>>( std::move(file_ptr)); - return std::dynamic_pointer_cast(json_writer); + return std::dynamic_pointer_cast(writer); } default: throw std::runtime_error("Invalid output format"); @@ -76,7 +76,7 @@ class output_controller { return it->second; } - // Get JSON writer for backward compatibility with existing code + // Forward all JSON writer methods to the underlying writer template> json_writer* get_json_writer(const std::string& info_type) { auto writer = get_writer(info_type); From ef0fa162f6427602e91f5be0ebd0216c6c567e9b Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:16:26 -0500 Subject: [PATCH 22/53] try again --- src/stan/callbacks/output_controller.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/stan/callbacks/output_controller.hpp b/src/stan/callbacks/output_controller.hpp index 2324348143..e4660a2732 100644 --- a/src/stan/callbacks/output_controller.hpp +++ b/src/stan/callbacks/output_controller.hpp @@ -54,9 +54,8 @@ class output_controller { auto file = std::make_shared(config.file_path); files_[config.file_path] = file; // Store file auto file_ptr = std::make_unique(config.file_path); - auto writer = std::make_shared>>( + return std::make_shared>>( std::move(file_ptr)); - return std::dynamic_pointer_cast(writer); } default: throw std::runtime_error("Invalid output format"); From 21fd3e8ee8ac1bd413470a2283aacc9f3157295e Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:18:14 -0500 Subject: [PATCH 23/53] try again --- src/stan/callbacks/output_controller.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/stan/callbacks/output_controller.hpp b/src/stan/callbacks/output_controller.hpp index e4660a2732..65de344d98 100644 --- a/src/stan/callbacks/output_controller.hpp +++ b/src/stan/callbacks/output_controller.hpp @@ -54,8 +54,9 @@ class output_controller { auto file = std::make_shared(config.file_path); files_[config.file_path] = file; // Store file auto file_ptr = std::make_unique(config.file_path); - return std::make_shared>>( + auto json_writer = std::make_shared>>( std::move(file_ptr)); + return std::static_pointer_cast(json_writer); } default: throw std::runtime_error("Invalid output format"); From 944c0a89251fd6555196dc874470a9c06654eefd Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:20:56 -0500 Subject: [PATCH 24/53] try again --- src/stan/callbacks/output_controller.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stan/callbacks/output_controller.hpp b/src/stan/callbacks/output_controller.hpp index 65de344d98..60995b2a84 100644 --- a/src/stan/callbacks/output_controller.hpp +++ b/src/stan/callbacks/output_controller.hpp @@ -54,9 +54,9 @@ class output_controller { auto file = std::make_shared(config.file_path); files_[config.file_path] = file; // Store file auto file_ptr = std::make_unique(config.file_path); - auto json_writer = std::make_shared>>( + auto jwriter = std::make_shared>>( std::move(file_ptr)); - return std::static_pointer_cast(json_writer); + return std::static_pointer_cast(jwriter); } default: throw std::runtime_error("Invalid output format"); From fb7966264f326b6a028c23b8e2f444e8f55dc458 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:30:18 -0500 Subject: [PATCH 25/53] try again --- src/stan/callbacks/output_controller.hpp | 18 +++++------------- .../unit/callbacks/output_controller_test.cpp | 3 +-- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/stan/callbacks/output_controller.hpp b/src/stan/callbacks/output_controller.hpp index 60995b2a84..1dbfebeff4 100644 --- a/src/stan/callbacks/output_controller.hpp +++ b/src/stan/callbacks/output_controller.hpp @@ -17,7 +17,6 @@ namespace callbacks { enum class OutputFormat { CSV, // Plain text CSV files for streaming data MATRIX, // In-memory column-major matrix - ARROW, // Apache Arrow format for streaming data JSON // JSON format for flexible metadata }; @@ -47,9 +46,6 @@ class output_controller { case OutputFormat::MATRIX: return std::make_shared( config.dims.rows, config.dims.cols); - case OutputFormat::ARROW: - // TODO: Implement Arrow writer - throw std::runtime_error("Arrow writer not yet implemented"); case OutputFormat::JSON: { auto file = std::make_shared(config.file_path); files_[config.file_path] = file; // Store file @@ -80,22 +76,18 @@ class output_controller { template> json_writer* get_json_writer(const std::string& info_type) { auto writer = get_writer(info_type); - auto* json_writer = dynamic_cast*>(writer.get()); - if (!json_writer) { + auto* jwriter = dynamic_cast*>(writer.get()); + if (!jwriter) { throw std::runtime_error("Writer for " + info_type + " is not a JSON writer"); } - return json_writer; + return jwriter; } void write(const std::string& info_type, const std::vector& data) { - auto writer = get_writer(info_type); - writer->operator()(data); + auto swriter = get_writer(info_type); + swriter->operator()(data); } - void write_metadata(const std::string& info_type, const std::vector& metadata) { - auto writer = get_writer(info_type); - writer->operator()(metadata); - } }; } // namespace callbacks diff --git a/src/test/unit/callbacks/output_controller_test.cpp b/src/test/unit/callbacks/output_controller_test.cpp index c85ec17309..ef75a6780e 100644 --- a/src/test/unit/callbacks/output_controller_test.cpp +++ b/src/test/unit/callbacks/output_controller_test.cpp @@ -51,11 +51,10 @@ TEST(output_controller, write_metadata) { controller.configure_output("model_info", {stan::callbacks::OutputFormat::JSON, "model.json"}); - // Get JSON writer + // Get JSON writer and write metadata directly auto* json_writer = controller.get_json_writer("model_info"); EXPECT_NE(json_writer, nullptr); - // Write metadata as a properly structured JSON record json_writer->begin_record(); json_writer->write("model_name", "bernoulli"); json_writer->write("version", "2.29.0"); From b52f087ed4aafb904f552c64c34252a6e5a6c95f Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:32:56 -0500 Subject: [PATCH 26/53] try again --- src/test/unit/callbacks/output_controller_test.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/test/unit/callbacks/output_controller_test.cpp b/src/test/unit/callbacks/output_controller_test.cpp index ef75a6780e..c585b2efea 100644 --- a/src/test/unit/callbacks/output_controller_test.cpp +++ b/src/test/unit/callbacks/output_controller_test.cpp @@ -13,15 +13,10 @@ class mock_writer : public stan::callbacks::writer { public: std::vector> data; - std::vector> metadata; void operator()(const std::vector& x) override { data.push_back(x); } - - void operator()(const std::vector& x) override { - metadata.push_back(x); - } }; TEST(output_controller, configure_and_write_streaming) { @@ -126,7 +121,6 @@ TEST(output_controller, unconfigured_output) { std::vector metadata = {"model_name", "bernoulli"}; EXPECT_THROW(controller.write("samples", data), std::runtime_error); - EXPECT_THROW(controller.write_metadata("model_info", metadata), std::runtime_error); EXPECT_THROW(controller.get_json_writer("metric"), std::runtime_error); } From 75f59da962169e3f86e9cd6e6b86a9ead8fa261c Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:43:34 -0500 Subject: [PATCH 27/53] try again --- src/stan/callbacks/output_controller.hpp | 36 ++++++++++--------- .../unit/callbacks/output_controller_test.cpp | 5 +-- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/stan/callbacks/output_controller.hpp b/src/stan/callbacks/output_controller.hpp index 1dbfebeff4..efd3ffc7f0 100644 --- a/src/stan/callbacks/output_controller.hpp +++ b/src/stan/callbacks/output_controller.hpp @@ -34,13 +34,14 @@ struct OutputConfig { class output_controller { private: std::unordered_map> writers_; - std::unordered_map> files_; // Keep files alive + std::unordered_map> structured_writers_; + std::unordered_map> files_; std::shared_ptr create_writer(const OutputConfig& config) { switch (config.format) { case OutputFormat::CSV: { auto file = std::make_shared(config.file_path); - files_[config.file_path] = file; // Store file + files_[config.file_path] = file; return std::make_shared(*file); } case OutputFormat::MATRIX: @@ -48,11 +49,12 @@ class output_controller { config.dims.rows, config.dims.cols); case OutputFormat::JSON: { auto file = std::make_shared(config.file_path); - files_[config.file_path] = file; // Store file + files_[config.file_path] = file; auto file_ptr = std::make_unique(config.file_path); - auto jwriter = std::make_shared>>( + auto jwriter = std::make_shared>( std::move(file_ptr)); - return std::static_pointer_cast(jwriter); + structured_writers_[config.file_path] = jwriter; + return nullptr; // JSON writers are handled separately } default: throw std::runtime_error("Invalid output format"); @@ -61,7 +63,10 @@ class output_controller { public: void configure_output(const std::string& info_type, const OutputConfig& config) { - writers_[info_type] = create_writer(config); + auto writer = create_writer(config); + if (writer) { + writers_[info_type] = writer; + } } std::shared_ptr get_writer(const std::string& info_type) { @@ -72,22 +77,19 @@ class output_controller { return it->second; } - // Forward all JSON writer methods to the underlying writer - template> - json_writer* get_json_writer(const std::string& info_type) { - auto writer = get_writer(info_type); - auto* jwriter = dynamic_cast*>(writer.get()); - if (!jwriter) { - throw std::runtime_error("Writer for " + info_type + " is not a JSON writer"); + template + json_writer* get_json_writer(const std::string& info_type) { + auto it = structured_writers_.find(info_type); + if (it == structured_writers_.end()) { + throw std::runtime_error("No JSON writer configured for " + info_type); } - return jwriter; + return dynamic_cast*>(it->second.get()); } void write(const std::string& info_type, const std::vector& data) { - auto swriter = get_writer(info_type); - swriter->operator()(data); + auto writer = get_writer(info_type); + writer->operator()(data); } - }; } // namespace callbacks diff --git a/src/test/unit/callbacks/output_controller_test.cpp b/src/test/unit/callbacks/output_controller_test.cpp index c585b2efea..76f5961624 100644 --- a/src/test/unit/callbacks/output_controller_test.cpp +++ b/src/test/unit/callbacks/output_controller_test.cpp @@ -46,10 +46,11 @@ TEST(output_controller, write_metadata) { controller.configure_output("model_info", {stan::callbacks::OutputFormat::JSON, "model.json"}); - // Get JSON writer and write metadata directly - auto* json_writer = controller.get_json_writer("model_info"); + // Get JSON writer - note we don't need to specify template parameters + auto* json_writer = controller.get_json_writer("model_info"); EXPECT_NE(json_writer, nullptr); + // Write metadata as a properly structured JSON record json_writer->begin_record(); json_writer->write("model_name", "bernoulli"); json_writer->write("version", "2.29.0"); From ce5b7a92a368fe5e8cb78dd4e7552d4f8c06020b Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:51:29 -0500 Subject: [PATCH 28/53] try again --- src/stan/callbacks/output_controller.hpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/stan/callbacks/output_controller.hpp b/src/stan/callbacks/output_controller.hpp index efd3ffc7f0..d33ebed1b7 100644 --- a/src/stan/callbacks/output_controller.hpp +++ b/src/stan/callbacks/output_controller.hpp @@ -53,7 +53,6 @@ class output_controller { auto file_ptr = std::make_unique(config.file_path); auto jwriter = std::make_shared>( std::move(file_ptr)); - structured_writers_[config.file_path] = jwriter; return nullptr; // JSON writers are handled separately } default: @@ -63,9 +62,18 @@ class output_controller { public: void configure_output(const std::string& info_type, const OutputConfig& config) { - auto writer = create_writer(config); - if (writer) { - writers_[info_type] = writer; + if (config.format == OutputFormat::JSON) { + auto file = std::make_shared(config.file_path); + files_[config.file_path] = file; + auto file_ptr = std::make_unique(config.file_path); + auto jwriter = std::make_shared>( + std::move(file_ptr)); + structured_writers_[info_type] = jwriter; + } else { + auto writer = create_writer(config); + if (writer) { + writers_[info_type] = writer; + } } } From e4f9e00aaf698b9db87bfbad446a8d7c35765156 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:54:48 -0500 Subject: [PATCH 29/53] try again --- src/stan/callbacks/output_controller.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/stan/callbacks/output_controller.hpp b/src/stan/callbacks/output_controller.hpp index d33ebed1b7..74eabe7564 100644 --- a/src/stan/callbacks/output_controller.hpp +++ b/src/stan/callbacks/output_controller.hpp @@ -56,7 +56,7 @@ class output_controller { return nullptr; // JSON writers are handled separately } default: - throw std::runtime_error("Invalid output format"); + return nullptr; // Let configure_output handle the error } } @@ -71,9 +71,10 @@ class output_controller { structured_writers_[info_type] = jwriter; } else { auto writer = create_writer(config); - if (writer) { - writers_[info_type] = writer; + if (!writer) { + throw std::runtime_error("Invalid output format"); } + writers_[info_type] = writer; } } From 7cf3f4ef5ae30fee5281a5ee4ff2f0bb007b9f7d Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 3 Mar 2025 22:58:39 -0500 Subject: [PATCH 30/53] try again --- src/test/unit/callbacks/output_controller_test.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/test/unit/callbacks/output_controller_test.cpp b/src/test/unit/callbacks/output_controller_test.cpp index 76f5961624..d24dbfeff0 100644 --- a/src/test/unit/callbacks/output_controller_test.cpp +++ b/src/test/unit/callbacks/output_controller_test.cpp @@ -128,10 +128,12 @@ TEST(output_controller, unconfigured_output) { TEST(output_controller, invalid_format) { stan::callbacks::output_controller controller; - // Configure with invalid format - controller.configure_output("samples", - {static_cast(999), "invalid"}); + // Configure with invalid format - should throw during configuration + EXPECT_THROW(controller.configure_output("samples", + {static_cast(999), "invalid"}), + std::runtime_error); + // Verify we can't write after invalid configuration std::vector data = {1.0, 2.0, 3.0}; EXPECT_THROW(controller.write("samples", data), std::runtime_error); } From 0558bad62f066c1fd44f80a0c4540148c60e6c2f Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Mon, 3 Mar 2025 23:05:19 -0500 Subject: [PATCH 31/53] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/stan/callbacks/matrix_writer.hpp | 10 +- src/stan/callbacks/output_controller.hpp | 32 +++--- src/stan/io/output_controller.hpp | 36 ++++--- .../unit/callbacks/output_controller_test.cpp | 99 ++++++++++--------- 4 files changed, 95 insertions(+), 82 deletions(-) diff --git a/src/stan/callbacks/matrix_writer.hpp b/src/stan/callbacks/matrix_writer.hpp index 36c8c0d57a..d9b5aaf43b 100644 --- a/src/stan/callbacks/matrix_writer.hpp +++ b/src/stan/callbacks/matrix_writer.hpp @@ -10,7 +10,7 @@ namespace callbacks { class matrix_writer : public writer { public: - matrix_writer(size_t rows, size_t cols) + matrix_writer(size_t rows, size_t cols) : rows_(rows), cols_(cols), data_(rows * cols) {} void operator()(const std::vector& x) override { @@ -20,7 +20,7 @@ class matrix_writer : public writer { if (x.size() != cols_) { throw std::runtime_error("Matrix writer: incorrect number of columns"); } - + // Store in column-major order for (size_t j = 0; j < cols_; ++j) { data_[j * rows_ + current_row_] = x[j]; @@ -44,7 +44,7 @@ class matrix_writer : public writer { std::vector data_; }; -} // namespace callbacks -} // namespace stan +} // namespace callbacks +} // namespace stan -#endif \ No newline at end of file +#endif \ No newline at end of file diff --git a/src/stan/callbacks/output_controller.hpp b/src/stan/callbacks/output_controller.hpp index 74eabe7564..e8b377344c 100644 --- a/src/stan/callbacks/output_controller.hpp +++ b/src/stan/callbacks/output_controller.hpp @@ -15,9 +15,9 @@ namespace stan { namespace callbacks { enum class OutputFormat { - CSV, // Plain text CSV files for streaming data - MATRIX, // In-memory column-major matrix - JSON // JSON format for flexible metadata + CSV, // Plain text CSV files for streaming data + MATRIX, // In-memory column-major matrix + JSON // JSON format for flexible metadata }; struct OutputDims { @@ -34,7 +34,8 @@ struct OutputConfig { class output_controller { private: std::unordered_map> writers_; - std::unordered_map> structured_writers_; + std::unordered_map> + structured_writers_; std::unordered_map> files_; std::shared_ptr create_writer(const OutputConfig& config) { @@ -45,14 +46,14 @@ class output_controller { return std::make_shared(*file); } case OutputFormat::MATRIX: - return std::make_shared( - config.dims.rows, config.dims.cols); + return std::make_shared(config.dims.rows, + config.dims.cols); case OutputFormat::JSON: { auto file = std::make_shared(config.file_path); files_[config.file_path] = file; auto file_ptr = std::make_unique(config.file_path); - auto jwriter = std::make_shared>( - std::move(file_ptr)); + auto jwriter + = std::make_shared>(std::move(file_ptr)); return nullptr; // JSON writers are handled separately } default: @@ -61,13 +62,14 @@ class output_controller { } public: - void configure_output(const std::string& info_type, const OutputConfig& config) { + void configure_output(const std::string& info_type, + const OutputConfig& config) { if (config.format == OutputFormat::JSON) { auto file = std::make_shared(config.file_path); files_[config.file_path] = file; auto file_ptr = std::make_unique(config.file_path); - auto jwriter = std::make_shared>( - std::move(file_ptr)); + auto jwriter + = std::make_shared>(std::move(file_ptr)); structured_writers_[info_type] = jwriter; } else { auto writer = create_writer(config); @@ -86,7 +88,7 @@ class output_controller { return it->second; } - template + template json_writer* get_json_writer(const std::string& info_type) { auto it = structured_writers_.find(info_type); if (it == structured_writers_.end()) { @@ -101,7 +103,7 @@ class output_controller { } }; -} // namespace callbacks -} // namespace stan +} // namespace callbacks +} // namespace stan -#endif +#endif diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp index edc37b1bfc..9c72cd380a 100644 --- a/src/stan/io/output_controller.hpp +++ b/src/stan/io/output_controller.hpp @@ -15,10 +15,10 @@ namespace stan { namespace callbacks { enum class OutputFormat { - CSV, // Plain text CSV files for streaming data - MATRIX, // In-memory column-major matrix - ARROW, // Apache Arrow format for streaming data - JSON // JSON format for flexible metadata + CSV, // Plain text CSV files for streaming data + MATRIX, // In-memory column-major matrix + ARROW, // Apache Arrow format for streaming data + JSON // JSON format for flexible metadata }; struct OutputDims { @@ -35,7 +35,8 @@ struct OutputConfig { class output_controller { private: std::unordered_map> writers_; - std::unordered_map> files_; // Keep files alive + std::unordered_map> + files_; // Keep files alive std::shared_ptr create_writer(const OutputConfig& config) { switch (config.format) { @@ -45,8 +46,8 @@ class output_controller { return std::make_shared(*file); } case OutputFormat::MATRIX: - return std::make_shared( - config.dims.rows, config.dims.cols); + return std::make_shared(config.dims.rows, + config.dims.cols); case OutputFormat::ARROW: // TODO: Implement Arrow writer throw std::runtime_error("Arrow writer not yet implemented"); @@ -54,7 +55,8 @@ class output_controller { auto file = std::make_shared(config.file_path); files_[config.file_path] = file; // Store file auto file_ptr = std::make_unique(config.file_path); - auto json_writer = std::make_shared>>( + auto json_writer = std::make_shared< + json_writer>>( std::move(file_ptr)); return std::dynamic_pointer_cast(json_writer); } @@ -64,7 +66,8 @@ class output_controller { } public: - void configure_output(const std::string& info_type, const OutputConfig& config) { + void configure_output(const std::string& info_type, + const OutputConfig& config) { writers_[info_type] = create_writer(config); } @@ -77,12 +80,15 @@ class output_controller { } // Forward all JSON writer methods to the underlying writer - template> + template > json_writer* get_json_writer(const std::string& info_type) { auto writer = get_writer(info_type); - auto* json_writer = dynamic_cast*>(writer.get()); + auto* json_writer + = dynamic_cast*>(writer.get()); if (!json_writer) { - throw std::runtime_error("Writer for " + info_type + " is not a JSON writer"); + throw std::runtime_error("Writer for " + info_type + + " is not a JSON writer"); } return json_writer; } @@ -93,7 +99,7 @@ class output_controller { } }; -} // namespace callbacks -} // namespace stan +} // namespace callbacks +} // namespace stan -#endif \ No newline at end of file +#endif \ No newline at end of file diff --git a/src/test/unit/callbacks/output_controller_test.cpp b/src/test/unit/callbacks/output_controller_test.cpp index d24dbfeff0..d09e78e67f 100644 --- a/src/test/unit/callbacks/output_controller_test.cpp +++ b/src/test/unit/callbacks/output_controller_test.cpp @@ -13,26 +13,25 @@ class mock_writer : public stan::callbacks::writer { public: std::vector> data; - - void operator()(const std::vector& x) override { - data.push_back(x); - } + + void operator()(const std::vector& x) override { data.push_back(x); } }; TEST(output_controller, configure_and_write_streaming) { stan::callbacks::output_controller controller; - + // Configure output for samples - controller.configure_output("samples", - {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); - + controller.configure_output( + "samples", {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); + // Write sample data std::vector sample = {1.0, 2.0, 3.0}; controller.write("samples", sample); - + // Get writer and verify data auto writer = controller.get_writer("samples"); - auto* matrix_writer = dynamic_cast(writer.get()); + auto* matrix_writer + = dynamic_cast(writer.get()); EXPECT_NE(matrix_writer, nullptr); EXPECT_EQ(matrix_writer->rows(), 100); EXPECT_EQ(matrix_writer->cols(), 3); @@ -41,15 +40,15 @@ TEST(output_controller, configure_and_write_streaming) { TEST(output_controller, write_metadata) { stan::callbacks::output_controller controller; - + // Configure output for metadata - controller.configure_output("model_info", - {stan::callbacks::OutputFormat::JSON, "model.json"}); - + controller.configure_output( + "model_info", {stan::callbacks::OutputFormat::JSON, "model.json"}); + // Get JSON writer - note we don't need to specify template parameters auto* json_writer = controller.get_json_writer("model_info"); EXPECT_NE(json_writer, nullptr); - + // Write metadata as a properly structured JSON record json_writer->begin_record(); json_writer->write("model_name", "bernoulli"); @@ -59,15 +58,15 @@ TEST(output_controller, write_metadata) { TEST(output_controller, get_json_writer) { stan::callbacks::output_controller controller; - + // Configure JSON writer for metric - controller.configure_output("metric", - {stan::callbacks::OutputFormat::JSON, "metric.json"}); - + controller.configure_output( + "metric", {stan::callbacks::OutputFormat::JSON, "metric.json"}); + // Get JSON writer directly auto* metric_writer = controller.get_json_writer("metric"); EXPECT_NE(metric_writer, nullptr); - + // Test using JSON writer interface directly metric_writer->write("stepsize", "0.1"); metric_writer->write("metric_type", "dense"); @@ -75,65 +74,71 @@ TEST(output_controller, get_json_writer) { TEST(output_controller, get_json_writer_wrong_type) { stan::callbacks::output_controller controller; - + // Configure non-JSON writer - controller.configure_output("samples", - {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); - + controller.configure_output( + "samples", {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); + // Attempt to get JSON writer - EXPECT_THROW(controller.get_json_writer("samples"), std::runtime_error); + EXPECT_THROW(controller.get_json_writer("samples"), + std::runtime_error); } TEST(output_controller, multiple_formats) { stan::callbacks::output_controller controller; - + // Configure different formats for different information types - controller.configure_output("samples", - {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); - controller.configure_output("diagnostics", - {stan::callbacks::OutputFormat::CSV, "diagnostics.csv"}); - controller.configure_output("metric", - {stan::callbacks::OutputFormat::JSON, "metric.json"}); - + controller.configure_output( + "samples", {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); + controller.configure_output( + "diagnostics", {stan::callbacks::OutputFormat::CSV, "diagnostics.csv"}); + controller.configure_output( + "metric", {stan::callbacks::OutputFormat::JSON, "metric.json"}); + // Write streaming data std::vector sample = {1.0, 2.0, 3.0}; std::vector diagnostic = {4.0, 5.0, 6.0}; controller.write("samples", sample); controller.write("diagnostics", diagnostic); - + // Write metric using JSON writer interface auto* metric_writer = controller.get_json_writer("metric"); metric_writer->write("stepsize", "0.1"); - + // Verify writers auto sample_writer = controller.get_writer("samples"); auto diagnostic_writer = controller.get_writer("diagnostics"); - - EXPECT_NE(dynamic_cast(sample_writer.get()), nullptr); - EXPECT_NE(dynamic_cast(diagnostic_writer.get()), nullptr); + + EXPECT_NE(dynamic_cast(sample_writer.get()), + nullptr); + EXPECT_NE( + dynamic_cast(diagnostic_writer.get()), + nullptr); EXPECT_NE(controller.get_json_writer("metric"), nullptr); } TEST(output_controller, unconfigured_output) { stan::callbacks::output_controller controller; - + // Attempt to write without configuration std::vector data = {1.0, 2.0, 3.0}; std::vector metadata = {"model_name", "bernoulli"}; - + EXPECT_THROW(controller.write("samples", data), std::runtime_error); - EXPECT_THROW(controller.get_json_writer("metric"), std::runtime_error); + EXPECT_THROW(controller.get_json_writer("metric"), + std::runtime_error); } TEST(output_controller, invalid_format) { stan::callbacks::output_controller controller; - + // Configure with invalid format - should throw during configuration - EXPECT_THROW(controller.configure_output("samples", - {static_cast(999), "invalid"}), - std::runtime_error); - + EXPECT_THROW(controller.configure_output( + "samples", {static_cast(999), + "invalid"}), + std::runtime_error); + // Verify we can't write after invalid configuration std::vector data = {1.0, 2.0, 3.0}; EXPECT_THROW(controller.write("samples", data), std::runtime_error); -} +} From 03d24f6fce93870e4ca07598d96dcbe9fc9d3f8a Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 4 Mar 2025 12:10:52 -0500 Subject: [PATCH 32/53] lint fix, cleanup --- src/stan/io/output_controller.hpp | 105 ------------------------------ 1 file changed, 105 deletions(-) delete mode 100644 src/stan/io/output_controller.hpp diff --git a/src/stan/io/output_controller.hpp b/src/stan/io/output_controller.hpp deleted file mode 100644 index 9c72cd380a..0000000000 --- a/src/stan/io/output_controller.hpp +++ /dev/null @@ -1,105 +0,0 @@ -#ifndef STAN_IO_OUTPUT_CONTROLLER_HPP -#define STAN_IO_OUTPUT_CONTROLLER_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace stan { -namespace callbacks { - -enum class OutputFormat { - CSV, // Plain text CSV files for streaming data - MATRIX, // In-memory column-major matrix - ARROW, // Apache Arrow format for streaming data - JSON // JSON format for flexible metadata -}; - -struct OutputDims { - size_t rows; - size_t cols; -}; - -struct OutputConfig { - OutputFormat format; - std::string file_path; - OutputDims dims; -}; - -class output_controller { - private: - std::unordered_map> writers_; - std::unordered_map> - files_; // Keep files alive - - std::shared_ptr create_writer(const OutputConfig& config) { - switch (config.format) { - case OutputFormat::CSV: { - auto file = std::make_shared(config.file_path); - files_[config.file_path] = file; // Store file - return std::make_shared(*file); - } - case OutputFormat::MATRIX: - return std::make_shared(config.dims.rows, - config.dims.cols); - case OutputFormat::ARROW: - // TODO: Implement Arrow writer - throw std::runtime_error("Arrow writer not yet implemented"); - case OutputFormat::JSON: { - auto file = std::make_shared(config.file_path); - files_[config.file_path] = file; // Store file - auto file_ptr = std::make_unique(config.file_path); - auto json_writer = std::make_shared< - json_writer>>( - std::move(file_ptr)); - return std::dynamic_pointer_cast(json_writer); - } - default: - throw std::runtime_error("Invalid output format"); - } - } - - public: - void configure_output(const std::string& info_type, - const OutputConfig& config) { - writers_[info_type] = create_writer(config); - } - - std::shared_ptr get_writer(const std::string& info_type) { - auto it = writers_.find(info_type); - if (it == writers_.end()) { - throw std::runtime_error("No writer configured for " + info_type); - } - return it->second; - } - - // Forward all JSON writer methods to the underlying writer - template > - json_writer* get_json_writer(const std::string& info_type) { - auto writer = get_writer(info_type); - auto* json_writer - = dynamic_cast*>(writer.get()); - if (!json_writer) { - throw std::runtime_error("Writer for " + info_type - + " is not a JSON writer"); - } - return json_writer; - } - - void write(const std::string& info_type, const std::vector& data) { - auto writer = get_writer(info_type); - writer->operator()(data); - } -}; - -} // namespace callbacks -} // namespace stan - -#endif \ No newline at end of file From 072d882e3e102e3a3c46dab004ffdae7afbdccc2 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 4 Mar 2025 12:15:47 -0500 Subject: [PATCH 33/53] lint fix --- src/stan/callbacks/matrix_writer.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/stan/callbacks/matrix_writer.hpp b/src/stan/callbacks/matrix_writer.hpp index d9b5aaf43b..97a26319f0 100644 --- a/src/stan/callbacks/matrix_writer.hpp +++ b/src/stan/callbacks/matrix_writer.hpp @@ -47,4 +47,5 @@ class matrix_writer : public writer { } // namespace callbacks } // namespace stan -#endif \ No newline at end of file +#endif + From 3d6f08f6041df2988830e13a29d1edca9e05d197 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Tue, 4 Mar 2025 12:16:11 -0500 Subject: [PATCH 34/53] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/stan/callbacks/matrix_writer.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stan/callbacks/matrix_writer.hpp b/src/stan/callbacks/matrix_writer.hpp index 97a26319f0..81f7974728 100644 --- a/src/stan/callbacks/matrix_writer.hpp +++ b/src/stan/callbacks/matrix_writer.hpp @@ -48,4 +48,3 @@ class matrix_writer : public writer { } // namespace stan #endif - From f86be7773d74fbaaab01712a1349d6dcc8901847 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 4 Mar 2025 14:27:02 -0500 Subject: [PATCH 35/53] lint fix --- src/stan/callbacks/matrix_writer.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stan/callbacks/matrix_writer.hpp b/src/stan/callbacks/matrix_writer.hpp index 97a26319f0..81f7974728 100644 --- a/src/stan/callbacks/matrix_writer.hpp +++ b/src/stan/callbacks/matrix_writer.hpp @@ -48,4 +48,3 @@ class matrix_writer : public writer { } // namespace stan #endif - From 87b71659397418e9d86bf8369d8ab95e3264a211 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 5 Mar 2025 19:06:00 -0500 Subject: [PATCH 36/53] checkpointing --- src/stan/callbacks/dispatcher.hpp | 139 +++++++++++++++++ src/stan/callbacks/output_controller.hpp | 109 ------------- src/test/unit/callbacks/dispatcher_test.cpp | 117 ++++++++++++++ .../unit/callbacks/output_controller_test.cpp | 144 ------------------ 4 files changed, 256 insertions(+), 253 deletions(-) create mode 100644 src/stan/callbacks/dispatcher.hpp delete mode 100644 src/stan/callbacks/output_controller.hpp create mode 100644 src/test/unit/callbacks/dispatcher_test.cpp delete mode 100644 src/test/unit/callbacks/output_controller_test.cpp diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp new file mode 100644 index 0000000000..e586c57265 --- /dev/null +++ b/src/stan/callbacks/dispatcher.hpp @@ -0,0 +1,139 @@ +#ifndef STAN_CALLBACKS_DISPATCHER_HPP +#define STAN_CALLBACKS_DISPATCHER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace stan { + namespace callbacks { + + enum class InfoType { + CONFIG, // series of string messages + SAMPLE, // draw from posterior + METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' + ALGORITHM_STATE, // sampler state for returned draw + }; + + struct InfoTypeHash { + std::size_t operator()(const InfoType& type) const { + return std::hash()(static_cast(type)); + } + }; + + // Base type for type erasure. + class Channel { + public: + virtual ~Channel() {} + }; + + // Adapter for plain writers. + // These writer types (e.g., stream_writer, unique_stream_writer) support only a one-argument operator(). + class WriterChannel : public Channel { + public: + explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) { + if (!w) + throw std::runtime_error("Null writer pointer provided to WriterChannel"); + } + // Single-argument dispatch: forwards to operator(). + template + void dispatch(const T& value) { + (*writer_)(value); + } + // Plain writers do not support key/value writes. + void begin_record() { } + void end_record() { } + private: + stan::callbacks::writer* writer_; + }; + + // Adapter for structured writers. + // The structured_writer interface provides a one-argument write(const std::string&) + // and a key/value write(const std::string&, const T&) overload. + class StructuredWriterChannel : public Channel { + public: + explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw) : writer_(sw) { + if (!sw) + throw std::runtime_error("Null structured writer pointer provided to StructuredWriterChannel"); + } + // Key dispatch + void dispatch(const std::string& key) { + writer_->write(key); + } + // Key/value dispatch + template + void dispatch(const std::string& key, const T& value) { + writer_->write(key, std::forward(value)); + } + void begin_record() { + writer_->begin_record(); + } + void end_record() { + writer_->end_record(); + } + private: + stan::callbacks::structured_writer* writer_; + }; + + // dispatcher class with two overloads for dispatch(). + class dispatcher { + public: + dispatcher() = default; + ~dispatcher() = default; + + void register_channel(InfoType type, std::unique_ptr channel) { + channels_[type] = std::move(channel); + } + + // Overload for non-string types: only forward to plain writer channels. + template , std::string>::value>> + void dispatch(InfoType type, T&& info) { + auto it = channels_.find(type); + if (it == channels_.end()) return; // silently do nothing + if (auto* wc = dynamic_cast(it->second.get())) + wc->dispatch(std::forward(info)); + // We do not forward non-string types to structured writer channels. + } + + // Overload for string types: forward to both plain and structured writer channels. + void dispatch(InfoType type, const std::string& info) { + auto it = channels_.find(type); + if (it == channels_.end()) return; // silently do nothing + if (auto* wc = dynamic_cast(it->second.get())) + wc->dispatch(info); + if (auto* sw = dynamic_cast(it->second.get())) + sw->dispatch(info); + } + + // Forward a begin_record call. + void begin_record(InfoType type) { + auto it = channels_.find(type); + if (it == channels_.end()) return; + if (auto* sw = dynamic_cast(it->second.get())) + sw->begin_record(); + } + + // Forward an end_record call. + void end_record(InfoType type) { + auto it = channels_.find(type); + if (it == channels_.end()) return; + if (auto* sw = dynamic_cast(it->second.get())) + sw->end_record(); + } + + protected: + std::unordered_map, InfoTypeHash> channels_; + }; + + } // namespace callbacks +} // namespace stan + +#endif // STAN_CALLBACKS_DISPATCHER_HPP diff --git a/src/stan/callbacks/output_controller.hpp b/src/stan/callbacks/output_controller.hpp deleted file mode 100644 index e8b377344c..0000000000 --- a/src/stan/callbacks/output_controller.hpp +++ /dev/null @@ -1,109 +0,0 @@ -#ifndef STAN_CALLBACKS_OUTPUT_CONTROLLER_HPP -#define STAN_CALLBACKS_OUTPUT_CONTROLLER_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace stan { -namespace callbacks { - -enum class OutputFormat { - CSV, // Plain text CSV files for streaming data - MATRIX, // In-memory column-major matrix - JSON // JSON format for flexible metadata -}; - -struct OutputDims { - size_t rows; - size_t cols; -}; - -struct OutputConfig { - OutputFormat format; - std::string file_path; - OutputDims dims; -}; - -class output_controller { - private: - std::unordered_map> writers_; - std::unordered_map> - structured_writers_; - std::unordered_map> files_; - - std::shared_ptr create_writer(const OutputConfig& config) { - switch (config.format) { - case OutputFormat::CSV: { - auto file = std::make_shared(config.file_path); - files_[config.file_path] = file; - return std::make_shared(*file); - } - case OutputFormat::MATRIX: - return std::make_shared(config.dims.rows, - config.dims.cols); - case OutputFormat::JSON: { - auto file = std::make_shared(config.file_path); - files_[config.file_path] = file; - auto file_ptr = std::make_unique(config.file_path); - auto jwriter - = std::make_shared>(std::move(file_ptr)); - return nullptr; // JSON writers are handled separately - } - default: - return nullptr; // Let configure_output handle the error - } - } - - public: - void configure_output(const std::string& info_type, - const OutputConfig& config) { - if (config.format == OutputFormat::JSON) { - auto file = std::make_shared(config.file_path); - files_[config.file_path] = file; - auto file_ptr = std::make_unique(config.file_path); - auto jwriter - = std::make_shared>(std::move(file_ptr)); - structured_writers_[info_type] = jwriter; - } else { - auto writer = create_writer(config); - if (!writer) { - throw std::runtime_error("Invalid output format"); - } - writers_[info_type] = writer; - } - } - - std::shared_ptr get_writer(const std::string& info_type) { - auto it = writers_.find(info_type); - if (it == writers_.end()) { - throw std::runtime_error("No writer configured for " + info_type); - } - return it->second; - } - - template - json_writer* get_json_writer(const std::string& info_type) { - auto it = structured_writers_.find(info_type); - if (it == structured_writers_.end()) { - throw std::runtime_error("No JSON writer configured for " + info_type); - } - return dynamic_cast*>(it->second.get()); - } - - void write(const std::string& info_type, const std::vector& data) { - auto writer = get_writer(info_type); - writer->operator()(data); - } -}; - -} // namespace callbacks -} // namespace stan - -#endif diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp new file mode 100644 index 0000000000..fb6879c50c --- /dev/null +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -0,0 +1,117 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// For this test we assume that InfoType has at least these values: +using stan::callbacks::InfoType; +using stan::callbacks::dispatcher; + +struct deleter_noop { + template + constexpr void operator()(T* arg) const {} +}; + +class DispatcherTest : public ::testing::Test { + public: + DispatcherTest() : + ss_sample(), + ss_config(), + ss_metric(), + writer_sample(ss_sample), + writer_config(ss_config), + writer_metric(std::unique_ptr(&ss_metric)), + dispatcher() {} + + void SetUp() { + ss_sample.str(std::string()); + ss_sample.clear(); + ss_config.str(std::string()); + ss_config.clear(); + ss_metric.str(std::string()); + ss_metric.clear(); + + dispatcher.register_channel(InfoType::CONFIG, + std::unique_ptr(new stan::callbacks::WriterChannel(&writer_config))); + + dispatcher.register_channel(InfoType::SAMPLE, + std::unique_ptr(new stan::callbacks::WriterChannel(&writer_sample))); + + dispatcher.register_channel(InfoType::METRIC, + std::unique_ptr(new stan::callbacks::StructuredWriterChannel(&writer_metric))); + } + + void TearDown() {} + + std::stringstream ss_sample; + std::stringstream ss_config; + std::stringstream ss_metric; + + stan::callbacks::stream_writer writer_sample; + stan::callbacks::stream_writer writer_config; + stan::callbacks::json_writer writer_metric; + stan::callbacks::dispatcher dispatcher; +}; + +TEST_F(DispatcherTest, ConfigPlainMultipleMessages) { + // Dispatch several string messages to a plain writer (CONFIG). + dispatcher.dispatch(InfoType::CONFIG, std::string("Config1")); + dispatcher.dispatch(InfoType::CONFIG, std::string("Config2")); + dispatcher.dispatch(InfoType::CONFIG, std::string("Config3")); + EXPECT_EQ(ss_config.str(), "Config1\nConfig2\nConfig3\n"); +} + +TEST_F(DispatcherTest, SamplePlainVector) { + // Dispatch a vector of doubles to a plain writer (SAMPLE). + std::vector sample = {1.1, 2.2, 3.3}; + dispatcher.dispatch(InfoType::SAMPLE, sample); + std::string output(ss_sample.str()); + EXPECT_EQ(count_matches("1.1", output), 1); + EXPECT_EQ(count_matches("2.2", output), 1); + EXPECT_EQ(count_matches("2.2", output), 1); +} + +TEST_F(DispatcherTest, MetricStructuredKeyValueRecord) { + // For METRIC (structured writer), open a record, dispatch key/value pairs, then close the record. + dispatcher.begin_record(InfoType::METRIC); + dispatcher.dispatch(InfoType::METRIC, "metric_type", std::string("diag")); + dispatcher.dispatch(InfoType::METRIC, "stepsize", 0.6789); + // For the inv_metric, assume the caller converts the vector to a comma-separated string. + std::vector inv_metric = {0.1, 0.2, 0.3}; + std::string inv_metric_str; + for (size_t i = 0; i < inv_metric.size(); ++i) { + inv_metric_str += std::to_string(inv_metric[i]); + if(i != inv_metric.size() - 1) inv_metric_str += ","; + } + dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric_str); + dispatcher.end_record(InfoType::METRIC); + // Expected output: + // Begin record marker, followed by key/value pairs each formatted as "key:value;" and then end record marker. + std::cout << ss_metric.str() << std::endl; +} + +// TEST_F(DispatcherTest, NonStringDispatchToPlain) { +// // Dispatch a non-string (e.g., an int) to a plain writer (CONFIG). +// dispatcher.dispatch(InfoType::CONFIG, 999); +// // Expect conversion via std::to_string, so output should be "999". +// EXPECT_EQ(plainWriter.output, "999"); +// } + +// TEST_F(DispatcherTest, NonStringDispatchToStructuredIgnored) { +// // Dispatch a non-string (e.g., an int) to a structured writer (METRIC) should be ignored. +// dispatcher.dispatch(InfoType::METRIC, 123); +// EXPECT_EQ(structuredWriter.output, ""); +// } + +// TEST_F(DispatcherTest, UnregisteredInfoType) { +// // Dispatch to an unregistered InfoType (e.g., ALGORITHM_STATE) produces no output. +// dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::string("NoOutput")); +// EXPECT_EQ(plainWriter.output, ""); +// EXPECT_EQ(structuredWriter.output, ""); +// } diff --git a/src/test/unit/callbacks/output_controller_test.cpp b/src/test/unit/callbacks/output_controller_test.cpp deleted file mode 100644 index d09e78e67f..0000000000 --- a/src/test/unit/callbacks/output_controller_test.cpp +++ /dev/null @@ -1,144 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// Mock writer for testing -class mock_writer : public stan::callbacks::writer { - public: - std::vector> data; - - void operator()(const std::vector& x) override { data.push_back(x); } -}; - -TEST(output_controller, configure_and_write_streaming) { - stan::callbacks::output_controller controller; - - // Configure output for samples - controller.configure_output( - "samples", {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); - - // Write sample data - std::vector sample = {1.0, 2.0, 3.0}; - controller.write("samples", sample); - - // Get writer and verify data - auto writer = controller.get_writer("samples"); - auto* matrix_writer - = dynamic_cast(writer.get()); - EXPECT_NE(matrix_writer, nullptr); - EXPECT_EQ(matrix_writer->rows(), 100); - EXPECT_EQ(matrix_writer->cols(), 3); - EXPECT_EQ(matrix_writer->current_row(), 1); -} - -TEST(output_controller, write_metadata) { - stan::callbacks::output_controller controller; - - // Configure output for metadata - controller.configure_output( - "model_info", {stan::callbacks::OutputFormat::JSON, "model.json"}); - - // Get JSON writer - note we don't need to specify template parameters - auto* json_writer = controller.get_json_writer("model_info"); - EXPECT_NE(json_writer, nullptr); - - // Write metadata as a properly structured JSON record - json_writer->begin_record(); - json_writer->write("model_name", "bernoulli"); - json_writer->write("version", "2.29.0"); - json_writer->end_record(); -} - -TEST(output_controller, get_json_writer) { - stan::callbacks::output_controller controller; - - // Configure JSON writer for metric - controller.configure_output( - "metric", {stan::callbacks::OutputFormat::JSON, "metric.json"}); - - // Get JSON writer directly - auto* metric_writer = controller.get_json_writer("metric"); - EXPECT_NE(metric_writer, nullptr); - - // Test using JSON writer interface directly - metric_writer->write("stepsize", "0.1"); - metric_writer->write("metric_type", "dense"); -} - -TEST(output_controller, get_json_writer_wrong_type) { - stan::callbacks::output_controller controller; - - // Configure non-JSON writer - controller.configure_output( - "samples", {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); - - // Attempt to get JSON writer - EXPECT_THROW(controller.get_json_writer("samples"), - std::runtime_error); -} - -TEST(output_controller, multiple_formats) { - stan::callbacks::output_controller controller; - - // Configure different formats for different information types - controller.configure_output( - "samples", {stan::callbacks::OutputFormat::MATRIX, "memory", 100, 3}); - controller.configure_output( - "diagnostics", {stan::callbacks::OutputFormat::CSV, "diagnostics.csv"}); - controller.configure_output( - "metric", {stan::callbacks::OutputFormat::JSON, "metric.json"}); - - // Write streaming data - std::vector sample = {1.0, 2.0, 3.0}; - std::vector diagnostic = {4.0, 5.0, 6.0}; - controller.write("samples", sample); - controller.write("diagnostics", diagnostic); - - // Write metric using JSON writer interface - auto* metric_writer = controller.get_json_writer("metric"); - metric_writer->write("stepsize", "0.1"); - - // Verify writers - auto sample_writer = controller.get_writer("samples"); - auto diagnostic_writer = controller.get_writer("diagnostics"); - - EXPECT_NE(dynamic_cast(sample_writer.get()), - nullptr); - EXPECT_NE( - dynamic_cast(diagnostic_writer.get()), - nullptr); - EXPECT_NE(controller.get_json_writer("metric"), nullptr); -} - -TEST(output_controller, unconfigured_output) { - stan::callbacks::output_controller controller; - - // Attempt to write without configuration - std::vector data = {1.0, 2.0, 3.0}; - std::vector metadata = {"model_name", "bernoulli"}; - - EXPECT_THROW(controller.write("samples", data), std::runtime_error); - EXPECT_THROW(controller.get_json_writer("metric"), - std::runtime_error); -} - -TEST(output_controller, invalid_format) { - stan::callbacks::output_controller controller; - - // Configure with invalid format - should throw during configuration - EXPECT_THROW(controller.configure_output( - "samples", {static_cast(999), - "invalid"}), - std::runtime_error); - - // Verify we can't write after invalid configuration - std::vector data = {1.0, 2.0, 3.0}; - EXPECT_THROW(controller.write("samples", data), std::runtime_error); -} From 89d756b23a601c560cb63a09930f5fcbce011efc Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Wed, 5 Mar 2025 19:06:38 -0500 Subject: [PATCH 37/53] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/stan/callbacks/dispatcher.hpp | 254 ++++++++++---------- src/test/unit/callbacks/dispatcher_test.cpp | 59 +++-- 2 files changed, 164 insertions(+), 149 deletions(-) diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index e586c57265..b4630860f6 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -11,129 +11,133 @@ #include #include - namespace stan { - namespace callbacks { - - enum class InfoType { - CONFIG, // series of string messages - SAMPLE, // draw from posterior - METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' - ALGORITHM_STATE, // sampler state for returned draw - }; - - struct InfoTypeHash { - std::size_t operator()(const InfoType& type) const { - return std::hash()(static_cast(type)); - } - }; - - // Base type for type erasure. - class Channel { - public: - virtual ~Channel() {} - }; - - // Adapter for plain writers. - // These writer types (e.g., stream_writer, unique_stream_writer) support only a one-argument operator(). - class WriterChannel : public Channel { - public: - explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) { - if (!w) - throw std::runtime_error("Null writer pointer provided to WriterChannel"); - } - // Single-argument dispatch: forwards to operator(). - template - void dispatch(const T& value) { - (*writer_)(value); - } - // Plain writers do not support key/value writes. - void begin_record() { } - void end_record() { } - private: - stan::callbacks::writer* writer_; - }; - - // Adapter for structured writers. - // The structured_writer interface provides a one-argument write(const std::string&) - // and a key/value write(const std::string&, const T&) overload. - class StructuredWriterChannel : public Channel { - public: - explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw) : writer_(sw) { - if (!sw) - throw std::runtime_error("Null structured writer pointer provided to StructuredWriterChannel"); - } - // Key dispatch - void dispatch(const std::string& key) { - writer_->write(key); - } - // Key/value dispatch - template - void dispatch(const std::string& key, const T& value) { - writer_->write(key, std::forward(value)); - } - void begin_record() { - writer_->begin_record(); - } - void end_record() { - writer_->end_record(); - } - private: - stan::callbacks::structured_writer* writer_; - }; - - // dispatcher class with two overloads for dispatch(). - class dispatcher { - public: - dispatcher() = default; - ~dispatcher() = default; - - void register_channel(InfoType type, std::unique_ptr channel) { - channels_[type] = std::move(channel); - } - - // Overload for non-string types: only forward to plain writer channels. - template , std::string>::value>> - void dispatch(InfoType type, T&& info) { - auto it = channels_.find(type); - if (it == channels_.end()) return; // silently do nothing - if (auto* wc = dynamic_cast(it->second.get())) - wc->dispatch(std::forward(info)); - // We do not forward non-string types to structured writer channels. - } - - // Overload for string types: forward to both plain and structured writer channels. - void dispatch(InfoType type, const std::string& info) { - auto it = channels_.find(type); - if (it == channels_.end()) return; // silently do nothing - if (auto* wc = dynamic_cast(it->second.get())) - wc->dispatch(info); - if (auto* sw = dynamic_cast(it->second.get())) - sw->dispatch(info); - } - - // Forward a begin_record call. - void begin_record(InfoType type) { - auto it = channels_.find(type); - if (it == channels_.end()) return; - if (auto* sw = dynamic_cast(it->second.get())) - sw->begin_record(); - } - - // Forward an end_record call. - void end_record(InfoType type) { - auto it = channels_.find(type); - if (it == channels_.end()) return; - if (auto* sw = dynamic_cast(it->second.get())) - sw->end_record(); - } - - protected: - std::unordered_map, InfoTypeHash> channels_; - }; - - } // namespace callbacks -} // namespace stan - -#endif // STAN_CALLBACKS_DISPATCHER_HPP +namespace callbacks { + +enum class InfoType { + CONFIG, // series of string messages + SAMPLE, // draw from posterior + METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' + ALGORITHM_STATE, // sampler state for returned draw +}; + +struct InfoTypeHash { + std::size_t operator()(const InfoType& type) const { + return std::hash()(static_cast(type)); + } +}; + +// Base type for type erasure. +class Channel { + public: + virtual ~Channel() {} +}; + +// Adapter for plain writers. +// These writer types (e.g., stream_writer, unique_stream_writer) support only a +// one-argument operator(). +class WriterChannel : public Channel { + public: + explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) { + if (!w) + throw std::runtime_error("Null writer pointer provided to WriterChannel"); + } + // Single-argument dispatch: forwards to operator(). + template + void dispatch(const T& value) { + (*writer_)(value); + } + // Plain writers do not support key/value writes. + void begin_record() {} + void end_record() {} + + private: + stan::callbacks::writer* writer_; +}; + +// Adapter for structured writers. +// The structured_writer interface provides a one-argument write(const +// std::string&) and a key/value write(const std::string&, const T&) overload. +class StructuredWriterChannel : public Channel { + public: + explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw) + : writer_(sw) { + if (!sw) + throw std::runtime_error( + "Null structured writer pointer provided to StructuredWriterChannel"); + } + // Key dispatch + void dispatch(const std::string& key) { writer_->write(key); } + // Key/value dispatch + template + void dispatch(const std::string& key, const T& value) { + writer_->write(key, std::forward(value)); + } + void begin_record() { writer_->begin_record(); } + void end_record() { writer_->end_record(); } + + private: + stan::callbacks::structured_writer* writer_; +}; + +// dispatcher class with two overloads for dispatch(). +class dispatcher { + public: + dispatcher() = default; + ~dispatcher() = default; + + void register_channel(InfoType type, std::unique_ptr channel) { + channels_[type] = std::move(channel); + } + + // Overload for non-string types: only forward to plain writer channels. + template , std::string>::value>> + void dispatch(InfoType type, T&& info) { + auto it = channels_.find(type); + if (it == channels_.end()) + return; // silently do nothing + if (auto* wc = dynamic_cast(it->second.get())) + wc->dispatch(std::forward(info)); + // We do not forward non-string types to structured writer channels. + } + + // Overload for string types: forward to both plain and structured writer + // channels. + void dispatch(InfoType type, const std::string& info) { + auto it = channels_.find(type); + if (it == channels_.end()) + return; // silently do nothing + if (auto* wc = dynamic_cast(it->second.get())) + wc->dispatch(info); + if (auto* sw = dynamic_cast(it->second.get())) + sw->dispatch(info); + } + + // Forward a begin_record call. + void begin_record(InfoType type) { + auto it = channels_.find(type); + if (it == channels_.end()) + return; + if (auto* sw = dynamic_cast(it->second.get())) + sw->begin_record(); + } + + // Forward an end_record call. + void end_record(InfoType type) { + auto it = channels_.find(type); + if (it == channels_.end()) + return; + if (auto* sw = dynamic_cast(it->second.get())) + sw->end_record(); + } + + protected: + std::unordered_map, InfoTypeHash> + channels_; +}; + +} // namespace callbacks +} // namespace stan + +#endif // STAN_CALLBACKS_DISPATCHER_HPP diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index fb6879c50c..ab4b73cfd5 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -10,8 +10,8 @@ #include // For this test we assume that InfoType has at least these values: -using stan::callbacks::InfoType; using stan::callbacks::dispatcher; +using stan::callbacks::InfoType; struct deleter_noop { template @@ -20,14 +20,15 @@ struct deleter_noop { class DispatcherTest : public ::testing::Test { public: - DispatcherTest() : - ss_sample(), - ss_config(), - ss_metric(), - writer_sample(ss_sample), - writer_config(ss_config), - writer_metric(std::unique_ptr(&ss_metric)), - dispatcher() {} + DispatcherTest() + : ss_sample(), + ss_config(), + ss_metric(), + writer_sample(ss_sample), + writer_config(ss_config), + writer_metric( + std::unique_ptr(&ss_metric)), + dispatcher() {} void SetUp() { ss_sample.str(std::string()); @@ -37,14 +38,20 @@ class DispatcherTest : public ::testing::Test { ss_metric.str(std::string()); ss_metric.clear(); - dispatcher.register_channel(InfoType::CONFIG, - std::unique_ptr(new stan::callbacks::WriterChannel(&writer_config))); + dispatcher.register_channel( + InfoType::CONFIG, + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_config))); - dispatcher.register_channel(InfoType::SAMPLE, - std::unique_ptr(new stan::callbacks::WriterChannel(&writer_sample))); + dispatcher.register_channel( + InfoType::SAMPLE, + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_sample))); - dispatcher.register_channel(InfoType::METRIC, - std::unique_ptr(new stan::callbacks::StructuredWriterChannel(&writer_metric))); + dispatcher.register_channel( + InfoType::METRIC, + std::unique_ptr( + new stan::callbacks::StructuredWriterChannel(&writer_metric))); } void TearDown() {} @@ -78,21 +85,25 @@ TEST_F(DispatcherTest, SamplePlainVector) { } TEST_F(DispatcherTest, MetricStructuredKeyValueRecord) { - // For METRIC (structured writer), open a record, dispatch key/value pairs, then close the record. + // For METRIC (structured writer), open a record, dispatch key/value pairs, + // then close the record. dispatcher.begin_record(InfoType::METRIC); dispatcher.dispatch(InfoType::METRIC, "metric_type", std::string("diag")); dispatcher.dispatch(InfoType::METRIC, "stepsize", 0.6789); - // For the inv_metric, assume the caller converts the vector to a comma-separated string. + // For the inv_metric, assume the caller converts the vector to a + // comma-separated string. std::vector inv_metric = {0.1, 0.2, 0.3}; std::string inv_metric_str; for (size_t i = 0; i < inv_metric.size(); ++i) { inv_metric_str += std::to_string(inv_metric[i]); - if(i != inv_metric.size() - 1) inv_metric_str += ","; + if (i != inv_metric.size() - 1) + inv_metric_str += ","; } dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric_str); dispatcher.end_record(InfoType::METRIC); // Expected output: - // Begin record marker, followed by key/value pairs each formatted as "key:value;" and then end record marker. + // Begin record marker, followed by key/value pairs each formatted as + // "key:value;" and then end record marker. std::cout << ss_metric.str() << std::endl; } @@ -104,14 +115,14 @@ TEST_F(DispatcherTest, MetricStructuredKeyValueRecord) { // } // TEST_F(DispatcherTest, NonStringDispatchToStructuredIgnored) { -// // Dispatch a non-string (e.g., an int) to a structured writer (METRIC) should be ignored. -// dispatcher.dispatch(InfoType::METRIC, 123); +// // Dispatch a non-string (e.g., an int) to a structured writer (METRIC) +// should be ignored. dispatcher.dispatch(InfoType::METRIC, 123); // EXPECT_EQ(structuredWriter.output, ""); // } // TEST_F(DispatcherTest, UnregisteredInfoType) { -// // Dispatch to an unregistered InfoType (e.g., ALGORITHM_STATE) produces no output. -// dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::string("NoOutput")); -// EXPECT_EQ(plainWriter.output, ""); +// // Dispatch to an unregistered InfoType (e.g., ALGORITHM_STATE) produces no +// output. dispatcher.dispatch(InfoType::ALGORITHM_STATE, +// std::string("NoOutput")); EXPECT_EQ(plainWriter.output, ""); // EXPECT_EQ(structuredWriter.output, ""); // } From 261c26949b60fb0929808420d383cc9c2706dea4 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 6 Mar 2025 14:43:46 -0500 Subject: [PATCH 38/53] claude 3.7 rewrite --- src/stan/callbacks/dispatcher.hpp | 142 ++++++---- src/test/unit/callbacks/dispatcher_test.cpp | 282 ++++++++++++++++---- 2 files changed, 313 insertions(+), 111 deletions(-) diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index e586c57265..aac0cfc71c 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -10,7 +10,7 @@ #include #include #include - +#include namespace stan { namespace callbacks { @@ -31,58 +31,59 @@ namespace stan { // Base type for type erasure. class Channel { public: - virtual ~Channel() {} + virtual ~Channel() = default; }; // Adapter for plain writers. - // These writer types (e.g., stream_writer, unique_stream_writer) support only a one-argument operator(). class WriterChannel : public Channel { public: explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) { - if (!w) - throw std::runtime_error("Null writer pointer provided to WriterChannel"); + if (!w) throw std::runtime_error("Null writer pointer provided to WriterChannel"); } - // Single-argument dispatch: forwards to operator(). + + // Handle all types that writer supports via operator() + void dispatch() { (*writer_)(); } + void dispatch(const std::string& value) { (*writer_)(value); } + void dispatch(const std::vector& value) { (*writer_)(value); } + void dispatch(const std::vector& value) { (*writer_)(value); } + + // Handle any Eigen Matrix type + template + void dispatch(const Eigen::Matrix& value) { (*writer_)(value); } + + // No key-value support for plain writers template - void dispatch(const T& value) { - (*writer_)(value); - } - // Plain writers do not support key/value writes. - void begin_record() { } - void end_record() { } + void dispatch(const std::string&, const T&) {} + private: stan::callbacks::writer* writer_; }; // Adapter for structured writers. - // The structured_writer interface provides a one-argument write(const std::string&) - // and a key/value write(const std::string&, const T&) overload. class StructuredWriterChannel : public Channel { public: explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw) : writer_(sw) { - if (!sw) - throw std::runtime_error("Null structured writer pointer provided to StructuredWriterChannel"); + if (!sw) throw std::runtime_error("Null structured writer pointer provided to StructuredWriterChannel"); } - // Key dispatch - void dispatch(const std::string& key) { - writer_->write(key); - } - // Key/value dispatch + + // Forward all key-value calls directly to the writer + void dispatch(const std::string& key) { writer_->write(key); } + + // Perfect forwarding for any key-value pair template - void dispatch(const std::string& key, const T& value) { + void dispatch(const std::string& key, T&& value) { writer_->write(key, std::forward(value)); } - void begin_record() { - writer_->begin_record(); - } - void end_record() { - writer_->end_record(); - } + + void begin_record() { writer_->begin_record(); } + void begin_record(const std::string& key) { writer_->begin_record(key); } + void end_record() { writer_->end_record(); } + private: stan::callbacks::structured_writer* writer_; }; - // dispatcher class with two overloads for dispatch(). + // dispatcher class class dispatcher { public: dispatcher() = default; @@ -92,44 +93,69 @@ namespace stan { channels_[type] = std::move(channel); } - // Overload for non-string types: only forward to plain writer channels. - template , std::string>::value>> - void dispatch(InfoType type, T&& info) { - auto it = channels_.find(type); - if (it == channels_.end()) return; // silently do nothing - if (auto* wc = dynamic_cast(it->second.get())) - wc->dispatch(std::forward(info)); - // We do not forward non-string types to structured writer channels. + // Empty call + void dispatch(InfoType type) { + if (auto* wc = find_channel(type)) + wc->dispatch(); } - - // Overload for string types: forward to both plain and structured writer channels. - void dispatch(InfoType type, const std::string& info) { - auto it = channels_.find(type); - if (it == channels_.end()) return; // silently do nothing - if (auto* wc = dynamic_cast(it->second.get())) - wc->dispatch(info); - if (auto* sw = dynamic_cast(it->second.get())) - sw->dispatch(info); + + // String, vector, vector + template , std::string> || + std::is_same_v, std::vector> || + std::is_same_v, std::vector> + >> + void dispatch(InfoType type, T&& value) { + if (auto* wc = find_channel(type)) + wc->dispatch(std::forward(value)); } - - // Forward a begin_record call. + + // Eigen matrix types + template + void dispatch(InfoType type, const Eigen::Matrix& value) { + if (auto* wc = find_channel(type)) + wc->dispatch(value); + } + + // Key with no value (null) + void dispatch(InfoType type, const std::string& key) { + if (auto* sw = find_channel(type)) + sw->dispatch(key); + } + + // Key-value pairs (forward to structured writers) + template + void dispatch(InfoType type, const std::string& key, T&& value) { + if (auto* sw = find_channel(type)) + sw->dispatch(key, std::forward(value)); + } + + // Record operations void begin_record(InfoType type) { - auto it = channels_.find(type); - if (it == channels_.end()) return; - if (auto* sw = dynamic_cast(it->second.get())) + if (auto* sw = find_channel(type)) sw->begin_record(); } - - // Forward an end_record call. + + void begin_record(InfoType type, const std::string& key) { + if (auto* sw = find_channel(type)) + sw->begin_record(key); + } + void end_record(InfoType type) { - auto it = channels_.find(type); - if (it == channels_.end()) return; - if (auto* sw = dynamic_cast(it->second.get())) + if (auto* sw = find_channel(type)) sw->end_record(); } - protected: + private: + // Helper to find and cast a channel of specific type + template + ChannelType* find_channel(InfoType type) { + auto it = channels_.find(type); + if (it == channels_.end()) return nullptr; + return dynamic_cast(it->second.get()); + } + std::unordered_map, InfoTypeHash> channels_; }; diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index fb6879c50c..9f2f43bee9 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -8,6 +8,7 @@ #include #include #include +#include // For this test we assume that InfoType has at least these values: using stan::callbacks::InfoType; @@ -38,13 +39,16 @@ class DispatcherTest : public ::testing::Test { ss_metric.clear(); dispatcher.register_channel(InfoType::CONFIG, - std::unique_ptr(new stan::callbacks::WriterChannel(&writer_config))); + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_config))); dispatcher.register_channel(InfoType::SAMPLE, - std::unique_ptr(new stan::callbacks::WriterChannel(&writer_sample))); + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_sample))); dispatcher.register_channel(InfoType::METRIC, - std::unique_ptr(new stan::callbacks::StructuredWriterChannel(&writer_metric))); + std::unique_ptr( + new stan::callbacks::StructuredWriterChannel(&writer_metric))); } void TearDown() {} @@ -59,59 +63,231 @@ class DispatcherTest : public ::testing::Test { stan::callbacks::dispatcher dispatcher; }; -TEST_F(DispatcherTest, ConfigPlainMultipleMessages) { - // Dispatch several string messages to a plain writer (CONFIG). - dispatcher.dispatch(InfoType::CONFIG, std::string("Config1")); - dispatcher.dispatch(InfoType::CONFIG, std::string("Config2")); - dispatcher.dispatch(InfoType::CONFIG, std::string("Config3")); - EXPECT_EQ(ss_config.str(), "Config1\nConfig2\nConfig3\n"); +// Test basic string dispatch to plain writer +TEST_F(DispatcherTest, StringDispatch) { + dispatcher.dispatch(InfoType::CONFIG, std::string("Message1")); + EXPECT_EQ(ss_config.str(), "Message1\n"); } -TEST_F(DispatcherTest, SamplePlainVector) { - // Dispatch a vector of doubles to a plain writer (SAMPLE). - std::vector sample = {1.1, 2.2, 3.3}; - dispatcher.dispatch(InfoType::SAMPLE, sample); - std::string output(ss_sample.str()); - EXPECT_EQ(count_matches("1.1", output), 1); - EXPECT_EQ(count_matches("2.2", output), 1); - EXPECT_EQ(count_matches("2.2", output), 1); +// Test multiple string dispatches +TEST_F(DispatcherTest, MultipleStringDispatch) { + dispatcher.dispatch(InfoType::CONFIG, std::string("Message1")); + dispatcher.dispatch(InfoType::CONFIG, std::string("Message2")); + dispatcher.dispatch(InfoType::CONFIG, std::string("Message3")); + EXPECT_EQ(ss_config.str(), "Message1\nMessage2\nMessage3\n"); } -TEST_F(DispatcherTest, MetricStructuredKeyValueRecord) { - // For METRIC (structured writer), open a record, dispatch key/value pairs, then close the record. +// Test empty call dispatch +TEST_F(DispatcherTest, EmptyDispatch) { + dispatcher.dispatch(InfoType::CONFIG); + // Empty dispatch should produce just a newline in stream_writer + EXPECT_EQ(ss_config.str(), "\n"); +} + +// Test vector of doubles dispatch +TEST_F(DispatcherTest, VectorDoubleDispatch) { + std::vector values = {1.1, 2.2, 3.3}; + dispatcher.dispatch(InfoType::SAMPLE, values); + std::string output = ss_sample.str(); + EXPECT_NE(output.find("1.1"), std::string::npos); + EXPECT_NE(output.find("2.2"), std::string::npos); + EXPECT_NE(output.find("3.3"), std::string::npos); +} + +// Test vector of strings dispatch +TEST_F(DispatcherTest, VectorStringDispatch) { + std::vector names = {"alpha", "beta", "gamma"}; + dispatcher.dispatch(InfoType::SAMPLE, names); + std::string output = ss_sample.str(); + EXPECT_NE(output.find("alpha"), std::string::npos); + EXPECT_NE(output.find("beta"), std::string::npos); + EXPECT_NE(output.find("gamma"), std::string::npos); +} + +// Test Eigen matrix dispatch +TEST_F(DispatcherTest, EigenMatrixDispatch) { + Eigen::MatrixXd matrix(2, 2); + matrix << 1.0, 2.0, 3.0, 4.0; + dispatcher.dispatch(InfoType::SAMPLE, matrix); + std::string output = ss_sample.str(); + EXPECT_NE(output.find("1"), std::string::npos); + EXPECT_NE(output.find("2"), std::string::npos); + EXPECT_NE(output.find("3"), std::string::npos); + EXPECT_NE(output.find("4"), std::string::npos); +} + +// Test Eigen vector dispatch +TEST_F(DispatcherTest, EigenVectorDispatch) { + Eigen::VectorXd vector(3); + vector << 1.0, 2.0, 3.0; + dispatcher.dispatch(InfoType::SAMPLE, vector); + std::string output = ss_sample.str(); + EXPECT_NE(output.find("1"), std::string::npos); + EXPECT_NE(output.find("2"), std::string::npos); + EXPECT_NE(output.find("3"), std::string::npos); +} + +// Test Eigen row vector dispatch +TEST_F(DispatcherTest, EigenRowVectorDispatch) { + Eigen::RowVectorXd vector(3); + vector << 1.0, 2.0, 3.0; + dispatcher.dispatch(InfoType::SAMPLE, vector); + std::string output = ss_sample.str(); + EXPECT_NE(output.find("1"), std::string::npos); + EXPECT_NE(output.find("2"), std::string::npos); + EXPECT_NE(output.find("3"), std::string::npos); +} + +// Test structured writer begin/end record +TEST_F(DispatcherTest, StructuredBeginEndRecord) { dispatcher.begin_record(InfoType::METRIC); - dispatcher.dispatch(InfoType::METRIC, "metric_type", std::string("diag")); - dispatcher.dispatch(InfoType::METRIC, "stepsize", 0.6789); - // For the inv_metric, assume the caller converts the vector to a comma-separated string. - std::vector inv_metric = {0.1, 0.2, 0.3}; - std::string inv_metric_str; - for (size_t i = 0; i < inv_metric.size(); ++i) { - inv_metric_str += std::to_string(inv_metric[i]); - if(i != inv_metric.size() - 1) inv_metric_str += ","; - } - dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric_str); dispatcher.end_record(InfoType::METRIC); - // Expected output: - // Begin record marker, followed by key/value pairs each formatted as "key:value;" and then end record marker. - std::cout << ss_metric.str() << std::endl; -} - -// TEST_F(DispatcherTest, NonStringDispatchToPlain) { -// // Dispatch a non-string (e.g., an int) to a plain writer (CONFIG). -// dispatcher.dispatch(InfoType::CONFIG, 999); -// // Expect conversion via std::to_string, so output should be "999". -// EXPECT_EQ(plainWriter.output, "999"); -// } - -// TEST_F(DispatcherTest, NonStringDispatchToStructuredIgnored) { -// // Dispatch a non-string (e.g., an int) to a structured writer (METRIC) should be ignored. -// dispatcher.dispatch(InfoType::METRIC, 123); -// EXPECT_EQ(structuredWriter.output, ""); -// } - -// TEST_F(DispatcherTest, UnregisteredInfoType) { -// // Dispatch to an unregistered InfoType (e.g., ALGORITHM_STATE) produces no output. -// dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::string("NoOutput")); -// EXPECT_EQ(plainWriter.output, ""); -// EXPECT_EQ(structuredWriter.output, ""); -// } + std::string output = ss_metric.str(); + // JSON output should contain opening and closing braces + EXPECT_NE(output.find("{"), std::string::npos); + EXPECT_NE(output.find("}"), std::string::npos); +} + +// Test structured writer key-value pairs with string value +TEST_F(DispatcherTest, StructuredKeyStringValue) { + dispatcher.begin_record(InfoType::METRIC); + dispatcher.dispatch(InfoType::METRIC, "key1", std::string("value1")); + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("key1"), std::string::npos); + EXPECT_NE(output.find("value1"), std::string::npos); +} + +// Test structured writer with multiple key-value types +TEST_F(DispatcherTest, StructuredMultipleValueTypes) { + dispatcher.begin_record(InfoType::METRIC); + dispatcher.dispatch(InfoType::METRIC, "string_key", std::string("string_value")); + dispatcher.dispatch(InfoType::METRIC, "int_key", 42); + dispatcher.dispatch(InfoType::METRIC, "double_key", 3.14159); + dispatcher.dispatch(InfoType::METRIC, "bool_key", true); + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("string_key"), std::string::npos); + EXPECT_NE(output.find("string_value"), std::string::npos); + EXPECT_NE(output.find("int_key"), std::string::npos); + EXPECT_NE(output.find("42"), std::string::npos); + EXPECT_NE(output.find("double_key"), std::string::npos); + EXPECT_NE(output.find("3.14159"), std::string::npos); + EXPECT_NE(output.find("bool_key"), std::string::npos); + EXPECT_NE(output.find("true"), std::string::npos); +} + +// Test structured writer with vector values +TEST_F(DispatcherTest, StructuredVectorValues) { + dispatcher.begin_record(InfoType::METRIC); + + std::vector doubles = {1.1, 2.2, 3.3}; + dispatcher.dispatch(InfoType::METRIC, "doubles", doubles); + + std::vector strings = {"one", "two", "three"}; + dispatcher.dispatch(InfoType::METRIC, "strings", strings); + + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("doubles"), std::string::npos); + EXPECT_NE(output.find("1.1"), std::string::npos); + EXPECT_NE(output.find("strings"), std::string::npos); + EXPECT_NE(output.find("one"), std::string::npos); +} + +// Test structured writer with Eigen values +TEST_F(DispatcherTest, StructuredEigenValues) { + dispatcher.begin_record(InfoType::METRIC); + + Eigen::MatrixXd matrix(2, 2); + matrix << 1.0, 2.0, 3.0, 4.0; + dispatcher.dispatch(InfoType::METRIC, "matrix", matrix); + + Eigen::VectorXd vector(3); + vector << 5.0, 6.0, 7.0; + dispatcher.dispatch(InfoType::METRIC, "vector", vector); + + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("matrix"), std::string::npos); + EXPECT_NE(output.find("1"), std::string::npos); + EXPECT_NE(output.find("4"), std::string::npos); + EXPECT_NE(output.find("vector"), std::string::npos); + EXPECT_NE(output.find("5"), std::string::npos); + EXPECT_NE(output.find("7"), std::string::npos); +} + +// Test unregistered channel +TEST_F(DispatcherTest, UnregisteredChannel) { + // Dispatch to unregistered channel should silently do nothing + dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::string("Message")); + dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::vector{1.0, 2.0}); + dispatcher.begin_record(InfoType::ALGORITHM_STATE); + dispatcher.dispatch(InfoType::ALGORITHM_STATE, "key", "value"); + dispatcher.end_record(InfoType::ALGORITHM_STATE); + + // No exceptions should be thrown +} + +// Test named record +TEST_F(DispatcherTest, NamedRecord) { + dispatcher.begin_record(InfoType::METRIC, "record_name"); + dispatcher.dispatch(InfoType::METRIC, "key", "value"); + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("record_name"), std::string::npos); + EXPECT_NE(output.find("key"), std::string::npos); + EXPECT_NE(output.find("value"), std::string::npos); +} + +// Test that begin_record and end_record on a plain writer channel are silently ignored +TEST_F(DispatcherTest, RecordOperationsOnPlainWriter) { + dispatcher.begin_record(InfoType::CONFIG); + dispatcher.end_record(InfoType::CONFIG); + + // Should not generate any output + EXPECT_EQ(ss_config.str(), ""); +} + +// Test complex sampler metric output pattern +TEST_F(DispatcherTest, ComplexSamplerMetricPattern) { + // This test simulates a more complex real-world usage pattern + + // Start a record for a sampling iteration + dispatcher.begin_record(InfoType::METRIC); + + // Add various diagnostic info + dispatcher.dispatch(InfoType::METRIC, "iter", 10); + dispatcher.dispatch(InfoType::METRIC, "lp", -105.2); + dispatcher.dispatch(InfoType::METRIC, "accept_stat", 0.8); + + // Add a nested object for adaptation + dispatcher.begin_record(InfoType::METRIC, "adaptation"); + dispatcher.dispatch(InfoType::METRIC, "step_size", 0.85); + + // Add an inverse metric matrix + Eigen::MatrixXd inv_metric(2, 2); + inv_metric << 1.2, 0.1, 0.1, 0.9; + dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric); + + // End adaptation object + dispatcher.end_record(InfoType::METRIC); + + // End the main record + dispatcher.end_record(InfoType::METRIC); + + // Verify key entries exist in the output + std::string output = ss_metric.str(); + EXPECT_NE(output.find("iter"), std::string::npos); + EXPECT_NE(output.find("10"), std::string::npos); + EXPECT_NE(output.find("lp"), std::string::npos); + EXPECT_NE(output.find("-105.2"), std::string::npos); + EXPECT_NE(output.find("adaptation"), std::string::npos); + EXPECT_NE(output.find("step_size"), std::string::npos); + EXPECT_NE(output.find("inv_metric"), std::string::npos); +} From e082f64e665722b2aa40f3c0e94e431af6ca4f67 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 6 Mar 2025 15:03:03 -0500 Subject: [PATCH 39/53] lint fix --- src/stan/callbacks/dispatcher.hpp | 292 +++++++++++++++--------------- 1 file changed, 149 insertions(+), 143 deletions(-) diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index aac0cfc71c..b2d46fdd4c 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -13,153 +13,159 @@ #include namespace stan { - namespace callbacks { - - enum class InfoType { - CONFIG, // series of string messages - SAMPLE, // draw from posterior - METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' - ALGORITHM_STATE, // sampler state for returned draw - }; - - struct InfoTypeHash { - std::size_t operator()(const InfoType& type) const { - return std::hash()(static_cast(type)); - } - }; - - // Base type for type erasure. - class Channel { - public: - virtual ~Channel() = default; - }; - - // Adapter for plain writers. - class WriterChannel : public Channel { - public: - explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) { - if (!w) throw std::runtime_error("Null writer pointer provided to WriterChannel"); - } +namespace callbacks { - // Handle all types that writer supports via operator() - void dispatch() { (*writer_)(); } - void dispatch(const std::string& value) { (*writer_)(value); } - void dispatch(const std::vector& value) { (*writer_)(value); } - void dispatch(const std::vector& value) { (*writer_)(value); } - - // Handle any Eigen Matrix type - template - void dispatch(const Eigen::Matrix& value) { (*writer_)(value); } - - // No key-value support for plain writers - template - void dispatch(const std::string&, const T&) {} - - private: - stan::callbacks::writer* writer_; - }; - - // Adapter for structured writers. - class StructuredWriterChannel : public Channel { - public: - explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw) : writer_(sw) { - if (!sw) throw std::runtime_error("Null structured writer pointer provided to StructuredWriterChannel"); - } - - // Forward all key-value calls directly to the writer - void dispatch(const std::string& key) { writer_->write(key); } - - // Perfect forwarding for any key-value pair - template - void dispatch(const std::string& key, T&& value) { - writer_->write(key, std::forward(value)); - } - - void begin_record() { writer_->begin_record(); } - void begin_record(const std::string& key) { writer_->begin_record(key); } - void end_record() { writer_->end_record(); } - - private: - stan::callbacks::structured_writer* writer_; - }; - - // dispatcher class - class dispatcher { - public: - dispatcher() = default; - ~dispatcher() = default; - - void register_channel(InfoType type, std::unique_ptr channel) { - channels_[type] = std::move(channel); - } + enum class InfoType { + CONFIG, // series of string messages + SAMPLE, // draw from posterior + METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' + ALGORITHM_STATE, // sampler state for returned draw + }; + + struct InfoTypeHash { + std::size_t operator()(const InfoType& type) const { + return std::hash()(static_cast(type)); + } + }; - // Empty call - void dispatch(InfoType type) { - if (auto* wc = find_channel(type)) - wc->dispatch(); + // Base type for type erasure. + class Channel { + public: + virtual ~Channel() = default; + }; + + // Adapter for plain writers. + class WriterChannel : public Channel { + public: + explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) { + if (!w) { + throw std::runtime_error("config error, null writer"); } - - // String, vector, vector - template , std::string> || - std::is_same_v, std::vector> || - std::is_same_v, std::vector> + } + + // Handle all types that writer supports via operator() + void dispatch() { (*writer_)(); } + void dispatch(const std::string& value) { (*writer_)(value); } + void dispatch(const std::vector& value) {(*writer_)(value); } + void dispatch(const std::vector& value) { + (*writer_)(value); + } + + // Handle any Eigen Matrix type + template + void dispatch(const Eigen::Matrix& value) { + (*writer_)(value); + } + + // No key-value support for plain writers + template + void dispatch(const std::string&, const T&) {} + + private: + stan::callbacks::writer* writer_; + }; + + // Adapter for structured writers. + class StructuredWriterChannel : public Channel { + public: + explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw) + : writer_(sw) { + if (!sw) throw std::runtime_error("config error, null writer"); + } + // Forward all key-value calls directly to the writer + void dispatch(const std::string& key) { writer_->write(key); } + // Perfect forwarding for any key-value pair + template + void dispatch(const std::string& key, T&& value) { + writer_->write(key, std::forward(value)); + } + void begin_record() { writer_->begin_record(); } + void begin_record(const std::string& key) { writer_->begin_record(key); } + void end_record() { writer_->end_record(); } + + private: + stan::callbacks::structured_writer* writer_; + }; + + // dispatcher class + class dispatcher { + public: + dispatcher() = default; + ~dispatcher() = default; + + void register_channel(InfoType type, std::unique_ptr channel) { + channels_[type] = std::move(channel); + } + + // Empty call + void dispatch(InfoType type) { + if (auto* wc = find_channel(type)) + wc->dispatch(); + } + + // String, vector, vector + template , std::string> || + std::is_same_v, std::vector> || + std::is_same_v, std::vector> >> - void dispatch(InfoType type, T&& value) { - if (auto* wc = find_channel(type)) - wc->dispatch(std::forward(value)); - } - - // Eigen matrix types - template - void dispatch(InfoType type, const Eigen::Matrix& value) { - if (auto* wc = find_channel(type)) - wc->dispatch(value); - } - - // Key with no value (null) - void dispatch(InfoType type, const std::string& key) { - if (auto* sw = find_channel(type)) - sw->dispatch(key); - } - - // Key-value pairs (forward to structured writers) - template - void dispatch(InfoType type, const std::string& key, T&& value) { - if (auto* sw = find_channel(type)) - sw->dispatch(key, std::forward(value)); - } - - // Record operations - void begin_record(InfoType type) { - if (auto* sw = find_channel(type)) - sw->begin_record(); - } - - void begin_record(InfoType type, const std::string& key) { - if (auto* sw = find_channel(type)) - sw->begin_record(key); - } - - void end_record(InfoType type) { - if (auto* sw = find_channel(type)) - sw->end_record(); - } + void dispatch(InfoType type, T&& value) { + if (auto* wc = find_channel(type)) + wc->dispatch(std::forward(value)); + } - private: - // Helper to find and cast a channel of specific type - template - ChannelType* find_channel(InfoType type) { - auto it = channels_.find(type); - if (it == channels_.end()) return nullptr; - return dynamic_cast(it->second.get()); - } - - std::unordered_map, InfoTypeHash> channels_; - }; + // Eigen matrix types + template + void dispatch(InfoType type, const Eigen::Matrix& value) { + if (auto* wc = find_channel(type)) + wc->dispatch(value); + } + + // Key with no value (null) + void dispatch(InfoType type, const std::string& key) { + if (auto* sw = find_channel(type)) + sw->dispatch(key); + } + + // Key-value pairs (forward to structured writers) + template + void dispatch(InfoType type, const std::string& key, T&& value) { + if (auto* sw = find_channel(type)) + sw->dispatch(key, std::forward(value)); + } + + // Record operations + void begin_record(InfoType type) { + if (auto* sw = find_channel(type)) + sw->begin_record(); + } + + void begin_record(InfoType type, const std::string& key) { + if (auto* sw = find_channel(type)) + sw->begin_record(key); + } + + void end_record(InfoType type) { + if (auto* sw = find_channel(type)) + sw->end_record(); + } + + private: + // Helper to find and cast a channel of specific type + template + ChannelType* find_channel(InfoType type) { + auto it = channels_.find(type); + if (it == channels_.end()) return nullptr; + return dynamic_cast(it->second.get()); + } + + std::unordered_map, + InfoTypeHash> channels_; + }; - } // namespace callbacks -} // namespace stan +} // namespace callbacks +} // namespace stan -#endif // STAN_CALLBACKS_DISPATCHER_HPP +#endif // STAN_CALLBACKS_DISPATCHER_HPP From 91f7d4ca547ab10032383bb0416cc8f72b5a06a8 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 6 Mar 2025 15:07:40 -0500 Subject: [PATCH 40/53] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/stan/callbacks/dispatcher.hpp | 297 +++++++++-------- src/test/unit/callbacks/dispatcher_test.cpp | 340 ++++++++++---------- 2 files changed, 321 insertions(+), 316 deletions(-) diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index 3280c105e5..4593a8d51e 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -12,159 +12,158 @@ #include #include - namespace stan { namespace callbacks { - enum class InfoType { - CONFIG, // series of string messages - SAMPLE, // draw from posterior - METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' - ALGORITHM_STATE, // sampler state for returned draw - }; - - struct InfoTypeHash { - std::size_t operator()(const InfoType& type) const { - return std::hash()(static_cast(type)); - } - }; - - // Base type for type erasure. - class Channel { - public: - virtual ~Channel() = default; - }; - - // Adapter for plain writers. - class WriterChannel : public Channel { - public: - explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) { - if (!w) { - throw std::runtime_error("config error, null writer"); - } - } - - // Handle all types that writer supports via operator() - void dispatch() { (*writer_)(); } - void dispatch(const std::string& value) { (*writer_)(value); } - void dispatch(const std::vector& value) {(*writer_)(value); } - void dispatch(const std::vector& value) { - (*writer_)(value); - } - - // Handle any Eigen Matrix type - template - void dispatch(const Eigen::Matrix& value) { - (*writer_)(value); - } - - // No key-value support for plain writers - template - void dispatch(const std::string&, const T&) {} - - private: - stan::callbacks::writer* writer_; - }; - - // Adapter for structured writers. - class StructuredWriterChannel : public Channel { - public: - explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw) +enum class InfoType { + CONFIG, // series of string messages + SAMPLE, // draw from posterior + METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' + ALGORITHM_STATE, // sampler state for returned draw +}; + +struct InfoTypeHash { + std::size_t operator()(const InfoType& type) const { + return std::hash()(static_cast(type)); + } +}; + +// Base type for type erasure. +class Channel { + public: + virtual ~Channel() = default; +}; + +// Adapter for plain writers. +class WriterChannel : public Channel { + public: + explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) { + if (!w) { + throw std::runtime_error("config error, null writer"); + } + } + + // Handle all types that writer supports via operator() + void dispatch() { (*writer_)(); } + void dispatch(const std::string& value) { (*writer_)(value); } + void dispatch(const std::vector& value) { (*writer_)(value); } + void dispatch(const std::vector& value) { (*writer_)(value); } + + // Handle any Eigen Matrix type + template + void dispatch(const Eigen::Matrix& value) { + (*writer_)(value); + } + + // No key-value support for plain writers + template + void dispatch(const std::string&, const T&) {} + + private: + stan::callbacks::writer* writer_; +}; + +// Adapter for structured writers. +class StructuredWriterChannel : public Channel { + public: + explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw) : writer_(sw) { - if (!sw) throw std::runtime_error("config error, null writer"); - } - // Forward all key-value calls directly to the writer - void dispatch(const std::string& key) { writer_->write(key); } - // Perfect forwarding for any key-value pair - template - void dispatch(const std::string& key, T&& value) { - writer_->write(key, std::forward(value)); - } - void begin_record() { writer_->begin_record(); } - void begin_record(const std::string& key) { writer_->begin_record(key); } - void end_record() { writer_->end_record(); } - - private: - stan::callbacks::structured_writer* writer_; - }; - - // dispatcher class - class dispatcher { - public: - dispatcher() = default; - ~dispatcher() = default; - - void register_channel(InfoType type, std::unique_ptr channel) { - channels_[type] = std::move(channel); - } - - // Empty call - void dispatch(InfoType type) { - if (auto* wc = find_channel(type)) - wc->dispatch(); - } - - // String, vector, vector - template , std::string> || - std::is_same_v, std::vector> || - std::is_same_v, std::vector> - >> - void dispatch(InfoType type, T&& value) { - if (auto* wc = find_channel(type)) - wc->dispatch(std::forward(value)); - } - - // Eigen matrix types - template - void dispatch(InfoType type, const Eigen::Matrix& value) { - if (auto* wc = find_channel(type)) - wc->dispatch(value); - } - - // Key with no value (null) - void dispatch(InfoType type, const std::string& key) { - if (auto* sw = find_channel(type)) - sw->dispatch(key); - } - - // Key-value pairs (forward to structured writers) - template - void dispatch(InfoType type, const std::string& key, T&& value) { - if (auto* sw = find_channel(type)) - sw->dispatch(key, std::forward(value)); - } - - // Record operations - void begin_record(InfoType type) { - if (auto* sw = find_channel(type)) - sw->begin_record(); - } - - void begin_record(InfoType type, const std::string& key) { - if (auto* sw = find_channel(type)) - sw->begin_record(key); - } - - void end_record(InfoType type) { - if (auto* sw = find_channel(type)) - sw->end_record(); - } - - private: - // Helper to find and cast a channel of specific type - template - ChannelType* find_channel(InfoType type) { - auto it = channels_.find(type); - if (it == channels_.end()) return nullptr; - return dynamic_cast(it->second.get()); - } - - std::unordered_map, - InfoTypeHash> channels_; - }; + if (!sw) + throw std::runtime_error("config error, null writer"); + } + // Forward all key-value calls directly to the writer + void dispatch(const std::string& key) { writer_->write(key); } + // Perfect forwarding for any key-value pair + template + void dispatch(const std::string& key, T&& value) { + writer_->write(key, std::forward(value)); + } + void begin_record() { writer_->begin_record(); } + void begin_record(const std::string& key) { writer_->begin_record(key); } + void end_record() { writer_->end_record(); } + + private: + stan::callbacks::structured_writer* writer_; +}; + +// dispatcher class +class dispatcher { + public: + dispatcher() = default; + ~dispatcher() = default; + + void register_channel(InfoType type, std::unique_ptr channel) { + channels_[type] = std::move(channel); + } + + // Empty call + void dispatch(InfoType type) { + if (auto* wc = find_channel(type)) + wc->dispatch(); + } + + // String, vector, vector + template < + typename T, + typename = std::enable_if_t< + std::is_same_v< + std::decay_t, + std:: + string> || std::is_same_v, std::vector> || std::is_same_v, std::vector>>> + void dispatch(InfoType type, T&& value) { + if (auto* wc = find_channel(type)) + wc->dispatch(std::forward(value)); + } + + // Eigen matrix types + template + void dispatch(InfoType type, const Eigen::Matrix& value) { + if (auto* wc = find_channel(type)) + wc->dispatch(value); + } + + // Key with no value (null) + void dispatch(InfoType type, const std::string& key) { + if (auto* sw = find_channel(type)) + sw->dispatch(key); + } + + // Key-value pairs (forward to structured writers) + template + void dispatch(InfoType type, const std::string& key, T&& value) { + if (auto* sw = find_channel(type)) + sw->dispatch(key, std::forward(value)); + } + + // Record operations + void begin_record(InfoType type) { + if (auto* sw = find_channel(type)) + sw->begin_record(); + } + + void begin_record(InfoType type, const std::string& key) { + if (auto* sw = find_channel(type)) + sw->begin_record(key); + } + + void end_record(InfoType type) { + if (auto* sw = find_channel(type)) + sw->end_record(); + } + + private: + // Helper to find and cast a channel of specific type + template + ChannelType* find_channel(InfoType type) { + auto it = channels_.find(type); + if (it == channels_.end()) + return nullptr; + return dynamic_cast(it->second.get()); + } + + std::unordered_map, InfoTypeHash> + channels_; +}; } // namespace callbacks } // namespace stan diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index 39cdb32025..37c2cf635c 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -39,17 +39,20 @@ class DispatcherTest : public ::testing::Test { ss_metric.str(std::string()); ss_metric.clear(); - dispatcher.register_channel(InfoType::CONFIG, - std::unique_ptr( - new stan::callbacks::WriterChannel(&writer_config))); + dispatcher.register_channel( + InfoType::CONFIG, + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_config))); - dispatcher.register_channel(InfoType::SAMPLE, - std::unique_ptr( - new stan::callbacks::WriterChannel(&writer_sample))); + dispatcher.register_channel( + InfoType::SAMPLE, + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_sample))); - dispatcher.register_channel(InfoType::METRIC, - std::unique_ptr( - new stan::callbacks::StructuredWriterChannel(&writer_metric))); + dispatcher.register_channel( + InfoType::METRIC, + std::unique_ptr( + new stan::callbacks::StructuredWriterChannel(&writer_metric))); } void TearDown() {} @@ -148,171 +151,174 @@ TEST_F(DispatcherTest, StructuredBeginEndRecord) { EXPECT_NE(output.find("{"), std::string::npos); EXPECT_NE(output.find("}"), std::string::npos); ======= -TEST_F(DispatcherTest, MetricStructuredKeyValueRecord) { - // For METRIC (structured writer), open a record, dispatch key/value pairs, - // then close the record. - dispatcher.begin_record(InfoType::METRIC); - dispatcher.dispatch(InfoType::METRIC, "metric_type", std::string("diag")); - dispatcher.dispatch(InfoType::METRIC, "stepsize", 0.6789); - // For the inv_metric, assume the caller converts the vector to a - // comma-separated string. - std::vector inv_metric = {0.1, 0.2, 0.3}; - std::string inv_metric_str; - for (size_t i = 0; i < inv_metric.size(); ++i) { - inv_metric_str += std::to_string(inv_metric[i]); - if (i != inv_metric.size() - 1) - inv_metric_str += ","; - } - dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric_str); - dispatcher.end_record(InfoType::METRIC); - // Expected output: - // Begin record marker, followed by key/value pairs each formatted as - // "key:value;" and then end record marker. - std::cout << ss_metric.str() << std::endl; + TEST_F(DispatcherTest, MetricStructuredKeyValueRecord) { + // For METRIC (structured writer), open a record, dispatch key/value pairs, + // then close the record. + dispatcher.begin_record(InfoType::METRIC); + dispatcher.dispatch(InfoType::METRIC, "metric_type", std::string("diag")); + dispatcher.dispatch(InfoType::METRIC, "stepsize", 0.6789); + // For the inv_metric, assume the caller converts the vector to a + // comma-separated string. + std::vector inv_metric = {0.1, 0.2, 0.3}; + std::string inv_metric_str; + for (size_t i = 0; i < inv_metric.size(); ++i) { + inv_metric_str += std::to_string(inv_metric[i]); + if (i != inv_metric.size() - 1) + inv_metric_str += ","; + } + dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric_str); + dispatcher.end_record(InfoType::METRIC); + // Expected output: + // Begin record marker, followed by key/value pairs each formatted as + // "key:value;" and then end record marker. + std::cout << ss_metric.str() << std::endl; >>>>>>> 89d756b23a601c560cb63a09930f5fcbce011efc -} + } -// Test structured writer key-value pairs with string value -TEST_F(DispatcherTest, StructuredKeyStringValue) { - dispatcher.begin_record(InfoType::METRIC); - dispatcher.dispatch(InfoType::METRIC, "key1", std::string("value1")); - dispatcher.end_record(InfoType::METRIC); - - std::string output = ss_metric.str(); - EXPECT_NE(output.find("key1"), std::string::npos); - EXPECT_NE(output.find("value1"), std::string::npos); -} + // Test structured writer key-value pairs with string value + TEST_F(DispatcherTest, StructuredKeyStringValue) { + dispatcher.begin_record(InfoType::METRIC); + dispatcher.dispatch(InfoType::METRIC, "key1", std::string("value1")); + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("key1"), std::string::npos); + EXPECT_NE(output.find("value1"), std::string::npos); + } <<<<<<< HEAD -// Test structured writer with multiple key-value types -TEST_F(DispatcherTest, StructuredMultipleValueTypes) { - dispatcher.begin_record(InfoType::METRIC); - dispatcher.dispatch(InfoType::METRIC, "string_key", std::string("string_value")); - dispatcher.dispatch(InfoType::METRIC, "int_key", 42); - dispatcher.dispatch(InfoType::METRIC, "double_key", 3.14159); - dispatcher.dispatch(InfoType::METRIC, "bool_key", true); - dispatcher.end_record(InfoType::METRIC); - - std::string output = ss_metric.str(); - EXPECT_NE(output.find("string_key"), std::string::npos); - EXPECT_NE(output.find("string_value"), std::string::npos); - EXPECT_NE(output.find("int_key"), std::string::npos); - EXPECT_NE(output.find("42"), std::string::npos); - EXPECT_NE(output.find("double_key"), std::string::npos); - EXPECT_NE(output.find("3.14159"), std::string::npos); - EXPECT_NE(output.find("bool_key"), std::string::npos); - EXPECT_NE(output.find("true"), std::string::npos); -} + // Test structured writer with multiple key-value types + TEST_F(DispatcherTest, StructuredMultipleValueTypes) { + dispatcher.begin_record(InfoType::METRIC); + dispatcher.dispatch(InfoType::METRIC, "string_key", + std::string("string_value")); + dispatcher.dispatch(InfoType::METRIC, "int_key", 42); + dispatcher.dispatch(InfoType::METRIC, "double_key", 3.14159); + dispatcher.dispatch(InfoType::METRIC, "bool_key", true); + dispatcher.end_record(InfoType::METRIC); -// Test structured writer with vector values -TEST_F(DispatcherTest, StructuredVectorValues) { - dispatcher.begin_record(InfoType::METRIC); - - std::vector doubles = {1.1, 2.2, 3.3}; - dispatcher.dispatch(InfoType::METRIC, "doubles", doubles); - - std::vector strings = {"one", "two", "three"}; - dispatcher.dispatch(InfoType::METRIC, "strings", strings); - - dispatcher.end_record(InfoType::METRIC); - - std::string output = ss_metric.str(); - EXPECT_NE(output.find("doubles"), std::string::npos); - EXPECT_NE(output.find("1.1"), std::string::npos); - EXPECT_NE(output.find("strings"), std::string::npos); - EXPECT_NE(output.find("one"), std::string::npos); -} + std::string output = ss_metric.str(); + EXPECT_NE(output.find("string_key"), std::string::npos); + EXPECT_NE(output.find("string_value"), std::string::npos); + EXPECT_NE(output.find("int_key"), std::string::npos); + EXPECT_NE(output.find("42"), std::string::npos); + EXPECT_NE(output.find("double_key"), std::string::npos); + EXPECT_NE(output.find("3.14159"), std::string::npos); + EXPECT_NE(output.find("bool_key"), std::string::npos); + EXPECT_NE(output.find("true"), std::string::npos); + } -// Test structured writer with Eigen values -TEST_F(DispatcherTest, StructuredEigenValues) { - dispatcher.begin_record(InfoType::METRIC); - - Eigen::MatrixXd matrix(2, 2); - matrix << 1.0, 2.0, 3.0, 4.0; - dispatcher.dispatch(InfoType::METRIC, "matrix", matrix); - - Eigen::VectorXd vector(3); - vector << 5.0, 6.0, 7.0; - dispatcher.dispatch(InfoType::METRIC, "vector", vector); - - dispatcher.end_record(InfoType::METRIC); - - std::string output = ss_metric.str(); - EXPECT_NE(output.find("matrix"), std::string::npos); - EXPECT_NE(output.find("1"), std::string::npos); - EXPECT_NE(output.find("4"), std::string::npos); - EXPECT_NE(output.find("vector"), std::string::npos); - EXPECT_NE(output.find("5"), std::string::npos); - EXPECT_NE(output.find("7"), std::string::npos); -} + // Test structured writer with vector values + TEST_F(DispatcherTest, StructuredVectorValues) { + dispatcher.begin_record(InfoType::METRIC); -// Test unregistered channel -TEST_F(DispatcherTest, UnregisteredChannel) { - // Dispatch to unregistered channel should silently do nothing - dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::string("Message")); - dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::vector{1.0, 2.0}); - dispatcher.begin_record(InfoType::ALGORITHM_STATE); - dispatcher.dispatch(InfoType::ALGORITHM_STATE, "key", "value"); - dispatcher.end_record(InfoType::ALGORITHM_STATE); - - // No exceptions should be thrown -} + std::vector doubles = {1.1, 2.2, 3.3}; + dispatcher.dispatch(InfoType::METRIC, "doubles", doubles); -// Test named record -TEST_F(DispatcherTest, NamedRecord) { - dispatcher.begin_record(InfoType::METRIC, "record_name"); - dispatcher.dispatch(InfoType::METRIC, "key", "value"); - dispatcher.end_record(InfoType::METRIC); - - std::string output = ss_metric.str(); - EXPECT_NE(output.find("record_name"), std::string::npos); - EXPECT_NE(output.find("key"), std::string::npos); - EXPECT_NE(output.find("value"), std::string::npos); -} + std::vector strings = {"one", "two", "three"}; + dispatcher.dispatch(InfoType::METRIC, "strings", strings); -// Test that begin_record and end_record on a plain writer channel are silently ignored -TEST_F(DispatcherTest, RecordOperationsOnPlainWriter) { - dispatcher.begin_record(InfoType::CONFIG); - dispatcher.end_record(InfoType::CONFIG); - - // Should not generate any output - EXPECT_EQ(ss_config.str(), ""); -} + dispatcher.end_record(InfoType::METRIC); -// Test complex sampler metric output pattern -TEST_F(DispatcherTest, ComplexSamplerMetricPattern) { - // This test simulates a more complex real-world usage pattern - - // Start a record for a sampling iteration - dispatcher.begin_record(InfoType::METRIC); - - // Add various diagnostic info - dispatcher.dispatch(InfoType::METRIC, "iter", 10); - dispatcher.dispatch(InfoType::METRIC, "lp", -105.2); - dispatcher.dispatch(InfoType::METRIC, "accept_stat", 0.8); - - // Add a nested object for adaptation - dispatcher.begin_record(InfoType::METRIC, "adaptation"); - dispatcher.dispatch(InfoType::METRIC, "step_size", 0.85); - - // Add an inverse metric matrix - Eigen::MatrixXd inv_metric(2, 2); - inv_metric << 1.2, 0.1, 0.1, 0.9; - dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric); - - // End adaptation object - dispatcher.end_record(InfoType::METRIC); - - // End the main record - dispatcher.end_record(InfoType::METRIC); - - // Verify key entries exist in the output - std::string output = ss_metric.str(); - EXPECT_NE(output.find("iter"), std::string::npos); - EXPECT_NE(output.find("10"), std::string::npos); - EXPECT_NE(output.find("lp"), std::string::npos); - EXPECT_NE(output.find("-105.2"), std::string::npos); - EXPECT_NE(output.find("adaptation"), std::string::npos); - EXPECT_NE(output.find("step_size"), std::string::npos); - EXPECT_NE(output.find("inv_metric"), std::string::npos); -} + std::string output = ss_metric.str(); + EXPECT_NE(output.find("doubles"), std::string::npos); + EXPECT_NE(output.find("1.1"), std::string::npos); + EXPECT_NE(output.find("strings"), std::string::npos); + EXPECT_NE(output.find("one"), std::string::npos); + } + + // Test structured writer with Eigen values + TEST_F(DispatcherTest, StructuredEigenValues) { + dispatcher.begin_record(InfoType::METRIC); + + Eigen::MatrixXd matrix(2, 2); + matrix << 1.0, 2.0, 3.0, 4.0; + dispatcher.dispatch(InfoType::METRIC, "matrix", matrix); + + Eigen::VectorXd vector(3); + vector << 5.0, 6.0, 7.0; + dispatcher.dispatch(InfoType::METRIC, "vector", vector); + + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("matrix"), std::string::npos); + EXPECT_NE(output.find("1"), std::string::npos); + EXPECT_NE(output.find("4"), std::string::npos); + EXPECT_NE(output.find("vector"), std::string::npos); + EXPECT_NE(output.find("5"), std::string::npos); + EXPECT_NE(output.find("7"), std::string::npos); + } + + // Test unregistered channel + TEST_F(DispatcherTest, UnregisteredChannel) { + // Dispatch to unregistered channel should silently do nothing + dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::string("Message")); + dispatcher.dispatch(InfoType::ALGORITHM_STATE, + std::vector{1.0, 2.0}); + dispatcher.begin_record(InfoType::ALGORITHM_STATE); + dispatcher.dispatch(InfoType::ALGORITHM_STATE, "key", "value"); + dispatcher.end_record(InfoType::ALGORITHM_STATE); + + // No exceptions should be thrown + } + + // Test named record + TEST_F(DispatcherTest, NamedRecord) { + dispatcher.begin_record(InfoType::METRIC, "record_name"); + dispatcher.dispatch(InfoType::METRIC, "key", "value"); + dispatcher.end_record(InfoType::METRIC); + + std::string output = ss_metric.str(); + EXPECT_NE(output.find("record_name"), std::string::npos); + EXPECT_NE(output.find("key"), std::string::npos); + EXPECT_NE(output.find("value"), std::string::npos); + } + + // Test that begin_record and end_record on a plain writer channel are + // silently ignored + TEST_F(DispatcherTest, RecordOperationsOnPlainWriter) { + dispatcher.begin_record(InfoType::CONFIG); + dispatcher.end_record(InfoType::CONFIG); + + // Should not generate any output + EXPECT_EQ(ss_config.str(), ""); + } + + // Test complex sampler metric output pattern + TEST_F(DispatcherTest, ComplexSamplerMetricPattern) { + // This test simulates a more complex real-world usage pattern + + // Start a record for a sampling iteration + dispatcher.begin_record(InfoType::METRIC); + + // Add various diagnostic info + dispatcher.dispatch(InfoType::METRIC, "iter", 10); + dispatcher.dispatch(InfoType::METRIC, "lp", -105.2); + dispatcher.dispatch(InfoType::METRIC, "accept_stat", 0.8); + + // Add a nested object for adaptation + dispatcher.begin_record(InfoType::METRIC, "adaptation"); + dispatcher.dispatch(InfoType::METRIC, "step_size", 0.85); + + // Add an inverse metric matrix + Eigen::MatrixXd inv_metric(2, 2); + inv_metric << 1.2, 0.1, 0.1, 0.9; + dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric); + + // End adaptation object + dispatcher.end_record(InfoType::METRIC); + + // End the main record + dispatcher.end_record(InfoType::METRIC); + + // Verify key entries exist in the output + std::string output = ss_metric.str(); + EXPECT_NE(output.find("iter"), std::string::npos); + EXPECT_NE(output.find("10"), std::string::npos); + EXPECT_NE(output.find("lp"), std::string::npos); + EXPECT_NE(output.find("-105.2"), std::string::npos); + EXPECT_NE(output.find("adaptation"), std::string::npos); + EXPECT_NE(output.find("step_size"), std::string::npos); + EXPECT_NE(output.find("inv_metric"), std::string::npos); + } From b5e99554e2f7b9073d1294a67f32c4c9f8a9c574 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 6 Mar 2025 16:07:03 -0500 Subject: [PATCH 41/53] unit tests --- lib/stan_math | 2 +- src/test/unit/callbacks/dispatcher_test.cpp | 29 ++++----------------- 2 files changed, 6 insertions(+), 25 deletions(-) diff --git a/lib/stan_math b/lib/stan_math index afea210fd8..ac977e4ed7 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit afea210fd8d68edcf0679a7618510e7dfe767103 +Subproject commit ac977e4ed79bd593440b9afbc25e1bba905e9276 diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index 39cdb32025..9b0a208ea5 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -10,7 +10,6 @@ #include #include -// For this test we assume that InfoType has at least these values: using stan::callbacks::dispatcher; using stan::callbacks::InfoType; @@ -147,43 +146,25 @@ TEST_F(DispatcherTest, StructuredBeginEndRecord) { // JSON output should contain opening and closing braces EXPECT_NE(output.find("{"), std::string::npos); EXPECT_NE(output.find("}"), std::string::npos); -======= +} + TEST_F(DispatcherTest, MetricStructuredKeyValueRecord) { // For METRIC (structured writer), open a record, dispatch key/value pairs, // then close the record. dispatcher.begin_record(InfoType::METRIC); dispatcher.dispatch(InfoType::METRIC, "metric_type", std::string("diag")); dispatcher.dispatch(InfoType::METRIC, "stepsize", 0.6789); - // For the inv_metric, assume the caller converts the vector to a - // comma-separated string. std::vector inv_metric = {0.1, 0.2, 0.3}; - std::string inv_metric_str; - for (size_t i = 0; i < inv_metric.size(); ++i) { - inv_metric_str += std::to_string(inv_metric[i]); - if (i != inv_metric.size() - 1) - inv_metric_str += ","; - } - dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric_str); + dispatcher.dispatch(InfoType::METRIC, "inv_metric", inv_metric); dispatcher.end_record(InfoType::METRIC); // Expected output: // Begin record marker, followed by key/value pairs each formatted as // "key:value;" and then end record marker. - std::cout << ss_metric.str() << std::endl; ->>>>>>> 89d756b23a601c560cb63a09930f5fcbce011efc -} - -// Test structured writer key-value pairs with string value -TEST_F(DispatcherTest, StructuredKeyStringValue) { - dispatcher.begin_record(InfoType::METRIC); - dispatcher.dispatch(InfoType::METRIC, "key1", std::string("value1")); - dispatcher.end_record(InfoType::METRIC); - std::string output = ss_metric.str(); - EXPECT_NE(output.find("key1"), std::string::npos); - EXPECT_NE(output.find("value1"), std::string::npos); + EXPECT_NE(output.find("metric_type"), std::string::npos); + EXPECT_NE(output.find("diag"), std::string::npos); } -<<<<<<< HEAD // Test structured writer with multiple key-value types TEST_F(DispatcherTest, StructuredMultipleValueTypes) { dispatcher.begin_record(InfoType::METRIC); From e50ea62856ef2d5aa8558bea92d6d85d9a8fb1e2 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 6 Mar 2025 16:13:41 -0500 Subject: [PATCH 42/53] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/test/unit/callbacks/dispatcher_test.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index 86bd61d001..b0c58d6b3a 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -171,12 +171,13 @@ TEST_F(DispatcherTest, MetricStructuredKeyValueRecord) { // Test structured writer with multiple key-value types TEST_F(DispatcherTest, StructuredMultipleValueTypes) { dispatcher.begin_record(InfoType::METRIC); - dispatcher.dispatch(InfoType::METRIC, "string_key", std::string("string_value")); + dispatcher.dispatch(InfoType::METRIC, "string_key", + std::string("string_value")); dispatcher.dispatch(InfoType::METRIC, "int_key", 42); dispatcher.dispatch(InfoType::METRIC, "double_key", 3.14159); dispatcher.dispatch(InfoType::METRIC, "bool_key", true); dispatcher.end_record(InfoType::METRIC); - + std::string output = ss_metric.str(); EXPECT_NE(output.find("string_key"), std::string::npos); EXPECT_NE(output.find("string_value"), std::string::npos); @@ -188,7 +189,6 @@ TEST_F(DispatcherTest, StructuredMultipleValueTypes) { EXPECT_NE(output.find("true"), std::string::npos); } - // Test structured writer with Eigen values TEST_F(DispatcherTest, StructuredEigenValues) { dispatcher.begin_record(InfoType::METRIC); @@ -216,8 +216,7 @@ TEST_F(DispatcherTest, StructuredEigenValues) { TEST_F(DispatcherTest, UnregisteredChannel) { // Dispatch to unregistered channel should silently do nothing dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::string("Message")); - dispatcher.dispatch(InfoType::ALGORITHM_STATE, - std::vector{1.0, 2.0}); + dispatcher.dispatch(InfoType::ALGORITHM_STATE, std::vector{1.0, 2.0}); dispatcher.begin_record(InfoType::ALGORITHM_STATE); dispatcher.dispatch(InfoType::ALGORITHM_STATE, "key", "value"); dispatcher.end_record(InfoType::ALGORITHM_STATE); From 877d6e5c67bd02ac0fc8eb52ab2f857b65bc93fa Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 6 Mar 2025 16:21:48 -0500 Subject: [PATCH 43/53] lint fix --- src/stan/callbacks/dispatcher.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index 4593a8d51e..ee360b38cd 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -106,10 +106,9 @@ class dispatcher { template < typename T, typename = std::enable_if_t< - std::is_same_v< - std::decay_t, - std:: - string> || std::is_same_v, std::vector> || std::is_same_v, std::vector>>> + std::is_same_v, std::string> + || std::is_same_v, std::vector> + || std::is_same_v, std::vector>>> void dispatch(InfoType type, T&& value) { if (auto* wc = find_channel(type)) wc->dispatch(std::forward(value)); From cef5f2be53de19eba9cc5b9761e9c38fbb30c1f6 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 6 Mar 2025 16:22:23 -0500 Subject: [PATCH 44/53] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/stan/callbacks/dispatcher.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index ee360b38cd..4593a8d51e 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -106,9 +106,10 @@ class dispatcher { template < typename T, typename = std::enable_if_t< - std::is_same_v, std::string> - || std::is_same_v, std::vector> - || std::is_same_v, std::vector>>> + std::is_same_v< + std::decay_t, + std:: + string> || std::is_same_v, std::vector> || std::is_same_v, std::vector>>> void dispatch(InfoType type, T&& value) { if (auto* wc = find_channel(type)) wc->dispatch(std::forward(value)); From a62c1eafb4c4eda6191bb0da80bcb51346974970 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 6 Mar 2025 18:57:35 -0500 Subject: [PATCH 45/53] lint fix --- src/stan/callbacks/dispatcher.hpp | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index 4593a8d51e..e7d5550eb8 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -91,25 +91,10 @@ class dispatcher { public: dispatcher() = default; ~dispatcher() = default; - - void register_channel(InfoType type, std::unique_ptr channel) { - channels_[type] = std::move(channel); - } - - // Empty call - void dispatch(InfoType type) { - if (auto* wc = find_channel(type)) - wc->dispatch(); - } - - // String, vector, vector - template < - typename T, - typename = std::enable_if_t< - std::is_same_v< - std::decay_t, - std:: - string> || std::is_same_v, std::vector> || std::is_same_v, std::vector>>> + typename = std::enable_if_t< + std::is_same_v, std::string> + || std::is_same_v, std::vector> + || std::is_same_v, std::vector>>> // NOLINT void dispatch(InfoType type, T&& value) { if (auto* wc = find_channel(type)) wc->dispatch(std::forward(value)); From 09aba5af39a338e5bbb4c1d335a78c42f53a3cec Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 6 Mar 2025 20:01:56 -0500 Subject: [PATCH 46/53] unit tests working --- src/stan/callbacks/dispatcher.hpp | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index e7d5550eb8..70af88ecb9 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -91,10 +91,27 @@ class dispatcher { public: dispatcher() = default; ~dispatcher() = default; - typename = std::enable_if_t< - std::is_same_v, std::string> - || std::is_same_v, std::vector> - || std::is_same_v, std::vector>>> // NOLINT + + void register_channel(InfoType type, std::unique_ptr channel) { + channels_[type] = std::move(channel); + } + + // no-arg call to writer operator () + void dispatch(InfoType type) { + auto it = channels_.find(type); + if (it == channels_.end()) + return; // silently do nothing + if (auto* wc = dynamic_cast(it->second.get())) + wc->dispatch(); + } + + // single non-string argument - call writer operator () + template < + typename T, + typename = std::enable_if_t< + std::is_same_v, std::string> + || std::is_same_v, std::vector> + || std::is_same_v, std::vector>>> // NOLINT void dispatch(InfoType type, T&& value) { if (auto* wc = find_channel(type)) wc->dispatch(std::forward(value)); From 0552dbed0ababf1eba223822d630d9cbee1e0961 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 6 Mar 2025 20:02:17 -0500 Subject: [PATCH 47/53] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/stan/callbacks/dispatcher.hpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index 70af88ecb9..c55f279340 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -104,14 +104,15 @@ class dispatcher { if (auto* wc = dynamic_cast(it->second.get())) wc->dispatch(); } - + // single non-string argument - call writer operator () template < - typename T, - typename = std::enable_if_t< - std::is_same_v, std::string> - || std::is_same_v, std::vector> - || std::is_same_v, std::vector>>> // NOLINT + typename T, + typename = std::enable_if_t< + std::is_same_v< + std::decay_t, + std:: + string> || std::is_same_v, std::vector> || std::is_same_v, std::vector>>> // NOLINT void dispatch(InfoType type, T&& value) { if (auto* wc = find_channel(type)) wc->dispatch(std::forward(value)); From 3a430cf1c4cf1eb1ec18c1faa75447ac2498bd6f Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 6 Mar 2025 21:26:16 -0500 Subject: [PATCH 48/53] added in_memory_writer and unit tests --- src/stan/callbacks/dispatcher.hpp | 3 +- src/stan/callbacks/in_memory_writer.hpp | 190 ++++++++++ src/test/unit/callbacks/dispatcher_test.cpp | 99 +++++ .../unit/callbacks/in_memory_writer_test.cpp | 358 ++++++++++++++++++ 4 files changed, 649 insertions(+), 1 deletion(-) create mode 100644 src/stan/callbacks/in_memory_writer.hpp create mode 100644 src/test/unit/callbacks/in_memory_writer_test.cpp diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index 70af88ecb9..914c2dfa84 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -18,6 +18,7 @@ namespace callbacks { enum class InfoType { CONFIG, // series of string messages SAMPLE, // draw from posterior + SAMPLE_RAW, // draw from posterior METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' ALGORITHM_STATE, // sampler state for returned draw }; @@ -104,7 +105,7 @@ class dispatcher { if (auto* wc = dynamic_cast(it->second.get())) wc->dispatch(); } - + // single non-string argument - call writer operator () template < typename T, diff --git a/src/stan/callbacks/in_memory_writer.hpp b/src/stan/callbacks/in_memory_writer.hpp new file mode 100644 index 0000000000..32cbd8f3f6 --- /dev/null +++ b/src/stan/callbacks/in_memory_writer.hpp @@ -0,0 +1,190 @@ +#ifndef STAN_CALLBACKS_IN_MEMORY_WRITER_HPP +#define STAN_CALLBACKS_IN_MEMORY_WRITER_HPP + +#include +#include +#include +#include +#include + +namespace stan { +namespace callbacks { + +/** + * A general in-memory writer that stores draws from Stan's write_array + * callback into a contiguous Eigen matrix. The matrix is allocated with + * column-major storage (Eigen's default), meaning that the element at + * row i and column j is stored at index (i + j * num_rows) in memory. + * + * The writer accepts draws row-by-row and writes them into the matrix. + * If more rows are written than initially allocated, or if the input row + * does not have the expected number of columns, an exception is thrown. + */ +class in_memory_writer : public stan::callbacks::writer { +public: + /** + * Construct an in-memory writer. + * + * @param num_rows the total number of rows (draws) expected + * @param num_cols the number of columns (parameters) per draw + * + * The underlying Eigen matrix is allocated to have dimensions + * num_rows x num_cols in column-major order. + */ + in_memory_writer(std::size_t num_rows, std::size_t num_cols) + : num_rows_(num_rows), + num_cols_(num_cols), + data_(Eigen::MatrixXd::Zero(num_rows, num_cols)), // column-major + current_row_(0), + names_() {} + + virtual ~in_memory_writer() {} + + /** + * Writes a set of names. + * + * @param[in] names Names in a std::vector + */ + void operator()(const std::vector& names) override { + names_ = names; + } + + /** + * Writes a single row (draw) to the in-memory matrix. + * + * The input vector must have exactly num_cols elements. + * The row is written into the matrix at the current row index. + * If the matrix is full, an exception is thrown. + * + * @param row A vector containing one draw from the posterior. + */ + void operator()(const std::vector& row) override { + if (row.size() != num_cols_) { + throw std::runtime_error("Row size does not match the number of columns"); + } + if (current_row_ >= num_rows_) { + throw std::runtime_error("Attempted to write more rows than allocated"); + } + // Because Eigen::MatrixXd is column-major by default, simply assigning + // row by row will place the data in memory in the order: + // index = current_row_ + (column_index * num_rows_). + for (std::size_t j = 0; j < num_cols_; ++j) { + data_(current_row_, j) = row[j]; + } + ++current_row_; + } + + /** + * Default implementation for empty call. + */ + void operator()() override { } + + /** + * Default implementation for string message. + */ + void operator()(const std::string& message) override { } + + /** + * Handles Eigen matrix input by converting to rows and inserting. + * This method treats each row of the input matrix as a separate draw. + */ + void operator()(const Eigen::Matrix& matrix) override { + for (Eigen::Index i = 0; i < matrix.rows(); ++i) { + std::vector row(matrix.cols()); + for (Eigen::Index j = 0; j < matrix.cols(); ++j) { + row[j] = matrix(i, j); + } + (*this)(row); // Use the vector operator to insert the row + } + } + + /** + * Handles Eigen vector input by inserting as a single row. + */ + void operator()(const Eigen::Matrix& vector) override { + std::vector row(vector.size()); + for (Eigen::Index i = 0; i < vector.size(); ++i) { + row[i] = vector(i); + } + (*this)(row); // Use the vector operator to insert the row + } + + /** + * Handles Eigen row vector input by inserting as a single row. + */ + void operator()(const Eigen::Matrix& vector) override { + std::vector row(vector.size()); + for (Eigen::Index i = 0; i < vector.size(); ++i) { + row[i] = vector(i); + } + (*this)(row); // Use the vector operator to insert the row + } + + /** + * Always returns true as the in-memory writer is always valid. + */ + bool is_valid() const noexcept override { + return true; + } + + /** + * Returns a const reference to the in-memory Eigen matrix containing all draws. + * + * The matrix is stored in column-major order. + * + * @return const reference to the Eigen::MatrixXd holding the draws. + */ + const Eigen::MatrixXd& get_eigen_state_values() const { + return data_; + } + + /** + * Returns a const reference to the column names. + * + * @return const reference to the vector of column names. + */ + const std::vector& get_names() const { + return names_; + } + + /** + * Returns the number of rows that have been written so far. + * + * @return The current row count. + */ + std::size_t get_row_count() const { + return current_row_; + } + + /** + * Resets the writer to its initial state. + * + * Clears the stored data (sets the matrix to zero) and resets the current row index. + * Column names are retained. + */ + void reset() { + current_row_ = 0; + data_.setZero(); + } + + /** + * Fully resets the writer, including clearing column names. + */ + void clear() { + current_row_ = 0; + data_.setZero(); + names_.clear(); + } + +private: + std::size_t num_rows_; // Total number of draws (rows) expected. + std::size_t num_cols_; // Number of parameters (columns) per draw. + Eigen::MatrixXd data_; // Internal storage; Eigen matrices are column-major. + std::size_t current_row_; // Next row index to be written. + std::vector names_; // Column names +}; + +} // namespace callbacks +} // namespace stan + +#endif // STAN_CALLBACKS_IN_MEMORY_WRITER_HPP diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index b0c58d6b3a..559e5b7170 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -52,6 +53,11 @@ class DispatcherTest : public ::testing::Test { InfoType::METRIC, std::unique_ptr( new stan::callbacks::StructuredWriterChannel(&writer_metric))); + + dispatcher.register_channel( + InfoType::SAMPLE_RAW, + std::unique_ptr( + new stan::callbacks::WriterChannel(&writer_sample_in_memory))); } void TearDown() {} @@ -63,6 +69,7 @@ class DispatcherTest : public ::testing::Test { stan::callbacks::stream_writer writer_sample; stan::callbacks::stream_writer writer_config; stan::callbacks::json_writer writer_metric; + stan::callbacks::in_memory_writer writer_sample_in_memory; stan::callbacks::dispatcher dispatcher; }; @@ -245,3 +252,95 @@ TEST_F(DispatcherTest, RecordOperationsOnPlainWriter) { // Should not generate any output EXPECT_EQ(ss_config.str(), ""); } + +// Test in_memory_writer integration with dispatcher +TEST_F(DispatcherTest, InMemoryWriterBasic) { + // Write rows of data to the in_memory_writer via the dispatcher + std::vector row1 = {1.1, 2.2, 3.3}; + std::vector row2 = {4.4, 5.5, 6.6}; + std::vector row3 = {7.7, 8.8, 9.9}; + + dispatcher.dispatch(InfoType::SAMPLE_RAW, row1); + dispatcher.dispatch(InfoType::SAMPLE_RAW, row2); + dispatcher.dispatch(InfoType::SAMPLE_RAW, row3); + + // Check that the data was stored correctly + const Eigen::MatrixXd& data = in_memory_writer_.get_eigen_state_values(); + + EXPECT_EQ(data.rows(), 5); // As initialized + EXPECT_EQ(data.cols(), 3); // As initialized + + // Check the stored values (first 3 rows should have our data) + EXPECT_DOUBLE_EQ(data(0, 0), 1.1); + EXPECT_DOUBLE_EQ(data(0, 1), 2.2); + EXPECT_DOUBLE_EQ(data(0, 2), 3.3); + + EXPECT_DOUBLE_EQ(data(1, 0), 4.4); + EXPECT_DOUBLE_EQ(data(1, 1), 5.5); + EXPECT_DOUBLE_EQ(data(1, 2), 6.6); + + EXPECT_DOUBLE_EQ(data(2, 0), 7.7); + EXPECT_DOUBLE_EQ(data(2, 1), 8.8); + EXPECT_DOUBLE_EQ(data(2, 2), 9.9); +} + +// Test in_memory_writer with row overflow +TEST_F(DispatcherTest, InMemoryWriterRowOverflow) { + // Try to write more rows than allocated + std::vector row = {1.0, 2.0, 3.0}; + + // Should be able to write 5 rows (as initialized) + for (int i = 0; i < 5; i++) { + dispatcher.dispatch(InfoType::SAMPLE_RAW, row); + } + + // Sixth row should throw exception + EXPECT_THROW( + dispatcher.dispatch(InfoType::SAMPLE_RAW, row), + std::runtime_error + ); +} + +// Test in_memory_writer with column mismatch +TEST_F(DispatcherTest, InMemoryWriterColumnMismatch) { + // Try to write a row with wrong number of columns + std::vector wrong_size_row = {1.0, 2.0}; // Only 2 columns when we need 3 + + EXPECT_THROW( + dispatcher.dispatch(InfoType::SAMPLE_RAW, wrong_size_row), + std::runtime_error + ); + + std::vector also_wrong_size = {1.0, 2.0, 3.0, 4.0}; // 4 columns when we need 3 + + EXPECT_THROW( + dispatcher.dispatch(InfoType::SAMPLE_RAW, also_wrong_size), + std::runtime_error + ); +} + +// Test in_memory_writer reset +TEST_F(DispatcherTest, InMemoryWriterReset) { + // Write some data + std::vector row = {1.1, 2.2, 3.3}; + dispatcher.dispatch(InfoType::SAMPLE_RAW, row); + + // Verify data is written + const Eigen::MatrixXd& data1 = in_memory_writer_.get_eigen_state_values(); + EXPECT_DOUBLE_EQ(data1(0, 0), 1.1); + + // Reset the writer + in_memory_writer_.reset(); + + // Verify data is cleared + const Eigen::MatrixXd& data2 = in_memory_writer_.get_eigen_state_values(); + EXPECT_DOUBLE_EQ(data2(0, 0), 0.0); + + // Write new data + std::vector new_row = {4.4, 5.5, 6.6}; + dispatcher.dispatch(InfoType::SAMPLE_RAW, new_row); + + // Verify new data is written + const Eigen::MatrixXd& data3 = in_memory_writer_.get_eigen_state_values(); + EXPECT_DOUBLE_EQ(data3(0, 0), 4.4); +} diff --git a/src/test/unit/callbacks/in_memory_writer_test.cpp b/src/test/unit/callbacks/in_memory_writer_test.cpp new file mode 100644 index 0000000000..487c2aa301 --- /dev/null +++ b/src/test/unit/callbacks/in_memory_writer_test.cpp @@ -0,0 +1,358 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +using stan::callbacks::in_memory_writer; +using stan::callbacks::dispatcher; +using stan::callbacks::InfoType; +using stan::callbacks::WriterChannel; + +// Helper function to split a comma-separated line into tokens +std::vector split_csv_line(const std::string& line) { + std::vector tokens; + std::stringstream ss(line); + std::string token; + + while (std::getline(ss, token, ',')) { + // Trim whitespace + token.erase(0, token.find_first_not_of(" \t")); + token.erase(token.find_last_not_of(" \t") + 1); + tokens.push_back(token); + } + + return tokens; +} + +// Helper function to convert string tokens to doubles +std::vector convert_to_doubles(const std::vector& tokens) { + std::vector values; + for (const auto& token : tokens) { + values.push_back(std::stod(token)); + } + return values; +} + +class InMemoryWriterTest : public ::testing::Test { +protected: + void SetUp() override { + // Parse the header + header_tokens = split_csv_line("theta, mu,y_rep.1,y_rep.2,y_rep.3,y_rep.4,y_rep.5,y_rep.6,y_rep.7,y_rep.8,y_rep.9,y_rep.10"); + num_cols = header_tokens.size(); + + // Parse the sample rows + std::vector sample_lines = { + "0.309252,0.309252,0,0,1,1,0,0,0,1,0,0", + "0.572524,0.572524,0,0,0,1,0,1,1,0,1,0", + "0.0795978,0.0795978,0,0,0,0,0,0,0,0,0,0" + }; + + for (const auto& line : sample_lines) { + auto tokens = split_csv_line(line); + auto values = convert_to_doubles(tokens); + sample_rows.push_back(values); + } + + num_rows = sample_rows.size(); + } + + std::vector header_tokens; + std::vector> sample_rows; + size_t num_rows; + size_t num_cols; +}; + +// Test basic writing and retrieval of samples +TEST_F(InMemoryWriterTest, BasicWriteAndRetrieve) { + // Create writer with exact capacity + in_memory_writer writer(num_rows, num_cols); + + // Write header + writer(header_tokens); + + // Write sample rows + for (const auto& row : sample_rows) { + writer(row); + } + + // Verify names were stored correctly + ASSERT_EQ(writer.get_names().size(), num_cols); + for (size_t i = 0; i < num_cols; ++i) { + EXPECT_EQ(writer.get_names()[i], header_tokens[i]); + } + + // Verify data was stored correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + ASSERT_EQ(data.rows(), num_rows); + ASSERT_EQ(data.cols(), num_cols); + + for (size_t i = 0; i < num_rows; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i, j), sample_rows[i][j]); + } + } + + // Verify row count + EXPECT_EQ(writer.get_row_count(), num_rows); +} + +// Test writing more data than allocated +TEST_F(InMemoryWriterTest, WriteOverflow) { + // Create writer with less capacity than needed + in_memory_writer writer(num_rows - 1, num_cols); + + // Write header + writer(header_tokens); + + // Write sample rows until overflow + for (size_t i = 0; i < num_rows - 1; ++i) { + writer(sample_rows[i]); + } + + // The next write should throw an exception + EXPECT_THROW(writer(sample_rows.back()), std::runtime_error); +} + +// Test writing rows with incorrect column count +TEST_F(InMemoryWriterTest, ColumnMismatch) { + // Create writer + in_memory_writer writer(num_rows, num_cols); + + // Write header + writer(header_tokens); + + // Create an invalid row with too few columns + std::vector invalid_row(num_cols - 1, 0.5); + + // Writing should throw an exception + EXPECT_THROW(writer(invalid_row), std::runtime_error); + + // Create an invalid row with too many columns + invalid_row.resize(num_cols + 1, 0.5); + + // Writing should throw an exception + EXPECT_THROW(writer(invalid_row), std::runtime_error); +} + +// Test reset functionality +TEST_F(InMemoryWriterTest, Reset) { + // Create writer + in_memory_writer writer(num_rows, num_cols); + + // Write header and first row + writer(header_tokens); + writer(sample_rows[0]); + + // Verify row was written + EXPECT_EQ(writer.get_row_count(), 1); + EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), sample_rows[0][0]); + + // Reset the writer + writer.reset(); + + // Verify row count is reset + EXPECT_EQ(writer.get_row_count(), 0); + + // Verify data is cleared + EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), 0.0); + + // Verify names are preserved + ASSERT_EQ(writer.get_names().size(), num_cols); + EXPECT_EQ(writer.get_names()[0], header_tokens[0]); + + // Write a different row + writer(sample_rows[1]); + + // Verify new data is written + EXPECT_EQ(writer.get_row_count(), 1); + EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), sample_rows[1][0]); +} + +// Test clear functionality +TEST_F(InMemoryWriterTest, Clear) { + // Create writer + in_memory_writer writer(num_rows, num_cols); + + // Write header and first row + writer(header_tokens); + writer(sample_rows[0]); + + // Clear the writer + writer.clear(); + + // Verify row count is reset + EXPECT_EQ(writer.get_row_count(), 0); + + // Verify data is cleared + EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), 0.0); + + // Verify names are also cleared + EXPECT_EQ(writer.get_names().size(), 0); +} + +// Test with different size allocations +TEST_F(InMemoryWriterTest, DifferentSizeAllocations) { + // Create writer with more capacity than needed + in_memory_writer writer(num_rows * 2, num_cols); + + // Write header + writer(header_tokens); + + // Write sample rows + for (const auto& row : sample_rows) { + writer(row); + } + + // Verify data was stored correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + ASSERT_EQ(data.rows(), num_rows * 2); // Total allocation + ASSERT_EQ(data.cols(), num_cols); + + // Verify only written rows have data + for (size_t i = 0; i < num_rows; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i, j), sample_rows[i][j]); + } + } + + // Remaining rows should be zero + for (size_t i = num_rows; i < num_rows * 2; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i, j), 0.0); + } + } +} + +// Test integration with dispatcher +TEST_F(InMemoryWriterTest, DispatcherIntegration) { + // Create writer and dispatcher + in_memory_writer writer(num_rows, num_cols); + dispatcher disp; + + // Register channel + disp.register_channel( + InfoType::SAMPLE_RAW, + std::unique_ptr( + new WriterChannel(&writer)) + ); + + // Send header and data through dispatcher + disp.dispatch(InfoType::SAMPLE_RAW, header_tokens); + for (const auto& row : sample_rows) { + disp.dispatch(InfoType::SAMPLE_RAW, row); + } + + // Verify data was received correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + for (size_t i = 0; i < num_rows; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i, j), sample_rows[i][j]); + } + } +} + +// Test handling of string messages +TEST_F(InMemoryWriterTest, StringMessages) { + in_memory_writer writer(num_rows, num_cols); + + // Write header + writer(header_tokens); + + // Write a string message (should be ignored) + writer("This is a message that should be ignored"); + + // Write sample row + writer(sample_rows[0]); + + // Verify row count is correct (only the data row should be counted) + EXPECT_EQ(writer.get_row_count(), 1); + + // Verify data was stored correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(0, j), sample_rows[0][j]); + } +} + +// Test handling of empty calls +TEST_F(InMemoryWriterTest, EmptyCalls) { + in_memory_writer writer(num_rows, num_cols); + + // Write header + writer(header_tokens); + + // Write an empty call (should be ignored) + writer(); + + // Write sample row + writer(sample_rows[0]); + + // Verify row count is correct (only the data row should be counted) + EXPECT_EQ(writer.get_row_count(), 1); + + // Verify data was stored correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(0, j), sample_rows[0][j]); + } +} + +// Test handling of Eigen input types +TEST_F(InMemoryWriterTest, EigenInputs) { + in_memory_writer writer(num_rows + 3, num_cols); + + // Write header + writer(header_tokens); + + // Write sample row as vector + writer(sample_rows[0]); + + // Write sample row as Eigen::VectorXd + Eigen::VectorXd vec(num_cols); + for (size_t j = 0; j < num_cols; ++j) { + vec(j) = sample_rows[1][j]; + } + writer(vec); + + // Write sample row as Eigen::RowVectorXd + Eigen::RowVectorXd row_vec(num_cols); + for (size_t j = 0; j < num_cols; ++j) { + row_vec(j) = sample_rows[2][j]; + } + writer(row_vec); + + // Create a matrix with all three rows + Eigen::MatrixXd matrix(3, num_cols); + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + matrix(i, j) = sample_rows[i][j]; + } + } + + // Write all rows at once as a matrix + writer(matrix); + + // Verify row count (should be 1 + 1 + 1 + 3 = 6) + EXPECT_EQ(writer.get_row_count(), 6); + + // Verify data was stored correctly + const Eigen::MatrixXd& data = writer.get_eigen_state_values(); + + // First three rows should match the original inputs + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i, j), sample_rows[i][j]); + } + } + + // Next three rows should match the matrix input + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < num_cols; ++j) { + EXPECT_DOUBLE_EQ(data(i + 3, j), sample_rows[i][j]); + } + } +} From 811a03381d1b266bc3809f2b859d0e5d4d765431 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 6 Mar 2025 21:26:51 -0500 Subject: [PATCH 49/53] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/stan/callbacks/dispatcher.hpp | 6 +- src/stan/callbacks/in_memory_writer.hpp | 49 +++--- src/test/unit/callbacks/dispatcher_test.cpp | 58 ++++--- .../unit/callbacks/in_memory_writer_test.cpp | 144 +++++++++--------- 4 files changed, 123 insertions(+), 134 deletions(-) diff --git a/src/stan/callbacks/dispatcher.hpp b/src/stan/callbacks/dispatcher.hpp index 25ffc95bb1..f8b3f116ad 100644 --- a/src/stan/callbacks/dispatcher.hpp +++ b/src/stan/callbacks/dispatcher.hpp @@ -16,10 +16,10 @@ namespace stan { namespace callbacks { enum class InfoType { - CONFIG, // series of string messages - SAMPLE, // draw from posterior + CONFIG, // series of string messages + SAMPLE, // draw from posterior SAMPLE_RAW, // draw from posterior - METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' + METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric' ALGORITHM_STATE, // sampler state for returned draw }; diff --git a/src/stan/callbacks/in_memory_writer.hpp b/src/stan/callbacks/in_memory_writer.hpp index 32cbd8f3f6..25d38ec905 100644 --- a/src/stan/callbacks/in_memory_writer.hpp +++ b/src/stan/callbacks/in_memory_writer.hpp @@ -21,7 +21,7 @@ namespace callbacks { * does not have the expected number of columns, an exception is thrown. */ class in_memory_writer : public stan::callbacks::writer { -public: + public: /** * Construct an in-memory writer. * @@ -32,11 +32,11 @@ class in_memory_writer : public stan::callbacks::writer { * num_rows x num_cols in column-major order. */ in_memory_writer(std::size_t num_rows, std::size_t num_cols) - : num_rows_(num_rows), - num_cols_(num_cols), - data_(Eigen::MatrixXd::Zero(num_rows, num_cols)), // column-major - current_row_(0), - names_() {} + : num_rows_(num_rows), + num_cols_(num_cols), + data_(Eigen::MatrixXd::Zero(num_rows, num_cols)), // column-major + current_row_(0), + names_() {} virtual ~in_memory_writer() {} @@ -77,12 +77,12 @@ class in_memory_writer : public stan::callbacks::writer { /** * Default implementation for empty call. */ - void operator()() override { } + void operator()() override {} /** * Default implementation for string message. */ - void operator()(const std::string& message) override { } + void operator()(const std::string& message) override {} /** * Handles Eigen matrix input by converting to rows and inserting. @@ -123,44 +123,37 @@ class in_memory_writer : public stan::callbacks::writer { /** * Always returns true as the in-memory writer is always valid. */ - bool is_valid() const noexcept override { - return true; - } + bool is_valid() const noexcept override { return true; } /** - * Returns a const reference to the in-memory Eigen matrix containing all draws. + * Returns a const reference to the in-memory Eigen matrix containing all + * draws. * * The matrix is stored in column-major order. * * @return const reference to the Eigen::MatrixXd holding the draws. */ - const Eigen::MatrixXd& get_eigen_state_values() const { - return data_; - } + const Eigen::MatrixXd& get_eigen_state_values() const { return data_; } /** * Returns a const reference to the column names. * * @return const reference to the vector of column names. */ - const std::vector& get_names() const { - return names_; - } + const std::vector& get_names() const { return names_; } /** * Returns the number of rows that have been written so far. * * @return The current row count. */ - std::size_t get_row_count() const { - return current_row_; - } + std::size_t get_row_count() const { return current_row_; } /** * Resets the writer to its initial state. * - * Clears the stored data (sets the matrix to zero) and resets the current row index. - * Column names are retained. + * Clears the stored data (sets the matrix to zero) and resets the current row + * index. Column names are retained. */ void reset() { current_row_ = 0; @@ -176,11 +169,11 @@ class in_memory_writer : public stan::callbacks::writer { names_.clear(); } -private: - std::size_t num_rows_; // Total number of draws (rows) expected. - std::size_t num_cols_; // Number of parameters (columns) per draw. - Eigen::MatrixXd data_; // Internal storage; Eigen matrices are column-major. - std::size_t current_row_; // Next row index to be written. + private: + std::size_t num_rows_; // Total number of draws (rows) expected. + std::size_t num_cols_; // Number of parameters (columns) per draw. + Eigen::MatrixXd data_; // Internal storage; Eigen matrices are column-major. + std::size_t current_row_; // Next row index to be written. std::vector names_; // Column names }; diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index 559e5b7170..02a8600787 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -55,7 +55,7 @@ class DispatcherTest : public ::testing::Test { new stan::callbacks::StructuredWriterChannel(&writer_metric))); dispatcher.register_channel( - InfoType::SAMPLE_RAW, + InfoType::SAMPLE_RAW, std::unique_ptr( new stan::callbacks::WriterChannel(&writer_sample_in_memory))); } @@ -259,26 +259,26 @@ TEST_F(DispatcherTest, InMemoryWriterBasic) { std::vector row1 = {1.1, 2.2, 3.3}; std::vector row2 = {4.4, 5.5, 6.6}; std::vector row3 = {7.7, 8.8, 9.9}; - + dispatcher.dispatch(InfoType::SAMPLE_RAW, row1); dispatcher.dispatch(InfoType::SAMPLE_RAW, row2); dispatcher.dispatch(InfoType::SAMPLE_RAW, row3); - + // Check that the data was stored correctly const Eigen::MatrixXd& data = in_memory_writer_.get_eigen_state_values(); - + EXPECT_EQ(data.rows(), 5); // As initialized EXPECT_EQ(data.cols(), 3); // As initialized - + // Check the stored values (first 3 rows should have our data) EXPECT_DOUBLE_EQ(data(0, 0), 1.1); EXPECT_DOUBLE_EQ(data(0, 1), 2.2); EXPECT_DOUBLE_EQ(data(0, 2), 3.3); - + EXPECT_DOUBLE_EQ(data(1, 0), 4.4); EXPECT_DOUBLE_EQ(data(1, 1), 5.5); EXPECT_DOUBLE_EQ(data(1, 2), 6.6); - + EXPECT_DOUBLE_EQ(data(2, 0), 7.7); EXPECT_DOUBLE_EQ(data(2, 1), 8.8); EXPECT_DOUBLE_EQ(data(2, 2), 9.9); @@ -288,35 +288,31 @@ TEST_F(DispatcherTest, InMemoryWriterBasic) { TEST_F(DispatcherTest, InMemoryWriterRowOverflow) { // Try to write more rows than allocated std::vector row = {1.0, 2.0, 3.0}; - + // Should be able to write 5 rows (as initialized) for (int i = 0; i < 5; i++) { dispatcher.dispatch(InfoType::SAMPLE_RAW, row); } - + // Sixth row should throw exception - EXPECT_THROW( - dispatcher.dispatch(InfoType::SAMPLE_RAW, row), - std::runtime_error - ); + EXPECT_THROW(dispatcher.dispatch(InfoType::SAMPLE_RAW, row), + std::runtime_error); } // Test in_memory_writer with column mismatch TEST_F(DispatcherTest, InMemoryWriterColumnMismatch) { // Try to write a row with wrong number of columns - std::vector wrong_size_row = {1.0, 2.0}; // Only 2 columns when we need 3 - - EXPECT_THROW( - dispatcher.dispatch(InfoType::SAMPLE_RAW, wrong_size_row), - std::runtime_error - ); - - std::vector also_wrong_size = {1.0, 2.0, 3.0, 4.0}; // 4 columns when we need 3 - - EXPECT_THROW( - dispatcher.dispatch(InfoType::SAMPLE_RAW, also_wrong_size), - std::runtime_error - ); + std::vector wrong_size_row + = {1.0, 2.0}; // Only 2 columns when we need 3 + + EXPECT_THROW(dispatcher.dispatch(InfoType::SAMPLE_RAW, wrong_size_row), + std::runtime_error); + + std::vector also_wrong_size + = {1.0, 2.0, 3.0, 4.0}; // 4 columns when we need 3 + + EXPECT_THROW(dispatcher.dispatch(InfoType::SAMPLE_RAW, also_wrong_size), + std::runtime_error); } // Test in_memory_writer reset @@ -324,22 +320,22 @@ TEST_F(DispatcherTest, InMemoryWriterReset) { // Write some data std::vector row = {1.1, 2.2, 3.3}; dispatcher.dispatch(InfoType::SAMPLE_RAW, row); - + // Verify data is written const Eigen::MatrixXd& data1 = in_memory_writer_.get_eigen_state_values(); EXPECT_DOUBLE_EQ(data1(0, 0), 1.1); - + // Reset the writer in_memory_writer_.reset(); - + // Verify data is cleared const Eigen::MatrixXd& data2 = in_memory_writer_.get_eigen_state_values(); EXPECT_DOUBLE_EQ(data2(0, 0), 0.0); - + // Write new data std::vector new_row = {4.4, 5.5, 6.6}; dispatcher.dispatch(InfoType::SAMPLE_RAW, new_row); - + // Verify new data is written const Eigen::MatrixXd& data3 = in_memory_writer_.get_eigen_state_values(); EXPECT_DOUBLE_EQ(data3(0, 0), 4.4); diff --git a/src/test/unit/callbacks/in_memory_writer_test.cpp b/src/test/unit/callbacks/in_memory_writer_test.cpp index 487c2aa301..21a750c5be 100644 --- a/src/test/unit/callbacks/in_memory_writer_test.cpp +++ b/src/test/unit/callbacks/in_memory_writer_test.cpp @@ -7,8 +7,8 @@ #include #include -using stan::callbacks::in_memory_writer; using stan::callbacks::dispatcher; +using stan::callbacks::in_memory_writer; using stan::callbacks::InfoType; using stan::callbacks::WriterChannel; @@ -17,14 +17,14 @@ std::vector split_csv_line(const std::string& line) { std::vector tokens; std::stringstream ss(line); std::string token; - + while (std::getline(ss, token, ',')) { // Trim whitespace token.erase(0, token.find_first_not_of(" \t")); token.erase(token.find_last_not_of(" \t") + 1); tokens.push_back(token); } - + return tokens; } @@ -38,28 +38,30 @@ std::vector convert_to_doubles(const std::vector& tokens) { } class InMemoryWriterTest : public ::testing::Test { -protected: + protected: void SetUp() override { // Parse the header - header_tokens = split_csv_line("theta, mu,y_rep.1,y_rep.2,y_rep.3,y_rep.4,y_rep.5,y_rep.6,y_rep.7,y_rep.8,y_rep.9,y_rep.10"); + header_tokens = split_csv_line( + "theta, " + "mu,y_rep.1,y_rep.2,y_rep.3,y_rep.4,y_rep.5,y_rep.6,y_rep.7,y_rep.8,y_" + "rep.9,y_rep.10"); num_cols = header_tokens.size(); - + // Parse the sample rows - std::vector sample_lines = { - "0.309252,0.309252,0,0,1,1,0,0,0,1,0,0", - "0.572524,0.572524,0,0,0,1,0,1,1,0,1,0", - "0.0795978,0.0795978,0,0,0,0,0,0,0,0,0,0" - }; - + std::vector sample_lines + = {"0.309252,0.309252,0,0,1,1,0,0,0,1,0,0", + "0.572524,0.572524,0,0,0,1,0,1,1,0,1,0", + "0.0795978,0.0795978,0,0,0,0,0,0,0,0,0,0"}; + for (const auto& line : sample_lines) { auto tokens = split_csv_line(line); auto values = convert_to_doubles(tokens); sample_rows.push_back(values); } - + num_rows = sample_rows.size(); } - + std::vector header_tokens; std::vector> sample_rows; size_t num_rows; @@ -70,32 +72,32 @@ class InMemoryWriterTest : public ::testing::Test { TEST_F(InMemoryWriterTest, BasicWriteAndRetrieve) { // Create writer with exact capacity in_memory_writer writer(num_rows, num_cols); - + // Write header writer(header_tokens); - + // Write sample rows for (const auto& row : sample_rows) { writer(row); } - + // Verify names were stored correctly ASSERT_EQ(writer.get_names().size(), num_cols); for (size_t i = 0; i < num_cols; ++i) { EXPECT_EQ(writer.get_names()[i], header_tokens[i]); } - + // Verify data was stored correctly const Eigen::MatrixXd& data = writer.get_eigen_state_values(); ASSERT_EQ(data.rows(), num_rows); ASSERT_EQ(data.cols(), num_cols); - + for (size_t i = 0; i < num_rows; ++i) { for (size_t j = 0; j < num_cols; ++j) { EXPECT_DOUBLE_EQ(data(i, j), sample_rows[i][j]); } } - + // Verify row count EXPECT_EQ(writer.get_row_count(), num_rows); } @@ -104,15 +106,15 @@ TEST_F(InMemoryWriterTest, BasicWriteAndRetrieve) { TEST_F(InMemoryWriterTest, WriteOverflow) { // Create writer with less capacity than needed in_memory_writer writer(num_rows - 1, num_cols); - + // Write header writer(header_tokens); - + // Write sample rows until overflow for (size_t i = 0; i < num_rows - 1; ++i) { writer(sample_rows[i]); } - + // The next write should throw an exception EXPECT_THROW(writer(sample_rows.back()), std::runtime_error); } @@ -121,19 +123,19 @@ TEST_F(InMemoryWriterTest, WriteOverflow) { TEST_F(InMemoryWriterTest, ColumnMismatch) { // Create writer in_memory_writer writer(num_rows, num_cols); - + // Write header writer(header_tokens); - + // Create an invalid row with too few columns std::vector invalid_row(num_cols - 1, 0.5); - + // Writing should throw an exception EXPECT_THROW(writer(invalid_row), std::runtime_error); - + // Create an invalid row with too many columns invalid_row.resize(num_cols + 1, 0.5); - + // Writing should throw an exception EXPECT_THROW(writer(invalid_row), std::runtime_error); } @@ -142,31 +144,31 @@ TEST_F(InMemoryWriterTest, ColumnMismatch) { TEST_F(InMemoryWriterTest, Reset) { // Create writer in_memory_writer writer(num_rows, num_cols); - + // Write header and first row writer(header_tokens); writer(sample_rows[0]); - + // Verify row was written EXPECT_EQ(writer.get_row_count(), 1); EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), sample_rows[0][0]); - + // Reset the writer writer.reset(); - + // Verify row count is reset EXPECT_EQ(writer.get_row_count(), 0); - + // Verify data is cleared EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), 0.0); - + // Verify names are preserved ASSERT_EQ(writer.get_names().size(), num_cols); EXPECT_EQ(writer.get_names()[0], header_tokens[0]); - + // Write a different row writer(sample_rows[1]); - + // Verify new data is written EXPECT_EQ(writer.get_row_count(), 1); EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), sample_rows[1][0]); @@ -176,20 +178,20 @@ TEST_F(InMemoryWriterTest, Reset) { TEST_F(InMemoryWriterTest, Clear) { // Create writer in_memory_writer writer(num_rows, num_cols); - + // Write header and first row writer(header_tokens); writer(sample_rows[0]); - + // Clear the writer writer.clear(); - + // Verify row count is reset EXPECT_EQ(writer.get_row_count(), 0); - + // Verify data is cleared EXPECT_DOUBLE_EQ(writer.get_eigen_state_values()(0, 0), 0.0); - + // Verify names are also cleared EXPECT_EQ(writer.get_names().size(), 0); } @@ -198,27 +200,27 @@ TEST_F(InMemoryWriterTest, Clear) { TEST_F(InMemoryWriterTest, DifferentSizeAllocations) { // Create writer with more capacity than needed in_memory_writer writer(num_rows * 2, num_cols); - + // Write header writer(header_tokens); - + // Write sample rows for (const auto& row : sample_rows) { writer(row); } - + // Verify data was stored correctly const Eigen::MatrixXd& data = writer.get_eigen_state_values(); ASSERT_EQ(data.rows(), num_rows * 2); // Total allocation ASSERT_EQ(data.cols(), num_cols); - + // Verify only written rows have data for (size_t i = 0; i < num_rows; ++i) { for (size_t j = 0; j < num_cols; ++j) { EXPECT_DOUBLE_EQ(data(i, j), sample_rows[i][j]); } } - + // Remaining rows should be zero for (size_t i = num_rows; i < num_rows * 2; ++i) { for (size_t j = 0; j < num_cols; ++j) { @@ -232,20 +234,18 @@ TEST_F(InMemoryWriterTest, DispatcherIntegration) { // Create writer and dispatcher in_memory_writer writer(num_rows, num_cols); dispatcher disp; - + // Register channel disp.register_channel( InfoType::SAMPLE_RAW, - std::unique_ptr( - new WriterChannel(&writer)) - ); - + std::unique_ptr(new WriterChannel(&writer))); + // Send header and data through dispatcher disp.dispatch(InfoType::SAMPLE_RAW, header_tokens); for (const auto& row : sample_rows) { disp.dispatch(InfoType::SAMPLE_RAW, row); } - + // Verify data was received correctly const Eigen::MatrixXd& data = writer.get_eigen_state_values(); for (size_t i = 0; i < num_rows; ++i) { @@ -258,19 +258,19 @@ TEST_F(InMemoryWriterTest, DispatcherIntegration) { // Test handling of string messages TEST_F(InMemoryWriterTest, StringMessages) { in_memory_writer writer(num_rows, num_cols); - + // Write header writer(header_tokens); - + // Write a string message (should be ignored) writer("This is a message that should be ignored"); - + // Write sample row writer(sample_rows[0]); - + // Verify row count is correct (only the data row should be counted) EXPECT_EQ(writer.get_row_count(), 1); - + // Verify data was stored correctly const Eigen::MatrixXd& data = writer.get_eigen_state_values(); for (size_t j = 0; j < num_cols; ++j) { @@ -281,19 +281,19 @@ TEST_F(InMemoryWriterTest, StringMessages) { // Test handling of empty calls TEST_F(InMemoryWriterTest, EmptyCalls) { in_memory_writer writer(num_rows, num_cols); - + // Write header writer(header_tokens); - + // Write an empty call (should be ignored) writer(); - + // Write sample row writer(sample_rows[0]); - + // Verify row count is correct (only the data row should be counted) EXPECT_EQ(writer.get_row_count(), 1); - + // Verify data was stored correctly const Eigen::MatrixXd& data = writer.get_eigen_state_values(); for (size_t j = 0; j < num_cols; ++j) { @@ -304,27 +304,27 @@ TEST_F(InMemoryWriterTest, EmptyCalls) { // Test handling of Eigen input types TEST_F(InMemoryWriterTest, EigenInputs) { in_memory_writer writer(num_rows + 3, num_cols); - + // Write header writer(header_tokens); - + // Write sample row as vector writer(sample_rows[0]); - + // Write sample row as Eigen::VectorXd Eigen::VectorXd vec(num_cols); for (size_t j = 0; j < num_cols; ++j) { vec(j) = sample_rows[1][j]; } writer(vec); - + // Write sample row as Eigen::RowVectorXd Eigen::RowVectorXd row_vec(num_cols); for (size_t j = 0; j < num_cols; ++j) { row_vec(j) = sample_rows[2][j]; } writer(row_vec); - + // Create a matrix with all three rows Eigen::MatrixXd matrix(3, num_cols); for (size_t i = 0; i < 3; ++i) { @@ -332,23 +332,23 @@ TEST_F(InMemoryWriterTest, EigenInputs) { matrix(i, j) = sample_rows[i][j]; } } - + // Write all rows at once as a matrix writer(matrix); - + // Verify row count (should be 1 + 1 + 1 + 3 = 6) EXPECT_EQ(writer.get_row_count(), 6); - + // Verify data was stored correctly const Eigen::MatrixXd& data = writer.get_eigen_state_values(); - + // First three rows should match the original inputs for (size_t i = 0; i < 3; ++i) { for (size_t j = 0; j < num_cols; ++j) { EXPECT_DOUBLE_EQ(data(i, j), sample_rows[i][j]); } } - + // Next three rows should match the matrix input for (size_t i = 0; i < 3; ++i) { for (size_t j = 0; j < num_cols; ++j) { From 61db5b824ddf74983d733d5e9352f827848a02f6 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 7 Mar 2025 00:11:22 -0500 Subject: [PATCH 50/53] in_memory_sampler unit test --- src/test/unit/callbacks/dispatcher_test.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index 559e5b7170..b43b9be8a0 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -29,6 +29,7 @@ class DispatcherTest : public ::testing::Test { writer_config(ss_config), writer_metric( std::unique_ptr(&ss_metric)), + writer_sample_in_memory(3,12), dispatcher() {} void SetUp() { From 9c0882561b96d670fa973f1cd9677a48fb4db4f4 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 7 Mar 2025 00:12:29 -0500 Subject: [PATCH 51/53] in_memory_sampler unit test --- src/test/unit/callbacks/dispatcher_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index b43b9be8a0..2bca316f75 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -29,7 +29,7 @@ class DispatcherTest : public ::testing::Test { writer_config(ss_config), writer_metric( std::unique_ptr(&ss_metric)), - writer_sample_in_memory(3,12), + writer_sample_in_memory(3,3), dispatcher() {} void SetUp() { From 5ce6073ea7209e0965d769f6597a5189abcc03e1 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 7 Mar 2025 00:14:51 -0500 Subject: [PATCH 52/53] dispatcher unit test --- src/test/unit/callbacks/dispatcher_test.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index 2bca316f75..a920edc796 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -29,7 +29,7 @@ class DispatcherTest : public ::testing::Test { writer_config(ss_config), writer_metric( std::unique_ptr(&ss_metric)), - writer_sample_in_memory(3,3), + writer_sample_in_memory(5,3), dispatcher() {} void SetUp() { @@ -266,7 +266,7 @@ TEST_F(DispatcherTest, InMemoryWriterBasic) { dispatcher.dispatch(InfoType::SAMPLE_RAW, row3); // Check that the data was stored correctly - const Eigen::MatrixXd& data = in_memory_writer_.get_eigen_state_values(); + const Eigen::MatrixXd& data = writer_sample_in_memory.get_eigen_state_values(); EXPECT_EQ(data.rows(), 5); // As initialized EXPECT_EQ(data.cols(), 3); // As initialized @@ -327,14 +327,14 @@ TEST_F(DispatcherTest, InMemoryWriterReset) { dispatcher.dispatch(InfoType::SAMPLE_RAW, row); // Verify data is written - const Eigen::MatrixXd& data1 = in_memory_writer_.get_eigen_state_values(); + const Eigen::MatrixXd& data1 = writer_sample_in_memory.get_eigen_state_values(); EXPECT_DOUBLE_EQ(data1(0, 0), 1.1); // Reset the writer - in_memory_writer_.reset(); + writer_sample_in_memory.reset(); // Verify data is cleared - const Eigen::MatrixXd& data2 = in_memory_writer_.get_eigen_state_values(); + const Eigen::MatrixXd& data2 = writer_sample_in_memory.get_eigen_state_values(); EXPECT_DOUBLE_EQ(data2(0, 0), 0.0); // Write new data @@ -342,6 +342,6 @@ TEST_F(DispatcherTest, InMemoryWriterReset) { dispatcher.dispatch(InfoType::SAMPLE_RAW, new_row); // Verify new data is written - const Eigen::MatrixXd& data3 = in_memory_writer_.get_eigen_state_values(); + const Eigen::MatrixXd& data3 = writer_sample_in_memory.get_eigen_state_values(); EXPECT_DOUBLE_EQ(data3(0, 0), 4.4); } From 4aced3e8a24a1a046f37403791e2db5b70dd04b0 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Fri, 7 Mar 2025 00:16:48 -0500 Subject: [PATCH 53/53] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/test/unit/callbacks/dispatcher_test.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/test/unit/callbacks/dispatcher_test.cpp b/src/test/unit/callbacks/dispatcher_test.cpp index f7024f658f..2d992d8462 100644 --- a/src/test/unit/callbacks/dispatcher_test.cpp +++ b/src/test/unit/callbacks/dispatcher_test.cpp @@ -29,7 +29,7 @@ class DispatcherTest : public ::testing::Test { writer_config(ss_config), writer_metric( std::unique_ptr(&ss_metric)), - writer_sample_in_memory(5,3), + writer_sample_in_memory(5, 3), dispatcher() {} void SetUp() { @@ -266,7 +266,8 @@ TEST_F(DispatcherTest, InMemoryWriterBasic) { dispatcher.dispatch(InfoType::SAMPLE_RAW, row3); // Check that the data was stored correctly - const Eigen::MatrixXd& data = writer_sample_in_memory.get_eigen_state_values(); + const Eigen::MatrixXd& data + = writer_sample_in_memory.get_eigen_state_values(); EXPECT_EQ(data.rows(), 5); // As initialized EXPECT_EQ(data.cols(), 3); // As initialized @@ -323,14 +324,16 @@ TEST_F(DispatcherTest, InMemoryWriterReset) { dispatcher.dispatch(InfoType::SAMPLE_RAW, row); // Verify data is written - const Eigen::MatrixXd& data1 = writer_sample_in_memory.get_eigen_state_values(); + const Eigen::MatrixXd& data1 + = writer_sample_in_memory.get_eigen_state_values(); EXPECT_DOUBLE_EQ(data1(0, 0), 1.1); // Reset the writer writer_sample_in_memory.reset(); // Verify data is cleared - const Eigen::MatrixXd& data2 = writer_sample_in_memory.get_eigen_state_values(); + const Eigen::MatrixXd& data2 + = writer_sample_in_memory.get_eigen_state_values(); EXPECT_DOUBLE_EQ(data2(0, 0), 0.0); // Write new data @@ -338,6 +341,7 @@ TEST_F(DispatcherTest, InMemoryWriterReset) { dispatcher.dispatch(InfoType::SAMPLE_RAW, new_row); // Verify new data is written - const Eigen::MatrixXd& data3 = writer_sample_in_memory.get_eigen_state_values(); + const Eigen::MatrixXd& data3 + = writer_sample_in_memory.get_eigen_state_values(); EXPECT_DOUBLE_EQ(data3(0, 0), 4.4); }