C implementation of MIMC using GMP (~10x faster than Python)
This commit is contained in:
parent
9498f7aa0f
commit
505fc96c2c
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,3 +1,6 @@
|
|||||||
venv/*
|
venv/*
|
||||||
.vscode/*
|
.vscode/*
|
||||||
.mypy_cache/*
|
.mypy_cache/*
|
||||||
|
build/*
|
||||||
|
dist/*
|
||||||
|
mimcvdf.egg-info/*
|
||||||
|
@ -25,18 +25,18 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||||||
DEFAULT_ROUNDS = 8000
|
DEFAULT_ROUNDS = 8000
|
||||||
|
|
||||||
|
|
||||||
def _sha3_256_hash(data: bytes) -> int:
|
def _sha3_256_hash(data: bytes) -> bytes:
|
||||||
sha3 = sha3_256()
|
sha3 = sha3_256()
|
||||||
sha3.update(data)
|
sha3.update(data)
|
||||||
return int.from_bytes(sha3.digest(), byteorder='big')
|
return sha3.digest()
|
||||||
|
|
||||||
|
|
||||||
def vdf_create(data: bytes, rounds: int = DEFAULT_ROUNDS, dec=False) -> str:
|
def vdf_create(data: bytes, rounds: int = DEFAULT_ROUNDS, dec=False) -> str:
|
||||||
assert rounds > 1
|
assert rounds > 1
|
||||||
input_data: int = _sha3_256_hash(data)
|
input_data: int = _sha3_256_hash(data)
|
||||||
if not dec:
|
if dec:
|
||||||
return hex(reverse_mimc(input_data, rounds)).replace('0x', '')
|
return int.from_bytes(reverse_mimc(input_data, rounds), "big")
|
||||||
return reverse_mimc(input_data, rounds)
|
return reverse_mimc(input_data, rounds).hex()
|
||||||
|
|
||||||
|
|
||||||
def vdf_verify(
|
def vdf_verify(
|
||||||
@ -45,11 +45,14 @@ def vdf_verify(
|
|||||||
rounds: int = DEFAULT_ROUNDS) -> bool:
|
rounds: int = DEFAULT_ROUNDS) -> bool:
|
||||||
"""Verify data for test_hash generated by vdf_create."""
|
"""Verify data for test_hash generated by vdf_create."""
|
||||||
assert rounds > 1
|
assert rounds > 1
|
||||||
should_match = _sha3_256_hash(data)
|
should_match = _sha3_256_hash(data).lstrip(b'\0')
|
||||||
try:
|
if isinstance(test_hash, int):
|
||||||
test_hash = int(test_hash, 16)
|
test_hash = test_hash.to_bytes((test_hash.bit_length() + 7) // 8, "big")
|
||||||
except TypeError:
|
else:
|
||||||
pass
|
try:
|
||||||
|
test_hash = bytes.fromhex(test_hash)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
return forward_mimc(test_hash, rounds) == should_match
|
return forward_mimc(test_hash, rounds) == should_match
|
||||||
|
|
||||||
|
|
||||||
|
8
mimcvdf/mimc/__init__.py
Normal file
8
mimcvdf/mimc/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
"""Mimc hash function."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .native import forward_mimc, reverse_mimc
|
||||||
|
is_fast = True
|
||||||
|
except ImportError:
|
||||||
|
from .fallback import forward_mimc, reverse_mimc
|
||||||
|
is_fast = False
|
@ -1,4 +1,4 @@
|
|||||||
"""Mimc hash function."""
|
"""Slow fallback implementation of MIMC written in pure Python."""
|
||||||
"""
|
"""
|
||||||
This module adapted from https://github.com/OlegJakushkin/deepblockchains/blob/master/vdf/mimc/python/mimc.py by Sourabh Niyogi https://github.com/sourabhniyogi
|
This module adapted from https://github.com/OlegJakushkin/deepblockchains/blob/master/vdf/mimc/python/mimc.py by Sourabh Niyogi https://github.com/sourabhniyogi
|
||||||
|
|
||||||
@ -15,22 +15,25 @@ GNU General Public License for more details.
|
|||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
modulus = 2**256 - 2**32 * 351 + 1
|
|
||||||
little_fermat_expt = (modulus*2-1)//3
|
is_fast = False
|
||||||
round_constants = [(i**7) ^ 42 for i in range(64)]
|
|
||||||
|
_modulus = 2**256 - 2**32 * 351 + 1
|
||||||
|
_little_fermat_expt = (_modulus*2-1)//3
|
||||||
|
_round_constants = [(i**7) ^ 42 for i in range(64)]
|
||||||
|
|
||||||
|
|
||||||
def forward_mimc(inp: int, steps: int) -> int:
|
def forward_mimc(input_data: bytes, steps: int) -> bytes:
|
||||||
|
inp = int.from_bytes(input_data, "big")
|
||||||
for i in range(1,steps):
|
for i in range(1,steps):
|
||||||
inp = (inp**3 + round_constants[i % len(round_constants)]) % modulus
|
inp = (inp**3 + _round_constants[i % len(_round_constants)]) % _modulus
|
||||||
return inp
|
return inp.to_bytes((inp.bit_length() + 7) // 8, "big")
|
||||||
|
|
||||||
|
|
||||||
def reverse_mimc(input_data: int, steps: int) -> int:
|
def reverse_mimc(input_data: bytes, steps: int) -> bytes:
|
||||||
rtrace = input_data
|
rtrace = int.from_bytes(input_data, "big")
|
||||||
|
|
||||||
for i in range(steps - 1, 0, -1):
|
for i in range(steps - 1, 0, -1):
|
||||||
rtrace = pow(rtrace-round_constants[i%len(round_constants)],
|
rtrace = pow(rtrace-_round_constants[i%len(_round_constants)],
|
||||||
little_fermat_expt, modulus)
|
_little_fermat_expt, _modulus)
|
||||||
return rtrace
|
return rtrace.to_bytes((rtrace.bit_length() + 7) // 8, "big")
|
||||||
|
|
157
mimcvdf/mimc/native.c
Normal file
157
mimcvdf/mimc/native.c
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
#define PY_SSIZE_T_CLEAN
|
||||||
|
#include <Python.h>
|
||||||
|
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include <gmp.h>
|
||||||
|
|
||||||
|
/*
|
||||||
|
This module adapted from https://github.com/OlegJakushkin/deepblockchains/blob/master/vdf/mimc/python/mimc.py by Sourabh Niyogi https://github.com/sourabhniyogi
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
static mpz_t MODULUS;
|
||||||
|
static mpz_t LITTLE_FERMAT_EXPT;
|
||||||
|
#define ROUND_CONSTANTS_COUNT 64
|
||||||
|
static mpz_t ROUND_CONSTANTS[ROUND_CONSTANTS_COUNT];
|
||||||
|
|
||||||
|
static void
|
||||||
|
mimc_init_constants()
|
||||||
|
{
|
||||||
|
mpz_t fortytwo;
|
||||||
|
|
||||||
|
mpz_init(fortytwo);
|
||||||
|
mpz_set_ui(fortytwo, 42);
|
||||||
|
|
||||||
|
// Set MODULUS to hex(2**256 - 2**32 * 351 + 1)
|
||||||
|
mpz_init(MODULUS);
|
||||||
|
mpz_set_str(MODULUS,
|
||||||
|
"ffffffffffffffffffffffffffffffff"
|
||||||
|
"fffffffffffffffffffffea100000001",
|
||||||
|
16);
|
||||||
|
|
||||||
|
// Set LITTLE_FERMAT_EXPT to hex((MODULUS * 2 - 1) // 3)
|
||||||
|
mpz_init(LITTLE_FERMAT_EXPT);
|
||||||
|
mpz_set_str(LITTLE_FERMAT_EXPT,
|
||||||
|
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||||
|
"aaaaaaaaaaaaaaaaaaaaa9c0aaaaaaab",
|
||||||
|
16);
|
||||||
|
|
||||||
|
// Set all the constants to (i**7 ^ 42)
|
||||||
|
for (unsigned long int i = 0; i < ROUND_CONSTANTS_COUNT; i++) {
|
||||||
|
mpz_init(ROUND_CONSTANTS[i]);
|
||||||
|
mpz_ui_pow_ui(ROUND_CONSTANTS[i], i, 7);
|
||||||
|
mpz_xor(ROUND_CONSTANTS[i], ROUND_CONSTANTS[i], fortytwo);
|
||||||
|
}
|
||||||
|
|
||||||
|
mpz_clear(fortytwo);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool
|
||||||
|
unpack_args(PyObject *args, mpz_t input, unsigned int *steps)
|
||||||
|
{
|
||||||
|
const char *data_bytes;
|
||||||
|
Py_ssize_t count;
|
||||||
|
|
||||||
|
if (!PyArg_ParseTuple(args, "y#I", &data_bytes, &count, steps)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
mpz_import(input, count, 1, 1, 0, 0, data_bytes);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
convert_mpz_to_bytes(mpz_t op)
|
||||||
|
{
|
||||||
|
char string[34];
|
||||||
|
size_t size;
|
||||||
|
|
||||||
|
mpz_export(string, &size, 1, 1, 0, 0, op);
|
||||||
|
PyObject *result = PyBytes_FromStringAndSize(string, size);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
forward_mimc(PyObject *_self, PyObject *args)
|
||||||
|
{
|
||||||
|
mpz_t result;
|
||||||
|
unsigned int steps;
|
||||||
|
|
||||||
|
mpz_init(result);
|
||||||
|
if (!unpack_args(args, result, &steps)) {
|
||||||
|
mpz_clear(result);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned int i = 1; i < steps; ++i) {
|
||||||
|
mpz_powm_ui(result, result, 3, MODULUS);
|
||||||
|
mpz_add(result, result, ROUND_CONSTANTS[i % ROUND_CONSTANTS_COUNT]);
|
||||||
|
if (mpz_cmp(result, MODULUS) >= 0) {
|
||||||
|
mpz_sub(result, result, MODULUS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject *result_obj = convert_mpz_to_bytes(result);
|
||||||
|
mpz_clear(result);
|
||||||
|
return result_obj;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
reverse_mimc(PyObject *_self, PyObject *args)
|
||||||
|
{
|
||||||
|
mpz_t result;
|
||||||
|
unsigned int steps;
|
||||||
|
|
||||||
|
mpz_init(result);
|
||||||
|
if (!unpack_args(args, result, &steps)) {
|
||||||
|
mpz_clear(result);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned int i = steps - 1; i > 0; --i) {
|
||||||
|
mpz_sub(result, result, ROUND_CONSTANTS[i % ROUND_CONSTANTS_COUNT]);
|
||||||
|
mpz_powm(result, result, LITTLE_FERMAT_EXPT, MODULUS);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject *result_obj = convert_mpz_to_bytes(result);
|
||||||
|
mpz_clear(result);
|
||||||
|
return result_obj;
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyMethodDef Methods[] = {
|
||||||
|
{"forward_mimc", forward_mimc, METH_VARARGS,
|
||||||
|
"Run MIMC forward (the fast direction)"},
|
||||||
|
{"reverse_mimc", reverse_mimc, METH_VARARGS,
|
||||||
|
"Run MIMC in reverse (the slow direction)"},
|
||||||
|
{NULL, NULL, 0, NULL}
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct PyModuleDef moduledef = {
|
||||||
|
PyModuleDef_HEAD_INIT,
|
||||||
|
"native",
|
||||||
|
"Fast native implementation of MIMC.",
|
||||||
|
-1,
|
||||||
|
Methods,
|
||||||
|
NULL
|
||||||
|
};
|
||||||
|
|
||||||
|
PyMODINIT_FUNC
|
||||||
|
PyInit_native()
|
||||||
|
{
|
||||||
|
mimc_init_constants();
|
||||||
|
|
||||||
|
return PyModule_Create(&moduledef);
|
||||||
|
}
|
8
setup.py
8
setup.py
@ -1,4 +1,4 @@
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages, Extension
|
||||||
|
|
||||||
setup(name='mimcvdf',
|
setup(name='mimcvdf',
|
||||||
version='1.1.0',
|
version='1.1.0',
|
||||||
@ -8,6 +8,12 @@ setup(name='mimcvdf',
|
|||||||
url='https://www.chaoswebs.net/',
|
url='https://www.chaoswebs.net/',
|
||||||
packages=find_packages(exclude=['contrib', 'docs', 'tests']),
|
packages=find_packages(exclude=['contrib', 'docs', 'tests']),
|
||||||
install_requires=[],
|
install_requires=[],
|
||||||
|
ext_package="mimcvdf",
|
||||||
|
ext_modules=[
|
||||||
|
Extension('mimc.native', ['mimcvdf/mimc/native.c'],
|
||||||
|
libraries=['gmp'], optional=True,
|
||||||
|
extra_compile_args=['-O2'])
|
||||||
|
],
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
|
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
|
||||||
|
@ -14,14 +14,23 @@ class TestMimc(unittest.TestCase):
|
|||||||
data = b"a" * 6000000
|
data = b"a" * 6000000
|
||||||
h = sha3_256()
|
h = sha3_256()
|
||||||
h.update(data)
|
h.update(data)
|
||||||
data = int(h.hexdigest(), 16)
|
data = h.digest()
|
||||||
|
|
||||||
forw = mimcvdf.forward_mimc(data, 2000)
|
forw = mimcvdf.forward_mimc(data, 2000)
|
||||||
|
|
||||||
rev = mimcvdf.reverse_mimc(forw, 2000)
|
rev = mimcvdf.reverse_mimc(forw, 2000)
|
||||||
|
|
||||||
print(data)
|
print(data.hex())
|
||||||
print(forw, rev)
|
print(forw.hex(), rev.hex())
|
||||||
self.assertEqual(rev, data)
|
self.assertEqual(rev, data)
|
||||||
|
|
||||||
unittest.main()
|
def test_expected_data(self):
|
||||||
|
h = sha3_256()
|
||||||
|
h.update(b"test")
|
||||||
|
data = h.digest()
|
||||||
|
self.assertEqual(
|
||||||
|
"66ea2a863bd103f2c7f190503cf8456198f31660069d4903afbd5f2e40a28695",
|
||||||
|
mimcvdf.forward_mimc(data, 2000).hex()
|
||||||
|
)
|
||||||
|
|
||||||
|
unittest.main()
|
||||||
|
@ -3,6 +3,7 @@ import os
|
|||||||
sys.path.append('..')
|
sys.path.append('..')
|
||||||
import unittest
|
import unittest
|
||||||
from time import time
|
from time import time
|
||||||
|
from hashlib import sha3_256
|
||||||
|
|
||||||
import mimcvdf
|
import mimcvdf
|
||||||
|
|
||||||
@ -38,5 +39,11 @@ class TestVDF(unittest.TestCase):
|
|||||||
h = mimcvdf.vdf_create(b"test", dec=True)
|
h = mimcvdf.vdf_create(b"test", dec=True)
|
||||||
self.assertTrue(mimcvdf.vdf_verify(b"test", h))
|
self.assertTrue(mimcvdf.vdf_verify(b"test", h))
|
||||||
|
|
||||||
|
def test_hash_starting_with_zero(self):
|
||||||
|
s = b"test vector 1097"
|
||||||
|
self.assertEqual(0, sha3_256(s).digest()[0])
|
||||||
|
h = mimcvdf.vdf_create(s)
|
||||||
|
self.assertTrue(mimcvdf.vdf_verify(s, h))
|
||||||
|
|
||||||
unittest.main()
|
|
||||||
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user