Skip to content

Commit e30247a

Browse files
committed
fixes for general_matrix_matrix_product
1 parent 95748de commit e30247a

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

stan/math/rev/core/Eigen_NumTraits.hpp

+16
Original file line numberDiff line numberDiff line change
@@ -389,11 +389,19 @@ struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
389389
}
390390
};
391391

392+
#if EIGEN_VERSION_AT_LEAST(3, 3, 8)
393+
template <typename Index, int LhsStorageOrder, bool ConjugateLhs,
394+
int RhsStorageOrder, bool ConjugateRhs, int ResInnerStride>
395+
struct general_matrix_matrix_product<
396+
Index, stan::math::var, LhsStorageOrder, ConjugateLhs, stan::math::var,
397+
RhsStorageOrder, ConjugateRhs, ColMajor, ResInnerStride> {
398+
#else
392399
template <typename Index, int LhsStorageOrder, bool ConjugateLhs,
393400
int RhsStorageOrder, bool ConjugateRhs>
394401
struct general_matrix_matrix_product<Index, stan::math::var, LhsStorageOrder,
395402
ConjugateLhs, stan::math::var,
396403
RhsStorageOrder, ConjugateRhs, ColMajor> {
404+
#endif
397405
using LhsScalar = stan::math::var;
398406
using RhsScalar = stan::math::var;
399407
using ResScalar = stan::math::var;
@@ -406,11 +414,19 @@ struct general_matrix_matrix_product<Index, stan::math::var, LhsStorageOrder,
406414
= const_blas_data_mapper<stan::math::var, Index, RhsStorageOrder>;
407415

408416
EIGEN_DONT_INLINE
417+
#if EIGEN_VERSION_AT_LEAST(3, 3, 8)
418+
static void run(Index rows, Index cols, Index depth, const LhsScalar* lhs,
419+
Index lhsStride, const RhsScalar* rhs, Index rhsStride,
420+
ResScalar* res, Index resIncr, Index resStride, const ResScalar& alpha,
421+
level3_blocking<LhsScalar, RhsScalar>& /* blocking */,
422+
GemmParallelInfo<Index>* /* info = 0 */) {
423+
#else
409424
static void run(Index rows, Index cols, Index depth, const LhsScalar* lhs,
410425
Index lhsStride, const RhsScalar* rhs, Index rhsStride,
411426
ResScalar* res, Index resStride, const ResScalar& alpha,
412427
level3_blocking<LhsScalar, RhsScalar>& /* blocking */,
413428
GemmParallelInfo<Index>* /* info = 0 */) {
429+
#endif
414430
for (Index i = 0; i < cols; i++) {
415431
general_matrix_vector_product<
416432
Index, LhsScalar, LhsMapper, LhsStorageOrder, ConjugateLhs, RhsScalar,

0 commit comments

Comments
 (0)