Skip to content

Commit 624c071

Browse files
ZuseZ4bytesnake
andcommittedOct 11, 2024
Single commit implementing the enzyme/autodiff frontend
Co-authored-by: Lorenz Schmidt <[email protected]>
1 parent 52fd998 commit 624c071

File tree

17 files changed

+1384
-1
lines changed

17 files changed

+1384
-1
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2+
//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
3+
//! is the function to which the autodiff attribute is applied, and the target is the function
4+
//! getting generated by us (with a name given by the user as the first autodiff arg).
5+
6+
use std::fmt::{self, Display, Formatter};
7+
use std::str::FromStr;
8+
9+
use crate::expand::typetree::TypeTree;
10+
use crate::expand::{Decodable, Encodable, HashStable_Generic};
11+
use crate::ptr::P;
12+
use crate::{Ty, TyKind};
13+
14+
/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
15+
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
16+
/// are a hack to support higher order derivatives. We need to compute first order derivatives
17+
/// before we compute second order derivatives, otherwise we would differentiate our placeholder
18+
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
19+
/// as it's already done in the C++ and Julia frontend of Enzyme.
20+
///
21+
/// (FIXME) remove *First variants.
22+
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
23+
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
24+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
25+
pub enum DiffMode {
26+
/// No autodiff is applied (used during error handling).
27+
Error,
28+
/// The primal function which we will differentiate.
29+
Source,
30+
/// The target function, to be created using forward mode AD.
31+
Forward,
32+
/// The target function, to be created using reverse mode AD.
33+
Reverse,
34+
/// The target function, to be created using forward mode AD.
35+
/// This target function will also be used as a source for higher order derivatives,
36+
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
37+
ForwardFirst,
38+
/// The target function, to be created using reverse mode AD.
39+
/// This target function will also be used as a source for higher order derivatives,
40+
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
41+
ReverseFirst,
42+
}
43+
44+
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
45+
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
46+
/// we add to the previous shadow value. To not surprise users, we picked different names.
47+
/// Dual numbers is also a quite well known name for forward mode AD types.
48+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
49+
pub enum DiffActivity {
50+
/// Implicit or Explicit () return type, so a special case of Const.
51+
None,
52+
/// Don't compute derivatives with respect to this input/output.
53+
Const,
54+
/// Reverse Mode, Compute derivatives for this scalar input/output.
55+
Active,
56+
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
57+
/// the original return value.
58+
ActiveOnly,
59+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
60+
/// with it.
61+
Dual,
62+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
63+
/// with it. Drop the code which updates the original input/output for maximum performance.
64+
DualOnly,
65+
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
66+
Duplicated,
67+
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
68+
/// Drop the code which updates the original input for maximum performance.
69+
DuplicatedOnly,
70+
/// All Integers must be Const, but these are used to mark the integer which represents the
71+
/// length of a slice/vec. This is used for safety checks on slices.
72+
FakeActivitySize,
73+
}
74+
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
75+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
76+
pub struct AutoDiffItem {
77+
/// The name of the function getting differentiated
78+
pub source: String,
79+
/// The name of the function being generated
80+
pub target: String,
81+
pub attrs: AutoDiffAttrs,
82+
/// Describe the memory layout of input types
83+
pub inputs: Vec<TypeTree>,
84+
/// Describe the memory layout of the output type
85+
pub output: TypeTree,
86+
}
87+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
88+
pub struct AutoDiffAttrs {
89+
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
90+
/// e.g. in the [JAX
91+
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
92+
pub mode: DiffMode,
93+
pub ret_activity: DiffActivity,
94+
pub input_activity: Vec<DiffActivity>,
95+
}
96+
97+
impl DiffMode {
98+
pub fn is_rev(&self) -> bool {
99+
matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst)
100+
}
101+
pub fn is_fwd(&self) -> bool {
102+
matches!(self, DiffMode::Forward | DiffMode::ForwardFirst)
103+
}
104+
}
105+
106+
impl Display for DiffMode {
107+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
108+
match self {
109+
DiffMode::Error => write!(f, "Error"),
110+
DiffMode::Source => write!(f, "Source"),
111+
DiffMode::Forward => write!(f, "Forward"),
112+
DiffMode::Reverse => write!(f, "Reverse"),
113+
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
114+
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
115+
}
116+
}
117+
}
118+
119+
/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
120+
/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
121+
/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
122+
/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
123+
/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
124+
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
125+
if activity == DiffActivity::None {
126+
// Only valid if primal returns (), but we can't check that here.
127+
return true;
128+
}
129+
match mode {
130+
DiffMode::Error => false,
131+
DiffMode::Source => false,
132+
DiffMode::Forward | DiffMode::ForwardFirst => {
133+
activity == DiffActivity::Dual
134+
|| activity == DiffActivity::DualOnly
135+
|| activity == DiffActivity::Const
136+
}
137+
DiffMode::Reverse | DiffMode::ReverseFirst => {
138+
activity == DiffActivity::Const
139+
|| activity == DiffActivity::Active
140+
|| activity == DiffActivity::ActiveOnly
141+
}
142+
}
143+
}
144+
145+
/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
146+
/// for the given argument, but we generally can't know the size of such a type.
147+
/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
148+
/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
149+
/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
150+
/// users here from marking scalars as Duplicated, due to type aliases.
151+
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
152+
use DiffActivity::*;
153+
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
154+
if matches!(activity, Const) {
155+
return true;
156+
}
157+
if matches!(activity, Dual | DualOnly) {
158+
return true;
159+
}
160+
// FIXME(ZuseZ4) We should make this more robust to also
161+
// handle type aliases. Once that is done, we can be more restrictive here.
162+
if matches!(activity, Active | ActiveOnly) {
163+
return true;
164+
}
165+
matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
166+
&& matches!(activity, Duplicated | DuplicatedOnly)
167+
}
168+
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
169+
use DiffActivity::*;
170+
return match mode {
171+
DiffMode::Error => false,
172+
DiffMode::Source => false,
173+
DiffMode::Forward | DiffMode::ForwardFirst => {
174+
matches!(activity, Dual | DualOnly | Const)
175+
}
176+
DiffMode::Reverse | DiffMode::ReverseFirst => {
177+
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
178+
}
179+
};
180+
}
181+
182+
impl Display for DiffActivity {
183+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
184+
match self {
185+
DiffActivity::None => write!(f, "None"),
186+
DiffActivity::Const => write!(f, "Const"),
187+
DiffActivity::Active => write!(f, "Active"),
188+
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
189+
DiffActivity::Dual => write!(f, "Dual"),
190+
DiffActivity::DualOnly => write!(f, "DualOnly"),
191+
DiffActivity::Duplicated => write!(f, "Duplicated"),
192+
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
193+
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
194+
}
195+
}
196+
}
197+
198+
impl FromStr for DiffMode {
199+
type Err = ();
200+
201+
fn from_str(s: &str) -> Result<DiffMode, ()> {
202+
match s {
203+
"Error" => Ok(DiffMode::Error),
204+
"Source" => Ok(DiffMode::Source),
205+
"Forward" => Ok(DiffMode::Forward),
206+
"Reverse" => Ok(DiffMode::Reverse),
207+
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
208+
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
209+
_ => Err(()),
210+
}
211+
}
212+
}
213+
impl FromStr for DiffActivity {
214+
type Err = ();
215+
216+
fn from_str(s: &str) -> Result<DiffActivity, ()> {
217+
match s {
218+
"None" => Ok(DiffActivity::None),
219+
"Active" => Ok(DiffActivity::Active),
220+
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
221+
"Const" => Ok(DiffActivity::Const),
222+
"Dual" => Ok(DiffActivity::Dual),
223+
"DualOnly" => Ok(DiffActivity::DualOnly),
224+
"Duplicated" => Ok(DiffActivity::Duplicated),
225+
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
226+
_ => Err(()),
227+
}
228+
}
229+
}
230+
231+
impl AutoDiffAttrs {
232+
pub fn has_ret_activity(&self) -> bool {
233+
self.ret_activity != DiffActivity::None
234+
}
235+
pub fn has_active_only_ret(&self) -> bool {
236+
self.ret_activity == DiffActivity::ActiveOnly
237+
}
238+
239+
pub fn error() -> Self {
240+
AutoDiffAttrs {
241+
mode: DiffMode::Error,
242+
ret_activity: DiffActivity::None,
243+
input_activity: Vec::new(),
244+
}
245+
}
246+
pub fn source() -> Self {
247+
AutoDiffAttrs {
248+
mode: DiffMode::Source,
249+
ret_activity: DiffActivity::None,
250+
input_activity: Vec::new(),
251+
}
252+
}
253+
254+
pub fn is_active(&self) -> bool {
255+
self.mode != DiffMode::Error
256+
}
257+
258+
pub fn is_source(&self) -> bool {
259+
self.mode == DiffMode::Source
260+
}
261+
pub fn apply_autodiff(&self) -> bool {
262+
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
263+
}
264+
265+
pub fn into_item(
266+
self,
267+
source: String,
268+
target: String,
269+
inputs: Vec<TypeTree>,
270+
output: TypeTree,
271+
) -> AutoDiffItem {
272+
AutoDiffItem { source, target, inputs, output, attrs: self }
273+
}
274+
}
275+
276+
impl fmt::Display for AutoDiffItem {
277+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278+
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
279+
write!(f, " with attributes: {:?}", self.attrs)?;
280+
write!(f, " with inputs: {:?}", self.inputs)?;
281+
write!(f, " with output: {:?}", self.output)
282+
}
283+
}

