added put/get from db
This commit is contained in:
parent
27085845eb
commit
1a59a465c0
@ -1,22 +1,51 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
|
||||||
import dbm
|
import dbm
|
||||||
|
|
||||||
|
from .securestring import generate_key_file, protect_string, unprotect_string
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SafeDB:
|
class SafeDB:
|
||||||
def safe_get(key: Union[str, bytes]) -> bytes:
|
"""Wrapper around dbm to optionally encrypt db values."""
|
||||||
return
|
|
||||||
|
|
||||||
def __enter__(self):
|
def get(self, key: Union[str, bytes, bytearray]) -> bytes:
|
||||||
self.db = dbm.open(self.db_path, "c")
|
if self.protected:
|
||||||
return self.db
|
return self.db_conn[key]
|
||||||
|
return unprotect_string(self.db_conn[key])
|
||||||
|
|
||||||
def __exit__(self):
|
def put(
|
||||||
self.db.close()
|
self, key: [str, bytes, bytearray], value: [bytes, bytearray]):
|
||||||
|
if self.protected:
|
||||||
|
self.db_conn[key] = protect_string(value)
|
||||||
|
else:
|
||||||
|
self.db_conn[key] = value
|
||||||
|
|
||||||
def __init__(self, db_path: str, use_):
|
def close(self):
|
||||||
|
self.db_conn.close()
|
||||||
|
|
||||||
|
def __init__(self, db_path: str, protected=True):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
self.db_conn = dbm.open(db_path, "c")
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_protected_mode = self.db_conn['enc']
|
||||||
|
if protected and existing_protected_mode != b'1':
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot open unencrypted database with protected=True")
|
||||||
|
elif not protected and existing_protected_mode != b'0':
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot open encrypted database in protected=False")
|
||||||
|
except KeyError:
|
||||||
|
if protected:
|
||||||
|
self.db_conn['enc'] = b'1'
|
||||||
|
else:
|
||||||
|
self.db_conn['enc'] = b'0'
|
||||||
|
try:
|
||||||
|
generate_key_file()
|
||||||
|
except FileExistsError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.protected = protected
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,12 @@
|
|||||||
|
"""Wrap RinseOff, a c# CLI tool for secure data erasure via a keyfile.
|
||||||
|
|
||||||
|
Intended for encrypting database entries.
|
||||||
|
|
||||||
|
It is quite slow since it spawns an external process,
|
||||||
|
but an ext process is necessary to keep the key out
|
||||||
|
of memory as much as possible
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -9,29 +18,13 @@ import logger
|
|||||||
_rinseoff = f"{app_root}/src/rinseoff/rinseoffcli"
|
_rinseoff = f"{app_root}/src/rinseoff/rinseoffcli"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_key_file():
|
def generate_key_file():
|
||||||
if os.path.exists(secure_erase_key_file):
|
if os.path.exists(secure_erase_key_file):
|
||||||
raise FileExistsError
|
raise FileExistsError(
|
||||||
|
"Key file for rinseoff secure erase already exists")
|
||||||
process = subprocess.Popen(
|
|
||||||
["dotnet", "run",
|
|
||||||
"--project", _rinseoff,
|
|
||||||
"keygen", f"{secure_erase_key_file}"],
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE)
|
|
||||||
res = process.communicate()
|
|
||||||
|
|
||||||
if res[0]:
|
|
||||||
for line in res[0].decode('utf-8').split('\n'):
|
|
||||||
logger.info(line, terminal=True)
|
|
||||||
if res[1]:
|
|
||||||
logger.warn("Error when generating database encryption keyfile")
|
|
||||||
for line in res[1].decode('utf-8').split('\n'):
|
|
||||||
logger.error(line, terminal=True)
|
|
||||||
raise subprocess.CalledProcessError
|
|
||||||
|
|
||||||
|
with open(secure_erase_key_file, 'wb') as f:
|
||||||
|
f.write(os.urandom(32))
|
||||||
|
|
||||||
|
|
||||||
def protect_string(plaintext: Union[bytes, bytearray, str]) -> bytes:
|
def protect_string(plaintext: Union[bytes, bytearray, str]) -> bytes:
|
||||||
@ -58,7 +51,8 @@ def protect_string(plaintext: Union[bytes, bytearray, str]) -> bytes:
|
|||||||
logger.warn("Error when protecting string for database", terminal=True)
|
logger.warn("Error when protecting string for database", terminal=True)
|
||||||
for line in res[1].decode('utf-8').split('\n'):
|
for line in res[1].decode('utf-8').split('\n'):
|
||||||
logger.error(line, terminal=True)
|
logger.error(line, terminal=True)
|
||||||
raise subprocess.CalledProcessError
|
raise subprocess.CalledProcessError(
|
||||||
|
"Error protecting string")
|
||||||
|
|
||||||
|
|
||||||
def unprotect_string(ciphertext: Union[bytes, bytearray]) -> bytes:
|
def unprotect_string(ciphertext: Union[bytes, bytearray]) -> bytes:
|
||||||
@ -78,4 +72,5 @@ def unprotect_string(ciphertext: Union[bytes, bytearray]) -> bytes:
|
|||||||
"Error when decrypting ciphertext from database", terminal=True)
|
"Error when decrypting ciphertext from database", terminal=True)
|
||||||
for line in res[1].decode('utf-8').split('\n'):
|
for line in res[1].decode('utf-8').split('\n'):
|
||||||
logger.error(line, terminal=True)
|
logger.error(line, terminal=True)
|
||||||
raise subprocess.CalledProcessError
|
raise subprocess.CalledProcessError(
|
||||||
|
"Error unprotecting string")
|
||||||
|
59
tests/test_safedb.py
Normal file
59
tests/test_safedb.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import sys, os
|
||||||
|
sys.path.append(".")
|
||||||
|
sys.path.append("src/")
|
||||||
|
import uuid
|
||||||
|
TEST_DIR = 'testdata/%s-%s' % (uuid.uuid4(), os.path.basename(__file__)) + '/'
|
||||||
|
print("Test directory:", TEST_DIR)
|
||||||
|
os.environ["ONIONR_HOME"] = TEST_DIR
|
||||||
|
import unittest, json
|
||||||
|
import dbm
|
||||||
|
|
||||||
|
from utils import identifyhome, createdirs
|
||||||
|
from onionrsetup import setup_config
|
||||||
|
createdirs.create_dirs()
|
||||||
|
setup_config()
|
||||||
|
import safedb
|
||||||
|
|
||||||
|
db_path = identifyhome.identify_home() + "test.db"
|
||||||
|
|
||||||
|
def _remove_db():
|
||||||
|
try:
|
||||||
|
os.remove(db_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TestSafeDB(unittest.TestCase):
|
||||||
|
def test_db_create_unprotected(self):
|
||||||
|
_remove_db()
|
||||||
|
db = safedb.SafeDB(db_path, protected=False)
|
||||||
|
db.close()
|
||||||
|
with dbm.open(db_path) as db:
|
||||||
|
self.assertEqual(db['enc'], b'0')
|
||||||
|
|
||||||
|
def test_db_create_proteced(self):
|
||||||
|
_remove_db()
|
||||||
|
db = safedb.SafeDB(db_path, protected=True)
|
||||||
|
db.close()
|
||||||
|
with dbm.open(db_path) as db:
|
||||||
|
self.assertEqual(db['enc'], b'1')
|
||||||
|
|
||||||
|
def test_db_open_protected(self):
|
||||||
|
_remove_db()
|
||||||
|
with dbm.open(db_path, 'c') as db:
|
||||||
|
db['enc'] = b'1'
|
||||||
|
db = safedb.SafeDB(db_path, protected=True)
|
||||||
|
db.close()
|
||||||
|
self.assertRaises(ValueError, safedb.SafeDB, db_path, protected=False)
|
||||||
|
|
||||||
|
def test_db_open_unproteced(self):
|
||||||
|
_remove_db()
|
||||||
|
with dbm.open(db_path, 'c') as db:
|
||||||
|
db['enc'] = b'0'
|
||||||
|
db = safedb.SafeDB(db_path, protected=False)
|
||||||
|
db.close()
|
||||||
|
self.assertRaises(ValueError, safedb.SafeDB, db_path, protected=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user