Skip to content

Commit 4eeb02a

Browse files
committed
feat: add a trait for call-like container enums
1 parent 7f2d8c4 commit 4eeb02a

File tree

7 files changed

+244
-44
lines changed

7 files changed

+244
-44
lines changed

crates/sol-macro/src/expand/contract.rs

+52-41
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, contract: &ItemContract) -> Result<TokenS
4040

4141
let functions_enum = if functions.len() > 1 {
4242
let mut attrs = d_attrs.clone();
43-
let doc_str = format!("Container for all the [`{name}`] function calls.");
43+
let doc_str = format!("Container for all the `{name}` function calls.");
4444
attrs.push(parse_quote!(#[doc = #doc_str]));
4545
Some(expand_functions_enum(cx, name, functions, &attrs))
4646
} else {
@@ -49,7 +49,7 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, contract: &ItemContract) -> Result<TokenS
4949

5050
let errors_enum = if errors.len() > 1 {
5151
let mut attrs = d_attrs;
52-
let doc_str = format!("Container for all the [`{name}`] custom errors.");
52+
let doc_str = format!("Container for all the `{name}` custom errors.");
5353
attrs.push(parse_quote!(#[doc = #doc_str]));
5454
Some(expand_errors_enum(cx, name, errors, &attrs))
5555
} else {
@@ -71,11 +71,11 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, contract: &ItemContract) -> Result<TokenS
7171

7272
fn expand_functions_enum(
7373
cx: &ExpCtxt<'_>,
74-
name: &SolIdent,
74+
contract_name: &SolIdent,
7575
functions: Vec<&ItemFunction>,
7676
attrs: &[Attribute],
7777
) -> TokenStream {
78-
let name = format_ident!("{name}Calls");
78+
let name = format_ident!("{contract_name}Calls");
7979
let variants: Vec<_> = functions
8080
.iter()
8181
.map(|f| cx.function_name_ident(f).0)
@@ -86,24 +86,24 @@ fn expand_functions_enum(
8686
let min_data_len = functions
8787
.iter()
8888
.map(|function| r#type::params_min_data_size(cx, &function.arguments))
89-
.max()
89+
.min()
9090
.unwrap();
9191
let trt = Ident::new("SolCall", Span::call_site());
9292
expand_call_like_enum(name, &variants, &types, min_data_len, trt, attrs)
9393
}
9494

9595
fn expand_errors_enum(
9696
cx: &ExpCtxt<'_>,
97-
name: &SolIdent,
97+
contract_name: &SolIdent,
9898
errors: Vec<&ItemError>,
9999
attrs: &[Attribute],
100100
) -> TokenStream {
101-
let name = format_ident!("{name}Errors");
101+
let name = format_ident!("{contract_name}Errors");
102102
let variants: Vec<_> = errors.iter().map(|error| error.name.0.clone()).collect();
103103
let min_data_len = errors
104104
.iter()
105105
.map(|error| r#type::params_min_data_size(cx, &error.parameters))
106-
.max()
106+
.min()
107107
.unwrap();
108108
let trt = Ident::new("SolError", Span::call_site());
109109
expand_call_like_enum(name, &variants, &variants, min_data_len, trt, attrs)
@@ -120,64 +120,75 @@ fn expand_call_like_enum(
120120
assert_eq!(variants.len(), types.len());
121121
let name_s = name.to_string();
122122
let count = variants.len();
123-
let min_data_len = min_data_len.min(4);
124123
quote! {
125124
#(#attrs)*
126125
pub enum #name {
127126
#(#variants(#types),)*
128127
}
129128

130-
// TODO: Implement these functions using traits?
131129
#[automatically_derived]
132-
impl #name {
133-
/// The number of variants.
134-
pub const COUNT: usize = #count;
135-
136-
// no decode_raw is possible because we need the selector to know which variant to
137-
// decode into
138-
139-
/// ABI-decodes the given data into one of the variants of `self`.
140-
pub fn decode(data: &[u8], validate: bool) -> ::alloy_sol_types::Result<Self> {
141-
if data.len() >= #min_data_len {
142-
// TODO: Replace with `data.split_array_ref` once it's stable
143-
let (selector, data) = data.split_at(4);
144-
let selector: &[u8; 4] =
145-
::core::convert::TryInto::try_into(selector).expect("unreachable");
146-
match *selector {
147-
#(<#types as ::alloy_sol_types::#trt>::SELECTOR => {
148-
return <#types as ::alloy_sol_types::#trt>::decode_raw(data, validate)
149-
.map(Self::#variants)
150-
})*
151-
_ => {}
152-
}
130+
impl ::alloy_sol_types::SolCalls for #name {
131+
const NAME: &'static str = #name_s;
132+
const MIN_DATA_LENGTH: usize = #min_data_len;
133+
const COUNT: usize = #count;
134+
135+
#[inline]
136+
fn selector(&self) -> [u8; 4] {
137+
match self {#(
138+
Self::#variants(_) => <#types as ::alloy_sol_types::#trt>::SELECTOR,
139+
)*}
140+
}
141+
142+
#[inline]
143+
fn type_check(selector: [u8; 4]) -> ::alloy_sol_types::Result<()> {
144+
match selector {
145+
#(<#types as ::alloy_sol_types::#trt>::SELECTOR)|* => Ok(()),
146+
s => ::core::result::Result::Err(::alloy_sol_types::Error::unknown_selector(
147+
Self::NAME,
148+
s,
149+
)),
153150
}
154-
::core::result::Result::Err(::alloy_sol_types::Error::type_check_fail(
155-
data,
156-
#name_s,
157-
))
158151
}
159152

160-
/// ABI-encodes `self` into the given buffer.
161-
pub fn encode_raw(&self, out: &mut Vec<u8>) {
153+
#[inline]
154+
fn decode_raw(
155+
selector: [u8; 4],
156+
data: &[u8],
157+
validate: bool
158+
)-> ::alloy_sol_types::Result<Self> {
159+
match selector {
160+
#(<#types as ::alloy_sol_types::#trt>::SELECTOR => {
161+
<#types as ::alloy_sol_types::#trt>::decode_raw(data, validate)
162+
.map(Self::#variants)
163+
})*
164+
s => ::core::result::Result::Err(::alloy_sol_types::Error::unknown_selector(
165+
Self::NAME,
166+
s,
167+
)),
168+
}
169+
}
170+
171+
#[inline]
172+
fn encoded_size(&self) -> usize {
162173
match self {#(
163174
Self::#variants(inner) =>
164-
<#types as ::alloy_sol_types::#trt>::encode_raw(inner, out),
175+
<#types as ::alloy_sol_types::#trt>::encoded_size(inner),
165176
)*}
166177
}
167178

168-
/// ABI-encodes `self` into the given buffer.
169179
#[inline]
170-
pub fn encode(&self) -> Vec<u8> {
180+
fn encode_raw(&self, out: &mut Vec<u8>) {
171181
match self {#(
172182
Self::#variants(inner) =>
173-
<#types as ::alloy_sol_types::#trt>::encode(inner),
183+
<#types as ::alloy_sol_types::#trt>::encode_raw(inner, out),
174184
)*}
175185
}
176186
}
177187

178188
#(
179189
#[automatically_derived]
180190
impl From<#types> for #name {
191+
#[inline]
181192
fn from(value: #types) -> Self {
182193
Self::#variants(value)
183194
}

crates/sol-types/src/errors.rs

+20
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ pub enum Error {
4040
max: u8,
4141
},
4242

43+
/// Unknown selector.
44+
UnknownSelector {
45+
/// The type name.
46+
name: &'static str,
47+
/// The unknown selector.
48+
selector: alloy_primitives::FixedBytes<4>,
49+
},
50+
4351
/// Hex error.
4452
FromHexError(hex::FromHexError),
4553

@@ -73,6 +81,9 @@ impl fmt::Display for Error {
7381
f,
7482
"`{value}` is not a valid {name} enum value (max: `{max}`)"
7583
),
84+
Self::UnknownSelector { name, selector } => {
85+
write!(f, "Unknown selector `{selector}` for {name}")
86+
}
7687
Self::FromHexError(e) => e.fmt(f),
7788
Self::Other(e) => f.write_str(e),
7889
}
@@ -104,6 +115,15 @@ impl Error {
104115
data: hex::encode(data),
105116
}
106117
}
118+
119+
/// Instantiates a [`Error::UnknownSelector`] with the provided data.
120+
#[inline]
121+
pub fn unknown_selector(name: &'static str, selector: [u8; 4]) -> Self {
122+
Self::UnknownSelector {
123+
name,
124+
selector: selector.into(),
125+
}
126+
}
107127
}
108128

109129
impl From<hex::FromHexError> for Error {

crates/sol-types/src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ mod impl_core;
181181

182182
mod types;
183183
pub use types::{
184-
data_type as sol_data, Encodable, EventTopic, Panic, PanicKind, Revert, SolCall, SolError,
185-
SolEvent, SolStruct, SolType, TopicList,
184+
data_type as sol_data, ContractError, Encodable, EventTopic, Panic, PanicKind, Revert, SolCall,
185+
SolCalls, SolEnum, SolError, SolEvent, SolStruct, SolType, TopicList,
186186
};
187187

188188
mod util;

crates/sol-types/src/types/calls.rs

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
use crate::{Panic, Result, Revert, SolError};
2+
use alloc::vec::Vec;
3+
4+
/// A collection of ABI-encoded call-like types. This currently includes
5+
/// [`SolCall`] and [`SolError`].
6+
///
7+
/// [`SolCall`]: crate::SolCall
8+
/// [`SolError`]: crate::SolError
9+
///
10+
/// ### Implementer's Guide
11+
///
12+
/// We do not recommend implementing this trait directly. Instead, we recommend
13+
/// using the [`sol`][crate::sol] proc macro to parse a Solidity contract
14+
/// definition.
15+
pub trait SolCalls: Sized {
16+
/// The name of this type.
17+
const NAME: &'static str;
18+
19+
/// The minimum length of the data for this type.
20+
///
21+
/// This does *not* include the selector's length (4).
22+
const MIN_DATA_LENGTH: usize;
23+
24+
/// The number of variants.
25+
const COUNT: usize;
26+
27+
/// The selector of this type.
28+
fn selector(&self) -> [u8; 4];
29+
30+
/// Checks if the given selector is known to this type.
31+
fn type_check(selector: [u8; 4]) -> Result<()>;
32+
33+
/// ABI-decodes the given data into one of the variants of `self`.
34+
fn decode_raw(selector: [u8; 4], data: &[u8], validate: bool) -> Result<Self>;
35+
36+
/// The size of the encoded data, *without* any selectors.
37+
fn encoded_size(&self) -> usize;
38+
39+
/// ABI-encodes `self` into the given buffer, *without* any selectors.
40+
fn encode_raw(&self, out: &mut Vec<u8>);
41+
42+
/// ABI-encodes `self` into the given buffer.
43+
fn encode(&self) -> Vec<u8> {
44+
let mut out = Vec::with_capacity(4 + self.encoded_size());
45+
out.extend(self.selector());
46+
self.encode_raw(&mut out);
47+
out
48+
}
49+
50+
/// ABI-decodes the given data into one of the variants of `self`.
51+
#[inline]
52+
fn decode(data: &[u8], validate: bool) -> Result<Self> {
53+
if data.len() < Self::MIN_DATA_LENGTH + 4 {
54+
Err(crate::Error::type_check_fail(data, Self::NAME))
55+
} else {
56+
let (selector, data) = crate::impl_core::split_array_ref(data);
57+
Self::decode_raw(*selector, data, validate)
58+
}
59+
}
60+
}
61+
62+
/// A generic contract error.
63+
///
64+
/// Contains a [`Revert`] or [`Panic`] error, or a custom error.
65+
#[derive(Clone, Debug, PartialEq, Eq)]
66+
pub enum ContractError<T> {
67+
/// A contract's custom error.
68+
CustomError(T),
69+
/// A generic revert. See [`Revert`] for more information.
70+
Revert(Revert),
71+
/// A panic. See [`Panic`] for more information.
72+
Panic(Panic),
73+
}
74+
75+
impl<T: SolCalls> SolCalls for ContractError<T> {
76+
const NAME: &'static str = "ContractError";
77+
78+
// revert is 64, panic is 32
79+
const MIN_DATA_LENGTH: usize = if T::MIN_DATA_LENGTH < 32 {
80+
T::MIN_DATA_LENGTH
81+
} else {
82+
32
83+
};
84+
85+
const COUNT: usize = T::COUNT + 2;
86+
87+
#[inline]
88+
fn selector(&self) -> [u8; 4] {
89+
match self {
90+
Self::CustomError(error) => error.selector(),
91+
Self::Panic(_) => Panic::SELECTOR,
92+
Self::Revert(_) => Revert::SELECTOR,
93+
}
94+
}
95+
96+
#[inline]
97+
fn type_check(selector: [u8; 4]) -> Result<()> {
98+
match selector {
99+
Revert::SELECTOR | Panic::SELECTOR => Ok(()),
100+
s => T::type_check(s),
101+
}
102+
}
103+
104+
#[inline]
105+
fn decode_raw(selector: [u8; 4], data: &[u8], validate: bool) -> Result<Self> {
106+
match selector {
107+
Revert::SELECTOR => Revert::decode_raw(data, validate).map(Self::Revert),
108+
Panic::SELECTOR => Panic::decode_raw(data, validate).map(Self::Panic),
109+
_ => T::decode(data, validate).map(Self::CustomError),
110+
}
111+
}
112+
113+
#[inline]
114+
fn encoded_size(&self) -> usize {
115+
match self {
116+
Self::CustomError(error) => error.encoded_size(),
117+
Self::Panic(panic) => panic.encoded_size(),
118+
Self::Revert(revert) => revert.encoded_size(),
119+
}
120+
}
121+
122+
#[inline]
123+
fn encode_raw(&self, out: &mut Vec<u8>) {
124+
match self {
125+
Self::CustomError(error) => error.encode_raw(out),
126+
Self::Panic(panic) => panic.encode_raw(out),
127+
Self::Revert(revert) => revert.encode_raw(out),
128+
}
129+
}
130+
}

crates/sol-types/src/types/enum.rs

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
use crate::{token::WordToken, Result, SolType, Word};
2+
use alloc::vec::Vec;
3+
4+
/// Solidity enum. This is always a wrapper around a [`u8`].
5+
///
6+
/// ### Implementer's Guide
7+
///
8+
/// We do not recommend implementing this trait directly. Instead, we recommend
9+
/// using the [`sol`][crate::sol] proc macro to parse a Solidity error
10+
/// definition.
11+
pub trait SolEnum: Sized + Copy + Into<u8> + TryFrom<u8, Error = crate::Error> {
12+
/// Tokenize the enum.
13+
#[inline]
14+
fn tokenize(self) -> WordToken {
15+
WordToken(Word::with_last_byte(self.into()))
16+
}
17+
18+
/// ABI decode the enum from the given buffer.
19+
#[inline]
20+
fn decode(data: &[u8], validate: bool) -> Result<Self> {
21+
<crate::sol_data::Uint<8> as SolType>::decode_single(data, validate)
22+
.and_then(Self::try_from)
23+
}
24+
25+
/// ABI encode the enum into the given buffer.
26+
#[inline]
27+
fn encode_raw(self, out: &mut Vec<u8>) {
28+
out.extend(self.tokenize().0)
29+
}
30+
31+
/// ABI encode the enum.
32+
#[inline]
33+
fn encode(self) -> Vec<u8> {
34+
self.tokenize().0.to_vec()
35+
}
36+
}

crates/sol-types/src/types/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
mod calls;
2+
pub use calls::{ContractError, SolCalls};
3+
14
pub mod data_type;
25

36
mod r#enum;

0 commit comments

Comments
 (0)