‎compiler/rustc_ast/src/expand/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
77
use crate::MetaItem;
88

99
pub mod allocator;
10+
pub mod autodiff_attrs;
11+
pub mod typetree;
1012

1113
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
1214
pub struct StrippedCfgItem<ModId = DefId> {
+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//! This module contains the definition of the `TypeTree` and `Type` structs.
2+
//! They are thin Rust wrappers around the TypeTrees used by Enzyme as the LLVM based autodiff
3+
//! backend. The Enzyme TypeTrees currently have various limitations and should be rewritten, so the
4+
//! Rust frontend obviously has the same limitations. The main motivation of TypeTrees is to
5+
//! represent how a type looks like "in memory". Enzyme can deduce this based on usage patterns in
6+
//! the user code, but this is extremely slow and not even always sufficient. As such we lower some
7+
//! information from rustc to help Enzyme. For a full explanation of their design it is necessary to
8+
//! analyze the implementation in Enzyme core itself. As a rough summary, `-1` in Enzyme speech means
9+
//! everywhere. That is `{0:-1: Float}` means at index 0 you have a ptr, if you dereference it it
10+
//! will be floats everywhere. Thus `* f32`. If you have `{-1:int}` it means int's everywhere,
11+
//! e.g. [i32; N]. `{0:-1:-1 float}` then means one pointer at offset 0, if you dereference it there
12+
//! will be only pointers, if you dereference these new pointers they will point to array of floats.
13+
//! Generally, it allows byte-specific descriptions.
14+
//! FIXME: This description might be partly inaccurate and should be extended, along with
15+
//! adding documentation to the corresponding Enzyme core code.
16+
//! FIXME: Rewrite the TypeTree logic in Enzyme core to reduce the need for the rustc frontend to
17+
//! provide typetree information.
18+
//! FIXME: We should also re-evaluate where we create TypeTrees from Rust types, since MIR
19+
//! representations of some types might not be accurate. For example a vector of floats might be
20+
//! represented as a vector of u8s in MIR in some cases.
21+
22+
use std::fmt;
23+
24+
use crate::expand::{Decodable, Encodable, HashStable_Generic};
25+
26+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
27+
pub enum Kind {
28+
Anything,
29+
Integer,
30+
Pointer,
31+
Half,
32+
Float,
33+
Double,
34+
Unknown,
35+
}
36+
37+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
38+
pub struct TypeTree(pub Vec<Type>);
39+
40+
impl TypeTree {
41+
pub fn new() -> Self {
42+
Self(Vec::new())
43+
}
44+
pub fn all_ints() -> Self {
45+
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
46+
}
47+
pub fn int(size: usize) -> Self {
48+
let mut ints = Vec::with_capacity(size);
49+
for i in 0..size {
50+
ints.push(Type {
51+
offset: i as isize,
52+
size: 1,
53+
kind: Kind::Integer,
54+
child: TypeTree::new(),
55+
});
56+
}
57+
Self(ints)
58+
}
59+
}
60+
61+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
62+
pub struct FncTree {
63+
pub args: Vec<TypeTree>,
64+
pub ret: TypeTree,
65+
}
66+
67+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
68+
pub struct Type {
69+
pub offset: isize,
70+
pub size: usize,
71+
pub kind: Kind,
72+
pub child: TypeTree,
73+
}
74+
75+
impl Type {
76+
pub fn add_offset(self, add: isize) -> Self {
77+
let offset = match self.offset {
78+
-1 => add,
79+
x => add + x,
80+
};
81+
82+
Self { size: self.size, kind: self.kind, child: self.child, offset }
83+
}
84+
}
85+
86+
impl fmt::Display for Type {
87+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88+
<Self as fmt::Debug>::fmt(self, f)
89+
}
90+
}

