C implementation of MIMC using GMP (~10x faster than Python)
This commit is contained in:
parent
9498f7aa0f
commit
505fc96c2c
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,3 +1,6 @@
|
||||
venv/*
|
||||
.vscode/*
|
||||
.mypy_cache/*
|
||||
.mypy_cache/*
|
||||
build/*
|
||||
dist/*
|
||||
mimcvdf.egg-info/*
|
||||
|
@ -25,18 +25,18 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
DEFAULT_ROUNDS = 8000
|
||||
|
||||
|
||||
def _sha3_256_hash(data: bytes) -> int:
|
||||
def _sha3_256_hash(data: bytes) -> bytes:
|
||||
sha3 = sha3_256()
|
||||
sha3.update(data)
|
||||
return int.from_bytes(sha3.digest(), byteorder='big')
|
||||
return sha3.digest()
|
||||
|
||||
|
||||
def vdf_create(data: bytes, rounds: int = DEFAULT_ROUNDS, dec=False) -> str:
|
||||
assert rounds > 1
|
||||
input_data: int = _sha3_256_hash(data)
|
||||
if not dec:
|
||||
return hex(reverse_mimc(input_data, rounds)).replace('0x', '')
|
||||
return reverse_mimc(input_data, rounds)
|
||||
if dec:
|
||||
return int.from_bytes(reverse_mimc(input_data, rounds), "big")
|
||||
return reverse_mimc(input_data, rounds).hex()
|
||||
|
||||
|
||||
def vdf_verify(
|
||||
@ -45,11 +45,14 @@ def vdf_verify(
|
||||
rounds: int = DEFAULT_ROUNDS) -> bool:
|
||||
"""Verify data for test_hash generated by vdf_create."""
|
||||
assert rounds > 1
|
||||
should_match = _sha3_256_hash(data)
|
||||
try:
|
||||
test_hash = int(test_hash, 16)
|
||||
except TypeError:
|
||||
pass
|
||||
should_match = _sha3_256_hash(data).lstrip(b'\0')
|
||||
if isinstance(test_hash, int):
|
||||
test_hash = test_hash.to_bytes((test_hash.bit_length() + 7) // 8, "big")
|
||||
else:
|
||||
try:
|
||||
test_hash = bytes.fromhex(test_hash)
|
||||
except ValueError:
|
||||
return False
|
||||
return forward_mimc(test_hash, rounds) == should_match
|
||||
|
||||
|
||||
|
8
mimcvdf/mimc/__init__.py
Normal file
8
mimcvdf/mimc/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
"""Mimc hash function."""
|
||||
|
||||
try:
|
||||
from .native import forward_mimc, reverse_mimc
|
||||
is_fast = True
|
||||
except ImportError:
|
||||
from .fallback import forward_mimc, reverse_mimc
|
||||
is_fast = False
|
@ -1,4 +1,4 @@
|
||||
"""Mimc hash function."""
|
||||
"""Slow fallback implementation of MIMC written in pure Python."""
|
||||
"""
|
||||
This module adapted from https://github.com/OlegJakushkin/deepblockchains/blob/master/vdf/mimc/python/mimc.py by Sourabh Niyogi https://github.com/sourabhniyogi
|
||||
|
||||
@ -15,22 +15,25 @@ GNU General Public License for more details.
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
modulus = 2**256 - 2**32 * 351 + 1
|
||||
little_fermat_expt = (modulus*2-1)//3
|
||||
round_constants = [(i**7) ^ 42 for i in range(64)]
|
||||
|
||||
is_fast = False
|
||||
|
||||
_modulus = 2**256 - 2**32 * 351 + 1
|
||||
_little_fermat_expt = (_modulus*2-1)//3
|
||||
_round_constants = [(i**7) ^ 42 for i in range(64)]
|
||||
|
||||
|
||||
def forward_mimc(inp: int, steps: int) -> int:
|
||||
def forward_mimc(input_data: bytes, steps: int) -> bytes:
|
||||
inp = int.from_bytes(input_data, "big")
|
||||
for i in range(1,steps):
|
||||
inp = (inp**3 + round_constants[i % len(round_constants)]) % modulus
|
||||
return inp
|
||||
inp = (inp**3 + _round_constants[i % len(_round_constants)]) % _modulus
|
||||
return inp.to_bytes((inp.bit_length() + 7) // 8, "big")
|
||||
|
||||
|
||||
def reverse_mimc(input_data: int, steps: int) -> int:
|
||||
rtrace = input_data
|
||||
def reverse_mimc(input_data: bytes, steps: int) -> bytes:
|
||||
rtrace = int.from_bytes(input_data, "big")
|
||||
|
||||
for i in range(steps - 1, 0, -1):
|
||||
rtrace = pow(rtrace-round_constants[i%len(round_constants)],
|
||||
little_fermat_expt, modulus)
|
||||
return rtrace
|
||||
|
||||
rtrace = pow(rtrace-_round_constants[i%len(_round_constants)],
|
||||
_little_fermat_expt, _modulus)
|
||||
return rtrace.to_bytes((rtrace.bit_length() + 7) // 8, "big")
|
157
mimcvdf/mimc/native.c
Normal file
157
mimcvdf/mimc/native.c
Normal file
@ -0,0 +1,157 @@
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <gmp.h>
|
||||
|
||||
/*
|
||||
This module adapted from https://github.com/OlegJakushkin/deepblockchains/blob/master/vdf/mimc/python/mimc.py by Sourabh Niyogi https://github.com/sourabhniyogi
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
static mpz_t MODULUS;
|
||||
static mpz_t LITTLE_FERMAT_EXPT;
|
||||
#define ROUND_CONSTANTS_COUNT 64
|
||||
static mpz_t ROUND_CONSTANTS[ROUND_CONSTANTS_COUNT];
|
||||
|
||||
static void
|
||||
mimc_init_constants()
|
||||
{
|
||||
mpz_t fortytwo;
|
||||
|
||||
mpz_init(fortytwo);
|
||||
mpz_set_ui(fortytwo, 42);
|
||||
|
||||
// Set MODULUS to hex(2**256 - 2**32 * 351 + 1)
|
||||
mpz_init(MODULUS);
|
||||
mpz_set_str(MODULUS,
|
||||
"ffffffffffffffffffffffffffffffff"
|
||||
"fffffffffffffffffffffea100000001",
|
||||
16);
|
||||
|
||||
// Set LITTLE_FERMAT_EXPT to hex((MODULUS * 2 - 1) // 3)
|
||||
mpz_init(LITTLE_FERMAT_EXPT);
|
||||
mpz_set_str(LITTLE_FERMAT_EXPT,
|
||||
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
"aaaaaaaaaaaaaaaaaaaaa9c0aaaaaaab",
|
||||
16);
|
||||
|
||||
// Set all the constants to (i**7 ^ 42)
|
||||
for (unsigned long int i = 0; i < ROUND_CONSTANTS_COUNT; i++) {
|
||||
mpz_init(ROUND_CONSTANTS[i]);
|
||||
mpz_ui_pow_ui(ROUND_CONSTANTS[i], i, 7);
|
||||
mpz_xor(ROUND_CONSTANTS[i], ROUND_CONSTANTS[i], fortytwo);
|
||||
}
|
||||
|
||||
mpz_clear(fortytwo);
|
||||
}
|
||||
|
||||
static bool
|
||||
unpack_args(PyObject *args, mpz_t input, unsigned int *steps)
|
||||
{
|
||||
const char *data_bytes;
|
||||
Py_ssize_t count;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#I", &data_bytes, &count, steps)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
mpz_import(input, count, 1, 1, 0, 0, data_bytes);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static PyObject *
|
||||
convert_mpz_to_bytes(mpz_t op)
|
||||
{
|
||||
char string[34];
|
||||
size_t size;
|
||||
|
||||
mpz_export(string, &size, 1, 1, 0, 0, op);
|
||||
PyObject *result = PyBytes_FromStringAndSize(string, size);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
forward_mimc(PyObject *_self, PyObject *args)
|
||||
{
|
||||
mpz_t result;
|
||||
unsigned int steps;
|
||||
|
||||
mpz_init(result);
|
||||
if (!unpack_args(args, result, &steps)) {
|
||||
mpz_clear(result);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
for (unsigned int i = 1; i < steps; ++i) {
|
||||
mpz_powm_ui(result, result, 3, MODULUS);
|
||||
mpz_add(result, result, ROUND_CONSTANTS[i % ROUND_CONSTANTS_COUNT]);
|
||||
if (mpz_cmp(result, MODULUS) >= 0) {
|
||||
mpz_sub(result, result, MODULUS);
|
||||
}
|
||||
}
|
||||
|
||||
PyObject *result_obj = convert_mpz_to_bytes(result);
|
||||
mpz_clear(result);
|
||||
return result_obj;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
reverse_mimc(PyObject *_self, PyObject *args)
|
||||
{
|
||||
mpz_t result;
|
||||
unsigned int steps;
|
||||
|
||||
mpz_init(result);
|
||||
if (!unpack_args(args, result, &steps)) {
|
||||
mpz_clear(result);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
for (unsigned int i = steps - 1; i > 0; --i) {
|
||||
mpz_sub(result, result, ROUND_CONSTANTS[i % ROUND_CONSTANTS_COUNT]);
|
||||
mpz_powm(result, result, LITTLE_FERMAT_EXPT, MODULUS);
|
||||
}
|
||||
|
||||
PyObject *result_obj = convert_mpz_to_bytes(result);
|
||||
mpz_clear(result);
|
||||
return result_obj;
|
||||
}
|
||||
|
||||
static PyMethodDef Methods[] = {
|
||||
{"forward_mimc", forward_mimc, METH_VARARGS,
|
||||
"Run MIMC forward (the fast direction)"},
|
||||
{"reverse_mimc", reverse_mimc, METH_VARARGS,
|
||||
"Run MIMC in reverse (the slow direction)"},
|
||||
{NULL, NULL, 0, NULL}
|
||||
};
|
||||
|
||||
static struct PyModuleDef moduledef = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"native",
|
||||
"Fast native implementation of MIMC.",
|
||||
-1,
|
||||
Methods,
|
||||
NULL
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit_native()
|
||||
{
|
||||
mimc_init_constants();
|
||||
|
||||
return PyModule_Create(&moduledef);
|
||||
}
|
8
setup.py
8
setup.py
@ -1,4 +1,4 @@
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools import setup, find_packages, Extension
|
||||
|
||||
setup(name='mimcvdf',
|
||||
version='1.1.0',
|
||||
@ -8,6 +8,12 @@ setup(name='mimcvdf',
|
||||
url='https://www.chaoswebs.net/',
|
||||
packages=find_packages(exclude=['contrib', 'docs', 'tests']),
|
||||
install_requires=[],
|
||||
ext_package="mimcvdf",
|
||||
ext_modules=[
|
||||
Extension('mimc.native', ['mimcvdf/mimc/native.c'],
|
||||
libraries=['gmp'], optional=True,
|
||||
extra_compile_args=['-O2'])
|
||||
],
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
|
||||
|
@ -14,14 +14,23 @@ class TestMimc(unittest.TestCase):
|
||||
data = b"a" * 6000000
|
||||
h = sha3_256()
|
||||
h.update(data)
|
||||
data = int(h.hexdigest(), 16)
|
||||
data = h.digest()
|
||||
|
||||
forw = mimcvdf.forward_mimc(data, 2000)
|
||||
|
||||
rev = mimcvdf.reverse_mimc(forw, 2000)
|
||||
|
||||
print(data)
|
||||
print(forw, rev)
|
||||
print(data.hex())
|
||||
print(forw.hex(), rev.hex())
|
||||
self.assertEqual(rev, data)
|
||||
|
||||
unittest.main()
|
||||
def test_expected_data(self):
|
||||
h = sha3_256()
|
||||
h.update(b"test")
|
||||
data = h.digest()
|
||||
self.assertEqual(
|
||||
"66ea2a863bd103f2c7f190503cf8456198f31660069d4903afbd5f2e40a28695",
|
||||
mimcvdf.forward_mimc(data, 2000).hex()
|
||||
)
|
||||
|
||||
unittest.main()
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
sys.path.append('..')
|
||||
import unittest
|
||||
from time import time
|
||||
from hashlib import sha3_256
|
||||
|
||||
import mimcvdf
|
||||
|
||||
@ -38,5 +39,11 @@ class TestVDF(unittest.TestCase):
|
||||
h = mimcvdf.vdf_create(b"test", dec=True)
|
||||
self.assertTrue(mimcvdf.vdf_verify(b"test", h))
|
||||
|
||||
def test_hash_starting_with_zero(self):
|
||||
s = b"test vector 1097"
|
||||
self.assertEqual(0, sha3_256(s).digest()[0])
|
||||
h = mimcvdf.vdf_create(s)
|
||||
self.assertTrue(mimcvdf.vdf_verify(s, h))
|
||||
|
||||
unittest.main()
|
||||
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user