From 46c495f1265469a0fa7f97a5b3f7082864290db6 Mon Sep 17 00:00:00 2001 From: Sam Hadow Date: Sun, 9 Mar 2025 14:51:26 +0100 Subject: [PATCH] improve aead implementation --- src/aead.js | 58 ++++++++++++++++++++++++++++++++-------------- tests/aead.test.js | 5 ++-- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/aead.js b/src/aead.js index 8a6d8b1..446c9c5 100644 --- a/src/aead.js +++ b/src/aead.js @@ -12,9 +12,16 @@ function splitIntoChunks(data) { } class keccakAEAD { - constructor(iv, key) { - const input = kdf.concatUint8Arrays(iv, key); - this.state = keccak.SHAKE256(input, 32); + constructor(iv, key, nonce) { + const input = kdf.concatUint8Arrays(iv, key, nonce); + this.state = keccak.SHAKE256(input, 40); + let r = this.state.slice(0, 16); + let c = this.state.slice(16, 40); + let padded_key = kdf.concatUint8Arrays(new Uint8Array(24-key.length), key); + for (let i = 0; i < padded_key.length; i++) { + c[i] ^= key[i]; + } + this.state = kdf.concatUint8Arrays(r, c); } associated_data_processing(associated_data) { @@ -25,14 +32,22 @@ class keccakAEAD { let input = null; chunks.forEach((chunk) => { to_xor = this.state.slice(0, 16); - c = this.state.slice(16, 32); + c = this.state.slice(16, 40); r = new Uint8Array(chunk.length); for (let i = 0; i < chunk.length; i++) { r[i] = chunk[i] ^ to_xor[i]; } - input = kdf.concatUint8Arrays(c, r); - this.state = keccak.SHAKE256(input, 32); + input = kdf.concatUint8Arrays(r, c); + this.state = keccak.SHAKE256(input, 40); }); + r = this.state.slice(0, 16); + c = this.state.slice(16, 40); + to_xor = new Uint8Array(24); + to_xor[23] = 1; + for (let i = 0; i < c.length; i++) { + c[i] ^= to_xor[i]; + } + this.state = kdf.concatUint8Arrays(r, c); } plaintext_processing(plaintext) { @@ -44,14 +59,14 @@ class keccakAEAD { let cipherchunks = []; chunks.forEach((chunk) => { to_xor = this.state.slice(0, 16); - c = this.state.slice(16, 32); + c = this.state.slice(16, 40); r = new Uint8Array(chunk.length); for (let i = 0; i < chunk.length; i++) { r[i] = chunk[i] ^ to_xor[i]; } cipherchunks.push(r); - input = kdf.concatUint8Arrays(c, r); - this.state = keccak.SHAKE256(input, 32); + input = kdf.concatUint8Arrays(r, c); + this.state = keccak.SHAKE256(input, 40); }); return cipherchunks; } @@ -65,21 +80,28 @@ class keccakAEAD { let plaintextchunks = []; chunks.forEach((chunk) => { to_xor = this.state.slice(0, 16); - c = this.state.slice(16, 32); + c = this.state.slice(16, 40); r = new Uint8Array(chunk.length); for (let i = 0; i < chunk.length; i++) { r[i] = chunk[i] ^ to_xor[i]; } plaintextchunks.push(r); - input = kdf.concatUint8Arrays(c, chunk); - this.state = keccak.SHAKE256(input, 32); + input = kdf.concatUint8Arrays(chunk, c); + this.state = keccak.SHAKE256(input, 40); }); return plaintextchunks; } finalize(key) { - const output = keccak.SHAKE256(this.state, 32); - let to_xor = output.slice(16, 32); + let r = this.state.slice(0, 16); + let c = this.state.slice(16, 40); + let padded_key = kdf.concatUint8Arrays(new Uint8Array(24-key.length), key); + for (let i = 0; i < padded_key.length; i++) { + c[i] ^= padded_key[i]; + } + this.state = kdf.concatUint8Arrays(r, c); + const output = keccak.SHAKE256(this.state, 40); + let to_xor = output.slice(40-key.length, 40); let tag = new Uint8Array(key.length); for (let i = 0; i < key.length; i++) { tag[i] = key[i] ^ to_xor[i]; @@ -87,8 +109,8 @@ class keccakAEAD { return tag; } - static encrypt(key, plaintext, iv, associated_data) { - let sponge = new keccakAEAD(iv, key); + static encrypt(key, plaintext, iv, associated_data, nonce) { + let sponge = new keccakAEAD(iv, key, nonce); sponge.associated_data_processing(associated_data); let cipherChunks = sponge.plaintext_processing(plaintext); let ciphertext = kdf.concatUint8Arrays(...cipherChunks); @@ -99,8 +121,8 @@ class keccakAEAD { }; } - static decrypt(key, ciphertext, iv, associated_data) { - let sponge = new keccakAEAD(iv, key); + static decrypt(key, ciphertext, iv, associated_data, nonce) { + let sponge = new keccakAEAD(iv, key, nonce); sponge.associated_data_processing(associated_data); let plaintextChunks = sponge.ciphertext_processing(ciphertext); let plaintext = kdf.concatUint8Arrays(...plaintextChunks); diff --git a/tests/aead.test.js b/tests/aead.test.js index 065aa05..23c9f4b 100644 --- a/tests/aead.test.js +++ b/tests/aead.test.js @@ -15,11 +15,12 @@ describe('aead.js functions', () => { let msg_hex = stringutils.arrayToHex(msg); let ad = generateRandomUint8Array(83); let iv = generateRandomUint8Array(); + let nonce = generateRandomUint8Array(); let key = generateRandomUint8Array(); - let result = aead.keccakAEAD.encrypt(key, msg, iv, ad); + let result = aead.keccakAEAD.encrypt(key, msg, iv, ad, nonce); let tag_encrypt_hex = stringutils.arrayToHex(result.tag); let cipher_hex = stringutils.arrayToHex(result.cipher); - let result2 = aead.keccakAEAD.decrypt(key, result.cipher, iv, ad); + let result2 = aead.keccakAEAD.decrypt(key, result.cipher, iv, ad, nonce); let tag_decrypt_hex = stringutils.arrayToHex(result2.tag); let decrypted_hex = stringutils.arrayToHex(result2.plaintext); expect(decrypted_hex).toBe(msg_hex);