‎compiler/rustc_builtin_macros/Cargo.toml

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
33
version = "0.0.0"
44
edition = "2021"
55

6+
7+
[lints.rust]
8+
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }
9+
610
[lib]
711
doctest = false
812

‎compiler/rustc_builtin_macros/messages.ftl

+9
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ builtin_macros_assert_requires_boolean = macro requires a boolean expression as
6969
builtin_macros_assert_requires_expression = macro requires an expression as an argument
7070
.suggestion = try removing semicolon
7171
72+
builtin_macros_autodiff = autodiff must be applied to function
73+
builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
74+
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
75+
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
76+
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
77+
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
78+
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
79+
80+
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
7281
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
7382
.label = not applicable here
7483
.label2 = not a `struct`, `enum` or `union`

‎compiler/rustc_builtin_macros/src/autodiff.rs

+820
Large diffs are not rendered by default.

‎compiler/rustc_builtin_macros/src/errors.rs

+72
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,78 @@ pub(crate) struct AllocMustStatics {
145145
pub(crate) span: Span,
146146
}
147147

148+
#[cfg(llvm_enzyme)]
149+
pub(crate) use autodiff::*;
150+
151+
#[cfg(llvm_enzyme)]
152+
mod autodiff {
153+
use super::*;
154+
#[derive(Diagnostic)]
155+
#[diag(builtin_macros_autodiff_missing_config)]
156+
pub(crate) struct AutoDiffMissingConfig {
157+
#[primary_span]
158+
pub(crate) span: Span,
159+
}
160+
#[derive(Diagnostic)]
161+
#[diag(builtin_macros_autodiff_unknown_activity)]
162+
pub(crate) struct AutoDiffUnknownActivity {
163+
#[primary_span]
164+
pub(crate) span: Span,
165+
pub(crate) act: String,
166+
}
167+
#[derive(Diagnostic)]
168+
#[diag(builtin_macros_autodiff_ty_activity)]
169+
pub(crate) struct AutoDiffInvalidTypeForActivity {
170+
#[primary_span]
171+
pub(crate) span: Span,
172+
pub(crate) act: String,
173+
}
174+
#[derive(Diagnostic)]
175+
#[diag(builtin_macros_autodiff_number_activities)]
176+
pub(crate) struct AutoDiffInvalidNumberActivities {
177+
#[primary_span]
178+
pub(crate) span: Span,
179+
pub(crate) expected: usize,
180+
pub(crate) found: usize,
181+
}
182+
#[derive(Diagnostic)]
183+
#[diag(builtin_macros_autodiff_mode_activity)]
184+
pub(crate) struct AutoDiffInvalidApplicationModeAct {
185+
#[primary_span]
186+
pub(crate) span: Span,
187+
pub(crate) mode: String,
188+
pub(crate) act: String,
189+
}
190+
191+
#[derive(Diagnostic)]
192+
#[diag(builtin_macros_autodiff_mode)]
193+
pub(crate) struct AutoDiffInvalidMode {
194+
#[primary_span]
195+
pub(crate) span: Span,
196+
pub(crate) mode: String,
197+
}
198+
199+
#[derive(Diagnostic)]
200+
#[diag(builtin_macros_autodiff)]
201+
pub(crate) struct AutoDiffInvalidApplication {
202+
#[primary_span]
203+
pub(crate) span: Span,
204+
}
205+
}
206+
207+
#[cfg(not(llvm_enzyme))]
208+
pub(crate) use ad_fallback::*;
209+
#[cfg(not(llvm_enzyme))]
210+
mod ad_fallback {
211+
use super::*;
212+
#[derive(Diagnostic)]
213+
#[diag(builtin_macros_autodiff_not_build)]
214+
pub(crate) struct AutoDiffSupportNotBuild {
215+
#[primary_span]
216+
pub(crate) span: Span,
217+
}
218+
}
219+
148220
#[derive(Diagnostic)]
149221
#[diag(builtin_macros_concat_bytes_invalid)]
150222
pub(crate) struct ConcatBytesInvalid {

‎compiler/rustc_builtin_macros/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#![allow(internal_features)]
66
#![allow(rustc::diagnostic_outside_of_impl)]
77
#![allow(rustc::untranslatable_diagnostic)]
8+
#![cfg_attr(not(bootstrap), feature(autodiff))]
89
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
910
#![doc(rust_logo)]
1011
#![feature(assert_matches)]
@@ -29,6 +30,7 @@ use crate::deriving::*;
2930

3031
mod alloc_error_handler;
3132
mod assert;
33+
mod autodiff;
3234
mod cfg;
3335
mod cfg_accessible;
3436
mod cfg_eval;
@@ -106,6 +108,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
106108

107109
register_attr! {
108110
alloc_error_handler: alloc_error_handler::expand,
111+
autodiff: autodiff::expand,
109112
bench: test::expand_bench,
110113
cfg_accessible: cfg_accessible::Expander,
111114
cfg_eval: cfg_eval::expand,

‎compiler/rustc_expand/src/build.rs

+29
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ impl<'a> ExtCtxt<'a> {
220220
self.stmt_local(local, span)
221221
}
222222

223+
pub fn stmt_semi(&self, expr: P<ast::Expr>) -> ast::Stmt {
224+
ast::Stmt { id: ast::DUMMY_NODE_ID, span: expr.span, kind: ast::StmtKind::Semi(expr) }
225+
}
226+
223227
pub fn stmt_local(&self, local: P<ast::Local>, span: Span) -> ast::Stmt {
224228
ast::Stmt { id: ast::DUMMY_NODE_ID, kind: ast::StmtKind::Let(local), span }
225229
}
@@ -287,6 +291,25 @@ impl<'a> ExtCtxt<'a> {
287291
self.expr(sp, ast::ExprKind::Paren(e))
288292
}
289293

294+
pub fn expr_method_call(
295+
&self,
296+
span: Span,
297+
expr: P<ast::Expr>,
298+
ident: Ident,
299+
args: ThinVec<P<ast::Expr>>,
300+
) -> P<ast::Expr> {
301+
let seg = ast::PathSegment::from_ident(ident);
302+
self.expr(
303+
span,
304+
ast::ExprKind::MethodCall(Box::new(ast::MethodCall {
305+
seg,
306+
receiver: expr,
307+
args,
308+
span,
309+
})),
310+
)
311+
}
312+
290313
pub fn expr_call(
291314
&self,
292315
span: Span,
@@ -295,6 +318,12 @@ impl<'a> ExtCtxt<'a> {
295318
) -> P<ast::Expr> {
296319
self.expr(span, ast::ExprKind::Call(expr, args))
297320
}
321+
pub fn expr_loop(&self, sp: Span, block: P<ast::Block>) -> P<ast::Expr> {
322+
self.expr(sp, ast::ExprKind::Loop(block, None, sp))
323+
}
324+
pub fn expr_asm(&self, sp: Span, expr: P<ast::InlineAsm>) -> P<ast::Expr> {
325+
self.expr(sp, ast::ExprKind::InlineAsm(expr))
326+
}
298327
pub fn expr_call_ident(
299328
&self,
300329
span: Span,

‎compiler/rustc_feature/src/builtin_attrs.rs

+5
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,11 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[
752752
template!(NameValueStr: "transparent|semitransparent|opaque"), ErrorFollowing,
753753
EncodeCrossCrate::Yes, "used internally for testing macro hygiene",
754754
),
755+
rustc_attr!(
756+
rustc_autodiff, Normal,
757+
template!(Word, List: r#""...""#), DuplicatesOk,
758+
EncodeCrossCrate::No, INTERNAL_UNSTABLE
759+
),
755760

756761
// ==========================================================================
757762
// Internal attributes, Diagnostics related:

‎compiler/rustc_passes/messages.ftl

+4
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ passes_attr_crate_level =
4949
passes_attr_only_in_functions =
5050
`{$attr}` attribute can only be used on functions
5151
52+
passes_autodiff_attr =
53+
`#[autodiff]` should be applied to a function
54+
.label = not a function
55+
5256
passes_both_ffi_const_and_pure =
5357
`#[ffi_const]` function cannot be `#[ffi_pure]`
5458

‎compiler/rustc_passes/src/check_attr.rs

+15
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
243243
self.check_generic_attr(hir_id, attr, target, Target::Fn);
244244
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
245245
}
246+
[sym::autodiff, ..] => {
247+
self.check_autodiff(hir_id, attr, span, target)
248+
}
246249
[sym::coroutine, ..] => {
247250
self.check_coroutine(attr, target);
248251
}
@@ -2345,6 +2348,18 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
23452348
self.dcx().emit_err(errors::RustcPubTransparent { span, attr_span });
23462349
}
23472350
}
2351+
2352+
/// Checks if `#[autodiff]` is applied to an item other than a function item.
2353+
fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) {
2354+
debug!("check_autodiff");
2355+
match target {
2356+
Target::Fn => {}
2357+
_ => {
2358+
self.dcx().emit_err(errors::AutoDiffAttr { attr_span: span });
2359+
self.abort.set(true);
2360+
}
2361+
}
2362+
}
23482363
}
23492364

