diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 0a3e6e964945..840bab896a89 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -9,6 +9,7 @@ jobs: strategy: max-parallel: 4 matrix: + python-version: [3.7, 3.8] os: [ubuntu-latest, macOS-latest] steps: @@ -16,27 +17,32 @@ jobs: - name: Setup Python environment uses: actions/setup-python@v1.1.1 with: - python-version: 3.7 # optional, default is 3.x + python-version: ${{ matrix.python-version }} - name: Install dependencies - env: - CHIA_MACHINE_SSH_KEY: ${{ secrets.CHIA_MACHINE_SSH_KEY }} - GIT_SSH_COMMAND: "ssh -o StrictHostKeyChecking=no" run: | - eval "$(ssh-agent -s)" - ssh-add - <<< "${CHIA_MACHINE_SSH_KEY}" - git submodule update --init --recursive - brew update && brew install gmp || echo "" - python3 -m venv .venv + brew update && brew install gmp boost || echo "" + sh install.sh + - name: Test proof of space + run: | + cd lib/chiapos + mkdir -p build && cd build + cmake ../ + cmake --build . -- -j 6 + ./RunTests + cd ../../../ + - name: Test vdf + run: | . .venv/bin/activate - pip install -e . - pip install -r requirements.txt - - name: Lint with flake8 + cd lib/chiavdf/fast_vdf + python python_bindings/test_verifier.py + cd ../../../ + - name: Lint source with flake8 run: | ./.venv/bin/flake8 src - - name: Lint with mypy + - name: Lint source with mypy run: | ./.venv/bin/mypy src tests - - name: Test with pytest + - name: Test blockchain code with pytest run: | ./.venv/bin/py.test tests -s -v diff --git a/.gitignore b/.gitignore index 5c8c1ab02254..dbcbae704a78 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ mongod.log* fndb_test* blockchain_test* *.db +*.db-journal # Logs *.log diff --git a/README.md b/README.md index c3e4f4ca381d..0955368f8510 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ sh install.sh . .venv/bin/activate ``` -### CentOS 7 +### CentOS 7.7 ```bash sudo yum update @@ -63,9 +63,28 @@ git clone https://github.com/Chia-Network/chia-blockchain.git cd chia-blockchain sh install.sh + . .venv/bin/activate ``` +### RHEL 8.1 + +```bash +sudo yum update +sudo yum install gcc-c++ cmake3 git openssl openssl-devel +sudo yum install wget make libffi-devel gmp-devel sqlite-devel + +# Install Python 3.7.5 (current rpm's are 3.6.x) +wget https://www.python.org/ftp/python/3.7.5/Python-3.7.5.tgz +tar -zxvf Python-3.7.5.tgz; cd Python-3.7.5 +./configure --enable-optimizations; sudo make install; cd .. + +git clone https://github.com/Chia-Network/chia-blockchain.git +cd chia-blockchain +sh install.sh + +. .venv/bin/activate +``` ### Windows (WSL + Ubuntu) #### Install WSL + Ubuntu 18.04 LTS, upgrade to Ubuntu 19.x @@ -89,13 +108,14 @@ git clone https://github.com/Chia-Network/chia-blockchain.git cd chia-blockchain sudo sh install.sh + . .venv/bin/activate ``` #### Alternate method for Ubuntu 18.04 LTS In `./install.sh`: Change `python3` to `python3.7` -Each line that starts with `pip ...` becomes `python -m pip ...` +Each line that starts with `pip ...` becomes `python3.7 -m pip ...` ```bash sudo apt-get -y update @@ -129,14 +149,14 @@ sh install.sh ## Step 2: Install timelord (optional) Note: this step is needed only if you intend to run a timelord or a local simulation. -These assume you've already successfully installed harvester, farmer, plotting, and full node above. boost 1.67 or newer is required on all platforms. +These assume you've already successfully installed harvester, farmer, plotting, and full node above. boost 1.66 or newer is required on all platforms. ### Ubuntu/Debian ```bash cd chia-blockchain sh install_timelord.sh ``` -### Amazon Linux 2 and CentOS 7 +### Amazon Linux 2 and CentOS 7.7 ```bash #Only for Amazon Linux 2 sudo amazon-linux-extras install epel @@ -149,12 +169,20 @@ tar -zxvf boost_1_72_0.tar.gz cd boost_1_72_0 ./bootstrap.sh --prefix=/usr/local sudo ./b2 install --prefix=/usr/local --with=all; cd .. +LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib cd chia-blockchain sh install_timelord.sh ``` +### RHEL 8.1 +```bash +sudo yum install mpfr-devel boost boost-devel + +cd chia-blockchain +sh install_timelord.sh +``` ### Windows (WSL + Ubuntu) #### Install WSL + Ubuntu upgraded to 19.x ```bash @@ -189,7 +217,9 @@ python -m scripts.regenerate_keys ## Step 4a: Run a full node To run a full node on port 8444, and connect to the testnet, run the following command. This will also start an ssh server in port 8222 for the UI, which you can connect to -to see the state of the node. +to see the state of the node. If you want to see std::out log output, modify the logging.std_out +variable in ./config/config.yaml. + ```bash ./scripts/run_full_node.sh ssh -p 8222 localhost @@ -201,7 +231,8 @@ Farmers are entities in the network who use their hard drive space to try to cre blocks (like Bitcoin's miners), and earn block rewards. First, you must generate some hard drive plots, which can take a long time depending on the [size of the plots](https://github.com/Chia-Network/chia-blockchain/wiki/k-sizes) (the k variable). Then, run the farmer + full node with the following script. A full node is also started, -which you can ssh into to view the node UI (previous ssh command). +which you can ssh into to view the node UI (previous ssh command). You can also change the working directory and +final directory for plotting, with the "-t" and "-d" arguments to the create_plots script. ```bash python -m scripts.create_plots -k 20 -n 10 sh ./scripts/run_farming.sh @@ -229,8 +260,6 @@ Due to the nature of proof of space lookups by the harvester in the current alph the number of plots on a physical drive to 50 or less. This limit should significantly increase before beta. You can also run the simulation, which runs all servers and multiple full nodes, locally, at once. -If you want to run the simulation, change the introducer ip in ./config/config.yaml so that the -full node points to the local introducer (127.0.0.1:8445). Note the the simulation is local only and requires installation of timelords and VDFs. @@ -240,3 +269,15 @@ ips to external peers. ```bash sh ./scripts/run_all_simulation.sh ``` + +For increased networking performance, install uvloop: +```bash +pip install -e ".[uvloop]" +``` + +You can also use the [HTTP RPC](https://github.com/Chia-Network/chia-blockchain/wiki/Networking-and-Serialization#rpc) api to access information and control the full node: + + +```bash +curl -X POST http://localhost:8555/get_blockchain_state +``` diff --git a/config/config.yaml b/config/config.yaml index 43a15dfb4d5d..0f39eae3ebf2 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -2,6 +2,11 @@ network_id: testnet # testnet/mainnet # Send a ping to all peers after ping_interval seconds ping_interval: 120 +# Controls logging of all servers (harvester, farmer, etc..). Each one can be overriden. +logging: &logging + log_stdout: False # If True, outputs to stdout instead of a file + log_filename: "chia.log" + harvester: # The harvester server (if run) will run on this host and port host: 127.0.0.1 @@ -9,8 +14,9 @@ harvester: farmer_peer: host: 127.0.0.1 port: 8447 - # Location of all the plots, default ./plots + # Location of all the plots, default ./plots, for relative paths in plots.yaml. # plot_root: "/mnt/pos" + logging: *logging farmer: # The farmer server (if run) will run on this host and port @@ -26,7 +32,9 @@ farmer: # To send a share to a pool, a block must be faster than this, in seconds pool_share_threshold: 12000 # To send to the full node, a block must be faster than this, in seconds + propagate_threshold: 10000 + logging: *logging timelord: # The timelord server (if run) will run on this host and port @@ -49,11 +57,20 @@ timelord: full_node_peer: host: 127.0.0.1 port: 8444 + logging: *logging full_node: # The full node server (if run) will run on this host and port host: 127.0.0.1 port: 8444 + + # Run multiple nodes with different databases by changing the database_id + database_id: 1 + + # If True, starts an RPC server at the following port + start_rpc_server: True + rpc_port: 8555 + enable_upnp: True # Don't send any more than these number of headers and blocks, in one message max_headers_to_send: 25 @@ -63,8 +80,6 @@ full_node: # If node is more than these blocks behind, will do a sync sync_blocks_behind_threshold: 20 - # This SSH key is for the ui SSH server - ssh_filename: config/ssh_host_key # How often to connect to introducer if we need to learn more peers introducer_connect_interval: 500 # Continue trying to connect to more peers until this number of connections @@ -72,6 +87,9 @@ full_node: # Only connect to peers who we have heard about in the last recent_peer_threshold seconds recent_peer_threshold: 6000 + connect_to_farmer: False + connect_to_timelord: False + farmer_peer: host: 127.0.0.1 port: 8447 @@ -79,11 +97,21 @@ full_node: host: 127.0.0.1 port: 8446 introducer_peer: - # To run the simulation, set host to 127.0.0.1 and port to 8445 - # host: 127.0.0.1 - # port: 8445 host: introducer.chia.net # Chia AWS introducer IPv4/IPv6 port: 8444 + logging: *logging + +ui: + # The ui node server (if run) will run on this host and port + host: 127.0.0.1 + port: 8222 + + # Which port to use to communicate with the full node + rpc_port: 8555 + + # This SSH key is for the ui SSH server + ssh_filename: config/ssh_host_key + logging: *logging introducer: host: 127.0.0.1 @@ -92,3 +120,4 @@ introducer: # The introducer will only return peers who it has seen in the last # recent_peer_threshold seconds recent_peer_threshold: 6000 + logging: *logging diff --git a/lib/chiapos/CMakeLists.txt b/lib/chiapos/CMakeLists.txt index 6108647df2d3..c45cf137ae68 100644 --- a/lib/chiapos/CMakeLists.txt +++ b/lib/chiapos/CMakeLists.txt @@ -31,7 +31,21 @@ add_subdirectory(lib/pybind11) pybind11_add_module(chiapos ${CMAKE_CURRENT_SOURCE_DIR}/python-bindings/chiapos.cpp) -set (CMAKE_CXX_FLAGS "-g -O3 -Wall -msse2 -msse -march=native -std=c++11 -maes") +set (CMAKE_CXX_FLAGS "-g -O3 -Wall -msse2 -msse -march=native -std=c++1z -maes") +try_run(CMAKE_AESNI_TEST_RUN_RESULT + CMAKE_AESNI_TEST_COMPILE_RESULT + ${CMAKE_CURRENT_BINARY_DIR}/cmake_aesni_test + ${CMAKE_CURRENT_SOURCE_DIR}/src/cmake_aesni_test.cpp) + +# Did compilation succeed and process return 0 (success)? +IF("${CMAKE_AESNI_TEST_COMPILE_RESULT}" AND ("${CMAKE_AESNI_TEST_RUN_RESULT}" EQUAL 0)) + message(STATUS "AESNI Enabled") + set (CMAKE_CXX_FLAGS "-g -O3 -Wall -msse2 -msse -march=native -std=c++17 -maes") +ELSE() + message(STATUS "AESNI Disabled") + add_compile_definitions (DISABLE_AESNI) + set (CMAKE_CXX_FLAGS "-g -O3 -Wall -march=native -std=c++17") +ENDIF() add_executable(ProofOfSpace src/cli.cpp @@ -45,7 +59,15 @@ add_executable(RunTests tests/test-main.cpp tests/test.cpp ) -target_link_libraries(chiapos PRIVATE fse) -target_link_libraries(ProofOfSpace fse) -target_link_libraries(HellmanAttacks fse) -target_link_libraries(RunTests fse) + +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + target_link_libraries(chiapos PRIVATE fse) + target_link_libraries(ProofOfSpace fse) + target_link_libraries(HellmanAttacks fse) + target_link_libraries(RunTests fse) +else() + target_link_libraries(chiapos PRIVATE fse stdc++fs) + target_link_libraries(ProofOfSpace fse stdc++fs) + target_link_libraries(HellmanAttacks fse stdc++fs) + target_link_libraries(RunTests fse stdc++fs) +endif() diff --git a/lib/chiapos/README.md b/lib/chiapos/README.md index 9830a5c5d15a..1e4fb4732b42 100644 --- a/lib/chiapos/README.md +++ b/lib/chiapos/README.md @@ -8,8 +8,7 @@ Only runs on 64 bit architectures with AES-NI support. Read the [Proof of Space ### Compile ```bash -git submodule update --init --recursive -mkdir build && cd build +mkdir -p build && cd build cmake ../ cmake --build . -- -j 6 ``` diff --git a/lib/chiapos/prepare.sh b/lib/chiapos/prepare.sh new file mode 100755 index 000000000000..cfc4bb569ce3 --- /dev/null +++ b/lib/chiapos/prepare.sh @@ -0,0 +1,5 @@ +#!/bin/bash +git submodule update --init --recursive +mkdir build -p +cd build +cmake ../ \ No newline at end of file diff --git a/lib/chiapos/python-bindings/chiapos.cpp b/lib/chiapos/python-bindings/chiapos.cpp index 57129d872854..a06ca0269bfc 100644 --- a/lib/chiapos/python-bindings/chiapos.cpp +++ b/lib/chiapos/python-bindings/chiapos.cpp @@ -34,13 +34,14 @@ PYBIND11_MODULE(chiapos, m) { py::class_(m, "DiskPlotter") .def(py::init<>()) - .def("create_plot_disk", [](DiskPlotter &dp, const std::string filename, uint8_t k, + .def("create_plot_disk", [](DiskPlotter &dp, const std::string tmp_dir, const std::string final_dir, + const std::string filename, uint8_t k, const py::bytes &memo, const py::bytes &id) { std::string memo_str(memo); const uint8_t* memo_ptr = reinterpret_cast(memo_str.data()); std::string id_str(id); const uint8_t* id_ptr = reinterpret_cast(id_str.data()); - dp.CreatePlotDisk(filename, k, memo_ptr, len(memo), id_ptr, len(id)); + dp.CreatePlotDisk(tmp_dir, final_dir, filename, k, memo_ptr, len(memo), id_ptr, len(id)); }); py::class_(m, "DiskProver") diff --git a/lib/chiapos/python-bindings/test.py b/lib/chiapos/python-bindings/test.py deleted file mode 100644 index e8fcbdf984f2..000000000000 --- a/lib/chiapos/python-bindings/test.py +++ /dev/null @@ -1,32 +0,0 @@ -from chiapos import DiskProver, DiskPlotter, Verifier -from hashlib import sha256 -import secrets -import os - -challenge: bytes = bytes([i for i in range(0, 32)]) - -plot_id: bytes = bytes([5, 104, 52, 4, 51, 55, 23, 84, 91, 10, 111, 12, 13, - 222, 151, 16, 228, 211, 254, 45, 92, 198, 204, 10, 9, - 10, 11, 129, 139, 171, 15, 23]) -filename = "./myplot.dat" -pl = DiskPlotter() -pl.create_plot_disk(filename, 21, bytes([1, 2, 3, 4, 5]), plot_id) -pr = DiskProver(filename) - - -total_proofs: int = 0 -iterations: int = 5000 - -v = Verifier() -for i in range(iterations): - challenge = sha256(i.to_bytes(4, "big")).digest() - for index, quality in enumerate(pr.get_qualities_for_challenge(challenge)): - proof = pr.get_full_proof(challenge, index) - total_proofs += 1 - ver_quality = v.validate_proof(plot_id, 21, challenge, proof) - assert(quality == ver_quality) - -os.remove(filename) - -print(f"total proofs {total_proofs} out of {iterations}\ - {total_proofs / iterations}") diff --git a/lib/chiapos/setup.py b/lib/chiapos/setup.py index a5cb7d1c6ab0..b49b54ab711d 100644 --- a/lib/chiapos/setup.py +++ b/lib/chiapos/setup.py @@ -66,7 +66,7 @@ def build_extension(self, ext): setup( name='chiapos', - version='0.2.2', + version='0.2.3', author='Mariano Sorgente', author_email='mariano@chia.net', description='Chia proof of space plotting, proving, and verifying (wraps C++)', diff --git a/lib/chiapos/src/aes.hpp b/lib/chiapos/src/aes.hpp index def86c3b9121..db91ad7c64ff 100644 --- a/lib/chiapos/src/aes.hpp +++ b/lib/chiapos/src/aes.hpp @@ -1,270 +1,493 @@ -// Copyright 2018 Chia Network Inc - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Some public domain code is taken from pycrypto: -// https://github.com/dlitz/pycrypto/blob/master/src/AESNI.c -// -// AESNI.c: AES using AES-NI instructions -// -// Written in 2013 by Sebastian Ramacher - -#ifndef SRC_CPP_AES_HPP_ -#define SRC_CPP_AES_HPP_ - -#include // for memcmp -#include // for intrinsics for AES-NI - -/** - * Encrypts a message of 128 bits with a 128 bit key, using - * 10 rounds of AES128 (9 full rounds and one final round). Uses AES-NI - * assembly instructions. - */ -#define DO_ENC_BLOCK_128(m, k) \ - do \ - { \ - m = _mm_xor_si128(m, k[0]); \ - m = _mm_aesenc_si128(m, k[1]); \ - m = _mm_aesenc_si128(m, k[2]); \ - m = _mm_aesenc_si128(m, k[3]); \ - m = _mm_aesenc_si128(m, k[4]); \ - m = _mm_aesenc_si128(m, k[5]); \ - m = _mm_aesenc_si128(m, k[6]); \ - m = _mm_aesenc_si128(m, k[7]); \ - m = _mm_aesenc_si128(m, k[8]); \ - m = _mm_aesenc_si128(m, k[9]); \ - m = _mm_aesenclast_si128(m, k[10]); \ - } while (0) - -/** - * Encrypts a message of 128 bits with a 256 bit key, using - * 13 rounds of AES256 (13 full rounds and one final round). Uses - * AES-NI assembly instructions. - */ -#define DO_ENC_BLOCK_256(m, k) \ - do {\ - m = _mm_xor_si128(m, k[ 0]); \ - m = _mm_aesenc_si128(m, k[ 1]); \ - m = _mm_aesenc_si128(m, k[ 2]); \ - m = _mm_aesenc_si128(m, k[ 3]); \ - m = _mm_aesenc_si128(m, k[ 4]); \ - m = _mm_aesenc_si128(m, k[ 5]); \ - m = _mm_aesenc_si128(m, k[ 6]); \ - m = _mm_aesenc_si128(m, k[ 7]); \ - m = _mm_aesenc_si128(m, k[ 8]); \ - m = _mm_aesenc_si128(m, k[ 9]); \ - m = _mm_aesenc_si128(m, k[ 10]);\ - m = _mm_aesenc_si128(m, k[ 11]);\ - m = _mm_aesenc_si128(m, k[ 12]);\ - m = _mm_aesenc_si128(m, k[ 13]);\ - m = _mm_aesenclast_si128(m, k[ 14]);\ - }while(0) - -/** - * Encrypts a message of 128 bits with a 128 bit key, using - * 2 full rounds of AES128. Uses AES-NI assembly instructions. - */ -#define DO_ENC_BLOCK_2ROUND(m, k) \ - do \ - { \ - m = _mm_xor_si128(m, k[0]); \ - m = _mm_aesenc_si128(m, k[1]); \ - m = _mm_aesenc_si128(m, k[2]); \ - } while (0) -/** - * Decrypts a ciphertext of 128 bits with a 128 bit key, using - * 10 rounds of AES128 (9 full rounds and one final round). - * Uses AES-NI assembly instructions. - */ -#define DO_DEC_BLOCK(m, k) \ - do \ - { \ - m = _mm_xor_si128(m, k[10 + 0]); \ - m = _mm_aesdec_si128(m, k[10 + 1]); \ - m = _mm_aesdec_si128(m, k[10 + 2]); \ - m = _mm_aesdec_si128(m, k[10 + 3]); \ - m = _mm_aesdec_si128(m, k[10 + 4]); \ - m = _mm_aesdec_si128(m, k[10 + 5]); \ - m = _mm_aesdec_si128(m, k[10 + 6]); \ - m = _mm_aesdec_si128(m, k[10 + 7]); \ - m = _mm_aesdec_si128(m, k[10 + 8]); \ - m = _mm_aesdec_si128(m, k[10 + 9]); \ - m = _mm_aesdeclast_si128(m, k[0]); \ - } while (0) - -/** - * Decrypts a ciphertext of 128 bits with a 128 bit key, using - * 2 full rounds of AES128. Uses AES-NI assembly instructions. - * Will not work unless key schedule is modified. - */ /* -#define DO_DEC_BLOCK_2ROUND(m, k) \ - do \ - { \ - m = _mm_xor_si128(m, k[2 + 0]); \ - m = _mm_aesdec_si128(m, k[2 + 1]); \ - m = _mm_aesdec_si128(m, k[2 + 2]); \ - } while (0) + +The code in this file is originally from the Tiny AES project, which is in the +public domain. + +https://github.com/kokke/tiny-AES-c + +It has been heavily modified by Chia. + +*** + +This is an implementation of the AES algorithm, specifically ECB, CTR and CBC mode. +Block size can be chosen in aes.h - available choices are AES128, AES192, AES256. + +The implementation is verified against the test vectors in: + National Institute of Standards and Technology Special Publication 800-38A 2001 ED + +ECB-AES128 +---------- + + plain-text: + 6bc1bee22e409f96e93d7e117393172a + ae2d8a571e03ac9c9eb76fac45af8e51 + 30c81c46a35ce411e5fbc1191a0a52ef + f69f2445df4f9b17ad2b417be66c3710 + + key: + 2b7e151628aed2a6abf7158809cf4f3c + + resulting cipher + 3ad77bb40d7a3660a89ecaf32466ef97 + f5d3d58503b9699de785895a96fdbaaf + 43b1cd7f598ece23881b00e3ed030688 + 7b0c785e27e8ad3f8223207104725dd4 + + +NOTE: String length must be evenly divisible by 16byte (str_len % 16 == 0) + You should pad the end of the string with zeros if this is not the case. + For AES192/256 the key size is proportionally larger. + */ -static __m128i key_schedule[20]; // The expanded key -static __m128i aes128_keyexpand(__m128i key) { - key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); - key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); - return _mm_xor_si128(key, _mm_slli_si128(key, 4)); +/*****************************************************************************/ +/* Includes: */ +/*****************************************************************************/ +#include +#include // CBC mode, for memset + +//#define DISABLE_AESNI + +#ifndef DISABLE_AESNI +#include +#include "aesni.hpp" + +bool bHasAES=false; +bool bCheckedAES=false; +#endif // DISABLE_AESNI + +/*****************************************************************************/ +/* Defines: */ +/*****************************************************************************/ +// The number of columns comprising a state in AES. This is a constant in AES. Value=4 +#define Nb 4 + +/*****************************************************************************/ +/* Private variables: */ +/*****************************************************************************/ +// state - array holding the intermediate results during decryption. +typedef uint8_t state_t[4][4]; + +// The lookup-tables are marked const so they can be placed in read-only storage instead of RAM +// The numbers below can be computed dynamically trading ROM for RAM - +// This can be useful in (embedded) bootloader applications, where ROM is often limited. +static const uint8_t sbox[256] = { + //0 1 2 3 4 5 6 7 8 9 A B C D E F + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16 }; + +static const uint8_t rsbox[256] = { + 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb, + 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb, + 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e, + 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25, + 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92, + 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84, + 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, + 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, + 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, + 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e, + 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, + 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, + 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, + 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, + 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, + 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d }; + +// The round constant word array, Rcon[i], contains the values given by +// x to the power (i-1) being powers of x (x is denoted as {02}) in the field GF(2^8) +static const uint8_t Rcon[11] = { + 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36 }; + +/* + * Jordan Goulder points out in PR #12 (https://github.com/kokke/tiny-AES-C/pull/12), + * that you can remove most of the elements in the Rcon array, because they are unused. + * + * From Wikipedia's article on the Rijndael key schedule @ https://en.wikipedia.org/wiki/Rijndael_key_schedule#Rcon + * + * "Only the first some of these constants are actually used – up to rcon[10] for AES-128 (as 11 round keys are needed), + * up to rcon[8] for AES-192, up to rcon[7] for AES-256. rcon[0] is not used in AES algorithm." + */ + +/*****************************************************************************/ +/* Private functions: */ +/*****************************************************************************/ + +#define getSBoxValue(num) (sbox[(num)]) + + +// This function adds the round key to state. +// The round key is added to the state by an XOR function. +static void AddRoundKey(uint8_t round, state_t* state, const uint8_t* RoundKey) +{ + uint8_t i,j; + for (i = 0; i < 4; ++i) + { + for (j = 0; j < 4; ++j) + { + (*state)[i][j] ^= RoundKey[(round * Nb * 4) + (i * Nb) + j]; + } + } +} + +// The SubBytes Function Substitutes the values in the +// state matrix with values in an S-box. +static void SubBytes(state_t* state) +{ + uint8_t i, j; + for (i = 0; i < 4; ++i) + { + for (j = 0; j < 4; ++j) + { + (*state)[j][i] = getSBoxValue((*state)[j][i]); + } + } } -#define KEYEXP128_H(K1, K2, I, S) _mm_xor_si128(aes128_keyexpand(K1), \ - _mm_shuffle_epi32(_mm_aeskeygenassist_si128(K2, I), S)) +// The ShiftRows() function shifts the rows in the state to the left. +// Each row is shifted with different offset. +// Offset = Row number. So the first row is not shifted. +static void ShiftRows(state_t* state) +{ + uint8_t temp; + + // Rotate first row 1 columns to left + temp = (*state)[0][1]; + (*state)[0][1] = (*state)[1][1]; + (*state)[1][1] = (*state)[2][1]; + (*state)[2][1] = (*state)[3][1]; + (*state)[3][1] = temp; + + // Rotate second row 2 columns to left + temp = (*state)[0][2]; + (*state)[0][2] = (*state)[2][2]; + (*state)[2][2] = temp; + + temp = (*state)[1][2]; + (*state)[1][2] = (*state)[3][2]; + (*state)[3][2] = temp; + + // Rotate third row 3 columns to left + temp = (*state)[0][3]; + (*state)[0][3] = (*state)[3][3]; + (*state)[3][3] = (*state)[2][3]; + (*state)[2][3] = (*state)[1][3]; + (*state)[1][3] = temp; +} -#define KEYEXP128(K, I) KEYEXP128_H(K, K, I, 0xff) -#define KEYEXP256(K1, K2, I) KEYEXP128_H(K1, K2, I, 0xff) -#define KEYEXP256_2(K1, K2) KEYEXP128_H(K1, K2, 0x00, 0xaa) +static uint8_t xtime(uint8_t x) +{ + return ((x<<1) ^ (((x>>7) & 1) * 0x1b)); +} -// public API +// MixColumns function mixes the columns of the state matrix +static void MixColumns(state_t* state) +{ + uint8_t i; + uint8_t Tmp, Tm, t; + for (i = 0; i < 4; ++i) + { + t = (*state)[i][0]; + Tmp = (*state)[i][0] ^ (*state)[i][1] ^ (*state)[i][2] ^ (*state)[i][3] ; + Tm = (*state)[i][0] ^ (*state)[i][1] ; Tm = xtime(Tm); (*state)[i][0] ^= Tm ^ Tmp ; + Tm = (*state)[i][1] ^ (*state)[i][2] ; Tm = xtime(Tm); (*state)[i][1] ^= Tm ^ Tmp ; + Tm = (*state)[i][2] ^ (*state)[i][3] ; Tm = xtime(Tm); (*state)[i][2] ^= Tm ^ Tmp ; + Tm = (*state)[i][3] ^ t ; Tm = xtime(Tm); (*state)[i][3] ^= Tm ^ Tmp ; + } +} + +uint8_t RoundKey128[176]; // AES_KEYLEN 16 Key length in bytes +uint8_t RoundKey256[240]; // AES_KEYLEN 32 Key length in bytes + +#define KEYNR256 14 +#define KEYNR128 10 + +#define KEYNK256 8 +#define KEYNK128 4 + +#define ENCRYPTNR256 14 +#define ENCRYPTNR128 3 + +// This function produces Nb(Nr+1) round keys. The round keys are used in each round to decrypt the states. +static void KeyExpansion(uint8_t* RoundKey, const uint8_t* Key, int keyNr, int keyNk) +{ + int i, j, k; + uint8_t tempa[4]; // Used for the column/row operations + + // The first round key is the key itself. + for (i = 0; i < keyNk; ++i) + { + RoundKey[(i * 4) + 0] = Key[(i * 4) + 0]; + RoundKey[(i * 4) + 1] = Key[(i * 4) + 1]; + RoundKey[(i * 4) + 2] = Key[(i * 4) + 2]; + RoundKey[(i * 4) + 3] = Key[(i * 4) + 3]; + } + + // All other round keys are found from the previous round keys. + for (i = keyNk; i < Nb * (keyNr + 1); ++i) + { + { + k = (i - 1) * 4; + tempa[0]=RoundKey[k + 0]; + tempa[1]=RoundKey[k + 1]; + tempa[2]=RoundKey[k + 2]; + tempa[3]=RoundKey[k + 3]; + + } + + if (i % keyNk == 0) + { + // This function shifts the 4 bytes in a word to the left once. + // [a0,a1,a2,a3] becomes [a1,a2,a3,a0] + + // Function RotWord() + { + const uint8_t u8tmp = tempa[0]; + tempa[0] = tempa[1]; + tempa[1] = tempa[2]; + tempa[2] = tempa[3]; + tempa[3] = u8tmp; + } + + // SubWord() is a function that takes a four-byte input word and + // applies the S-box to each of the four bytes to produce an output word. + + // Function Subword() + { + tempa[0] = getSBoxValue(tempa[0]); + tempa[1] = getSBoxValue(tempa[1]); + tempa[2] = getSBoxValue(tempa[2]); + tempa[3] = getSBoxValue(tempa[3]); + } + + tempa[0] = tempa[0] ^ Rcon[i/keyNk]; + } + + // AES256 only + if ((keyNk==8)&&(i % keyNk == 4)) + { + // Function Subword() + { + tempa[0] = getSBoxValue(tempa[0]); + tempa[1] = getSBoxValue(tempa[1]); + tempa[2] = getSBoxValue(tempa[2]); + tempa[3] = getSBoxValue(tempa[3]); + } + } + + j = i * 4; k=(i - keyNk) * 4; + RoundKey[j + 0] = RoundKey[k + 0] ^ tempa[0]; + RoundKey[j + 1] = RoundKey[k + 1] ^ tempa[1]; + RoundKey[j + 2] = RoundKey[k + 2] ^ tempa[2]; + RoundKey[j + 3] = RoundKey[k + 3] ^ tempa[3]; + } +} /* * Loads an AES key. Can either be a 16 byte or 32 byte bytearray. */ void aes_load_key(uint8_t *enc_key, int keylen) { - switch (keylen) { - case 16: { - /* 128 bit key setup */ - key_schedule[0] = _mm_loadu_si128((const __m128i*) enc_key); - key_schedule[1] = KEYEXP128(key_schedule[0], 0x01); - key_schedule[2] = KEYEXP128(key_schedule[1], 0x02); - key_schedule[3] = KEYEXP128(key_schedule[2], 0x04); - key_schedule[4] = KEYEXP128(key_schedule[3], 0x08); - key_schedule[5] = KEYEXP128(key_schedule[4], 0x10); - key_schedule[6] = KEYEXP128(key_schedule[5], 0x20); - key_schedule[7] = KEYEXP128(key_schedule[6], 0x40); - key_schedule[8] = KEYEXP128(key_schedule[7], 0x80); - key_schedule[9] = KEYEXP128(key_schedule[8], 0x1B); - key_schedule[10] = KEYEXP128(key_schedule[9], 0x36); + +#ifndef DISABLE_AESNI + if(!bCheckedAES) + { + uint32_t eax, ebx, ecx, edx; + + eax = ebx = ecx = edx = 0; + __get_cpuid(1, &eax, &ebx, &ecx, &edx); + bHasAES=(ecx & bit_AES) > 0; + bCheckedAES=true; + } + + if(bHasAES) + return ni_aes_load_key(enc_key, keylen); +#endif // DISABLE_AESNI + + switch(keylen){ + case 32: + KeyExpansion(RoundKey256, enc_key, KEYNR256, KEYNK256); break; - } - case 32: { - /* 256 bit key setup */ - key_schedule[0] = _mm_loadu_si128((const __m128i*) enc_key); - key_schedule[1] = _mm_loadu_si128((const __m128i*) (enc_key+16)); - key_schedule[2] = KEYEXP256(key_schedule[0], key_schedule[1], 0x01); - key_schedule[3] = KEYEXP256_2(key_schedule[1], key_schedule[2]); - key_schedule[4] = KEYEXP256(key_schedule[2], key_schedule[3], 0x02); - key_schedule[5] = KEYEXP256_2(key_schedule[3], key_schedule[4]); - key_schedule[6] = KEYEXP256(key_schedule[4], key_schedule[5], 0x04); - key_schedule[7] = KEYEXP256_2(key_schedule[5], key_schedule[6]); - key_schedule[8] = KEYEXP256(key_schedule[6], key_schedule[7], 0x08); - key_schedule[9] = KEYEXP256_2(key_schedule[7], key_schedule[8]); - key_schedule[10] = KEYEXP256(key_schedule[8], key_schedule[9], 0x10); - key_schedule[11] = KEYEXP256_2(key_schedule[9], key_schedule[10]); - key_schedule[12] = KEYEXP256(key_schedule[10], key_schedule[11], 0x20); - key_schedule[13] = KEYEXP256_2(key_schedule[11], key_schedule[12]); - key_schedule[14] = KEYEXP256(key_schedule[12], key_schedule[13], 0x40); + case 16: + KeyExpansion(RoundKey128, enc_key, KEYNR128, KEYNK128); break; - } } } -// Declares a global variable for efficiency. -__m128i m_global; - /* - * Encrypts a plaintext using AES256. - */ -static inline void aes256_enc(const uint8_t *plainText, uint8_t *cipherText) { - m_global = _mm_loadu_si128(reinterpret_cast(plainText)); - - DO_ENC_BLOCK_256(m_global, key_schedule); - - _mm_storeu_si128(reinterpret_cast<__m128i *>(cipherText), m_global); +* XOR 128 bits +*/ +static inline void xor128(const uint8_t *in1, const uint8_t *in2, uint8_t *out) { + for(int i=0;i<16;i++) { + out[i]=in1[i]^in2[i]; + } } /* - * Encrypts a plaintext using AES128 with 2 rounds. + * Encrypts a plaintext using AES256. */ -static inline void aes128_enc(const uint8_t *plainText, uint8_t *cipherText) { - m_global = _mm_loadu_si128(reinterpret_cast(plainText)); - - // Uses the 2 round encryption innstead of the full 10 round encryption - DO_ENC_BLOCK_2ROUND(m_global, key_schedule); +static inline void aes256_enc(const uint8_t *in, uint8_t *out) { + +#ifndef DISABLE_AESNI + if(bHasAES) + return ni_aes256_enc(in, out); +#endif // DISABLE_AESNI + + memcpy(out,in,16); + + state_t *state=(state_t*)out; + + uint8_t round = 0; + + // Add the First round key to the state before starting the rounds. + AddRoundKey(0, state, RoundKey256); + + // There will be Nr rounds. + // The first Nr-1 rounds are identical. + // These Nr-1 rounds are executed in the loop below. + for (round = 1; round < ENCRYPTNR256; ++round) + { + SubBytes(state); + ShiftRows(state); + MixColumns(state); + AddRoundKey(round, state, RoundKey256); + } - _mm_storeu_si128(reinterpret_cast<__m128i *>(cipherText), m_global); + // The last round is given below. + // The MixColumns function is not here in the last round. + SubBytes(state); + ShiftRows(state); + AddRoundKey(ENCRYPTNR256, state, RoundKey256); } - + /* * Encrypts an integer using AES128 with 2 rounds. */ -static inline __m128i aes128_enc_int(__m128i plainText) { - // Uses the 2 round encryption innstead of the full 10 round encryption - DO_ENC_BLOCK_2ROUND(plainText, key_schedule); - return plainText; +static inline void aes128_enc(uint8_t *in, uint8_t *out) { + +#ifndef DISABLE_AESNI + if(bHasAES) + return ni_aes128_enc(in, out); +#endif // DISABLE_AESNI + + memcpy(out,in,16); + + state_t *state=(state_t*)out; + + uint8_t round = 0; + + // Add the First round key to the state before starting the rounds. + AddRoundKey(0, state, RoundKey128); + + // There will be Nr rounds. + // The first Nr-1 rounds are identical. + // These Nr-1 rounds are executed in the loop below. + for (round = 1; round < ENCRYPTNR128; ++round) + { + SubBytes(state); + ShiftRows(state); + MixColumns(state); + AddRoundKey(round, state, RoundKey128); + } } - -__m128i m1; -__m128i m2; -__m128i m3; -__m128i m4; - + /* * Uses AES cache mode to map a 2 block ciphertext into 128 bit result. */ static inline void aes128_2b(uint8_t *block1, uint8_t *block2, uint8_t *res) { - m1 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block1)); - m2 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block2)); - m3 = aes128_enc_int(m1); // E(L) - m3 = aes128_enc_int(_mm_xor_si128(m3, m2)); - _mm_storeu_si128(reinterpret_cast<__m128i *>(res), m3); + +#ifndef DISABLE_AESNI + if(bHasAES) + return ni_aes128_2b(block1, block2, res); +#endif // DISABLE_AESNI + + uint8_t m1[16]; + uint8_t m2[16]; + uint8_t m3[16]; + uint8_t intermediate[16]; + + memcpy(m1,block1,16); + memcpy(m2,block2,16); + + aes128_enc(m1,m3); + xor128(m3, m2, intermediate); + aes128_enc(intermediate,m3); + + memcpy(res,m3,16); } - + /* * Uses AES cache mode to map a 3 block ciphertext into 128 bit result. */ static inline void aes128_3b(uint8_t *block1, uint8_t* block2, uint8_t *block3, uint8_t* res) { - m1 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block1)); - m2 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block2)); - - m1 = aes128_enc_int(m1); // E(La) - m2 = aes128_enc_int(m2); // E(Ra) - - m1 = _mm_xor_si128(m1, m2); - m2 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block3)); - - m2 = aes128_enc_int(m2); - m1 = _mm_xor_si128(m1, m2); - m3 = aes128_enc_int(m1); - _mm_storeu_si128(reinterpret_cast<__m128i *>(res), m3); + +#ifndef DISABLE_AESNI + if(bHasAES) + return ni_aes128_3b(block1, block2, block3, res); +#endif // DISABLE_AESNI + + uint8_t m1[16]; + uint8_t m2[16]; + uint8_t m3[16]; + + memcpy(m1,block1,16); + memcpy(m2,block2,16); + + aes128_enc(m1,m1); // E(La) + aes128_enc(m2,m2); // E(Ra) + + xor128(m1, m2, m1); + memcpy(m2,block3,16); + + aes128_enc(m2,m2); + xor128(m1, m2, m1); + aes128_enc(m1,m3); + memcpy(res,m3,16); } /* * Uses AES cache mode to map a 4 block ciphertext into 128 bit result. */ static inline void aes128_4b(uint8_t *block1, uint8_t* block2, uint8_t *block3, uint8_t* block4, uint8_t* res) { - m1 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block1)); - m2 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block3)); - m3 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block2)); - m4 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block4)); - - m1 = aes128_enc_int(m1); // E(La) - m1 = _mm_xor_si128(m1, m3); - m1 = aes128_enc_int(m1); // E(E(La) ^ Lb) - m2 = aes128_enc_int(m2); // E(Ra) - - m1 = _mm_xor_si128(m1, m2); // xor e(Ra) - m1 = _mm_xor_si128(m1, m4); // xor Rb - - m3 = aes128_enc_int(m1); - _mm_storeu_si128(reinterpret_cast<__m128i *>(res), m3); + +#ifndef DISABLE_AESNI + if(bHasAES) + return ni_aes128_4b(block1, block2, block3, block4, res); +#endif // DISABLE_AESNI + + uint8_t m1[16]; + uint8_t m2[16]; + uint8_t m3[16]; + uint8_t m4[16]; + + memcpy(m1,block1,16); + memcpy(m2,block2,16); + memcpy(m3,block3,16); + memcpy(m4,block4,16); + + aes128_enc(m1,m1); // E(La) + xor128(m1, m3, m1); + aes128_enc(m1,m1); // E(E(La) ^ Lb) + aes128_enc(m2,m2); // E(Ra) + + xor128(m1, m2, m1); // xor e(Ra) + xor128(m1, m4, m1); // xor Rb + + aes128_enc(m1,m3); // E(La) + memcpy(res,m3,16); } -#endif // SRC_CPP_AES_HPP_ diff --git a/lib/chiapos/src/aesni.hpp b/lib/chiapos/src/aesni.hpp new file mode 100644 index 000000000000..d3d469d5353c --- /dev/null +++ b/lib/chiapos/src/aesni.hpp @@ -0,0 +1,270 @@ +// Copyright 2018 Chia Network Inc + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Some public domain code is taken from pycrypto: +// https://github.com/dlitz/pycrypto/blob/master/src/AESNI.c +// +// AESNI.c: AES using AES-NI instructions +// +// Written in 2013 by Sebastian Ramacher + +#ifndef SRC_CPP_AES_HPP_ +#define SRC_CPP_AES_HPP_ + +#include // for memcmp +#include // for intrinsics for AES-NI + +/** + * Encrypts a message of 128 bits with a 128 bit key, using + * 10 rounds of AES128 (9 full rounds and one final round). Uses AES-NI + * assembly instructions. + */ +#define DO_ENC_BLOCK_128(m, k) \ + do \ + { \ + m = _mm_xor_si128(m, k[0]); \ + m = _mm_aesenc_si128(m, k[1]); \ + m = _mm_aesenc_si128(m, k[2]); \ + m = _mm_aesenc_si128(m, k[3]); \ + m = _mm_aesenc_si128(m, k[4]); \ + m = _mm_aesenc_si128(m, k[5]); \ + m = _mm_aesenc_si128(m, k[6]); \ + m = _mm_aesenc_si128(m, k[7]); \ + m = _mm_aesenc_si128(m, k[8]); \ + m = _mm_aesenc_si128(m, k[9]); \ + m = _mm_aesenclast_si128(m, k[10]); \ + } while (0) + +/** + * Encrypts a message of 128 bits with a 256 bit key, using + * 13 rounds of AES256 (13 full rounds and one final round). Uses + * AES-NI assembly instructions. + */ +#define DO_ENC_BLOCK_256(m, k) \ + do {\ + m = _mm_xor_si128(m, k[ 0]); \ + m = _mm_aesenc_si128(m, k[ 1]); \ + m = _mm_aesenc_si128(m, k[ 2]); \ + m = _mm_aesenc_si128(m, k[ 3]); \ + m = _mm_aesenc_si128(m, k[ 4]); \ + m = _mm_aesenc_si128(m, k[ 5]); \ + m = _mm_aesenc_si128(m, k[ 6]); \ + m = _mm_aesenc_si128(m, k[ 7]); \ + m = _mm_aesenc_si128(m, k[ 8]); \ + m = _mm_aesenc_si128(m, k[ 9]); \ + m = _mm_aesenc_si128(m, k[ 10]);\ + m = _mm_aesenc_si128(m, k[ 11]);\ + m = _mm_aesenc_si128(m, k[ 12]);\ + m = _mm_aesenc_si128(m, k[ 13]);\ + m = _mm_aesenclast_si128(m, k[ 14]);\ + }while(0) + +/** + * Encrypts a message of 128 bits with a 128 bit key, using + * 2 full rounds of AES128. Uses AES-NI assembly instructions. + */ +#define DO_ENC_BLOCK_2ROUND(m, k) \ + do \ + { \ + m = _mm_xor_si128(m, k[0]); \ + m = _mm_aesenc_si128(m, k[1]); \ + m = _mm_aesenc_si128(m, k[2]); \ + } while (0) +/** + * Decrypts a ciphertext of 128 bits with a 128 bit key, using + * 10 rounds of AES128 (9 full rounds and one final round). + * Uses AES-NI assembly instructions. + */ +#define DO_DEC_BLOCK(m, k) \ + do \ + { \ + m = _mm_xor_si128(m, k[10 + 0]); \ + m = _mm_aesdec_si128(m, k[10 + 1]); \ + m = _mm_aesdec_si128(m, k[10 + 2]); \ + m = _mm_aesdec_si128(m, k[10 + 3]); \ + m = _mm_aesdec_si128(m, k[10 + 4]); \ + m = _mm_aesdec_si128(m, k[10 + 5]); \ + m = _mm_aesdec_si128(m, k[10 + 6]); \ + m = _mm_aesdec_si128(m, k[10 + 7]); \ + m = _mm_aesdec_si128(m, k[10 + 8]); \ + m = _mm_aesdec_si128(m, k[10 + 9]); \ + m = _mm_aesdeclast_si128(m, k[0]); \ + } while (0) + +/** + * Decrypts a ciphertext of 128 bits with a 128 bit key, using + * 2 full rounds of AES128. Uses AES-NI assembly instructions. + * Will not work unless key schedule is modified. + */ +/* +#define DO_DEC_BLOCK_2ROUND(m, k) \ + do \ + { \ + m = _mm_xor_si128(m, k[2 + 0]); \ + m = _mm_aesdec_si128(m, k[2 + 1]); \ + m = _mm_aesdec_si128(m, k[2 + 2]); \ + } while (0) +*/ + +static __m128i key_schedule[20]; // The expanded key + +static __m128i aes128_keyexpand(__m128i key) { + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + return _mm_xor_si128(key, _mm_slli_si128(key, 4)); +} + +#define KEYEXP128_H(K1, K2, I, S) _mm_xor_si128(aes128_keyexpand(K1), \ + _mm_shuffle_epi32(_mm_aeskeygenassist_si128(K2, I), S)) + +#define KEYEXP128(K, I) KEYEXP128_H(K, K, I, 0xff) +#define KEYEXP256(K1, K2, I) KEYEXP128_H(K1, K2, I, 0xff) +#define KEYEXP256_2(K1, K2) KEYEXP128_H(K1, K2, 0x00, 0xaa) + +// public API + +/* + * Loads an AES key. Can either be a 16 byte or 32 byte bytearray. + */ +void ni_aes_load_key(uint8_t *enc_key, int keylen) { + switch (keylen) { + case 16: { + /* 128 bit key setup */ + key_schedule[0] = _mm_loadu_si128((const __m128i*) enc_key); + key_schedule[1] = KEYEXP128(key_schedule[0], 0x01); + key_schedule[2] = KEYEXP128(key_schedule[1], 0x02); + key_schedule[3] = KEYEXP128(key_schedule[2], 0x04); + key_schedule[4] = KEYEXP128(key_schedule[3], 0x08); + key_schedule[5] = KEYEXP128(key_schedule[4], 0x10); + key_schedule[6] = KEYEXP128(key_schedule[5], 0x20); + key_schedule[7] = KEYEXP128(key_schedule[6], 0x40); + key_schedule[8] = KEYEXP128(key_schedule[7], 0x80); + key_schedule[9] = KEYEXP128(key_schedule[8], 0x1B); + key_schedule[10] = KEYEXP128(key_schedule[9], 0x36); + break; + } + case 32: { + /* 256 bit key setup */ + key_schedule[0] = _mm_loadu_si128((const __m128i*) enc_key); + key_schedule[1] = _mm_loadu_si128((const __m128i*) (enc_key+16)); + key_schedule[2] = KEYEXP256(key_schedule[0], key_schedule[1], 0x01); + key_schedule[3] = KEYEXP256_2(key_schedule[1], key_schedule[2]); + key_schedule[4] = KEYEXP256(key_schedule[2], key_schedule[3], 0x02); + key_schedule[5] = KEYEXP256_2(key_schedule[3], key_schedule[4]); + key_schedule[6] = KEYEXP256(key_schedule[4], key_schedule[5], 0x04); + key_schedule[7] = KEYEXP256_2(key_schedule[5], key_schedule[6]); + key_schedule[8] = KEYEXP256(key_schedule[6], key_schedule[7], 0x08); + key_schedule[9] = KEYEXP256_2(key_schedule[7], key_schedule[8]); + key_schedule[10] = KEYEXP256(key_schedule[8], key_schedule[9], 0x10); + key_schedule[11] = KEYEXP256_2(key_schedule[9], key_schedule[10]); + key_schedule[12] = KEYEXP256(key_schedule[10], key_schedule[11], 0x20); + key_schedule[13] = KEYEXP256_2(key_schedule[11], key_schedule[12]); + key_schedule[14] = KEYEXP256(key_schedule[12], key_schedule[13], 0x40); + break; + } + } +} + +// Declares a global variable for efficiency. +__m128i m_global; + +/* + * Encrypts a plaintext using AES256. + */ +static inline void ni_aes256_enc(const uint8_t *plainText, uint8_t *cipherText) { + m_global = _mm_loadu_si128(reinterpret_cast(plainText)); + + DO_ENC_BLOCK_256(m_global, key_schedule); + + _mm_storeu_si128(reinterpret_cast<__m128i *>(cipherText), m_global); +} + +/* + * Encrypts a plaintext using AES128 with 2 rounds. + */ +static inline void ni_aes128_enc(const uint8_t *plainText, uint8_t *cipherText) { + m_global = _mm_loadu_si128(reinterpret_cast(plainText)); + + // Uses the 2 round encryption innstead of the full 10 round encryption + DO_ENC_BLOCK_2ROUND(m_global, key_schedule); + + _mm_storeu_si128(reinterpret_cast<__m128i *>(cipherText), m_global); +} + +/* + * Encrypts an integer using AES128 with 2 rounds. + */ +static inline __m128i ni_aes128_enc_int(__m128i plainText) { + // Uses the 2 round encryption innstead of the full 10 round encryption + DO_ENC_BLOCK_2ROUND(plainText, key_schedule); + return plainText; +} + +__m128i m1; +__m128i m2; +__m128i m3; +__m128i m4; + +/* + * Uses AES cache mode to map a 2 block ciphertext into 128 bit result. + */ +static inline void ni_aes128_2b(uint8_t *block1, uint8_t *block2, uint8_t *res) { + m1 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block1)); + m2 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block2)); + m3 = ni_aes128_enc_int(m1); // E(L) + m3 = ni_aes128_enc_int(_mm_xor_si128(m3, m2)); + _mm_storeu_si128(reinterpret_cast<__m128i *>(res), m3); +} + +/* + * Uses AES cache mode to map a 3 block ciphertext into 128 bit result. + */ +static inline void ni_aes128_3b(uint8_t *block1, uint8_t* block2, uint8_t *block3, uint8_t* res) { + m1 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block1)); + m2 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block2)); + + m1 = ni_aes128_enc_int(m1); // E(La) + m2 = ni_aes128_enc_int(m2); // E(Ra) + + m1 = _mm_xor_si128(m1, m2); + m2 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block3)); + + m2 = ni_aes128_enc_int(m2); + m1 = _mm_xor_si128(m1, m2); + m3 = ni_aes128_enc_int(m1); + _mm_storeu_si128(reinterpret_cast<__m128i *>(res), m3); +} + +/* + * Uses AES cache mode to map a 4 block ciphertext into 128 bit result. + */ +static inline void ni_aes128_4b(uint8_t *block1, uint8_t* block2, uint8_t *block3, uint8_t* block4, uint8_t* res) { + m1 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block1)); + m2 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block3)); + m3 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block2)); + m4 = _mm_loadu_si128(reinterpret_cast<__m128i *>(block4)); + + m1 = ni_aes128_enc_int(m1); // E(La) + m1 = _mm_xor_si128(m1, m3); + m1 = ni_aes128_enc_int(m1); // E(E(La) ^ Lb) + m2 = ni_aes128_enc_int(m2); // E(Ra) + + m1 = _mm_xor_si128(m1, m2); // xor e(Ra) + m1 = _mm_xor_si128(m1, m4); // xor Rb + + m3 = ni_aes128_enc_int(m1); + _mm_storeu_si128(reinterpret_cast<__m128i *>(res), m3); +} + +#endif // SRC_CPP_AES_HPP_ diff --git a/lib/chiapos/src/bits.hpp b/lib/chiapos/src/bits.hpp index 60173b68ca11..30a547415ebd 100644 --- a/lib/chiapos/src/bits.hpp +++ b/lib/chiapos/src/bits.hpp @@ -31,6 +31,8 @@ // A stack vector of length 5, having the functions of std::vector needed for Bits. struct SmallVector { + typedef uint16_t size_type; + SmallVector() { count_ = 0; } @@ -49,24 +51,26 @@ struct SmallVector { SmallVector& operator = (const SmallVector& other) { count_ = other.count_; - for (uint16_t i = 0; i < other.count_; i++) + for (size_type i = 0; i < other.count_; i++) v_[i] = other.v_[i]; return (*this); } - uint16_t size() const { + size_type size() const { return count_; } private: uint128_t v_[5]; - uint16_t count_; + size_type count_; }; // A stack vector of length 1024, having the functions of std::vector needed for Bits. // The max number of Bits that can be stored is 1024 * 128 struct ParkVector { + typedef uint32_t size_type; + ParkVector() { count_ = 0; } @@ -85,18 +89,18 @@ struct ParkVector { ParkVector& operator = (const ParkVector& other) { count_ = other.count_; - for (uint32_t i = 0; i < other.count_; i++) + for (size_type i = 0; i < other.count_; i++) v_[i] = other.v_[i]; return (*this); } - uint32_t size() const { + size_type size() const { return count_; } private: uint128_t v_[1024]; - uint32_t count_; + size_type count_; }; /* @@ -212,12 +216,12 @@ template class BitsGeneric { } BitsGeneric result; if (values_.size() > 0) { - for (uint32_t i = 0; i < values_.size() - 1; i++) + for (typename T::size_type i = 0; i < values_.size() - 1; i++) result.AppendValue(values_[i], 128); result.AppendValue(values_[values_.size() - 1], last_size_); } if (b.values_.size() > 0) { - for (uint32_t i = 0; i < b.values_.size() - 1; i++) + for (typename T::size_type i = 0; i < b.values_.size() - 1; i++) result.AppendValue(b.values_[i], 128); result.AppendValue(b.values_[b.values_.size() - 1], b.last_size_); } @@ -228,12 +232,12 @@ template class BitsGeneric { template BitsGeneric& operator += (const BitsGeneric& b) { if (b.values_.size() > 0) { - for (uint32_t i = 0; i < b.values_.size() - 1; i++) + for (typename T2::size_type i = 0; i < b.values_.size() - 1; i++) this->AppendValue(b.values_[i], 128); this->AppendValue(b.values_[b.values_.size() - 1], b.last_size_); } return *this; - } + } BitsGeneric& operator++() { uint128_t limit = ((uint128_t)std::numeric_limits :: max() << 64) + @@ -243,12 +247,10 @@ template class BitsGeneric { if (values_[values_.size() - 1] != last_bucket_mask) { values_[values_.size() - 1]++; } else { - bool all_one = true; if (values_.size() > 1) { // Otherwise, search for the first bucket that isn't full of 1 bits. for (int16_t i = values_.size() - 2; i >= 0; i--) if (values_[i] != limit) { - all_one = false; // Increment it. values_[i]++; // Buckets that were full of 1 bits turn all to 0 bits. @@ -277,6 +279,7 @@ template class BitsGeneric { values_[values_.size() - 1]--; return *this; } + if (values_.size() > 1) { // Search for the first bucket different than 0. for (int16_t i = values_.size() - 2; i >= 0; i--) @@ -288,7 +291,7 @@ template class BitsGeneric { (uint128_t)std::numeric_limits :: max(); // All buckets that were previously 0, now become full of 1s. // (i.e. 1010000 - 1 = 1001111) - for (uint32_t j = i + 1; j < values_.size() - 1; j++) + for (typename T::size_type j = i + 1; j < values_.size() - 1; j++) values_[j] = limit; values_[values_.size() - 1] = (last_size_ == 128) ? limit : ((static_cast(1) << last_size_) - 1); @@ -317,12 +320,12 @@ template class BitsGeneric { return res; } - BitsGeneric Slice(int32_t start_index) const { + BitsGeneric Slice(uint32_t start_index) const { return Slice(start_index, GetSize()); } // Slices the bits from [start_index, end_index) - BitsGeneric Slice(int32_t start_index, int32_t end_index) const { + BitsGeneric Slice(uint32_t start_index, uint32_t end_index) const { if (end_index > GetSize()) { end_index = GetSize(); } @@ -363,7 +366,7 @@ template class BitsGeneric { } // Same as 'Slice', but result fits into an uint64_t. Used for memory optimization. - uint64_t SliceBitsToInt(int32_t start_index, int32_t end_index) const { + uint64_t SliceBitsToInt(uint32_t start_index, uint32_t end_index) const { /*if (end_index > GetSize()) { end_index = GetSize(); } @@ -372,7 +375,7 @@ template class BitsGeneric { } */ if ((start_index >> 7) == (end_index >> 7)) { uint128_t res = values_[start_index >> 7]; - if (((uint32_t)start_index >> 7) == values_.size() - 1) + if ((start_index >> 7) == values_.size() - 1) res = res >> (last_size_ - (end_index & 127)); else res = res >> (128 - (end_index & 127)); @@ -383,7 +386,7 @@ template class BitsGeneric { uint128_t prefix, suffix; SplitNumberByPrefix(values_[(start_index >> 7)], 128, start_index & 127, &prefix, &suffix); uint128_t result = suffix; - uint8_t bucket_size = (((uint32_t)end_index >> 7) == values_.size() - 1) ? last_size_ : 128; + uint8_t bucket_size = ((end_index >> 7) == values_.size() - 1) ? last_size_ : 128; SplitNumberByPrefix(values_[(end_index >> 7)], bucket_size, end_index & 127, &prefix, &suffix); result = (result << (end_index & 127)) + prefix; return result; @@ -431,11 +434,11 @@ template class BitsGeneric { std::string ToString() const { std::string str = ""; - for (uint32_t i = 0; i < values_.size(); i++) { + for (typename T::size_type i = 0; i < values_.size(); i++) { uint128_t val = values_[i]; - uint32_t size = (i == values_.size() - 1) ? last_size_ : 128; + typename T::size_type size = (i == values_.size() - 1) ? last_size_ : 128; std::string str_bucket = ""; - for (int i = 0; i < size; i++) { + for (typename T::size_type i = 0; i < size; i++) { if (val % 2) str_bucket = "1" + str_bucket; else diff --git a/lib/chiapos/src/cli.cpp b/lib/chiapos/src/cli.cpp index 2aad7ef57951..696a42c76c8f 100644 --- a/lib/chiapos/src/cli.cpp +++ b/lib/chiapos/src/cli.cpp @@ -64,6 +64,8 @@ int main(int argc, char *argv[]) { // Default values uint8_t k = 20; string filename = "plot.dat"; + string tempdir = "."; + string finaldir = "."; string operation = "help"; string memo = "0102030405"; string id = "022fb42c08c12de3a6af053880199806532e79515f94e83461612101f9412f9e"; @@ -71,6 +73,8 @@ int main(int argc, char *argv[]) { options.allow_unrecognised_options() .add_options() ("k, size", "Plot size", cxxopts::value(k)) + ("t, tempdir", "Temporary directory", cxxopts::value(tempdir)) + ("d, finaldir", "Final directory", cxxopts::value(finaldir)) ("f, file", "Filename", cxxopts::value(filename)) ("m, memo", "Memo to insert into the plot", cxxopts::value(memo)) ("i, id", "Unique 32-byte seed for the plot", cxxopts::value(id)) @@ -82,6 +86,7 @@ int main(int argc, char *argv[]) { HelpAndQuit(options); } operation = argv[1]; + std::cout << "operation" << operation << std::endl; if (operation == "help") { HelpAndQuit(options); @@ -101,7 +106,7 @@ int main(int argc, char *argv[]) { HexToBytes(id, id_bytes); DiskPlotter plotter = DiskPlotter(); - plotter.CreatePlotDisk(filename, k, memo_bytes, 5, id_bytes, 32); + plotter.CreatePlotDisk(tempdir, finaldir, filename, k, memo_bytes, 5, id_bytes, 32); } else if (operation == "prove") { if (argc < 3) { HelpAndQuit(options); diff --git a/lib/chiapos/src/cmake_aesni_test.cpp b/lib/chiapos/src/cmake_aesni_test.cpp new file mode 100644 index 000000000000..daed6189dfb3 --- /dev/null +++ b/lib/chiapos/src/cmake_aesni_test.cpp @@ -0,0 +1,16 @@ +#include "aes.hpp" + +int main() { + uint8_t enc_key[32]; + uint8_t in[16]; + uint8_t out[16]; + + memset(enc_key,0x00,sizeof(enc_key)); + memset(in,0x00,sizeof(in)); + + aes_load_key(enc_key, sizeof(enc_key)); + aes256_enc(in, out); + + return 0; +} + diff --git a/lib/chiapos/src/plotter_disk.hpp b/lib/chiapos/src/plotter_disk.hpp index 5aeb8382d3c3..0ea2363d2425 100644 --- a/lib/chiapos/src/plotter_disk.hpp +++ b/lib/chiapos/src/plotter_disk.hpp @@ -26,6 +26,21 @@ #include #include +#if __has_include() + +#include +namespace filesystem = std::filesystem; + +#elif __has_include() + +#include +namespace filesystem = std::experimental::filesystem; + +#else +#error "an implementation of filesystem is required!" +#endif + + #include "util.hpp" #include "encoding.hpp" #include "calculate_bucket.hpp" @@ -71,13 +86,32 @@ class DiskPlotter { // This method creates a plot on disk with the filename. A temporary file, "plotting" + filename, // is created and will be larger than the final plot file. This file is deleted at the end of // the process. - void CreatePlotDisk(std::string filename, uint8_t k, const uint8_t* memo, + void CreatePlotDisk(std::string tmp_dirname, std::string final_dirname, std::string filename, + uint8_t k, const uint8_t* memo, uint32_t memo_len, const uint8_t* id, uint32_t id_len) { - std::cout << std::endl << "Starting plotting progress into file " << filename << "." << std::endl; + std::cout << std::endl << "Starting plotting progress into temporary dir " << tmp_dirname << "." << std::endl; std::cout << "Memo: " << Util::HexStr(memo, memo_len) << std::endl; std::cout << "ID: " << Util::HexStr(id, id_len) << std::endl; std::cout << "Plot size is: " << static_cast(k) << std::endl; + // Cross platform way to concatenate paths, c++17. + filesystem::path tmp_1_filename = filesystem::path(tmp_dirname) / filesystem::path(filename + ".tmp"); + filesystem::path tmp_2_filename = filesystem::path(tmp_dirname) / filesystem::path(filename + ".2.tmp"); + filesystem::path final_filename = filesystem::path(final_dirname) / filesystem::path(filename); + + // Check if the paths exist + if (!filesystem::exists(tmp_dirname)) { + std::string err_string = "Directory " + tmp_dirname + " does not exist"; + std::cerr << err_string << std::endl; + throw err_string; + } + + if (!filesystem::exists(final_dirname)) { + std::string err_string = "Directory " + final_dirname + " does not exist"; + std::cerr << err_string << std::endl; + throw err_string; + } + // These variables are used in the WriteParkToFile method. They are preallocatted here // to save time. first_line_point_bytes = new uint8_t[CalculateLinePointSize(k)]; @@ -88,27 +122,27 @@ class DiskPlotter { assert(k >= kMinPlotSize); assert(k <= kMaxPlotSize); - std::string plot_filename = filename + ".tmp"; + std::cout << std::endl << "Starting phase 1/4: Forward Propagation... " << Timer::GetNow(); - std::cout << std::endl << "Starting phase 1/4: Forward Propagation..." << std::endl; Timer p1; Timer all_phases; - std::vector results = WritePlotFile(plot_filename, k, id, memo, memo_len); + std::vector results = WritePlotFile(tmp_1_filename, k, id, memo, memo_len); p1.PrintElapsed("Time for phase 1 ="); - std::cout << std::endl << "Starting phase 2/4: Backpropagation..." << std::endl; + std::cout << std::endl << "Starting phase 2/4: Backpropagation into " << tmp_1_filename << " and " << tmp_2_filename << " ..." << Timer::GetNow(); + Timer p2; - Backpropagate(filename, plot_filename, k, id, memo, memo_len, results); + Backpropagate(tmp_2_filename, tmp_1_filename, k, id, memo, memo_len, results); p2.PrintElapsed("Time for phase 2 ="); - std::cout << std::endl << "Starting phase 3/4: Compression..." << std::endl; + std::cout << std::endl << "Starting phase 3/4: Compression... " << Timer::GetNow(); Timer p3; - Phase3Results res = CompressTables(k, results, filename, plot_filename, id, memo, memo_len); + Phase3Results res = CompressTables(k, results, tmp_2_filename, tmp_1_filename, id, memo, memo_len); p3.PrintElapsed("Time for phase 3 ="); - std::cout << std::endl << "Starting phase 4/4: Write Checkpoint tables..." << std::endl; + std::cout << std::endl << "Starting phase 4/4: Write Checkpoint tables... " << Timer::GetNow(); Timer p4; - WriteCTables(k, k + 1, filename, plot_filename, res); + WriteCTables(k, k + 1, tmp_2_filename, tmp_1_filename, res); p4.PrintElapsed("Time for phase 4 ="); std::cout << "Approximate working space used: " << @@ -117,7 +151,14 @@ class DiskPlotter { static_cast(res.final_table_begin_pointers[11])/(1024*1024*1024) << " GB" << std::endl; all_phases.PrintElapsed("Total time ="); - remove(plot_filename.c_str()); + bool removed_1 = filesystem::remove(tmp_1_filename); + filesystem::copy(tmp_2_filename, final_filename, filesystem::copy_options::overwrite_existing); + + bool removed_2 = filesystem::remove(tmp_2_filename); + + std::cout << "Removed " << tmp_1_filename << "? " << removed_1 << std::endl; + std::cout << "Removed " << tmp_2_filename << "? " << removed_2 << std::endl; + std::cout << "Copied final file to " << final_filename << std::endl; delete[] first_line_point_bytes; delete[] park_stubs_bytes; @@ -885,7 +926,7 @@ class DiskPlotter { deltas_bits.ToBytes(park_deltas_bytes); uint16_t encoded_size = deltas_bits.GetSize() / 8; - + assert((uint32_t)(encoded_size + 2) < CalculateMaxDeltasSize(k, table_index)); writer.write((const char*)&encoded_size, 2); writer.write((const char*)park_deltas_bytes, encoded_size); diff --git a/lib/chiapos/src/util.hpp b/lib/chiapos/src/util.hpp index 41e08ae757f3..17b68709f0be 100644 --- a/lib/chiapos/src/util.hpp +++ b/lib/chiapos/src/util.hpp @@ -50,6 +50,13 @@ class Timer { this->cpu_time_start_ = clock(); } + static char* GetNow() + { + auto now = std::chrono::system_clock::now(); + auto tt = std::chrono::system_clock::to_time_t(now); + return ctime(&tt); // ctime includes newline + } + void PrintElapsed(std::string name) { auto end = std::chrono::steady_clock::now(); auto wall_clock_ms = std::chrono::duration_cast( @@ -59,7 +66,7 @@ class Timer { double cpu_ratio = static_cast(10000 * (cpu_time_ms / wall_clock_ms)) / 100.0; - std::cout << name << " " << (wall_clock_ms / 1000.0) << " seconds. CPU (" << cpu_ratio << "%)" << std::endl; + std::cout << name << " " << (wall_clock_ms / 1000.0) << " seconds. CPU (" << cpu_ratio << "%) " << Timer::GetNow(); } private: diff --git a/lib/chiapos/tests/test.cpp b/lib/chiapos/tests/test.cpp index 44d0f5c58f0f..23f9a370cf09 100644 --- a/lib/chiapos/tests/test.cpp +++ b/lib/chiapos/tests/test.cpp @@ -364,7 +364,7 @@ void PlotAndTestProofOfSpace(std::string filename, uint32_t iterations, uint8_t uint32_t expected_success) { DiskPlotter plotter = DiskPlotter(); uint8_t memo[5] = {1, 2, 3, 4, 5}; - plotter.CreatePlotDisk(filename, k, memo, 5, plot_id, 32); + plotter.CreatePlotDisk(".", ".", filename, k, memo, 5, plot_id, 32); TestProofOfSpace(filename, iterations, k, plot_id, expected_success); REQUIRE(remove(filename.c_str()) == 0); } @@ -388,7 +388,7 @@ TEST_CASE("Invalid plot") { uint8_t memo[5] = {1, 2, 3, 4, 5}; string filename = "invalid-plot.dat"; uint8_t k = 22; - plotter.CreatePlotDisk(filename, k, memo, 5, plot_id_1, 32); + plotter.CreatePlotDisk(".", ".", filename, k, memo, 5, plot_id_1, 32); DiskProver prover(filename); uint8_t* proof_data = new uint8_t[8 * k]; uint8_t challenge[32]; diff --git a/lib/chiapos/tests/test_python_bindings.py b/lib/chiapos/tests/test_python_bindings.py index 70971a57e780..a540e832480a 100644 --- a/lib/chiapos/tests/test_python_bindings.py +++ b/lib/chiapos/tests/test_python_bindings.py @@ -13,7 +13,7 @@ def test_k_21(self): 10, 11, 129, 139, 171, 15, 23]) pl = DiskPlotter() - pl.create_plot_disk("./myplot.dat", 21, bytes([1, 2, 3, 4, 5]), plot_seed) + pl.create_plot_disk(".", ".", "myplot.dat", 21, bytes([1, 2, 3, 4, 5]), plot_seed) pr = DiskProver("./myplot.dat") total_proofs: int = 0 @@ -32,7 +32,7 @@ def test_k_21(self): print(f"total proofs {total_proofs} out of {iterations}\ {total_proofs / iterations}") assert total_proofs == 4647 - os.remove("./myplot.dat") + os.remove("myplot.dat") if __name__ == '__main__': diff --git a/lib/chiavdf/fast_vdf/python_bindings/verifier.py b/lib/chiavdf/fast_vdf/python_bindings/test_verifier.py similarity index 96% rename from lib/chiavdf/fast_vdf/python_bindings/verifier.py rename to lib/chiavdf/fast_vdf/python_bindings/test_verifier.py index 1837344a14a2..01bf059aaca8 100644 --- a/lib/chiavdf/fast_vdf/python_bindings/verifier.py +++ b/lib/chiavdf/fast_vdf/python_bindings/test_verifier.py @@ -54,7 +54,7 @@ witness_type = 2 t1 = time.time() -result = verify( +result_1 = verify( 1024, challenge_hash, a, @@ -65,10 +65,11 @@ ) t2 = time.time() -print(f"Result test 1: {result}") +print(f"Result test 1: {result_1}") print(f"Test time: {t2 - t1}") +assert result_1 -result = verify( +result_2 = verify( 1024, challenge_hash, a, @@ -78,4 +79,5 @@ witness_type ) -print(f"Result test 2: {result}") +print(f"Result test 2: {result_2}") +assert not result_2 diff --git a/lib/chiavdf/inkfish/classgroup.py b/lib/chiavdf/inkfish/classgroup.py index 948b32d240f5..88b97810a942 100644 --- a/lib/chiavdf/inkfish/classgroup.py +++ b/lib/chiavdf/inkfish/classgroup.py @@ -11,7 +11,7 @@ def from_ab_discriminant(class_, a, b, discriminant): assert discriminant < 0 assert discriminant % 4 == 1 c = (b * b - discriminant) // (4 * a) - p = class_(a, b, c).reduced() + p = class_((a, b, c)).reduced() assert p.discriminant() == discriminant return p @@ -20,12 +20,14 @@ def from_bytes(class_, bytearray, discriminant): int_size = (discriminant.bit_length() + 16) >> 4 a = int.from_bytes(bytearray[0:int_size], "big", signed=True) b = int.from_bytes(bytearray[int_size:], "big", signed=True) - return ClassGroup(a, b, (b**2 - discriminant)//(4*a)) + return ClassGroup((a, b, (b**2 - discriminant)//(4*a))) - def __new__(self, a, b, c): - return tuple.__new__(self, (a, b, c)) + def __new__(cls, t): + a, b, c = t + return tuple.__new__(cls, (a, b, c)) - def __init__(self, a, b, c): + def __init__(self, t): + a, b, c = t super(ClassGroup, self).__init__() self._discriminant = None @@ -50,7 +52,7 @@ def reduced(self): while a > c or (a == c and b < 0): s = (c + b) // (c + c) a, b, c = c, -b + 2 * s * c, c * s * s - b * s + a - return self.__class__(a, b, c).normalized() + return self.__class__((a, b, c)).normalized() def normalized(self): a, b, c = self @@ -58,7 +60,7 @@ def normalized(self): return self r = (a - b) // (2 * a) b, c = b + 2 * r * a, a * r * r + b * r + c - return self.__class__(a, b, c) + return self.__class__((a, b, c)) def serialize(self): r = self.reduced() @@ -68,7 +70,7 @@ def serialize(self): for x in [r[0], r[1]]]) def __eq__(self, other): - return tuple(self.reduced()) == tuple(ClassGroup(*other).reduced()) + return tuple(self.reduced()) == tuple(ClassGroup((other[0], other[1], other[2])).reduced()) def __ne__(self, other): return not self.__eq__(other) @@ -85,7 +87,7 @@ def __pow__(self, n): def inverse(self): a, b, c = self - return self.__class__(a, -b, c) + return self.__class__((a, -b, c)) def multiply(self, other): """ @@ -130,7 +132,7 @@ def multiply(self, other): a3 = s * t - r * u b3 = (j * u + m * r) - (k * t + l * s) c3 = k * l - j * m - return self.__class__(a3, b3, c3).reduced() + return self.__class__((a3, b3, c3)).reduced() def square(self): """ @@ -174,7 +176,7 @@ def square(self): a3 = s * t - r * u b3 = (j * u + m * r) - (k * t + l * s) c3 = k * l - j * m - return self.__class__(a3, b3, c3).reduced() + return self.__class__((a3, b3, c3)).reduced() """ diff --git a/requirements.txt b/requirements.txt index f8604209cc8c..7f0ed0acf2f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,23 +6,19 @@ async-timeout==3.0.1 asyncssh==2.1.0 attrs==19.3.0 autoflake==1.3.1 -bitstring==3.1.6 black==19.10b0 blspy==0.1.14 -cbor2==5.0.0 +cbor2==5.0.1 cffi==1.13.2 chardet==3.0.4 Click==7.0 colorlog==4.1.0 -cppimport==18.11.8 cryptography==2.8 entrypoints==0.3 flake8==3.7.9 idna==2.8 importlib-metadata==1.4.0 isort==4.3.21 -Mako==1.1.0 -MarkupSafe==1.1.1 mccabe==0.6.1 miniupnpc==2.0.2 more-itertools==8.1.0 @@ -34,24 +30,23 @@ pathspec==0.7.0 pluggy==0.13.1 prompt-toolkit==3.0.2 py==1.8.1 -pybind11==2.4.3 pycodestyle==2.5.0 pycparser==2.19 pyflakes==2.1.1 pyparsing==2.4.6 -pytest==5.3.2 +pytest==5.3.4 pytest-asyncio==0.10.0 PyYAML==5.3 regex==2020.1.8 -setuptools-scm==3.3.3 +setuptools-scm==3.4.2 six==1.14.0 toml==0.10.0 typed-ast==1.4.1 typing-extensions==3.7.4.1 -uvloop==0.14.0 wcwidth==0.1.8 yarl==1.4.2 -zipp==1.0.0 +zipp==2.0.0 -e lib/chiapos -e lib/py-setproctitle -e lib/chiavdf/fast_vdf + diff --git a/scripts/check_plots.py b/scripts/check_plots.py index e27a9b61a5ea..f55125602448 100644 --- a/scripts/check_plots.py +++ b/scripts/check_plots.py @@ -20,7 +20,9 @@ def main(): """ parser = argparse.ArgumentParser(description="Chia plot checking script.") - parser.add_argument("-n", "--num", help="Number of challenges", type=int, default=1000) + parser.add_argument( + "-n", "--num", help="Number of challenges", type=int, default=1000 + ) args = parser.parse_args() v = Verifier() @@ -29,30 +31,41 @@ def main(): for plot_filename, plot_info in plot_config["plots"].items(): plot_seed: bytes32 = ProofOfSpace.calculate_plot_seed( PublicKey.from_bytes(bytes.fromhex(plot_info["pool_pk"])), - PrivateKey.from_bytes(bytes.fromhex(plot_info["sk"])).get_public_key() + PrivateKey.from_bytes(bytes.fromhex(plot_info["sk"])).get_public_key(), ) - # Tries relative path - full_path: str = os.path.join(plot_root, plot_filename) - if not os.path.isfile(full_path): - # Tries absolute path - full_path: str = plot_filename + if not os.path.isfile(plot_filename): + # Tries relative path + full_path: str = os.path.join(plot_root, plot_filename) if not os.path.isfile(full_path): - print(f"Plot file {full_path} not found.") - continue - pr = DiskProver(full_path) + # Tries absolute path + full_path: str = plot_filename + if not os.path.isfile(full_path): + print(f"Plot file {full_path} not found.") + continue + pr = DiskProver(full_path) + else: + pr = DiskProver(plot_filename) total_proofs = 0 try: for i in range(args.num): challenge = sha256(i.to_bytes(32, "big")).digest() - for index, quality in enumerate(pr.get_qualities_for_challenge(challenge)): + for index, quality in enumerate( + pr.get_qualities_for_challenge(challenge) + ): proof = pr.get_full_proof(challenge, index) total_proofs += 1 - ver_quality = v.validate_proof(plot_seed, pr.get_size(), challenge, proof) - assert(quality == ver_quality) + ver_quality = v.validate_proof( + plot_seed, pr.get_size(), challenge, proof + ) + assert quality == ver_quality except BaseException as e: - print(f"{type(e)}: {e} error in proving/verifying for plot {plot_filename}") - print(f"{plot_filename}: Proofs {total_proofs} / {args.num}, {round(total_proofs/float(args.num), 4)}") + print( + f"{type(e)}: {e} error in proving/verifying for plot {plot_filename}" + ) + print( + f"{plot_filename}: Proofs {total_proofs} / {args.num}, {round(total_proofs/float(args.num), 4)}" + ) else: print(f"Not plot file found at {plot_config_filename}") diff --git a/scripts/common.sh b/scripts/common.sh index 24ab9a761c7a..726511f37997 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -1,5 +1,5 @@ _kill_servers() { - PROCS=`ps -e | grep -E 'chia_|vdf_server' | awk '{print $1}'` + PROCS=`ps -e | grep -E 'chia_|vdf_server' | awk '!/grep/' | awk '{print $1}'` if [ -n "$PROCS" ]; then echo "$PROCS" | xargs -L1 kill fi diff --git a/scripts/create_plots.py b/scripts/create_plots.py index c8d3df141f48..1f0811dd39b3 100755 --- a/scripts/create_plots.py +++ b/scripts/create_plots.py @@ -10,7 +10,6 @@ from src.types.proof_of_space import ProofOfSpace from src.types.sized_bytes import bytes32 -plot_root = os.path.join(ROOT_DIR, "plots") plot_config_filename = os.path.join(ROOT_DIR, "config", "plots.yaml") key_config_filename = os.path.join(ROOT_DIR, "config", "keys.yaml") @@ -28,11 +27,27 @@ def main(): parser.add_argument( "-p", "--pool_pub_key", help="Hex public key of pool", type=str, default="" ) + parser.add_argument( + "-t", + "--tmp_dir", + help="Temporary directory for plotting files (relative or absolute)", + type=str, + default="./plots", + ) + parser.add_argument( + "-d", + "--final_dir", + help="Final directory for plots (relative or absolute)", + type=str, + default="./plots", + ) # We need the keys file, to access pool keys (if the exist), and the sk_seed. args = parser.parse_args() if not os.path.isfile(key_config_filename): - raise RuntimeError("Keys not generated. Run python3.7 ./scripts/regenerate_keys.py.") + raise RuntimeError( + "Keys not generated. Run python3 ./scripts/regenerate_keys.py." + ) # The seed is what will be used to generate a private key for each plot key_config = safe_load(open(key_config_filename, "r")) @@ -62,13 +77,15 @@ def main(): pool_pk, sk.get_public_key() ) filename: str = f"plot-{i}-{args.size}-{plot_seed}.dat" - full_path: str = os.path.join(plot_root, filename) + full_path: str = os.path.join(args.final_dir, filename) if os.path.isfile(full_path): print(f"Plot {filename} already exists") else: # Creates the plot. This will take a long time for larger plots. plotter: DiskPlotter = DiskPlotter() - plotter.create_plot_disk(full_path, args.size, bytes([]), plot_seed) + plotter.create_plot_disk( + args.tmp_dir, args.final_dir, filename, args.size, bytes([]), plot_seed + ) # Updates the config if necessary. if os.path.isfile(plot_config_filename): @@ -76,8 +93,8 @@ def main(): else: plot_config = {"plots": {}} plot_config_plots_new = deepcopy(plot_config["plots"]) - if filename not in plot_config_plots_new: - plot_config_plots_new[filename] = { + if full_path not in plot_config_plots_new: + plot_config_plots_new[full_path] = { "sk": bytes(sk).hex(), "pool_pk": bytes(pool_pk).hex(), } diff --git a/scripts/run_all.sh b/scripts/run_all.sh index 9ce03aad3834..3224111b5099 100755 --- a/scripts/run_all.sh +++ b/scripts/run_all.sh @@ -6,7 +6,7 @@ _run_bg_cmd python -m src.server.start_harvester _run_bg_cmd python -m src.server.start_timelord _run_bg_cmd python -m src.server.start_farmer -_run_bg_cmd python -m src.server.start_full_node "127.0.0.1" 8444 -id 1 -f -t -r 8555 -_run_bg_cmd python -m src.ui.start_ui 8222 -r 8555 +_run_bg_cmd python -m src.server.start_full_node --port=8444 --database_id=1 --connect_to_farmer=True --connect_to_timelord=True --rpc_port=8555 +_run_bg_cmd python -m src.ui.start_ui --port=8222 --rpc_port=8555 wait diff --git a/scripts/run_all_simulation.sh b/scripts/run_all_simulation.sh index 5d049cc54232..0c45614d65fd 100755 --- a/scripts/run_all_simulation.sh +++ b/scripts/run_all_simulation.sh @@ -1,7 +1,7 @@ . .venv/bin/activate . scripts/common.sh -echo "Starting local blockchain simulation. Make sure full node is configured to point to the local introducer (127.0.0.1:8445) in config/config.py." +echo "Starting local blockchain simulation. Runs a local introducer and chia system." echo "Note that this simulation will not work if connected to external nodes." # Starts a harvester, farmer, timelord, introducer, and 3 full nodes, locally. @@ -12,9 +12,9 @@ _run_bg_cmd python -m src.server.start_harvester _run_bg_cmd python -m src.server.start_timelord _run_bg_cmd python -m src.server.start_farmer _run_bg_cmd python -m src.server.start_introducer -_run_bg_cmd python -m src.server.start_full_node "127.0.0.1" 8444 -id 1 -f -t -r 8555 -_run_bg_cmd python -m src.server.start_full_node "127.0.0.1" 8002 -id 2 -r 8556 -_run_bg_cmd python -m src.ui.start_ui 8222 -r 8555 -_run_bg_cmd python -m src.ui.start_ui 8223 -r 8556 +_run_bg_cmd python -m src.server.start_full_node --port=8444 --database_id=1 --connect_to_farmer=True --connect_to_timelord=True --rpc_port=8555 --introducer_peer.host="127.0.0.1" --introducer_peer.port=8445 +_run_bg_cmd python -m src.server.start_full_node --port=8002 --database_id=2 --rpc_port=8556 --introducer_peer.host="127.0.0.1" --introducer_peer.port=8445 +_run_bg_cmd python -m src.ui.start_ui --port=8222 --rpc_port=8555 +_run_bg_cmd python -m src.ui.start_ui --port=8223 --rpc_port=8556 wait \ No newline at end of file diff --git a/scripts/run_farming.sh b/scripts/run_farming.sh index 7e55ba10892a..119747b18d30 100755 --- a/scripts/run_farming.sh +++ b/scripts/run_farming.sh @@ -5,7 +5,7 @@ _run_bg_cmd python -m src.server.start_harvester _run_bg_cmd python -m src.server.start_farmer -_run_bg_cmd python -m src.server.start_full_node "127.0.0.1" 8444 -id 1 -f -r 8555 -_run_bg_cmd python -m src.ui.start_ui 8222 -r 8555 +_run_bg_cmd python -m src.server.start_full_node --port=8444 --database_id=1 --connect_to_farmer=True --rpc_port=8555 +_run_bg_cmd python -m src.ui.start_ui --port=8222 --rpc_port=8555 wait diff --git a/scripts/run_full_node.sh b/scripts/run_full_node.sh index c8867ffdeb17..bbe1b6c2abbb 100755 --- a/scripts/run_full_node.sh +++ b/scripts/run_full_node.sh @@ -2,7 +2,7 @@ . scripts/common.sh # Starts a full node -_run_bg_cmd python -m src.server.start_full_node "127.0.0.1" 8444 -id 1 -r 8555 -_run_bg_cmd python -m src.ui.start_ui 8222 -r 8555 +_run_bg_cmd python -m src.server.start_full_node --port=8444 --database_id=1 --connect_to_farmer=True --connect_to_timelord=True --rpc_port=8555 +_run_bg_cmd python -m src.ui.start_ui --port=8222 --rpc_port=8555 wait diff --git a/scripts/run_timelord.sh b/scripts/run_timelord.sh index 20e8c1f1284c..68c418ab3fab 100755 --- a/scripts/run_timelord.sh +++ b/scripts/run_timelord.sh @@ -4,7 +4,7 @@ # Starts a timelord, and a full node _run_bg_cmd python -m src.server.start_timelord -_run_bg_cmd python -m src.server.start_full_node "127.0.0.1" 8444 -id 1 -t -r 8555 -_run_bg_cmd python -m src.ui.start_ui 8222 -r 8555 +_run_bg_cmd python -m src.server.start_full_node --port=8444 --database_id=1 --connect_to_timelord=True --rpc_port=8555 +_run_bg_cmd python -m src.ui.start_ui --port=8222 --rpc_port=8555 wait diff --git a/setup.py b/setup.py index fc1e4b35daf4..ee4ea18a2f97 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,6 @@ "pyyaml", # Used for config file format "asyncssh", # Used for SSH server for UI "miniupnpc", # Allows users to open ports on their router - "uvloop", # Faster replacement to asyncio "aiosqlite", # asyncio wrapper for sqlite, to store blocks "aiohttp", # HTTP server for full node rpc "setuptools-scm", # Used for versioning @@ -35,6 +34,9 @@ keywords="chia blockchain node", install_requires=dependencies + dev_dependencies, setup_requires=["setuptools_scm"], + extras_require={ + 'uvloop': ["uvloop"], + }, use_scm_version={"fallback_version": "unknown-no-.git-directory"}, long_description=open("README.md").read(), zip_safe=False, diff --git a/src/blockchain.py b/src/blockchain.py index 59c8f9d30498..94ff6f07ff63 100644 --- a/src/blockchain.py +++ b/src/blockchain.py @@ -15,7 +15,7 @@ ) from src.types.full_block import FullBlock -from src.types.header_block import HeaderBlock +from src.types.header_block import HeaderBlock, SmallHeaderBlock from src.types.sized_bytes import bytes32 from src.util.errors import BlockNotInBlockchain, InvalidGenesisBlock from src.util.ints import uint32, uint64 @@ -43,13 +43,13 @@ class Blockchain: # Allow passing in custom overrides for any consesus parameters constants: Dict # Tips of the blockchain - tips: List[HeaderBlock] + tips: List[SmallHeaderBlock] # Least common ancestor of tips - lca_block: HeaderBlock - # Defines the path from genesis to the tip + lca_block: SmallHeaderBlock + # Defines the path from genesis to the lca height_to_hash: Dict[uint32, bytes32] # All headers (but not orphans) from genesis to the tip are guaranteed to be in header_blocks - header_blocks: Dict[bytes32, HeaderBlock] + headers: Dict[bytes32, SmallHeaderBlock] # Process pool to verify blocks pool: concurrent.futures.ProcessPoolExecutor # Genesis block @@ -57,7 +57,7 @@ class Blockchain: @staticmethod async def create( - header_blocks: Dict[str, HeaderBlock], override_constants: Dict = {} + headers_input: Dict[str, SmallHeaderBlock], override_constants: Dict = {} ): """ Initializes a blockchain with the given header blocks, assuming they have all been @@ -71,7 +71,7 @@ async def create( self.constants[key] = value self.tips = [] self.height_to_hash = {} - self.header_blocks = {} + self.headers = {} self.genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"]) @@ -80,23 +80,23 @@ async def create( raise InvalidGenesisBlock() assert self.lca_block is not None - if len(header_blocks) > 0: - self.header_blocks = header_blocks - for _, header_block in self.header_blocks.items(): + if len(headers_input) > 0: + self.headers = headers_input + for _, header_block in self.headers.items(): self.height_to_hash[header_block.height] = header_block.header_hash await self._reconsider_heads(header_block, False) assert ( - self.header_blocks[self.height_to_hash[uint32(0)]] - == self.genesis.header_block + self.headers[self.height_to_hash[uint32(0)]].header.get_hash() + == self.genesis.header_block.header_hash ) - if len(header_blocks) > 1: + if len(headers_input) > 1: assert ( - self.header_blocks[self.height_to_hash[uint32(1)]].prev_header_hash + self.headers[self.height_to_hash[uint32(1)]].prev_header_hash == self.genesis.header_hash ) return self - def get_current_tips(self) -> List[HeaderBlock]: + def get_current_tips(self) -> List[SmallHeaderBlock]: """ Return the heads. """ @@ -112,20 +112,20 @@ def is_child_of_head(self, block: FullBlock): return False def cointains_block(self, header_hash: bytes32): - return header_hash in self.header_blocks + return header_hash in self.headers def get_header_hashes(self, tip_header_hash: bytes32) -> List[bytes32]: - if tip_header_hash not in self.header_blocks: + if tip_header_hash not in self.headers: raise ValueError("Invalid tip requested") - curr = self.header_blocks[tip_header_hash] + curr = self.headers[tip_header_hash] ret_hashes = [tip_header_hash] while curr.height != 0: - curr = self.header_blocks[curr.prev_header_hash] + curr = self.headers[curr.prev_header_hash] ret_hashes.append(curr.header_hash) return list(reversed(ret_hashes)) - def get_header_blocks_by_height( + def get_header_hashes_by_height( self, heights: List[uint32], tip_header_hash: bytes32 ) -> List[HeaderBlock]: """ @@ -138,29 +138,31 @@ def get_header_blocks_by_height( [(height, index) for index, height in enumerate(heights)], reverse=True ) - curr_block: Optional[HeaderBlock] = self.header_blocks[tip_header_hash] + curr_block: Optional[SmallHeaderBlock] = self.headers.get(tip_header_hash, None) if curr_block is None: raise BlockNotInBlockchain( f"Header hash {tip_header_hash} not present in chain." ) - headers: List[Tuple[int, HeaderBlock]] = [] + headers: List[Tuple[int, SmallHeaderBlock]] = [] for height, index in sorted_heights: if height > curr_block.height: raise ValueError("Height is not valid for tip {tip_header_hash}") while height < curr_block.height: - curr_block = self.header_blocks.get(curr_block.prev_header_hash, None) + curr_block = self.headers.get(curr_block.prev_header_hash, None) if curr_block is None: raise ValueError(f"Do not have header {height}") headers.append((index, curr_block)) - return [b for index, b in sorted(headers)] + + # Return sorted by index (original order) + return [b.header_hash for _, b in sorted(headers, key=lambda pair: pair[0])] def find_fork_point(self, alternate_chain: List[bytes32]) -> uint32: """ Takes in an alternate blockchain (headers), and compares it to self. Returns the last header where both blockchains are equal. """ - lca: HeaderBlock = self.lca_block + lca: SmallHeaderBlock = self.lca_block if lca.height >= len(alternate_chain) - 1: raise ValueError("Alternate chain is shorter") @@ -192,7 +194,7 @@ def get_next_difficulty(self, header_hash: bytes32) -> uint64: Returns the difficulty of the next block that extends onto header_hash. Used to calculate the number of iterations. """ - block: HeaderBlock = self.header_blocks[header_hash] + block: SmallHeaderBlock = self.headers[header_hash] next_height: uint32 = uint32(block.height + 1) if next_height < self.constants["DIFFICULTY_EPOCH"]: @@ -207,7 +209,7 @@ def get_next_difficulty(self, header_hash: bytes32) -> uint64: != self.constants["DIFFICULTY_DELAY"] ): # Not at a point where difficulty would change - prev_block: HeaderBlock = self.header_blocks[block.prev_header_hash] + prev_block: SmallHeaderBlock = self.headers[block.prev_header_hash] assert block.challenge is not None assert prev_block is not None and prev_block.challenge is not None if prev_block is None: @@ -238,7 +240,7 @@ def get_next_difficulty(self, header_hash: bytes32) -> uint64: if block not in self.get_current_tips() or height3 not in self.height_to_hash: # This means we are either on a fork, or on one of the chains, but after the LCA, # so we manually backtrack. - curr: Optional[HeaderBlock] = block + curr: Optional[SmallHeaderBlock] = block assert curr is not None while ( curr.height not in self.height_to_hash @@ -250,16 +252,16 @@ def get_next_difficulty(self, header_hash: bytes32) -> uint64: block2 = curr elif curr.height == height3: block3 = curr - curr = self.header_blocks.get(curr.prev_header_hash, None) + curr = self.headers.get(curr.prev_header_hash, None) assert curr is not None # Once we are before the fork point (and before the LCA), we can use the height_to_hash map if not block1 and height1 >= 0: # height1 could be -1, for the first difficulty calculation - block1 = self.header_blocks[self.height_to_hash[height1]] + block1 = self.headers[self.height_to_hash[height1]] if not block2: - block2 = self.header_blocks[self.height_to_hash[height2]] + block2 = self.headers[self.height_to_hash[height2]] if not block3: - block3 = self.header_blocks[self.height_to_hash[height3]] + block3 = self.headers[self.height_to_hash[height3]] assert block2 is not None and block3 is not None # Current difficulty parameter (diff of block h = i - 1) @@ -272,7 +274,7 @@ def get_next_difficulty(self, header_hash: bytes32) -> uint64: else: # In the case of height == -1, there is no timestamp here, so assume the genesis block # took constants["BLOCK_TIME_TARGET"] seconds to mine. - genesis = self.header_blocks[self.height_to_hash[uint32(0)]] + genesis = self.headers[self.height_to_hash[uint32(0)]] timestamp1 = ( genesis.header.data.timestamp - self.constants["BLOCK_TIME_TARGET"] ) @@ -317,12 +319,12 @@ def get_next_difficulty(self, header_hash: bytes32) -> uint64: ] ) - def get_next_ips(self, header_hash) -> uint64: + def get_next_ips(self, header_block: HeaderBlock) -> uint64: """ Returns the VDF speed in iterations per seconds, to be used for the next block. This depends on the number of iterations of the last epoch, and changes at the same block as the difficulty. """ - block: HeaderBlock = self.header_blocks[header_hash] + block: SmallHeaderBlock = self.headers[header_block.header_hash] assert block.challenge is not None next_height: uint32 = uint32(block.height + 1) @@ -330,10 +332,10 @@ def get_next_ips(self, header_hash) -> uint64: # First epoch has a hardcoded vdf speed return self.constants["VDF_IPS_STARTING"] - prev_block: HeaderBlock = self.header_blocks[block.prev_header_hash] + prev_block: SmallHeaderBlock = self.headers[block.prev_header_hash] assert prev_block.challenge is not None - proof_of_space = block.proof_of_space + proof_of_space = header_block.proof_of_space difficulty = self.get_next_difficulty(prev_block.header_hash) iterations = uint64( block.challenge.total_iters - prev_block.challenge.total_iters @@ -365,12 +367,12 @@ def get_next_ips(self, header_hash) -> uint64: # Height2 is the last block in the previous epoch height2 = uint32(next_height - self.constants["DIFFICULTY_DELAY"] - 1) - block1: Optional[HeaderBlock] = None - block2: Optional[HeaderBlock] = None + block1: Optional[SmallHeaderBlock] = None + block2: Optional[SmallHeaderBlock] = None if block not in self.get_current_tips() or height2 not in self.height_to_hash: # This means we are either on a fork, or on one of the chains, but after the LCA, # so we manually backtrack. - curr: Optional[HeaderBlock] = block + curr: Optional[SmallHeaderBlock] = block assert curr is not None while ( curr.height not in self.height_to_hash @@ -380,14 +382,14 @@ def get_next_ips(self, header_hash) -> uint64: block1 = curr elif curr.height == height2: block2 = curr - curr = self.header_blocks.get(curr.prev_header_hash, None) + curr = self.headers.get(curr.prev_header_hash, None) assert curr is not None # Once we are before the fork point (and before the LCA), we can use the height_to_hash map if block1 is None and height1 >= 0: # height1 could be -1, for the first difficulty calculation - block1 = self.header_blocks.get(self.height_to_hash[height1], None) + block1 = self.headers.get(self.height_to_hash[height1], None) if block2 is None: - block2 = self.header_blocks.get(self.height_to_hash[height2], None) + block2 = self.headers.get(self.height_to_hash[height2], None) assert block2 is not None assert block2.challenge is not None @@ -398,7 +400,7 @@ def get_next_ips(self, header_hash) -> uint64: else: # In the case of height == -1, there is no timestamp here, so assume the genesis block # took constants["BLOCK_TIME_TARGET"] seconds to mine. - genesis: HeaderBlock = self.header_blocks[self.height_to_hash[uint32(0)]] + genesis: SmallHeaderBlock = self.headers[self.height_to_hash[uint32(0)]] timestamp1 = ( genesis.header.data.timestamp - self.constants["BLOCK_TIME_TARGET"] ) @@ -419,7 +421,11 @@ def get_next_ips(self, header_hash) -> uint64: ) async def receive_block( - self, block: FullBlock, pre_validated: bool = False, pos_quality: bytes32 = None + self, + block: FullBlock, + prev_block: Optional[HeaderBlock] = None, + pre_validated: bool = False, + pos_quality: bytes32 = None, ) -> ReceiveBlockResult: """ Adds a new block into the blockchain, if it's valid and connected to the current @@ -427,19 +433,25 @@ async def receive_block( """ genesis: bool = block.height == 0 and not self.tips - if block.header_hash in self.header_blocks: + if block.header_hash in self.headers: return ReceiveBlockResult.ALREADY_HAVE_BLOCK - if block.prev_header_hash not in self.header_blocks and not genesis: + if block.prev_header_hash not in self.headers and not genesis: return ReceiveBlockResult.DISCONNECTED_BLOCK - if not await self.validate_block(block, genesis, pre_validated, pos_quality): + if not await self.validate_block( + block, prev_block, genesis, pre_validated, pos_quality + ): return ReceiveBlockResult.INVALID_BLOCK # Cache header in memory - self.header_blocks[block.header_hash] = block.header_block + assert block.header_block.challenge is not None + small_header_block = SmallHeaderBlock( + block.header_block.header, block.header_block.challenge + ) + self.headers[block.header_hash] = small_header_block - if await self._reconsider_heads(block.header_block, genesis): + if await self._reconsider_heads(small_header_block, genesis): return ReceiveBlockResult.ADDED_TO_HEAD else: return ReceiveBlockResult.ADDED_AS_ORPHAN @@ -490,21 +502,21 @@ async def validate_unfinished_block( return False # 6. Check previous pointer(s) / flyclient - if not genesis and block.prev_header_hash not in self.header_blocks: + if not genesis and block.prev_header_hash not in self.headers: return False # 7. Check Now+2hrs > timestamp > avg timestamp of last 11 blocks - prev_block: Optional[HeaderBlock] = None + prev_block: Optional[SmallHeaderBlock] = None if not genesis: # TODO: do something about first 11 blocks last_timestamps: List[uint64] = [] - prev_block = self.header_blocks.get(block.prev_header_hash, None) + prev_block = self.headers.get(block.prev_header_hash, None) if not prev_block: return False curr = prev_block while len(last_timestamps) < self.constants["NUMBER_OF_TIMESTAMPS"]: last_timestamps.append(curr.header.data.timestamp) - fetched = self.header_blocks.get(curr.prev_header_hash, None) + fetched = self.headers.get(curr.prev_header_hash, None) if not fetched: break curr = fetched @@ -570,6 +582,7 @@ async def validate_unfinished_block( async def validate_block( self, block: FullBlock, + prev_full_block: Optional[HeaderBlock] = None, genesis: bool = False, pre_validated: bool = False, pos_quality: bytes32 = None, @@ -590,7 +603,8 @@ async def validate_block( ips: uint64 if not genesis: difficulty = self.get_next_difficulty(block.prev_header_hash) - ips = self.get_next_ips(block.prev_header_hash) + assert prev_full_block is not None + ips = self.get_next_ips(prev_full_block) else: difficulty = uint64(self.constants["DIFFICULTY_STARTING"]) ips = uint64(self.constants["VDF_IPS_STARTING"]) @@ -640,7 +654,7 @@ async def validate_block( return False if not genesis: - prev_block: Optional[HeaderBlock] = self.header_blocks.get( + prev_block: Optional[SmallHeaderBlock] = self.headers.get( block.prev_header_hash, None ) if not prev_block or not prev_block.challenge: @@ -766,47 +780,49 @@ def pre_validate_block_multi(data) -> Tuple[bool, Optional[bytes]]: return True, bytes(pos_quality) - def _reconsider_heights(self, old_lca: Optional[HeaderBlock], new_lca: HeaderBlock): + def _reconsider_heights( + self, old_lca: Optional[SmallHeaderBlock], new_lca: SmallHeaderBlock + ): """ Update the mapping from height to block hash, when the lca changes. """ - curr_old: Optional[HeaderBlock] = old_lca if old_lca else None - curr_new: HeaderBlock = new_lca + curr_old: Optional[SmallHeaderBlock] = old_lca if old_lca else None + curr_new: SmallHeaderBlock = new_lca while True: - fetched: Optional[HeaderBlock] + fetched: Optional[SmallHeaderBlock] if not curr_old or curr_old.height < curr_new.height: self.height_to_hash[uint32(curr_new.height)] = curr_new.header_hash - self.header_blocks[curr_new.header_hash] = curr_new + self.headers[curr_new.header_hash] = curr_new if curr_new.height == 0: return - curr_new = self.header_blocks[curr_new.prev_header_hash] + curr_new = self.headers[curr_new.prev_header_hash] elif curr_old.height > curr_new.height: del self.height_to_hash[uint32(curr_old.height)] - curr_old = self.header_blocks[curr_old.prev_header_hash] + curr_old = self.headers[curr_old.prev_header_hash] else: if curr_new.header_hash == curr_old.header_hash: return self.height_to_hash[uint32(curr_new.height)] = curr_new.header_hash - curr_new = self.header_blocks[curr_new.prev_header_hash] - curr_old = self.header_blocks[curr_old.prev_header_hash] + curr_new = self.headers[curr_new.prev_header_hash] + curr_old = self.headers[curr_old.prev_header_hash] async def _reconsider_lca(self, genesis: bool): """ Update the least common ancestor of the heads. This is useful, since we can just assume there is one block per height before the LCA (and use the height_to_hash dict). """ - cur: List[HeaderBlock] = self.tips[:] + cur: List[SmallHeaderBlock] = self.tips[:] while any(b.header_hash != cur[0].header_hash for b in cur): heights = [b.height for b in cur] i = heights.index(max(heights)) - cur[i] = self.header_blocks[cur[i].prev_header_hash] + cur[i] = self.headers[cur[i].prev_header_hash] if genesis: self._reconsider_heights(None, cur[0]) else: self._reconsider_heights(self.lca_block, cur[0]) self.lca_block = cur[0] - async def _reconsider_heads(self, block: HeaderBlock, genesis: bool) -> bool: + async def _reconsider_heads(self, block: SmallHeaderBlock, genesis: bool) -> bool: """ When a new block is added, this is called, to check if the new block is heavier than one of the heads. diff --git a/src/consensus/weight_verifier.py b/src/consensus/weight_verifier.py index 89707121790c..cc63bbb8a54d 100644 --- a/src/consensus/weight_verifier.py +++ b/src/consensus/weight_verifier.py @@ -1,10 +1,10 @@ from typing import List -from src.types.header_block import HeaderBlock +from src.types.header_block import HeaderBlock, SmallHeaderBlock def verify_weight( - tip: HeaderBlock, proof_blocks: List[HeaderBlock], fork_point: HeaderBlock + tip: HeaderBlock, proof_blocks: List[HeaderBlock], fork_point: SmallHeaderBlock ) -> bool: """ Verifies whether the weight of the tip is valid or not. Naively, looks at every block diff --git a/src/farmer.py b/src/farmer.py index 3a2e232c7af8..3b12757396f8 100644 --- a/src/farmer.py +++ b/src/farmer.py @@ -1,12 +1,9 @@ import logging -import os from hashlib import sha256 from typing import Any, Dict, List, Set from blspy import PrependSignature, PrivateKey, Util -from yaml import safe_load -from definitions import ROOT_DIR from src.consensus.block_rewards import calculate_block_reward from src.consensus.constants import constants from src.consensus.pot_iterations import calculate_iterations_quality @@ -27,15 +24,9 @@ class Farmer: - def __init__(self): - config_filename = os.path.join(ROOT_DIR, "config", "config.yaml") - key_config_filename = os.path.join(ROOT_DIR, "config", "keys.yaml") - if not os.path.isfile(key_config_filename): - raise RuntimeError( - "Keys not generated. Run python3.7 ./scripts/regenerate_keys.py." - ) - self.config = safe_load(open(config_filename, "r"))["farmer"] - self.key_config = safe_load(open(key_config_filename, "r")) + def __init__(self, farmer_config: Dict, key_config: Dict): + self.config = farmer_config + self.key_config = key_config self.harvester_responses_header_hash: Dict[bytes32, bytes32] = {} self.harvester_responses_challenge: Dict[bytes32, bytes32] = {} self.harvester_responses_proofs: Dict[bytes32, ProofOfSpace] = {} diff --git a/src/full_node.py b/src/full_node.py index 8a8e8b05631c..b10d7ce41241 100644 --- a/src/full_node.py +++ b/src/full_node.py @@ -1,18 +1,15 @@ import asyncio import concurrent import logging -import os import time from asyncio import Event from hashlib import sha256 from secrets import token_bytes -from typing import AsyncGenerator, List, Optional, Tuple +from typing import AsyncGenerator, List, Optional, Tuple, Dict -import yaml from blspy import PrivateKey, Signature from chiapos import Verifier -from definitions import ROOT_DIR from src.blockchain import Blockchain, ReceiveBlockResult from src.consensus.constants import constants from src.consensus.pot_iterations import calculate_iterations @@ -26,7 +23,7 @@ from src.types.fees_target import FeesTarget from src.types.full_block import FullBlock from src.types.header import Header, HeaderData -from src.types.header_block import HeaderBlock +from src.types.header_block import HeaderBlock, SmallHeaderBlock from src.types.peer_info import PeerInfo from src.types.proof_of_space import ProofOfSpace from src.types.sized_bytes import bytes32 @@ -40,16 +37,13 @@ class FullNode: - store: FullNodeStore - blockchain: Blockchain - - def __init__(self, store: FullNodeStore, blockchain: Blockchain): - config_filename = os.path.join(ROOT_DIR, "config", "config.yaml") - self.config = yaml.safe_load(open(config_filename, "r"))["full_node"] - self.store = store - self.blockchain = blockchain - self._shut_down = False # Set to true to close all infinite loops + def __init__(self, store: FullNodeStore, blockchain: Blockchain, config: Dict): + self.store: FullNodeStore = store + self.blockchain: Blockchain = blockchain + self.config: Dict = config + self._shut_down: bool = False # Set to true to close all infinite loops self.server: Optional[ChiaServer] = None + log.warning(f"{self.config}") def _set_server(self, server: ChiaServer): self.server = server @@ -63,7 +57,10 @@ async def _send_tips_to_farmers( """ requests: List[farmer_protocol.ProofOfSpaceFinalized] = [] async with self.store.lock: - tips = self.blockchain.get_current_tips() + tips_raw = self.blockchain.get_current_tips() + tips = await self.store.get_header_blocks_by_hash( + [t.header_hash for t in tips_raw] + ) for tip in tips: assert tip.proof_of_time and tip.challenge challenge_hash = tip.challenge.get_hash() @@ -80,9 +77,7 @@ async def _send_tips_to_farmers( challenge_hash, height, tip.weight, quality, difficulty ) ) - proof_of_time_rate: uint64 = self.blockchain.get_next_ips( - tips[0].header_hash - ) + proof_of_time_rate: uint64 = self.blockchain.get_next_ips(tips[0]) rate_update = farmer_protocol.ProofOfTimeRate(proof_of_time_rate) yield OutboundMessage( NodeType.FARMER, Message("proof_of_time_rate", rate_update), delivery @@ -100,27 +95,24 @@ async def _send_challenges_to_timelords( """ challenge_requests: List[timelord_protocol.ChallengeStart] = [] pos_info_requests: List[timelord_protocol.ProofOfSpaceInfo] = [] - async with self.store.lock: - tips: List[HeaderBlock] = self.blockchain.get_current_tips() - for tip in tips: - assert tip.challenge - challenge_hash = tip.challenge.get_hash() - challenge_requests.append( - timelord_protocol.ChallengeStart( - challenge_hash, tip.challenge.total_weight - ) + tips: List[SmallHeaderBlock] = self.blockchain.get_current_tips() + for tip in tips: + assert tip.challenge + challenge_hash = tip.challenge.get_hash() + challenge_requests.append( + timelord_protocol.ChallengeStart( + challenge_hash, tip.challenge.total_weight ) + ) - tip_hashes = [tip.header_hash for tip in tips] - tip_infos = [ - tup[0] - for tup in list((await self.store.get_unfinished_blocks()).items()) - if tup[1].prev_header_hash in tip_hashes - ] - for chall, iters in tip_infos: - pos_info_requests.append( - timelord_protocol.ProofOfSpaceInfo(chall, iters) - ) + tip_hashes = [tip.header_hash for tip in tips] + tip_infos = [ + tup[0] + for tup in list((self.store.get_unfinished_blocks()).items()) + if tup[1].prev_header_hash in tip_hashes + ] + for chall, iters in tip_infos: + pos_info_requests.append(timelord_protocol.ProofOfSpaceInfo(chall, iters)) for challenge_msg in challenge_requests: yield OutboundMessage( NodeType.TIMELORD, Message("challenge_start", challenge_msg), delivery @@ -139,12 +131,11 @@ async def _on_connect(self) -> OutboundMessageGenerator: """ blocks: List[FullBlock] = [] - async with self.store.lock: - heads: List[HeaderBlock] = self.blockchain.get_current_tips() - for h in heads: - block = await self.store.get_block(h.header.get_hash()) - assert block - blocks.append(block) + tips: List[SmallHeaderBlock] = self.blockchain.get_current_tips() + for t in tips: + block = await self.store.get_block(t.header.get_hash()) + assert block + blocks.append(block) for block in blocks: request = peer_protocol.Block(block) yield OutboundMessage( @@ -209,34 +200,29 @@ async def _sync(self) -> OutboundMessageGenerator: highest_weight: uint64 = uint64(0) tip_block: FullBlock tip_height = 0 + sync_start_time = time.time() # Based on responses from peers about the current heads, see which head is the heaviest # (similar to longest chain rule). - async with self.store.lock: - potential_tips: List[ - Tuple[bytes32, FullBlock] - ] = await self.store.get_potential_tips_tuples() - log.info(f"Have collected {len(potential_tips)} potential tips") - for header_hash, potential_tip_block in potential_tips: - if potential_tip_block.header_block.challenge is None: - raise ValueError( - f"Invalid tip block {potential_tip_block.header_hash} received" - ) - if ( - potential_tip_block.header_block.challenge.total_weight - > highest_weight - ): - highest_weight = ( - potential_tip_block.header_block.challenge.total_weight - ) - tip_block = potential_tip_block - tip_height = potential_tip_block.header_block.challenge.height - if highest_weight <= max( - [t.weight for t in self.blockchain.get_current_tips()] - ): - log.info("Not performing sync, already caught up.") - return + potential_tips: List[ + Tuple[bytes32, FullBlock] + ] = self.store.get_potential_tips_tuples() + log.info(f"Have collected {len(potential_tips)} potential tips") + for header_hash, potential_tip_block in potential_tips: + if potential_tip_block.header_block.challenge is None: + raise ValueError( + f"Invalid tip block {potential_tip_block.header_hash} received" + ) + if potential_tip_block.header_block.challenge.total_weight > highest_weight: + highest_weight = potential_tip_block.header_block.challenge.total_weight + tip_block = potential_tip_block + tip_height = potential_tip_block.header_block.challenge.height + if highest_weight <= max( + [t.weight for t in self.blockchain.get_current_tips()] + ): + log.info("Not performing sync, already caught up.") + return assert tip_block log.info(f"Tip block {tip_block.header_hash} tip height {tip_block.height}") @@ -274,10 +260,9 @@ async def _sync(self) -> OutboundMessageGenerator: log.warning("Did not receive desired header hashes") # Finding the fork point allows us to only download headers and blocks from the fork point - async with self.store.lock: - header_hashes = self.store.get_potential_hashes() - fork_point_height: uint32 = self.blockchain.find_fork_point(header_hashes) - fork_point_hash: bytes32 = header_hashes[fork_point_height] + header_hashes = self.store.get_potential_hashes() + fork_point_height: uint32 = self.blockchain.find_fork_point(header_hashes) + fork_point_hash: bytes32 = header_hashes[fork_point_height] log.info(f"Fork point: {fork_point_hash} at height {fork_point_height}") # Now, we download all of the headers in order to verify the weight, in batches @@ -366,17 +351,14 @@ async def _sync(self) -> OutboundMessageGenerator: total_time_slept += sleep_interval log.info(f"Did not receive desired header blocks") - async with self.store.lock: - for h in range(fork_point_height + 1, tip_height + 1): - header = self.store.get_potential_header(uint32(h)) - assert header is not None - headers.append(header) + for h in range(fork_point_height + 1, tip_height + 1): + header = self.store.get_potential_header(uint32(h)) + assert header is not None + headers.append(header) log.info(f"Downloaded headers up to tip height: {tip_height}") if not verify_weight( - tip_block.header_block, - headers, - self.blockchain.header_blocks[fork_point_hash], + tip_block.header_block, headers, self.blockchain.headers[fork_point_hash], ): raise errors.InvalidWeight( f"Weight of {tip_block.header_block.header.get_hash()} not valid." @@ -485,6 +467,7 @@ async def _sync(self) -> OutboundMessageGenerator: assert b is not None blocks.append(b) + validation_start_time = time.time() prevalidate_results = await self.blockchain.pre_validate_blocks(blocks) index = 0 for height in range(height_checkpoint, end_height): @@ -494,42 +477,57 @@ async def _sync(self) -> OutboundMessageGenerator: uint32(height) ) assert block is not None - start = time.time() + + prev_block: Optional[FullBlock] = await self.store.get_potential_block( + uint32(height - 1) + ) + if prev_block is None: + prev_block = await self.store.get_block(block.prev_header_hash) + assert prev_block is not None + + # The block gets permanantly added to the blockchain + validated, pos = prevalidate_results[index] + index += 1 + async with self.store.lock: - # The block gets permanantly added to the blockchain - validated, pos = prevalidate_results[index] - index += 1 - result = await self.blockchain.receive_block(block, validated, pos) + result = await self.blockchain.receive_block( + block, prev_block.header_block, validated, pos + ) if ( result == ReceiveBlockResult.INVALID_BLOCK or result == ReceiveBlockResult.DISCONNECTED_BLOCK ): raise RuntimeError(f"Invalid block {block.header_hash}") - log.info( - f"Took {time.time() - start} seconds to validate and add block {block.height}." - ) + # Always immediately add the block to the database, after updating blockchain state await self.store.add_block(block) - assert ( - max([h.height for h in self.blockchain.get_current_tips()]) - >= height - ) - await self.store.set_proof_of_time_estimate_ips( - self.blockchain.get_next_ips(block.header_hash) - ) + + assert ( + max([h.height for h in self.blockchain.get_current_tips()]) + >= height + ) + self.store.set_proof_of_time_estimate_ips( + self.blockchain.get_next_ips(block.header_block) + ) + log.info( + f"Took {time.time() - validation_start_time} seconds to validate and add blocks " + f"{height_checkpoint} to {end_height}." + ) assert max([h.height for h in self.blockchain.get_current_tips()]) == tip_height - log.info(f"Finished sync up to height {tip_height}") + log.info( + f"Finished sync up to height {tip_height}. Total time: " + f"{round((time.time() - sync_start_time)/60, 2)} minutes." + ) async def _finish_sync(self) -> OutboundMessageGenerator: """ Finalize sync by setting sync mode to False, clearing all sync information, and adding any final blocks that we have finalized recently. """ + potential_fut_blocks = (self.store.get_potential_future_blocks()).copy() + self.store.set_sync_mode(False) + async with self.store.lock: - potential_fut_blocks = ( - await self.store.get_potential_future_blocks() - ).copy() - await self.store.set_sync_mode(False) await self.store.clear_sync_info() for block in potential_fut_blocks: @@ -562,11 +560,10 @@ async def all_header_hashes( self, all_header_hashes: peer_protocol.AllHeaderHashes ) -> OutboundMessageGenerator: assert len(all_header_hashes.header_hashes) > 0 - async with self.store.lock: - self.store.set_potential_hashes(all_header_hashes.header_hashes) - phr = self.store.get_potential_hashes_received() - assert phr is not None - phr.set() + self.store.set_potential_hashes(all_header_hashes.header_hashes) + phr = self.store.get_potential_hashes_received() + assert phr is not None + phr.set() for _ in []: # Yields nothing yield _ @@ -585,9 +582,14 @@ async def request_header_blocks( ) try: - headers: List[HeaderBlock] = self.blockchain.get_header_blocks_by_height( + header_hashes: List[ + HeaderBlock + ] = self.blockchain.get_header_hashes_by_height( request.heights, request.tip_header_hash ) + header_blocks: List[ + HeaderBlock + ] = await self.store.get_header_blocks_by_hash(header_hashes) log.info(f"Got header blocks by height {time.time() - start}") except KeyError: return @@ -595,7 +597,7 @@ async def request_header_blocks( log.info(f"{e}") return - response = peer_protocol.HeaderBlocks(request.tip_header_hash, headers) + response = peer_protocol.HeaderBlocks(request.tip_header_hash, header_blocks) yield OutboundMessage( NodeType.FULL_NODE, Message("header_blocks", response), Delivery.RESPOND ) @@ -610,10 +612,9 @@ async def header_blocks( log.info( f"Received header blocks {request.header_blocks[0].height, request.header_blocks[-1].height}." ) - async with self.store.lock: - for header_block in request.header_blocks: - self.store.add_potential_header(header_block) - (self.store.get_potential_headers_received(header_block.height)).set() + for header_block in request.header_blocks: + self.store.add_potential_header(header_block) + (self.store.get_potential_headers_received(header_block.height)).set() for _ in []: # Yields nothing yield _ @@ -626,9 +627,10 @@ async def request_sync_blocks( Responsd to a peers request for syncing blocks. """ blocks: List[FullBlock] = [] - tip_block: Optional[FullBlock] = await self.store.get_block( - request.tip_header_hash - ) + async with self.store.lock: + tip_block: Optional[FullBlock] = await self.store.get_block( + request.tip_header_hash + ) if tip_block is not None: if len(request.heights) > self.config["max_blocks_to_send"]: raise errors.TooManyheadersRequested( @@ -637,11 +639,14 @@ async def request_sync_blocks( f"but requested {len(request.heights)}" ) try: - header_blocks: List[ + header_hashes: List[ HeaderBlock - ] = self.blockchain.get_header_blocks_by_height( + ] = self.blockchain.get_header_hashes_by_height( request.heights, request.tip_header_hash ) + header_blocks: List[ + HeaderBlock + ] = await self.store.get_header_blocks_by_hash(header_hashes) for header_block in header_blocks: fetched = await self.store.get_block(header_block.header.get_hash()) assert fetched @@ -669,14 +674,18 @@ async def sync_blocks( We have received the blocks that we needed for syncing. Add them to processing queue. """ log.info(f"Received sync blocks {[b.height for b in request.blocks]}") - async with self.store.lock: - if not await self.store.get_sync_mode(): - log.warning("Receiving sync blocks when we are not in sync mode.") - return - for block in request.blocks: - await self.store.add_potential_block(block) - (self.store.get_potential_blocks_received(block.height)).set() + if not self.store.get_sync_mode(): + log.warning("Receiving sync blocks when we are not in sync mode.") + return + + for block in request.blocks: + await self.store.add_potential_block(block) + if ( + not self.store.get_sync_mode() + ): # We might have left sync mode after the previous await + return + (self.store.get_potential_blocks_received(block.height)).set() for _ in []: # Yields nothing yield _ @@ -700,64 +709,63 @@ async def request_header_hash( ) assert quality_string - async with self.store.lock: - # Retrieves the correct head for the challenge - heads: List[HeaderBlock] = self.blockchain.get_current_tips() - target_head: Optional[HeaderBlock] = None - for head in heads: - assert head.challenge - if head.challenge.get_hash() == request.challenge_hash: - target_head = head - if target_head is None: - # TODO: should we still allow the farmer to farm? - log.warning( - f"Challenge hash: {request.challenge_hash} not in one of three heads" - ) - return - - # TODO: use mempool to grab best transactions, for the selected head - transactions_generator: bytes32 = sha256(b"").digest() - # TODO: calculate the fees of these transactions - fees: FeesTarget = FeesTarget(request.fees_target_puzzle_hash, uint64(0)) - aggregate_sig: Signature = PrivateKey.from_seed(b"12345").sign(b"anything") - # TODO: calculate aggregate signature based on transactions - # TODO: calculate cost of all transactions - cost = uint64(0) - - # Creates a block with transactions, coinbase, and fees - body: Body = Body( - request.coinbase, - request.coinbase_signature, - fees, - aggregate_sig, - transactions_generator, - cost, + # Retrieves the correct head for the challenge + tips: List[SmallHeaderBlock] = self.blockchain.get_current_tips() + target_tip: Optional[SmallHeaderBlock] = None + for tip in tips: + assert tip.challenge + if tip.challenge.get_hash() == request.challenge_hash: + target_tip = tip + if target_tip is None: + # TODO: should we still allow the farmer to farm? + log.warning( + f"Challenge hash: {request.challenge_hash} not in one of three heads" ) + return - # Creates the block header - prev_header_hash: bytes32 = target_head.header.get_hash() - timestamp: uint64 = uint64(int(time.time())) - - # TODO: use a real BIP158 filter based on transactions - filter_hash: bytes32 = token_bytes(32) - proof_of_space_hash: bytes32 = request.proof_of_space.get_hash() - body_hash: Body = body.get_hash() - extension_data: bytes32 = bytes32([0] * 32) - block_header_data: HeaderData = HeaderData( - prev_header_hash, - timestamp, - filter_hash, - proof_of_space_hash, - body_hash, - extension_data, - ) + # TODO: use mempool to grab best transactions, for the selected head + transactions_generator: bytes32 = sha256(b"").digest() + # TODO: calculate the fees of these transactions + fees: FeesTarget = FeesTarget(request.fees_target_puzzle_hash, uint64(0)) + aggregate_sig: Signature = PrivateKey.from_seed(b"12345").sign(b"anything") + # TODO: calculate aggregate signature based on transactions + # TODO: calculate cost of all transactions + cost = uint64(0) + + # Creates a block with transactions, coinbase, and fees + body: Body = Body( + request.coinbase, + request.coinbase_signature, + fees, + aggregate_sig, + transactions_generator, + cost, + ) - block_header_data_hash: bytes32 = block_header_data.get_hash() + # Creates the block header + prev_header_hash: bytes32 = target_tip.header.get_hash() + timestamp: uint64 = uint64(int(time.time())) + + # TODO: use a real BIP158 filter based on transactions + filter_hash: bytes32 = token_bytes(32) + proof_of_space_hash: bytes32 = request.proof_of_space.get_hash() + body_hash: Body = body.get_hash() + extension_data: bytes32 = bytes32([0] * 32) + block_header_data: HeaderData = HeaderData( + prev_header_hash, + timestamp, + filter_hash, + proof_of_space_hash, + body_hash, + extension_data, + ) - # self.stores this block so we can submit it to the blockchain after it's signed by harvester - await self.store.add_candidate_block( - proof_of_space_hash, body, block_header_data, request.proof_of_space - ) + block_header_data_hash: bytes32 = block_header_data.get_hash() + + # self.stores this block so we can submit it to the blockchain after it's signed by harvester + self.store.add_candidate_block( + proof_of_space_hash, body, block_header_data, request.proof_of_space + ) message = farmer_protocol.HeaderHash( proof_of_space_hash, block_header_data_hash @@ -775,25 +783,22 @@ async def header_signature( block, which only needs a Proof of Time to be finished. If the signature is valid, we call the unfinished_block routine. """ - async with self.store.lock: - candidate: Optional[ - Tuple[Body, HeaderData, ProofOfSpace] - ] = await self.store.get_candidate_block(header_signature.pos_hash) - if candidate is None: - log.warning( - f"PoS hash {header_signature.pos_hash} not found in database" - ) - return - # Verifies that we have the correct header and body self.stored - block_body, block_header_data, pos = candidate + candidate: Optional[ + Tuple[Body, HeaderData, ProofOfSpace] + ] = self.store.get_candidate_block(header_signature.pos_hash) + if candidate is None: + log.warning(f"PoS hash {header_signature.pos_hash} not found in database") + return + # Verifies that we have the correct header and body self.stored + block_body, block_header_data, pos = candidate - assert block_header_data.get_hash() == header_signature.header_hash + assert block_header_data.get_hash() == header_signature.header_hash - block_header: Header = Header( - block_header_data, header_signature.header_signature - ) - header: HeaderBlock = HeaderBlock(pos, None, None, block_header) - unfinished_block_obj: FullBlock = FullBlock(header, block_body) + block_header: Header = Header( + block_header_data, header_signature.header_signature + ) + header: HeaderBlock = HeaderBlock(pos, None, None, block_header) + unfinished_block_obj: FullBlock = FullBlock(header, block_body) # Propagate to ourselves (which validates and does further propagations) request = peer_protocol.UnfinishedBlock(unfinished_block_obj) @@ -810,29 +815,28 @@ async def proof_of_time_finished( A proof of time, received by a peer timelord. We can use this to complete a block, and call the block routine (which handles propagation and verification of blocks). """ - async with self.store.lock: - dict_key = ( - request.proof.challenge_hash, - request.proof.number_of_iterations, - ) + dict_key = ( + request.proof.challenge_hash, + request.proof.number_of_iterations, + ) - unfinished_block_obj: Optional[ - FullBlock - ] = await self.store.get_unfinished_block(dict_key) - if not unfinished_block_obj: - log.warning( - f"Received a proof of time that we cannot use to complete a block {dict_key}" - ) - return - prev_full_block = await self.store.get_block( - unfinished_block_obj.prev_header_hash - ) - assert prev_full_block - prev_block: HeaderBlock = prev_full_block.header_block - difficulty: uint64 = self.blockchain.get_next_difficulty( - unfinished_block_obj.prev_header_hash + unfinished_block_obj: Optional[FullBlock] = self.store.get_unfinished_block( + dict_key + ) + if not unfinished_block_obj: + log.warning( + f"Received a proof of time that we cannot use to complete a block {dict_key}" ) - assert prev_block.challenge + return + prev_full_block = await self.store.get_block( + unfinished_block_obj.prev_header_hash + ) + assert prev_full_block + prev_block: HeaderBlock = prev_full_block.header_block + difficulty: uint64 = self.blockchain.get_next_difficulty( + unfinished_block_obj.prev_header_hash + ) + assert prev_block.challenge challenge: Challenge = Challenge( request.proof.challenge_hash, @@ -855,12 +859,8 @@ async def proof_of_time_finished( new_header_block, unfinished_block_obj.body ) - async with self.store.lock: - sync_mode = await self.store.get_sync_mode() - - if sync_mode: - async with self.store.lock: - await self.store.add_potential_future_block(new_full_block) + if self.store.get_sync_mode(): + self.store.add_potential_future_block(new_full_block) else: async for msg in self.block(peer_protocol.Block(new_full_block)): yield msg @@ -876,17 +876,17 @@ async def new_proof_of_time( """ finish_block: bool = False propagate_proof: bool = False - async with self.store.lock: - if await self.store.get_unfinished_block( - ( - new_proof_of_time.proof.challenge_hash, - new_proof_of_time.proof.number_of_iterations, - ) - ): + if self.store.get_unfinished_block( + ( + new_proof_of_time.proof.challenge_hash, + new_proof_of_time.proof.number_of_iterations, + ) + ): + + finish_block = True + elif new_proof_of_time.proof.is_valid(constants["DISCRIMINANT_SIZE_BITS"]): + propagate_proof = True - finish_block = True - elif new_proof_of_time.proof.is_valid(constants["DISCRIMINANT_SIZE_BITS"]): - propagate_proof = True if finish_block: request = timelord_protocol.ProofOfTimeFinished(new_proof_of_time.proof) async for msg in self.proof_of_time_finished(request): @@ -920,6 +920,7 @@ async def unfinished_block( prev_full_block: Optional[FullBlock] = await self.store.get_block( unfinished_block.block.prev_header_hash ) + assert prev_full_block prev_block: HeaderBlock = prev_full_block.header_block @@ -930,9 +931,7 @@ async def unfinished_block( difficulty: uint64 = self.blockchain.get_next_difficulty( unfinished_block.block.header_block.prev_header_hash ) - vdf_ips: uint64 = self.blockchain.get_next_ips( - unfinished_block.block.header_block.prev_header_hash - ) + vdf_ips: uint64 = self.blockchain.get_next_ips(prev_block) iterations_needed: uint64 = calculate_iterations( unfinished_block.block.header_block.proof_of_space, @@ -942,13 +941,13 @@ async def unfinished_block( ) if ( - await self.store.get_unfinished_block((challenge_hash, iterations_needed)) + self.store.get_unfinished_block((challenge_hash, iterations_needed)) is not None ): return expected_time: uint64 = uint64( - int(iterations_needed / (await self.store.get_proof_of_time_estimate_ips())) + int(iterations_needed / (self.store.get_proof_of_time_estimate_ips())) ) if expected_time > constants["PROPAGATION_DELAY_THRESHOLD"]: @@ -956,37 +955,36 @@ async def unfinished_block( # If this block is slow, sleep to allow faster blocks to come out first await asyncio.sleep(5) - async with self.store.lock: - leader: Tuple[uint32, uint64] = self.store.get_unfinished_block_leader() - if leader is None or unfinished_block.block.height > leader[0]: + leader: Tuple[uint32, uint64] = self.store.get_unfinished_block_leader() + if leader is None or unfinished_block.block.height > leader[0]: + log.info( + f"This is the first unfinished block at height {unfinished_block.block.height}, so propagate." + ) + # If this is the first block we see at this height, propagate + self.store.set_unfinished_block_leader( + (unfinished_block.block.height, expected_time) + ) + elif unfinished_block.block.height == leader[0]: + if expected_time > leader[1] + constants["PROPAGATION_THRESHOLD"]: + # If VDF is expected to finish X seconds later than the best, don't propagate log.info( - f"This is the first unfinished block at height {unfinished_block.block.height}, so propagate." + f"VDF will finish too late {expected_time} seconds, so don't propagate" ) - # If this is the first block we see at this height, propagate - self.store.set_unfinished_block_leader( - (unfinished_block.block.height, expected_time) - ) - elif unfinished_block.block.height == leader[0]: - if expected_time > leader[1] + constants["PROPAGATION_THRESHOLD"]: - # If VDF is expected to finish X seconds later than the best, don't propagate - log.info( - f"VDF will finish too late {expected_time} seconds, so don't propagate" - ) - return - elif expected_time < leader[1]: - log.info( - f"New best unfinished block at height {unfinished_block.block.height}" - ) - # If this will be the first block to finalize, update our leader - self.store.set_unfinished_block_leader((leader[0], expected_time)) - else: - # If we have seen an unfinished block at a greater or equal height, don't propagate - log.info(f"Unfinished block at old height, so don't propagate") return + elif expected_time < leader[1]: + log.info( + f"New best unfinished block at height {unfinished_block.block.height}" + ) + # If this will be the first block to finalize, update our leader + self.store.set_unfinished_block_leader((leader[0], expected_time)) + else: + # If we have seen an unfinished block at a greater or equal height, don't propagate + log.info(f"Unfinished block at old height, so don't propagate") + return - await self.store.add_unfinished_block( - (challenge_hash, iterations_needed), unfinished_block.block - ) + self.store.add_unfinished_block( + (challenge_hash, iterations_needed), unfinished_block.block + ) timelord_request = timelord_protocol.ProofOfSpaceInfo( challenge_hash, iterations_needed @@ -1014,24 +1012,35 @@ async def block(self, block: peer_protocol.Block) -> OutboundMessageGenerator: if self.blockchain.cointains_block(header_hash): return + if self.store.get_sync_mode(): + # Add the block to our potential tips list + self.store.add_potential_tip(block.block) + return + + prevalidate_block = await self.blockchain.pre_validate_blocks([block.block]) + val, pos = prevalidate_block[0] + async with self.store.lock: - if await self.store.get_sync_mode(): - # Add the block to our potential tips list - await self.store.add_potential_tip(block.block) - return - prevalidate_block = await self.blockchain.pre_validate_blocks([block.block]) - val, pos = prevalidate_block[0] - # Tries to add the block to the blockchain - added: ReceiveBlockResult = await self.blockchain.receive_block( - block.block, val, pos + prev_block: Optional[FullBlock] = await self.store.get_block( + block.block.prev_header_hash ) + added: ReceiveBlockResult + if prev_block is None: + added = ReceiveBlockResult.DISCONNECTED_BLOCK + else: + # Tries to add the block to the blockchain + added = await self.blockchain.receive_block( + block.block, prev_block.header_block, val, pos + ) + # Always immediately add the block to the database, after updating blockchain state if ( added == ReceiveBlockResult.ADDED_AS_ORPHAN or added == ReceiveBlockResult.ADDED_TO_HEAD ): await self.store.add_block(block.block) + if added == ReceiveBlockResult.ALREADY_HAVE_BLOCK: return elif added == ReceiveBlockResult.INVALID_BLOCK: @@ -1041,21 +1050,20 @@ async def block(self, block: peer_protocol.Block) -> OutboundMessageGenerator: return elif added == ReceiveBlockResult.DISCONNECTED_BLOCK: log.warning(f"Disconnected block {header_hash}") - async with self.store.lock: - tip_height = min( - [head.height for head in self.blockchain.get_current_tips()] - ) + tip_height = min( + [head.height for head in self.blockchain.get_current_tips()] + ) if ( block.block.height > tip_height + self.config["sync_blocks_behind_threshold"] ): async with self.store.lock: - if await self.store.get_sync_mode(): + if self.store.get_sync_mode(): return await self.store.clear_sync_info() - await self.store.add_potential_tip(block.block) - await self.store.set_sync_mode(True) + self.store.add_potential_tip(block.block) + self.store.set_sync_mode(True) log.info( f"We are too far behind this block. Our height is {tip_height} and block is at " f"{block.block.height}" @@ -1080,27 +1088,22 @@ async def block(self, block: peer_protocol.Block) -> OutboundMessageGenerator: "request_block", peer_protocol.RequestBlock(block.block.prev_header_hash), ) - async with self.store.lock: - await self.store.add_disconnected_block(block.block) + self.store.add_disconnected_block(block.block) yield OutboundMessage(NodeType.FULL_NODE, msg, Delivery.RESPOND) return elif added == ReceiveBlockResult.ADDED_TO_HEAD: # Only propagate blocks which extend the blockchain (becomes one of the heads) - ips_changed: bool = False - async with self.store.lock: - log.info( - f"Updated heads, new heights: {[b.height for b in self.blockchain.get_current_tips()]}" - ) + log.info( + f"Updated heads, new heights: {[b.height for b in self.blockchain.get_current_tips()]}" + ) - difficulty = self.blockchain.get_next_difficulty( - block.block.prev_header_hash - ) - next_vdf_ips = self.blockchain.get_next_ips(block.block.header_hash) - log.info(f"Difficulty {difficulty} IPS {next_vdf_ips}") - if next_vdf_ips != await self.store.get_proof_of_time_estimate_ips(): - await self.store.set_proof_of_time_estimate_ips(next_vdf_ips) - ips_changed = True - if ips_changed: + difficulty = self.blockchain.get_next_difficulty( + block.block.prev_header_hash + ) + next_vdf_ips = self.blockchain.get_next_ips(block.block.header_block) + log.info(f"Difficulty {difficulty} IPS {next_vdf_ips}") + if next_vdf_ips != self.store.get_proof_of_time_estimate_ips(): + self.store.set_proof_of_time_estimate_ips(next_vdf_ips) rate_update = farmer_protocol.ProofOfTimeRate(next_vdf_ips) log.info(f"Sending proof of time rate {next_vdf_ips}") yield OutboundMessage( @@ -1160,21 +1163,20 @@ async def block(self, block: peer_protocol.Block) -> OutboundMessageGenerator: # Recursively process the next block if we have it # This code path is reached if added == ADDED_AS_ORPHAN or ADDED_TO_HEAD - async with self.store.lock: - next_block: Optional[ - FullBlock - ] = await self.store.get_disconnected_block_by_prev(block.block.header_hash) + next_block: Optional[FullBlock] = self.store.get_disconnected_block_by_prev( + block.block.header_hash + ) + if next_block is not None: async for ret_msg in self.block(peer_protocol.Block(next_block)): yield ret_msg - async with self.store.lock: - # Removes all temporary data for old blocks - lowest_tip = min(tip.height for tip in self.blockchain.get_current_tips()) - clear_height = uint32(max(0, lowest_tip - 30)) - await self.store.clear_candidate_blocks_below(clear_height) - await self.store.clear_unfinished_blocks_below(clear_height) - await self.store.clear_disconnected_blocks_below(clear_height) + # Removes all temporary data for old blocks + lowest_tip = min(tip.height for tip in self.blockchain.get_current_tips()) + clear_height = uint32(max(0, lowest_tip - 30)) + self.store.clear_candidate_blocks_below(clear_height) + self.store.clear_unfinished_blocks_below(clear_height) + self.store.clear_disconnected_blocks_below(clear_height) @api_request async def request_block( diff --git a/src/harvester.py b/src/harvester.py index 534068752515..701f93733272 100644 --- a/src/harvester.py +++ b/src/harvester.py @@ -5,7 +5,6 @@ from typing import Dict, Optional, Tuple from blspy import PrependSignature, PrivateKey, PublicKey, Util -from yaml import safe_load from chiapos import DiskProver from definitions import ROOT_DIR @@ -20,23 +19,10 @@ class Harvester: - def __init__(self): - config_filename = os.path.join(ROOT_DIR, "config", "config.yaml") - plot_config_filename = os.path.join(ROOT_DIR, "config", "plots.yaml") - key_config_filename = os.path.join(ROOT_DIR, "config", "keys.yaml") - - if not os.path.isfile(key_config_filename): - raise RuntimeError( - "Keys not generated. Run python3.7 ./scripts/regenerate_keys.py." - ) - if not os.path.isfile(plot_config_filename): - raise RuntimeError( - "Plots not generated. Run python3.7 ./scripts/create_plots.py." - ) - - self.config = safe_load(open(config_filename, "r"))["harvester"] - self.key_config = safe_load(open(key_config_filename, "r")) - self.plot_config = safe_load(open(plot_config_filename, "r")) + def __init__(self, config: Dict, key_config: Dict, plot_config: Dict): + self.config: Dict = config + self.key_config: Dict = key_config + self.plot_config: Dict = plot_config # From filename to prover self.provers: Dict[str, DiskProver] = {} @@ -44,7 +30,7 @@ def __init__(self): # From quality to (challenge_hash, filename, index) self.challenge_hashes: Dict[bytes32, Tuple[bytes32, str, uint8]] = {} self._plot_notification_task = asyncio.create_task(self._plot_notification()) - self._is_shutdown = False + self._is_shutdown: bool = False async def _plot_notification(self): """ @@ -74,26 +60,35 @@ async def harvester_handshake( use any plots which don't have one of the pool keys. """ for partial_filename, plot_config in self.plot_config["plots"].items(): + potential_filenames = [partial_filename] if "plot_root" in self.config: - filename = os.path.join(self.config["plot_root"], partial_filename) + potential_filenames.append( + os.path.join(self.config["plot_root"], partial_filename) + ) else: - filename = os.path.join(ROOT_DIR, "plots", partial_filename) + potential_filenames.append( + os.path.join(ROOT_DIR, "plots", partial_filename) + ) pool_pubkey = PublicKey.from_bytes(bytes.fromhex(plot_config["pool_pk"])) # Only use plots that correct pools associated with them - if pool_pubkey in harvester_handshake.pool_pubkeys: + if pool_pubkey not in harvester_handshake.pool_pubkeys: + log.warning( + f"Plot {partial_filename} has a pool key that is not in the farmer's pool_pk list." + ) + continue + + found = False + for filename in potential_filenames: if os.path.isfile(filename): self.provers[partial_filename] = DiskProver(filename) log.info( f"Farming plot {filename} of size {self.provers[partial_filename].get_size()}" ) - else: - log.warn(f"Plot at {filename} does not exist.") - - else: - log.warning( - f"Plot {filename} has a pool key that is not in the farmer's pool_pk list." - ) + found = True + break + if not found: + log.warning(f"Plot at {potential_filenames} does not exist.") @api_request async def new_challenge(self, new_challenge: harvester_protocol.NewChallenge): diff --git a/src/introducer.py b/src/introducer.py index ea975ce3fcbf..ede83a72d93f 100644 --- a/src/introducer.py +++ b/src/introducer.py @@ -1,11 +1,7 @@ import asyncio import logging -import os from typing import AsyncGenerator, Dict -import yaml - -from definitions import ROOT_DIR from src.protocols.peer_protocol import Peers, RequestPeers from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage from src.server.server import ChiaServer @@ -16,9 +12,8 @@ class Introducer: - def __init__(self): - config_filename = os.path.join(ROOT_DIR, "config", "config.yaml") - self.config = yaml.safe_load(open(config_filename, "r"))["introducer"] + def __init__(self, config: Dict): + self.config: Dict = config self.vetted: Dict[bytes32, bool] = {} def set_server(self, server: ChiaServer): diff --git a/src/rpc/rpc_server.py b/src/rpc/rpc_server.py index a0d1593384d7..ce50d403edca 100644 --- a/src/rpc/rpc_server.py +++ b/src/rpc/rpc_server.py @@ -1,18 +1,19 @@ import dataclasses import json -from typing import Any, Callable, List, Optional, Dict +from typing import Any, Callable, List, Optional, Dict, Tuple from aiohttp import web +from blspy import PublicKey from src.full_node import FullNode -from src.types.header_block import SmallHeaderBlock, HeaderBlock +from src.types.header_block import SmallHeaderBlock from src.types.full_block import FullBlock from src.types.peer_info import PeerInfo from src.types.challenge import Challenge -from src.util.ints import uint16, uint64 -from src.util.byte_types import hexstr_to_bytes +from src.util.ints import uint16, uint32, uint64 from src.consensus.block_rewards import calculate_block_reward +from src.util.byte_types import hexstr_to_bytes class EnhancedJSONEncoder(json.JSONEncoder): @@ -52,19 +53,17 @@ async def get_blockchain_state(self, request) -> web.Response: """ Returns a summary of the node's view of the blockchain. """ - tips_hb: List[HeaderBlock] = self.full_node.blockchain.get_current_tips() - lca_hb: HeaderBlock = self.full_node.blockchain.lca_block - tips = [] - for tip in tips_hb: - assert tip.challenge is not None - tips.append(SmallHeaderBlock(tip.header, tip.challenge)) - assert lca_hb.challenge is not None - lca = SmallHeaderBlock(lca_hb.header, lca_hb.challenge) - sync_mode: bool = await self.full_node.store.get_sync_mode() + tips: List[SmallHeaderBlock] = self.full_node.blockchain.get_current_tips() + lca: SmallHeaderBlock = self.full_node.blockchain.lca_block + assert lca.challenge is not None + sync_mode: bool = self.full_node.store.get_sync_mode() difficulty: uint64 = self.full_node.blockchain.get_next_difficulty( - lca_hb.header_hash + lca.header_hash ) - ips: uint64 = self.full_node.blockchain.get_next_ips(lca_hb.header_hash) + lca_hb = ( + await self.full_node.store.get_header_blocks_by_hash([lca.header_hash]) + )[0] + ips: uint64 = self.full_node.blockchain.get_next_ips(lca_hb) response = { "tips": tips, "lca": lca, @@ -96,14 +95,12 @@ async def get_header(self, request) -> web.Response: if "header_hash" not in request_data: raise web.HTTPBadRequest() header_hash = hexstr_to_bytes(request_data["header_hash"]) - header_block: Optional[ - HeaderBlock - ] = self.full_node.blockchain.header_blocks.get(header_hash, None) - if header_block is None or header_block.challenge is None: + small_header_block: Optional[ + SmallHeaderBlock + ] = self.full_node.blockchain.headers.get(header_hash, None) + if small_header_block is None or small_header_block.challenge is None: raise web.HTTPNotFound() - return obj_to_response( - SmallHeaderBlock(header_block.header, header_block.challenge) - ) + return obj_to_response(small_header_block) async def get_connections(self, request) -> web.Response: """ @@ -178,23 +175,18 @@ async def get_pool_balances(self, request) -> web.Response: Retrieves the coinbase balances earned by all pools. TODO: remove after transactions and coins are added. """ - tips: List[HeaderBlock] = self.full_node.blockchain.get_current_tips() - header_block = tips[0] - coin_balances: Dict[str, uint64] = { - f"0x{bytes(header_block.proof_of_space.pool_pubkey).hex()}": calculate_block_reward( - header_block.height - ) - } - while header_block.height != 0: - header_block = self.full_node.blockchain.header_blocks[ - header_block.prev_header_hash - ] - pool_pk = f"0x{bytes(header_block.proof_of_space.pool_pubkey).hex()}" + ppks: List[ + Tuple[uint32, PublicKey] + ] = await self.full_node.store.get_pool_pks_hack() + + coin_balances: Dict[str, uint64] = {} + for height, pk in ppks: + pool_pk = f"0x{bytes(pk).hex()}" if pool_pk not in coin_balances: coin_balances[pool_pk] = uint64(0) coin_balances[pool_pk] = uint64( - coin_balances[pool_pk] + calculate_block_reward(header_block.height) + coin_balances[pool_pk] + calculate_block_reward(height) ) return obj_to_response(coin_balances) @@ -202,14 +194,14 @@ async def get_heaviest_block_seen(self, request) -> web.Response: """ Returns the heaviest block ever seen, whether it's been added to the blockchain or not """ - tips: List[HeaderBlock] = self.full_node.blockchain.get_current_tips() + tips: List[SmallHeaderBlock] = self.full_node.blockchain.get_current_tips() tip_weights = [tip.weight for tip in tips] i = tip_weights.index(max(tip_weights)) assert tips[i].challenge is not None challenge: Challenge = tips[i].challenge # type: ignore max_tip: SmallHeaderBlock = SmallHeaderBlock(tips[i].header, challenge) - if await self.full_node.store.get_sync_mode(): - potential_tips = await self.full_node.store.get_potential_tips_tuples() + if self.full_node.store.get_sync_mode(): + potential_tips = self.full_node.store.get_potential_tips_tuples() for _, pot_block in potential_tips: if pot_block.weight > max_tip.weight: assert pot_block.header_block.challenge is not None diff --git a/src/server/start_farmer.py b/src/server/start_farmer.py index df870db4a470..35552041dc6d 100644 --- a/src/server/start_farmer.py +++ b/src/server/start_farmer.py @@ -1,8 +1,13 @@ import asyncio import signal from typing import List +import logging + +try: + import uvloop +except ImportError: + uvloop = None -import uvloop from blspy import PrivateKey from src.farmer import Farmer @@ -10,24 +15,32 @@ from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage from src.server.server import ChiaServer from src.types.peer_info import PeerInfo -from src.util.network import parse_host_port from src.util.logging import initialize_logging +from src.util.config import load_config, load_config_cli from setproctitle import setproctitle -initialize_logging("Farmer %(name)-25s") -setproctitle("chia_farmer") - async def main(): - farmer = Farmer() + config = load_config_cli("config.yaml", "farmer") + try: + key_config = load_config("keys.yaml") + except FileNotFoundError: + raise RuntimeError( + "Keys not generated. Run python3 ./scripts/regenerate_keys.py." + ) + initialize_logging("Farmer %(name)-25s", config["logging"]) + log = logging.getLogger(__name__) + setproctitle("chia_farmer") + + farmer = Farmer(config, key_config) + harvester_peer = PeerInfo( - farmer.config["harvester_peer"]["host"], farmer.config["harvester_peer"]["port"] + config["harvester_peer"]["host"], config["harvester_peer"]["port"] ) full_node_peer = PeerInfo( - farmer.config["full_node_peer"]["host"], farmer.config["full_node_peer"]["port"] + config["full_node_peer"]["host"], config["full_node_peer"]["port"] ) - host, port = parse_host_port(farmer) - server = ChiaServer(port, farmer, NodeType.FARMER) + server = ChiaServer(config["port"], farmer, NodeType.FARMER) asyncio.get_running_loop().add_signal_handler(signal.SIGINT, server.close_all) asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, server.close_all) @@ -35,21 +48,22 @@ async def main(): async def on_connect(): # Sends a handshake to the harvester pool_sks: List[PrivateKey] = [ - PrivateKey.from_bytes(bytes.fromhex(ce)) - for ce in farmer.key_config["pool_sks"] + PrivateKey.from_bytes(bytes.fromhex(ce)) for ce in key_config["pool_sks"] ] msg = HarvesterHandshake([sk.get_public_key() for sk in pool_sks]) yield OutboundMessage( NodeType.HARVESTER, Message("harvester_handshake", msg), Delivery.BROADCAST ) - _ = await server.start_server(host, on_connect) + _ = await server.start_server(config["host"], on_connect) await asyncio.sleep(1) # Prevents TCP simultaneous connect with harvester _ = await server.start_client(harvester_peer, None) _ = await server.start_client(full_node_peer, None) await server.await_closed() + log.info("Farmer fully closed.") -uvloop.install() +if uvloop is not None: + uvloop.install() asyncio.run(main()) diff --git a/src/server/start_full_node.py b/src/server/start_full_node.py index 427b55031a1b..fa149643da6c 100644 --- a/src/server/start_full_node.py +++ b/src/server/start_full_node.py @@ -2,11 +2,14 @@ import logging import logging.config import signal -import sys -from typing import Dict, List +from typing import List, Dict import miniupnpc -import uvloop + +try: + import uvloop +except ImportError: + uvloop = None from src.blockchain import Blockchain from src.consensus.constants import constants @@ -16,33 +19,27 @@ from src.server.outbound_message import NodeType from src.server.server import ChiaServer from src.types.full_block import FullBlock -from src.types.header_block import HeaderBlock +from src.types.header_block import SmallHeaderBlock from src.types.peer_info import PeerInfo -from src.util.network import parse_host_port from src.util.logging import initialize_logging +from src.util.config import load_config_cli from setproctitle import setproctitle -setproctitle("chia_full_node") -initialize_logging("FullNode %(name)-23s") -log = logging.getLogger(__name__) - -server_closed = False - async def load_header_blocks_from_store( store: FullNodeStore, -) -> Dict[str, HeaderBlock]: - seen_blocks: Dict[str, HeaderBlock] = {} - tips: List[HeaderBlock] = [] - async for full_block in store.get_blocks(): - if not tips or full_block.weight > tips[0].weight: - tips = [full_block.header_block] - seen_blocks[full_block.header_hash] = full_block.header_block +) -> Dict[str, SmallHeaderBlock]: + seen_blocks: Dict[str, SmallHeaderBlock] = {} + tips: List[SmallHeaderBlock] = [] + for small_header_block in await store.get_small_header_blocks(): + if not tips or small_header_block.weight > tips[0].weight: + tips = [small_header_block] + seen_blocks[small_header_block.header_hash] = small_header_block header_blocks = {} if len(tips) > 0: - curr: HeaderBlock = tips[0] - reverse_blocks: List[HeaderBlock] = [curr] + curr: SmallHeaderBlock = tips[0] + reverse_blocks: List[SmallHeaderBlock] = [curr] while curr.height > 0: curr = seen_blocks[curr.prev_header_hash] reverse_blocks.append(curr) @@ -53,42 +50,49 @@ async def load_header_blocks_from_store( async def main(): + config = load_config_cli("config.yaml", "full_node") + setproctitle("chia_full_node") + initialize_logging("FullNode %(name)-23s", config["logging"]) + + log = logging.getLogger(__name__) + server_closed = False + # Create the store (DB) and full node instance - db_id = 0 - if "-id" in sys.argv: - db_id = int(sys.argv[sys.argv.index("-id") + 1]) - store = await FullNodeStore.create(f"blockchain_{db_id}.db") + store = await FullNodeStore.create(f"blockchain_{config['database_id']}.db") genesis: FullBlock = FullBlock.from_bytes(constants["GENESIS_BLOCK"]) await store.add_block(genesis) log.info("Initializing blockchain from disk") - header_blocks: Dict[str, HeaderBlock] = await load_header_blocks_from_store(store) - blockchain = await Blockchain.create(header_blocks) + small_header_blocks: Dict[ + str, SmallHeaderBlock + ] = await load_header_blocks_from_store(store) + blockchain = await Blockchain.create(small_header_blocks) - full_node = FullNode(store, blockchain) - # Starts the full node server (which full nodes can connect to) - host, port = parse_host_port(full_node) + full_node = FullNode(store, blockchain, config) - if full_node.config["enable_upnp"]: - log.info(f"Attempting to enable UPnP (open up port {port})") + if config["enable_upnp"]: + log.info(f"Attempting to enable UPnP (open up port {config['port']})") try: upnp = miniupnpc.UPnP() upnp.discoverdelay = 5 upnp.discover() upnp.selectigd() - upnp.addportmapping(port, "TCP", upnp.lanaddr, port, "chia", "") - log.info(f"Port {port} opened with UPnP.") + upnp.addportmapping( + config["port"], "TCP", upnp.lanaddr, config["port"], "chia", "" + ) + log.info(f"Port {config['port']} opened with UPnP.") except Exception as e: log.warning(f"UPnP failed: {e}") - server = ChiaServer(port, full_node, NodeType.FULL_NODE) + # Starts the full node server (which full nodes can connect to) + server = ChiaServer(config["port"], full_node, NodeType.FULL_NODE) full_node._set_server(server) - _ = await server.start_server(host, full_node._on_connect) + _ = await server.start_server(config["host"], full_node._on_connect) rpc_cleanup = None def master_close_cb(): - global server_closed + nonlocal server_closed if not server_closed: # Called by the UI, when node is closed, or when a signal is sent log.info("Closing all connections, and server...") @@ -96,32 +100,29 @@ def master_close_cb(): server.close_all() server_closed = True - if "-r" in sys.argv: + if config["start_rpc_server"]: # Starts the RPC server if -r is provided - index = sys.argv.index("-r") - rpc_port = int(sys.argv[index + 1]) - rpc_cleanup = await start_rpc_server(full_node, master_close_cb, rpc_port) + rpc_cleanup = await start_rpc_server( + full_node, master_close_cb, config["rpc_port"] + ) asyncio.get_running_loop().add_signal_handler(signal.SIGINT, master_close_cb) asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, master_close_cb) - connect_to_farmer = "-f" in sys.argv - connect_to_timelord = "-t" in sys.argv - full_node._start_bg_tasks() log.info("Waiting to connect to some peers...") await asyncio.sleep(3) log.info(f"Connected to {len(server.global_connections.get_connections())} peers.") - if connect_to_farmer and not server_closed: + if config["connect_to_farmer"] and not server_closed: peer_info = PeerInfo( full_node.config["farmer_peer"]["host"], full_node.config["farmer_peer"]["port"], ) _ = await server.start_client(peer_info, None) - if connect_to_timelord and not server_closed: + if config["connect_to_timelord"] and not server_closed: peer_info = PeerInfo( full_node.config["timelord_peer"]["host"], full_node.config["timelord_peer"]["port"], @@ -130,15 +131,20 @@ def master_close_cb(): # Awaits for server and all connections to close await server.await_closed() + log.info("Closed all node servers.") # Waits for the rpc server to close if rpc_cleanup is not None: await rpc_cleanup() + log.info("Closed RPC server.") await store.close() + log.info("Closed store.") + await asyncio.get_running_loop().shutdown_asyncgens() log.info("Node fully closed.") -uvloop.install() +if uvloop is not None: + uvloop.install() asyncio.run(main()) diff --git a/src/server/start_harvester.py b/src/server/start_harvester.py index 1030efb6aae5..e2273ab2f49d 100644 --- a/src/server/start_harvester.py +++ b/src/server/start_harvester.py @@ -1,25 +1,43 @@ import asyncio import signal +import logging -import uvloop +try: + import uvloop +except ImportError: + uvloop = None from src.harvester import Harvester from src.server.outbound_message import NodeType from src.server.server import ChiaServer from src.types.peer_info import PeerInfo -from src.util.network import parse_host_port from src.util.logging import initialize_logging +from src.util.config import load_config, load_config_cli from setproctitle import setproctitle -initialize_logging("Harvester %(name)-22s") -setproctitle("chia_harvester") - async def main(): - harvester = Harvester() - host, port = parse_host_port(harvester) - server = ChiaServer(port, harvester, NodeType.HARVESTER) - _ = await server.start_server(host, None) + config = load_config_cli("config.yaml", "harvester") + try: + key_config = load_config("keys.yaml") + except FileNotFoundError: + raise RuntimeError( + "Keys not generated. Run python3 ./scripts/regenerate_keys.py." + ) + try: + plot_config = load_config("plots.yaml") + except FileNotFoundError: + raise RuntimeError( + "Plots not generated. Run python3.7 ./scripts/create_plots.py." + ) + + initialize_logging("Harvester %(name)-22s", config["logging"]) + log = logging.getLogger(__name__) + setproctitle("chia_harvester") + + harvester = Harvester(config, key_config, plot_config) + server = ChiaServer(config["port"], harvester, NodeType.HARVESTER) + _ = await server.start_server(config["port"], None) asyncio.get_running_loop().add_signal_handler(signal.SIGINT, server.close_all) asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, server.close_all) @@ -32,7 +50,9 @@ async def main(): await server.await_closed() harvester._shutdown() await harvester._await_shutdown() + log.info("Harvester fully closed.") -uvloop.install() +if uvloop is not None: + uvloop.install() asyncio.run(main()) diff --git a/src/server/start_introducer.py b/src/server/start_introducer.py index 2bd1e495092f..3b22930a29f4 100644 --- a/src/server/start_introducer.py +++ b/src/server/start_introducer.py @@ -1,31 +1,39 @@ import asyncio import signal +import logging -import uvloop +try: + import uvloop +except ImportError: + uvloop = None from src.introducer import Introducer from src.server.outbound_message import NodeType from src.server.server import ChiaServer -from src.util.network import parse_host_port from src.util.logging import initialize_logging +from src.util.config import load_config_cli from setproctitle import setproctitle -initialize_logging("Introducer %(name)-21s") -setproctitle("chia_introducer") - async def main(): - introducer = Introducer() - host, port = parse_host_port(introducer) - server = ChiaServer(port, introducer, NodeType.INTRODUCER) + config = load_config_cli("config.yaml", "introducer") + + initialize_logging("Introducer %(name)-21s", config["logging"]) + log = logging.getLogger(__name__) + setproctitle("chia_introducer") + + introducer = Introducer(config) + server = ChiaServer(config["port"], introducer, NodeType.INTRODUCER) introducer.set_server(server) - _ = await server.start_server(host, None) + _ = await server.start_server(config["host"], None) asyncio.get_running_loop().add_signal_handler(signal.SIGINT, server.close_all) asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, server.close_all) await server.await_closed() + log.info("Introducer fully closed.") -uvloop.install() +if uvloop is not None: + uvloop.install() asyncio.run(main()) diff --git a/src/server/start_timelord.py b/src/server/start_timelord.py index 397d4584d233..f5ca812b3a5d 100644 --- a/src/server/start_timelord.py +++ b/src/server/start_timelord.py @@ -1,23 +1,31 @@ import asyncio import signal +import logging + +try: + import uvloop +except ImportError: + uvloop = None from src.server.outbound_message import NodeType from src.server.server import ChiaServer from src.timelord import Timelord from src.types.peer_info import PeerInfo -from src.util.network import parse_host_port from src.util.logging import initialize_logging +from src.util.config import load_config_cli from setproctitle import setproctitle -initialize_logging("Timelord %(name)-23s") -setproctitle("chia_timelord") - async def main(): - timelord = Timelord() - host, port = parse_host_port(timelord) - server = ChiaServer(port, timelord, NodeType.TIMELORD) - _ = await server.start_server(host, None) + config = load_config_cli("config.yaml", "timelord") + + initialize_logging("Timelord %(name)-23s", config["logging"]) + log = logging.getLogger(__name__) + setproctitle("chia_timelord") + + timelord = Timelord(config) + server = ChiaServer(config["port"], timelord, NodeType.TIMELORD) + _ = await server.start_server(config["host"], None) def signal_received(): server.close_all() @@ -38,6 +46,9 @@ def signal_received(): server.push_message(msg) await server.await_closed() + log.info("Timelord fully closed.") +if uvloop is not None: + uvloop.install() asyncio.run(main()) diff --git a/src/store.py b/src/store.py index 1742b3702797..3713d9ee5d3f 100644 --- a/src/store.py +++ b/src/store.py @@ -1,12 +1,13 @@ import asyncio import logging import aiosqlite -from typing import AsyncGenerator, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple +from blspy import PublicKey from src.types.body import Body from src.types.full_block import FullBlock from src.types.header import HeaderData -from src.types.header_block import HeaderBlock +from src.types.header_block import HeaderBlock, SmallHeaderBlock from src.types.proof_of_space import ProofOfSpace from src.types.sized_bytes import bytes32 from src.util.ints import uint32, uint64 @@ -64,10 +65,20 @@ async def create(cls, db_name: str): "CREATE TABLE IF NOT EXISTS potential_blocks(height bigint PRIMARY KEY, block blob)" ) + # Headers + await self.db.execute( + "CREATE TABLE IF NOT EXISTS small_header_blocks(height bigint, header_hash " + "text PRIMARY KEY, pool_pk text, small_header_block blob)" + ) + # Height index so we can look up in order of height for sync purposes await self.db.execute( "CREATE INDEX IF NOT EXISTS block_height on blocks(height)" ) + await self.db.execute( + "CREATE INDEX IF NOT EXISTS small_header__block_height on small_header_blocks(height)" + ) + await self.db.commit() self.sync_mode = False @@ -96,13 +107,29 @@ async def close(self): async def _clear_database(self): await self.db.execute("DELETE FROM blocks") await self.db.execute("DELETE FROM potential_blocks") + await self.db.execute("DELETE FROM small_header_blocks") await self.db.commit() async def add_block(self, block: FullBlock) -> None: - await self.db.execute( + cursor_1 = await self.db.execute( "INSERT OR REPLACE INTO blocks VALUES(?, ?, ?)", (block.height, block.header_hash.hex(), bytes(block)), ) + await cursor_1.close() + assert block.header_block.challenge is not None + small_header_block: SmallHeaderBlock = SmallHeaderBlock( + block.header_block.header, block.header_block.challenge + ) + cursor_2 = await self.db.execute( + ("INSERT OR REPLACE INTO small_header_blocks VALUES(?, ?, ?, ?)"), + ( + block.height, + block.header_hash.hex(), + bytes(block.header_block.proof_of_space.pool_pubkey).hex(), + bytes(small_header_block), + ), + ) + await cursor_2.close() await self.db.commit() async def get_block(self, header_hash: bytes32) -> Optional[FullBlock]: @@ -110,20 +137,65 @@ async def get_block(self, header_hash: bytes32) -> Optional[FullBlock]: "SELECT * from blocks WHERE header_hash=?", (header_hash.hex(),) ) row = await cursor.fetchone() + await cursor.close() if row is not None: return FullBlock.from_bytes(row[2]) return None - async def get_blocks(self) -> AsyncGenerator[FullBlock, None]: - async with self.db.execute("SELECT * FROM blocks") as cursor: - async for row in cursor: - yield FullBlock.from_bytes(row[2]) + async def get_header_blocks_by_hash( + self, header_hashes: List[bytes32] + ) -> List[HeaderBlock]: + if len(header_hashes) == 0: + return [] + header_hashes_db = tuple(h.hex() for h in header_hashes) + formatted_str = f'SELECT * from blocks WHERE header_hash in ({"?," * (len(header_hashes_db) - 1)}?)' + cursor = await self.db.execute(formatted_str, header_hashes_db) + rows = await cursor.fetchall() + await cursor.close() + header_blocks: List[HeaderBlock] = [] + for row in rows: + header_blocks.append(FullBlock.from_bytes(row[2]).header_block) + + # Sorts the passed in header hashes by hash, with original index + header_hashes_sorted = sorted( + enumerate(header_hashes), key=lambda pair: pair[1] + ) + + # Sorts the fetched header blocks by hash + header_blocks_sorted = sorted(header_blocks, key=lambda hb: hb.header_hash) + + # Combine both and sort by the original indeces + combined = sorted( + zip(header_hashes_sorted, header_blocks_sorted), key=lambda pair: pair[0][0] + ) + + # Return only the header blocks in the original order + return [pair[1] for pair in combined] + + async def get_small_header_blocks(self) -> List[SmallHeaderBlock]: + cursor = await self.db.execute("SELECT * from small_header_blocks") + rows = await cursor.fetchall() + await cursor.close() + return [SmallHeaderBlock.from_bytes(row[3]) for row in rows] + + async def get_pool_pks_hack(self) -> List[Tuple[uint32, PublicKey]]: + # TODO: this API call is a hack to allow us to see block winners. Replace with coin/UTXU set. + cursor = await self.db.execute("SELECT * from small_header_blocks") + rows = await cursor.fetchall() + return [ + ( + SmallHeaderBlock.from_bytes(row[3]).height, + PublicKey.from_bytes(bytes.fromhex(row[2])), + ) + for row in rows + ] async def add_potential_block(self, block: FullBlock) -> None: - await self.db.execute( + cursor = await self.db.execute( "INSERT OR REPLACE INTO potential_blocks VALUES(?, ?)", (block.height, bytes(block)), ) + await cursor.close() await self.db.commit() async def get_potential_block(self, height: uint32) -> Optional[FullBlock]: @@ -131,14 +203,15 @@ async def get_potential_block(self, height: uint32) -> Optional[FullBlock]: "SELECT * from potential_blocks WHERE height=?", (height,) ) row = await cursor.fetchone() + await cursor.close() if row is not None: return FullBlock.from_bytes(row[1]) return None - async def add_disconnected_block(self, block: FullBlock) -> None: + def add_disconnected_block(self, block: FullBlock) -> None: self.disconnected_blocks[block.header_hash] = block - async def get_disconnected_block_by_prev( + def get_disconnected_block_by_prev( self, prev_header_hash: bytes32 ) -> Optional[FullBlock]: for _, block in self.disconnected_blocks.items(): @@ -146,35 +219,35 @@ async def get_disconnected_block_by_prev( return block return None - async def get_disconnected_block(self, header_hash: bytes32) -> Optional[FullBlock]: + def get_disconnected_block(self, header_hash: bytes32) -> Optional[FullBlock]: return self.disconnected_blocks.get(header_hash, None) - async def clear_disconnected_blocks_below(self, height: uint32) -> None: + def clear_disconnected_blocks_below(self, height: uint32) -> None: for key in list(self.disconnected_blocks.keys()): if self.disconnected_blocks[key].height < height: del self.disconnected_blocks[key] - async def set_sync_mode(self, sync_mode: bool) -> None: + def set_sync_mode(self, sync_mode: bool) -> None: self.sync_mode = sync_mode - async def get_sync_mode(self) -> bool: + def get_sync_mode(self) -> bool: return self.sync_mode async def clear_sync_info(self): self.potential_tips.clear() self.potential_headers.clear() - await self.db.execute("DELETE FROM potential_blocks") - await self.db.commit() + cursor = await self.db.execute("DELETE FROM potential_blocks") + await cursor.close() self.potential_blocks_received.clear() self.potential_future_blocks.clear() - async def get_potential_tips_tuples(self) -> List[Tuple[bytes32, FullBlock]]: + def get_potential_tips_tuples(self) -> List[Tuple[bytes32, FullBlock]]: return list(self.potential_tips.items()) - async def add_potential_tip(self, block: FullBlock) -> None: + def add_potential_tip(self, block: FullBlock) -> None: self.potential_tips[block.header_hash] = block - async def get_potential_tip(self, header_hash: bytes32) -> Optional[FullBlock]: + def get_potential_tip(self, header_hash: bytes32) -> Optional[FullBlock]: return self.potential_tips.get(header_hash, None) def add_potential_header(self, block: HeaderBlock) -> None: @@ -210,18 +283,18 @@ def set_potential_blocks_received(self, height: uint32, event: asyncio.Event): def get_potential_blocks_received(self, height: uint32) -> asyncio.Event: return self.potential_blocks_received[height] - async def add_potential_future_block(self, block: FullBlock): + def add_potential_future_block(self, block: FullBlock): self.potential_future_blocks.append(block) - async def get_potential_future_blocks(self): + def get_potential_future_blocks(self): return self.potential_future_blocks - async def add_candidate_block( + def add_candidate_block( self, pos_hash: bytes32, body: Body, header: HeaderData, pos: ProofOfSpace, ): self.candidate_blocks[pos_hash] = (body, header, pos, body.coinbase.height) - async def get_candidate_block( + def get_candidate_block( self, pos_hash: bytes32 ) -> Optional[Tuple[Body, HeaderData, ProofOfSpace]]: res = self.candidate_blocks.get(pos_hash, None) @@ -229,19 +302,17 @@ async def get_candidate_block( return None return (res[0], res[1], res[2]) - async def clear_candidate_blocks_below(self, height: uint32) -> None: + def clear_candidate_blocks_below(self, height: uint32) -> None: for key in list(self.candidate_blocks.keys()): if self.candidate_blocks[key][3] < height: del self.candidate_blocks[key] - async def add_unfinished_block( + def add_unfinished_block( self, key: Tuple[bytes32, uint64], block: FullBlock ) -> None: self.unfinished_blocks[key] = block - async def get_unfinished_block( - self, key: Tuple[bytes32, uint64] - ) -> Optional[FullBlock]: + def get_unfinished_block(self, key: Tuple[bytes32, uint64]) -> Optional[FullBlock]: return self.unfinished_blocks.get(key, None) def seen_unfinished_block(self, header_hash: bytes32) -> bool: @@ -253,10 +324,10 @@ def seen_unfinished_block(self, header_hash: bytes32) -> bool: def clear_seen_unfinished_blocks(self) -> None: self.seen_unfinished_blocks.clear() - async def get_unfinished_blocks(self) -> Dict[Tuple[bytes32, uint64], FullBlock]: + def get_unfinished_blocks(self) -> Dict[Tuple[bytes32, uint64], FullBlock]: return self.unfinished_blocks.copy() - async def clear_unfinished_blocks_below(self, height: uint32) -> None: + def clear_unfinished_blocks_below(self, height: uint32) -> None: for key in list(self.unfinished_blocks.keys()): if self.unfinished_blocks[key].height < height: del self.unfinished_blocks[key] @@ -267,8 +338,8 @@ def set_unfinished_block_leader(self, key: Tuple[bytes32, uint64]) -> None: def get_unfinished_block_leader(self) -> Tuple[bytes32, uint64]: return self.unfinished_blocks_leader - async def set_proof_of_time_estimate_ips(self, estimate: uint64): + def set_proof_of_time_estimate_ips(self, estimate: uint64): self.proof_of_time_estimate_ips = estimate - async def get_proof_of_time_estimate_ips(self) -> uint64: + def get_proof_of_time_estimate_ips(self) -> uint64: return self.proof_of_time_estimate_ips diff --git a/src/timelord.py b/src/timelord.py index 56e93203d8b9..334399556d75 100644 --- a/src/timelord.py +++ b/src/timelord.py @@ -1,14 +1,11 @@ import asyncio import io import logging -import os import time from asyncio import Lock, StreamReader, StreamWriter from typing import Dict, List, Optional, Tuple -from yaml import safe_load -from definitions import ROOT_DIR from lib.chiavdf.inkfish.classgroup import ClassGroup from lib.chiavdf.inkfish.create_discriminant import create_discriminant from lib.chiavdf.inkfish.proof_of_time import check_proof_of_time_nwesolowski @@ -25,9 +22,8 @@ class Timelord: - def __init__(self): - config_filename = os.path.join(ROOT_DIR, "config", "config.yaml") - self.config = safe_load(open(config_filename, "r"))["timelord"] + def __init__(self, config: Dict): + self.config: Dict = config self.free_servers: List[Tuple[str, str]] = list( zip(self.config["vdf_server_ips"], self.config["vdf_server_ports"]) ) @@ -251,7 +247,7 @@ async def _do_process_communication( try: data = await reader.readexactly(4) except (asyncio.IncompleteReadError, ConnectionResetError) as e: - log.warn(f"{type(e)} {e}") + log.warning(f"{type(e)} {e}") break if data.decode() == "STOP": diff --git a/src/ui/prompt_ui.py b/src/ui/prompt_ui.py index f45fab311b44..597a6fea5710 100644 --- a/src/ui/prompt_ui.py +++ b/src/ui/prompt_ui.py @@ -340,7 +340,7 @@ async def inner(): else: self.syncing.text = f"Syncing" else: - self.syncing.text = "Not syncing" + self.syncing.text = "Synced" total_iters = self.lca_block.challenge.total_iters @@ -489,7 +489,7 @@ async def update_data(self): self.latest_blocks = await self.get_latest_blocks(self.tips) self.data_initialized = True - if counter % 20 == 0: + if counter % 50 == 0: # Only request balances periodically, since it's an expensive operation coin_balances: Dict[ bytes, uint64 diff --git a/src/ui/start_ui.py b/src/ui/start_ui.py index 5b180f7112a1..68f90f17c343 100644 --- a/src/ui/start_ui.py +++ b/src/ui/start_ui.py @@ -1,28 +1,19 @@ import asyncio import signal -import sys -import yaml -import os from src.ui.prompt_ui import start_ssh_server -from definitions import ROOT_DIR from src.util.logging import initialize_logging +from src.util.config import load_config_cli from setproctitle import setproctitle -initialize_logging("UI %(name)-29s") -setproctitle("chia_full_node_ui") - async def main(): - config_filename = os.path.join(ROOT_DIR, "config", "config.yaml") - config = yaml.safe_load(open(config_filename, "r"))["full_node"] - - rpc_index = sys.argv.index("-r") - rpc_port = int(sys.argv[rpc_index + 1]) + config = load_config_cli("config.yaml", "ui") + initialize_logging("UI %(name)-29s", config["logging"]) + setproctitle("chia_full_node_ui") - port = int(sys.argv[1]) await_all_closed, ui_close_cb = await start_ssh_server( - port, config["ssh_filename"], rpc_port + config["port"], config["ssh_filename"], config["rpc_port"] ) asyncio.get_running_loop().add_signal_handler( diff --git a/src/util/byte_types.py b/src/util/byte_types.py index e97408fde09b..ab9b771e9ca8 100644 --- a/src/util/byte_types.py +++ b/src/util/byte_types.py @@ -18,11 +18,11 @@ def make_sized_bytes(size): """ name = "bytes%d" % size - def __new__(self, v): + def __new__(cls, v): v = bytes(v) if not isinstance(v, bytes) or len(v) != size: raise ValueError("bad %s initializer %s" % (name, v)) - return bytes.__new__(self, v) # type: ignore + return bytes.__new__(cls, v) # type: ignore @classmethod # type: ignore def parse(cls, f: BinaryIO) -> Any: diff --git a/src/util/config.py b/src/util/config.py new file mode 100644 index 000000000000..da39e033ce71 --- /dev/null +++ b/src/util/config.py @@ -0,0 +1,82 @@ +import os + +import yaml +import argparse +from typing import Dict, Any, Callable, Optional +from definitions import ROOT_DIR + + +def load_config(filename: str, sub_config: Optional[str] = None) -> Dict: + config_filename = os.path.join(ROOT_DIR, "config", filename) + if sub_config is not None: + return yaml.safe_load(open(config_filename, "r"))[sub_config] + else: + return yaml.safe_load(open(config_filename, "r")) + + +def load_config_cli(filename: str, sub_config: Optional[str] = None) -> Dict: + """ + Loads configuration from the specified filename, in the config directory, + and then overrides any properties using the passed in command line arguments. + Nested properties in the config file can be used in the command line with ".", + for example --farmer_peer.host. Does not support lists. + """ + config = load_config(filename, sub_config) + + flattened_props = flatten_properties(config) + parser = argparse.ArgumentParser() + + for prop_name, value in flattened_props.items(): + if type(value) is list: + continue + prop_type: Callable = str2bool if type(value) is bool else type(value) # type: ignore + parser.add_argument(f"--{prop_name}", type=prop_type, dest=prop_name) + + for key, value in vars(parser.parse_args()).items(): + if value is not None: + flattened_props[key] = value + + return unflatten_properties(flattened_props) + + +def flatten_properties(config: Dict): + properties = {} + for key, value in config.items(): + if type(value) is dict: + for key_2, value_2 in flatten_properties(value).items(): + properties[key + "." + key_2] = value_2 + else: + properties[key] = value + return properties + + +def unflatten_properties(config: Dict): + properties: Dict = {} + for key, value in config.items(): + if "." in key: + add_property(properties, key, value) + else: + properties[key] = value + return properties + + +def add_property(d: Dict, partial_key: str, value: Any): + key_1, key_2 = partial_key.split(".") + if key_1 not in d: + d[key_1] = {} + if "." in key_2: + add_property(d, key_2, value) + else: + d[key_1][key_2] = value + + +def str2bool(v: Any) -> bool: + # Source from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") diff --git a/src/util/logging.py b/src/util/logging.py index 55f84ab35af3..50064261a283 100644 --- a/src/util/logging.py +++ b/src/util/logging.py @@ -1,17 +1,30 @@ import logging import colorlog +from typing import Dict -def initialize_logging(prefix): - handler = colorlog.StreamHandler() - handler.setFormatter( - colorlog.ColoredFormatter( - f"{prefix}: %(log_color)s%(levelname)-8s%(reset)s %(asctime)s.%(msecs)03d %(message)s", - datefmt="%H:%M:%S", - reset=True, +def initialize_logging(prefix: str, logging_config: Dict): + if logging_config["log_stdout"]: + handler = colorlog.StreamHandler() + handler.setFormatter( + colorlog.ColoredFormatter( + f"{prefix}: %(log_color)s%(levelname)-8s%(reset)s %(asctime)s.%(msecs)03d %(message)s", + datefmt="%H:%M:%S", + reset=True, + ) ) - ) - logger = colorlog.getLogger() - logger.addHandler(handler) + logger = colorlog.getLogger() + logger.addHandler(handler) + else: + print( + f"Starting process and logging to {logging_config['log_filename']}. Run with & to run in the background." + ) + logging.basicConfig( + filename=logging_config["log_filename"], + filemode="a", + format=f"{prefix}: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s", + datefmt="%H:%M:%S", + ) + logger = logging.getLogger() logger.setLevel(logging.INFO) diff --git a/src/util/network.py b/src/util/network.py index 1228498bbafd..55ae27168664 100644 --- a/src/util/network.py +++ b/src/util/network.py @@ -1,16 +1,8 @@ import secrets -import sys -from typing import Tuple from src.types.sized_bytes import bytes32 -def parse_host_port(api) -> Tuple[str, int]: - host: str = sys.argv[1] if len(sys.argv) >= 3 else api.config["host"] - port: int = int(sys.argv[2]) if len(sys.argv) >= 3 else api.config["port"] - return (host, port) - - def create_node_id() -> bytes32: """Generates a transient random node_id.""" return bytes32(secrets.token_bytes(32)) diff --git a/src/util/struct_stream.py b/src/util/struct_stream.py index 44942b4c15dc..f400c4f5b248 100644 --- a/src/util/struct_stream.py +++ b/src/util/struct_stream.py @@ -18,7 +18,6 @@ def __new__(cls: Any, value: int): f"Value {value} of size {value.bit_length()} does not fit into " f"{cls.__name__} of size {bits}" ) - return int.__new__(cls, value) # type: ignore @classmethod diff --git a/tests/block_tools.py b/tests/block_tools.py index 0a39ba1d489f..06954a961cc5 100644 --- a/tests/block_tools.py +++ b/tests/block_tools.py @@ -54,28 +54,34 @@ def __init__(self): plot_seeds: List[bytes32] = [ ProofOfSpace.calculate_plot_seed(pool_pk, plot_pk) for plot_pk in plot_pks ] + self.plot_dir = os.path.join("tests", "plots") self.filenames: List[str] = [ - os.path.join( - "tests", - "plots", - "genesis-plots-" - + str(k) - + sha256(int.to_bytes(i, 4, "big")).digest().hex() - + ".dat", - ) + "genesis-plots-" + + str(k) + + sha256(int.to_bytes(i, 4, "big")).digest().hex() + + ".dat" for i in range(num_plots) ] done_filenames = set() try: for pn, filename in enumerate(self.filenames): - if not os.path.exists(filename): + if not os.path.exists(os.path.join(self.plot_dir, filename)): plotter = DiskPlotter() - plotter.create_plot_disk(filename, k, b"genesis", plot_seeds[pn]) + plotter.create_plot_disk( + self.plot_dir, + self.plot_dir, + filename, + k, + b"genesis", + plot_seeds[pn], + ) done_filenames.add(filename) except KeyboardInterrupt: for filename in self.filenames: - if filename not in done_filenames and os.path.exists(filename): - os.remove(filename) + if filename not in done_filenames and os.path.exists( + os.path.join(self.plot_dir, filename) + ): + os.remove(os.path.join(self.plot_dir, filename)) sys.exit(1) def get_consecutive_blocks( @@ -317,7 +323,7 @@ def _create_block( filename = self.filenames[seeded_pn] plot_pk = plot_pks[seeded_pn] plot_sk = plot_sks[seeded_pn] - prover = DiskProver(filename) + prover = DiskProver(os.path.join(self.plot_dir, filename)) qualities = prover.get_qualities_for_challenge(challenge_hash) if len(qualities) > 0: break diff --git a/tests/rpc/test_rpc.py b/tests/rpc/test_rpc.py index 7f33a169d84d..65f4634a0c02 100644 --- a/tests/rpc/test_rpc.py +++ b/tests/rpc/test_rpc.py @@ -1,5 +1,6 @@ import asyncio from typing import Any, Dict +import os import pytest @@ -11,6 +12,7 @@ from tests.block_tools import BlockTools from src.rpc.rpc_server import start_rpc_server from src.rpc.rpc_client import RpcClient +from src.util.config import load_config bt = BlockTools() @@ -42,19 +44,23 @@ async def test1(self): test_node_1_port = 21234 test_node_2_port = 21235 test_rpc_port = 21236 + db_filename = "blockchain_test" - store = await FullNodeStore.create("blockchain_test") + if os.path.isfile(db_filename): + os.remove(db_filename) + store = await FullNodeStore.create(db_filename) await store._clear_database() blocks = bt.get_consecutive_blocks(test_constants, 10, [], 10) b: Blockchain = await Blockchain.create({}, test_constants) await store.add_block(blocks[0]) for i in range(1, 9): assert ( - await b.receive_block(blocks[i]) + await b.receive_block(blocks[i], blocks[i - 1].header_block) ) == ReceiveBlockResult.ADDED_TO_HEAD await store.add_block(blocks[i]) - full_node_1 = FullNode(store, b) + config = load_config("config.yaml", "full_node") + full_node_1 = FullNode(store, b, config) server_1 = ChiaServer(test_node_1_port, full_node_1, NodeType.FULL_NODE) _ = await server_1.start_server("127.0.0.1", None) full_node_1._set_server(server_1) @@ -65,47 +71,56 @@ def stop_node_cb(): rpc_cleanup = await start_rpc_server(full_node_1, stop_node_cb, test_rpc_port) - client = await RpcClient.create(test_rpc_port) - state = await client.get_blockchain_state() - assert state["lca"].header_hash is not None - assert not state["sync_mode"] - assert len(state["tips"]) > 0 - assert state["difficulty"] > 0 - assert state["ips"] > 0 + try: + client = await RpcClient.create(test_rpc_port) + state = await client.get_blockchain_state() + assert state["lca"].header_hash is not None + assert not state["sync_mode"] + assert len(state["tips"]) > 0 + assert state["difficulty"] > 0 + assert state["ips"] > 0 + + block = await client.get_block(state["lca"].header_hash) + assert block == blocks[6] + assert (await client.get_block(bytes([1] * 32))) is None + + small_header_block = await client.get_header(state["lca"].header_hash) + assert small_header_block.header == blocks[6].header_block.header + + assert len(await client.get_pool_balances()) > 0 + assert len(await client.get_connections()) == 0 + + full_node_2 = FullNode(store, b, config) + server_2 = ChiaServer(test_node_2_port, full_node_2, NodeType.FULL_NODE) + full_node_2._set_server(server_2) + + _ = await server_2.start_server("127.0.0.1", None) + await asyncio.sleep(2) # Allow server to start + cons = await client.get_connections() + assert len(cons) == 0 + + # Open a connection through the RPC + await client.open_connection(host="127.0.0.1", port=test_node_2_port) + cons = await client.get_connections() + assert len(cons) == 1 + + # Close a connection through the RPC + await client.close_connection(cons[0]["node_id"]) + cons = await client.get_connections() + assert len(cons) == 0 + except AssertionError: + # Checks that the RPC manages to stop the node + await client.stop_node() + client.close() + await client.await_closed() + server_2.close_all() + await server_1.await_closed() + await server_2.await_closed() + await rpc_cleanup() + await store.close() + raise - block = await client.get_block(state["lca"].header_hash) - assert block == blocks[6] - assert (await client.get_block(bytes([1] * 32))) is None - - small_header_block = await client.get_header(state["lca"].header_hash) - assert small_header_block.header == blocks[6].header_block.header - - assert len(await client.get_pool_balances()) > 0 - assert len(await client.get_connections()) == 0 - - full_node_2 = FullNode(store, b) - server_2 = ChiaServer(test_node_2_port, full_node_2, NodeType.FULL_NODE) - full_node_2._set_server(server_2) - - _ = await server_2.start_server("127.0.0.1", None) - await asyncio.sleep(2) # Allow server to start - - cons = await client.get_connections() - assert len(cons) == 0 - - # Open a connection through the RPC - await client.open_connection(host="127.0.0.1", port=test_node_2_port) - cons = await client.get_connections() - assert len(cons) == 1 - - # Close a connection through the RPC - await client.close_connection(cons[0]["node_id"]) - cons = await client.get_connections() - assert len(cons) == 0 - - # Checks that the RPC manages to stop the node await client.stop_node() - client.close() await client.await_closed() server_2.close_all() diff --git a/tests/setup_nodes.py b/tests/setup_nodes.py index bfa5a35d06dd..78ed14bb6b00 100644 --- a/tests/setup_nodes.py +++ b/tests/setup_nodes.py @@ -1,3 +1,4 @@ +import os from typing import Any, Dict from src.blockchain import Blockchain @@ -7,6 +8,7 @@ from src.server.server import ChiaServer from src.types.full_block import FullBlock from tests.block_tools import BlockTools +from src.util.config import load_config bt = BlockTools() @@ -41,12 +43,13 @@ async def setup_two_nodes(): await store_1.add_block(FullBlock.from_bytes(test_constants["GENESIS_BLOCK"])) await store_2.add_block(FullBlock.from_bytes(test_constants["GENESIS_BLOCK"])) - full_node_1 = FullNode(store_1, b_1) + config = load_config("config.yaml", "full_node") + full_node_1 = FullNode(store_1, b_1, config) server_1 = ChiaServer(21234, full_node_1, NodeType.FULL_NODE) _ = await server_1.start_server("127.0.0.1", full_node_1._on_connect) full_node_1._set_server(server_1) - full_node_2 = FullNode(store_2, b_2) + full_node_2 = FullNode(store_2, b_2, config) server_2 = ChiaServer(21235, full_node_2, NodeType.FULL_NODE) full_node_2._set_server(server_2) @@ -61,3 +64,5 @@ async def setup_two_nodes(): await server_2.await_closed() await store_1.close() await store_2.close() + os.remove("blockchain_test") + os.remove("blockchain_test_2") diff --git a/tests/test_blockchain.py b/tests/test_blockchain.py index fb0085991949..017d8d9fe1c7 100644 --- a/tests/test_blockchain.py +++ b/tests/test_blockchain.py @@ -14,6 +14,7 @@ from src.types.header_block import HeaderBlock from src.types.proof_of_space import ProofOfSpace from src.util.ints import uint8, uint32, uint64 +from src.util.errors import BlockNotInBlockchain from tests.block_tools import BlockTools bt = BlockTools() @@ -48,12 +49,12 @@ async def test_basic_blockchain(self): assert genesis_block.height == 0 assert genesis_block.challenge assert ( - bc1.get_header_blocks_by_height([uint32(0)], genesis_block.header_hash) - )[0] == genesis_block + bc1.get_header_hashes_by_height([uint32(0)], genesis_block.header_hash) + )[0] == genesis_block.header_hash assert ( bc1.get_next_difficulty(genesis_block.header_hash) ) == genesis_block.challenge.total_weight - assert bc1.get_next_ips(genesis_block.header_hash) > 0 + assert bc1.get_next_ips(bc1.genesis.header_block) > 0 class TestBlockValidation: @@ -66,10 +67,36 @@ async def initial_blockchain(self): b: Blockchain = await Blockchain.create({}, test_constants) for i in range(1, 9): assert ( - await b.receive_block(blocks[i]) + await b.receive_block(blocks[i], blocks[i - 1].header_block) ) == ReceiveBlockResult.ADDED_TO_HEAD return (blocks, b) + @pytest.mark.asyncio + async def test_get_header_hashes(self, initial_blockchain): + blocks, b = initial_blockchain + header_hashes_1 = b.get_header_hashes_by_height( + [0, 8, 3], blocks[8].header_hash + ) + assert header_hashes_1 == [ + blocks[0].header_hash, + blocks[8].header_hash, + blocks[3].header_hash, + ] + + try: + b.get_header_hashes_by_height([0, 8, 3], blocks[6].header_hash) + thrown = False + except ValueError: + thrown = True + assert thrown + + try: + b.get_header_hashes_by_height([0, 8, 3], blocks[9].header_hash) + thrown_2 = False + except BlockNotInBlockchain: + thrown_2 = True + assert thrown_2 + @pytest.mark.asyncio async def test_prev_pointer(self, initial_blockchain): blocks, b = initial_blockchain @@ -93,7 +120,7 @@ async def test_prev_pointer(self, initial_blockchain): blocks[9].body, ) assert ( - await b.receive_block(block_bad) + await b.receive_block(block_bad, blocks[8].header_block) ) == ReceiveBlockResult.DISCONNECTED_BLOCK @pytest.mark.asyncio @@ -119,7 +146,9 @@ async def test_timestamp(self, initial_blockchain): ), blocks[9].body, ) - assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK + assert ( + await b.receive_block(block_bad, blocks[8].header_block) + ) == ReceiveBlockResult.INVALID_BLOCK # Time too far in the future block_bad = FullBlock( @@ -142,7 +171,9 @@ async def test_timestamp(self, initial_blockchain): blocks[9].body, ) - assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK + assert ( + await b.receive_block(block_bad, blocks[8].header_block) + ) == ReceiveBlockResult.INVALID_BLOCK @pytest.mark.asyncio async def test_body_hash(self, initial_blockchain): @@ -167,7 +198,9 @@ async def test_body_hash(self, initial_blockchain): blocks[9].body, ) - assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK + assert ( + await b.receive_block(block_bad, blocks[8].header_block) + ) == ReceiveBlockResult.INVALID_BLOCK @pytest.mark.asyncio async def test_harvester_signature(self, initial_blockchain): @@ -185,7 +218,9 @@ async def test_harvester_signature(self, initial_blockchain): ), blocks[9].body, ) - assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK + assert ( + await b.receive_block(block_bad, blocks[8].header_block) + ) == ReceiveBlockResult.INVALID_BLOCK @pytest.mark.asyncio async def test_invalid_pos(self, initial_blockchain): @@ -209,7 +244,9 @@ async def test_invalid_pos(self, initial_blockchain): ), blocks[9].body, ) - assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK + assert ( + await b.receive_block(block_bad, blocks[8].header_block) + ) == ReceiveBlockResult.INVALID_BLOCK @pytest.mark.asyncio async def test_invalid_coinbase_height(self, initial_blockchain): @@ -231,7 +268,9 @@ async def test_invalid_coinbase_height(self, initial_blockchain): blocks[9].body.cost, ), ) - assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK + assert ( + await b.receive_block(block_bad, blocks[8].header_block) + ) == ReceiveBlockResult.INVALID_BLOCK @pytest.mark.asyncio async def test_difficulty_change(self): @@ -242,7 +281,7 @@ async def test_difficulty_change(self): b: Blockchain = await Blockchain.create({}, test_constants) for i in range(1, num_blocks): assert ( - await b.receive_block(blocks[i]) + await b.receive_block(blocks[i], blocks[i - 1].header_block) ) == ReceiveBlockResult.ADDED_TO_HEAD diff_25 = b.get_next_difficulty(blocks[24].header_hash) @@ -253,18 +292,18 @@ async def test_difficulty_change(self): assert diff_27 > diff_26 assert (diff_27 / diff_26) <= test_constants["DIFFICULTY_FACTOR"] - assert (b.get_next_ips(blocks[1].header_hash)) == constants["VDF_IPS_STARTING"] - assert (b.get_next_ips(blocks[24].header_hash)) == ( - b.get_next_ips(blocks[23].header_hash) + assert (b.get_next_ips(blocks[1].header_block)) == constants["VDF_IPS_STARTING"] + assert (b.get_next_ips(blocks[24].header_block)) == ( + b.get_next_ips(blocks[23].header_block) ) - assert (b.get_next_ips(blocks[25].header_hash)) == ( - b.get_next_ips(blocks[24].header_hash) + assert (b.get_next_ips(blocks[25].header_block)) == ( + b.get_next_ips(blocks[24].header_block) ) - assert (b.get_next_ips(blocks[26].header_hash)) > ( - b.get_next_ips(blocks[25].header_hash) + assert (b.get_next_ips(blocks[26].header_block)) > ( + b.get_next_ips(blocks[25].header_block) ) - assert (b.get_next_ips(blocks[27].header_hash)) == ( - b.get_next_ips(blocks[26].header_hash) + assert (b.get_next_ips(blocks[27].header_block)) == ( + b.get_next_ips(blocks[26].header_block) ) @@ -274,15 +313,18 @@ async def test_basic_reorg(self): blocks = bt.get_consecutive_blocks(test_constants, 100, [], 9) b: Blockchain = await Blockchain.create({}, test_constants) - for block in blocks: - await b.receive_block(block) + for i in range(1, len(blocks)): + await b.receive_block(blocks[i], blocks[i - 1].header_block) assert b.get_current_tips()[0].height == 100 blocks_reorg_chain = bt.get_consecutive_blocks( test_constants, 30, blocks[:90], 9, b"1" ) - for reorg_block in blocks_reorg_chain: - result = await b.receive_block(reorg_block) + for i in range(1, len(blocks_reorg_chain)): + reorg_block = blocks_reorg_chain[i] + result = await b.receive_block( + reorg_block, blocks_reorg_chain[i - 1].header_block + ) if reorg_block.height < 90: assert result == ReceiveBlockResult.ALREADY_HAVE_BLOCK elif reorg_block.height < 99: @@ -294,17 +336,21 @@ async def test_basic_reorg(self): @pytest.mark.asyncio async def test_reorg_from_genesis(self): blocks = bt.get_consecutive_blocks(test_constants, 20, [], 9, b"0") + print(len(blocks)) b: Blockchain = await Blockchain.create({}, test_constants) - for block in blocks: - await b.receive_block(block) + for i in range(1, len(blocks)): + await b.receive_block(blocks[i], blocks[i - 1].header_block) assert b.get_current_tips()[0].height == 20 # Reorg from genesis blocks_reorg_chain = bt.get_consecutive_blocks( test_constants, 21, [blocks[0]], 9, b"1" ) - for reorg_block in blocks_reorg_chain: - result = await b.receive_block(reorg_block) + for i in range(1, len(blocks_reorg_chain)): + reorg_block = blocks_reorg_chain[i] + result = await b.receive_block( + reorg_block, blocks_reorg_chain[i - 1].header_block + ) if reorg_block.height == 0: assert result == ReceiveBlockResult.ALREADY_HAVE_BLOCK elif reorg_block.height < 19: @@ -315,45 +361,52 @@ async def test_reorg_from_genesis(self): # Reorg back to original branch blocks_reorg_chain_2 = bt.get_consecutive_blocks( - test_constants, 3, blocks, 9, b"3" + test_constants, 3, blocks[:-1], 9, b"3" + ) + assert ( + await b.receive_block( + blocks_reorg_chain_2[20], blocks_reorg_chain_2[19].header_block + ) + == ReceiveBlockResult.ADDED_AS_ORPHAN ) - await b.receive_block( - blocks_reorg_chain_2[20] - ) == ReceiveBlockResult.ADDED_AS_ORPHAN assert ( - await b.receive_block(blocks_reorg_chain_2[21]) + await b.receive_block( + blocks_reorg_chain_2[21], blocks_reorg_chain_2[20].header_block + ) ) == ReceiveBlockResult.ADDED_TO_HEAD assert ( - await b.receive_block(blocks_reorg_chain_2[22]) + await b.receive_block( + blocks_reorg_chain_2[22], blocks_reorg_chain_2[21].header_block + ) ) == ReceiveBlockResult.ADDED_TO_HEAD @pytest.mark.asyncio async def test_lca(self): blocks = bt.get_consecutive_blocks(test_constants, 5, [], 9, b"0") b: Blockchain = await Blockchain.create({}, test_constants) - for block in blocks: - await b.receive_block(block) + for i in range(1, len(blocks)): + await b.receive_block(blocks[i], blocks[i - 1].header_block) - assert b.lca_block == blocks[3].header_block - block_5_2 = bt.get_consecutive_blocks(test_constants, 1, blocks[:5], 9, b"1")[5] - block_5_3 = bt.get_consecutive_blocks(test_constants, 1, blocks[:5], 9, b"2")[5] + assert b.lca_block.header_hash == blocks[3].header_block.header_hash + block_5_2 = bt.get_consecutive_blocks(test_constants, 1, blocks[:5], 9, b"1") + block_5_3 = bt.get_consecutive_blocks(test_constants, 1, blocks[:5], 9, b"2") - await b.receive_block(block_5_2) - assert b.lca_block == blocks[4].header_block - await b.receive_block(block_5_3) - assert b.lca_block == blocks[4].header_block + await b.receive_block(block_5_2[5], block_5_2[4].header_block) + assert b.lca_block.header_hash == blocks[4].header_block.header_hash + await b.receive_block(block_5_3[5], block_5_3[4].header_block) + assert b.lca_block.header_hash == blocks[4].header_block.header_hash reorg = bt.get_consecutive_blocks(test_constants, 6, [], 9, b"3") - for block in reorg: - await b.receive_block(block) - assert b.lca_block == blocks[0].header_block + for i in range(1, len(reorg)): + await b.receive_block(reorg[i], reorg[i - 1].header_block) + assert b.lca_block.header_hash == blocks[0].header_block.header_hash @pytest.mark.asyncio async def test_get_header_hashes(self): blocks = bt.get_consecutive_blocks(test_constants, 5, [], 9, b"0") b: Blockchain = await Blockchain.create({}, test_constants) - for block in blocks: - await b.receive_block(block) + for i in range(1, len(blocks)): + await b.receive_block(blocks[i], blocks[i - 1].header_block) header_hashes = b.get_header_hashes(blocks[-1].header_hash) assert len(header_hashes) == 6 print(header_hashes) diff --git a/tests/test_store.py b/tests/test_store.py index fe2d3a7c6819..565fdf07659d 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -1,6 +1,9 @@ import asyncio from secrets import token_bytes from typing import Any, Dict +import os +import sqlite3 +import random import pytest from src.consensus.constants import constants @@ -36,10 +39,21 @@ def event_loop(): class TestStore: @pytest.mark.asyncio async def test_basic_store(self): + assert sqlite3.threadsafety == 1 blocks = bt.get_consecutive_blocks(test_constants, 9, [], 9, b"0") - - db = await FullNodeStore.create("blockchain_test") - db_2 = await FullNodeStore.create("blockchain_test_2") + db_filename = "blockchain_test" + db_filename_2 = "blockchain_test_2" + db_filename_3 = "blockchain_test_3" + + if os.path.isfile(db_filename): + os.remove(db_filename) + if os.path.isfile(db_filename_2): + os.remove(db_filename_2) + if os.path.isfile(db_filename_3): + os.remove(db_filename_3) + + db = await FullNodeStore.create(db_filename) + db_2 = await FullNodeStore.create(db_filename_2) try: await db._clear_database() @@ -50,17 +64,27 @@ async def test_basic_store(self): await db.add_block(block) assert block == await db.get_block(block.header_hash) + # Get small header blocks + assert len(await db.get_small_header_blocks()) == len(blocks) + + # Get header_blocks + header_blocks = await db.get_header_blocks_by_hash( + [blocks[4].header_hash, blocks[0].header_hash] + ) + assert header_blocks[0] == blocks[4].header_block + assert header_blocks[1] == blocks[0].header_block + # Save/get sync for sync_mode in (False, True): - await db.set_sync_mode(sync_mode) - assert sync_mode == await db.get_sync_mode() + db.set_sync_mode(sync_mode) + assert sync_mode == db.get_sync_mode() # clear sync info await db.clear_sync_info() # add/get potential tip, get potential tips num - await db.add_potential_tip(blocks[6]) - assert blocks[6] == await db.get_potential_tip(blocks[6].header_hash) + db.add_potential_tip(blocks[6]) + assert blocks[6] == db.get_potential_tip(blocks[6].header_hash) # add/get potential trunk header = genesis.header_block @@ -72,16 +96,16 @@ async def test_basic_store(self): assert genesis == await db.get_potential_block(uint32(0)) # Add/get candidate block - assert await db.get_candidate_block(0) is None + assert db.get_candidate_block(0) is None partial = ( blocks[5].body, blocks[5].header_block.header.data, blocks[5].header_block.proof_of_space, ) - await db.add_candidate_block(blocks[5].header_hash, *partial) - assert await db.get_candidate_block(blocks[5].header_hash) == partial - await db.clear_candidate_blocks_below(uint32(8)) - assert await db.get_candidate_block(blocks[5].header_hash) is None + db.add_candidate_block(blocks[5].header_hash, *partial) + assert db.get_candidate_block(blocks[5].header_hash) == partial + db.clear_candidate_blocks_below(uint32(8)) + assert db.get_candidate_block(blocks[5].header_hash) is None # Add/get unfinished block i = 1 @@ -89,29 +113,29 @@ async def test_basic_store(self): key = (block.header_hash, uint64(1000)) # Different database should have different data - await db_2.add_unfinished_block(key, block) + db_2.add_unfinished_block(key, block) - assert await db.get_unfinished_block(key) is None - await db.add_unfinished_block(key, block) - assert await db.get_unfinished_block(key) == block - assert len(await db.get_unfinished_blocks()) == i + assert db.get_unfinished_block(key) is None + db.add_unfinished_block(key, block) + assert db.get_unfinished_block(key) == block + assert len(db.get_unfinished_blocks()) == i i += 1 - await db.clear_unfinished_blocks_below(uint32(5)) - assert len(await db.get_unfinished_blocks()) == 5 + db.clear_unfinished_blocks_below(uint32(5)) + assert len(db.get_unfinished_blocks()) == 5 # Set/get unf block leader assert db.get_unfinished_block_leader() == (0, (1 << 64) - 1) db.set_unfinished_block_leader(key) assert db.get_unfinished_block_leader() == key - assert await db.get_disconnected_block(blocks[0].prev_header_hash) is None + assert db.get_disconnected_block(blocks[0].prev_header_hash) is None # Disconnected blocks for block in blocks: - await db.add_disconnected_block(block) - await db.get_disconnected_block(block.prev_header_hash) == block + db.add_disconnected_block(block) + db.get_disconnected_block(block.prev_header_hash) == block - await db.clear_disconnected_blocks_below(uint32(5)) - assert await db.get_disconnected_block(blocks[4].prev_header_hash) is None + db.clear_disconnected_blocks_below(uint32(5)) + assert db.get_disconnected_block(blocks[4].prev_header_hash) is None h_hash_1 = bytes32(token_bytes(32)) assert not db.seen_unfinished_block(h_hash_1) @@ -122,12 +146,50 @@ async def test_basic_store(self): except Exception: await db.close() await db_2.close() + os.remove(db_filename) + os.remove(db_filename_2) raise # Different database should have different data - db_3 = await FullNodeStore.create("blockchain_test_3") + db_3 = await FullNodeStore.create(db_filename_3) assert db_3.get_unfinished_block_leader() == (0, (1 << 64) - 1) await db.close() await db_2.close() await db_3.close() + os.remove(db_filename) + os.remove(db_filename_2) + os.remove(db_filename_3) + + @pytest.mark.asyncio + async def test_deadlock(self): + blocks = bt.get_consecutive_blocks(test_constants, 10, [], 9, b"0") + db_filename = "blockchain_test" + + if os.path.isfile(db_filename): + os.remove(db_filename) + + db = await FullNodeStore.create(db_filename) + tasks = [] + + for i in range(10000): + rand_i = random.randint(0, 10) + if random.random() < 0.5: + tasks.append(asyncio.create_task(db.add_block(blocks[rand_i]))) + if random.random() < 0.5: + tasks.append( + asyncio.create_task(db.add_potential_block(blocks[rand_i])) + ) + if random.random() < 0.5: + tasks.append( + asyncio.create_task(db.get_block(blocks[rand_i].header_hash)) + ) + if random.random() < 0.5: + tasks.append( + asyncio.create_task( + db.get_potential_block(blocks[rand_i].header_hash) + ) + ) + await asyncio.gather(*tasks) + await db.close() + os.remove(db_filename)