@@ -30,6 +30,32 @@ let rec change_kwrds_stmts s =
30
30
| x -> map Fn. id change_kwrds_stmts x in
31
31
{s with pattern}
32
32
33
+ (* * A list of functions which return an Eigen block expression *)
34
+ let eigen_block_expr_fns =
35
+ [" head" ; " tail" ; " segment" ; " col" ; " row" ; " block" ; " sub_row" ; " sub_col" ]
36
+ |> String.Set. of_list
37
+
38
+ let eval_eigen_blocks e =
39
+ let open Expr.Fixed in
40
+ let f ({pattern; meta} as expr ) =
41
+ match (pattern, UnsizedType. is_eigen_type (Expr.Typed. type_of expr)) with
42
+ | Indexed _ , true ->
43
+ {meta; pattern= FunApp (StanLib (" eval" , FnPlain , AoS ), [expr])}
44
+ | FunApp (StanLib (fname, _, _), _), true
45
+ when Set. mem eigen_block_expr_fns fname ->
46
+ {meta; pattern= FunApp (StanLib (" eval" , FnPlain , AoS ), [expr])}
47
+ | _ -> expr in
48
+ rewrite_bottom_up ~f e
49
+
50
+ let eval_udf_indexed_calls e =
51
+ let open Expr.Fixed in
52
+ let f ({pattern; _} as expr ) =
53
+ match pattern with
54
+ | FunApp ((UserDefined (_ , _ ) as kind ), args ) ->
55
+ {expr with pattern= FunApp (kind, List. map ~f: eval_eigen_blocks args)}
56
+ | _ -> expr in
57
+ rewrite_bottom_up ~f e
58
+
33
59
let opencl_trigger_restrictions =
34
60
String.Map. of_alist_exn
35
61
[ ( " bernoulli_lpmf"
@@ -480,6 +506,33 @@ let trans_prog (p : Program.Typed.t) =
480
506
|> map translate_funapps_and_kwrds map_stmt
481
507
|> map Fn. id change_kwrds_stmts) in
482
508
let p = Program. map Fn. id map_fn_names p in
509
+ (* Eval indexed eigen types in UDF calls to prevent
510
+ infinite template expansion if the call is recursive
511
+ *)
512
+ let possibly_recursive_fns =
513
+ List. filter_map
514
+ ~f: (function
515
+ | {fdname; fdargs; fdbody = None ; _} -> Some (fdname, fdargs) | _ -> None
516
+ )
517
+ p.functions_block
518
+ |> Set.Poly. of_list in
519
+ let rec map_stmt {Stmt.Fixed. pattern; meta} =
520
+ match pattern with
521
+ | NRFunApp ((UserDefined _ as kind ), args ) ->
522
+ { Stmt.Fixed. meta
523
+ ; pattern= NRFunApp (kind, List. map ~f: eval_eigen_blocks args) }
524
+ | _ ->
525
+ { Stmt.Fixed. pattern=
526
+ Stmt.Fixed.Pattern. map eval_udf_indexed_calls map_stmt pattern
527
+ ; meta } in
528
+ let eval_udf_indexed_stmts (s : 'a Program.fun_def ) =
529
+ if Set. mem possibly_recursive_fns (s.fdname, s.fdargs) then
530
+ {s with fdbody= Option. map ~f: map_stmt s.fdbody}
531
+ else s in
532
+ let p =
533
+ { p with
534
+ functions_block= List. map ~f: eval_udf_indexed_stmts p.functions_block }
535
+ in
483
536
let init_pos =
484
537
[ Stmt.Fixed.Pattern. Decl
485
538
{ decl_adtype= DataOnly
0 commit comments