Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dispatcher object to callbacks to coordinate sampler outputs #3334

Closed
wants to merge 65 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
bef2b07
Merge branch 'develop' of https://github.com/stan-dev/stan into develop
mitzimorris Feb 4, 2025
dc17597
Merge branch 'develop' of https://github.com/stan-dev/stan into develop
mitzimorris Feb 19, 2025
cb727ce
Merge branch 'develop' of https://github.com/stan-dev/stan into develop
mitzimorris Mar 3, 2025
3b7a3f7
output_controller and tests
mitzimorris Mar 3, 2025
2cdfda8
path fix
mitzimorris Mar 3, 2025
db3e8a5
fix namespace
mitzimorris Mar 4, 2025
8473d2c
unneeded include
mitzimorris Mar 4, 2025
f4acf74
2nd try
mitzimorris Mar 4, 2025
7e7c110
another try
mitzimorris Mar 4, 2025
c16635e
try try again
mitzimorris Mar 4, 2025
32acdc0
try again
mitzimorris Mar 4, 2025
63194af
try again
mitzimorris Mar 4, 2025
c80c2ed
try again
mitzimorris Mar 4, 2025
eb6d343
try again
mitzimorris Mar 4, 2025
d7579c6
try again
mitzimorris Mar 4, 2025
8742b7d
try again
mitzimorris Mar 4, 2025
b2297c2
try again
mitzimorris Mar 4, 2025
11024d9
try again
mitzimorris Mar 4, 2025
37000c4
try again
mitzimorris Mar 4, 2025
b5c858a
fix namespaces
mitzimorris Mar 4, 2025
ff2f43e
namespace fix
mitzimorris Mar 4, 2025
dc90f4f
try again
mitzimorris Mar 4, 2025
ff72b3b
try again
mitzimorris Mar 4, 2025
1dfa914
TRY AGAIN
mitzimorris Mar 4, 2025
ef0fa16
try again
mitzimorris Mar 4, 2025
21fd3e8
try again
mitzimorris Mar 4, 2025
944c0a8
try again
mitzimorris Mar 4, 2025
fb79662
try again
mitzimorris Mar 4, 2025
b52f087
try again
mitzimorris Mar 4, 2025
75f59da
try again
mitzimorris Mar 4, 2025
ce5b7a9
try again
mitzimorris Mar 4, 2025
e4f9e00
try again
mitzimorris Mar 4, 2025
7cf3f4e
try again
mitzimorris Mar 4, 2025
0558bad
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 4, 2025
54f39e4
Merge branch 'feature/in-memory-writer' of https://github.com/stan-de…
mitzimorris Mar 4, 2025
03d24f6
lint fix, cleanup
mitzimorris Mar 4, 2025
072d882
lint fix
mitzimorris Mar 4, 2025
3d6f08f
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 4, 2025
f86be77
lint fix
mitzimorris Mar 4, 2025
0f31cd4
Merge branch 'feature/in-memory-writer' of https://github.com/stan-de…
mitzimorris Mar 4, 2025
87b7165
checkpointing
mitzimorris Mar 6, 2025
71c5863
Merge commit '092f96849701b63ae18d44cdbc3c1f0bd2efc6da' into HEAD
yashikno Mar 6, 2025
89d756b
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 6, 2025
261c269
claude 3.7 rewrite
mitzimorris Mar 6, 2025
e082f64
lint fix
mitzimorris Mar 6, 2025
db29cfa
merge fix
mitzimorris Mar 6, 2025
91f7d4c
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 6, 2025
b5e9955
unit tests
mitzimorris Mar 6, 2025
be90629
merge fix
mitzimorris Mar 6, 2025
e50ea62
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 6, 2025
877d6e5
lint fix
mitzimorris Mar 6, 2025
30a854e
Merge branch 'feature/in-memory-writer' of https://github.com/stan-de…
mitzimorris Mar 6, 2025
cef5f2b
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 6, 2025
ede4279
Merge branch 'feature/in-memory-writer' of https://github.com/stan-de…
mitzimorris Mar 6, 2025
a62c1ea
lint fix
mitzimorris Mar 6, 2025
09aba5a
unit tests working
mitzimorris Mar 7, 2025
0552dbe
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 7, 2025
3a430cf
added in_memory_writer and unit tests
mitzimorris Mar 7, 2025
b061f51
Merge branch 'feature/in-memory-writer' of https://github.com/stan-de…
mitzimorris Mar 7, 2025
811a033
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 7, 2025
61db5b8
in_memory_sampler unit test
mitzimorris Mar 7, 2025
9c08825
in_memory_sampler unit test
mitzimorris Mar 7, 2025
5ce6073
dispatcher unit test
mitzimorris Mar 7, 2025
a6aeb17
merge fix
mitzimorris Mar 7, 2025
4aced3e
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/stan_math
175 changes: 175 additions & 0 deletions src/stan/callbacks/dispatcher.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#ifndef STAN_CALLBACKS_DISPATCHER_HPP
#define STAN_CALLBACKS_DISPATCHER_HPP

