diff --git a/.codecov.yml b/.codecov.yml index a4bd51ef..6e201182 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -12,7 +12,7 @@ coverage: threshold: 10% patch: default: - threshold: 50% + threshold: 95% ignore: - "docs/*" diff --git a/apps/asynchromix/asynchromix.py b/apps/asynchromix/asynchromix.py index 11e19d7a..7e18e9fd 100644 --- a/apps/asynchromix/asynchromix.py +++ b/apps/asynchromix/asynchromix.py @@ -12,6 +12,7 @@ from web3 import HTTPProvider, Web3 from web3.contract import ConciseContract +from web3.exceptions import TransactionNotFound from apps.asynchromix.butterfly_network import iterated_butterfly_network @@ -36,7 +37,10 @@ async def wait_for_receipt(w3, tx_hash): while True: - tx_receipt = w3.eth.getTransactionReceipt(tx_hash) + try: + tx_receipt = w3.eth.getTransactionReceipt(tx_hash) + except TransactionNotFound: + tx_receipt = None if tx_receipt is not None: break await asyncio.sleep(5) @@ -48,8 +52,30 @@ async def wait_for_receipt(w3, tx_hash): ######## -class AsynchromixClient(object): +class AsynchromixClient: + """An Asynchromix client sends "masked" messages to an Ethereum contract. + ... + """ + def __init__(self, sid, myid, send, recv, w3, contract, req_mask): + """ + Parameters + ---------- + sid: int + Session id. + myid: int + Client id. + send: + Function used to send messages. Not used? + recv: + Function used to receive messages. Not used? + w3: + Connection instance to an Ethereum node. + contract: + Contract instance on the Ethereum blockchain. + req_mask: + Function used to request an input mask from a server. + """ self.sid = sid self.myid = myid self.contract = contract @@ -62,7 +88,8 @@ async def _run(self): contract_concise = ConciseContract(self.contract) await asyncio.sleep(60) # give the servers a head start # Client sends several batches of messages then quits - for epoch in range(1000): + # for epoch in range(1000): + for epoch in range(10): logging.info(f"[Client] Starting Epoch {epoch}") receipts = [] for i in range(32): @@ -121,13 +148,13 @@ async def send_message(self, m): # Step 3. Fetch the input mask from the servers inputmask = await self._get_inputmask(inputmask_idx) message = int.from_bytes(m.encode(), "big") - maskedinput = message + inputmask - maskedinput_bytes = self.w3.toBytes(hexstr=hex(maskedinput.value)) - maskedinput_bytes = maskedinput_bytes.rjust(32, b"\x00") + masked_message = message + inputmask + masked_message_bytes = self.w3.toBytes(hexstr=hex(masked_message.value)) + masked_message_bytes = masked_message_bytes.rjust(32, b"\x00") # Step 4. Publish the masked input tx_hash = self.contract.functions.submit_message( - inputmask_idx, maskedinput_bytes + inputmask_idx, masked_message_bytes ).transact({"from": self.w3.eth.accounts[0]}) tx_receipt = await wait_for_receipt(self.w3, tx_hash) @@ -138,19 +165,42 @@ async def send_message(self, m): class AsynchromixServer(object): + """Asynchromix server class to ...""" + def __init__(self, sid, myid, send, recv, w3, contract): + """ + Parameters + ---------- + sid: int + Session id. + myid: int + Client id. + send: + Function used to send messages. + recv: + Function used to receive messages. + w3: + Connection instance to an Ethereum node. + contract: + Contract instance on the Ethereum blockchain. + """ self.sid = sid self.myid = myid self.contract = contract self.w3 = w3 + self._task1a = asyncio.ensure_future(self._offline_inputmasks_loop()) self._task1a.add_done_callback(print_exception_callback) + self._task1b = asyncio.ensure_future(self._offline_mixes_loop()) self._task1b.add_done_callback(print_exception_callback) + self._task2 = asyncio.ensure_future(self._client_request_loop()) self._task2.add_done_callback(print_exception_callback) + self._task3 = asyncio.ensure_future(self._mixing_loop()) self._task3.add_done_callback(print_exception_callback) + self._task4 = asyncio.ensure_future(self._mixing_initiate_loop()) self._task4.add_done_callback(print_exception_callback) @@ -182,7 +232,7 @@ async def join(self): The bits and triples are consumed by each mixing epoch. The input masks may be claimed at a different rate than - than the mixing epochs so they are replenished in a separate + the mixing epochs so they are replenished in a separate task """ @@ -322,13 +372,13 @@ async def _mixing_loop(self): # 3.b. Collect the inputs inputs = [] for idx in range(epoch * K, (epoch + 1) * K): - # Get the public input - masked_input, inputmask_idx = contract_concise.input_queue(idx) - masked_input = field(int.from_bytes(masked_input, "big")) - # Get the input masks + # Get the public input (masked message) + masked_message_bytes, inputmask_idx = contract_concise.input_queue(idx) + masked_message = field(int.from_bytes(masked_message_bytes, "big")) + # Get the input mask inputmask = self._inputmasks[inputmask_idx] - m_share = masked_input - inputmask + m_share = masked_message - inputmask inputs.append(m_share) # 3.c. Collect the preprocessing @@ -349,18 +399,28 @@ async def prog(ctx): pp_elements._init_data_dir() # Overwrite triples and one_minus_ones - for kind, elems in zip(("triples", "one_minus_one"), (triples, bits)): + logging.info("overwriting triples and one_minus_ones") + for kind, elems in zip(("triples", "one_minus_ones"), (triples, bits)): if kind == "triples": elems = flatten_lists(elems) elems = [e.value for e in elems] - mixin = pp_elements.mixins[kind] + # mixin = pp_elements.mixins[kind] + mixin = getattr(pp_elements, f"_{kind}") mixin_filename = mixin.build_filename(ctx.N, ctx.t, ctx.myid) + logging.info( + f"writing preprocessed {kind} to file {mixin_filename}" + ) + logging.info(f"number of elements is: {len(elems)}") mixin._write_preprocessing_file( mixin_filename, ctx.t, ctx.myid, elems, append=False ) - pp_elements._init_mixins() + # FIXME Not sure what this is supposed to be ... + # the method does not exist. + # pp_elements._init_mixins() + pp_elements._triples._refresh_cache() + pp_elements._one_minus_ones._refresh_cache() logging.info(f"[{ctx.myid}] Running permutation network") inps = list(map(ctx.Share, inputs)) @@ -419,9 +479,17 @@ async def _mixing_initiate_loop(self): await asyncio.sleep(5) # Step 4.b. Call initiate mix - tx_hash = self.contract.functions.initiate_mix().transact( - {"from": self.w3.eth.accounts[0]} - ) + try: + tx_hash = self.contract.functions.initiate_mix().transact( + {"from": self.w3.eth.accounts[0]} + ) + except ValueError as err: + logging.info("\n") + logging.info(79 * "*") + logging.info(err) + logging.info(79 * "*") + logging.info("\n") + continue tx_receipt = await wait_for_receipt(self.w3, tx_hash) rich_logs = self.contract.events.MixingEpochInitiated().processReceipt( tx_receipt @@ -484,6 +552,7 @@ async def main_loop(w3): ] # Step 3. Create the client + # TODO communicate with server instead of fetching from list of servers async def req_mask(i, idx): # client requests input mask {idx} from server {i} return servers[i]._inputmasks[idx] diff --git a/apps/asynchromix/asynchromix.sol b/apps/asynchromix/asynchromix.sol index a0f16672..cfa279e8 100644 --- a/apps/asynchromix/asynchromix.sol +++ b/apps/asynchromix/asynchromix.sol @@ -8,13 +8,13 @@ contract AsynchromixCoordinator { * 3. Initiates mixing epochs (MPC computations) * (makes use of preprocess triples, bits, powers) */ - + // Session parameters uint public n; uint public t; address[] public servers; mapping (address => uint) public servermap; - + constructor(address[] _servers, uint _t) public { n = _servers.length; t = _t; @@ -34,7 +34,7 @@ contract AsynchromixCoordinator { "0xca35b7d915458ef540ade6068dfe2f44e8fa733c"] * */ - + // ############################################### // 1. Preprocessing Buffer (the MPC offline phase) // ############################################### @@ -44,26 +44,26 @@ contract AsynchromixCoordinator { uint bits; // [b] with b in {-1,1} uint inputmasks; // [r] } - + // Consensus count (min of the player report counts) PreProcessCount public preprocess; - + // How many of each have been reserved already PreProcessCount public preprocess_used; function inputmasks_available () public view returns(uint) { return preprocess.inputmasks - preprocess_used.inputmasks; - } + } // Report of preprocess buffer size from each server mapping ( uint => PreProcessCount ) public preprocess_reports; - + event PreProcessUpdated(); - + function min(uint a, uint b) private pure returns (uint) { return a < b ? a : b; - } - + } + function max(uint a, uint b) private pure returns (uint) { return a > b ? a : b; } @@ -96,27 +96,27 @@ contract AsynchromixCoordinator { preprocess.bits = mins.bits; preprocess.inputmasks = mins.inputmasks; } - - - + + + // ###################### // 2. Accept client input // ###################### - + // Step 2.a. Clients can reserve an input mask [r] from Preprocessing // maps each element of preprocess.inputmasks to the client (if any) that claims it mapping (uint => address) public inputmasks_claimed; - + event InputMaskClaimed(address client, uint inputmask_idx); - + // Client reserves a random values function reserve_inputmask() public returns(uint) { // Extension point: override this function to add custom token rules - + // An unclaimed input mask must already be available require(preprocess.inputmasks > preprocess_used.inputmasks); - + // Acquire this input mask for msg.sender uint idx = preprocess_used.inputmasks; inputmasks_claimed[idx] = msg.sender; @@ -124,57 +124,58 @@ contract AsynchromixCoordinator { emit InputMaskClaimed(msg.sender, idx); return idx; } - + // Step 2.b. Client requests (out of band, e.g. over https) shares of [r] // from each server. Servers use this function to check authorization. // Authentication using client's address is also out of band function client_authorized(address client, uint idx) view public returns(bool) { return inputmasks_claimed[idx] == client; } - + // Step 2.c. Clients publish masked message (m+r) to provide a new input [m] // and bind it to the preprocess input mapping (uint => bool) public inputmask_map; // Maps a mask - + struct Input { bytes32 masked_input; // (m+r) uint inputmask; // index in inputmask of mask [r] // Extension point: add more metadata about each input } - + Input[] public input_queue; // All inputs sent so far function input_queue_length() public view returns(uint) { return input_queue.length; } - + event MessageSubmitted(uint idx, uint inputmask_idx, bytes32 masked_input); function submit_message(uint inputmask_idx, bytes32 masked_input) public { // Must be authorized to use this input mask require(inputmasks_claimed[inputmask_idx] == msg.sender); - + // Extension point: add additional client authorizations, // e.g. prevent the client from submitting more than one message per mix - + uint idx = input_queue.length; input_queue.length += 1; - + input_queue[idx].masked_input = masked_input; input_queue[idx].inputmask = inputmask_idx; - + + // QUESTION: What is the purpose of this event? emit MessageSubmitted(idx, inputmask_idx, masked_input); // The input masks are deactivated after first use inputmasks_claimed[inputmask_idx] = address(0); } - + // ######################### // 3. Initiate Mixing Epochs // ######################### - + uint public constant K = 32; // Mix Size - + // Preprocessing requirements uint public constant PER_MIX_TRIPLES = (K / 2) * 5 * 5; // k log^2 k uint public constant PER_MIX_BITS = (K / 2) * 5 * 5; @@ -187,7 +188,7 @@ contract AsynchromixCoordinator { return min(triples_available / PER_MIX_TRIPLES, bits_available / PER_MIX_BITS); } - + // Step 3.a. Trigger a mix to start uint public inputs_mixed; uint public epochs_initiated; @@ -196,27 +197,27 @@ contract AsynchromixCoordinator { function inputs_ready() public view returns(uint) { return input_queue.length - inputs_mixed; } - + function initiate_mix() public { // Must mix eactly K values in each epoch require(input_queue.length >= inputs_mixed + K); - + // Can only initiate mix if enough preprocessings are ready require(preprocess.triples >= preprocess_used.triples + PER_MIX_TRIPLES); require(preprocess.bits >= preprocess_used.bits + PER_MIX_BITS); preprocess_used.triples += PER_MIX_TRIPLES; preprocess_used.bits += PER_MIX_BITS; - + inputs_mixed += K; emit MixingEpochInitiated(epochs_initiated); epochs_initiated += 1; output_votes.length = epochs_initiated; output_hashes.length = epochs_initiated; } - + // Step 3.b. Output reporting: the output is considered "approved" once // at least t+1 servers report it - + uint public outputs_ready; event MixOutput(uint epoch, string output); bytes32[] public output_hashes; @@ -242,7 +243,7 @@ contract AsynchromixCoordinator { } else { output_hashes[epoch] = output_hash; } - + output_votes[epoch] += 1; if (output_votes[epoch] == t + 1) { // at least one honest node agrees emit MixOutput(epoch, output); diff --git a/apps/asynchromix/powermixing.py b/apps/asynchromix/powermixing.py index 29d91e6f..f6df5b45 100644 --- a/apps/asynchromix/powermixing.py +++ b/apps/asynchromix/powermixing.py @@ -168,6 +168,7 @@ async def async_mixing_in_processes(network_info, n, t, k, run_id, node_id): if __name__ == "__main__": from honeybadgermpc.config import HbmpcConfig + logging.info("Running powermixing app ...") HbmpcConfig.load_config() run_id = HbmpcConfig.extras["run_id"] @@ -181,6 +182,10 @@ async def async_mixing_in_processes(network_info, n, t, k, run_id, node_id): try: if not HbmpcConfig.skip_preprocessing: + logging.info( + "Running preprocessing.\n" + 'To skip preprocessing phase set "skip_preprocessing" config to true.' + ) # Need to keep these fixed when running on processes. field = GF(Subgroup.BLS12_381) a_s = [field(i) for i in range(1000 + k, 1000, -1)] @@ -191,6 +196,11 @@ async def async_mixing_in_processes(network_info, n, t, k, run_id, node_id): pp_elements.preprocessing_done() else: loop.run_until_complete(pp_elements.wait_for_preprocessing()) + else: + logging.info( + "Skipping preprocessing.\n" + 'To run preprocessing phase set "skip_preprocessing" config to false.' + ) loop.run_until_complete( async_mixing_in_processes( diff --git a/apps/baseclient.py b/apps/baseclient.py new file mode 100644 index 00000000..90a34672 --- /dev/null +++ b/apps/baseclient.py @@ -0,0 +1,225 @@ +import asyncio +import logging +from collections import namedtuple +from pathlib import Path + +from aiohttp import ClientSession + +import toml + +from web3 import HTTPProvider, Web3 +from web3.contract import ConciseContract + +from apps.masks.config import CONTRACT_ADDRESS_FILEPATH +from apps.utils import fetch_contract, get_contract_address, wait_for_receipt + +from honeybadgermpc.elliptic_curve import Subgroup +from honeybadgermpc.field import GF +from honeybadgermpc.polynomial import EvalPoint, polynomials_over +from honeybadgermpc.utils.misc import print_exception_callback + +PARENT_DIR = Path(__file__).resolve().parent +field = GF(Subgroup.BLS12_381) +Server = namedtuple("Server", ("id", "host", "port")) + + +class Client: + """An MPC client that sends "masked" messages to an Ethereum contract.""" + + def __init__(self, sid, myid, w3, req_mask, *, contract_context, mpc_network): + """ + Parameters + ---------- + sid: int + Session id. + myid: int + Client id. + w3: + Connection instance to an Ethereum node. + req_mask: + Function used to request an input mask from a server. + contract_context: dict + Contract attributes needed to interact with the contract + using web3. Should contain the address, name and source code + file path. + mpc_network : list or tuple or set + List or tuple or set of MPC servers, where each element is a + dictionary of server attributes: "id", "host", and "port". + """ + self.sid = sid + self.myid = myid + self._contract_context = contract_context + self.contract = fetch_contract(w3, **contract_context) + self.w3 = w3 + self.req_mask = req_mask + self.mpc_network = [Server(**server_attrs) for server_attrs in mpc_network] + self._task = asyncio.create_task(self._run()) + self._task.add_done_callback(print_exception_callback) + + @classmethod + def from_dict_config(cls, config): + """Create a ``Client`` class instance from a config dict. + + Parameters + ---------- + config : dict + The configuration to create the ``Client`` instance. + """ + eth_config = config["eth"] + # contract + contract_context = { + "address": get_contract_address(CONTRACT_ADDRESS_FILEPATH), + "filepath": eth_config["contract_path"], + "name": eth_config["contract_name"], + } + + # web3 + eth_rpc_hostname = eth_config["rpc_host"] + eth_rpc_port = eth_config["rpc_port"] + w3_endpoint_uri = f"http://{eth_rpc_hostname}:{eth_rpc_port}" + w3 = Web3(HTTPProvider(w3_endpoint_uri)) + + # mpc network + mpc_network = config["servers"] + + return cls( + config["session_id"], + config["id"], + w3, + None, # TODO remove or pass callable for GET /inputmasks/{id} + contract_context=contract_context, + mpc_network=mpc_network, + ) + + @classmethod + def from_toml_config(cls, config_path): + """Create a ``Client`` class instance from a config TOML file. + + Parameters + ---------- + config_path : str + The path to the TOML configuration file to create the + ``Client`` instance. + """ + config = toml.load(config_path) + # TODO extract resolving of relative path into utils + context_path = Path(config_path).resolve().parent.joinpath(config["context"]) + config["eth"]["contract_path"] = context_path.joinpath( + config["eth"]["contract_path"] + ) + return cls.from_dict_config(config) + + async def _run(self): + contract_concise = ConciseContract(self.contract) + # Client sends several batches of messages then quits + # for epoch in range(1000): + for epoch in range(3): + logging.info(f"[Client] Starting Epoch {epoch}") + receipts = [] + m = f"Hello! (Epoch: {epoch})" + task = asyncio.ensure_future(self.send_message(m)) + task.add_done_callback(print_exception_callback) + receipts.append(task) + receipts = await asyncio.gather(*receipts) + + while True: # wait before sending next + if contract_concise.outputs_ready() > epoch: + break + await asyncio.sleep(5) + + async def _request_mask_share(self, server, mask_idx): + logging.info( + f"query server {server.host}:{server.port} " + f"for its share of input mask with id {mask_idx}" + ) + url = f"http://{server.host}:{server.port}/inputmasks/{mask_idx}" + async with ClientSession() as session: + async with session.get(url) as resp: + json_response = await resp.json() + return json_response["inputmask"] + + def _request_mask_shares(self, mpc_network, mask_idx): + shares = [] + for server in mpc_network: + share = self._request_mask_share(server, mask_idx) + shares.append(share) + return shares + + def _req_masks(self, server_ids, mask_idx): + shares = [] + for server_id in server_ids: + share = self.req_mask(server_id, mask_idx) + shares.append(share) + return shares + + async def _get_inputmask(self, idx): + # Private reconstruct + contract_concise = ConciseContract(self.contract) + n = contract_concise.n() + poly = polynomials_over(field) + eval_point = EvalPoint(field, n, use_omega_powers=False) + # shares = self._req_masks(range(n), idx) + shares = self._request_mask_shares(self.mpc_network, idx) + shares = await asyncio.gather(*shares) + logging.info( + f"{len(shares)} of input mask shares have" + "been received from the MPC servers" + ) + logging.info( + "privately reconstruct the input mask from the received shares ..." + ) + shares = [(eval_point(i), share) for i, share in enumerate(shares)] + mask = poly.interpolate_at(shares, 0) + return mask + + async def join(self): + await self._task + + async def send_message(self, m): + logging.info("sending message ...") + # Submit a message to be unmasked + contract_concise = ConciseContract(self.contract) + + # Step 1. Wait until there is input available, and enough triples + while True: + inputmasks_available = contract_concise.inputmasks_available() + logging.info(f"inputmasks_available: {inputmasks_available}") + if inputmasks_available >= 1: + break + await asyncio.sleep(5) + + # Step 2. Reserve the input mask + logging.info("trying to reserve an input mask ...") + tx_hash = self.contract.functions.reserve_inputmask().transact( + {"from": self.w3.eth.accounts[0]} + ) + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + rich_logs = self.contract.events.InputMaskClaimed().processReceipt(tx_receipt) + if rich_logs: + inputmask_idx = rich_logs[0]["args"]["inputmask_idx"] + else: + raise ValueError + logging.info(f"input mask (id: {inputmask_idx}) reserved") + logging.info(f"tx receipt hash is: {tx_receipt['transactionHash'].hex()}") + + # Step 3. Fetch the input mask from the servers + logging.info("query the MPC servers for their share of the input mask ...") + inputmask = await self._get_inputmask(inputmask_idx) + logging.info("input mask has been privately reconstructed") + message = int.from_bytes(m.encode(), "big") + logging.info("masking the message ...") + masked_message = message + inputmask + masked_message_bytes = self.w3.toBytes(hexstr=hex(masked_message.value)) + masked_message_bytes = masked_message_bytes.rjust(32, b"\x00") + + # Step 4. Publish the masked input + logging.info("publish the masked message to the public contract ...") + tx_hash = self.contract.functions.submit_message( + inputmask_idx, masked_message_bytes + ).transact({"from": self.w3.eth.accounts[0]}) + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + logging.info( + f"masked message has been published to the " + f"public contract at address {self.contract.address}" + ) + logging.info(f"tx receipt hash is: {tx_receipt['transactionHash'].hex()}") diff --git a/apps/helloshard/__init__.py b/apps/helloshard/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/helloshard/client.py b/apps/helloshard/client.py new file mode 100644 index 00000000..bb261b3c --- /dev/null +++ b/apps/helloshard/client.py @@ -0,0 +1,117 @@ +import asyncio +import logging + +from web3.contract import ConciseContract + +from apps.utils import wait_for_receipt + +from honeybadgermpc.elliptic_curve import Subgroup +from honeybadgermpc.field import GF +from honeybadgermpc.polynomial import EvalPoint, polynomials_over +from honeybadgermpc.utils.misc import print_exception_callback + +field = GF(Subgroup.BLS12_381) + + +class Client: + """An MPC client that sends "masked" messages to an Ethereum contract.""" + + def __init__(self, sid, myid, send, recv, w3, contract, req_mask): + """ + Parameters + ---------- + sid: int + Session id. + myid: int + Client id. + send: + Function used to send messages. Not used? + recv: + Function used to receive messages. Not used? + w3: + Connection instance to an Ethereum node. + contract: + Contract instance on the Ethereum blockchain. + req_mask: + Function used to request an input mask from a server. + """ + self.sid = sid + self.myid = myid + self.contract = contract + self.w3 = w3 + self.req_mask = req_mask + self._task = asyncio.ensure_future(self._run()) + self._task.add_done_callback(print_exception_callback) + + async def _run(self): + contract_concise = ConciseContract(self.contract) + await asyncio.sleep(60) # give the servers a head start + # Client sends several batches of messages then quits + for epoch in range(3): + logging.info(f"[Client] Starting Epoch {epoch}") + receipts = [] + m = f"Hello Shard! (Epoch: {epoch})" + task = asyncio.ensure_future(self.send_message(m)) + task.add_done_callback(print_exception_callback) + receipts.append(task) + receipts = await asyncio.gather(*receipts) + + while True: # wait before sending next + # if contract_concise.intershard_msg_ready() > epoch: + if contract_concise.outputs_ready() > epoch: + break + await asyncio.sleep(5) + + async def _get_inputmask(self, idx): + # Private reconstruct + contract_concise = ConciseContract(self.contract) + n = contract_concise.n() + poly = polynomials_over(field) + eval_point = EvalPoint(field, n, use_omega_powers=False) + shares = [] + for i in range(n): + share = self.req_mask(i, idx) + shares.append(share) + shares = await asyncio.gather(*shares) + shares = [(eval_point(i), share) for i, share in enumerate(shares)] + mask = poly.interpolate_at(shares, 0) + return mask + + async def join(self): + await self._task + + async def send_message(self, m): + # Submit a message to be unmasked + contract_concise = ConciseContract(self.contract) + + # Step 1. Wait until there is input available, and enough triples + while True: + inputmasks_available = contract_concise.inputmasks_available() + # logging.infof'inputmasks_available: {inputmasks_available}') + if inputmasks_available >= 1: + break + await asyncio.sleep(5) + + # Step 2. Reserve the input mask + tx_hash = self.contract.functions.reserve_inputmask().transact( + {"from": self.w3.eth.accounts[0]} + ) + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + rich_logs = self.contract.events.InputMaskClaimed().processReceipt(tx_receipt) + if rich_logs: + inputmask_idx = rich_logs[0]["args"]["inputmask_idx"] + else: + raise ValueError + + # Step 3. Fetch the input mask from the servers + inputmask = await self._get_inputmask(inputmask_idx) + message = int.from_bytes(m.encode(), "big") + masked_message = message + inputmask + masked_message_bytes = self.w3.toBytes(hexstr=hex(masked_message.value)) + masked_message_bytes = masked_message_bytes.rjust(32, b"\x00") + + # Step 4. Publish the masked input + tx_hash = self.contract.functions.submit_message( + inputmask_idx, masked_message_bytes + ).transact({"from": self.w3.eth.accounts[0]}) + tx_receipt = await wait_for_receipt(self.w3, tx_hash) diff --git a/apps/helloshard/contract.sol b/apps/helloshard/contract.sol new file mode 100644 index 00000000..9a1ae23c --- /dev/null +++ b/apps/helloshard/contract.sol @@ -0,0 +1,291 @@ +pragma solidity >=0.4.22 <0.6.0; + +contract MpcCoordinator { + /* A blockchain-based MPC coordinator. + * 1. Keeps track of the MPC "preprocessing buffer" + * 2. Accepts client input + * (makes use of preprocess randoms) + * 3. Initiates MPC epochs (MPC computations) + * (can make use of preprocessing values if needed) + */ + + // Session parameters + uint public n; + uint public t; + address[] public shard_1; + address[] public shard_2; + mapping (address => uint) public servermap; + // mapping (address => uint) public shard_1_map; + // mapping (address => uint) public shard_2_map; + + // Who shards? + // =========== + // A different approach could be that that a list of servers is + // passed and then split into shards, as opposed to expect shards. + // + // In other words, the task of assigning servers to shards can be: + // + // 1. the responsibility of the code that instantiates the contract + // 2. the responsibility of the contract code + constructor(address[] _shard_1, address[] _shard_2, uint _t) public { + // for simplicity, both shards have the same number of servers + require(_shard_1.length == _shard_2.length); + n = _shard_1.length; + t = _t; + require(3*t < n); + shard_1.length = n; + shard_2.length = n; + for (uint i = 0; i < n; i++) { + shard_1[i] = _shard_1[i]; + shard_2[i] = _shard_2[i]; + servermap[_shard_1[i]] = i+1; // servermap is off-by-one + servermap[_shard_2[i]] = i+1+n; // servermap is off-by-one + // shard_1_map[_shard_1[i]] = i+1; // servermap is off-by-one + // shard_2_map[_shard_2[i]] = i+1; // servermap is off-by-one + } + } + + // ############################################### + // 1. Preprocessing Buffer (the MPC offline phase) + // ############################################### + + struct PreProcessCount { + uint intershardmasks; + uint inputmasks; // [r] + } + + // Consensus count (min of the player report counts) + PreProcessCount public preprocess; + + // How many of each have been reserved already + PreProcessCount public preprocess_used; + + function inputmasks_available () public view returns(uint) { + return preprocess.inputmasks - preprocess_used.inputmasks; + } + + // Report of preprocess buffer size from each server + mapping ( uint => PreProcessCount ) public preprocess_reports; + + event PreProcessUpdated(); + + function min(uint a, uint b) private pure returns (uint) { + return a < b ? a : b; + } + + function max(uint a, uint b) private pure returns (uint) { + return a > b ? a : b; + } + + function preprocess_report(uint[1] rep) public { + // Update the Report + require(servermap[msg.sender] > 0); // only valid servers + uint id = servermap[msg.sender] - 1; + preprocess_reports[id].inputmasks = rep[0]; + + // Update the consensus + // .triples = min (over each id) of _reports[id].triples; same for bits, etc. + PreProcessCount memory mins; + mins.inputmasks = preprocess_reports[0].inputmasks; + for (uint i = 1; i < n; i++) { + mins.inputmasks = min(mins.inputmasks, preprocess_reports[i].inputmasks); + } + if (preprocess.inputmasks < mins.inputmasks) { + emit PreProcessUpdated(); + } + preprocess.inputmasks = mins.inputmasks; + } + + + + // ###################### + // 2. Accept client input + // ###################### + + // Step 2.a. Clients can reserve an input mask [r] from Preprocessing + + // maps each element of preprocess.inputmasks to the client (if any) that claims it + mapping (uint => address) public inputmasks_claimed; + + event InputMaskClaimed(address client, uint inputmask_idx); + + // Client reserves a random values + function reserve_inputmask() public returns(uint) { + // Extension point: override this function to add custom token rules + + // An unclaimed input mask must already be available + require(preprocess.inputmasks > preprocess_used.inputmasks); + + // Acquire this input mask for msg.sender + uint idx = preprocess_used.inputmasks; + inputmasks_claimed[idx] = msg.sender; + preprocess_used.inputmasks += 1; + emit InputMaskClaimed(msg.sender, idx); + return idx; + } + + // Step 2.b. Client requests (out of band, e.g. over https) shares of [r] + // from each server. Servers use this function to check authorization. + // Authentication using client's address is also out of band + function client_authorized(address client, uint idx) view public returns(bool) { + return inputmasks_claimed[idx] == client; + } + + // Step 2.c. Clients publish masked message (m+r) to provide a new input [m] + // and bind it to the preprocess input + mapping (uint => bool) public inputmask_map; // Maps a mask + + struct Input { + bytes32 masked_input; // (m+r) + uint inputmask; // index in inputmask of mask [r] + + // Extension point: add more metadata about each input + } + + Input[] public input_queue; // All inputs sent so far + function input_queue_length() public view returns(uint) { + return input_queue.length; + } + + event MessageSubmitted(uint idx, uint inputmask_idx, bytes32 masked_input); + + function submit_message(uint inputmask_idx, bytes32 masked_input) public { + // Must be authorized to use this input mask + require(inputmasks_claimed[inputmask_idx] == msg.sender); + + // Extension point: add additional client authorizations, + // e.g. prevent the client from submitting more than one message per mix + + uint idx = input_queue.length; + input_queue.length += 1; + + input_queue[idx].masked_input = masked_input; + input_queue[idx].inputmask = inputmask_idx; + + // QUESTION: What is the purpose of this event? + emit MessageSubmitted(idx, inputmask_idx, masked_input); + + // The input masks are deactivated after first use + inputmasks_claimed[inputmask_idx] = address(0); + } + + // ###################### + // 3. Initiate MPC Epochs + // ###################### + + uint public constant K = 1; // number of messages per epoch + + // Preprocessing requirements + uint public constant PER_EPOCH_INTERSHARDMASKS = K * n * 2; + + // Return the maximum number of mixes that can be run with the + // available preprocessing + function intershardmasks_available() public view returns(uint) { + return preprocess.intershardmasks - preprocess_used.intershardmasks; + } + + // Step 3.a. Trigger MPC to start + uint public inputs_unmasked; + uint public epochs_initiated; + event MpcEpochInitiated(uint epoch); + + function inputs_ready() public view returns(uint) { + return input_queue.length - inputs_unmasked; + } + + function initiate_mpc() public { + // Must mix eactly K values in each epoch + require(input_queue.length >= inputs_unmasked + K); + inputs_unmasked += K; + emit MpcEpochInitiated(epochs_initiated); + epochs_initiated += 1; + intershard_msg_votes.length = epochs_initiated; + intershard_messages.length = epochs_initiated; + output_votes.length = epochs_initiated; + output_hashes.length = epochs_initiated; + } + + // Step 3.b. Output reporting: the output is considered "approved" once + // at least t+1 servers report it + struct IntershardMessage { + bytes32 masked_msg; // (m+r) + uint mask_idx; // index of intershard mask + } + + IntershardMessage[] public intershard_msg_queue; // All inputs sent so far + function intershard_msg_queue_length() public view returns(uint) { + return intershard_msg_queue.length; + } + + uint public intershard_msg_ready; + event IntershardMessageReady(uint epoch, uint msg_idx, bytes32 masked_msg); + bytes32[] public intershard_messages; + uint[] public intershard_msg_votes; + mapping (uint => uint) public server_voted_in_epoch; // highest epoch voted in + + // function propose_intershard_secrets(uint epoch, uint[] secrets) public { + function transfer_intershard_message(uint epoch, bytes32 masked_msg) public { + require(epoch < epochs_initiated); // can't provide output if it hasn't been initiated + require(servermap[msg.sender] > 0); // only valid servers + uint id = servermap[msg.sender] - 1; + + // Each server can only vote once per epoch + // Hazard note: honest servers must vote in strict ascending order, or votes + // will be lost! + require(epoch <= server_voted_in_epoch[id]); + server_voted_in_epoch[id] = max(epoch + 1, server_voted_in_epoch[id]); + + if (intershard_messages[epoch] > 0) { + // All the votes must match + require(masked_msg == intershard_messages[epoch]); + } else { + intershard_messages[epoch] = masked_msg; + } + + intershard_msg_votes[epoch] += 1; + if (intershard_msg_votes[epoch] == t + 1) { // at least one honest node agrees + uint idx = intershard_msg_queue.length; + intershard_msg_queue.length += 1; + intershard_msg_queue[idx].masked_msg = masked_msg; + intershard_msg_queue[idx].mask_idx = epoch; + emit IntershardMessageReady(epoch, idx, masked_msg); + intershard_msg_ready += 1; + } + } + + // Output reporting: the output is considered "approved" once + // at least t+1 servers report it + + uint public outputs_ready; + event MpcOutput(uint epoch, string output); + bytes32[] public output_hashes; + uint[] public output_votes; + mapping (uint => uint) public server_voted; // highest epoch voted in + + function propose_output(uint epoch, string output) public { + require(epoch < epochs_initiated); // can't provide output if it hasn't been initiated + require(servermap[msg.sender] > 0); // only valid servers + uint id = servermap[msg.sender] - 1; + + // Each server can only vote once per epoch + // Hazard note: honest servers must vote in strict ascending order, or votes + // will be lost! + require(epoch <= server_voted[id]); + server_voted[id] = max(epoch + 1, server_voted[id]); + + bytes32 output_hash = sha3(output); + + if (output_votes[epoch] > 0) { + // All the votes must match + require(output_hash == output_hashes[epoch]); + } else { + output_hashes[epoch] = output_hash; + } + + output_votes[epoch] += 1; + if (output_votes[epoch] == t + 1) { // at least one honest node agrees + emit MpcOutput(epoch, output); + outputs_ready += 1; + } + } +} diff --git a/apps/helloshard/main.py b/apps/helloshard/main.py new file mode 100644 index 00000000..97e8f7f2 --- /dev/null +++ b/apps/helloshard/main.py @@ -0,0 +1,151 @@ +import asyncio +import logging +import subprocess +from contextlib import contextmanager +from pathlib import Path + +from ethereum.tools._solidity import compile_code as compile_source + +from web3 import HTTPProvider, Web3 +from web3.contract import ConciseContract + +from apps.helloshard.client import Client +from apps.helloshard.server import Server +from apps.utils import wait_for_receipt + +from honeybadgermpc.preprocessing import PreProcessedElements +from honeybadgermpc.router import SimpleRouter + + +async def main_loop(w3, *, contract_name, contract_filepath): + # system config parameters + n = 4 + t = 1 + k = 1000 # number of intershard masks to generate + pp_elements = PreProcessedElements() + # deletes sharedata/ if present + pp_elements.clear_preprocessing() + pp_elements.generate_intershard_masks(k, n, t, shard_1_id=0, shard_2_id=1) + intershard_masks = pp_elements._intershard_masks + + # Step 1. + # Create the coordinator contract and web3 interface to it + compiled_sol = compile_source( + open(contract_filepath).read() + ) # Compiled source code + contract_interface = compiled_sol[f":{contract_name}"] + contract_class = w3.eth.contract( + abi=contract_interface["abi"], bytecode=contract_interface["bin"] + ) + + # 2 shards: n=4, t=1 for each shard + shard_1_accounts = w3.eth.accounts[:4] + shard_2_accounts = w3.eth.accounts[4:8] + tx_hash = contract_class.constructor( + shard_1_accounts, shard_2_accounts, 1 + ).transact({"from": w3.eth.accounts[0]}) + + # Get tx receipt to get contract address + tx_receipt = await wait_for_receipt(w3, tx_hash) + contract_address = tx_receipt["contractAddress"] + + if w3.eth.getCode(contract_address) == b"": + logging.critical("code was empty 0x, constructor may have run out of gas") + raise ValueError + + # Contract instance in concise mode + abi = contract_interface["abi"] + contract = w3.eth.contract(address=contract_address, abi=abi) + contract_concise = ConciseContract(contract) + + # Call read only methods to check, and check that n in contract is as expected + assert contract_concise.n() == n + + # Step 2: Create the servers + servers = [] + for shard_id in (0, 1): + is_gateway_shard = True if shard_id == 0 else False + router = SimpleRouter(n) + sends, recvs = router.sends, router.recvs + for i in range(n): + servers.append( + Server( + "sid", + i, + sends[i], + recvs[i], + w3, + contract, + shard_id=shard_id, + intershardmask_shares=intershard_masks.cache[ + (f"{i}-{shard_id}", n, t) + ], + is_gateway_shard=is_gateway_shard, + ) + ) + + # Step 3. Create the client + # TODO communicate with server instead of fetching from list of servers + async def req_mask(i, idx): + # client requests input mask {idx} from server {i} + return servers[i]._inputmasks[idx] + + client = Client("sid", "client", None, None, w3, contract, req_mask) + + # Step 4. Wait for conclusion + for i, server in enumerate(servers): + await server.join() + await client.join() + + +@contextmanager +def run_and_terminate_process(*args, **kwargs): + try: + p = subprocess.Popen(*args, **kwargs) + yield p + finally: + logging.info(f"Killing ganache-cli {p.pid}") + p.terminate() # send sigterm, or ... + p.kill() # send sigkill + p.wait() + logging.info("done") + + +def run_eth(*, contract_name, contract_filepath): + w3 = Web3(HTTPProvider()) # Connect to localhost:8545 + asyncio.set_event_loop(asyncio.new_event_loop()) + loop = asyncio.get_event_loop() + + try: + logging.info("entering loop") + loop.run_until_complete( + asyncio.gather( + main_loop( + w3, contract_name=contract_name, contract_filepath=contract_filepath + ) + ) + ) + finally: + logging.info("closing") + loop.close() + + +def test_asynchromix(contract_name=None, contract_filepath=None): + import time + + # cmd = 'testrpc -a 50 2>&1 | tee -a acctKeys.json' + # with run_and_terminate_process(cmd, shell=True, + # stdout=sys.stdout, stderr=sys.stderr) as proc: + cmd = "ganache-cli -p 8545 -a 50 -b 1 > acctKeys.json 2>&1" + logging.info(f"Running {cmd}") + with run_and_terminate_process(cmd, shell=True): + time.sleep(5) + run_eth(contract_name=contract_name, contract_filepath=contract_filepath) + + +if __name__ == "__main__": + # Launch an ethereum test chain + contract_name = "MpcCoordinator" + contract_filename = "contract.sol" + contract_filepath = Path(__file__).resolve().parent.joinpath(contract_filename) + test_asynchromix(contract_name=contract_name, contract_filepath=contract_filepath) diff --git a/apps/helloshard/server.py b/apps/helloshard/server.py new file mode 100644 index 00000000..0d424d50 --- /dev/null +++ b/apps/helloshard/server.py @@ -0,0 +1,422 @@ +import asyncio +import logging +import time + +from web3.contract import ConciseContract + +from apps.utils import wait_for_receipt + +from honeybadgermpc.elliptic_curve import Subgroup +from honeybadgermpc.field import GF +from honeybadgermpc.mpc import Mpc +from honeybadgermpc.offline_randousha import randousha +from honeybadgermpc.utils.misc import ( + print_exception_callback, + subscribe_recv, + wrap_send, +) + +field = GF(Subgroup.BLS12_381) + + +class Server: + """MPC server class. The server's main functions are, for one epoch: + + * preprocessing for client masks and intershard masks + * consume secret from client + * produce masked message for other shard + * consume secret from other shard + + Notes + ----- + preprocessing + ^^^^^^^^^^^^^ + 1. (intra-shard communication) generate input masks via randousha + (requires intrashard collab with other nodes) + 2. (inter-shard communication) generate intershard masks via + randousha (requires intershard collab with other nodes) + + consume secret from client + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + 1. (blockchain state read) consume a client's secret from a contract + 2. (intra-shard communication) unmask the secret in a MPC with + nodes of its shard + + produce masked message for other shard + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + 1. mask the client message with an intershard mask share + 2. (intra-shard communication) open the masked share to get the + intershard masked message + 3. (blockchain state write) submit the intershard masked message to + the contract + + consume secret from other shard + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + 1. (blockchain state read) consume intershard secret + 2. set [m] = secret - intershard_mask + 3. m = [m].open() + 4. (blockchain state write) notify other shard that message has been + received -- propose output to contract + """ + + def __init__( + self, + sid, + myid, + send, + recv, + w3, + contract, + *, + shard_id, + intershardmask_shares, + is_gateway_shard=False, + ): + """ + Parameters + ---------- + sid: int + Session id. + myid: int + Client id. + send: + Function used to send messages. + recv: + Function used to receive messages. + w3: + Connection instance to an Ethereum node. + contract: + Contract instance on the Ethereum blockchain. + """ + self.sid = sid + self.myid = myid + self.contract = contract + self.w3 = w3 + self.shard_id = shard_id + self.intershardmask_shares = tuple(intershardmask_shares) + self.is_gateway_shard = is_gateway_shard + self._init_tasks() + self._subscribe_task, subscribe = subscribe_recv(recv) + + def _get_send_recv(tag): + return wrap_send(tag, send), subscribe(tag) + + self.get_send_recv = _get_send_recv + self._inputmasks = [] + + @property + def global_id(self): + """Unique of id of the server with respect to other servers, + in its shard and other shards. + """ + return f"{self.myid}-{self.shard_id}" + + @property + def eth_account_index(self): + return self.myid + self.shard_id * 4 + + def _init_tasks(self): + if self.is_gateway_shard: + self._task1 = asyncio.ensure_future(self._offline_inputmasks_loop()) + self._task1.add_done_callback(print_exception_callback) + if not self.is_gateway_shard: + self._task1b = asyncio.ensure_future(self._recv_intershard_msg_loop()) + self._task1b.add_done_callback(print_exception_callback) + self._task2 = asyncio.ensure_future(self._client_request_loop()) + self._task2.add_done_callback(print_exception_callback) + self._task3 = asyncio.ensure_future(self._mpc_loop()) + self._task3.add_done_callback(print_exception_callback) + self._task4 = asyncio.ensure_future(self._mpc_initiate_loop()) + self._task4.add_done_callback(print_exception_callback) + + async def join(self): + if self.is_gateway_shard: + await self._task1 + if not self.is_gateway_shard: + await self._task1b + await self._task2 + await self._task3 + await self._task4 + await self._subscribe_task + + ####################### + # Step 1. Offline Phase + ####################### + """ + 1a. offline inputmasks + """ + + async def _preprocess_report(self): + # Submit the preprocessing report + tx_hash = self.contract.functions.preprocess_report( + [len(self._inputmasks)] + ).transact({"from": self.w3.eth.accounts[self.eth_account_index]}) + + # Wait for the tx receipt + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + return tx_receipt + + async def _offline_inputmasks_loop(self): + contract_concise = ConciseContract(self.contract) + n = contract_concise.n() + t = contract_concise.t() + preproc_round = 0 + k = 1 # batch size + while True: + # Step 1. I) Wait until needed + while True: + inputmasks_available = contract_concise.inputmasks_available() + totalmasks = contract_concise.preprocess()[1] + # Policy: try to maintain a buffer of 10 input masks + target = 10 + if inputmasks_available < target: + break + # already have enough input masks, sleep + await asyncio.sleep(5) + + # Step 1. II) Run Randousha + logging.info( + f"[{self.global_id}] totalmasks: {totalmasks} \ + inputmasks available: {inputmasks_available} \ + target: {target} Initiating Randousha {k * (n - 2*t)}" + ) + send, recv = self.get_send_recv(f"preproc:inputmasks:{preproc_round}") + start_time = time.time() + rs_t, rs_2t = zip(*await randousha(n, t, k, self.myid, send, recv, field)) + assert len(rs_t) == len(rs_2t) == k * (n - 2 * t) + + # Note: here we just discard the rs_2t + # In principle both sides of randousha could be used with + # a small modification to randousha + end_time = time.time() + logging.debug( + f"[{self.global_id}] Randousha finished in {end_time-start_time}" + ) + logging.debug(f"len(rs_t): {len(rs_t)}") + logging.debug(f"rs_t: {rs_t}") + self._inputmasks += rs_t + + # Step 1. III) Submit an updated report + await self._preprocess_report() + + # Increment the preprocessing round and continue + preproc_round += 1 + + async def _client_request_loop(self): + # Task 2. Handling client input + # TODO: if a client requests a share, + # check if it is authorized and if so send it along + pass + + def _collect_client_input(self, *, index, queue): + # Get the public input (masked message) + masked_message_bytes, inputmask_idx = queue(index) + masked_message = field(int.from_bytes(masked_message_bytes, "big")) + # Get the input mask + logging.debug(f"[{self.global_id}] inputmask idx: {inputmask_idx}") + logging.debug(f"[{self.global_id}] inputmasks: {self._inputmasks}") + try: + inputmask = self._inputmasks[inputmask_idx] + except KeyError as err: + logging.error(err) + logging.error(f"[{self.global_id}] inputmask idx: {inputmask_idx}") + logging.error(f"[{self.global_id}] inputmasks: {self._inputmasks}") + + msg_field_elem = masked_message - inputmask + return msg_field_elem + + # TODO generalize client and intershard collection + def _collect_intershard_msg(self, *, index, queue): + masked_message_bytes, mask_idx = queue(index) + masked_message = field(int.from_bytes(masked_message_bytes, "big")) + mask = self.intershardmask_shares[mask_idx] + msg_field_elem = masked_message - mask + return msg_field_elem + + async def _mpc_loop(self): + # Task 3. Participating in MPC epochs + contract_concise = ConciseContract(self.contract) + n = contract_concise.n() + t = contract_concise.t() + + epoch = 0 + while True: + # 3.a. Wait for the next MPC to be initiated + while True: + epochs_initiated = contract_concise.epochs_initiated() + if epochs_initiated > epoch: + break + await asyncio.sleep(5) + + if self.is_gateway_shard: + client_msg_field_elem = self._collect_client_input( + index=epoch, queue=contract_concise.input_queue + ) + + # 3.d. Call the MPC program + async def prog(ctx): + logging.info(f"[{self.global_id}] Running MPC network") + client_msg_share = ctx.Share(client_msg_field_elem) + client_msg = await client_msg_share.open() + logging.info(f"[{self.global_id}] Client secret opened.") + mask_field_elem = field(self.intershardmask_shares[epoch]) + intershard_masked_msg_share = ctx.Share( + client_msg + mask_field_elem + ) + intershard_masked_msg = await intershard_masked_msg_share.open() + return intershard_masked_msg.value + + send, recv = self.get_send_recv(f"mpc:{epoch}") + logging.info(f"[{self.global_id}] MPC initiated:{epoch}") + + config = {} + ctx = Mpc( + f"mpc:{epoch}", + n, + t, + self.myid, + send, + recv, + prog, + config, + shard_id=self.shard_id, + ) + result = await ctx._run() + logging.info( + f"[{self.global_id}] MPC Intershard message queued: {result}" + ) + + # 3.e. Output the published messages to contract + # TODO instead of proposing output, mask the output with an intershard + # mask, and submit the masked output + # + # 1) fetch an intershard mask + # 2) intershard_secret = message + intershard_mask + # 3) submit intershard secret to contract + intershard_masked_msg = result.to_bytes(32, "big") + tx_hash = self.contract.functions.transfer_intershard_message( + epoch, intershard_masked_msg + ).transact({"from": self.w3.eth.accounts[self.eth_account_index]}) + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + rich_logs = self.contract.events.IntershardMessageReady().processReceipt( + tx_receipt + ) + if rich_logs: + epoch = rich_logs[0]["args"]["epoch"] + msg_idx = rich_logs[0]["args"]["msg_idx"] + masked_msg = rich_logs[0]["args"]["masked_msg"] + logging.info( + f"[{self.global_id}] MPC INTERSHARD XFER [{epoch}] {msg_idx} {masked_msg}" + ) + + epoch += 1 + + async def _recv_intershard_msg_loop(self): + # Task 3. Participating in MPC epochs + contract_concise = ConciseContract(self.contract) + n = contract_concise.n() + t = contract_concise.t() + + epoch = 0 + while True: + # Wait for a message to be available + while True: + intershard_msg_ready = contract_concise.intershard_msg_ready() + if intershard_msg_ready > epoch: + break + await asyncio.sleep(5) + + if not self.is_gateway_shard: + msg_field_elem = self._collect_intershard_msg( + index=epoch, queue=contract_concise.intershard_msg_queue + ) + + # Call the MPC program + async def prog(ctx): + logging.info(f"[{self.global_id}] Processing intershard message") + msg_share = ctx.Share(msg_field_elem) + msg = await msg_share.open() + logging.info(f"[{self.global_id}] Intershard message opened.") + return msg.value + + send, recv = self.get_send_recv(f"mpc:{epoch}") + logging.info( + f"[{self.global_id}] MPC intershard message processing:{epoch}" + ) + + config = {} + ctx = Mpc( + f"mpc:{epoch}", + n, + t, + self.myid, + send, + recv, + prog, + config, + shard_id=self.shard_id, + ) + result = await ctx._run() + logging.info( + f"[{self.global_id}] MPC intershard message processing done {result}" + ) + intershard_msg = result.to_bytes(32, "big").decode().strip("\x00") + logging.info( + f"[{self.global_id}] MPC intershard message processing done {intershard_msg}" + ) + logging.info( + f"[{self.global_id}] eth account index: {self.eth_account_index}" + ) + eth_addr = self.w3.eth.accounts[self.eth_account_index] + logging.info(f"[{self.global_id}] eth addr: {eth_addr}") + balance = self.w3.eth.getBalance(eth_addr) + logging.info(f"[{self.global_id}] eth account balance: {balance}") + try: + tx_hash = self.contract.functions.propose_output( + epoch, intershard_msg + ).transact({"from": self.w3.eth.accounts[self.eth_account_index]}) + except ValueError as err: + logging.error(f"[{self.global_id}] eth addr: {eth_addr}") + logging.error(f"[{self.global_id}] balance: {balance}") + raise ValueError(f"[{self.global_id}] {err}") + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + rich_logs = self.contract.events.MpcOutput().processReceipt(tx_receipt) + if rich_logs: + epoch = rich_logs[0]["args"]["epoch"] + output = rich_logs[0]["args"]["output"] + logging.info(f"[{self.global_id}] MPC OUTPUT[{epoch}] {output}") + epoch += 1 + + async def _mpc_initiate_loop(self): + # Task 4. Initiate MPC epochs + contract_concise = ConciseContract(self.contract) + K = contract_concise.K() # noqa: N806 + while True: + # Step 4.a. Wait until there are k values then call initiate_mpc + while True: + inputs_ready = contract_concise.inputs_ready() + if inputs_ready >= K: + break + await asyncio.sleep(5) + + # Step 4.b. Call initiate_mpc + try: + tx_hash = self.contract.functions.initiate_mpc().transact( + {"from": self.w3.eth.accounts[0]} + ) + except ValueError as err: + # Since only one server is needed to initiate the MPC, once + # intiated, a ValueError will occur due to the race condition + # between the servers. + logging.debug(err) + continue + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + rich_logs = self.contract.events.MpcEpochInitiated().processReceipt( + tx_receipt + ) + if rich_logs: + epoch = rich_logs[0]["args"]["epoch"] + logging.info(f"[{self.global_id}] MPC epoch initiated: {epoch}") + else: + logging.info(f"[{self.global_id}] initiate_mpc failed (redundant?)") + await asyncio.sleep(10) diff --git a/apps/masks/Makefile b/apps/masks/Makefile new file mode 100644 index 00000000..a609e41d --- /dev/null +++ b/apps/masks/Makefile @@ -0,0 +1,54 @@ +.PHONY: clean clean-test clean-pyc clean-build docs help + +.DEFAULT_GOAL := help + +define PRINT_HELP_PYSCRIPT +import re, sys + +for line in sys.stdin: + match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) + if match: + target, help = match.groups() + print("%-20s %s" % (target, help)) +endef + +export PRINT_HELP_PYSCRIPT + + + +help: + @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) + +clean: clean-pyc + +clean-pyc: ## remove Python file artifacts + docker-compose run --no-deps --rm mpcnet find . -name '*.pyc' -exec rm -f {} + + docker-compose run --no-deps --rm mpcnet find . -name '*.pyo' -exec rm -f {} + + docker-compose run --no-deps --rm mpcnet find . -name '*~' -exec rm -f {} + + docker-compose run --no-deps --rm mpcnet find . -name '__pycache__' -exec rm -fr {} + + +down: ## stop and remove containers, networks, images, and volumes + docker-compose down + +stop: + docker-compose stop blockchain mpcnet + +rm: + docker-compose rm --stop --force blockchain setup mpcnet client + +run: rm ## run the example + docker-compose up -d blockchain + docker-compose up setup + docker-compose up -d client + sh follow-logs-with-tmux.sh + +run-without-tmux: rm ## run the example + docker-compose up -d blockchain + docker-compose up setup + docker-compose up -d client + docker-compose logs --follow blockchain mpcnet client + +setup: rm + docker-compose up -d blockchain + docker-compose up setup + docker-compose rm --stop --force blockchain diff --git a/apps/masks/README.md b/apps/masks/README.md new file mode 100644 index 00000000..37500c96 --- /dev/null +++ b/apps/masks/README.md @@ -0,0 +1,18 @@ +# Masks App +Simple MPC application in which a client can send a masked message to +an Ethereum (public) contract, and a network of servers perform a +simple Multi-Party Computation (MPC) to un-mask the secret message. + +This simple application contains the building blocks for more complex +applications such as the message mixing application `asynchromix` found +under [apps/asynchromix](../asynchromix). + +To run a demo of this app: + +```shell +$ make run +``` + +You should then see something like: + +![example](./example-tmux.png) diff --git a/apps/masks/__init__.py b/apps/masks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/masks/client.py b/apps/masks/client.py new file mode 100644 index 00000000..f91c04fe --- /dev/null +++ b/apps/masks/client.py @@ -0,0 +1,28 @@ +import argparse +import asyncio +from pathlib import Path + +PARENT_DIR = Path(__file__).resolve().parent + + +async def main(config_file): + from apps.baseclient import Client + + client = Client.from_toml_config(config_file) + await client.join() + + +if __name__ == "__main__": + # arg parsing + default_config_path = PARENT_DIR.joinpath("client.toml") + parser = argparse.ArgumentParser(description="MPC client.") + parser.add_argument( + "-c", + "--config-file", + default=str(default_config_path), + help=f"Configuration file to use. Defaults to '{default_config_path}'.", + ) + args = parser.parse_args() + + # Launch a client + asyncio.run(main(args.config_file)) diff --git a/apps/masks/client.toml b/apps/masks/client.toml new file mode 100644 index 00000000..662e2251 --- /dev/null +++ b/apps/masks/client.toml @@ -0,0 +1,44 @@ +# Client configuation + +# the context is used to evaluate relative file paths +context = "." + +id = "client" +session_id = "sid" +eth_address = '' + +# MPC network params +n = 4 +t = 1 + +# Ethereum blockchain +[eth] +rpc_host = "blockchain" +rpc_port = 8545 + +# contract +# the contract file path can be relative or absolute +contract_path = "contract.sol" +contract_name = "MpcCoordinator" +public_data_dir = "public-data" + +# MPC servers +[[servers]] +id = 0 +host = "mpcnet" +port = 8080 + +[[servers]] +id = 1 +host = "mpcnet" +port = 8081 + +[[servers]] +id = 2 +host = "mpcnet" +port = 8082 + +[[servers]] +id = 3 +host = "mpcnet" +port = 8083 diff --git a/apps/masks/config.py b/apps/masks/config.py new file mode 100644 index 00000000..cce25c00 --- /dev/null +++ b/apps/masks/config.py @@ -0,0 +1,8 @@ +from pathlib import Path + +PARENT_DIR = Path(__file__).resolve().parent +PUBLIC_DATA_DIR = "public-data" +CONTRACT_ADDRESS_FILENAME = "contract_address" +CONTRACT_ADDRESS_FILEPATH = PARENT_DIR.joinpath( + PUBLIC_DATA_DIR, CONTRACT_ADDRESS_FILENAME +) diff --git a/apps/masks/contract.sol b/apps/masks/contract.sol new file mode 100644 index 00000000..eb756b6a --- /dev/null +++ b/apps/masks/contract.sol @@ -0,0 +1,212 @@ +pragma solidity >=0.4.22 <0.6.0; + +contract MpcCoordinator { + /* A blockchain-based MPC coordinator. + * 1. Keeps track of the MPC "preprocessing buffer" + * 2. Accepts client input + * (makes use of preprocess randoms) + * 3. Initiates MPC epochs (MPC computations) + * (can make use of preprocessing values if needed) + */ + + // Session parameters + uint public n; + uint public t; + address[] public servers; + mapping (address => uint) public servermap; + + constructor(address[] _servers, uint _t) public { + n = _servers.length; + t = _t; + require(3*t < n); + servers.length = n; + for (uint i = 0; i < n; i++) { + servers[i] = _servers[i]; + servermap[_servers[i]] = i+1; // servermap is off-by-one + } + } + + // ############################################### + // 1. Preprocessing Buffer (the MPC offline phase) + // ############################################### + + struct PreProcessCount { + uint inputmasks; // [r] + } + + // Consensus count (min of the player report counts) + PreProcessCount public preprocess; + + // How many of each have been reserved already + PreProcessCount public preprocess_used; + + function inputmasks_available () public view returns(uint) { + return preprocess.inputmasks - preprocess_used.inputmasks; + } + + // Report of preprocess buffer size from each server + mapping ( uint => PreProcessCount ) public preprocess_reports; + + event PreProcessUpdated(); + + function min(uint a, uint b) private pure returns (uint) { + return a < b ? a : b; + } + + function max(uint a, uint b) private pure returns (uint) { + return a > b ? a : b; + } + + function preprocess_report(uint[1] rep) public { + // Update the Report + require(servermap[msg.sender] > 0); // only valid servers + uint id = servermap[msg.sender] - 1; + preprocess_reports[id].inputmasks = rep[0]; + + // Update the consensus + // .triples = min (over each id) of _reports[id].triples; same for bits, etc. + PreProcessCount memory mins; + mins.inputmasks = preprocess_reports[0].inputmasks; + for (uint i = 1; i < n; i++) { + mins.inputmasks = min(mins.inputmasks, preprocess_reports[i].inputmasks); + } + if (preprocess.inputmasks < mins.inputmasks) { + emit PreProcessUpdated(); + } + preprocess.inputmasks = mins.inputmasks; + } + + + + // ###################### + // 2. Accept client input + // ###################### + + // Step 2.a. Clients can reserve an input mask [r] from Preprocessing + + // maps each element of preprocess.inputmasks to the client (if any) that claims it + mapping (uint => address) public inputmasks_claimed; + + event InputMaskClaimed(address client, uint inputmask_idx); + + // Client reserves a random values + function reserve_inputmask() public returns(uint) { + // Extension point: override this function to add custom token rules + + // An unclaimed input mask must already be available + require(preprocess.inputmasks > preprocess_used.inputmasks); + + // Acquire this input mask for msg.sender + uint idx = preprocess_used.inputmasks; + inputmasks_claimed[idx] = msg.sender; + preprocess_used.inputmasks += 1; + emit InputMaskClaimed(msg.sender, idx); + return idx; + } + + // Step 2.b. Client requests (out of band, e.g. over https) shares of [r] + // from each server. Servers use this function to check authorization. + // Authentication using client's address is also out of band + function client_authorized(address client, uint idx) view public returns(bool) { + return inputmasks_claimed[idx] == client; + } + + // Step 2.c. Clients publish masked message (m+r) to provide a new input [m] + // and bind it to the preprocess input + mapping (uint => bool) public inputmask_map; // Maps a mask + + struct Input { + bytes32 masked_input; // (m+r) + uint inputmask; // index in inputmask of mask [r] + + // Extension point: add more metadata about each input + } + + Input[] public input_queue; // All inputs sent so far + function input_queue_length() public view returns(uint) { + return input_queue.length; + } + + event MessageSubmitted(uint idx, uint inputmask_idx, bytes32 masked_input); + + function submit_message(uint inputmask_idx, bytes32 masked_input) public { + // Must be authorized to use this input mask + require(inputmasks_claimed[inputmask_idx] == msg.sender); + + // Extension point: add additional client authorizations, + // e.g. prevent the client from submitting more than one message per mix + + uint idx = input_queue.length; + input_queue.length += 1; + + input_queue[idx].masked_input = masked_input; + input_queue[idx].inputmask = inputmask_idx; + + // QUESTION: What is the purpose of this event? + emit MessageSubmitted(idx, inputmask_idx, masked_input); + + // The input masks are deactivated after first use + inputmasks_claimed[inputmask_idx] = address(0); + } + + // ###################### + // 3. Initiate MPC Epochs + // ###################### + + uint public constant K = 1; // number of messages per epoch + + // Step 3.a. Trigger MPC to start + uint public inputs_unmasked; + uint public epochs_initiated; + event MpcEpochInitiated(uint epoch); + + function inputs_ready() public view returns(uint) { + return input_queue.length - inputs_unmasked; + } + + function initiate_mpc() public { + // Must unmask eactly K values in each epoch + require(input_queue.length >= inputs_unmasked + K); + inputs_unmasked += K; + emit MpcEpochInitiated(epochs_initiated); + epochs_initiated += 1; + output_votes.length = epochs_initiated; + output_hashes.length = epochs_initiated; + } + + // Step 3.b. Output reporting: the output is considered "approved" once + // at least t+1 servers report it + + uint public outputs_ready; + event MpcOutput(uint epoch, string output); + bytes32[] public output_hashes; + uint[] public output_votes; + mapping (uint => uint) public server_voted; // highest epoch voted in + + function propose_output(uint epoch, string output) public { + require(epoch < epochs_initiated); // can't provide output if it hasn't been initiated + require(servermap[msg.sender] > 0); // only valid servers + uint id = servermap[msg.sender] - 1; + + // Each server can only vote once per epoch + // Hazard note: honest servers must vote in strict ascending order, or votes + // will be lost! + require(epoch <= server_voted[id]); + server_voted[id] = max(epoch + 1, server_voted[id]); + + bytes32 output_hash = sha3(output); + + if (output_votes[epoch] > 0) { + // All the votes must match + require(output_hash == output_hashes[epoch]); + } else { + output_hashes[epoch] = output_hash; + } + + output_votes[epoch] += 1; + if (output_votes[epoch] == t + 1) { // at least one honest node agrees + emit MpcOutput(epoch, output); + outputs_ready += 1; + } + } +} diff --git a/apps/masks/docker-compose.yml b/apps/masks/docker-compose.yml new file mode 100644 index 00000000..0f847178 --- /dev/null +++ b/apps/masks/docker-compose.yml @@ -0,0 +1,40 @@ +version: '3.7' + +services: + blockchain: + container_name: blockchain + image: trufflesuite/ganache-cli + command: --accounts 50 --blockTime 1 > acctKeys.json 2>&1 + setup: + image: honeybadgermpc-local + build: + context: ../.. + dockerfile: Dockerfile + volumes: + - ../../apps:/usr/src/HoneyBadgerMPC/apps + - ../../honeybadgermpc:/usr/src/honeybadgermpc/honeybadgermpc + depends_on: + - blockchain + command: ["./apps/wait-for-it.sh", "blockchain:8545", "--", "python", "apps/masks/setup_phase.py"] + mpcnet: + image: honeybadgermpc-local + build: + context: ../.. + dockerfile: Dockerfile + volumes: + - ../../apps:/usr/src/HoneyBadgerMPC/apps + - ../../honeybadgermpc:/usr/src/honeybadgermpc/honeybadgermpc + depends_on: + - setup + command: ["./apps/wait-for-it.sh", "blockchain:8545", "--", "python", "apps/masks/mpcnet.py"] + client: + image: honeybadgermpc-local + build: + context: ../.. + dockerfile: Dockerfile + volumes: + - ../../apps:/usr/src/HoneyBadgerMPC/apps + - ../../honeybadgermpc:/usr/src/honeybadgermpc/honeybadgermpc + depends_on: + - mpcnet + command: ["./apps/wait-for-it.sh", "mpcnet:8083", "--", "python", "apps/masks/client.py"] diff --git a/apps/masks/example-tmux.png b/apps/masks/example-tmux.png new file mode 100644 index 00000000..77d2f00f Binary files /dev/null and b/apps/masks/example-tmux.png differ diff --git a/apps/masks/follow-logs-with-tmux.sh b/apps/masks/follow-logs-with-tmux.sh new file mode 100755 index 00000000..1d56a6c3 --- /dev/null +++ b/apps/masks/follow-logs-with-tmux.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +if [ -z $TMUX ]; then + echo "tmux is not active, will start new session" + TMUX_CMD="new-session" +else + echo "tmux is active, will launch into new window" + TMUX_CMD="new-window" +fi + +tmux $TMUX_CMD "docker-compose logs -f blockchain; sh" \; \ + splitw -h -p 50 "docker-compose logs -f setup; sh" \; \ + splitw -v -p 50 "docker-compose logs -f mpcnet; sh" \; \ + selectp -t 0 \; \ + splitw -v -p 50 "docker-compose logs -f client; sh" diff --git a/apps/masks/mpcnet.py b/apps/masks/mpcnet.py new file mode 100644 index 00000000..2a66a6a5 --- /dev/null +++ b/apps/masks/mpcnet.py @@ -0,0 +1,71 @@ +import argparse +import asyncio +from pathlib import Path + +import toml + +from apps.masks.server import Server + +from honeybadgermpc.preprocessing import PreProcessedElements +from honeybadgermpc.router import SimpleRouter + +PARENT_DIR = Path(__file__).resolve().parent + + +class MPCNet: + def __init__(self, servers): + self.servers = servers + pp_elements = PreProcessedElements() + pp_elements.clear_preprocessing() # deletes sharedata/ if present + + @classmethod + def from_toml_config(cls, config_path): + config = toml.load(config_path) + + # TODO extract resolving of relative path into utils + context_path = Path(config_path).resolve().parent.joinpath(config["context"]) + config["eth"]["contract_path"] = context_path.joinpath( + config["eth"]["contract_path"] + ) + + n = config["n"] + + # communication channels + router = SimpleRouter(n) + sends, recvs = router.sends, router.recvs + + base_config = {k: v for k, v in config.items() if k != "servers"} + servers = [] + for i in range(n): + server_config = {k: v for k, v in config["servers"][i].items()} + server_config.update(base_config, session_id="sid") + server = Server.from_dict_config( + server_config, send=sends[i], recv=recvs[i] + ) + servers.append(server) + return cls(servers) + + async def start(self): + for server in self.servers: + await server.join() + + +async def main(config_file): + mpcnet = MPCNet.from_toml_config(config_file) + await mpcnet.start() + + +if __name__ == "__main__": + # arg parsing + default_config_path = PARENT_DIR.joinpath("mpcnet.toml") + parser = argparse.ArgumentParser(description="MPC network.") + parser.add_argument( + "-c", + "--config-file", + default=str(default_config_path), + help=f"Configuration file to use. Defaults to '{default_config_path}'.", + ) + args = parser.parse_args() + + # Launch MPC network + asyncio.run(main(args.config_file)) diff --git a/apps/masks/mpcnet.toml b/apps/masks/mpcnet.toml new file mode 100644 index 00000000..a1828ec9 --- /dev/null +++ b/apps/masks/mpcnet.toml @@ -0,0 +1,40 @@ +# MPC network configuation + +# the context is used to evaluate relative file paths +context = "." + +# MPC network params +n = 4 +t = 1 + +# Ethereum blockchain +[eth] +rpc_host = "blockchain" +rpc_port = 8545 + +# contract +# the contract file path can be relative or absolute +contract_path = "contract.sol" +contract_name = "MpcCoordinator" +public_data_dir = "public-data" + +# MPC servers +[[servers]] +id = 0 +host = "mpcnet" +port = 8080 + +[[servers]] +id = 1 +host = "mpcnet" +port = 8081 + +[[servers]] +id = 2 +host = "mpcnet" +port = 8082 + +[[servers]] +id = 3 +host = "mpcnet" +port = 8083 diff --git a/apps/masks/public-data/.gitignore b/apps/masks/public-data/.gitignore new file mode 100644 index 00000000..2ebf87a4 --- /dev/null +++ b/apps/masks/public-data/.gitignore @@ -0,0 +1 @@ +contract_address diff --git a/apps/masks/public-data/README.md b/apps/masks/public-data/README.md new file mode 100644 index 00000000..6abfe4eb --- /dev/null +++ b/apps/masks/public-data/README.md @@ -0,0 +1,14 @@ +# Public Data Directory +This directory, `public-data`, serves the purpose of storing data that +does not require any privacy protection, and is meant to be accessible +by all participants involved in an MPC computation, including the +MPC servers, clients, coordinator, and any other entity that may play a +role in the MPC computation from the point of view of making its +execution a reality. + +An example of such data is the address of the MPC coordinator contract, +which is needed by clients and MPC servers. + +**NOTE**: This data is not meant to be tracked by a version control +system (git) as it may differ from one MPC protocol execution to +another. diff --git a/apps/masks/public-data/config.toml b/apps/masks/public-data/config.toml new file mode 100644 index 00000000..b424e840 --- /dev/null +++ b/apps/masks/public-data/config.toml @@ -0,0 +1,37 @@ +n = 4 +t = 1 +deployer_address = "0x7A09EF88e6371d2c838611c24b8d6E8B2d2D81d0" +[[servers]] +id = 0 +host = "mpcnet" +port = 8080 +address = "0x4106F7F1F23Ea636ed395c8ba7E84945E5AFb1f6" +eth_address = "0x5B538955Aa2fafC533d9CdB98f5FCbA5433281fc" + +[[servers]] +id = 1 +host = "mpcnet" +port = 8081 +address = "0xAC8534EAf2777BF1558052Bf876e2248e8897231" +eth_address = "0x4a597657De07b4740BFa1E381Ad98d4a58587901" + +[[servers]] +id = 2 +host = "mpcnet" +port = 8082 +address = "0xa6688B66878af966f6Cc13bc5E8E49717209e648" +eth_address = "0x797A2d8DF92c08F75B1B2CA9192BFC8Ec5E581A4" + +[[servers]] +id = 3 +host = "mpcnet" +port = 8083 +address = "0xE2F90DDA3F39b048284A259984149cDce38cd32F" +eth_address = "0xaDc5E5c05f30aD30e20B6bc34B85b3380956F526" + +[eth] +rpc_host = "blockchain" +rpc_port = 8545 +contract_filename = "contract.sol" +contract_name = "MpcCoordinator" +contract_address = "0xeF631D9D1125f4837B2aa86254eb81B045a7c303" diff --git a/apps/masks/screenshot-tmux.png b/apps/masks/screenshot-tmux.png new file mode 100644 index 00000000..1c42cece Binary files /dev/null and b/apps/masks/screenshot-tmux.png differ diff --git a/apps/masks/server.py b/apps/masks/server.py new file mode 100644 index 00000000..0f7b6e64 --- /dev/null +++ b/apps/masks/server.py @@ -0,0 +1,340 @@ +import asyncio +import logging +import time +from pathlib import Path + +from aiohttp import web + +import toml + +from web3.contract import ConciseContract + +from apps.utils import fetch_contract, wait_for_receipt + +from honeybadgermpc.elliptic_curve import Subgroup +from honeybadgermpc.field import GF +from honeybadgermpc.mpc import Mpc +from honeybadgermpc.offline_randousha import randousha +from honeybadgermpc.utils.misc import ( + print_exception_callback, + subscribe_recv, + wrap_send, +) + +field = GF(Subgroup.BLS12_381) + + +class Server: + """MPC server class to ...""" + + def __init__( + self, + sid, + myid, + send, + recv, + w3, + *, + contract_context, + http_host="0.0.0.0", + http_port=8080, + ): + """ + Parameters + ---------- + sid: int + Session id. + myid: int + Client id. + send: + Function used to send messages. + recv: + Function used to receive messages. + w3: + Connection instance to an Ethereum node. + contract_context: dict + Contract attributes needed to interact with the contract + using web3. Should contain the address, name and source code + file path. + """ + self.sid = sid + self.myid = myid + self._contract_context = contract_context + self.contract = fetch_contract(w3, **contract_context) + self.w3 = w3 + self._init_tasks() + self._subscribe_task, subscribe = subscribe_recv(recv) + self._http_host = http_host + self._http_port = http_port + + def _get_send_recv(tag): + return wrap_send(tag, send), subscribe(tag) + + self.get_send_recv = _get_send_recv + self._inputmasks = [] + + def _init_tasks(self): + self._task1 = asyncio.ensure_future(self._offline_inputmasks_loop()) + self._task1.add_done_callback(print_exception_callback) + self._task2 = asyncio.ensure_future(self._client_request_loop()) + self._task2.add_done_callback(print_exception_callback) + self._task3 = asyncio.ensure_future(self._mpc_loop()) + self._task3.add_done_callback(print_exception_callback) + self._task4 = asyncio.ensure_future(self._mpc_initiate_loop()) + self._task4.add_done_callback(print_exception_callback) + # self._http_server = asyncio.create_task(self._client_request_loop()) + # self._http_server.add_done_callback(print_exception_callback) + + @classmethod + def from_dict_config(cls, config, *, send, recv): + """Create a ``Server`` class instance from a config dict. + + Parameters + ---------- + config : dict + The configuration to create the ``Server`` instance. + send: + Function used to send messages. + recv: + Function used to receive messages. + """ + from web3 import HTTPProvider, Web3 + from apps.masks.config import CONTRACT_ADDRESS_FILEPATH + from apps.utils import get_contract_address + + eth_config = config["eth"] + # contract + contract_context = { + "address": get_contract_address(CONTRACT_ADDRESS_FILEPATH), + "filepath": eth_config["contract_path"], + "name": eth_config["contract_name"], + } + + # web3 + eth_rpc_hostname = eth_config["rpc_host"] + eth_rpc_port = eth_config["rpc_port"] + w3_endpoint_uri = f"http://{eth_rpc_hostname}:{eth_rpc_port}" + w3 = Web3(HTTPProvider(w3_endpoint_uri)) + + return cls( + config["session_id"], + config["id"], + send, + recv, + w3, + contract_context=contract_context, + http_host=config["host"], + http_port=config["port"], + ) + + @classmethod + def from_toml_config(cls, config_path, *, send, recv): + """Create a ``Server`` class instance from a config TOML file. + + Parameters + ---------- + config_path : str + The path to the TOML configuration file to create the + ``Server`` instance. + send: + Function used to send messages. + recv: + Function used to receive messages. + """ + config = toml.load(config_path) + # TODO extract resolving of relative path into utils + context_path = Path(config_path).resolve().parent.joinpath(config["context"]) + config["eth"]["contract_path"] = context_path.joinpath( + config["eth"]["contract_path"] + ) + return cls.from_dict_config(config, send=send, recv=recv) + + async def join(self): + await self._task1 + await self._task2 + await self._task3 + await self._task4 + await self._subscribe_task + # await self._http_server + await self._client_request_loop() + + ####################### + # Step 1. Offline Phase + ####################### + """ + 1a. offline inputmasks + """ + + async def _preprocess_report(self): + # Submit the preprocessing report + tx_hash = self.contract.functions.preprocess_report( + [len(self._inputmasks)] + ).transact({"from": self.w3.eth.accounts[self.myid]}) + + # Wait for the tx receipt + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + return tx_receipt + + async def _offline_inputmasks_loop(self): + contract_concise = ConciseContract(self.contract) + n = contract_concise.n() + t = contract_concise.t() + preproc_round = 0 + k = 1 + while True: + # Step 1. I) Wait until needed + while True: + inputmasks_available = contract_concise.inputmasks_available() + totalmasks = contract_concise.preprocess() + # Policy: try to maintain a buffer of 10 input masks + target = 10 + if inputmasks_available < target: + break + # already have enough input masks, sleep + await asyncio.sleep(5) + + # Step 1. II) Run Randousha + logging.info( + f"[{self.myid}] totalmasks: {totalmasks} \ + inputmasks available: {inputmasks_available} \ + target: {target} Initiating Randousha {k * (n - 2*t)}" + ) + send, recv = self.get_send_recv(f"preproc:inputmasks:{preproc_round}") + start_time = time.time() + rs_t, rs_2t = zip(*await randousha(n, t, k, self.myid, send, recv, field)) + assert len(rs_t) == len(rs_2t) == k * (n - 2 * t) + + # Note: here we just discard the rs_2t + # In principle both sides of randousha could be used with + # a small modification to randousha + end_time = time.time() + logging.debug(f"[{self.myid}] Randousha finished in {end_time-start_time}") + logging.debug(f"len(rs_t): {len(rs_t)}") + logging.debug(f"rs_t: {rs_t}") + self._inputmasks += rs_t + + # Step 1. III) Submit an updated report + await self._preprocess_report() + + # Increment the preprocessing round and continue + preproc_round += 1 + + ################################## + # Web server for input mask shares + ################################## + + async def _client_request_loop(self): + """ Task 2. Handling client input + + .. todo:: if a client requests a share, check if it is + authorized and if so send it along + + """ + routes = web.RouteTableDef() + + @routes.get("/inputmasks/{idx}") + async def _handler(request): + idx = int(request.match_info.get("idx")) + inputmask = self._inputmasks[idx] + data = { + "inputmask": inputmask, + "server_id": self.myid, + "server_port": self._http_port, + } + return web.json_response(data) + + app = web.Application() + app.add_routes(routes) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, host=self._http_host, port=self._http_port) + await site.start() + print(f"======= Serving on http://{self._http_host}:{self._http_port}/ ======") + # pause here for very long time by serving HTTP requests and + # waiting for keyboard interruption + await asyncio.sleep(100 * 3600) + + async def _mpc_loop(self): + # Task 3. Participating in MPC epochs + contract_concise = ConciseContract(self.contract) + n = contract_concise.n() + t = contract_concise.t() + + epoch = 0 + while True: + # 3.a. Wait for the next MPC to be initiated + while True: + epochs_initiated = contract_concise.epochs_initiated() + if epochs_initiated > epoch: + break + await asyncio.sleep(5) + + # 3.b. Collect the input + # Get the public input (masked message) + masked_message_bytes, inputmask_idx = contract_concise.input_queue(epoch) + masked_message = field(int.from_bytes(masked_message_bytes, "big")) + inputmask = self._inputmasks[inputmask_idx] # Get the input mask + msg_field_elem = masked_message - inputmask + + # 3.d. Call the MPC program + async def prog(ctx): + logging.info(f"[{ctx.myid}] Running MPC network") + msg_share = ctx.Share(msg_field_elem) + opened_value = await msg_share.open() + msg = opened_value.value.to_bytes(32, "big").decode().strip("\x00") + return msg + + send, recv = self.get_send_recv(f"mpc:{epoch}") + logging.info(f"[{self.myid}] MPC initiated:{epoch}") + + config = {} + ctx = Mpc(f"mpc:{epoch}", n, t, self.myid, send, recv, prog, config) + result = await ctx._run() + logging.info(f"[{self.myid}] MPC complete {result}") + + # 3.e. Output the published messages to contract + tx_hash = self.contract.functions.propose_output(epoch, result).transact( + {"from": self.w3.eth.accounts[self.myid]} + ) + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + rich_logs = self.contract.events.MpcOutput().processReceipt(tx_receipt) + if rich_logs: + epoch = rich_logs[0]["args"]["epoch"] + output = rich_logs[0]["args"]["output"] + logging.info(f"[{self.myid}] MPC OUTPUT[{epoch}] {output}") + + epoch += 1 + + async def _mpc_initiate_loop(self): + # Task 4. Initiate MPC epochs + contract_concise = ConciseContract(self.contract) + K = contract_concise.K() # noqa: N806 + while True: + # Step 4.a. Wait until there are k values then call initiate_mpc + while True: + inputs_ready = contract_concise.inputs_ready() + if inputs_ready >= K: + break + await asyncio.sleep(5) + + # Step 4.b. Call initiate_mpc + try: + tx_hash = self.contract.functions.initiate_mpc().transact( + {"from": self.w3.eth.accounts[0]} + ) + except ValueError as err: + # Since only one server is needed to initiate the MPC, once + # intiated, a ValueError will occur due to the race condition + # between the servers. + logging.debug(err) + continue + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + rich_logs = self.contract.events.MpcEpochInitiated().processReceipt( + tx_receipt + ) + if rich_logs: + epoch = rich_logs[0]["args"]["epoch"] + logging.info(f"[{self.myid}] MPC epoch initiated: {epoch}") + else: + logging.info(f"[{self.myid}] initiate_mpc failed (redundant?)") + await asyncio.sleep(10) diff --git a/apps/masks/setup_phase.py b/apps/masks/setup_phase.py new file mode 100644 index 00000000..fe79645c --- /dev/null +++ b/apps/masks/setup_phase.py @@ -0,0 +1,106 @@ +import argparse +import logging +import pprint +from pathlib import Path + +import toml + +from web3 import HTTPProvider, Web3 + +from apps.masks.config import CONTRACT_ADDRESS_FILEPATH +from apps.utils import create_and_deploy_contract + +PARENT_DIR = Path(__file__).resolve().parent + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def set_eth_addrs(config_dict, config_filepath): + """Set eth addresses for the contract deployer, the MPC servers and + the client and update the given config file. + + Parameters + ---------- + config_dict : dict + Configuration dict to update with eth addresses. + config_filepath : str + Toml file path to the configuration to update. + """ + raise NotImplementedError + + +def deploy_contract( + w3, *, contract_name, contract_filepath, n=4, t=1, deployer_addr, mpc_addrs +): + contract_address, abi = create_and_deploy_contract( + w3, + deployer=deployer_addr, + contract_name=contract_name, + contract_filepath=contract_filepath, + args=(mpc_addrs, t), + ) + return contract_address + + +if __name__ == "__main__": + # TODO figure out why logging does not show up in the output + # NOTE appears to be a configuration issue with respect to the + # level as `.warning()` works. + logger.info(f"Deploying contract ...") + print(f"Deploying contract ...") + + default_config_path = PARENT_DIR.joinpath("public-data/config.toml") + parser = argparse.ArgumentParser(description="Setup phase.") + parser.add_argument( + "-c", + "--config-file", + default=str(default_config_path), + help=f"Configuration file to use. Defaults to '{default_config_path}'.", + ) + args = parser.parse_args() + config_file = args.config_file + config = toml.load(config_file) + print(config) + + n = config["n"] + t = config["t"] + eth_config = config["eth"] + contract_name = eth_config["contract_name"] + contract_filename = eth_config["contract_filename"] + contract_filepath = PARENT_DIR.joinpath(contract_filename) + eth_rpc_hostname = eth_config["rpc_host"] + eth_rpc_port = eth_config["rpc_port"] + w3_endpoint_uri = f"http://{eth_rpc_hostname}:{eth_rpc_port}" + w3 = Web3(HTTPProvider(w3_endpoint_uri)) + + deployer_addr = w3.eth.accounts[49] + mpc_addrs = [] + for s in config["servers"]: + mpc_addr = w3.eth.accounts[s["id"]] + s["eth_address"] = mpc_addr + mpc_addrs.append(mpc_addr) + + contract_address = deploy_contract( + w3, + contract_name=contract_name, + contract_filepath=contract_filepath, + n=n, + t=t, + deployer_addr=deployer_addr, + mpc_addrs=mpc_addrs, + ) + config["deployer_address"] = deployer_addr + config["eth"]["contract_address"] = contract_address + logger.info(f"Contract deployed at address: {contract_address}") + print(f"Contract deployed at address: {contract_address}") + + with open(config_file, "w") as f: + toml.dump(config, f) + with open(CONTRACT_ADDRESS_FILEPATH, "w") as f: + f.write(contract_address) + + logger.info(f"Wrote contract address to file: {CONTRACT_ADDRESS_FILEPATH}") + print(f"Wrote contract address to file: {CONTRACT_ADDRESS_FILEPATH}") + print(f"\nUpdated common config file: {config_file}\n") + pprint.pprint(config) diff --git a/apps/utils.py b/apps/utils.py new file mode 100644 index 00000000..263f733d --- /dev/null +++ b/apps/utils.py @@ -0,0 +1,172 @@ +import asyncio +import logging + +from ethereum.tools._solidity import compile_code as compile_source + +from web3.exceptions import TransactionNotFound + + +async def wait_for_receipt(w3, tx_hash): + while True: + try: + tx_receipt = w3.eth.getTransactionReceipt(tx_hash) + except TransactionNotFound: + tx_receipt = None + if tx_receipt is not None: + break + await asyncio.sleep(5) + return tx_receipt + + +def compile_contract_source(filepath): + """Compiles the contract located in given file path. + + filepath : str + File path to the contract. + """ + with open(filepath, "r") as f: + source = f.read() + return compile_source(source) + + +def get_contract_interface(*, contract_name, contract_filepath): + compiled_sol = compile_contract_source(contract_filepath) + try: + contract_interface = compiled_sol[f":{contract_name}"] + except KeyError: + logging.error(f"Contract {contract_name} not found") + raise + + return contract_interface + + +def get_contract_abi(*, contract_name, contract_filepath): + ci = get_contract_interface( + contract_name=contract_name, contract_filepath=contract_filepath + ) + return ci["abi"] + + +def deploy_contract(w3, *, abi, bytecode, deployer, args=(), kwargs=None): + """Deploy the contract. + + Parameters + ---------- + w3 : + Web3-based connection to an Ethereum network. + abi : + ABI of the contract to deploy. + bytecode : + Bytecode of the contract to deploy. + deployer : str + Ethereum address of the deployer. The deployer is the one + making the transaction to deploy the contract, meaning that + the costs of the transaction to deploy the contract are consumed + from the ``deployer`` address. + args : tuple, optional + Positional arguments to be passed to the contract constructor. + Defaults to ``()``. + kwargs : dict, optional + Keyword arguments to be passed to the contract constructor. + Defaults to ``{}``. + + Returns + ------- + contract_address: str + Contract address in hexadecimal format. + + Raises + ------ + ValueError + If the contract deployment failed. + """ + if kwargs is None: + kwargs = {} + contract_class = w3.eth.contract(abi=abi, bytecode=bytecode) + tx_hash = contract_class.constructor(*args, **kwargs).transact({"from": deployer}) + + # Get tx receipt to get contract address + tx_receipt = w3.eth.waitForTransactionReceipt(tx_hash) + contract_address = tx_receipt["contractAddress"] + + if w3.eth.getCode(contract_address) == b"": + err_msg = "code was empty 0x, constructor may have run out of gas" + logging.critical(err_msg) + raise ValueError(err_msg) + return contract_address + + +def create_and_deploy_contract( + w3, *, deployer, contract_name, contract_filepath, args=(), kwargs=None +): + """Create and deploy the contract. + + Parameters + ---------- + w3 : + Web3-based connection to an Ethereum network. + deployer : str + Ethereum address of the deployer. The deployer is the one + making the transaction to deploy the contract, meaning that + the costs of the transaction to deploy the contract are consumed + from the ``deployer`` address. + contract_name : str + Name of the contract to be created. + contract_filepath : str + Path of the Solidity contract file. + args : tuple, optional + Positional arguments to be passed to the contract constructor. + Defaults to ``()``. + kwargs : dict, optional + Keyword arguments to be passed to the contract constructor. + Defaults to ``None``. + + Returns + ------- + contract_address: str + Contract address in hexadecimal format. + abi: + Contract abi. + """ + compiled_sol = compile_contract_source(contract_filepath) + contract_interface = compiled_sol[f":{contract_name}"] + abi = contract_interface["abi"] + contract_address = deploy_contract( + w3, + abi=abi, + bytecode=contract_interface["bin"], + deployer=deployer, + args=args, + kwargs=kwargs, + ) + return contract_address, abi + + +def get_contract_address(filepath): + with open(filepath, "r") as f: + line = f.readline() + contract_address = line.strip() + return contract_address + + +def fetch_contract(w3, *, address, name, filepath): + """Fetch a contract using the given web3 connection, and contract + attributes. + + Parameters + ---------- + address : str + Ethereum address of the contract. + name : str + Name of the contract. + filepath : str + File path to the source code of the contract. + + Returns + ------- + web3.contract.Contract + The ``web3`` ``Contract`` object. + """ + abi = get_contract_abi(contract_name=name, contract_filepath=filepath) + contract = w3.eth.contract(address=address, abi=abi) + return contract diff --git a/apps/wait-for-it.sh b/apps/wait-for-it.sh new file mode 100755 index 00000000..c5773a44 --- /dev/null +++ b/apps/wait-for-it.sh @@ -0,0 +1,184 @@ +#!/usr/bin/env bash +# Use this script to test if a given TCP host/port are available + +# Source: https://github.com/vishnubob/wait-for-it/blob/c096cface5fbd9f2d6b037391dfecae6fde1362e/wait-for-it.sh + +WAITFORIT_cmdname=${0##*/} + +echoerr() { if [[ $WAITFORIT_QUIET -ne 1 ]]; then echo "$@" 1>&2; fi } + +usage() +{ + cat << USAGE >&2 +Usage: + $WAITFORIT_cmdname host:port [-s] [-t timeout] [-- command args] + -h HOST | --host=HOST Host or IP under test + -p PORT | --port=PORT TCP port under test + Alternatively, you specify the host and port as host:port + -s | --strict Only execute subcommand if the test succeeds + -q | --quiet Don't output any status messages + -t TIMEOUT | --timeout=TIMEOUT + Timeout in seconds, zero for no timeout + -- COMMAND ARGS Execute command with args after the test finishes +USAGE + exit 1 +} + +wait_for() +{ + if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then + echoerr "$WAITFORIT_cmdname: waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT" + else + echoerr "$WAITFORIT_cmdname: waiting for $WAITFORIT_HOST:$WAITFORIT_PORT without a timeout" + fi + WAITFORIT_start_ts=$(date +%s) + while : + do + if [[ $WAITFORIT_ISBUSY -eq 1 ]]; then + nc -z $WAITFORIT_HOST $WAITFORIT_PORT + WAITFORIT_result=$? + else + (echo > /dev/tcp/$WAITFORIT_HOST/$WAITFORIT_PORT) >/dev/null 2>&1 + WAITFORIT_result=$? + fi + if [[ $WAITFORIT_result -eq 0 ]]; then + WAITFORIT_end_ts=$(date +%s) + echoerr "$WAITFORIT_cmdname: $WAITFORIT_HOST:$WAITFORIT_PORT is available after $((WAITFORIT_end_ts - WAITFORIT_start_ts)) seconds" + break + fi + sleep 1 + done + return $WAITFORIT_result +} + +wait_for_wrapper() +{ + # In order to support SIGINT during timeout: http://unix.stackexchange.com/a/57692 + if [[ $WAITFORIT_QUIET -eq 1 ]]; then + timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --quiet --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT & + else + timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT & + fi + WAITFORIT_PID=$! + trap "kill -INT -$WAITFORIT_PID" INT + wait $WAITFORIT_PID + WAITFORIT_RESULT=$? + if [[ $WAITFORIT_RESULT -ne 0 ]]; then + echoerr "$WAITFORIT_cmdname: timeout occurred after waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT" + fi + return $WAITFORIT_RESULT +} + +# process arguments +while [[ $# -gt 0 ]] +do + case "$1" in + *:* ) + WAITFORIT_hostport=(${1//:/ }) + WAITFORIT_HOST=${WAITFORIT_hostport[0]} + WAITFORIT_PORT=${WAITFORIT_hostport[1]} + shift 1 + ;; + --child) + WAITFORIT_CHILD=1 + shift 1 + ;; + -q | --quiet) + WAITFORIT_QUIET=1 + shift 1 + ;; + -s | --strict) + WAITFORIT_STRICT=1 + shift 1 + ;; + -h) + WAITFORIT_HOST="$2" + if [[ $WAITFORIT_HOST == "" ]]; then break; fi + shift 2 + ;; + --host=*) + WAITFORIT_HOST="${1#*=}" + shift 1 + ;; + -p) + WAITFORIT_PORT="$2" + if [[ $WAITFORIT_PORT == "" ]]; then break; fi + shift 2 + ;; + --port=*) + WAITFORIT_PORT="${1#*=}" + shift 1 + ;; + -t) + WAITFORIT_TIMEOUT="$2" + if [[ $WAITFORIT_TIMEOUT == "" ]]; then break; fi + shift 2 + ;; + --timeout=*) + WAITFORIT_TIMEOUT="${1#*=}" + shift 1 + ;; + --) + shift + WAITFORIT_CLI=("$@") + break + ;; + --help) + usage + ;; + *) + echoerr "Unknown argument: $1" + usage + ;; + esac +done + +if [[ "$WAITFORIT_HOST" == "" || "$WAITFORIT_PORT" == "" ]]; then + echoerr "Error: you need to provide a host and port to test." + usage +fi + +WAITFORIT_TIMEOUT=${WAITFORIT_TIMEOUT:-15} +WAITFORIT_STRICT=${WAITFORIT_STRICT:-0} +WAITFORIT_CHILD=${WAITFORIT_CHILD:-0} +WAITFORIT_QUIET=${WAITFORIT_QUIET:-0} + +# Check to see if timeout is from busybox? +WAITFORIT_TIMEOUT_PATH=$(type -p timeout) +WAITFORIT_TIMEOUT_PATH=$(realpath $WAITFORIT_TIMEOUT_PATH 2>/dev/null || readlink -f $WAITFORIT_TIMEOUT_PATH) + +WAITFORIT_BUSYTIMEFLAG="" +if [[ $WAITFORIT_TIMEOUT_PATH =~ "busybox" ]]; then + WAITFORIT_ISBUSY=1 + # Check if busybox timeout uses -t flag + # (recent Alpine versions don't support -t anymore) + if timeout &>/dev/stdout | grep -q -e '-t '; then + WAITFORIT_BUSYTIMEFLAG="-t" + fi +else + WAITFORIT_ISBUSY=0 +fi + +if [[ $WAITFORIT_CHILD -gt 0 ]]; then + wait_for + WAITFORIT_RESULT=$? + exit $WAITFORIT_RESULT +else + if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then + wait_for_wrapper + WAITFORIT_RESULT=$? + else + wait_for + WAITFORIT_RESULT=$? + fi +fi + +if [[ $WAITFORIT_CLI != "" ]]; then + if [[ $WAITFORIT_RESULT -ne 0 && $WAITFORIT_STRICT -eq 1 ]]; then + echoerr "$WAITFORIT_cmdname: strict mode, refusing to execute subprocess" + exit $WAITFORIT_RESULT + fi + exec "${WAITFORIT_CLI[@]}" +else + exit $WAITFORIT_RESULT +fi diff --git a/doc8.ini b/doc8.ini index aa064011..d86deba1 100644 --- a/doc8.ini +++ b/doc8.ini @@ -1,2 +1,5 @@ [doc8] ignore-path=docs/_build,honeybadgermpc.egg-info/ +# ignore math mode nowrap errors until problem is fixed +# see https://github.com/PyCQA/doc8/pull/32 for more details +ignore=D000 diff --git a/docker-compose.yml b/docker-compose.yml index 3852d0eb..246954dd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,7 +13,9 @@ services: - ./benchmark:/usr/src/HoneyBadgerMPC/benchmark - ./aws:/usr/src/HoneyBadgerMPC/aws - ./conf:/usr/src/HoneyBadgerMPC/conf + - ./contracts:/usr/src/HoneyBadgerMPC/contracts - ./docs:/usr/src/HoneyBadgerMPC/docs + - ./doc8.ini:/usr/src/HoneyBadgerMPC/doc8.ini - ./honeybadgermpc:/usr/src/HoneyBadgerMPC/honeybadgermpc - ./scripts:/usr/src/HoneyBadgerMPC/scripts - ./tests:/usr/src/HoneyBadgerMPC/tests @@ -29,3 +31,6 @@ services: - ./pairing/setup.py:/usr/src/HoneyBadgerMPC/pairing/setup.py - /usr/src/HoneyBadgerMPC/honeybadgermpc/ntl # Directory _not_ mounted from host command: pytest -v --cov=honeybadgermpc + environment: + # FIXME temporary, should be in developer settings + PYTHONBREAKPOINT: ipdb.set_trace diff --git a/docs/conf.py b/docs/conf.py index 994c802c..1a2ad1a1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -44,12 +44,15 @@ "sphinx.ext.autodoc", "sphinx.ext.doctest", "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", "sphinx.ext.todo", "sphinx.ext.coverage", "sphinx.ext.mathjax", "sphinx.ext.viewcode", "sphinx_tabs.tabs", "m2r", + "sphinxcontrib.bibtex", + "sphinxcontrib.soliditydomain", ] autodoc_default_options = { @@ -57,7 +60,7 @@ "undoc-members": None, "private-members": None, "inherited-members": None, - "show-inheritance": None, + # "show-inheritance": None, } # Add any paths that contain templates here, relative to this directory. diff --git a/docs/index.rst b/docs/index.rst index f42df0d3..b550110a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,8 +30,9 @@ .. toctree:: :maxdepth: 1 - :caption: Integrations + :caption: Blockchain Integrations + integrations/eth integrations/hyperledger-fabric .. toctree:: diff --git a/docs/integrations/eth.rst b/docs/integrations/eth.rst new file mode 100644 index 00000000..1ee72115 --- /dev/null +++ b/docs/integrations/eth.rst @@ -0,0 +1,461 @@ +AsynchroMix with Ethereum as an MPC Coordinator +=============================================== +A blockchain can be used as a coordinating mechanism to run an MPC +program. This document covers the AsynchroMix application. + +In the paper, Reliable Broadcast and Common Subset are used to +"coordinate" the MPC operations. Below is the protocol as it is in the +paper and after that the AsynchroMix protocol is revisited and +presented as it is implemented in :mod:`apps.asynchromix.asynchromix` +where Ethereum is used in place of Reliable Broadcast and Common +Subset. + +AsynchroMix (paper version) +--------------------------- +As presented in the paper :cite:`honeybadgermpc` (`eprint iacr version +`_), figure 5, section 4. + +* Input: Each client :math:`C_j` receives an input :math:`m_j` +* Output: In each epoch a subset of client inputs + :math:`m_1, \ldots, m_k` are selected, and a permutation + :math:`\pi (m_1, \ldots, m_k)` is published where :math:`\pi` does + not depend on the input permutation +* Preprocessing: + + * For each :math:`m_j`, a random :math:`[\![r_j]\!]`, where each + client has received :math:`r_j` + * Preprocessing for PowerMix and/or Switching-Network + +* Protocol (for client :math:`C_j`): + + 1. Set :math:`\overline m_j := m_j + r_j` + 2. :math:`\textsf{ReliableBroadcast} \; \overline m_j` + 3. Wait until :math:`m_j` appears in the output of a mixing epoch + +* Protocol (for server :math:`P_i`) + + - Initialize for each client :math:`C_j` + + .. math:: + :nowrap: + + \begin{align*} + \textsf{input}_j & := 0 \quad \textit{ // No. of inputs received from } C_j \\ + \textsf{done}_j & := 0 \quad \textit{ // No. of messages mixed for } C_j + \end{align*} + + - On receiving :math:`\overline m_j` output from + :math:`\textsf{ReliableBroadcast}` client :math:`C_j` at any time, + set :math:`\textsf{input}_j := \textsf{input}_{j + 1}` + - Proceed in consecutive mixing epochs :math:`e`: + + **Input Collection Phase** + + * Let :math:`b_i` be a :math:`\lvert \mathcal{C} \rvert`-bit vector + where :math:`b_{i,j} = 1` if :math:`\textsf{input}_j \gt + \textsf{done}_j`. + * Pass :math:`b_i` as input to an instance of + :math:`\textsf{CommonSubset}`. + * Wait to receive :math:`b` from :math:`\textsf{CommonSubset}`, where + :math:`b` is an :math:`n \times \lvert \mathcal{C} \rvert` matrix, each row of + :math:`b` corresponds to the input from one server, and at least + :math:`n − t` of the rows are non-default. + * Let :math:`b_{\cdot,, j}` denote the column corresponding to client + :math:`C_j`. + * For each :math:`C_j`, + + .. math:: + :nowrap: + + \begin{equation} + [\![m_j]\!] := + \begin{cases} + \overline m_j - [\![r_j]\!] & \sum b_{\cdot,j} \geq t+1 \\ + 0 & \text{otherwise} + \end{cases} + \end{equation} + + **Online Phase** + + Switch Network Option + + Run the MPC Program switching-network on + :math:`\{[\![m_{j,k_j}]\!]\}`, resulting in + :math:`\pi (m_1, \ldots, m_k)` + Requires :math:`k` rounds, + + Powermix Option + + Run the MPC Program power-mix on + :math:`\{[\![m_{j,k_j}]\!]\}`, resulting in + :math:`\pi (m_1, \ldots, m_k)` + + Set :math:`\textsf{done}_j := \textsf{done}_{j+1}` for each + client :math:`C_j` whose input was mixed this epoch + + +AsynchroMix & Ethereum +---------------------- +In the original protocol asynchronous Reliable Broadcast and Common +Subset are used to orchestrate the different MPC operations that +require consensus amongst the MPC servers. See section 2.3 and 4 of the +paper for details. In this section the original protocol is presented +as it is implemented under :mod:`apps.asynchromix.asynchromix`. In +:mod:`apps.asynchromix.asynchromix` Ethereum is used as a consensus +backbone to orchestrate the MPC operations. + +**Main components:** + +* coordinator: blockchain (:sol:contract:`AsynchromixCoordinator`) +* asynchromix servers ( + :class:`~apps.asynchromix.asynchromix.AsynchromixServer`) +* asynchromix clients ( + :class:`~apps.asynchromix.asynchromix.AsynchromixClient`) + + +Input +^^^^^ +Each client :math:`C_j` receives an input :math:`m_j`. + +Currently, only one client is used, and the client itself sends a +series of "dummy" messages. In +:func:`~apps.asynchromix.asynchromix.AsynchromixClient._run()`: + +.. code-block:: python + + class AsynchromixClient: + + async def _run(self): + + # ... + for epoch in range(1000): + receipts = [] + for i in range(32): + m = f"message:{epoch}:{i}" + task = asyncio.ensure_future(self.send_message(m)) + receipts.append(task) + receipts = await asyncio.gather(*receipts) + # ... + +Output +^^^^^^ +In each epoch a subset of client inputs :math:`m_1, \ldots, m_k` are +selected, and a permutation :math:`\pi (m_1, \ldots, m_k)` is published +where :math:`\pi` does not depend on the input permutation + +Preprocessing +^^^^^^^^^^^^^ +* For each :math:`m_j`, a random :math:`[\![r_j]\!]`, where each client + has received :math:`r_j` +* Preprocessing for PowerMix and/or Switching-Network + +.. note:: At the moment the MPC program uses the switching network ( + :func:`~apps.asynchromix.butterfly_network.iterated_butterfly_network`). + +.. todo:: Explain how the preprocessing values are generated. + +.. todo:: Explain what preprocessing is done for the switching + (butterfly) network. + +In the :mod:`~apps.asynchromix.asynchromix` example the client ( +:class:`~apps.asynchromix.asynchromix.AsynchromixClient`) + +1. waits for an input mask to be ready via the smart contract function + :sol:func:`inputmasks_available`; +2. reserves an input mask via :sol:func:`reserve_inputmask`; +3. fetches the input mask from the servers (the client reconstructs the + input mask, given sufficient shares from the servers) + +Below are some code snippets that perform the above 3 steps. *Some +details of the implementation are omitted in order to ease the +presentation.* + +.. code-block:: python + + class AsynchromixClient: + + async def send_message(self, m): + contract_concise = ConciseContract(self.contract) + + # Step 1. Wait until there is input available, and enough triples + while True: + inputmasks_available = contract_concise.inputmasks_available() + if inputmasks_available >= 1: + break + await asyncio.sleep(5) + + # Step 2. Reserve the input mask + tx_hash = self.contract.functions.reserve_inputmask().transact( + {"from": self.w3.eth.accounts[0]} + ) + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + rich_logs = self.contract.events.InputMaskClaimed().processReceipt(tx_receipt) + inputmask_idx = rich_logs[0]["args"]["inputmask_idx"] + + # Step 3. Fetch the input mask from the servers + inputmask = await self._get_inputmask(inputmask_idx) + + async def _get_inputmask(self, idx): + contract_concise = ConciseContract(self.contract) + n = contract_concise.n() + poly = polynomials_over(field) + eval_point = EvalPoint(field, n, use_omega_powers=False) + shares = [] + for i in range(n): + share = self.req_mask(i, idx) + shares.append(share) + shares = await asyncio.gather(*shares) + shares = [(eval_point(i), share) for i, share in enumerate(shares)] + mask = poly.interpolate_at(shares, 0) + return mask + + +AsynchromixClient Protocol (for client :math:`C_j`) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +1. Set :math:`\overline m_j := m_j + r_j` +2. :math:`\textsf{ReliableBroadcast} \; \overline m_j` +3. Wait until :math:`m_j` appears in the output of a mixing epoch + +For step 2, instead of using :math:`\textsf{ReliableBroadcast}` the +client (:class:`~apps.asynchromix.asynchromix.AsynchromixClient`) +publishes the masked message :math:`\overline m_j` onto the Ethereum +blockchain via the smart contract function :sol:func:`submit_message`. +The masked messages are stored in the +:sol:contract:`AsynchromixCoordinator` contract' state variable +:sol:svar:`input_queue`. + +.. code-block:: python + + class AsynchromixClient: + + async def send_message(self, m): + # ... + masked_message = message + inputmask + masked_message_bytes = self.w3.toBytes(hexstr=hex(masked_message.value)) + masked_message_bytes = masked_message_bytes.rjust(32, b"\x00") + + # Step 4. Publish the masked input + tx_hash = self.contract.functions.submit_message( + inputmask_idx, masked_message_bytes + ).transact({"from": self.w3.eth.accounts[0]}) + tx_receipt = await wait_for_receipt(self.w3, tx_hash) + + +AsynchromixServer Protocol (for server :math:`P_i`) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +- Initialize for each client :math:`C_j` + +.. math:: + :nowrap: + + \begin{align*} + \textsf{input}_j & := 0 \quad \textit{ // No. of inputs received from } C_j \\ + \textsf{done}_j & := 0 \quad \textit{ // No. of messages mixed for } C_j + \end{align*} + +.. todo:: Is there a :math:`\textsf{done}_j` state variable in the + code or contract? + +- On receiving :math:`\overline m_j` output from + :math:`\textsf{ReliableBroadcast}` client :math:`C_j` at any time, + set :math:`\textsf{input}_j := \textsf{input}_{j + 1}` + +This step is handled by the smart contract function +:sol:func:`submit_message`. When the client submits a masked message, +the masked input (message) is stored in the contract's state variable +:sol:svar:`input_queue` and the length of the input queue ( +:math:`\textsf{input}_j`) is incremented by one. + +.. code-block:: solidity + + struct Input { + bytes32 masked_input; // (m+r) + uint inputmask; // index in inputmask of mask [r] + } + + Input[] public input_queue; // All inputs sent so far + + event MessageSubmitted(uint idx, uint inputmask_idx, bytes32 masked_input); + + function submit_message(uint inputmask_idx, bytes32 masked_input) public { + // Must be authorized to use this input mask + require(inputmasks_claimed[inputmask_idx] == msg.sender); + + uint idx = input_queue.length; + input_queue.length += 1; + + input_queue[idx].masked_input = masked_input; + input_queue[idx].inputmask = inputmask_idx; + + emit MessageSubmitted(idx, inputmask_idx, masked_input); + + // The input masks are deactivated after first use + inputmasks_claimed[inputmask_idx] = address(0); + } + +- Proceed in consecutive mixing epochs :math:`e`: + + **Input Collection Phase** + + * Let :math:`b_i` be a :math:`|\mathcal{C}|`-bit vector where + :math:`b_{i,j} = 1` if :math:`\textsf{input}_j \gt + \textsf{done}_j`. + * Pass :math:`b_i` as input to an instance of + :math:`\textsf{CommonSubset}`. + * Wait to receive :math:`b` from :math:`\textsf{CommonSubset}`, where + :math:`b` is an :math:`n \times |\mathcal{C}|` matrix, each row of + :math:`b` corresponds to the input from one server, and at least + :math:`n − t` of the rows are non-default. + * Let :math:`b_{\cdot,, j}` denote the column corresponding to client + :math:`C_j`. + * For each :math:`C_j`, + + .. math:: + :nowrap: + + \begin{equation} + [\![m_j]\!] := + \begin{cases} + \overline m_j - [\![r_j]\!] & \sum b_{\cdot,j} \geq t+1 \\ + 0 & \text{otherwise} + \end{cases} + \end{equation} + + .. todo:: Explain how the contract function :sol:func:`propose_output` + is used instead by the servers to submit their shuffled messages + :math:`\pi (m_1, \ldots, m_k)` that were obtained in the MPC run + for the epoch. + + **Online Phase** + + Switch Network Option + + Run the MPC Program switching-network on + :math:`\{[\![m_{j,k_j}]\!]\}`, resulting in + :math:`\pi (m_1, \ldots, m_k)` + Requires :math:`k` rounds, + + Powermix Option + + Run the MPC Program power-mix on + :math:`\{[\![m_{j,k_j}]\!]\}`, resulting in + :math:`\pi (m_1, \ldots, m_k)` + + Set :math:`\textsf{done}_j := \textsf{done}_{j+1}` for each + client :math:`C_j` whose input was mixed this epoch + + .. todo:: Explain `briefly` when, where (in the code), and how the + messages are shuffled via the switching (butterfly) network + in the MPC program. + + Also, is there a :math:`\textsf{done}_j` state variable in the + code or contract? + +Walkthrough +----------- +This section presents a step-by-step walkthrough of the code involved +to run the asynchromix example. + +To run the example: + +.. code-block:: shell + + $ python apps/asynchromix/asynchromix.py + + +So what happens when the above command is run? + +1. :py:func:`~apps.asynchromix.asynchromix.test_asynchromix` is run. +2. :py:func:`~apps.asynchromix.asynchromix.test_asynchromix` takes care + of running a local test Ethereum blockchain using `Ganache`_ and of + starting the main loop via + :py:func:`~apps.asynchromix.asynchromix.run_eth()`. More precisely, + :py:func:`~apps.asynchromix.asynchromix.test_asynchromix` runs the + command: + + .. code-block:: shell + + ganache-cli -p 8545 -a 50 -b 1 > acctKeys.json 2>&1 + + in a subprocess, in a :py:func:`contextmanager` ( + :py:func:`~apps.asynchromix.asynchromix.run_and_terminate_process`) + and within this context, in which Ethereum is running, the function + :py:func:`~apps.asynchromix.asynchromix.run_eth()` is invoked. +3. :py:func:`~apps.asynchromix.asynchromix.run_eth()` takes care of + instantiating a connection to the local Ethereum node: + + .. code-block:: python + + w3 = Web3(HTTPProvider()) + + and of starting the main loop which needs a connection to Ethereum: + + .. code-block:: python + + loop.run_until_complete(asyncio.gather(main_loop(w3))) +4. The :py:func:`~apps.asynchromix.asynchromix.main_loop` takes care of + four main things: + + 1. creating a coordinator contract (and web3 interface to it); + 2. instantiating the asynchromix servers; + 3. instantiating an asynchromix client; + 4. starting the servers and client and waiting for the completion of + their tasks. + +Initialization Phase +-------------------- +.. todo:: This section's goal is to outline the basic setup + requirements such as: + + * eth accounts creation for the MPC servers; + * "loading" of the contract on chain. + + + +Internal API docs +----------------- + +Asynchromix Coordinator Contract +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. .. sol:contract:: AsynchromixCoordinator +.. +.. .. sol:function:: inputmasks_available () public view returns (uint) +.. +.. Returns the number of input masks that are available. + +.. autosolcontract:: AsynchromixCoordinator + + + +Asynchromix Servers +^^^^^^^^^^^^^^^^^^^ +.. autoclass:: apps.asynchromix.asynchromix.AsynchromixServer + +.. automodule:: apps.asynchromix.butterfly_network + +Asynchromix Client +^^^^^^^^^^^^^^^^^^ +.. autoclass:: apps.asynchromix.asynchromix.AsynchromixClient + + + + +.. .. automodule:: apps.asynchromix.asynchromix + + +Questions +--------- +When submitting a message to Ethereum, via the contract, is the +identity of the client public? Can it be kept hidden? + +What about intersection attacks? + + +References +---------- +.. bibliography:: refs.bib + + +.. _paper: https://eprint.iacr.org/2019/883.pdf +.. _Ganache: https://github.com/trufflesuite/ganache diff --git a/docs/integrations/refs.bib b/docs/integrations/refs.bib new file mode 100644 index 00000000..5fd7d3a7 --- /dev/null +++ b/docs/integrations/refs.bib @@ -0,0 +1,16 @@ +@inproceedings{honeybadgermpc, +author = {Lu, Donghang and Yurek, Thomas and Kulshreshtha, Samarth and Govind, Rahul and Kate, Aniket and Miller, Andrew}, +title = {"HoneyBadgerMPC and AsynchroMix: Practical Asynchronous MPC and Its Application to Anonymous Communication"}, +year = {2019}, +isbn = {9781450367479}, +publisher = {Association for Computing Machinery}, +address = {New York, NY, USA}, +url = {https://doi.org/10.1145/3319535.3354238}, +doi = {10.1145/3319535.3354238}, +booktitle = {Proceedings of the 2019 ACM SIGSAC Conference on Computer and Communications Security}, +pages = {887–903}, +numpages = {17}, +keywords = {robustness, anonymous communication, asynchronous mixing, fairness, honeybadgerMPC}, +location = {London, United Kingdom}, +series = {CCS ’19} +} diff --git a/honeybadgermpc/__init__.py b/honeybadgermpc/__init__.py index 08e09058..d24c13bd 100644 --- a/honeybadgermpc/__init__.py +++ b/honeybadgermpc/__init__.py @@ -14,4 +14,4 @@ os.makedirs(ROOT_DIR / "benchmark-logs", exist_ok=True) logging_config = yaml.safe_load(f.read()) logging.config.dictConfig(logging_config) - logging.getLogger("asyncio").setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.INFO) diff --git a/honeybadgermpc/mpc.py b/honeybadgermpc/mpc.py index 49133290..dfef2599 100644 --- a/honeybadgermpc/mpc.py +++ b/honeybadgermpc/mpc.py @@ -24,7 +24,18 @@ class Mpc(object): def __init__( - self, sid, n, t, myid, send, recv, prog, config, preproc=None, **prog_args + self, + sid, + n, + t, + myid, + send, + recv, + prog, + config, + preproc=None, + shard_id=None, + **prog_args, ): # Parameters for robust MPC # Note: tolerates min(t,N-t) crash faults @@ -38,6 +49,7 @@ def __init__( self.poly = polynomials_over(self.field) self.config = config self.preproc = preproc if preproc is not None else PreProcessedElements() + self.shard_id = shard_id # send(j, o): sends object o to party j with (current sid) # recv(): returns (j, o) from party j @@ -286,22 +298,51 @@ def __init__(self, n, t, config={}): self.loop = asyncio.get_event_loop() self.router = SimpleRouter(self.N) - def add(self, program, **kwargs): - for i in range(self.N): + def add(self, program, *, shard_id=None, **kwargs): + """Creates an :class:`Mpc` instance for the given program and + adds it to a list of tasks to run. + + Parameters + ---------- + program: + The MPC program to run. + shard_id: + The shard_id of the MPC subnetwork to run. If not `None` indicates + that we are in a sharded network. Defaults to None, i.e. non-sharded + network. + + Optional keyword arguments will be passed to the :class:`Mpc` class. + """ + + def f(i): + return i + + def g(i): + return f"{shard_id}:{i}" + + # get_myid = f if shard_id is None else g + + for myid in range(self.N): + # myid = f"{shard_id}:{i}" if shard_id is not None else i + # myid = get_myid(i) context = Mpc( "mpc:%d" % (self.counter,), self.N, self.t, - i, - self.router.sends[i], - self.router.recvs[i], + myid, + self.router.sends[myid], + self.router.recvs[myid], program, self.config, + shard_id=shard_id, **kwargs, ) self.tasks.append(self.loop.create_task(context._run())) self.counter += 1 + def add_to_shard(self, program, shard_id, **kwargs): + self.add(program, shard_id=shard_id, **kwargs) + async def join(self): return await asyncio.gather(*self.tasks) diff --git a/honeybadgermpc/offline_randousha.py b/honeybadgermpc/offline_randousha.py index 6e988f2d..9bd3cf04 100644 --- a/honeybadgermpc/offline_randousha.py +++ b/honeybadgermpc/offline_randousha.py @@ -232,6 +232,10 @@ async def prog(ctx): return result +async def generate_intershardmasks(): + raise NotImplementedError + + ######################## # Process runner ######################## diff --git a/honeybadgermpc/polynomial.py b/honeybadgermpc/polynomial.py index 935d05a0..19c27946 100644 --- a/honeybadgermpc/polynomial.py +++ b/honeybadgermpc/polynomial.py @@ -127,6 +127,7 @@ def evaluate_fft(self, omega, n): return fft(self, omega, n) @classmethod + # TODO def random(cls, degree, *, y0=None): def random(cls, degree, y0=None): coeffs = [field.random() for _ in range(degree + 1)] if y0 is not None: diff --git a/honeybadgermpc/preprocessing.py b/honeybadgermpc/preprocessing.py index 0f1dc413..e297ba7a 100644 --- a/honeybadgermpc/preprocessing.py +++ b/honeybadgermpc/preprocessing.py @@ -1,3 +1,10 @@ +""" +TODO: + +Consider renaming to FakePreprocessing ... + +Real preprocsseing takes place in offline_randousha +""" import asyncio import logging import os @@ -8,6 +15,7 @@ from itertools import chain from os import listdir, makedirs from os.path import isfile, join +from pathlib import Path from random import randint from shutil import rmtree from uuid import uuid4 @@ -17,6 +25,10 @@ from .ntl import vandermonde_batch_evaluate from .polynomial import polynomials_over +logger = logging.getLogger(__name__) +# FIXME move log level setting to an entrypoint (e.g. __init__ or an app main entry) +logger.setLevel(os.environ.get("HBMPC_LOGLEVEL", logging.INFO)) + class PreProcessingConstants(Enum): SHARED_DATA_DIR = "sharedata/" @@ -28,9 +40,11 @@ class PreProcessingConstants(Enum): BITS = "bits" POWERS = "powers" SHARES = "share" - ONE_MINUS_ONE = "one_minus_one" + ONE_MINUS_ONES = "one_minus_ones" DOUBLE_SHARES = "double_shares" SHARE_BITS = "share_bits" + INTERSHARD_MASKS = "intershards_masks" + SHARE_FILE_EXT = ".share" def __str__(self): return self.value @@ -73,6 +87,14 @@ def file_prefix(self): """ return f"{self.data_dir}{self.preprocessing_name}" + @property + def data_dir_path(self): + return Path(self.data_dir) + + def key(self, myid, n, t, shard_id=None): + _id = myid if shard_id is None else f"{myid}-{shard_id}" + return _id, n, t + def min_count(self, n, t): """ Returns the minimum number of preprocessing stored in the cache across all of the keys with the given n, t values. @@ -87,7 +109,7 @@ def min_count(self, n, t): return min(counts) // self._preprocessing_stride - def get_value(self, context, *args, **kwargs): + def get_value(self, context, *args, shard_id=None, **kwargs): """ Given an MPC context, retrieve one preprocessing value. args: @@ -96,10 +118,14 @@ def get_value(self, context, *args, **kwargs): outputs: Preprocessing value for this mixin """ - key = (context.myid, context.N, context.t) + context_id = context.myid if shard_id is None else f"{context.myid}-{shard_id}" + key = (context_id, context.N, context.t) to_return, used = self._get_value(context, key, *args, **kwargs) + logger.debug(f'got value "{to_return}" and used "{used}"') + logger.debug(f"decrement count by {used}") self.count[key] -= used + logger.debug(f"count is now: {self.count}") return to_return @@ -123,7 +149,7 @@ def _read_preprocessing_file(self, file_name): return values[3:] def _write_preprocessing_file( - self, file_name, degree, context_id, values, append=False + self, file_name, degree, context_id, values, append=False, refresh_cache=False ): """ Write the values to the preprocessing file given by the filename. When append is true, this will append to an existing file, otherwise, it will @@ -148,8 +174,10 @@ def _write_preprocessing_file( print(*values, file=f, sep="\n") f.close() + if refresh_cache: + self._refresh_cache() - def build_filename(self, n, t, context_id, prefix=None): + def build_filename(self, n, t, context_id, prefix=None, **kwargs): """ Given a file prefix, and metadata, return the filename to put the shares in. @@ -189,6 +217,10 @@ def _refresh_cache(self): """ Refreshes the cache by reading in sharedata files, and updating the cache values and count variables. """ + logger.debug(f"(- {self.preprocessing_name} -) refreshing cache") + logger.debug( + f"(- {self.preprocessing_name} -) before cache refresh, count is: {dict(self.count)}" + ) self.cache = defaultdict(chain) self.count = defaultdict(int) @@ -204,11 +236,16 @@ def _refresh_cache(self): (n, t, context_id) = groups key = (context_id, n, t) values = self._read_preprocessing_file(file_name) - self.cache[key] = chain(values) self.count[key] = len(values) - def _write_polys(self, n, t, polys, append=False, prefix=None): + logger.debug( + f"(- {self.preprocessing_name} -) after cache refresh, count is: {dict(self.count)}" + ) + + def _write_polys( + self, n, t, polys, append=False, prefix=None, shard_id=None, recv_shard_id=None + ): """ Given a file prefix, a list of polynomials, and associated n, t values, write the preprocessing for the share values represented by the polnomials. @@ -218,6 +255,12 @@ def _write_polys(self, n, t, polys, append=False, prefix=None): t: number of faults tolerated by this preprocessing polys: polynomials corresponding to secret share values to write append: Whether or not to append shares to an existing file, or to overwrite. + shard_id: If in a sharded network, id of the shard the elements need + to be written to file for. + recv_shard_id: For intershard operations, the id of the shard that is + on the receiving end. That is, the elements to be written to file, + are meant to be used when shard_id needs to send "secured" data + to another shard. """ polys = [[coeff.value for coeff in poly.coeffs] for poly in polys] all_values = vandermonde_batch_evaluate( @@ -226,10 +269,13 @@ def _write_polys(self, n, t, polys, append=False, prefix=None): for i in range(n): values = [v[i] for v in all_values] - file_name = self.build_filename(n, t, i, prefix=prefix) + file_name = self.build_filename( + n, t, i, prefix=prefix, shard_id=shard_id, recv_shard_id=recv_shard_id, + ) self._write_preprocessing_file(file_name, t, i, values, append=append) - key = (i, n, t) + context_id = i if shard_id is None else f"{i}-{shard_id}" + key = (context_id, n, t) if append: self.cache[key] = chain(self.cache[key], values) self.count[key] += len(values) @@ -410,9 +456,145 @@ def _generate_polys(self, k, n, t): def _get_value(self, context, key, t=None): t = t if t is not None else context.t - assert self.count[key] >= 1 + assert self.count[key] >= 1, f"key is: {key}\ncount is: {self.count}\n" + return context.Share(next(self.cache[key]), t), 1 + + +class InterShardMasksPreProcessing(PreProcessingMixin): + preprocessing_name = PreProcessingConstants.INTERSHARD_MASKS.value + _preprocessing_stride = 1 + + def mk_subdirs(self, context_id): + """Create intermediate sub directories for the given server id, aka + context id. + + Parameters: + ----------- + context_id : str or int + The id of a server. In a sharded network it MUST be a string + of the form: ``"{myid}:{shard_id}"`` where ``myid`` is the integer + used to identify a server within a shard. + + Returns + ------- + Path object. + """ + new_path = self.data_dir_path.joinpath(context_id, self.preprocessing_name) + new_path.mkdir(exist_ok=True) + return new_path + + def build_filename( + self, n, t, context_id, prefix=None, shard_id=None, recv_shard_id=None, **kwargs + ): + """ Given context id, shard ids, and metadata, return the filename to put + the shares in. + + The filename structure is as follows: + + data_dir/context_id/preprocessing_name/n_t-sh1_sh2.share + + For example, for node 1, in shard 5, and shard 108 as a receiving shard, in a + network of n=1000 and t=300: + + sharedata/1-5/intershard_masks/1000_300-5_8.share + + The reasoning behind the above file path is that by having the context id as + a base directory it helps to remind one that the files are meant to be seen + only by the node in question. In other words, everything under the directory + "sharedata/1-5/" would be private in a real world scenario. + + Parsing this kind of file path should also be easier, such that regular + expressions will not be needed. + + Parameters + ---------- + n : int + Total number of nodes in the shard. + t : int + Polynomial degree, and maximum fault tolerance for the + number of nodes that may deviate from the protocol. + context_id : str + The id of the mpc context/server the file belongs to. + shard_id : str or int + The id of the shard the server belongs to. + recv_shard_idc: str or int + The id of the receiving shard. + + output: + Filename to use + """ + context_id = f"{context_id}-{shard_id}" + dir_path = self.data_dir_path.joinpath(context_id, self.preprocessing_name) + dir_path.mkdir(parents=True, exist_ok=True) + file_path = dir_path.joinpath(f"{n}_{t}-{shard_id}_{recv_shard_id}") + full_file_path = file_path.with_suffix( + PreProcessingConstants.SHARE_FILE_EXT.value + ) + return str(full_file_path) + + def generate_values(self, k, n, t, *, shard_1_id, shard_2_id, append=False): + polys = self._generate_polys( + k, n, t, shard_1_id=shard_1_id, shard_2_id=shard_2_id + ) + shard_ids = {shard_1_id, shard_2_id} + for shard_id, polys in polys.items(): + recv_shard_id = (shard_ids - {shard_id}).pop() + self._write_polys( + n, + t, + polys, + append=False, + shard_id=shard_id, + recv_shard_id=recv_shard_id, + ) + + def _generate_polys(self, k, n, t, *, shard_1_id, shard_2_id): + """Return a pair of polys for each k value. + + .. note:: negligible chance that coeffs is empty + """ + polys = defaultdict(list) + for _ in range(k): + poly_1 = self.poly.random(t) + poly_2 = self.poly.random(t, y0=poly_1.coeffs[0]) + polys[shard_1_id].append(poly_1) + polys[shard_2_id].append(poly_2) + return polys + + def _get_value(self, context, key, t=None): + t = t if t is not None else context.t + assert self.count[key] >= 1, f"key is: {key}\ncount is: {self.count}\n" return context.Share(next(self.cache[key]), t), 1 + def _refresh_cache(self): + """ Refreshes the cache by reading in sharedata files, and + updating the cache values and count variables. + """ + logger.debug(f"(- {self.preprocessing_name} -) refreshing cache") + logger.debug( + f"(- {self.preprocessing_name} -) before cache refresh, count is: {dict(self.count)}" + ) + self.cache = defaultdict(chain) + self.count = defaultdict(int) + + for server_dir in self.data_dir_path.iterdir(): + if server_dir.is_file(): + continue + context_id = server_dir.stem + for pp_elements_dir in server_dir.iterdir(): + for share_file in pp_elements_dir.iterdir(): + # expected filename format is: "n_t-sh1_sh2.share" + n, t = (int(v) for v in share_file.stem.split("-")[0].split("_")) + key = (context_id, n, t) + values = self._read_preprocessing_file(share_file) + + self.cache[key] = chain(values) + self.count[key] = len(values) + + logger.debug( + f"(- {self.preprocessing_name} -) after cache refresh, count is: {dict(self.count)}" + ) + class SimplePreProcessing(PreProcessingMixin): """ Subclass of PreProcessingMixin to be used in the trivial case @@ -426,9 +608,12 @@ def _get_value(self, context, key): assert self.count[key] >= self._preprocessing_stride, ( f"Expected " f"{self._preprocessing_stride} elements of {self.preprocessing_name}, " - f"but found only {self.count[key]}" + f"but found only {self.count[key]}\n" + f"key is: {key}\n" + f"count is: {self.count}\n" ) + logger.debug("getting value ...") values = tuple( context.Share(next(self.cache[key])) for _ in range(self._preprocessing_stride) @@ -487,7 +672,7 @@ def _generate_polys(self, k, n, t): class SignedBitPreProcessing(SimplePreProcessing): - preprocessing_name = PreProcessingConstants.ONE_MINUS_ONE.value + preprocessing_name = PreProcessingConstants.ONE_MINUS_ONES.value _preprocessing_stride = 1 def _generate_polys(self, k, n, t): @@ -557,6 +742,9 @@ def __init__(self, append=True, data_directory=None, field=None): self._share_bits = ShareBitsPreProcessing( self.field, self.poly, self.data_directory ) + self._intershard_masks = InterShardMasksPreProcessing( + self.field, self.poly, self.data_directory + ) @classmethod def reset_cache(cls): @@ -572,11 +760,12 @@ def _init_data_dir(self): def clear_preprocessing(self): """ Delete all things from the preprocessing folder """ + logger.debug( + f"Deleting all files from preprocessing folder: {self.data_directory}" + ) rmtree( self.data_directory, - onerror=lambda f, p, e: logging.debug( - f"Error deleting data directory: {e}" - ), + onerror=lambda f, p, e: logger.debug(f"Error deleting data directory: {e}"), ) self._init_data_dir() @@ -585,7 +774,7 @@ async def wait_for_preprocessing(self, timeout=1): """ Block until the ready file is created """ while not os.path.exists(self._ready_file): - logging.info(f"waiting for preprocessing {self._ready_file}") + logger.debug(f"waiting for preprocessing {self._ready_file}") await asyncio.sleep(timeout) def preprocessing_done(self): @@ -635,6 +824,16 @@ def generate_powers(self, k, n, t, z): def generate_share(self, n, t, *args, **kwargs): return self._generate(self._shares, 1, n, t, *args, **kwargs) + def generate_intershard_masks(self, k, n, t, *, shard_1_id, shard_2_id): + return self._generate( + self._intershard_masks, + k, + n, + t, + shard_1_id=shard_1_id, + shard_2_id=shard_2_id, + ) + ## Preprocessing retrieval methods: def get_triples(self, context): @@ -666,3 +865,6 @@ def get_double_shares(self, context): def get_share_bits(self, context): return self._share_bits.get_value(context) + + def get_intershard_masks(self, context, shard_id): + return self._intershard_masks.get_value(context, shard_id=shard_id) diff --git a/setup.py b/setup.py index e4ecafaa..fa611b85 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ REQUIRES_PYTHON = ">=3.7.0" VERSION = None -REQUIRED = ["gmpy2", "zfec", "pycrypto", "cffi", "psutil", "pyzmq"] +REQUIRED = ["aiohttp", "gmpy2", "zfec", "pycrypto", "cffi", "psutil", "pyzmq", "toml"] TESTS_REQUIRES = [ "black", @@ -36,13 +36,15 @@ DOCS_REQUIRE = [ "Sphinx", "sphinx-autobuild", + "sphinxcontrib-bibtex", + "sphinxcontrib-soliditydomain", "sphinx_rtd_theme", "sphinx_tabs", "m2r", "doc8", ] -ETH_REQUIRES = ["web3", "ethereum"] +ETH_REQUIRES = ["bitcoin", "web3", "ethereum"] AWS_REQUIRES = ["boto3", "paramiko"] diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index f7e91666..00fb6b21 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -3,7 +3,7 @@ from pytest import mark from honeybadgermpc.mpc import TaskProgramRunner -from honeybadgermpc.preprocessing import PreProcessedElements +from honeybadgermpc.preprocessing import PreProcessedElements, PreProcessingConstants @mark.asyncio @@ -168,3 +168,91 @@ async def _prog(ctx): program_runner = TaskProgramRunner(n, t) program_runner.add(_prog) await program_runner.join() + + +@mark.asyncio +async def test_get_intershard_masks(): + k, n, t = 100, 4, 1 + shards = (3, 8) + pp_elements = PreProcessedElements() + pp_elements.generate_intershard_masks( + k, n, t, shard_1_id=shards[0], shard_2_id=shards[1] + ) + intershard_masks = pp_elements._intershard_masks + # check that all masks are there + assert all( + (f"{i}-{s}", n, t) in intershard_masks.count for i in range(n) for s in shards + ) + num_masks = 2 + masks_3 = [] + masks_8 = [] + + # TODO + # * simplify the 2 progs, and + # * if possible only have one def, parametrized with the shard id + # * also: can the shard be accessed via the ctx object instead? The main + # point is that information seems to be redundant ... if the ctx has + # access to the shard id then perhaps no need to pass it to the method + # `get_intershard_masks()` + async def _prog3(ctx): + for _ in range(num_masks): + mask_share = ctx.preproc.get_intershard_masks(ctx, shards[0]) + mask = await mask_share.open() + masks_3.append(mask) + + async def _prog8(ctx): + for _ in range(num_masks): + mask_share = ctx.preproc.get_intershard_masks(ctx, shards[1]) + mask = await mask_share.open() + masks_8.append(mask) + + program_runner = TaskProgramRunner(n, t) + program_runner.add(_prog3, shard_id=shards[0]) + await program_runner.join() + program_runner.add(_prog8, shard_id=shards[1]) + await program_runner.join() + print(f"\nmasks for shard 3: {masks_3}") + print(f"len of masks: {len(masks_3)}") + print(f"\nmasks for shard 8: {masks_8}") + print(f"len of masks: {len(masks_8)}") + assert masks_3 == masks_8 + + +def test_generate_intershard_masks(): + k, n, t = 100, 4, 1 + shards = (3, 8) + pp_elements = PreProcessedElements() + pp_elements.generate_intershard_masks( + k, n, t, shard_1_id=shards[0], shard_2_id=shards[1] + ) + intershard_masks = pp_elements._intershard_masks + # check the cache and count + cache = intershard_masks.cache + count = intershard_masks.count + assert len(cache) == 2 * n # there are 2 shards with n servers in each + # Check that the cache contains all expected keys. A key is a 3-tuple made + # from (context_id, n, t), The context_id is made from "{i}-{shard_id}". + assert all((f"{i}-{s}", n, t) in cache for i in range(n) for s in shards) + assert all(len(tuple(elements)) == k for elements in cache.values()) + assert all(c == k for c in count.values()) + assert all((f"{i}-{s}", n, t) in count for i in range(n) for s in shards) + # check all the expected files have been created + data_dir_path = intershard_masks.data_dir_path + for shard_index, shard_id in enumerate(shards): + other_shard = shards[1 - shard_index] + for node_id in range(n): + node_path = data_dir_path.joinpath(f"{node_id}-{shard_id}") + assert node_path.exists() + csm_path = node_path.joinpath(intershard_masks.preprocessing_name) + assert csm_path.exists() + file_path = csm_path.joinpath( + f"{n}_{t}-{shard_id}_{other_shard}" + ).with_suffix(PreProcessingConstants.SHARE_FILE_EXT.value) + assert file_path.exists() + with file_path.open() as f: + _lines = f.readlines() + lines = [int(line) for line in _lines] + assert len(lines) == 3 + k # modulus, degree t, n, k + assert lines[0] == intershard_masks.field.modulus + assert lines[1] == t + assert lines[2] == node_id