Skip to content

Commit

Permalink
Merge pull request #52 from mikeshultz/chore/updates
Browse files Browse the repository at this point in the history
chore: depdency upgrades and web3.py v7 support
  • Loading branch information
mikeshultz authored Oct 19, 2024
2 parents 77d3e7b + d6c90f0 commit a82f378
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 49 deletions.
90 changes: 51 additions & 39 deletions ledgereth/web3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
"""Web3.py middleware for Ledger devices."""

from typing import Any

# Some of the following imports utilize web3.py deps that are not deps of
# ledgereth.
from eth_account.messages import encode_structured_data
from eth_account.messages import encode_typed_data
from eth_utils import decode_hex, encode_hex
from rlp import encode
from web3.middleware import Web3Middleware
from web3.types import MakeRequestFn, RPCEndpoint, RPCResponse

from ledgereth.accounts import find_account, get_accounts
from ledgereth.messages import sign_message, sign_typed_data_draft
Expand Down Expand Up @@ -39,7 +45,15 @@
"""


class LedgerSignerMiddleware:
def _make_response(result: Any) -> RPCResponse:
return {
"jsonrpc": "2.0",
"id": 1337,
"result": result,
}


class LedgerSignerMiddleware(Web3Middleware):
"""Web3.py middleware. It will automatically intercept the relevant
JSON-RPC calls and respond with data from your Ledger device.
Expand All @@ -64,33 +78,33 @@ class LedgerSignerMiddleware:

_dongle = None

def __init__(self, make_request, w3):
self.w3 = w3
self.make_request = make_request
def wrap_make_request(self, make_request: MakeRequestFn):
"""Intercept some JSON-RPC requests and forward them to the Ledger device."""

def __call__(self, method, params):
if method == "eth_sendTransaction":
return self._handle_eth_sendTransaction(method, params)
def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:
if method == "eth_sendTransaction":
return self._handle_eth_sendTransaction(params, make_request)
elif method == "eth_accounts":
return self._handle_eth_accounts(params)
elif method == "eth_sign":
return self._handle_eth_sign(params)
elif method == "eth_signTypedData":
return self._handle_eth_signTypedData(params)

elif method == "eth_accounts":
return self._handle_eth_accounts(method, params)
# Send on to the next middleware(s)
return make_request(method, params)

elif method == "eth_sign":
return self._handle_eth_sign(method, params)
return middleware

elif method == "eth_signTypedData":
return self._handle_eth_signTypedData(method, params)

# Send on to the next middleware(s)
return self.make_request(method, params)

def _handle_eth_accounts(self, method, params):
def _handle_eth_accounts(self, _: Any) -> RPCResponse:
"""Handler for eth_accounts RPC calls"""
return {
"result": list(map(lambda a: a.address, get_accounts(dongle=self._dongle))),
}
return _make_response(
list(map(lambda a: a.address, get_accounts(dongle=self._dongle)))
)

def _handle_eth_sendTransaction(self, method, params):
def _handle_eth_sendTransaction(
self, params: Any, make_request: MakeRequestFn
) -> RPCResponse:
"""Handler for eth_sendTransaction RPC calls"""
new_params = []

Expand Down Expand Up @@ -121,21 +135,23 @@ def _handle_eth_sendTransaction(self, method, params):
raise Exception(f"Account {sender_address} not found")

if nonce is None:
nonce = self.w3.eth.get_transaction_count(sender_address)
nonce = self._w3.eth.get_transaction_count(sender_address)

if "accessList" in tx_obj:
access_list = decode_web3_access_list(tx_obj["accessList"])

signed_tx = create_transaction(
chain_id=self.w3.eth.chain_id,
chain_id=self._w3.eth.chain_id,
destination=tx_obj.get("to"),
amount=int(value, 16),
gas=int(gas, 16),
gas_price=int(gas_price, 16) if gas_price else None,
max_fee_per_gas=int(max_fee_per_gas, 16) if max_fee_per_gas else None,
max_priority_fee_per_gas=int(max_priority_fee_per_gas, 16)
if max_priority_fee_per_gas
else None,
max_priority_fee_per_gas=(
int(max_priority_fee_per_gas, 16)
if max_priority_fee_per_gas
else None
),
nonce=nonce,
data=tx_obj.get("data", b""),
sender_path=sender_account.path,
Expand All @@ -146,12 +162,12 @@ def _handle_eth_sendTransaction(self, method, params):
new_params.append(signed_tx.rawTransaction)

# Change to raw tx call
method = "eth_sendRawTransaction"
method: RPCEndpoint = "eth_sendRawTransaction"
params = new_params

return self.make_request(method, params)
return make_request(method, params)

def _handle_eth_sign(self, mehtod, params):
def _handle_eth_sign(self, params: Any) -> RPCResponse:
"""Handler for eth_sign RPC calls"""
if len(params) != 2:
raise ValueError("Unexpected RPC request params length for eth_sign")
Expand All @@ -162,11 +178,9 @@ def _handle_eth_sign(self, mehtod, params):
signer_account = find_account(account, dongle=self._dongle)
signed = sign_message(message, signer_account.path, dongle=self._dongle)

return {
"result": signed.signature,
}
return _make_response(signed.signature)

def _handle_eth_signTypedData(self, mehtod, params):
def _handle_eth_signTypedData(self, params: Any) -> RPCResponse:
"""Handler for eth_signTypedData RPC calls"""
if len(params) != 2:
raise ValueError("Unexpected RPC request params length for eth_sign")
Expand All @@ -180,7 +194,7 @@ def _handle_eth_signTypedData(self, mehtod, params):
)

