Skip to content

Add PrintTAFn flag for targeted type analysis printing #142809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ fn thin_lto(
}

fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
for &val in ad {
for val in ad {
// We intentionally don't use a wildcard, to not forget handling anything new.
match val {
config::AutoDiff::PrintPerf => {
Expand All @@ -599,6 +599,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
config::AutoDiff::PrintTA => {
llvm::set_print_type(true);
}
config::AutoDiff::PrintTAFn(fun) => {
llvm::set_print_type_fun(&fun);
}
config::AutoDiff::Inline => {
llvm::set_inline(true);
}
Expand Down
16 changes: 16 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,17 @@ pub(crate) use self::Enzyme_AD::*;
#[cfg(llvm_enzyme)]
pub(crate) mod Enzyme_AD {
use libc::c_void;
use std::ffi::{CString, c_char};

unsafe extern "C" {
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
}
unsafe extern "C" {
static mut EnzymePrintPerf: c_void;
static mut EnzymePrintActivity: c_void;
static mut EnzymePrintType: c_void;
static mut FunctionToAnalyze: c_void;
static mut EnzymePrint: c_void;
static mut EnzymeStrictAliasing: c_void;
static mut looseTypeAnalysis: c_void;
Expand All @@ -86,6 +90,15 @@ pub(crate) mod Enzyme_AD {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
}
}
pub(crate) fn set_print_type_fun(fun_name: &str) {
let c_fun_name = CString::new(fun_name).unwrap();
unsafe {
EnzymeSetCLString(
std::ptr::addr_of_mut!(FunctionToAnalyze),
c_fun_name.as_ptr() as *const c_char,
);
}
}
pub(crate) fn set_print(print: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
Expand Down Expand Up @@ -132,6 +145,9 @@ pub(crate) mod Fallback_AD {
pub(crate) fn set_print_type(print: bool) {
unimplemented!()
}
pub(crate) fn set_print_type_fun(fun_name: &str) {
unimplemented!()
}
pub(crate) fn set_print(print: bool) {
unimplemented!()
}
Expand Down
4 changes: 3 additions & 1 deletion compiler/rustc_session/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,15 @@ pub enum CoverageLevel {
}

/// The different settings that the `-Z autodiff` flag can have.
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
#[derive(Clone, PartialEq, Hash, Debug)]
pub enum AutoDiff {
/// Enable the autodiff opt pipeline
Enable,

/// Print TypeAnalysis information
PrintTA,
/// Print TypeAnalysis information for a specific function
PrintTAFn(String),
/// Print ActivityAnalysis Information
PrintAA,
/// Print Performance Warnings from Enzyme
Expand Down
22 changes: 19 additions & 3 deletions compiler/rustc_session/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ mod desc {
pub(crate) const parse_list: &str = "a space-separated list of strings";
pub(crate) const parse_list_with_polarity: &str =
"a comma-separated list of strings, with elements beginning with + or -";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
pub(crate) const parse_number: &str = "a number";
Expand Down Expand Up @@ -1351,9 +1351,25 @@ pub mod parse {
let mut v: Vec<&str> = v.split(",").collect();
v.sort_unstable();
for &val in v.iter() {
let variant = match val {
// Split each entry on '=' if it has an argument
let (key, arg) = match val.split_once('=') {
Some((k, a)) => (k, Some(a)),
None => (val, None),
};

let variant = match key {
"Enable" => AutoDiff::Enable,
"PrintTA" => AutoDiff::PrintTA,
"PrintTAFn" => {
if let Some(fun) = arg {
AutoDiff::PrintTAFn(fun.to_string())
} else {
eprintln!(
"Missing argument for PrintTAFn (expected PrintTAFn=<function_name>)"
);
return false;
}
}
"PrintAA" => AutoDiff::PrintAA,
"PrintPerf" => AutoDiff::PrintPerf,
"PrintSteps" => AutoDiff::PrintSteps,
Expand All @@ -1365,7 +1381,7 @@ pub mod parse {
"LooseTypes" => AutoDiff::LooseTypes,
"Inline" => AutoDiff::Inline,
_ => {
// FIXME(ZuseZ4): print an error saying which value is not recognized
eprintln!("Unknown autodiff option: {key}");
return false;
}
};
Expand Down
1 change: 1 addition & 0 deletions src/doc/rustc-dev-guide/src/autodiff/flags.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ To support you while debugging or profiling, we have added support for an experi

```text
PrintTA // Print TypeAnalysis information
PrintTAFn // Print TypeAnalysis information for a specific function
PrintAA // Print ActivityAnalysis information
Print // Print differentiated functions while they are being generated and optimized
PrintPerf // Print AD related Performance warnings
Expand Down
1 change: 1 addition & 0 deletions src/doc/unstable-book/src/compiler-flags/autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Multiple options can be separated with a comma. Valid options are:

`Enable` - Required flag to enable autodiff
`PrintTA` - print Type Analysis Information
`PrintTAFn` - print Type Analysis Information for a specific function
`PrintAA` - print Activity Analysis Information
`PrintPerf` - print Performance Warnings from Enzyme
`PrintSteps` - prints all intermediate transformations
Expand Down
Loading