Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the Dirichlet-Multinomial distribution (issue 54) #2979

Merged
merged 10 commits into from
Dec 17, 2023
Merged

Added the Dirichlet-Multinomial distribution (issue 54) #2979

merged 10 commits into from
Dec 17, 2023

Conversation

chvandorp
Copy link
Contributor

Summary

This PR adds the Dirichlet-Multinomial distribution to the Stan math library. The Dirichlet-Multinomial (DirMult) distribution generalizes the Beta-Binomial distribution with more than two categories. It can also be seen as an over-dispersed multinomial distribution. It is implemented in other popular frameworks such as Pyro and the Python package scipy. In issue 54 (first created in 2014) it is suggested to add this distribution.

I largely based my implementation on the existing multinomial distribution. As such the DirMult distribution is currently not vectorized, as I could not find an example of a vectorized multivariate discrete distribution in the current code base. However, I would be interested in implementing vectorization in a future PR. I have added clear Doxygen documentation for the lpmf, rng (and log) functions.

I added 4 files for the log-PMF, the log function, the PRNG, a number of tests, and I updated the prob.hpp header. I have also prepared PRs for stanc3 and the docs, and tested the new native distribution in action with a Stan model using cmdstanpy.

Tests

I have included 8 unit tests.

  • test 1 compares LPMF values with pre-computed values (using scipy.stats).
  • test 2 checks behavior of propto (a log-prob of 0.0)
  • test 3 checks that the right exceptions are thrown in case of incorrect arguments
  • test 4 checks that the observation [0, 0, ..., 0] has log-prob 0.0
  • test 5 checks that the PRNG returns values in the right domain
  • test 6 checks that the PRNG throws the correct exceptions
  • test 7 is a goodness-of-fit test
  • test 8 checks that for two categories, the DirMult coincides the BetaBinom

Side Effects

There should not be any side effects.

Release notes

Added the Dirichlet-Multinomial distribution to the Stan Math library (dirichlet_multinomial_lpmf, dirichlet_multinomial_log, and dirichlet_multinomial_rng).

Checklist

  • [x ] Copyright holder: Christiaan H. van Dorp

By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
- Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
- Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)

  • [x ] the basic tests are passing

    • unit tests pass (to run, use: ./runTests.py test/unit)
    • header checks pass, (make test-headers)
    • dependencies checks pass, (make test-math-dependencies)
    • docs build, (make doxygen)
    • code passes the built in C++ standards checks (make cpplint)
  • [x ] the code is written in idiomatic C++ and changes are documented in the doxygen

  • [x ] the new changes are tested

@SteveBronder
Copy link
Collaborator

Thanks! Can you also add a test in test/unit/math/mix/prob/ that uses the expect_ad test framework? The expect_ad() function will check and compare the gradients and higher order derivativers against a finite difference method to see if they are close. You can see how the other mix tests files use the expect_ad() function as examples.

@chvandorp
Copy link
Contributor Author

Thanks @SteveBronder! I wrote the AD test. But before I push this, do you know how to fix this continuous integration issue I get? I think it might be because I named my fork "stan_math" instead of "math"

@SteveBronder
Copy link
Collaborator

Yes the name looks to be an issue. We can fix it on our side but it might just be easier to change the name of your fork

Copy link
Collaborator

@andrjohns andrjohns left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! A couple of minor requests to code style and files.

I've also updated the lpmf function to take advantage of our/Eigen's vectorised functions and added analytical gradients. Let me know if there's anything in those changes that you disagree with or similar.

Thanks!

@chvandorp
Copy link
Contributor Author

Thanks @andrjohns for the improvements. I've added a (ns_array > 0).select to the ops_partials definition. It does not change the math, but I found that it gives a significant speed increase when the data contains lots of zeros. I also added a test case to the AD test that contains zeros in the count vector to cover this case.

Copy link
Collaborator

@andrjohns andrjohns left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for adding this!

@andrjohns andrjohns merged commit 56d0432 into stan-dev:develop Dec 17, 2023
@chvandorp chvandorp deleted the feature/issue-54-dirichlet-multinomial branch December 18, 2023 15:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants