diff --git a/.gitignore b/.gitignore index 5e504a5a7b..4eb1b58b86 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,7 @@ output*.csv *.o # gdb .gdb_history + +#vscode +.vscode/* +.vscode/** diff --git a/src/cmdstan/arguments/arg_pathfinder.hpp b/src/cmdstan/arguments/arg_pathfinder.hpp index 1999807462..c414b89a8c 100644 --- a/src/cmdstan/arguments/arg_pathfinder.hpp +++ b/src/cmdstan/arguments/arg_pathfinder.hpp @@ -21,6 +21,22 @@ class arg_pathfinder : public arg_lbfgs { _subarguments.push_back(new arg_single_bool( "save_single_paths", "Output single-path pathfinder draws as CSV", false)); + _subarguments.push_back(new arg_single_bool( + "psis_resample", + "If true, perform psis resampling on samples returned" + " from individual pathfinders. If false, returns num_paths * num_draws" + " samples", + true)); + _subarguments.push_back(new arg_single_bool( + "calculate_lp", + "If true, individual pathfinders lp calculations are calculated and" + " returned with the output. If false, each pathfinder will only " + " calculate the lp values needed for the elbo calculation." + " If false, psis resampling cannot be performed and" + " the algorithm returns num_paths * num_draws samples." + " The output will still contain any lp values used when" + " calculating ELBO scores within LBFGS iterations.", + true)); _subarguments.push_back(new arg_single_int_pos( "max_lbfgs_iters", "Maximum number of LBFGS iterations", 1000)); _subarguments.push_back(new arg_single_int_pos( diff --git a/src/cmdstan/command.hpp b/src/cmdstan/command.hpp index e4b2d8b2cf..53f620eed7 100644 --- a/src/cmdstan/command.hpp +++ b/src/cmdstan/command.hpp @@ -313,7 +313,10 @@ int command(int argc, const char *argv[]) { int num_draws = get_arg_val(*pathfinder_arg, "num_draws"); int num_psis_draws = get_arg_val(*pathfinder_arg, "num_psis_draws"); - + bool psis_resample + = get_arg_val(*pathfinder_arg, "psis_resample"); + bool calculate_lp + = get_arg_val(*pathfinder_arg, "calculate_lp"); if (num_psis_draws > num_draws * num_chains) { logger.warn( "Warning: Number of PSIS draws is larger than the total number of " @@ -332,7 +335,7 @@ int command(int argc, const char *argv[]) { history_size, init_alpha, tol_obj, tol_rel_obj, tol_grad, tol_rel_grad, tol_param, max_lbfgs_iters, num_elbo_draws, num_draws, save_single_paths, refresh, interrupt, logger, init_writer, - sample_writers[0], diagnostic_json_writers[0]); + sample_writers[0], diagnostic_json_writers[0], calculate_lp); } else { auto output_filenames = make_filenames(output_file, "", ".csv", 1, id); auto ofs = std::make_unique(output_filenames[0]); @@ -348,7 +351,7 @@ int command(int argc, const char *argv[]) { max_lbfgs_iters, num_elbo_draws, num_draws, num_psis_draws, num_chains, save_single_paths, refresh, interrupt, logger, init_writers, sample_writers, diagnostic_json_writers, - pathfinder_writer, dummy_json_writer); + pathfinder_writer, dummy_json_writer, calculate_lp, psis_resample); } // ---- pathfinder end ---- // } else if (user_method->arg("generate_quantities")) { diff --git a/src/cmdstan/command_helper.hpp b/src/cmdstan/command_helper.hpp index 105b2b9515..7e2cf38c4f 100644 --- a/src/cmdstan/command_helper.hpp +++ b/src/cmdstan/command_helper.hpp @@ -78,7 +78,12 @@ inline constexpr auto get_arg(List &&arg_list, const char *arg1, */ template inline constexpr auto get_arg_val(Arg &&argument, const char *arg_name) { - return dynamic_cast *>(argument.arg(arg_name))->value(); + auto *arg = argument.arg(arg_name); + if (arg) { + return dynamic_cast *>(arg)->value(); + } else { + throw std::invalid_argument(std::string("Unable to find: ") + arg_name); + } } /**