diff --git a/.gitignore b/.gitignore index 9f30c53..174824b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ venv/* .vscode/* -.mypy_cache/* \ No newline at end of file +.mypy_cache/* +build/* +dist/* +mimcvdf.egg-info/* diff --git a/mimcvdf/__init__.py b/mimcvdf/__init__.py index ac4fc90..ce2dff6 100644 --- a/mimcvdf/__init__.py +++ b/mimcvdf/__init__.py @@ -25,18 +25,18 @@ along with this program. If not, see . DEFAULT_ROUNDS = 8000 -def _sha3_256_hash(data: bytes) -> int: +def _sha3_256_hash(data: bytes) -> bytes: sha3 = sha3_256() sha3.update(data) - return int.from_bytes(sha3.digest(), byteorder='big') + return sha3.digest() def vdf_create(data: bytes, rounds: int = DEFAULT_ROUNDS, dec=False) -> str: assert rounds > 1 input_data: int = _sha3_256_hash(data) - if not dec: - return hex(reverse_mimc(input_data, rounds)).replace('0x', '') - return reverse_mimc(input_data, rounds) + if dec: + return int.from_bytes(reverse_mimc(input_data, rounds), "big") + return reverse_mimc(input_data, rounds).hex() def vdf_verify( @@ -45,11 +45,14 @@ def vdf_verify( rounds: int = DEFAULT_ROUNDS) -> bool: """Verify data for test_hash generated by vdf_create.""" assert rounds > 1 - should_match = _sha3_256_hash(data) - try: - test_hash = int(test_hash, 16) - except TypeError: - pass + should_match = _sha3_256_hash(data).lstrip(b'\0') + if isinstance(test_hash, int): + test_hash = test_hash.to_bytes((test_hash.bit_length() + 7) // 8, "big") + else: + try: + test_hash = bytes.fromhex(test_hash) + except ValueError: + return False return forward_mimc(test_hash, rounds) == should_match diff --git a/mimcvdf/mimc/__init__.py b/mimcvdf/mimc/__init__.py new file mode 100644 index 0000000..6cca951 --- /dev/null +++ b/mimcvdf/mimc/__init__.py @@ -0,0 +1,8 @@ +"""Mimc hash function.""" + +try: + from .native import forward_mimc, reverse_mimc + is_fast = True +except ImportError: + from .fallback import forward_mimc, reverse_mimc + is_fast = False diff --git a/mimcvdf/mimc.py b/mimcvdf/mimc/fallback.py similarity index 53% rename from mimcvdf/mimc.py rename to mimcvdf/mimc/fallback.py index ff7f041..a7cf23f 100644 --- a/mimcvdf/mimc.py +++ b/mimcvdf/mimc/fallback.py @@ -1,4 +1,4 @@ -"""Mimc hash function.""" +"""Slow fallback implementation of MIMC written in pure Python.""" """ This module adapted from https://github.com/OlegJakushkin/deepblockchains/blob/master/vdf/mimc/python/mimc.py by Sourabh Niyogi https://github.com/sourabhniyogi @@ -15,22 +15,25 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ -modulus = 2**256 - 2**32 * 351 + 1 -little_fermat_expt = (modulus*2-1)//3 -round_constants = [(i**7) ^ 42 for i in range(64)] + +is_fast = False + +_modulus = 2**256 - 2**32 * 351 + 1 +_little_fermat_expt = (_modulus*2-1)//3 +_round_constants = [(i**7) ^ 42 for i in range(64)] -def forward_mimc(inp: int, steps: int) -> int: +def forward_mimc(input_data: bytes, steps: int) -> bytes: + inp = int.from_bytes(input_data, "big") for i in range(1,steps): - inp = (inp**3 + round_constants[i % len(round_constants)]) % modulus - return inp + inp = (inp**3 + _round_constants[i % len(_round_constants)]) % _modulus + return inp.to_bytes((inp.bit_length() + 7) // 8, "big") -def reverse_mimc(input_data: int, steps: int) -> int: - rtrace = input_data +def reverse_mimc(input_data: bytes, steps: int) -> bytes: + rtrace = int.from_bytes(input_data, "big") for i in range(steps - 1, 0, -1): - rtrace = pow(rtrace-round_constants[i%len(round_constants)], - little_fermat_expt, modulus) - return rtrace - + rtrace = pow(rtrace-_round_constants[i%len(_round_constants)], + _little_fermat_expt, _modulus) + return rtrace.to_bytes((rtrace.bit_length() + 7) // 8, "big") diff --git a/mimcvdf/mimc/native.c b/mimcvdf/mimc/native.c new file mode 100644 index 0000000..dc6441c --- /dev/null +++ b/mimcvdf/mimc/native.c @@ -0,0 +1,157 @@ +#define PY_SSIZE_T_CLEAN +#include + +#include +#include + +/* +This module adapted from https://github.com/OlegJakushkin/deepblockchains/blob/master/vdf/mimc/python/mimc.py by Sourabh Niyogi https://github.com/sourabhniyogi + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . +*/ + +static mpz_t MODULUS; +static mpz_t LITTLE_FERMAT_EXPT; +#define ROUND_CONSTANTS_COUNT 64 +static mpz_t ROUND_CONSTANTS[ROUND_CONSTANTS_COUNT]; + +static void +mimc_init_constants() +{ + mpz_t fortytwo; + + mpz_init(fortytwo); + mpz_set_ui(fortytwo, 42); + + // Set MODULUS to hex(2**256 - 2**32 * 351 + 1) + mpz_init(MODULUS); + mpz_set_str(MODULUS, + "ffffffffffffffffffffffffffffffff" + "fffffffffffffffffffffea100000001", + 16); + + // Set LITTLE_FERMAT_EXPT to hex((MODULUS * 2 - 1) // 3) + mpz_init(LITTLE_FERMAT_EXPT); + mpz_set_str(LITTLE_FERMAT_EXPT, + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + "aaaaaaaaaaaaaaaaaaaaa9c0aaaaaaab", + 16); + + // Set all the constants to (i**7 ^ 42) + for (unsigned long int i = 0; i < ROUND_CONSTANTS_COUNT; i++) { + mpz_init(ROUND_CONSTANTS[i]); + mpz_ui_pow_ui(ROUND_CONSTANTS[i], i, 7); + mpz_xor(ROUND_CONSTANTS[i], ROUND_CONSTANTS[i], fortytwo); + } + + mpz_clear(fortytwo); +} + +static bool +unpack_args(PyObject *args, mpz_t input, unsigned int *steps) +{ + const char *data_bytes; + Py_ssize_t count; + + if (!PyArg_ParseTuple(args, "y#I", &data_bytes, &count, steps)) { + return false; + } + + mpz_import(input, count, 1, 1, 0, 0, data_bytes); + return true; +} + + +static PyObject * +convert_mpz_to_bytes(mpz_t op) +{ + char string[34]; + size_t size; + + mpz_export(string, &size, 1, 1, 0, 0, op); + PyObject *result = PyBytes_FromStringAndSize(string, size); + return result; +} + +static PyObject * +forward_mimc(PyObject *_self, PyObject *args) +{ + mpz_t result; + unsigned int steps; + + mpz_init(result); + if (!unpack_args(args, result, &steps)) { + mpz_clear(result); + return NULL; + } + + for (unsigned int i = 1; i < steps; ++i) { + mpz_powm_ui(result, result, 3, MODULUS); + mpz_add(result, result, ROUND_CONSTANTS[i % ROUND_CONSTANTS_COUNT]); + if (mpz_cmp(result, MODULUS) >= 0) { + mpz_sub(result, result, MODULUS); + } + } + + PyObject *result_obj = convert_mpz_to_bytes(result); + mpz_clear(result); + return result_obj; +} + +static PyObject * +reverse_mimc(PyObject *_self, PyObject *args) +{ + mpz_t result; + unsigned int steps; + + mpz_init(result); + if (!unpack_args(args, result, &steps)) { + mpz_clear(result); + return NULL; + } + + for (unsigned int i = steps - 1; i > 0; --i) { + mpz_sub(result, result, ROUND_CONSTANTS[i % ROUND_CONSTANTS_COUNT]); + mpz_powm(result, result, LITTLE_FERMAT_EXPT, MODULUS); + } + + PyObject *result_obj = convert_mpz_to_bytes(result); + mpz_clear(result); + return result_obj; +} + +static PyMethodDef Methods[] = { + {"forward_mimc", forward_mimc, METH_VARARGS, + "Run MIMC forward (the fast direction)"}, + {"reverse_mimc", reverse_mimc, METH_VARARGS, + "Run MIMC in reverse (the slow direction)"}, + {NULL, NULL, 0, NULL} +}; + +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "native", + "Fast native implementation of MIMC.", + -1, + Methods, + NULL +}; + +PyMODINIT_FUNC +PyInit_native() +{ + mimc_init_constants(); + + return PyModule_Create(&moduledef); +} diff --git a/setup.py b/setup.py index 5ec8da6..dbe9fab 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import setup, find_packages, Extension setup(name='mimcvdf', version='1.1.0', @@ -8,6 +8,12 @@ setup(name='mimcvdf', url='https://www.chaoswebs.net/', packages=find_packages(exclude=['contrib', 'docs', 'tests']), install_requires=[], + ext_package="mimcvdf", + ext_modules=[ + Extension('mimc.native', ['mimcvdf/mimc/native.c'], + libraries=['gmp'], optional=True, + extra_compile_args=['-O2']) + ], classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", diff --git a/tests/test_mimc.py b/tests/test_mimc.py index a076525..8edc185 100644 --- a/tests/test_mimc.py +++ b/tests/test_mimc.py @@ -14,14 +14,23 @@ class TestMimc(unittest.TestCase): data = b"a" * 6000000 h = sha3_256() h.update(data) - data = int(h.hexdigest(), 16) + data = h.digest() forw = mimcvdf.forward_mimc(data, 2000) rev = mimcvdf.reverse_mimc(forw, 2000) - print(data) - print(forw, rev) + print(data.hex()) + print(forw.hex(), rev.hex()) self.assertEqual(rev, data) -unittest.main() \ No newline at end of file + def test_expected_data(self): + h = sha3_256() + h.update(b"test") + data = h.digest() + self.assertEqual( + "66ea2a863bd103f2c7f190503cf8456198f31660069d4903afbd5f2e40a28695", + mimcvdf.forward_mimc(data, 2000).hex() + ) + +unittest.main() diff --git a/tests/test_vdf.py b/tests/test_vdf.py index db4cdff..ea593c1 100644 --- a/tests/test_vdf.py +++ b/tests/test_vdf.py @@ -3,6 +3,7 @@ import os sys.path.append('..') import unittest from time import time +from hashlib import sha3_256 import mimcvdf @@ -38,5 +39,11 @@ class TestVDF(unittest.TestCase): h = mimcvdf.vdf_create(b"test", dec=True) self.assertTrue(mimcvdf.vdf_verify(b"test", h)) + def test_hash_starting_with_zero(self): + s = b"test vector 1097" + self.assertEqual(0, sha3_256(s).digest()[0]) + h = mimcvdf.vdf_create(s) + self.assertTrue(mimcvdf.vdf_verify(s, h)) -unittest.main() \ No newline at end of file + +unittest.main()