#include <stan/callbacks/writer.hpp>
#include <stan/callbacks/structured_writer.hpp>
#include <memory>
#include <unordered_map>
#include <string>
#include <vector>
#include <stdexcept>
#include <utility>
#include <type_traits>
#include <variant>

namespace stan {
namespace callbacks {

enum class InfoType {
CONFIG, // series of string messages
SAMPLE, // draw from posterior
SAMPLE_RAW, // draw from posterior
METRIC, // struct with kv pairs 'metric_type', 'stepsize', 'inv_metric'
ALGORITHM_STATE, // sampler state for returned draw
};

struct InfoTypeHash {
std::size_t operator()(const InfoType& type) const {
return std::hash<int>()(static_cast<int>(type));
}
};

// Base type for type erasure.
class Channel {
public:
virtual ~Channel() = default;
};

// Adapter for plain writers.
class WriterChannel : public Channel {
public:
explicit WriterChannel(stan::callbacks::writer* w) : writer_(w) {
if (!w) {
throw std::runtime_error("config error, null writer");
}
}

// Handle all types that writer supports via operator()
void dispatch() { (*writer_)(); }
void dispatch(const std::string& value) { (*writer_)(value); }
void dispatch(const std::vector<double>& value) { (*writer_)(value); }
void dispatch(const std::vector<std::string>& value) { (*writer_)(value); }

// Handle any Eigen Matrix type
template <int R, int C>
void dispatch(const Eigen::Matrix<double, R, C>& value) {
(*writer_)(value);
}

// No key-value support for plain writers
template <typename T>
void dispatch(const std::string&, const T&) {}

private:
stan::callbacks::writer* writer_;
};

// Adapter for structured writers.
class StructuredWriterChannel : public Channel {
public:
explicit StructuredWriterChannel(stan::callbacks::structured_writer* sw)
: writer_(sw) {
if (!sw)
throw std::runtime_error("config error, null writer");
}
// Forward all key-value calls directly to the writer
void dispatch(const std::string& key) { writer_->write(key); }
// Perfect forwarding for any key-value pair
template <typename T>
void dispatch(const std::string& key, T&& value) {
writer_->write(key, std::forward<T>(value));
}
void begin_record() { writer_->begin_record(); }
void begin_record(const std::string& key) { writer_->begin_record(key); }
void end_record() { writer_->end_record(); }

private:
stan::callbacks::structured_writer* writer_;
};

// dispatcher class
class dispatcher {
public:
dispatcher() = default;
~dispatcher() = default;

void register_channel(InfoType type, std::unique_ptr<Channel> channel) {
channels_[type] = std::move(channel);
}

// no-arg call to writer operator ()
void dispatch(InfoType type) {
auto it = channels_.find(type);
if (it == channels_.end())
return; // silently do nothing
if (auto* wc = dynamic_cast<WriterChannel*>(it->second.get()))
wc->dispatch();
}

// single non-string argument - call writer operator ()
template <
typename T,
typename = std::enable_if_t<
std::is_same_v<
std::decay_t<T>,
std::
string> || std::is_same_v<std::decay_t<T>, std::vector<double>> || std::is_same_v<std::decay_t<T>, std::vector<std::string>>>> // NOLINT
void dispatch(InfoType type, T&& value) {
if (auto* wc = find_channel<WriterChannel>(type))
wc->dispatch(std::forward<T>(value));
}

// Eigen matrix types
template <int R, int C>
void dispatch(InfoType type, const Eigen::Matrix<double, R, C>& value) {
if (auto* wc = find_channel<WriterChannel>(type))
wc->dispatch(value);
}

// Key with no value (null)
void dispatch(InfoType type, const std::string& key) {
if (auto* sw = find_channel<StructuredWriterChannel>(type))
sw->dispatch(key);
}

// Key-value pairs (forward to structured writers)
template <typename T>
void dispatch(InfoType type, const std::string& key, T&& value) {
if (auto* sw = find_channel<StructuredWriterChannel>(type))
sw->dispatch(key, std::forward<T>(value));
}

// Record operations
void begin_record(InfoType type) {
if (auto* sw = find_channel<StructuredWriterChannel>(type))
sw->begin_record();
}

void begin_record(InfoType type, const std::string& key) {
if (auto* sw = find_channel<StructuredWriterChannel>(type))
sw->begin_record(key);
}

void end_record(InfoType type) {
if (auto* sw = find_channel<StructuredWriterChannel>(type))
sw->end_record();
}

private:
// Helper to find and cast a channel of specific type
template <typename ChannelType>
ChannelType* find_channel(InfoType type) {
auto it = channels_.find(type);
if (it == channels_.end())
return nullptr;
return dynamic_cast<ChannelType*>(it->second.get());
}

std::unordered_map<InfoType, std::unique_ptr<Channel>, InfoTypeHash>
channels_;
};

} // namespace callbacks
} // namespace stan

#endif // STAN_CALLBACKS_DISPATCHER_HPP
183 changes: 183 additions & 0 deletions src/stan/callbacks/in_memory_writer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#ifndef STAN_CALLBACKS_IN_MEMORY_WRITER_HPP
#define STAN_CALLBACKS_IN_MEMORY_WRITER_HPP

