diff --git a/wolfcrypt/src/asn.c b/wolfcrypt/src/asn.c index 4874cd1a6..81c5ee5de 100644 --- a/wolfcrypt/src/asn.c +++ b/wolfcrypt/src/asn.c @@ -599,7 +599,7 @@ char* GetSigName(int oid) { } -#if !defined(NO_DSA) || defined(HAVE_ECC) || \ +#if !defined(NO_DSA) || defined(HAVE_ECC) || !defined(NO_CERTS) || \ (!defined(NO_RSA) && \ (defined(WOLFSSL_CERT_GEN) || \ ((defined(WOLFSSL_KEY_GEN) || defined(OPENSSL_EXTRA)) && !defined(HAVE_USER_RSA)))) @@ -8898,43 +8898,39 @@ WOLFSSL_LOCAL int SetMyVersion(word32 version, byte* output, int header) WOLFSSL_LOCAL int SetSerialNumber(const byte* sn, word32 snSz, byte* output, - int maxSnSz) + word32 outputSz, int maxSnSz) { - int i = 0; + int i; int snSzInt = (int)snSz; if (sn == NULL || output == NULL || snSzInt < 0) return BAD_FUNC_ARG; /* remove leading zeros */ - while (snSzInt > 0 && sn[0] == 0) { + while (snSzInt > 1 && sn[0] == 0) { snSzInt--; sn++; } + if (sn[0] & 0x80) + maxSnSz--; /* truncate if input is too long */ if (snSzInt > maxSnSz) snSzInt = maxSnSz; - /* encode ASN Integer, with length and value */ - output[i++] = ASN_INTEGER; - - /* handle MSB, to make sure value is positive */ - if (sn[0] & 0x80) { - /* make room for zero pad */ - if (snSzInt > maxSnSz-1) - snSzInt = maxSnSz-1; - - /* add zero pad */ - i += SetLength(snSzInt+1, &output[i]); - output[i++] = 0x00; - XMEMCPY(&output[i], sn, snSzInt); - } - else { - i += SetLength(snSzInt, &output[i]); - XMEMCPY(&output[i], sn, snSzInt); + i = SetASNInt(snSzInt, sn[0], NULL); + /* truncate if input is too long */ + if ((word32)snSzInt > outputSz - i) + snSzInt = outputSz - i; + /* sanity check number of bytes to copy */ + if (snSzInt <= 0) { + return BUFFER_E; } + /* write out ASN.1 Integer */ + (void)SetASNInt(snSzInt, sn[0], output); + XMEMCPY(output + i, sn, snSzInt); + /* compute final length */ i += snSzInt; @@ -11876,7 +11872,7 @@ static int EncodeCert(Cert* cert, DerCert* der, RsaKey* rsaKey, ecc_key* eccKey, return ret; } der->serialSz = SetSerialNumber(cert->serial, cert->serialSz, der->serial, - CTC_SERIAL_SIZE); + sizeof(der->serial), CTC_SERIAL_SIZE); if (der->serialSz < 0) return der->serialSz; @@ -15369,7 +15365,8 @@ int EncodeOcspRequest(OcspRequest* req, byte* output, word32 size) issuerSz = SetDigest(req->issuerHash, KEYID_SIZE, issuerArray); issuerKeySz = SetDigest(req->issuerKeyHash, KEYID_SIZE, issuerKeyArray); - snSz = SetSerialNumber(req->serial, req->serialSz, snArray, MAX_SN_SZ); + snSz = SetSerialNumber(req->serial, req->serialSz, snArray, + MAX_SN_SZ, MAX_SN_SZ); extSz = 0; if (snSz < 0) diff --git a/wolfcrypt/src/pkcs7.c b/wolfcrypt/src/pkcs7.c index 9cecc4f5d..0b0626f5d 100644 --- a/wolfcrypt/src/pkcs7.c +++ b/wolfcrypt/src/pkcs7.c @@ -2173,7 +2173,7 @@ static int PKCS7_EncodeSigned(PKCS7* pkcs7, ESD* esd, if (pkcs7->sidType == CMS_ISSUER_AND_SERIAL_NUMBER) { /* IssuerAndSerialNumber */ esd->issuerSnSz = SetSerialNumber(pkcs7->issuerSn, pkcs7->issuerSnSz, - esd->issuerSn, MAX_SN_SZ); + esd->issuerSn, MAX_SN_SZ, MAX_SN_SZ); signerInfoSz += esd->issuerSnSz; esd->issuerNameSz = SetSequence(pkcs7->issuerSz, esd->issuerName); signerInfoSz += esd->issuerNameSz + pkcs7->issuerSz; @@ -6128,7 +6128,7 @@ int wc_PKCS7_AddRecipient_KTRI(PKCS7* pkcs7, const byte* cert, word32 certSz, return -1; } snSz = SetSerialNumber(decoded->serial, decoded->serialSz, serial, - MAX_SN_SZ); + MAX_SN_SZ, MAX_SN_SZ); issuerSerialSeqSz = SetSequence(issuerSeqSz + issuerSz + snSz, issuerSerialSeq); diff --git a/wolfssl/wolfcrypt/asn.h b/wolfssl/wolfcrypt/asn.h index e79d87411..1d70fd7e4 100644 --- a/wolfssl/wolfcrypt/asn.h +++ b/wolfssl/wolfcrypt/asn.h @@ -1122,7 +1122,7 @@ WOLFSSL_LOCAL word32 SetSet(word32 len, byte* output); WOLFSSL_LOCAL word32 SetAlgoID(int algoOID,byte* output,int type,int curveSz); WOLFSSL_LOCAL int SetMyVersion(word32 version, byte* output, int header); WOLFSSL_LOCAL int SetSerialNumber(const byte* sn, word32 snSz, byte* output, - int maxSnSz); + word32 outputSz, int maxSnSz); WOLFSSL_LOCAL int GetSerialNumber(const byte* input, word32* inOutIdx, byte* serial, int* serialSz, word32 maxIdx); WOLFSSL_LOCAL int GetNameHash(const byte* source, word32* idx, byte* hash,