diff --git a/main.py b/main.py index ea31a7f..5084c5a 100755 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 from random import randint from math import log2 +import multiprocessing class lfsr(object): def __init__(self, state, taps): @@ -32,8 +33,9 @@ class lfsr(object): def test_lfsr17(): print("test lfsr17") - key = [randint(0, 1) for _ in range(16)] # first 16 bits - key.append(1) # prevent initial state from being {0}^17 + key = [randint(0, 1) for _ in range(17)] # random initial state + while key == [0]*17: # initial state shouldn't be 0 + key = [randint(0, 1) for _ in range(17)] taps = [0, 14] lfsr17 = lfsr(key, taps) states = [lfsr17.state] @@ -80,9 +82,9 @@ def css_encrypt(text, key): carry = 1 if x + y > 255 else 0 cipher_byte = z ^ byte - # print(cipher_byte.to_bytes((cipher_byte.bit_length() + 7) // 8, 'big')) + cipher = (cipher << 8) | cipher_byte - cipher_bytes = b'\x00'*(len(Bytes) - len(cipher.to_bytes((cipher.bit_length() + 7) // 8, 'big'))) +cipher.to_bytes((cipher.bit_length() + 7) // 8, 'big') # padding + cipher_bytes = b'\x00'*(len(Bytes) - len(cipher.to_bytes((cipher.bit_length() + 7) // 8, 'big'))) +cipher.to_bytes((cipher.bit_length() + 7) // 8, 'big') # padding return cipher_bytes def test_encrypt(): @@ -98,8 +100,6 @@ def gen_6_bytes(key=[randint(0, 1) for _ in range(40)]): cipher = css_encrypt(text, key) return cipher -import multiprocessing - def attack_worker(start, end, Bytes, result_queue, stop_event): taps17 = [0, 14] taps25 = [0, 3, 4, 12] @@ -193,6 +193,7 @@ def test_attack(n=1): print("".join(str(bit) for bit in key)) def test_fail(): + # happened with 74/36000 keys, probability of ~0.0021 taps17 = [0, 14] taps25 = [0, 3, 4, 12] key = [1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1] @@ -243,6 +244,142 @@ def test_fail(): print(f'\nencrypting 0xffffffffffff (6 bytes)\nwith 1st key: {hex(cipher)}\nwith 2nd key: {hex(cipher2)}\nboth outputs are equal: {cipher==cipher2}\n') + +def safer_attack_worker(start, end, Bytes, result_queue): + taps17 = [0, 14] + taps25 = [0, 3, 4, 12] + for i in range(start, end): + lfsr17_init = [int(bit) for bit in bin(i)[2:].zfill(16)]+[1] + lfsr17 = lfsr(lfsr17_init, taps17) + x = [] + for _ in range(3): + x_bin = "" + for _ in range(8): + x_bin += str(lfsr17.shift()) + x.append(int(x_bin[::-1], 2)) + y = [(Bytes[0]-x[0])%256] + c=1 if x[0]+y[0]>255 else 0 + y.append((Bytes[1]-(x[1]+c))%256) + c=1 if x[1]+y[1]>255 else 0 + y.append((Bytes[2]-(x[2]+c))%256) + lfsr25_init = [int(bit) for bit in (bin(y[0])[2:].zfill(8)[::-1] + bin(y[1])[2:].zfill(8)[::-1] + bin(y[2])[2:].zfill(8)[::-1] ) ]+[1] + lfsr25 = lfsr(lfsr25_init, taps25) + for _ in range(24): + lfsr25.shift() + for _ in range(3): + x_bin = "" + y_bin = "" + for _ in range(8): + x_bin += str(lfsr17.shift()) + y_bin += str(lfsr25.shift()) + x.append(int(x_bin[::-1], 2)) + y.append(int(y_bin[::-1], 2)) + c=1 if x[2]+y[2]>255 else 0 + z4 = (x[3]+y[3]+c)%256 + c=1 if x[3]+y[3]>255 else 0 + z5 = (x[4]+y[4]+c)%256 + c=1 if x[4]+y[4]>255 else 0 + z6 = (x[5]+y[5]+c)%256 + if z4 == Bytes[3] and z5 == Bytes[4] and z6 == Bytes[5]: + key = bin(x[0])[2:].zfill(8)[::-1] + bin(x[1])[2:].zfill(8)[::-1] + bin(y[0])[2:].zfill(8)[::-1] + bin(y[1])[2:].zfill(8)[::-1] + bin(y[2])[2:].zfill(8)[::-1] + result_queue.put(key) + + +def safer_attack(ask_queue, answer_queue, Bytes=gen_6_bytes()): + result_queue = multiprocessing.Queue() + processes = [] + num_cores = multiprocessing.cpu_count() + max_upper_limit = 2**16 + chunk_size = max_upper_limit // num_cores + + for i in range(num_cores-1): + start = i * chunk_size + end = start + chunk_size + process = multiprocessing.Process(target=safer_attack_worker, args=(start, end, Bytes, result_queue)) + processes.append(process) + process.start() + # last process + start = (num_cores-1) * chunk_size + end = max_upper_limit + process = multiprocessing.Process(target=safer_attack_worker, args=(start, end, Bytes, result_queue)) + processes.append(process) + process.start() + + for process in processes: + process.join() + + keys = [] + while not result_queue.empty(): + key = result_queue.get() + print(f'key found: \t{key}') + keys.append(key) + if len(keys)==1: + ask_queue.put(False) + answer_queue.put([int(bit) for bit in keys[0]]) + else: + print("collision detected") + ask_queue.put(True) + Byte7 = answer_queue.get() + taps17 = [0, 14] + taps25 = [0, 3, 4, 12] + for key_str in keys: + key = [int(bit) for bit in key_str] + lfsr17 = lfsr(key[:16]+[1], taps17) + lfsr25 = lfsr(key[16:]+[1], taps25) + for _ in range(40): + lfsr17.shift() + lfsr25.shift() + x = [] + y = [] + for _ in range(2): + x_bin = "" + y_bin = "" + for _ in range(8): + x_bin += str(lfsr17.shift()) + y_bin += str(lfsr25.shift()) + x.append(int(x_bin[::-1], 2)) + y.append(int(y_bin[::-1], 2)) + c=1 if x[0]+y[0]>255 else 0 + z7 = (x[1]+y[1]+c)%256 + if z7 == Byte7: + print(f'kept key: \t{key_str}') + answer_queue.put(key) + return + + +def test_safer_attack(n=1): + success = 0 + ask_queue = multiprocessing.Queue() + answer_queue = multiprocessing.Queue() + print(f'testing safer attack in 2^16 against CSS {n} times (keys randomly generated each time)\n') + for _ in range(n): + key = [randint(0, 1) for _ in range(40)] + # key = [1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0] + key_string = ''.join(str(bit) for bit in key) + print(f'key generated: \t{key_string}') + Bytes = gen_6_bytes(key) + process = multiprocessing.Process(target=safer_attack, args=(ask_queue, answer_queue, Bytes)) + process.start() + failed = [] + additional_byte = ask_queue.get() + if additional_byte: # give access to an additional byte if asked + print("giving an additional byte") + text = b'\x00\x00\x00\x00\x00\x00\x00' + cipher = css_encrypt(text, key) + answer_queue.put(cipher[-1]) + found_key = answer_queue.get() + if found_key == key: + success += 1 + else: + failed.append(key) + print() + print(f'{success}/{n} success') + if len(failed)>0: + print("fails:") + for key in failed: + print("".join(str(bit) for bit in key)) + + test_lfsr17() print("\n") test_encrypt() @@ -253,3 +390,5 @@ print("\n") test_attack(10) print("\n") test_fail() +print("\n") +test_safer_attack(5)