diff --git a/src/bio.c b/src/bio.c index 9129e2f0b..cdf3e046f 100644 --- a/src/bio.c +++ b/src/bio.c @@ -869,20 +869,34 @@ size_t wolfSSL_BIO_ctrl_pending(WOLFSSL_BIO *bio) long wolfSSL_BIO_get_mem_ptr(WOLFSSL_BIO *bio, WOLFSSL_BUF_MEM **ptr) { + WOLFSSL_BIO* front = bio; + long ret = WOLFSSL_FAILURE; + WOLFSSL_ENTER("wolfSSL_BIO_get_mem_ptr"); if (bio == NULL || ptr == NULL) { return WOLFSSL_FAILURE; } - if (bio->type != WOLFSSL_BIO_MEMORY) { - WOLFSSL_MSG("BIO is not memory buffer type"); - return SSL_FAILURE; + /* start at end and work backwards to find a memory BIO in the BIO chain */ + while ((bio != NULL) && (bio->next != NULL)) { + bio = bio->next; } - *ptr = bio->mem_buf; + while (bio != NULL) { - return SSL_SUCCESS; + if (bio->type == WOLFSSL_BIO_MEMORY) { + *ptr = bio->mem_buf; + ret = WOLFSSL_SUCCESS; + } + + if (bio == front) { + break; + } + bio = bio->prev; + } + + return ret; } WOLFSSL_API long wolfSSL_BIO_int_ctrl(WOLFSSL_BIO *bp, int cmd, long larg, int iarg) diff --git a/tests/api.c b/tests/api.c index ab95d98a4..03fc03cd8 100644 --- a/tests/api.c +++ b/tests/api.c @@ -22974,6 +22974,7 @@ static void test_wolfSSL_BIO_write(void) char msg[] = "conversion test"; char out[40]; char expected[] = "Y29udmVyc2lvbiB0ZXN0AA==\n"; + BUF_MEM* buf = NULL; printf(testingFmt, "wolfSSL_BIO_write()"); @@ -22983,6 +22984,12 @@ static void test_wolfSSL_BIO_write(void) /* now should convert to base64 then write to memory */ AssertIntEQ(BIO_write(bio, msg, sizeof(msg)), 25); BIO_flush(bio); + + /* test BIO chain */ + AssertIntEQ(SSL_SUCCESS, (int)BIO_get_mem_ptr(bio, &buf)); + AssertNotNull(buf); + AssertIntEQ(buf->length, 25); + AssertNotNull(ptr = BIO_find_type(bio, BIO_TYPE_MEM)); sz = sizeof(out); XMEMSET(out, 0, sz); diff --git a/wolfcrypt/src/coding.c b/wolfcrypt/src/coding.c index 94a85a2e1..cee699575 100644 --- a/wolfcrypt/src/coding.c +++ b/wolfcrypt/src/coding.c @@ -130,6 +130,10 @@ int Base64_Decode(const byte* in, word32 inLen, byte* out, word32* outLen) } } } + + if (out && *outLen > i) + out[i]= '\0'; + *outLen = i; return 0; @@ -321,9 +325,14 @@ static int DoBase64_Encode(const byte* in, word32 inLen, byte* out, if (i != outSz && escaped != 1 && ret == 0) return ASN_INPUT_E; + if (out && *outLen > i) + out[i]= '\0'; + *outLen = i; - if(ret == 0) + + if (ret == 0) return getSzOnly ? LENGTH_ONLY_E : 0; + return ret; }