C implementation of MIMC using GMP (~10x faster than Python)

This commit is contained in:
Carter Sande 2021-01-17 17:25:03 -08:00 committed by Kevin Froman
parent 9498f7aa0f
commit 505fc96c2c
8 changed files with 226 additions and 30 deletions

5
.gitignore vendored
View File

@ -1,3 +1,6 @@
venv/* venv/*
.vscode/* .vscode/*
.mypy_cache/* .mypy_cache/*
build/*
dist/*
mimcvdf.egg-info/*

View File

@ -25,18 +25,18 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
DEFAULT_ROUNDS = 8000 DEFAULT_ROUNDS = 8000
def _sha3_256_hash(data: bytes) -> int: def _sha3_256_hash(data: bytes) -> bytes:
sha3 = sha3_256() sha3 = sha3_256()
sha3.update(data) sha3.update(data)
return int.from_bytes(sha3.digest(), byteorder='big') return sha3.digest()
def vdf_create(data: bytes, rounds: int = DEFAULT_ROUNDS, dec=False) -> str: def vdf_create(data: bytes, rounds: int = DEFAULT_ROUNDS, dec=False) -> str:
assert rounds > 1 assert rounds > 1
input_data: int = _sha3_256_hash(data) input_data: int = _sha3_256_hash(data)
if not dec: if dec:
return hex(reverse_mimc(input_data, rounds)).replace('0x', '') return int.from_bytes(reverse_mimc(input_data, rounds), "big")
return reverse_mimc(input_data, rounds) return reverse_mimc(input_data, rounds).hex()
def vdf_verify( def vdf_verify(
@ -45,11 +45,14 @@ def vdf_verify(
rounds: int = DEFAULT_ROUNDS) -> bool: rounds: int = DEFAULT_ROUNDS) -> bool:
"""Verify data for test_hash generated by vdf_create.""" """Verify data for test_hash generated by vdf_create."""
assert rounds > 1 assert rounds > 1
should_match = _sha3_256_hash(data) should_match = _sha3_256_hash(data).lstrip(b'\0')
try: if isinstance(test_hash, int):
test_hash = int(test_hash, 16) test_hash = test_hash.to_bytes((test_hash.bit_length() + 7) // 8, "big")
except TypeError: else:
pass try:
test_hash = bytes.fromhex(test_hash)
except ValueError:
return False
return forward_mimc(test_hash, rounds) == should_match return forward_mimc(test_hash, rounds) == should_match

8
mimcvdf/mimc/__init__.py Normal file
View File

@ -0,0 +1,8 @@
"""Mimc hash function."""
try:
from .native import forward_mimc, reverse_mimc
is_fast = True
except ImportError:
from .fallback import forward_mimc, reverse_mimc
is_fast = False

View File

@ -1,4 +1,4 @@
"""Mimc hash function.""" """Slow fallback implementation of MIMC written in pure Python."""
""" """
This module adapted from https://github.com/OlegJakushkin/deepblockchains/blob/master/vdf/mimc/python/mimc.py by Sourabh Niyogi https://github.com/sourabhniyogi This module adapted from https://github.com/OlegJakushkin/deepblockchains/blob/master/vdf/mimc/python/mimc.py by Sourabh Niyogi https://github.com/sourabhniyogi
@ -15,22 +15,25 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
modulus = 2**256 - 2**32 * 351 + 1
little_fermat_expt = (modulus*2-1)//3 is_fast = False
round_constants = [(i**7) ^ 42 for i in range(64)]
_modulus = 2**256 - 2**32 * 351 + 1
_little_fermat_expt = (_modulus*2-1)//3
_round_constants = [(i**7) ^ 42 for i in range(64)]
def forward_mimc(inp: int, steps: int) -> int: def forward_mimc(input_data: bytes, steps: int) -> bytes:
inp = int.from_bytes(input_data, "big")
for i in range(1,steps): for i in range(1,steps):
inp = (inp**3 + round_constants[i % len(round_constants)]) % modulus inp = (inp**3 + _round_constants[i % len(_round_constants)]) % _modulus
return inp return inp.to_bytes((inp.bit_length() + 7) // 8, "big")
def reverse_mimc(input_data: int, steps: int) -> int: def reverse_mimc(input_data: bytes, steps: int) -> bytes:
rtrace = input_data rtrace = int.from_bytes(input_data, "big")
for i in range(steps - 1, 0, -1): for i in range(steps - 1, 0, -1):
rtrace = pow(rtrace-round_constants[i%len(round_constants)], rtrace = pow(rtrace-_round_constants[i%len(_round_constants)],
little_fermat_expt, modulus) _little_fermat_expt, _modulus)
return rtrace return rtrace.to_bytes((rtrace.bit_length() + 7) // 8, "big")

157
mimcvdf/mimc/native.c Normal file
View File

@ -0,0 +1,157 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <stdbool.h>
#include <gmp.h>
/*
This module adapted from https://github.com/OlegJakushkin/deepblockchains/blob/master/vdf/mimc/python/mimc.py by Sourabh Niyogi https://github.com/sourabhniyogi
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
static mpz_t MODULUS;
static mpz_t LITTLE_FERMAT_EXPT;
#define ROUND_CONSTANTS_COUNT 64
static mpz_t ROUND_CONSTANTS[ROUND_CONSTANTS_COUNT];
static void
mimc_init_constants()
{
mpz_t fortytwo;
mpz_init(fortytwo);
mpz_set_ui(fortytwo, 42);
// Set MODULUS to hex(2**256 - 2**32 * 351 + 1)
mpz_init(MODULUS);
mpz_set_str(MODULUS,
"ffffffffffffffffffffffffffffffff"
"fffffffffffffffffffffea100000001",
16);
// Set LITTLE_FERMAT_EXPT to hex((MODULUS * 2 - 1) // 3)
mpz_init(LITTLE_FERMAT_EXPT);
mpz_set_str(LITTLE_FERMAT_EXPT,
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
"aaaaaaaaaaaaaaaaaaaaa9c0aaaaaaab",
16);
// Set all the constants to (i**7 ^ 42)
for (unsigned long int i = 0; i < ROUND_CONSTANTS_COUNT; i++) {
mpz_init(ROUND_CONSTANTS[i]);
mpz_ui_pow_ui(ROUND_CONSTANTS[i], i, 7);
mpz_xor(ROUND_CONSTANTS[i], ROUND_CONSTANTS[i], fortytwo);
}
mpz_clear(fortytwo);
}
static bool
unpack_args(PyObject *args, mpz_t input, unsigned int *steps)
{
const char *data_bytes;
Py_ssize_t count;
if (!PyArg_ParseTuple(args, "y#I", &data_bytes, &count, steps)) {
return false;
}
mpz_import(input, count, 1, 1, 0, 0, data_bytes);
return true;
}
static PyObject *
convert_mpz_to_bytes(mpz_t op)
{
char string[34];
size_t size;
mpz_export(string, &size, 1, 1, 0, 0, op);
PyObject *result = PyBytes_FromStringAndSize(string, size);
return result;
}
static PyObject *
forward_mimc(PyObject *_self, PyObject *args)
{
mpz_t result;
unsigned int steps;
mpz_init(result);
if (!unpack_args(args, result, &steps)) {
mpz_clear(result);
return NULL;
}
for (unsigned int i = 1; i < steps; ++i) {
mpz_powm_ui(result, result, 3, MODULUS);
mpz_add(result, result, ROUND_CONSTANTS[i % ROUND_CONSTANTS_COUNT]);
if (mpz_cmp(result, MODULUS) >= 0) {
mpz_sub(result, result, MODULUS);
}
}
PyObject *result_obj = convert_mpz_to_bytes(result);
mpz_clear(result);
return result_obj;
}
static PyObject *
reverse_mimc(PyObject *_self, PyObject *args)
{
mpz_t result;
unsigned int steps;
mpz_init(result);
if (!unpack_args(args, result, &steps)) {
mpz_clear(result);
return NULL;
}
for (unsigned int i = steps - 1; i > 0; --i) {
mpz_sub(result, result, ROUND_CONSTANTS[i % ROUND_CONSTANTS_COUNT]);
mpz_powm(result, result, LITTLE_FERMAT_EXPT, MODULUS);
}
PyObject *result_obj = convert_mpz_to_bytes(result);
mpz_clear(result);
return result_obj;
}
static PyMethodDef Methods[] = {
{"forward_mimc", forward_mimc, METH_VARARGS,
"Run MIMC forward (the fast direction)"},
{"reverse_mimc", reverse_mimc, METH_VARARGS,
"Run MIMC in reverse (the slow direction)"},
{NULL, NULL, 0, NULL}
};
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"native",
"Fast native implementation of MIMC.",
-1,
Methods,
NULL
};
PyMODINIT_FUNC
PyInit_native()
{
mimc_init_constants();
return PyModule_Create(&moduledef);
}

View File

@ -1,4 +1,4 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages, Extension
setup(name='mimcvdf', setup(name='mimcvdf',
version='1.1.0', version='1.1.0',
@ -8,6 +8,12 @@ setup(name='mimcvdf',
url='https://www.chaoswebs.net/', url='https://www.chaoswebs.net/',
packages=find_packages(exclude=['contrib', 'docs', 'tests']), packages=find_packages(exclude=['contrib', 'docs', 'tests']),
install_requires=[], install_requires=[],
ext_package="mimcvdf",
ext_modules=[
Extension('mimc.native', ['mimcvdf/mimc/native.c'],
libraries=['gmp'], optional=True,
extra_compile_args=['-O2'])
],
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",

View File

@ -14,14 +14,23 @@ class TestMimc(unittest.TestCase):
data = b"a" * 6000000 data = b"a" * 6000000
h = sha3_256() h = sha3_256()
h.update(data) h.update(data)
data = int(h.hexdigest(), 16) data = h.digest()
forw = mimcvdf.forward_mimc(data, 2000) forw = mimcvdf.forward_mimc(data, 2000)
rev = mimcvdf.reverse_mimc(forw, 2000) rev = mimcvdf.reverse_mimc(forw, 2000)
print(data) print(data.hex())
print(forw, rev) print(forw.hex(), rev.hex())
self.assertEqual(rev, data) self.assertEqual(rev, data)
unittest.main() def test_expected_data(self):
h = sha3_256()
h.update(b"test")
data = h.digest()
self.assertEqual(
"66ea2a863bd103f2c7f190503cf8456198f31660069d4903afbd5f2e40a28695",
mimcvdf.forward_mimc(data, 2000).hex()
)
unittest.main()

View File

@ -3,6 +3,7 @@ import os
sys.path.append('..') sys.path.append('..')
import unittest import unittest
from time import time from time import time
from hashlib import sha3_256
import mimcvdf import mimcvdf
@ -38,5 +39,11 @@ class TestVDF(unittest.TestCase):
h = mimcvdf.vdf_create(b"test", dec=True) h = mimcvdf.vdf_create(b"test", dec=True)
self.assertTrue(mimcvdf.vdf_verify(b"test", h)) self.assertTrue(mimcvdf.vdf_verify(b"test", h))
def test_hash_starting_with_zero(self):
s = b"test vector 1097"
self.assertEqual(0, sha3_256(s).digest()[0])
h = mimcvdf.vdf_create(s)
self.assertTrue(mimcvdf.vdf_verify(s, h))
unittest.main()
unittest.main()