diff --git a/kasten/__init__.py b/kasten/__init__.py new file mode 100644 index 0000000..c3b02a9 --- /dev/null +++ b/kasten/__init__.py @@ -0,0 +1,4 @@ +from . import exceptions +from . import generator +from . import types +from .main import Kasten diff --git a/kasten/exceptions/__init__.py b/kasten/exceptions/__init__.py index cbff9f3..6f5681a 100644 --- a/kasten/exceptions/__init__.py +++ b/kasten/exceptions/__init__.py @@ -7,3 +7,7 @@ class InvalidKastenTypeLength(KastenException): class InvalidEncryptionMode(KastenException): pass + +class InvalidID(KastenException): + pass + diff --git a/kasten/exceptions/__pycache__/__init__.cpython-38.pyc b/kasten/exceptions/__pycache__/__init__.cpython-38.pyc index 6afe4a6..a10cf21 100644 Binary files a/kasten/exceptions/__pycache__/__init__.cpython-38.pyc and b/kasten/exceptions/__pycache__/__init__.cpython-38.pyc differ diff --git a/kasten/generator/__init__.py b/kasten/generator/__init__.py new file mode 100644 index 0000000..9993993 --- /dev/null +++ b/kasten/generator/__init__.py @@ -0,0 +1,19 @@ +from kasten.types import KastenPacked +from kasten.types import KastenChecksum +from kasten.exceptions import InvalidID + +from hashlib import sha3_384 +from ..main import Kasten + + +class KastenBaseGenerator: + @classmethod + def generate(cls, packed_bytes: KastenPacked) -> Kasten: + return Kasten(sha3_384(packed_bytes).digest(), packed_bytes, cls, + auto_check_generator=False) + + @staticmethod + def validate_id(hash: KastenChecksum, packed_bytes: KastenPacked) -> None: + if not sha3_384(packed_bytes).digest() == hash: + raise InvalidID + return None \ No newline at end of file diff --git a/kasten/pack/__init__.py b/kasten/generator/pack.py similarity index 77% rename from kasten/pack/__init__.py rename to kasten/generator/pack.py index 534ee28..f9875bc 100644 --- a/kasten/pack/__init__.py +++ b/kasten/generator/pack.py @@ -10,15 +10,22 @@ encrypted with specified mode: data: bytes """ +from math import floor +from time import time + from msgpack import packb -from .. import exceptions +from kasten import exceptions + +from kasten.types import KastenPacked def pack(data: bytes, data_type: 'KastenDataType', enc_mode: 'KastenEncryptionModeID', signer: bytes = None, signature: bytes = None, - app_metadata: 'KastenSerializeableDict' = None) -> 'PreparedKasten': + app_metadata: 'KastenSerializeableDict' = None, + timestamp: int = None + ) -> KastenPacked: # Ensure data type does not exceed 4 characters if not data_type or len(data_type) > 4: @@ -31,13 +38,17 @@ def pack(data: bytes, data_type: 'KastenDataType', raise exceptions.InvalidEncryptionMode if not enc_mode >= 0 or enc_mode >= 100: raise exceptions.InvalidEncryptionMode - + try: data = data.encode('utf8') except AttributeError: pass + if timestamp is None: + timestamp = floor(time()) + assert int(timestamp) - kasten_header = [data_type, enc_mode] + + kasten_header = [data_type, enc_mode, timestamp] if signer: if signature is None: raise ValueError("Signer specified without signature") diff --git a/kasten/main.py b/kasten/main.py new file mode 100644 index 0000000..5ae7dd4 --- /dev/null +++ b/kasten/main.py @@ -0,0 +1,17 @@ +from .types import KastenChecksum +from .types import KastenPacked + + +class Kasten: + def __init__(self, id: KastenChecksum, + packed_bytes: KastenPacked, + generator: 'KastenBaseGenerator', + auto_check_generator = False): # noqa + if auto_check_generator: + generator.validate_id(id, packed_bytes) + self.id = id + self.packed_bytes = packed_bytes + self.generator = generator + + def check_generator(self): + self.generator.validate_id(self.id, self.packed_bytes) diff --git a/kasten/types.py b/kasten/types.py index 2276261..cd29efb 100644 --- a/kasten/types.py +++ b/kasten/types.py @@ -1,2 +1,18 @@ +from typing import Tuple +from typing import NewType +from typing import NamedTuple + +KastenDataType = NewType('KastenDataType', str) + class KastenDataType(str): + pass + + +class KastenPacked(bytes): + """Raw Kasten bytes that have not yet been passed through a KastenGenerator""" + + +class KastenChecksum(bytes): + """hash or checksum of a Kasten object""" + diff --git a/tests/test_generator_base.py b/tests/test_generator_base.py new file mode 100644 index 0000000..a237d51 --- /dev/null +++ b/tests/test_generator_base.py @@ -0,0 +1,18 @@ +import unittest +from hashlib import sha3_384 + +from kasten import exceptions +from kasten.generator import KastenBaseGenerator + + +class TestBaseGenerator(unittest.TestCase): + def test_base_generator(self): + k = b'\x92\xa3bin\x00\xc4\x01\n(\x86!\xd7\xb5\x8ar\xae\x97z' + K = KastenBaseGenerator.generate(k) + h = sha3_384(k).digest() + self.assertTrue(len(K.packed_bytes) > 0) + KastenBaseGenerator.validate_id(h, k) + self.assertRaises(exceptions.InvalidID, KastenBaseGenerator.validate_id, h, b"\x92\xa3txt\x00\xc4\x01\n(\x86!\xd7\xb5\x8ar\xae\x97z") + + +unittest.main() diff --git a/tests/test_kasten.py b/tests/test_kasten.py new file mode 100644 index 0000000..988ff4f --- /dev/null +++ b/tests/test_kasten.py @@ -0,0 +1,14 @@ +import unittest +from kasten import Kasten +from hashlib import sha3_384 + +from kasten import exceptions +from kasten.generator import KastenBaseGenerator + + +class TestKasten(unittest.TestCase): + def test_kasten(self): + k = b'\x92\xa3bin\x00\xc4\x01\n(\x86!\xd7\xb5\x8ar\xae\x97z' + + +unittest.main() diff --git a/tests/test_pack.py b/tests/test_pack.py index c508458..d42c166 100644 --- a/tests/test_pack.py +++ b/tests/test_pack.py @@ -1,7 +1,7 @@ import unittest import os -from kasten import pack +from kasten.generator import pack from kasten import exceptions @@ -11,14 +11,14 @@ class TestPack(unittest.TestCase): data = os.urandom(10) packed = pack.pack(data, 'bin', 0) parts = packed.split(b'\n', 1) - self.assertEqual(parts[0], b'\x92\xa3bin\x00\xc4\x01') + self.assertEqual(parts[0], b'\x93\xa3bin\x00\xce^\x95\x82:\xc4\x01') self.assertEqual(parts[1], data) def test_linebreak_data(self): data = os.urandom(9) + b'\n' + b"okay" packed = pack.pack(data, 'bin', 0) parts = packed.split(b'\n', 1) - self.assertEqual(parts[0], b'\x92\xa3bin\x00\xc4\x01') + self.assertEqual(parts[0], b'\x93\xa3bin\x00\xce^\x95\x82:\xc4\x01') self.assertEqual(parts[1], data) def test_invalid_data_type(self):