diff --git a/src/safedb/__init__.py b/src/safedb/__init__.py index dd1f02f1..e7275ae2 100644 --- a/src/safedb/__init__.py +++ b/src/safedb/__init__.py @@ -1,22 +1,51 @@ from typing import Union from enum import Enum, auto - import dbm +from .securestring import generate_key_file, protect_string, unprotect_string + class SafeDB: - def safe_get(key: Union[str, bytes]) -> bytes: - return + """Wrapper around dbm to optionally encrypt db values.""" - def __enter__(self): - self.db = dbm.open(self.db_path, "c") - return self.db + def get(self, key: Union[str, bytes, bytearray]) -> bytes: + if self.protected: + return self.db_conn[key] + return unprotect_string(self.db_conn[key]) - def __exit__(self): - self.db.close() + def put( + self, key: [str, bytes, bytearray], value: [bytes, bytearray]): + if self.protected: + self.db_conn[key] = protect_string(value) + else: + self.db_conn[key] = value - def __init__(self, db_path: str, use_): + def close(self): + self.db_conn.close() + + def __init__(self, db_path: str, protected=True): self.db_path = db_path + self.db_conn = dbm.open(db_path, "c") + + try: + existing_protected_mode = self.db_conn['enc'] + if protected and existing_protected_mode != b'1': + raise ValueError( + "Cannot open unencrypted database with protected=True") + elif not protected and existing_protected_mode != b'0': + raise ValueError( + "Cannot open encrypted database in protected=False") + except KeyError: + if protected: + self.db_conn['enc'] = b'1' + else: + self.db_conn['enc'] = b'0' + try: + generate_key_file() + except FileExistsError: + pass + + self.protected = protected diff --git a/src/safedb/securestring/__init__.py b/src/safedb/securestring/__init__.py index a78b2900..0a6c7101 100644 --- a/src/safedb/securestring/__init__.py +++ b/src/safedb/securestring/__init__.py @@ -1,3 +1,12 @@ +"""Wrap RinseOff, a c# CLI tool for secure data erasure via a keyfile. + +Intended for encrypting database entries. + +It is quite slow since it spawns an external process, +but an ext process is necessary to keep the key out +of memory as much as possible +""" + import os from typing import Union @@ -9,29 +18,13 @@ import logger _rinseoff = f"{app_root}/src/rinseoff/rinseoffcli" - - def generate_key_file(): if os.path.exists(secure_erase_key_file): - raise FileExistsError - - process = subprocess.Popen( - ["dotnet", "run", - "--project", _rinseoff, - "keygen", f"{secure_erase_key_file}"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - res = process.communicate() - - if res[0]: - for line in res[0].decode('utf-8').split('\n'): - logger.info(line, terminal=True) - if res[1]: - logger.warn("Error when generating database encryption keyfile") - for line in res[1].decode('utf-8').split('\n'): - logger.error(line, terminal=True) - raise subprocess.CalledProcessError + raise FileExistsError( + "Key file for rinseoff secure erase already exists") + with open(secure_erase_key_file, 'wb') as f: + f.write(os.urandom(32)) def protect_string(plaintext: Union[bytes, bytearray, str]) -> bytes: @@ -58,7 +51,8 @@ def protect_string(plaintext: Union[bytes, bytearray, str]) -> bytes: logger.warn("Error when protecting string for database", terminal=True) for line in res[1].decode('utf-8').split('\n'): logger.error(line, terminal=True) - raise subprocess.CalledProcessError + raise subprocess.CalledProcessError( + "Error protecting string") def unprotect_string(ciphertext: Union[bytes, bytearray]) -> bytes: @@ -78,4 +72,5 @@ def unprotect_string(ciphertext: Union[bytes, bytearray]) -> bytes: "Error when decrypting ciphertext from database", terminal=True) for line in res[1].decode('utf-8').split('\n'): logger.error(line, terminal=True) - raise subprocess.CalledProcessError + raise subprocess.CalledProcessError( + "Error unprotecting string") diff --git a/tests/test_safedb.py b/tests/test_safedb.py new file mode 100644 index 00000000..b54b867a --- /dev/null +++ b/tests/test_safedb.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +import sys, os +sys.path.append(".") +sys.path.append("src/") +import uuid +TEST_DIR = 'testdata/%s-%s' % (uuid.uuid4(), os.path.basename(__file__)) + '/' +print("Test directory:", TEST_DIR) +os.environ["ONIONR_HOME"] = TEST_DIR +import unittest, json +import dbm + +from utils import identifyhome, createdirs +from onionrsetup import setup_config +createdirs.create_dirs() +setup_config() +import safedb + +db_path = identifyhome.identify_home() + "test.db" + +def _remove_db(): + try: + os.remove(db_path) + except FileNotFoundError: + pass + +class TestSafeDB(unittest.TestCase): + def test_db_create_unprotected(self): + _remove_db() + db = safedb.SafeDB(db_path, protected=False) + db.close() + with dbm.open(db_path) as db: + self.assertEqual(db['enc'], b'0') + + def test_db_create_proteced(self): + _remove_db() + db = safedb.SafeDB(db_path, protected=True) + db.close() + with dbm.open(db_path) as db: + self.assertEqual(db['enc'], b'1') + + def test_db_open_protected(self): + _remove_db() + with dbm.open(db_path, 'c') as db: + db['enc'] = b'1' + db = safedb.SafeDB(db_path, protected=True) + db.close() + self.assertRaises(ValueError, safedb.SafeDB, db_path, protected=False) + + def test_db_open_unproteced(self): + _remove_db() + with dbm.open(db_path, 'c') as db: + db['enc'] = b'0' + db = safedb.SafeDB(db_path, protected=False) + db.close() + self.assertRaises(ValueError, safedb.SafeDB, db_path, protected=True) + + + +unittest.main()