#include <vector>
#include <stdexcept>
#include <cstddef>
#include <Eigen/Dense>
#include <stan/callbacks/writer.hpp>

namespace stan {
namespace callbacks {

/**
* A general in-memory writer that stores draws from Stan's write_array
* callback into a contiguous Eigen matrix. The matrix is allocated with
* column-major storage (Eigen's default), meaning that the element at
* row i and column j is stored at index (i + j * num_rows) in memory.
*
* The writer accepts draws row-by-row and writes them into the matrix.
* If more rows are written than initially allocated, or if the input row
* does not have the expected number of columns, an exception is thrown.
*/
class in_memory_writer : public stan::callbacks::writer {
public:
/**
* Construct an in-memory writer.
*
* @param num_rows the total number of rows (draws) expected
* @param num_cols the number of columns (parameters) per draw
*
* The underlying Eigen matrix is allocated to have dimensions
* num_rows x num_cols in column-major order.
*/
in_memory_writer(std::size_t num_rows, std::size_t num_cols)
: num_rows_(num_rows),
num_cols_(num_cols),
data_(Eigen::MatrixXd::Zero(num_rows, num_cols)), // column-major
current_row_(0),
names_() {}

virtual ~in_memory_writer() {}

/**
* Writes a set of names.
*
* @param[in] names Names in a std::vector
*/
void operator()(const std::vector<std::string>& names) override {
names_ = names;
}

/**
* Writes a single row (draw) to the in-memory matrix.
*
* The input vector must have exactly num_cols elements.
* The row is written into the matrix at the current row index.
* If the matrix is full, an exception is thrown.
*
* @param row A vector containing one draw from the posterior.
*/
void operator()(const std::vector<double>& row) override {
if (row.size() != num_cols_) {
throw std::runtime_error("Row size does not match the number of columns");
}
if (current_row_ >= num_rows_) {
throw std::runtime_error("Attempted to write more rows than allocated");
}
// Because Eigen::MatrixXd is column-major by default, simply assigning
// row by row will place the data in memory in the order:
// index = current_row_ + (column_index * num_rows_).
for (std::size_t j = 0; j < num_cols_; ++j) {
data_(current_row_, j) = row[j];
}
++current_row_;
}

/**
* Default implementation for empty call.
*/
void operator()() override {}

/**
* Default implementation for string message.
*/
void operator()(const std::string& message) override {}

/**
* Handles Eigen matrix input by converting to rows and inserting.
* This method treats each row of the input matrix as a separate draw.
*/
void operator()(const Eigen::Matrix<double, -1, -1>& matrix) override {
for (Eigen::Index i = 0; i < matrix.rows(); ++i) {
std::vector<double> row(matrix.cols());
for (Eigen::Index j = 0; j < matrix.cols(); ++j) {
row[j] = matrix(i, j);
}
(*this)(row); // Use the vector<double> operator to insert the row
}
}

/**
* Handles Eigen vector input by inserting as a single row.
*/
void operator()(const Eigen::Matrix<double, -1, 1>& vector) override {
std::vector<double> row(vector.size());
for (Eigen::Index i = 0; i < vector.size(); ++i) {
row[i] = vector(i);
}
(*this)(row); // Use the vector<double> operator to insert the row
}

/**
* Handles Eigen row vector input by inserting as a single row.
*/
void operator()(const Eigen::Matrix<double, 1, -1>& vector) override {
std::vector<double> row(vector.size());
for (Eigen::Index i = 0; i < vector.size(); ++i) {
row[i] = vector(i);
}
(*this)(row); // Use the vector<double> operator to insert the row
}

/**
* Always returns true as the in-memory writer is always valid.
*/
bool is_valid() const noexcept override { return true; }

/**
* Returns a const reference to the in-memory Eigen matrix containing all
* draws.
*
* The matrix is stored in column-major order.
*
* @return const reference to the Eigen::MatrixXd holding the draws.
*/
const Eigen::MatrixXd& get_eigen_state_values() const { return data_; }

/**
* Returns a const reference to the column names.
*
* @return const reference to the vector of column names.
*/
const std::vector<std::string>& get_names() const { return names_; }

/**
* Returns the number of rows that have been written so far.
*
* @return The current row count.
*/
std::size_t get_row_count() const { return current_row_; }

/**
* Resets the writer to its initial state.
*
* Clears the stored data (sets the matrix to zero) and resets the current row
* index. Column names are retained.
*/
void reset() {
current_row_ = 0;
data_.setZero();
}

/**
* Fully resets the writer, including clearing column names.
*/
void clear() {
current_row_ = 0;
data_.setZero();
names_.clear();
}

private:
std::size_t num_rows_; // Total number of draws (rows) expected.
std::size_t num_cols_; // Number of parameters (columns) per draw.
Eigen::MatrixXd data_; // Internal storage; Eigen matrices are column-major.
std::size_t current_row_; // Next row index to be written.
std::vector<std::string> names_; // Column names
};

} // namespace callbacks
} // namespace stan

#endif // STAN_CALLBACKS_IN_MEMORY_WRITER_HPP
Loading