improve aead implementation

This commit is contained in:
Sam Hadow 2025-03-09 14:51:26 +01:00
parent 0c80e915f3
commit 46c495f126
2 changed files with 43 additions and 20 deletions

View File

@ -12,9 +12,16 @@ function splitIntoChunks(data) {
} }
class keccakAEAD { class keccakAEAD {
constructor(iv, key) { constructor(iv, key, nonce) {
const input = kdf.concatUint8Arrays(iv, key); const input = kdf.concatUint8Arrays(iv, key, nonce);
this.state = keccak.SHAKE256(input, 32); this.state = keccak.SHAKE256(input, 40);
let r = this.state.slice(0, 16);
let c = this.state.slice(16, 40);
let padded_key = kdf.concatUint8Arrays(new Uint8Array(24-key.length), key);
for (let i = 0; i < padded_key.length; i++) {
c[i] ^= key[i];
}
this.state = kdf.concatUint8Arrays(r, c);
} }
associated_data_processing(associated_data) { associated_data_processing(associated_data) {
@ -25,14 +32,22 @@ class keccakAEAD {
let input = null; let input = null;
chunks.forEach((chunk) => { chunks.forEach((chunk) => {
to_xor = this.state.slice(0, 16); to_xor = this.state.slice(0, 16);
c = this.state.slice(16, 32); c = this.state.slice(16, 40);
r = new Uint8Array(chunk.length); r = new Uint8Array(chunk.length);
for (let i = 0; i < chunk.length; i++) { for (let i = 0; i < chunk.length; i++) {
r[i] = chunk[i] ^ to_xor[i]; r[i] = chunk[i] ^ to_xor[i];
} }
input = kdf.concatUint8Arrays(c, r); input = kdf.concatUint8Arrays(r, c);
this.state = keccak.SHAKE256(input, 32); this.state = keccak.SHAKE256(input, 40);
}); });
r = this.state.slice(0, 16);
c = this.state.slice(16, 40);
to_xor = new Uint8Array(24);
to_xor[23] = 1;
for (let i = 0; i < c.length; i++) {
c[i] ^= to_xor[i];
}
this.state = kdf.concatUint8Arrays(r, c);
} }
plaintext_processing(plaintext) { plaintext_processing(plaintext) {
@ -44,14 +59,14 @@ class keccakAEAD {
let cipherchunks = []; let cipherchunks = [];
chunks.forEach((chunk) => { chunks.forEach((chunk) => {
to_xor = this.state.slice(0, 16); to_xor = this.state.slice(0, 16);
c = this.state.slice(16, 32); c = this.state.slice(16, 40);
r = new Uint8Array(chunk.length); r = new Uint8Array(chunk.length);
for (let i = 0; i < chunk.length; i++) { for (let i = 0; i < chunk.length; i++) {
r[i] = chunk[i] ^ to_xor[i]; r[i] = chunk[i] ^ to_xor[i];
} }
cipherchunks.push(r); cipherchunks.push(r);
input = kdf.concatUint8Arrays(c, r); input = kdf.concatUint8Arrays(r, c);
this.state = keccak.SHAKE256(input, 32); this.state = keccak.SHAKE256(input, 40);
}); });
return cipherchunks; return cipherchunks;
} }
@ -65,21 +80,28 @@ class keccakAEAD {
let plaintextchunks = []; let plaintextchunks = [];
chunks.forEach((chunk) => { chunks.forEach((chunk) => {
to_xor = this.state.slice(0, 16); to_xor = this.state.slice(0, 16);
c = this.state.slice(16, 32); c = this.state.slice(16, 40);
r = new Uint8Array(chunk.length); r = new Uint8Array(chunk.length);
for (let i = 0; i < chunk.length; i++) { for (let i = 0; i < chunk.length; i++) {
r[i] = chunk[i] ^ to_xor[i]; r[i] = chunk[i] ^ to_xor[i];
} }
plaintextchunks.push(r); plaintextchunks.push(r);
input = kdf.concatUint8Arrays(c, chunk); input = kdf.concatUint8Arrays(chunk, c);
this.state = keccak.SHAKE256(input, 32); this.state = keccak.SHAKE256(input, 40);
}); });
return plaintextchunks; return plaintextchunks;
} }
finalize(key) { finalize(key) {
const output = keccak.SHAKE256(this.state, 32); let r = this.state.slice(0, 16);
let to_xor = output.slice(16, 32); let c = this.state.slice(16, 40);
let padded_key = kdf.concatUint8Arrays(new Uint8Array(24-key.length), key);
for (let i = 0; i < padded_key.length; i++) {
c[i] ^= padded_key[i];
}
this.state = kdf.concatUint8Arrays(r, c);
const output = keccak.SHAKE256(this.state, 40);
let to_xor = output.slice(40-key.length, 40);
let tag = new Uint8Array(key.length); let tag = new Uint8Array(key.length);
for (let i = 0; i < key.length; i++) { for (let i = 0; i < key.length; i++) {
tag[i] = key[i] ^ to_xor[i]; tag[i] = key[i] ^ to_xor[i];
@ -87,8 +109,8 @@ class keccakAEAD {
return tag; return tag;
} }
static encrypt(key, plaintext, iv, associated_data) { static encrypt(key, plaintext, iv, associated_data, nonce) {
let sponge = new keccakAEAD(iv, key); let sponge = new keccakAEAD(iv, key, nonce);
sponge.associated_data_processing(associated_data); sponge.associated_data_processing(associated_data);
let cipherChunks = sponge.plaintext_processing(plaintext); let cipherChunks = sponge.plaintext_processing(plaintext);
let ciphertext = kdf.concatUint8Arrays(...cipherChunks); let ciphertext = kdf.concatUint8Arrays(...cipherChunks);
@ -99,8 +121,8 @@ class keccakAEAD {
}; };
} }
static decrypt(key, ciphertext, iv, associated_data) { static decrypt(key, ciphertext, iv, associated_data, nonce) {
let sponge = new keccakAEAD(iv, key); let sponge = new keccakAEAD(iv, key, nonce);
sponge.associated_data_processing(associated_data); sponge.associated_data_processing(associated_data);
let plaintextChunks = sponge.ciphertext_processing(ciphertext); let plaintextChunks = sponge.ciphertext_processing(ciphertext);
let plaintext = kdf.concatUint8Arrays(...plaintextChunks); let plaintext = kdf.concatUint8Arrays(...plaintextChunks);

View File

@ -15,11 +15,12 @@ describe('aead.js functions', () => {
let msg_hex = stringutils.arrayToHex(msg); let msg_hex = stringutils.arrayToHex(msg);
let ad = generateRandomUint8Array(83); let ad = generateRandomUint8Array(83);
let iv = generateRandomUint8Array(); let iv = generateRandomUint8Array();
let nonce = generateRandomUint8Array();
let key = generateRandomUint8Array(); let key = generateRandomUint8Array();
let result = aead.keccakAEAD.encrypt(key, msg, iv, ad); let result = aead.keccakAEAD.encrypt(key, msg, iv, ad, nonce);
let tag_encrypt_hex = stringutils.arrayToHex(result.tag); let tag_encrypt_hex = stringutils.arrayToHex(result.tag);
let cipher_hex = stringutils.arrayToHex(result.cipher); let cipher_hex = stringutils.arrayToHex(result.cipher);
let result2 = aead.keccakAEAD.decrypt(key, result.cipher, iv, ad); let result2 = aead.keccakAEAD.decrypt(key, result.cipher, iv, ad, nonce);
let tag_decrypt_hex = stringutils.arrayToHex(result2.tag); let tag_decrypt_hex = stringutils.arrayToHex(result2.tag);
let decrypted_hex = stringutils.arrayToHex(result2.plaintext); let decrypted_hex = stringutils.arrayToHex(result2.plaintext);
expect(decrypted_hex).toBe(msg_hex); expect(decrypted_hex).toBe(msg_hex);