@@ -103,20 +103,43 @@ bool HKDFTraits::DeriveBits(
103
103
EVPKeyCtxPointer ctx =
104
104
EVPKeyCtxPointer (EVP_PKEY_CTX_new_id (EVP_PKEY_HKDF, nullptr ));
105
105
if (!ctx || !EVP_PKEY_derive_init (ctx.get ()) ||
106
- !EVP_PKEY_CTX_hkdf_mode (ctx.get (),
107
- EVP_PKEY_HKDEF_MODE_EXTRACT_AND_EXPAND) ||
108
106
!EVP_PKEY_CTX_set_hkdf_md (ctx.get (), params.digest ) ||
109
- !EVP_PKEY_CTX_set1_hkdf_salt (
110
- ctx.get (), params.salt .data <unsigned char >(), params.salt .size ()) ||
111
- !EVP_PKEY_CTX_set1_hkdf_key (
112
- ctx.get (),
113
- reinterpret_cast <const unsigned char *>(params.key ->GetSymmetricKey ()),
114
- params.key ->GetSymmetricKeySize ()) ||
115
107
!EVP_PKEY_CTX_add1_hkdf_info (
116
108
ctx.get (), params.info .data <unsigned char >(), params.info .size ())) {
117
109
return false ;
118
110
}
119
111
112
+ if (params.key ->GetSymmetricKeySize () != 0 ) {
113
+ if (!EVP_PKEY_CTX_hkdf_mode (ctx.get (),
114
+ EVP_PKEY_HKDEF_MODE_EXTRACT_AND_EXPAND) ||
115
+ !EVP_PKEY_CTX_set1_hkdf_salt (
116
+ ctx.get (), params.salt .data <unsigned char >(), params.salt .size ()) ||
117
+ !EVP_PKEY_CTX_set1_hkdf_key (ctx.get (),
118
+ reinterpret_cast <const unsigned char *>(
119
+ params.key ->GetSymmetricKey ()),
120
+ params.key ->GetSymmetricKeySize ())) {
121
+ return false ;
122
+ }
123
+ } else {
124
+ unsigned int len = EVP_MD_size (params.digest );
125
+ uint8_t tempKey[len]; // NOLINT(runtime/arrays)
126
+ if (params.salt .size ()) {
127
+ HMAC (params.digest ,
128
+ params.salt .data (),
129
+ params.salt .size (),
130
+ nullptr ,
131
+ 0 ,
132
+ tempKey,
133
+ &len);
134
+ } else {
135
+ HMAC (params.digest , new char [len]{}, len, nullptr , 0 , tempKey, &len);
136
+ }
137
+ if (!EVP_PKEY_CTX_hkdf_mode (ctx.get (), EVP_PKEY_HKDEF_MODE_EXPAND_ONLY) ||
138
+ !EVP_PKEY_CTX_set1_hkdf_key (ctx.get (), tempKey, len)) {
139
+ return false ;
140
+ }
141
+ }
142
+
120
143
size_t length = params.length ;
121
144
ByteSource::Builder buf (length);
122
145
if (EVP_PKEY_derive (ctx.get (), buf.data <unsigned char >(), &length) <= 0 )
0 commit comments