Skip to content

Commit 8fce5fb

Browse files
authored
Merge pull request #1228 from stan-dev/fix/1224-recursive-indexing
Call eval on UDF arguments which are indices in UDF bodies
2 parents e1c6ac6 + 43b8282 commit 8fce5fb

File tree

3 files changed

+926
-0
lines changed

3 files changed

+926
-0
lines changed

src/stan_math_backend/Transform_Mir.ml

+53
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,32 @@ let rec change_kwrds_stmts s =
3030
| x -> map Fn.id change_kwrds_stmts x in
3131
{s with pattern}
3232

33+
(** A list of functions which return an Eigen block expression *)
34+
let eigen_block_expr_fns =
35+
["head"; "tail"; "segment"; "col"; "row"; "block"; "sub_row"; "sub_col"]
36+
|> String.Set.of_list
37+
38+
let eval_eigen_blocks e =
39+
let open Expr.Fixed in
40+
let f ({pattern; meta} as expr) =
41+
match (pattern, UnsizedType.is_eigen_type (Expr.Typed.type_of expr)) with
42+
| Indexed _, true ->
43+
{meta; pattern= FunApp (StanLib ("eval", FnPlain, AoS), [expr])}
44+
| FunApp (StanLib (fname, _, _), _), true
45+
when Set.mem eigen_block_expr_fns fname ->
46+
{meta; pattern= FunApp (StanLib ("eval", FnPlain, AoS), [expr])}
47+
| _ -> expr in
48+
rewrite_bottom_up ~f e
49+
50+
let eval_udf_indexed_calls e =
51+
let open Expr.Fixed in
52+
let f ({pattern; _} as expr) =
53+
match pattern with
54+
| FunApp ((UserDefined (_, _) as kind), args) ->
55+
{expr with pattern= FunApp (kind, List.map ~f:eval_eigen_blocks args)}
56+
| _ -> expr in
57+
rewrite_bottom_up ~f e
58+
3359
let opencl_trigger_restrictions =
3460
String.Map.of_alist_exn
3561
[ ( "bernoulli_lpmf"
@@ -480,6 +506,33 @@ let trans_prog (p : Program.Typed.t) =
480506
|> map translate_funapps_and_kwrds map_stmt
481507
|> map Fn.id change_kwrds_stmts) in
482508
let p = Program.map Fn.id map_fn_names p in
509+
(* Eval indexed eigen types in UDF calls to prevent
510+
infinite template expansion if the call is recursive
511+
*)
512+
let possibly_recursive_fns =
513+
List.filter_map
514+
~f:(function
515+
| {fdname; fdargs; fdbody= None; _} -> Some (fdname, fdargs) | _ -> None
516+
)
517+
p.functions_block
518+
|> Set.Poly.of_list in
519+
let rec map_stmt {Stmt.Fixed.pattern; meta} =
520+
match pattern with
521+
| NRFunApp ((UserDefined _ as kind), args) ->
522+
{ Stmt.Fixed.meta
523+
; pattern= NRFunApp (kind, List.map ~f:eval_eigen_blocks args) }
524+
| _ ->
525+
{ Stmt.Fixed.pattern=
526+
Stmt.Fixed.Pattern.map eval_udf_indexed_calls map_stmt pattern
527+
; meta } in
528+
let eval_udf_indexed_stmts (s : 'a Program.fun_def) =
529+
if Set.mem possibly_recursive_fns (s.fdname, s.fdargs) then
530+
{s with fdbody= Option.map ~f:map_stmt s.fdbody}
531+
else s in
532+
let p =
533+
{ p with
534+
functions_block= List.map ~f:eval_udf_indexed_stmts p.functions_block }
535+
in
483536
let init_pos =
484537
[ Stmt.Fixed.Pattern.Decl
485538
{ decl_adtype= DataOnly

0 commit comments

Comments
 (0)