This commit is contained in:
Sean Parkinson
2024-08-02 11:58:50 +10:00
parent ebb49b6e68
commit 423c1d3e57

View File

@ -1395,10 +1395,10 @@ static WC_INLINE int wc_chacha_encrypt_256(const word32* input, const byte* m,
/* Odd Round */ /* Odd Round */
QUARTER_ROUND_ODD_4() QUARTER_ROUND_ODD_4()
ODD_SHUFFLE_4() ODD_SHUFFLE_4()
"addi a3, a3, -1\n\t"
/* Even Round */ /* Even Round */
QUARTER_ROUND_EVEN_4() QUARTER_ROUND_EVEN_4()
EVEN_SHUFFLE_4() EVEN_SHUFFLE_4()
"addi a3, a3, -1\n\t"
"bnez a3, L_chacha20_riscv_256_loop\n\t" "bnez a3, L_chacha20_riscv_256_loop\n\t"
/* Load message */ /* Load message */
"mv t2, %[m]\n\t" "mv t2, %[m]\n\t"
@ -1770,13 +1770,13 @@ static WC_INLINE void wc_chacha_encrypt_64(const word32* input, const byte* m,
EIGHT_QUARTER_ROUNDS(REG_V0, REG_V1, REG_V2, REG_V3, REG_V12) EIGHT_QUARTER_ROUNDS(REG_V0, REG_V1, REG_V2, REG_V3, REG_V12)
EIGHT_QUARTER_ROUNDS(REG_V0, REG_V1, REG_V2, REG_V3, REG_V12) EIGHT_QUARTER_ROUNDS(REG_V0, REG_V1, REG_V2, REG_V3, REG_V12)
EIGHT_QUARTER_ROUNDS(REG_V0, REG_V1, REG_V2, REG_V3, REG_V12) EIGHT_QUARTER_ROUNDS(REG_V0, REG_V1, REG_V2, REG_V3, REG_V12)
"addi t1, %[bytes], -64\n\t"
/* Add back state */ /* Add back state */
VADD_VV(REG_V0, REG_V0, REG_V8) VADD_VV(REG_V0, REG_V0, REG_V8)
VADD_VV(REG_V1, REG_V1, REG_V9) VADD_VV(REG_V1, REG_V1, REG_V9)
VADD_VV(REG_V2, REG_V2, REG_V10) VADD_VV(REG_V2, REG_V2, REG_V10)
VADD_VV(REG_V3, REG_V3, REG_V11) VADD_VV(REG_V3, REG_V3, REG_V11)
"addi t2, %[bytes], -64\n\t" "bltz t1, L_chacha20_riscv_64_lt_64\n\t"
"bltz t2, L_chacha20_riscv_64_lt_64\n\t"
"mv t2, %[m]\n\t" "mv t2, %[m]\n\t"
VL4RE32_V(REG_V4, REG_T2) VL4RE32_V(REG_V4, REG_T2)
VXOR_VV(REG_V4, REG_V4, REG_V0) VXOR_VV(REG_V4, REG_V4, REG_V0)
@ -1785,73 +1785,73 @@ static WC_INLINE void wc_chacha_encrypt_64(const word32* input, const byte* m,
VXOR_VV(REG_V7, REG_V7, REG_V3) VXOR_VV(REG_V7, REG_V7, REG_V3)
"mv t2, %[c]\n\t" "mv t2, %[c]\n\t"
VS4R_V(REG_V4, REG_T2) VS4R_V(REG_V4, REG_T2)
"addi %[bytes], %[bytes], -64\n\t"
"addi %[c], %[c], 64\n\t" "addi %[c], %[c], 64\n\t"
"addi %[m], %[m], 64\n\t" "addi %[m], %[m], 64\n\t"
"addi %[bytes], %[bytes], -64\n\t"
VADD_VV(REG_V11, REG_V11, REG_V13) VADD_VV(REG_V11, REG_V11, REG_V13)
"bnez %[bytes], L_chacha20_riscv_64_loop\n\t" "bnez %[bytes], L_chacha20_riscv_64_loop\n\t"
"beqz %[bytes], L_chacha20_riscv_64_done\n\t" "beqz %[bytes], L_chacha20_riscv_64_done\n\t"
"\n" "\n"
"L_chacha20_riscv_64_lt_64:\n\t" "L_chacha20_riscv_64_lt_64:\n\t"
"mv t2, %[over]\n\t" "mv t2, %[over]\n\t"
"addi t1, %[bytes], -32\n\t"
VS4R_V(REG_V0, REG_T2) VS4R_V(REG_V0, REG_T2)
"addi t2, %[bytes], -32\n\t" "bltz t1, L_chacha20_riscv_64_lt_32\n\t"
"bltz t2, L_chacha20_riscv_64_lt_32\n\t"
"mv t2, %[m]\n\t" "mv t2, %[m]\n\t"
VL2RE32_V(REG_V4, REG_T2) VL2RE32_V(REG_V4, REG_T2)
VXOR_VV(REG_V4, REG_V4, REG_V0) VXOR_VV(REG_V4, REG_V4, REG_V0)
VXOR_VV(REG_V5, REG_V5, REG_V1) VXOR_VV(REG_V5, REG_V5, REG_V1)
"mv t2, %[c]\n\t" "mv t2, %[c]\n\t"
VS2R_V(REG_V4, REG_T2) VS2R_V(REG_V4, REG_T2)
"addi %[bytes], %[bytes], -32\n\t"
"addi %[c], %[c], 32\n\t" "addi %[c], %[c], 32\n\t"
"addi %[m], %[m], 32\n\t" "addi %[m], %[m], 32\n\t"
"addi %[bytes], %[bytes], -32\n\t"
"beqz %[bytes], L_chacha20_riscv_64_done\n\t" "beqz %[bytes], L_chacha20_riscv_64_done\n\t"
VMVR_V(REG_V0, REG_V2, 2) VMVR_V(REG_V0, REG_V2, 2)
"\n" "\n"
"L_chacha20_riscv_64_lt_32:\n\t" "L_chacha20_riscv_64_lt_32:\n\t"
"addi t2, %[bytes], -16\n\t" "addi t1, %[bytes], -16\n\t"
"bltz t2, L_chacha20_riscv_64_lt_16\n\t" "bltz t1, L_chacha20_riscv_64_lt_16\n\t"
"mv t2, %[m]\n\t" "mv t2, %[m]\n\t"
VL1RE32_V(REG_V4, REG_T2) VL1RE32_V(REG_V4, REG_T2)
VXOR_VV(REG_V4, REG_V4, REG_V0) VXOR_VV(REG_V4, REG_V4, REG_V0)
"mv t2, %[c]\n\t" "mv t2, %[c]\n\t"
VS1R_V(REG_V4, REG_T2) VS1R_V(REG_V4, REG_T2)
"addi %[bytes], %[bytes], -16\n\t"
"addi %[c], %[c], 16\n\t" "addi %[c], %[c], 16\n\t"
"addi %[m], %[m], 16\n\t" "addi %[m], %[m], 16\n\t"
"addi %[bytes], %[bytes], -16\n\t"
"beqz %[bytes], L_chacha20_riscv_64_done\n\t" "beqz %[bytes], L_chacha20_riscv_64_done\n\t"
VMV_V_V(REG_V0, REG_V1) VMV_V_V(REG_V0, REG_V1)
"\n" "\n"
"L_chacha20_riscv_64_lt_16:\n\t" "L_chacha20_riscv_64_lt_16:\n\t"
"addi t2, %[bytes], -8\n\t" "addi t1, %[bytes], -8\n\t"
"bltz t2, L_chacha20_riscv_64_lt_8\n\t" "bltz t1, L_chacha20_riscv_64_lt_8\n\t"
VSETIVLI(REG_X0, 2, 1, 1, 0b011, 0b000) VSETIVLI(REG_X0, 2, 1, 1, 0b011, 0b000)
VMV_X_S(REG_T0, REG_V0) VMV_X_S(REG_T0, REG_V0)
VSETIVLI(REG_X0, 4, 1, 1, 0b010, 0b000) VSETIVLI(REG_X0, 4, 1, 1, 0b010, 0b000)
"ld t1, (%[m])\n\t" "ld t1, (%[m])\n\t"
"xor t1, t1, t0\n\t" "xor t1, t1, t0\n\t"
"sd t1, (%[c])\n\t" "sd t1, (%[c])\n\t"
"addi %[bytes], %[bytes], -8\n\t"
"addi %[c], %[c], 8\n\t" "addi %[c], %[c], 8\n\t"
"addi %[m], %[m], 8\n\t" "addi %[m], %[m], 8\n\t"
"addi %[bytes], %[bytes], -8\n\t"
"beqz %[bytes], L_chacha20_riscv_64_done\n\t" "beqz %[bytes], L_chacha20_riscv_64_done\n\t"
VSLIDEDOWN_VI(REG_V0, REG_V0, 2) VSLIDEDOWN_VI(REG_V0, REG_V0, 2)
"\n" "\n"
"L_chacha20_riscv_64_lt_8:\n\t" "L_chacha20_riscv_64_lt_8:\n\t"
"addi %[bytes], %[bytes], -1\n\t"
VSETIVLI(REG_X0, 2, 1, 1, 0b011, 0b000) VSETIVLI(REG_X0, 2, 1, 1, 0b011, 0b000)
VMV_X_S(REG_T0, REG_V0) VMV_X_S(REG_T0, REG_V0)
VSETIVLI(REG_X0, 4, 1, 1, 0b010, 0b000) VSETIVLI(REG_X0, 4, 1, 1, 0b010, 0b000)
"addi %[bytes], %[bytes], -1\n\t"
"\n" "\n"
"L_chacha20_riscv_64_loop_lt_8:\n\t" "L_chacha20_riscv_64_loop_lt_8:\n\t"
"addi %[bytes], %[bytes], -1\n\t"
"lb t1, (%[m])\n\t" "lb t1, (%[m])\n\t"
"addi %[m], %[m], 1\n\t" "addi %[m], %[m], 1\n\t"
"xor t1, t1, t0\n\t" "xor t1, t1, t0\n\t"
"sb t1, (%[c])\n\t" "sb t1, (%[c])\n\t"
"addi %[c], %[c], 1\n\t" "addi %[c], %[c], 1\n\t"
"addi %[bytes], %[bytes], -1\n\t"
"srli t0, t0, 8\n\t" "srli t0, t0, 8\n\t"
"bgez %[bytes], L_chacha20_riscv_64_loop_lt_8\n\t" "bgez %[bytes], L_chacha20_riscv_64_loop_lt_8\n\t"
"\n" "\n"
@ -2085,9 +2085,11 @@ static void wc_chacha_encrypt_bytes(ChaCha* ctx, const byte* m, byte* c,
static WC_INLINE void wc_chacha_encrypt(const word32* input, const byte* m, static WC_INLINE void wc_chacha_encrypt(const word32* input, const byte* m,
byte* c, word32 bytes, word32* over) byte* c, word32 bytes, word32* over)
{ {
word64 bytes64 = (word64)bytes;
__asm__ __volatile__ ( __asm__ __volatile__ (
/* Ensure 64-bit bytes has top bits clear. */
"slli %[bytes], %[bytes], 32\n\t"
"srli %[bytes], %[bytes], 32\n\t"
"L_chacha20_riscv_outer:\n\t" "L_chacha20_riscv_outer:\n\t"
/* Move state into regular registers */ /* Move state into regular registers */
"ld a4, 0(%[input])\n\t" "ld a4, 0(%[input])\n\t"
@ -2113,11 +2115,13 @@ static WC_INLINE void wc_chacha_encrypt(const word32* input, const byte* m,
"L_chacha20_riscv_loop:\n\t" "L_chacha20_riscv_loop:\n\t"
/* Odd Round */ /* Odd Round */
QUARTER_ROUND_ODD() QUARTER_ROUND_ODD()
"addi a3, a3, -1\n\t"
/* Even Round */ /* Even Round */
QUARTER_ROUND_EVEN() QUARTER_ROUND_EVEN()
"addi a3, a3, -1\n\t"
"bnez a3, L_chacha20_riscv_loop\n\t" "bnez a3, L_chacha20_riscv_loop\n\t"
"addi %[bytes], %[bytes], -64\n\t"
"ld t0, 0(%[input])\n\t" "ld t0, 0(%[input])\n\t"
"ld t1, 8(%[input])\n\t" "ld t1, 8(%[input])\n\t"
"ld t2, 16(%[input])\n\t" "ld t2, 16(%[input])\n\t"
@ -2141,9 +2145,11 @@ static WC_INLINE void wc_chacha_encrypt(const word32* input, const byte* m,
"add s2, s2, t0\n\t" "add s2, s2, t0\n\t"
"add s4, s4, t1\n\t" "add s4, s4, t1\n\t"
"add s6, s6, t2\n\t" "add s6, s6, t2\n\t"
"addi t2, t2, 1\n\t"
"add s8, s8, s1\n\t" "add s8, s8, s1\n\t"
"srli t0, t0, 32\n\t" "srli t0, t0, 32\n\t"
"srli t1, t1, 32\n\t" "srli t1, t1, 32\n\t"
"sw t2, 48(%[input])\n\t"
"srli t2, t2, 32\n\t" "srli t2, t2, 32\n\t"
"srli s1, s1, 32\n\t" "srli s1, s1, 32\n\t"
"add s3, s3, t0\n\t" "add s3, s3, t0\n\t"
@ -2151,79 +2157,8 @@ static WC_INLINE void wc_chacha_encrypt(const word32* input, const byte* m,
"add s7, s7, t2\n\t" "add s7, s7, t2\n\t"
"add s9, s9, s1\n\t" "add s9, s9, s1\n\t"
"addi %[bytes], %[bytes], -64\n\t" "bltz %[bytes], L_chacha20_riscv_over\n\t"
"bgez %[bytes], L_chacha20_riscv_xor\n\t"
"addi a3, %[bytes], 64\n\t"
"sw a4, 0(%[over])\n\t"
"sw a5, 4(%[over])\n\t"
"sw a6, 8(%[over])\n\t"
"sw a7, 12(%[over])\n\t"
"sw t3, 16(%[over])\n\t"
"sw t4, 20(%[over])\n\t"
"sw t5, 24(%[over])\n\t"
"sw t6, 28(%[over])\n\t"
"sw s2, 32(%[over])\n\t"
"sw s3, 36(%[over])\n\t"
"sw s4, 40(%[over])\n\t"
"sw s5, 44(%[over])\n\t"
"sw s6, 48(%[over])\n\t"
"sw s7, 52(%[over])\n\t"
"sw s8, 56(%[over])\n\t"
"sw s9, 60(%[over])\n\t"
"addi t0, a3, -8\n\t"
"bltz t0, L_chacha20_riscv_32bit\n\t"
"addi a3, a3, -1\n\t"
"L_chacha20_riscv_64bit_loop:\n\t"
"ld t0, (%[m])\n\t"
"ld t1, (%[over])\n\t"
"xor t0, t0, t1\n\t"
"sd t0, (%[c])\n\t"
"addi %[m], %[m], 8\n\t"
"addi %[c], %[c], 8\n\t"
"addi %[over], %[over], 8\n\t"
"addi a3, a3, -8\n\t"
"bgez a3, L_chacha20_riscv_64bit_loop\n\t"
"addi a3, a3, 1\n\t"
"L_chacha20_riscv_32bit:\n\t"
"addi t0, a3, -4\n\t"
"bltz t0, L_chacha20_riscv_16bit\n\t"
"lw t0, (%[m])\n\t"
"lw t1, (%[over])\n\t"
"xor t0, t0, t1\n\t"
"sw t0, (%[c])\n\t"
"addi %[m], %[m], 4\n\t"
"addi %[c], %[c], 4\n\t"
"addi %[over], %[over], 4\n\t"
"L_chacha20_riscv_16bit:\n\t"
"addi t0, a3, -2\n\t"
"bltz t0, L_chacha20_riscv_8bit\n\t"
"lh t0, (%[m])\n\t"
"lh t1, (%[over])\n\t"
"xor t0, t0, t1\n\t"
"sh t0, (%[c])\n\t"
"addi %[m], %[m], 2\n\t"
"addi %[c], %[c], 2\n\t"
"addi %[over], %[over], 2\n\t"
"L_chacha20_riscv_8bit:\n\t"
"addi t0, a3, -1\n\t"
"bltz t0, L_chacha20_riscv_bytes_done\n\t"
"lb t0, (%[m])\n\t"
"lb t1, (%[over])\n\t"
"xor t0, t0, t1\n\t"
"sb t0, (%[c])\n\t"
"L_chacha20_riscv_bytes_done:\n\t"
"lw t0, 48(%[input])\n\t"
"addi t0, t0, 1\n\t"
"sw t0, 48(%[input])\n\t"
"bltz %[bytes], L_chacha20_riscv_done\n\t"
"L_chacha20_riscv_xor:\n\t"
#if !defined(WOLFSSL_RISCV_BIT_MANIPULATION) #if !defined(WOLFSSL_RISCV_BIT_MANIPULATION)
"ld t0, 0(%[m])\n\t" "ld t0, 0(%[m])\n\t"
"ld t1, 8(%[m])\n\t" "ld t1, 8(%[m])\n\t"
@ -2308,16 +2243,80 @@ static WC_INLINE void wc_chacha_encrypt(const word32* input, const byte* m,
"sd s8, 56(%[c])\n\t" "sd s8, 56(%[c])\n\t"
#endif #endif
"lw t0, 48(%[input])\n\t"
"addi %[m], %[m], 64\n\t" "addi %[m], %[m], 64\n\t"
"addi t0, t0, 1\n\t"
"addi %[c], %[c], 64\n\t" "addi %[c], %[c], 64\n\t"
"sw t0, 48(%[input])\n\t"
"bnez %[bytes], L_chacha20_riscv_outer\n\t" "bnez %[bytes], L_chacha20_riscv_outer\n\t"
"beqz %[bytes], L_chacha20_riscv_done\n\t"
"L_chacha20_riscv_over:\n\t"
"addi a3, %[bytes], 64\n\t"
"sw a4, 0(%[over])\n\t"
"sw a5, 4(%[over])\n\t"
"sw a6, 8(%[over])\n\t"
"sw a7, 12(%[over])\n\t"
"sw t3, 16(%[over])\n\t"
"sw t4, 20(%[over])\n\t"
"sw t5, 24(%[over])\n\t"
"sw t6, 28(%[over])\n\t"
"sw s2, 32(%[over])\n\t"
"sw s3, 36(%[over])\n\t"
"sw s4, 40(%[over])\n\t"
"sw s5, 44(%[over])\n\t"
"sw s6, 48(%[over])\n\t"
"sw s7, 52(%[over])\n\t"
"sw s8, 56(%[over])\n\t"
"sw s9, 60(%[over])\n\t"
"addi t0, a3, -8\n\t"
"bltz t0, L_chacha20_riscv_32bit\n\t"
"addi a3, a3, -1\n\t"
"L_chacha20_riscv_64bit_loop:\n\t"
"ld t0, (%[m])\n\t"
"ld t1, (%[over])\n\t"
"xor t0, t0, t1\n\t"
"sd t0, (%[c])\n\t"
"addi %[m], %[m], 8\n\t"
"addi %[c], %[c], 8\n\t"
"addi %[over], %[over], 8\n\t"
"addi a3, a3, -8\n\t"
"bgez a3, L_chacha20_riscv_64bit_loop\n\t"
"addi a3, a3, 1\n\t"
"L_chacha20_riscv_32bit:\n\t"
"addi t0, a3, -4\n\t"
"bltz t0, L_chacha20_riscv_16bit\n\t"
"lw t0, (%[m])\n\t"
"lw t1, (%[over])\n\t"
"xor t0, t0, t1\n\t"
"sw t0, (%[c])\n\t"
"addi %[m], %[m], 4\n\t"
"addi %[c], %[c], 4\n\t"
"addi %[over], %[over], 4\n\t"
"L_chacha20_riscv_16bit:\n\t"
"addi t0, a3, -2\n\t"
"bltz t0, L_chacha20_riscv_8bit\n\t"
"lh t0, (%[m])\n\t"
"lh t1, (%[over])\n\t"
"xor t0, t0, t1\n\t"
"sh t0, (%[c])\n\t"
"addi %[m], %[m], 2\n\t"
"addi %[c], %[c], 2\n\t"
"addi %[over], %[over], 2\n\t"
"L_chacha20_riscv_8bit:\n\t"
"addi t0, a3, -1\n\t"
"bltz t0, L_chacha20_riscv_done\n\t\n\t"
"lb t0, (%[m])\n\t"
"lb t1, (%[over])\n\t"
"xor t0, t0, t1\n\t"
"sb t0, (%[c])\n\t"
"bltz %[bytes], L_chacha20_riscv_done\n\t"
"L_chacha20_riscv_done:\n\t" "L_chacha20_riscv_done:\n\t"
: [m] "+r" (m), [c] "+r" (c), [bytes] "+r" (bytes64), [over] "+r" (over) : [m] "+r" (m), [c] "+r" (c), [bytes] "+r" (bytes), [over] "+r" (over)
: [input] "r" (input) : [input] "r" (input)
: "memory", "t0", "t1", "t2", "s1", "a3", : "memory", "t0", "t1", "t2", "s1", "a3",
"t3", "t4", "t5", "t6", "t3", "t4", "t5", "t6",
@ -2330,12 +2329,12 @@ static WC_INLINE void wc_chacha_encrypt(const word32* input, const byte* m,
/** /**
* Encrypt a stream of bytes * Encrypt a stream of bytes
*/ */
static void wc_chacha_encrypt_bytes(ChaCha* ctx, const byte* m, byte* c, static WC_INLINE void wc_chacha_encrypt_bytes(ChaCha* ctx, const byte* m,
word32 bytes) byte* c, word32 bytes)
{ {
wc_chacha_encrypt(ctx->X, m, c, bytes, ctx->over); wc_chacha_encrypt(ctx->X, m, c, bytes, ctx->over);
ctx->left = CHACHA_CHUNK_BYTES - (bytes & (CHACHA_CHUNK_BYTES - 1)); ctx->left = (CHACHA_CHUNK_BYTES - (bytes & (CHACHA_CHUNK_BYTES - 1))) &
ctx->left &= CHACHA_CHUNK_BYTES - 1; (CHACHA_CHUNK_BYTES - 1);
} }
#endif #endif
@ -2350,24 +2349,20 @@ int wc_Chacha_Process(ChaCha* ctx, byte* output, const byte* input,
if ((ctx == NULL) || (output == NULL) || (input == NULL)) { if ((ctx == NULL) || (output == NULL) || (input == NULL)) {
ret = BAD_FUNC_ARG; ret = BAD_FUNC_ARG;
} }
else { else if (msglen > 0) {
/* handle left overs */ if (ctx->left > 0) {
if (msglen > 0 && ctx->left > 0) { word32 processed = min(msglen, ctx->left);
byte* out; byte* out = (byte*)ctx->over + CHACHA_CHUNK_BYTES - ctx->left;
word32 i;
out = (byte*)ctx->over + CHACHA_CHUNK_BYTES - ctx->left; xorbufout(output, input, out, processed);
for (i = 0; i < msglen && i < ctx->left; i++) {
output[i] = (byte)(input[i] ^ out[i]);
}
ctx->left -= i;
msglen -= i; ctx->left -= processed;
output += i; msglen -= processed;
input += i; output += processed;
input += processed;
} }
if (msglen != 0) { if (msglen > 0) {
wc_chacha_encrypt_bytes(ctx, input, output, msglen); wc_chacha_encrypt_bytes(ctx, input, output, msglen);
} }
} }