23502365
impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> {

‎compiler/rustc_passes/src/errors.rs

+8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ use crate::lang_items::Duplicate;
2020
#[diag(passes_incorrect_do_not_recommend_location)]
2121
pub(crate) struct IncorrectDoNotRecommendLocation;
2222

23+
#[derive(Diagnostic)]
24+
#[diag(passes_autodiff_attr)]
25+
pub(crate) struct AutoDiffAttr {
26+
#[primary_span]
27+
#[label]
28+
pub attr_span: Span,
29+
}
30+
2331
#[derive(LintDiagnostic)]
2432
#[diag(passes_outer_crate_level_attr)]
2533
pub(crate) struct OuterCrateLevelAttr;

‎compiler/rustc_span/src/symbol.rs

+5
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ symbols! {
481481
audit_that,
482482
augmented_assignments,
483483
auto_traits,
484+
autodiff,
485+
autodiff_fallback,
484486
automatically_derived,
485487
avx,
486488
avx512_target_feature,
@@ -544,6 +546,7 @@ symbols! {
544546
cfg_accessible,
545547
cfg_attr,
546548
cfg_attr_multi,
549+
cfg_autodiff_fallback,
547550
cfg_boolean_literals,
548551
cfg_doctest,
549552
cfg_eval,
@@ -998,6 +1001,7 @@ symbols! {
9981001
hashset_iter_ty,
9991002
hexagon_target_feature,
10001003
hidden,
1004+
hint,
10011005
homogeneous_aggregate,
10021006
host,
10031007
html_favicon_url,
@@ -1650,6 +1654,7 @@ symbols! {
16501654
rustc_allow_incoherent_impl,
16511655
rustc_allowed_through_unstable_modules,
16521656
rustc_attrs,
1657+
rustc_autodiff,
16531658
rustc_box,
16541659
rustc_builtin_macro,
16551660
rustc_capture_analysis,

‎library/core/src/lib.rs

+9
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,15 @@ pub mod assert_matches {
278278
pub use crate::macros::{assert_matches, debug_assert_matches};
279279
}
280280

281+
// We don't export this through #[macro_export] for now, to avoid breakage.
282+
#[cfg(not(bootstrap))]
283+
#[unstable(feature = "autodiff", issue = "124509")]
284+
/// Unstable module containing the unstable `autodiff` macro.
285+
pub mod autodiff {
286+
#[unstable(feature = "autodiff", issue = "124509")]
287+
pub use crate::macros::builtin::autodiff;
288+
}
289+
281290
#[unstable(feature = "cfg_match", issue = "115585")]
282291
pub use crate::macros::cfg_match;
283292

‎library/core/src/macros/mod.rs

+18
Original file line numberDiff line numberDiff line change
@@ -1539,6 +1539,24 @@ pub(crate) mod builtin {
15391539
($file:expr $(,)?) => {{ /* compiler built-in */ }};
15401540
}
15411541

1542+
/// Automatic Differentiation macro which allows generating a new function to compute
1543+
/// the derivative of a given function. It may only be applied to a function.
1544+
/// The expected usage syntax is
1545+
/// `#[autodiff(NAME, MODE, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
1546+
/// where:
1547+
/// NAME is a string that represents a valid function name.
1548+
/// MODE is any of Forward, Reverse, ForwardFirst, ReverseFirst.
1549+
/// INPUT_ACTIVITIES consists of one valid activity for each input parameter.
1550+
/// OUTPUT_ACTIVITY must not be set if we implicitely return nothing (or explicitely return
1551+
/// `-> ()`. Otherwise it must be set to one of the allowed activities.
1552+
#[unstable(feature = "autodiff", issue = "124509")]
1553+
#[allow_internal_unstable(rustc_attrs)]
1554+
#[rustc_builtin_macro]
1555+
#[cfg(not(bootstrap))]
1556+
pub macro autodiff($item:item) {
1557+
/* compiler built-in */
1558+
}
1559+
15421560
/// Asserts that a boolean expression is `true` at runtime.
15431561
///
15441562
/// This will invoke the [`panic!`] macro if the provided expression cannot be

‎library/std/src/lib.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@
267267
#![allow(unused_features)]
268268
//
269269
// Features:
270+
#![cfg_attr(not(bootstrap), feature(autodiff))]
270271
#![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))]
271272
#![cfg_attr(
272273
all(target_vendor = "fortanix", target_env = "sgx"),
@@ -627,7 +628,13 @@ pub mod simd {
627628
#[doc(inline)]
628629
pub use crate::std_float::StdFloat;
629630
}
630-
631+
#[cfg(not(bootstrap))]
632+
#[unstable(feature = "autodiff", issue = "124509")]
633+
/// This module provides support for automatic differentiation.
634+
pub mod autodiff {
635+
/// This macro handles automatic differentiation.
636+
pub use core::autodiff::autodiff;
637+
}
631638
#[stable(feature = "futures_api", since = "1.36.0")]
632639
pub mod task {
633640
//! Types and Traits for working with asynchronous tasks.

0 commit comments

Comments
 (0)
Please sign in to comment.