forked from stan-dev/stanc3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAst.ml
363 lines (313 loc) · 12.2 KB
/
Ast.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
(** Abstract syntax tree for Stan. Defined with the
'two-level types' pattern, where the variant types are not
directly recursive, but rather parametric in some other type.
This type ends up being substituted for the fixpoint of the recursive
type itself including metadata. So instead of recursively referencing
[expression] you would instead reference type parameter ['e], which will
later be filled in with something like [type expr_with_meta = metadata expression]
*)
open Core_kernel
open Middle
(** Our type for identifiers, on which we record a location *)
type identifier =
{name: string; id_loc: (Location_span.t[@sexp.opaque] [@compare.ignore])}
[@@deriving sexp, hash, compare]
(** Indices for array access *)
type 'e index =
| All
| Single of 'e
| Upfrom of 'e
| Downfrom of 'e
| Between of 'e * 'e
[@@deriving sexp, hash, compare, map, fold]
(** Front-end function kinds *)
type fun_kind =
| StanLib of bool Fun_kind.suffix
| UserDefined of bool Fun_kind.suffix
[@@deriving compare, sexp, hash]
(** Expression shapes (used for both typed and untyped expressions, where we
substitute untyped_expression or typed_expression for 'e *)
type ('e, 'f) expression =
| TernaryIf of 'e * 'e * 'e
| BinOp of 'e * Operator.t * 'e
| PrefixOp of Operator.t * 'e
| PostfixOp of 'e * Operator.t
| Variable of identifier
| IntNumeral of string
| RealNumeral of string
| ImagNumeral of string
| FunApp of 'f * identifier * 'e list
| CondDistApp of 'f * identifier * 'e list
| Promotion of 'e * UnsizedType.t * UnsizedType.autodifftype
(* GetLP is deprecated *)
| GetLP
| GetTarget
| ArrayExpr of 'e list
| RowVectorExpr of 'e list
| Paren of 'e
| Indexed of 'e * 'e index list
[@@deriving sexp, hash, compare, map, fold]
type ('m, 'f) expr_with = {expr: (('m, 'f) expr_with, 'f) expression; emeta: 'm}
[@@deriving sexp, compare, map, hash, fold]
(** Untyped expressions, which have location_spans as meta-data *)
type located_meta = {loc: (Location_span.t[@sexp.opaque] [@compare.ignore])}
[@@deriving sexp, compare, map, hash, fold]
type untyped_expression = (located_meta, unit) expr_with
[@@deriving sexp, compare, map, hash, fold]
(** Typed expressions also have meta-data after type checking: a location_span, as well as a type
and an origin block (lub of the origin blocks of the identifiers in it) *)
type typed_expr_meta =
{ loc: (Location_span.t[@sexp.opaque] [@compare.ignore])
; ad_level: UnsizedType.autodifftype
; type_: UnsizedType.t }
[@@deriving sexp, compare, map, hash, fold]
type typed_expression = (typed_expr_meta, fun_kind) expr_with
[@@deriving sexp, compare, map, hash, fold]
let mk_untyped_expression ~expr ~loc = {expr; emeta= {loc}}
let mk_typed_expression ~expr ~loc ~type_ ~ad_level =
{expr; emeta= {loc; type_; ad_level}}
let expr_loc_lub exprs =
match List.map ~f:(fun e -> e.emeta.loc) exprs with
| [] -> Location_span.empty
| [hd] -> hd
| x1 :: tl -> List.fold ~init:x1 ~f:Location_span.merge tl
(** Least upper bound of expression autodiff types *)
let expr_ad_lub exprs =
exprs |> List.map ~f:(fun x -> x.emeta.ad_level) |> UnsizedType.lub_ad_type
(** Assignment operators *)
type assignmentoperator =
| Assign
(* ArrowAssign is deprecated *)
| ArrowAssign
| OperatorAssign of Operator.t
[@@deriving sexp, hash, compare]
(** Truncations *)
type 'e truncation =
| NoTruncate
| TruncateUpFrom of 'e
| TruncateDownFrom of 'e
| TruncateBetween of 'e * 'e
[@@deriving sexp, hash, compare, map, fold]
(** Things that can be printed *)
type 'e printable = PString of string | PExpr of 'e
[@@deriving sexp, compare, map, hash, fold]
type ('l, 'e) lvalue =
| LVariable of identifier
| LIndexed of 'l * 'e index list
[@@deriving sexp, hash, compare, map, fold]
type ('e, 'm) lval_with = {lval: (('e, 'm) lval_with, 'e) lvalue; lmeta: 'm}
[@@deriving sexp, hash, compare, map, fold]
type untyped_lval = (untyped_expression, located_meta) lval_with
[@@deriving sexp, hash, compare, map, fold]
type typed_lval = (typed_expression, typed_expr_meta) lval_with
[@@deriving sexp, hash, compare, map, fold]
type 'e variable = {identifier: identifier; initial_value: 'e option}
[@@deriving sexp, hash, compare, map, fold]
(** Statement shapes, where we substitute untyped_expression and untyped_statement
for 'e and 's respectively to get untyped_statement and typed_expression and
typed_statement to get typed_statement *)
type ('e, 's, 'l, 'f) statement =
| Assignment of {assign_lhs: 'l; assign_op: assignmentoperator; assign_rhs: 'e}
| NRFunApp of 'f * identifier * 'e list
| TargetPE of 'e
(* IncrementLogProb is deprecated *)
| IncrementLogProb of 'e
| Tilde of
{ arg: 'e
; distribution: identifier
; args: 'e list
; truncation: 'e truncation }
| Break
| Continue
| Return of 'e
| ReturnVoid
| Print of 'e printable list
| Reject of 'e printable list
| Skip
| IfThenElse of 'e * 's * 's option
| While of 'e * 's
| For of
{ loop_variable: identifier
; lower_bound: 'e
; upper_bound: 'e
; loop_body: 's }
| ForEach of identifier * 'e * 's
| Profile of string * 's list
| Block of 's list
| VarDecl of
{ decl_type: 'e SizedType.t
; transformation: 'e Transformation.t
; is_global: bool
; variables: 'e variable list }
| FunDef of
{ returntype: UnsizedType.returntype
; funname: identifier
; arguments:
(Middle.UnsizedType.autodifftype * Middle.UnsizedType.t * identifier)
list
; body: 's }
[@@deriving sexp, hash, compare, map, fold]
(** Statement return types which we will decorate statements with during type
checking: the purpose is to check that function bodies have the correct
return type in every possible execution branch.
NoReturnType corresponds to not having a return statement in it.
Incomplete rt corresponds to having some return statement(s) of type rt
in it, but not one in every branch
Complete rt corresponds to having a return statement of type rt in every branch
AnyReturnType corresponds to statements which have an error in every branch *)
type statement_returntype =
| NoReturnType
| Incomplete of Middle.UnsizedType.returntype
| Complete of Middle.UnsizedType.returntype
| AnyReturnType
[@@deriving sexp, hash, compare]
type ('e, 'm, 'l, 'f) statement_with =
{stmt: ('e, ('e, 'm, 'l, 'f) statement_with, 'l, 'f) statement; smeta: 'm}
[@@deriving sexp, compare, map, hash, fold]
(** Untyped statements, which have location_spans as meta-data *)
type untyped_statement =
(untyped_expression, located_meta, untyped_lval, unit) statement_with
[@@deriving sexp, compare, map, hash]
let mk_untyped_statement ~stmt ~loc : untyped_statement = {stmt; smeta= {loc}}
type stmt_typed_located_meta =
{ loc: (Middle.Location_span.t[@sexp.opaque] [@compare.ignore])
; return_type: statement_returntype }
[@@deriving sexp, compare, map, hash]
(** Typed statements also have meta-data after type checking: a location_span, as well as a statement returntype
to check that function bodies have the right return type*)
type typed_statement =
( typed_expression
, stmt_typed_located_meta
, typed_lval
, fun_kind )
statement_with
[@@deriving sexp, compare, map, hash]
let mk_typed_statement ~stmt ~loc ~return_type =
{stmt; smeta= {loc; return_type}}
(** Program shapes, where we obtain types of programs if we substitute typed or untyped
statements for 's *)
type 's block = {stmts: 's list; xloc: Middle.Location_span.t [@ignore]}
and comment_type =
| LineComment of string * Middle.Location_span.t
| Include of string * Middle.Location_span.t
| BlockComment of string list * Middle.Location_span.t
| Separator of Middle.Location.t
(** Separator records the location of items like commas, operators, and keywords
which don't have location information stored in the AST
but are useful for placing comments in pretty printing *)
and 's program =
{ functionblock: 's block option
; datablock: 's block option
; transformeddatablock: 's block option
; parametersblock: 's block option
; transformedparametersblock: 's block option
; modelblock: 's block option
; generatedquantitiesblock: 's block option
; comments: (comment_type list[@sexp.opaque] [@ignore]) }
[@@deriving sexp, hash, compare, map, fold]
let get_stmts = Option.value_map ~default:[] ~f:(fun x -> x.stmts)
(** Untyped programs (before type checking) *)
type untyped_program = untyped_statement program [@@deriving sexp, compare, map]
(** Typed programs (after type checking) *)
type typed_program = typed_statement program [@@deriving sexp, compare, map]
(*========================== Helper functions ===============================*)
(** Forgetful function from typed to untyped expressions *)
let rec untyped_expression_of_typed_expression ({expr; emeta} : typed_expression)
: untyped_expression =
match expr with
| Promotion (e, _, _) -> untyped_expression_of_typed_expression e
| _ ->
{ expr=
map_expression untyped_expression_of_typed_expression
(fun _ -> ())
expr
; emeta= {loc= emeta.loc} }
let rec untyped_lvalue_of_typed_lvalue ({lval; lmeta} : typed_lval) :
untyped_lval =
{ lval=
map_lvalue untyped_lvalue_of_typed_lvalue
untyped_expression_of_typed_expression lval
; lmeta= {loc= lmeta.loc} }
(** Forgetful function from typed to untyped statements *)
let rec untyped_statement_of_typed_statement {stmt; smeta} =
{ stmt=
map_statement untyped_expression_of_typed_expression
untyped_statement_of_typed_statement untyped_lvalue_of_typed_lvalue
(fun _ -> ())
stmt
; smeta= {loc= smeta.loc} }
(** Forgetful function from typed to untyped programs *)
let untyped_program_of_typed_program : typed_program -> untyped_program =
map_program untyped_statement_of_typed_statement
let rec expr_of_lvalue {lval; lmeta} =
{ expr=
( match lval with
| LVariable s -> Variable s
| LIndexed (l, i) -> Indexed (expr_of_lvalue l, i) )
; emeta= lmeta }
let rec lvalue_of_expr {expr; emeta} =
{ lval=
( match expr with
| Variable s -> LVariable s
| Indexed (l, i) -> LIndexed (lvalue_of_expr l, i)
| _ ->
Common.FatalError.fatal_error_msg
[%message "Trying to convert illegal expression to lval."] )
; lmeta= emeta }
let rec id_of_lvalue {lval; _} =
match lval with LVariable s -> s | LIndexed (l, _) -> id_of_lvalue l
let type_of_arguments :
(UnsizedType.autodifftype * UnsizedType.t * 'a) list
-> UnsizedType.argumentlist =
List.map ~f:(fun (a, t, _) -> (a, t))
(* XXX: the parser produces inaccurate locations: smeta.loc.begin_loc is the last
token before the current statement and all the whitespace between two statements
appears as if it were part of the second statement.
get_first_loc tries to skip the leading whitespace and approximate the location
of the first token in the statement.
TODO: See if $sloc works better than $loc for this
*)
let get_loc_dt (t : untyped_expression SizedType.t) =
match t with
| SInt | SReal | SComplex -> None
| SVector (_, e)
|SRowVector (_, e)
|SMatrix (_, e, _)
|SComplexVector e
|SComplexRowVector e
|SComplexMatrix (e, _)
|SArray (_, e) ->
Some e.emeta.loc.begin_loc
let get_loc_tf (t : untyped_expression Transformation.t) =
match t with
| Lower e
|Upper e
|LowerUpper (e, _)
|Offset e
|Multiplier e
|OffsetMultiplier (e, _) ->
Some e.emeta.loc.begin_loc
| _ -> None
let get_first_loc (s : untyped_statement) =
match s.stmt with
| NRFunApp (_, id, _)
|For {loop_variable= id; _}
|ForEach (id, _, _)
|FunDef {funname= id; _} ->
id.id_loc.begin_loc
| TargetPE e
|IncrementLogProb e
|Return e
|IfThenElse (e, _, _)
|While (e, _) ->
e.emeta.loc.begin_loc
| Assignment _ | Profile _ | Block _ | Tilde _ | Break | Continue
|ReturnVoid | Print _ | Reject _ | Skip ->
s.smeta.loc.begin_loc
| VarDecl {decl_type; transformation; variables; _} -> (
match get_loc_dt decl_type with
| Some loc -> loc
| None -> (
match get_loc_tf transformation with
| Some loc -> loc
| None -> (List.hd_exn variables).identifier.id_loc.begin_loc ) )