# Use eth_account to encode and hash the typed data
signable = encode_structured_data(typed_data)
signable = encode_typed_data(full_message=typed_data)
domain_hash = signable.header
message_hash = signable.body

Expand All @@ -190,6 +204,4 @@ def _handle_eth_signTypedData(self, mehtod, params):
domain_hash, message_hash, signer_account.path, dongle=self._dongle
)

return {
"result": signed.signature,
}
return _make_response(signed.signature)
6 changes: 4 additions & 2 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
autoflake~=1.4
black>=21.11b1
bump2version~=1.0.1
eth-account>=0.8.0
eth-account>=0.13.1
isort>=5.10.1
mypy>=0.910
pytest>=5.3.2
setuptools>=44.0.0
twine>=3.1.1
web3[tester]~=6.2.0
web3[tester]~=7.3.1
# peer dep of eth-tester of web3[tester]
#eth-account==0.12.1
wheel>=0.33.6
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
eth-utils>=2.1.0,<3.0.0
eth-utils>=2.1.0,<6
ledgerblue==0.1.48
rlp~=3.0.0
rlp~=4.0.1
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from eth_account.account import Account
from eth_account.messages import SignableMessage, encode_defunct, encode_structured_data
from eth_account.messages import SignableMessage, encode_defunct
from eth_utils import decode_hex, encode_hex
from hexbytes import HexBytes
from ledgerblue.comm import getDongle
Expand Down Expand Up @@ -195,7 +195,7 @@ def exchange(self, adpu, timeout=20000):
raise self.exception

def close(self):
...
pass


def getMockDongle():
Expand Down
1 change: 1 addition & 0 deletions tests/test_chain_id.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test signing for multiple chains"""

import pytest
from eth_account import Account

Expand Down
1 change: 1 addition & 0 deletions tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
approve all the transactions when testing.. Might work with mock dongle if
that ever gets done.
"""

import binascii
import os
import re
Expand Down
1 change: 1 addition & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Test that exceptions are translated/rendered correctly """

import pytest
from ledgerblue.commException import CommException

Expand Down
5 changes: 3 additions & 2 deletions tests/test_message_signing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
Test higher level message signing functionality
"""

from eth_account import Account
from eth_account.messages import encode_defunct, encode_structured_data
from eth_account.messages import encode_defunct, encode_typed_data
from eth_utils import decode_hex, encode_hex

from ledgereth.accounts import get_accounts
Expand Down Expand Up @@ -54,7 +55,7 @@ def test_sign_large_message(yield_dongle):

def test_sign_typed_data(yield_dongle):
"""Test signing an EIP-712 typed data"""
signable = encode_structured_data(eip712_dict)
signable = encode_typed_data(full_message=eip712_dict)

# header/body is eth_account naming, presumably to be generic
domain_hash = signable.header
Expand Down
1 change: 1 addition & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Test objects and serialization
"""

from eth_utils import decode_hex, is_checksum_address

from ledgereth.constants import DEFAULT_CHAIN_ID, DEFAULTS
Expand Down
1 change: 1 addition & 0 deletions tests/test_transactions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Test higher level transaction functionality
"""

from eth_account import Account
from eth_utils import decode_hex

Expand Down
4 changes: 2 additions & 2 deletions tests/test_web3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from eth_account import Account
from eth_account.messages import encode_defunct, encode_structured_data
from eth_account.messages import encode_defunct, encode_typed_data
from eth_utils import encode_hex
from web3 import Web3
from web3.datastructures import AttributeDict
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_web3_middleware_sign_text(yield_dongle):
def test_web3_middleware_sign_typed_data(yield_dongle):
"""Test LedgerSignerMiddleware EIP-712 typed data signing"""

signable = encode_structured_data(eip712_dict)
signable = encode_typed_data(full_message=eip712_dict)
provider = EthereumTesterProvider()
web3 = Web3(provider)
clean_web3 = Web3(provider)
Expand Down

0 comments on commit a82f378

Please sign in to comment.