Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate forward decls #1277

Merged
merged 13 commits into from
Jan 12, 2023
5 changes: 5 additions & 0 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ let rec lvalue_of_expr {expr; emeta} =
let rec id_of_lvalue {lval; _} =
match lval with LVariable s -> s | LIndexed (l, _) -> id_of_lvalue l

let type_of_arguments :
(UnsizedType.autodifftype * UnsizedType.t * 'a) list
-> UnsizedType.argumentlist =
List.map ~f:(fun (a, t, _) -> (a, t))

(* XXX: the parser produces inaccurate locations: smeta.loc.begin_loc is the last
token before the current statement and all the whitespace between two statements
appears as if it were part of the second statement.
Expand Down
11 changes: 10 additions & 1 deletion src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,16 @@ let repair_syntax program settings =
let canonicalize_program program settings : typed_program =
let program =
if settings.deprecations then
program
let fundefs = userdef_functions program in
let drop_forwarddecl = function
| {stmt= FunDef {body= {stmt= Skip; _}; funname; arguments; _}; _}
when is_redundant_forwarddecl fundefs funname arguments ->
false
| _ -> true in
{ program with
functionblock=
Option.map program.functionblock ~f:(fun x ->
{x with stmts= List.filter ~f:drop_forwarddecl x.stmts} ) }
|> map_program
(replace_deprecated_stmt (collect_userdef_distributions program))
else program in
Expand Down
44 changes: 35 additions & 9 deletions src/frontend/Deprecation_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ let is_deprecated_distribution name =
let rename_deprecated map name =
Map.find map name |> Option.map ~f:fst |> Option.value ~default:name

let userdef_functions program =
match program.functionblock with
| None -> []
| Some {stmts; _} ->
List.filter_map stmts ~f:(function
| {stmt= FunDef {body= {stmt= Skip; _}; _}; _} -> None
| {stmt= FunDef {funname; arguments; _}; _} ->
Some (funname.name, Ast.type_of_arguments arguments)
| _ -> None )

let is_redundant_forwarddecl fundefs funname arguments =
let equal (id1, a1) (id2, a2) =
String.equal id1 id2 && UnsizedType.equal_argumentlist a1 a2 in
List.mem ~equal fundefs (funname.name, Ast.type_of_arguments arguments)

let userdef_distributions stmts =
let open String in
List.filter_map
Expand Down Expand Up @@ -156,9 +171,16 @@ let rec collect_deprecated_expr (acc : (Location_span.t * string) list)
let collect_deprecated_lval acc l =
fold_lval_with collect_deprecated_expr (fun x _ -> x) acc l

let rec collect_deprecated_stmt (acc : (Location_span.t * string) list) {stmt; _}
: (Location_span.t * string) list =
let rec collect_deprecated_stmt fundefs (acc : (Location_span.t * string) list)
{stmt; _} : (Location_span.t * string) list =
match stmt with
| FunDef {body= {stmt= Skip; _}; funname; arguments; _}
when is_redundant_forwarddecl fundefs funname arguments ->
acc
@ [ ( funname.id_loc
, "Functions do not need to be declared before definition; all user \
defined function names are always in scope regardless of \
defintion order." ) ]
| FunDef
{ body
; funname= {name; id_loc}
Expand All @@ -174,8 +196,8 @@ let rec collect_deprecated_stmt (acc : (Location_span.t * string) list) {stmt; _
^ "' instead if you intend on using this function in ~ \
statements or calling unnormalized probability functions \
inside of it." ) ] in
collect_deprecated_stmt acc body
| FunDef {body; _} -> collect_deprecated_stmt acc body
collect_deprecated_stmt fundefs acc body
| FunDef {body; _} -> collect_deprecated_stmt fundefs acc body
| IfThenElse ({emeta= {type_= UReal; loc; _}; _}, ifb, elseb) ->
let acc =
acc
Expand All @@ -184,8 +206,10 @@ let rec collect_deprecated_stmt (acc : (Location_span.t * string) list) {stmt; _
Stan 2.34. Use an explicit != 0 comparison instead. This can be \
automatically changed using the canonicalize flag for stanc" ) ]
in
let acc = collect_deprecated_stmt acc ifb in
Option.value_map ~default:acc ~f:(collect_deprecated_stmt acc) elseb
let acc = collect_deprecated_stmt fundefs acc ifb in
Option.value_map ~default:acc
~f:(collect_deprecated_stmt fundefs acc)
elseb
| While ({emeta= {type_= UReal; loc; _}; _}, body) ->
let acc =
acc
Expand All @@ -194,9 +218,10 @@ let rec collect_deprecated_stmt (acc : (Location_span.t * string) list) {stmt; _
Stan 2.34. Use an explicit != 0 comparison instead. This can be \
automatically changed using the canonicalize flag for stanc" ) ]
in
collect_deprecated_stmt acc body
collect_deprecated_stmt fundefs acc body
| _ ->
fold_statement collect_deprecated_expr collect_deprecated_stmt
fold_statement collect_deprecated_expr
(collect_deprecated_stmt fundefs)
collect_deprecated_lval
(fun l _ -> l)
acc stmt
Expand All @@ -208,4 +233,5 @@ let collect_userdef_distributions program =
|> String.Map.of_alist_exn

let collect_warnings (program : typed_program) =
fold_program collect_deprecated_stmt [] program
let fundefs = userdef_functions program in
fold_program (collect_deprecated_stmt fundefs) [] program
11 changes: 11 additions & 0 deletions src/frontend/Deprecation_analysis.mli
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,16 @@ val is_deprecated_distribution : string -> bool
val deprecated_distributions : (string * string) String.Map.t
val deprecated_functions : (string * string) String.Map.t
val rename_deprecated : (string * string) String.Map.t -> string -> string

val userdef_functions :
('a, 'b, 'c, 'd) statement_with program
-> (string * Middle.UnsizedType.argumentlist) list

val is_redundant_forwarddecl :
(string * Middle.UnsizedType.argumentlist) list
-> identifier
-> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t * 'a) list
-> bool

val userdef_distributions : untyped_statement block option -> string list
val collect_warnings : typed_program -> Warnings.t list
34 changes: 25 additions & 9 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ let verify_name_fresh_var loc tenv name =
Semantic_error.ident_has_unnormalized_suffix loc name |> error
else if
List.exists (Env.find tenv name) ~f:(function
| {kind= `StanMath; _} ->
false (* user variables can shadow library names *)
| _ -> true )
| {kind= `Variable _; _} -> true
| _ -> false (* user variables can shadow function names *) )
then Semantic_error.ident_in_use loc name |> error

(** verify that the variable being declared is previous unused. *)
Expand Down Expand Up @@ -1579,15 +1578,9 @@ and check_fundef loc cf tenv return_ty id args body =
let arg_types = List.map ~f:(fun (w, y, _) -> (w, y)) args in
let arg_identifiers = List.map ~f:(fun (_, _, z) -> z) args in
let arg_names = List.map ~f:(fun x -> x.name) arg_identifiers in
verify_fundef_overloaded loc tenv id arg_types return_ty ;
let defined = get_fn_decl_or_defn loc tenv id arg_types return_ty body in
verify_fundef_dist_rt loc id return_ty ;
verify_pdf_fundef_first_arg_ty loc id arg_types ;
verify_pmf_fundef_first_arg_ty loc id arg_types ;
let tenv =
add_function tenv id.name
(UFun (arg_types, return_ty, Fun_kind.suffix_from_name id.name, AoS))
defined in
List.iter
~f:(fun id -> verify_name_fresh tenv id ~is_udf:false)
arg_identifiers ;
Expand Down Expand Up @@ -1682,6 +1675,28 @@ let verify_functions_have_defn tenv function_block_stmts_opt =
| Some {stmts= []; _} | None -> ()
| Some {stmts= ls; _} -> List.iter ~f:verify_fun_def_body_in_block ls

let add_userdefined_functions tenv stmts_opt =
match stmts_opt with
| None -> tenv
| Some {stmts; _} ->
let f tenv (s : Ast.untyped_statement) =
match s with
| {stmt= FunDef {returntype; funname; arguments; body}; smeta= {loc}} ->
let arg_types = Ast.type_of_arguments arguments in
verify_fundef_overloaded loc tenv funname arg_types returntype ;
let defined =
get_fn_decl_or_defn loc tenv funname arg_types returntype body
in
add_function tenv funname.name
(UFun
( arg_types
, returntype
, Fun_kind.suffix_from_name funname.name
, AoS ) )
defined
| _ -> tenv in
List.fold ~init:tenv ~f stmts

let check_toplevel_block block tenv stmts_opt =
let cf = context block in
match stmts_opt with
Expand Down Expand Up @@ -1714,6 +1729,7 @@ let check_program_exn
warnings := [] ;
(* create a new type environment which has only stan-math functions *)
let tenv = Env.stan_math_environment in
let tenv = add_userdefined_functions tenv fb in
let tenv, typed_fb = check_toplevel_block Functions tenv fb in
verify_functions_have_defn tenv typed_fb ;
let tenv, typed_db = check_toplevel_block Data tenv db in
Expand Down
2 changes: 1 addition & 1 deletion src/middle/Fun_kind.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ open Core_kernel
open Core_kernel.Poly

type 'propto suffix = FnPlain | FnRng | FnLpdf of 'propto | FnTarget
[@@deriving compare, sexp, hash, map]
[@@deriving compare, sexp, hash, map, equal]

let without_propto = map_suffix (function true | false -> ())

Expand Down
2 changes: 1 addition & 1 deletion src/middle/Mem_pattern.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ open Core_kernel.Poly
* (fyi a var in the C++ code is an alias for var_value<double>)
*
**)
type t = AoS | SoA [@@deriving sexp, compare, map, hash, fold]
type t = AoS | SoA [@@deriving sexp, compare, map, hash, fold, equal]

let pp ppf = function
| AoS -> Fmt.string ppf "AoS"
Expand Down
11 changes: 5 additions & 6 deletions src/middle/UnsizedType.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ type t =
| UComplexRowVector
| UComplexMatrix
| UArray of t
| UFun of
(autodifftype * t) list
* returntype
* bool Fun_kind.suffix
* Mem_pattern.t
| UFun of argumentlist * returntype * bool Fun_kind.suffix * Mem_pattern.t
| UMathLibraryFunction

and argumentlist = (autodifftype * t) list

and autodifftype = DataOnly | AutoDiffable

and returntype = Void | ReturnType of t [@@deriving compare, hash, sexp]
and returntype = Void | ReturnType of t
[@@deriving compare, hash, sexp, equal]

let pp_autodifftype ppf = function
| DataOnly -> Fmt.string ppf "data "
Expand Down
4 changes: 4 additions & 0 deletions src/stan_math_backend/Cpp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ type fun_defn =
; body: stmt list option }
[@@deriving make, sexp]

let split_fun_decl_defn (fn : fun_defn) =
( {fn with body= None}
, {fn with templates_init= (fst fn.templates_init, false)} )

type constructor =
{ args: (type_ * string) list
; init_list: (identifier * expr list) list
Expand Down
22 changes: 15 additions & 7 deletions src/stan_math_backend/Lower_expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,19 @@ let functor_suffix = "_functor__"
let reduce_sum_functor_suffix = "_rsfunctor__"
let variadic_functor_suffix x = sprintf "_variadic%d_functor__" x

let functor_suffix_select hof =
type variadic = FixedArgs | ReduceSum | VariadicHOF of int
[@@deriving compare, hash]

let functor_type hof =
match Hashtbl.find Stan_math_signatures.stan_math_variadic_signatures hof with
| Some {required_fn_args; _} ->
variadic_functor_suffix (List.length required_fn_args)
| None when Stan_math_signatures.is_reduce_sum_fn hof ->
reduce_sum_functor_suffix
| None -> functor_suffix
| Some {required_fn_args; _} -> VariadicHOF (List.length required_fn_args)
| None when Stan_math_signatures.is_reduce_sum_fn hof -> ReduceSum
| None -> FixedArgs

let functor_suffix_select = function
| VariadicHOF n -> variadic_functor_suffix n
| ReduceSum -> reduce_sum_functor_suffix
| FixedArgs -> functor_suffix

(* retun true if the type of the expression
is integer, real, or complex (e.g. not a container) *)
Expand Down Expand Up @@ -292,7 +298,9 @@ and lower_functionals fname suffix es mem_pattern =
pattern=
FunApp
( StanLib
(name ^ functor_suffix_select fname, FnPlain, mem_pattern)
( name ^ functor_suffix_select (functor_type fname)
, FnPlain
, mem_pattern )
, [] ) }
| e -> e in
let converted_es = List.map ~f:convert_hof_vars es in
Expand Down
Loading