Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement jacobian+= statement #1435

Merged
merged 4 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/analysis_and_optimization/Factor_graph.ml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type factor_graph =

let extract_factors_statement stmt =
match stmt with
| Stmt.Fixed.Pattern.TargetPE e ->
| Stmt.Fixed.Pattern.TargetPE e | JacobianPE e ->
List.map (summation_terms e) ~f:(fun x -> TargetTerm x)
| NRFunApp (CompilerInternal (FnReject | FnFatalError), _) -> [Reject]
| NRFunApp ((UserDefined (s, FnTarget) | StanLib (s, FnTarget, _)), args) ->
Expand Down
3 changes: 2 additions & 1 deletion src/analysis_and_optimization/Memory_patterns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
| SList lst | Profile (_, lst) | Block lst ->
Set.Poly.union_list
(List.map ~f:(query_initial_demotable_stmt in_loop acc) lst)
| TargetPE expr -> query_expr acc expr
| TargetPE expr | JacobianPE expr -> query_expr acc expr
(* NOTE: loops generated by inlining are not actually loops;
we do not unconditionally set "in_loop" *)
| For
Expand Down Expand Up @@ -603,6 +603,7 @@ let rec modify_stmt_pattern
; upper= mod_expr false upper
; body= mod_stmt body }
| TargetPE expr -> TargetPE ((mod_expr false) expr)
| JacobianPE expr -> JacobianPE ((mod_expr false) expr)
| Return optional_expr ->
Return (Option.map ~f:(mod_expr false) optional_expr)
| Profile ((p_name : string), stmt) ->
Expand Down
2 changes: 2 additions & 0 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ let fwd_traverse_statement stmt ~init ~f =
(s', SList (List.rev ls))
| Assignment _ as s -> (init, s)
| TargetPE _ as s -> (init, s)
| JacobianPE _ as s -> (init, s)
| NRFunApp _ as s -> (init, s)
| Break as s -> (init, s)
| Continue as s -> (init, s)
Expand Down Expand Up @@ -268,6 +269,7 @@ let stmt_rhs stmt =
|While (rhs, _)
|Assignment (_, _, rhs)
|TargetPE rhs
|JacobianPE rhs
|Return (Some rhs) ->
Set.Poly.singleton rhs
| Return None
Expand Down
36 changes: 21 additions & 15 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ let rec free_vars_lval ((lval, idxs) : Expr.Typed.t Stmt.Fixed.Pattern.lvalue) =
let rec free_vars_stmt (s : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t)
=
match s with
| Return (Some e) | TargetPE e -> free_vars_expr e
| Return (Some e) | TargetPE e | JacobianPE e -> free_vars_expr e
| Assignment (l, _, e) ->
Set.Poly.union_list [free_vars_expr e; free_vars_lval l]
| NRFunApp (kind, l) -> free_vars_fnapp kind l
Expand All @@ -81,8 +81,8 @@ let top_free_vars_stmt
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
(s : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) =
match s with
| Assignment _ | Return _ | TargetPE _ | NRFunApp _ | Decl _ | Break
|Continue | Skip ->
| Assignment _ | Return _ | TargetPE _ | JacobianPE _ | NRFunApp _ | Decl _
|Break | Continue | Skip ->
free_vars_stmt
(statement_stmt_loc_of_statement_stmt_loc_num flowgraph_to_mir s)
| While (e, _) | IfElse (e, _, _) -> free_vars_expr e
Expand Down Expand Up @@ -349,7 +349,7 @@ let constant_propagation_transfer ?(preserve_stability = false)
| Decl {decl_id= s; _} -> Map.remove m s
| Assignment (lhs, _, _) ->
Map.remove m (Middle.Stmt.Helpers.lhs_variable lhs)
| TargetPE _
| TargetPE _ | JacobianPE _
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
Expand Down Expand Up @@ -407,7 +407,7 @@ let expression_propagation_transfer ?(preserve_stability = false)
Set.Poly.union_list
(List.map ~f:(label_top_decls flowgraph_to_mir) b) in
Set.fold kills ~init:m ~f:kill_var
| TargetPE _
| TargetPE _ | JacobianPE _
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
Expand Down Expand Up @@ -451,7 +451,7 @@ let copy_propagation_transfer (globals : string Set.Poly.t)
(List.map ~f:(label_top_decls flowgraph_to_mir) stmt_lst)
in
Set.fold kills ~init:expr_map ~f:kill_var
| TargetPE _
| TargetPE _ | JacobianPE _
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip | SList _
|IfElse (_, _, _)
Expand All @@ -472,8 +472,11 @@ let assigned_vars_stmt (s : (Expr.Typed.t, 'a) Stmt.Fixed.Pattern.t) =
match s with
| Assignment (lhs, _, _) ->
Set.Poly.singleton (Middle.Stmt.Helpers.lhs_variable lhs)
| TargetPE _ -> Set.Poly.singleton "target"
| NRFunApp ((UserDefined (_, FnTarget) | StanLib (_, FnTarget, _)), _) ->
| TargetPE _ | JacobianPE _ -> Set.Poly.singleton "target"
| NRFunApp
( ( UserDefined (_, (FnTarget | FnJacobian))
| StanLib (_, (FnTarget | FnJacobian), _) )
, _ ) ->
Set.Poly.singleton "target"
| For {loopvar= x; _} -> Set.Poly.singleton x
| Decl {decl_id= _; _}
Expand Down Expand Up @@ -514,9 +517,12 @@ let reaching_definitions_transfer
|Assignment ((LVariable x, []), _, _)
|For {loopvar= x; _} ->
Set.filter p ~f:(fun (y, _) -> y = x)
| TargetPE _ -> Set.filter p ~f:(fun (y, _) -> y = "target")
| NRFunApp ((UserDefined (_, FnTarget) | StanLib (_, FnTarget, _)), _)
->
| TargetPE _ | JacobianPE _ ->
Set.filter p ~f:(fun (y, _) -> y = "target")
| NRFunApp
( ( UserDefined (_, (FnTarget | FnJacobian))
| StanLib (_, (FnTarget | FnJacobian), _) )
, _ ) ->
Set.filter p ~f:(fun (y, _) -> y = "target")
| NRFunApp (_, _)
|Break | Continue | Return _ | Skip
Expand Down Expand Up @@ -558,7 +564,7 @@ let live_variables_transfer (never_kill : string Set.Poly.t)
match mir_node with
| Assignment ((LVariable x, []), _, _) | Decl {decl_id= x; _} ->
Set.Poly.singleton x
| TargetPE _
| TargetPE _ | JacobianPE _
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
Expand Down Expand Up @@ -616,7 +622,7 @@ let used_expressions_expr e = Expr.Typed.Set.singleton e
let rec used_expressions_stmt_help f
(s : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) =
match s with
| TargetPE e | Return (Some e) -> f e
| TargetPE e | JacobianPE e | Return (Some e) -> f e
| Assignment (l, _, e) -> Set.union (f e) (used_expressions_lval f l)
| IfElse (e, b1, Some b2) ->
Expr.Typed.Set.union_list
Expand Down Expand Up @@ -650,7 +656,7 @@ let used_expressions_stmt = used_expressions_stmt_help used_expressions_expr
let top_used_expressions_stmt_help f
(s : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) =
match s with
| TargetPE e | Return (Some e) -> f e
| TargetPE e | JacobianPE e | Return (Some e) -> f e
| Assignment (l, _, e) -> Set.union (f e) (used_expressions_lval f l)
| While (e, _) | IfElse (e, _, _) -> f e
| NRFunApp (k, l) ->
Expand Down Expand Up @@ -882,7 +888,7 @@ let rec declared_variables_stmt
match s with
| Decl {decl_id= x; _} -> Set.Poly.singleton x
| Assignment (_, _, _)
|TargetPE _
|TargetPE _ | JacobianPE _
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip ->
Set.Poly.empty
Expand Down
7 changes: 5 additions & 2 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.{pattern; meta} =
| TargetPE e ->
let d, s, e = inline_function_expression propto adt fim e in
slist_concat_no_loc (d @ s) (TargetPE e)
| JacobianPE e ->
let d, s, e = inline_function_expression propto adt fim e in
slist_concat_no_loc (d @ s) (JacobianPE e)
| NRFunApp (kind, exprs) ->
let d_list, s_list, es =
inline_list (inline_function_expression propto adt fim) exprs
Expand Down Expand Up @@ -531,7 +534,7 @@ let rec contains_top_break_or_continue Stmt.Fixed.{pattern; _} =
match pattern with
| Break | Continue -> true
| Assignment (_, _, _)
|TargetPE _
|TargetPE _ | JacobianPE _
|NRFunApp (_, _)
|Return _ | Decl _
|While (_, _)
Expand Down Expand Up @@ -749,7 +752,7 @@ let dead_code_elimination (mir : Program.Typed.t) =
remove an assignment to a variable
due to side effects. *)
(* TODO: maybe we should revisit that. *)
| Decl _ | TargetPE _
| Decl _ | TargetPE _ | JacobianPE _
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip ->
stmt
Expand Down
29 changes: 22 additions & 7 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ type ('e, 's, 'l, 'f) statement =
{assign_lhs: 'l lvalue_pack; assign_op: assignmentoperator; assign_rhs: 'e}
| NRFunApp of 'f * identifier * 'e list
| TargetPE of 'e
| JacobianPE of 'e
| Tilde of
{ arg: 'e
; distribution: identifier
Expand Down Expand Up @@ -285,12 +286,25 @@ let rec untyped_lvalue_of_typed_lvalue_pack :

(** Forgetful function from typed to untyped statements *)
let rec untyped_statement_of_typed_statement {stmt; smeta} =
{ stmt=
map_statement untyped_expression_of_typed_expression
untyped_statement_of_typed_statement untyped_lvalue_of_typed_lvalue
(fun _ -> ())
stmt
; smeta= {loc= smeta.loc} }
match stmt with
(* TODO(2.38): Remove this workaround *)
| JacobianPE e ->
{ stmt=
Assignment
{ assign_lhs=
LValue
{ lval= LVariable {name= "jacobian"; id_loc= smeta.loc}
; lmeta= {loc= smeta.loc} }
; assign_op= OperatorAssign Plus
; assign_rhs= untyped_expression_of_typed_expression e }
; smeta= {loc= smeta.loc} }
| _ ->
{ stmt=
map_statement untyped_expression_of_typed_expression
untyped_statement_of_typed_statement untyped_lvalue_of_typed_lvalue
(fun _ -> ())
stmt
; smeta= {loc= smeta.loc} }

(** Forgetful function from typed to untyped programs *)
let untyped_program_of_typed_program : typed_program -> untyped_program =
Expand Down Expand Up @@ -390,7 +404,8 @@ let get_first_loc (s : untyped_statement) =
|ForEach (id, _, _)
|FunDef {funname= id; _} ->
id.id_loc.begin_loc
| TargetPE e | Return e | IfThenElse (e, _, _) | While (e, _) ->
| TargetPE e | JacobianPE e | Return e | IfThenElse (e, _, _) | While (e, _)
->
e.emeta.loc.begin_loc
| Assignment _ | Profile _ | Block _ | Tilde _ | Break | Continue
|ReturnVoid | Print _ | Reject _ | FatalError _ | Skip ->
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
| Ast.NRFunApp (fn_kind, {name; _}, args) ->
NRFunApp (trans_fn_kind fn_kind name, trans_exprs args) |> swrap
| Ast.TargetPE e -> TargetPE (trans_expr e) |> swrap
| Ast.JacobianPE e -> JacobianPE (trans_expr e) |> swrap
| Ast.Tilde {arg; distribution; args; truncation} ->
let suffix =
Stan_math_signatures.dist_name_suffix ud_dists distribution.name in
Expand Down
6 changes: 0 additions & 6 deletions src/frontend/Input_warnings.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,3 @@ let empty file =
add_warning Middle.Location_span.empty
("Empty file '" ^ file
^ "' detected; this is a valid stan model but likely unintended!")

let future_keyword kwrd version positions =
add_warning
(Preprocessor.location_span_of_positions positions)
("Variable name '" ^ kwrd ^ "' will be a reserved word starting in Stan "
^ version ^ ". Please rename it!")
4 changes: 0 additions & 4 deletions src/frontend/Input_warnings.mli
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,3 @@ val add_warning : Middle.Location_span.t -> string -> unit

val empty : string -> unit
(** Register that an empty file is being parsed *)

val future_keyword :
string -> string -> Lexing.position * Lexing.position -> unit
(** Warn on a keyword which will be reserved in the future*)
1 change: 1 addition & 0 deletions src/frontend/Pretty_printing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ and pp_statement ppf ({stmt= s_content; smeta= {loc}} as ss : untyped_statement)
| NRFunApp (_, id, es) ->
pf ppf "%a(@[%a);@]" pp_identifier id pp_list_of_expression (es, loc)
| TargetPE e -> pf ppf "target += %a;" pp_expression e
| JacobianPE e -> pf ppf "jacobian += %a;" pp_expression e
| Tilde {arg= e; distribution= id; args= es; truncation= t} ->
pf ppf "%a ~ %a(@[%a)@]%a;" pp_expression e pp_identifier id
pp_list_of_expression (es, loc) pp_truncation t
Expand Down
12 changes: 10 additions & 2 deletions src/frontend/Semantic_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ module StatementError = struct
| InvalidTildeCDForCCDF of string
| InvalidTildeNoSuchDistribution of string * bool
| TargetPlusEqualsOutsideModelOrLogProb
| JacobianPlusEqualsNotAllowed
| InvalidTruncationCDForCCDF of
(UnsizedType.autodifftype * UnsizedType.t) list
| BreakOutsideLoop
Expand Down Expand Up @@ -428,11 +429,15 @@ module StatementError = struct
Fmt.(list ~sep:comma string)
ids
| TargetPlusEqualsOutsideModelOrLogProb ->
Fmt.pf ppf
Fmt.string ppf
"Target can only be accessed in the model block or in definitions of \
functions with the suffix _lp."
| JacobianPlusEqualsNotAllowed ->
Fmt.string ppf
"The jacobian adjustment can only be applied in the transformed \
parameters block or in functions ending with _jacobian"
| InvalidTildePDForPMF ->
Fmt.pf ppf
Fmt.string ppf
"~ statement should refer to a distribution without its \
\"_lpdf/_lupdf\" or \"_lpmf/_lupmf\" suffix.\n\
For example, \"target += normal_lpdf(y, 0, 1)\" should become \"y ~ \
Expand Down Expand Up @@ -722,6 +727,9 @@ let invalid_tilde_no_such_dist loc name is_int =
let target_plusequals_outside_model_or_logprob loc =
StatementError (loc, StatementError.TargetPlusEqualsOutsideModelOrLogProb)

let jacobian_plusequals_not_allowed loc =
StatementError (loc, StatementError.JacobianPlusEqualsNotAllowed)

let invalid_truncation_cdf_or_ccdf loc args =
StatementError (loc, StatementError.InvalidTruncationCDForCCDF args)

Expand Down
1 change: 1 addition & 0 deletions src/frontend/Semantic_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ val invalid_tilde_pdf_or_pmf : Location_span.t -> t
val invalid_tilde_cdf_or_ccdf : Location_span.t -> string -> t
val invalid_tilde_no_such_dist : Location_span.t -> string -> bool -> t
val target_plusequals_outside_model_or_logprob : Location_span.t -> t
val jacobian_plusequals_not_allowed : Location_span.t -> t

val invalid_truncation_cdf_or_ccdf :
Location_span.t -> (UnsizedType.autodifftype * UnsizedType.t) list -> t
Expand Down
3 changes: 2 additions & 1 deletion src/frontend/SignatureMismatch.ml
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) =
| Fun_kind.FnPlain -> "a pure function"
| FnRng -> "an rng function"
| FnLpdf () -> "a probability density or mass function"
| FnTarget -> "an _lp function" in
| FnTarget -> "an _lp function"
| FnJacobian -> "a _jacobian function" in
let index_str = function
| 1 -> "first"
| 2 -> "second"
Expand Down
Loading