Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: depdency upgrades and web3.py v7 support #52

Merged
merged 4 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 51 additions & 39 deletions ledgereth/web3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
"""Web3.py middleware for Ledger devices."""

from typing import Any

# Some of the following imports utilize web3.py deps that are not deps of
# ledgereth.
from eth_account.messages import encode_structured_data
from eth_account.messages import encode_typed_data
from eth_utils import decode_hex, encode_hex
from rlp import encode
from web3.middleware import Web3Middleware
from web3.types import MakeRequestFn, RPCEndpoint, RPCResponse

from ledgereth.accounts import find_account, get_accounts
from ledgereth.messages import sign_message, sign_typed_data_draft
Expand Down Expand Up @@ -39,7 +45,15 @@
"""


class LedgerSignerMiddleware:
def _make_response(result: Any) -> RPCResponse:
return {
"jsonrpc": "2.0",
"id": 1337,
"result": result,
}


class LedgerSignerMiddleware(Web3Middleware):
"""Web3.py middleware. It will automatically intercept the relevant
JSON-RPC calls and respond with data from your Ledger device.

Expand All @@ -64,33 +78,33 @@ class LedgerSignerMiddleware:

_dongle = None

def __init__(self, make_request, w3):
self.w3 = w3
self.make_request = make_request
def wrap_make_request(self, make_request: MakeRequestFn):
"""Intercept some JSON-RPC requests and forward them to the Ledger device."""

def __call__(self, method, params):
if method == "eth_sendTransaction":
return self._handle_eth_sendTransaction(method, params)
def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:
if method == "eth_sendTransaction":
return self._handle_eth_sendTransaction(params, make_request)
elif method == "eth_accounts":
return self._handle_eth_accounts(params)
elif method == "eth_sign":
return self._handle_eth_sign(params)
elif method == "eth_signTypedData":
return self._handle_eth_signTypedData(params)

elif method == "eth_accounts":
return self._handle_eth_accounts(method, params)
# Send on to the next middleware(s)
return make_request(method, params)

elif method == "eth_sign":
return self._handle_eth_sign(method, params)
return middleware

elif method == "eth_signTypedData":
return self._handle_eth_signTypedData(method, params)

# Send on to the next middleware(s)
return self.make_request(method, params)

def _handle_eth_accounts(self, method, params):
def _handle_eth_accounts(self, _: Any) -> RPCResponse:
"""Handler for eth_accounts RPC calls"""
return {
"result": list(map(lambda a: a.address, get_accounts(dongle=self._dongle))),
}
return _make_response(
list(map(lambda a: a.address, get_accounts(dongle=self._dongle)))
)

def _handle_eth_sendTransaction(self, method, params):
def _handle_eth_sendTransaction(
self, params: Any, make_request: MakeRequestFn
) -> RPCResponse:
"""Handler for eth_sendTransaction RPC calls"""
new_params = []

Expand Down Expand Up @@ -121,21 +135,23 @@ def _handle_eth_sendTransaction(self, method, params):
raise Exception(f"Account {sender_address} not found")

if nonce is None:
nonce = self.w3.eth.get_transaction_count(sender_address)
nonce = self._w3.eth.get_transaction_count(sender_address)

if "accessList" in tx_obj:
access_list = decode_web3_access_list(tx_obj["accessList"])

signed_tx = create_transaction(
chain_id=self.w3.eth.chain_id,
chain_id=self._w3.eth.chain_id,
destination=tx_obj.get("to"),
amount=int(value, 16),
gas=int(gas, 16),
gas_price=int(gas_price, 16) if gas_price else None,
max_fee_per_gas=int(max_fee_per_gas, 16) if max_fee_per_gas else None,
max_priority_fee_per_gas=int(max_priority_fee_per_gas, 16)
if max_priority_fee_per_gas
else None,
max_priority_fee_per_gas=(
int(max_priority_fee_per_gas, 16)
if max_priority_fee_per_gas
else None
),
nonce=nonce,
data=tx_obj.get("data", b""),
sender_path=sender_account.path,
Expand All @@ -146,12 +162,12 @@ def _handle_eth_sendTransaction(self, method, params):
new_params.append(signed_tx.rawTransaction)

# Change to raw tx call
method = "eth_sendRawTransaction"
method: RPCEndpoint = "eth_sendRawTransaction"
params = new_params

return self.make_request(method, params)
return make_request(method, params)

def _handle_eth_sign(self, mehtod, params):
def _handle_eth_sign(self, params: Any) -> RPCResponse:
"""Handler for eth_sign RPC calls"""
if len(params) != 2:
raise ValueError("Unexpected RPC request params length for eth_sign")
Expand All @@ -162,11 +178,9 @@ def _handle_eth_sign(self, mehtod, params):
signer_account = find_account(account, dongle=self._dongle)
signed = sign_message(message, signer_account.path, dongle=self._dongle)

return {
"result": signed.signature,
}
return _make_response(signed.signature)

def _handle_eth_signTypedData(self, mehtod, params):
def _handle_eth_signTypedData(self, params: Any) -> RPCResponse:
"""Handler for eth_signTypedData RPC calls"""
if len(params) != 2:
raise ValueError("Unexpected RPC request params length for eth_sign")
Expand All @@ -180,7 +194,7 @@ def _handle_eth_signTypedData(self, mehtod, params):
)

# Use eth_account to encode and hash the typed data
signable = encode_structured_data(typed_data)
signable = encode_typed_data(full_message=typed_data)
domain_hash = signable.header
message_hash = signable.body

Expand All @@ -190,6 +204,4 @@ def _handle_eth_signTypedData(self, mehtod, params):
domain_hash, message_hash, signer_account.path, dongle=self._dongle
)

return {
"result": signed.signature,
}
return _make_response(signed.signature)
6 changes: 4 additions & 2 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
autoflake~=1.4
black>=21.11b1
bump2version~=1.0.1
eth-account>=0.8.0
eth-account>=0.13.1
isort>=5.10.1
mypy>=0.910
pytest>=5.3.2
setuptools>=44.0.0
twine>=3.1.1
web3[tester]~=6.2.0
web3[tester]~=7.3.1
# peer dep of eth-tester of web3[tester]
#eth-account==0.12.1
wheel>=0.33.6
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
eth-utils>=2.1.0,<3.0.0
eth-utils>=2.1.0,<6
ledgerblue==0.1.48
rlp~=3.0.0
rlp~=4.0.1
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from eth_account.account import Account
from eth_account.messages import SignableMessage, encode_defunct, encode_structured_data
from eth_account.messages import SignableMessage, encode_defunct
from eth_utils import decode_hex, encode_hex
from hexbytes import HexBytes
from ledgerblue.comm import getDongle
Expand Down Expand Up @@ -195,7 +195,7 @@ def exchange(self, adpu, timeout=20000):
raise self.exception

def close(self):
...
pass


def getMockDongle():
Expand Down
1 change: 1 addition & 0 deletions tests/test_chain_id.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test signing for multiple chains"""

import pytest
from eth_account import Account

Expand Down
1 change: 1 addition & 0 deletions tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
approve all the transactions when testing.. Might work with mock dongle if
that ever gets done.
"""

import binascii
import os
import re
Expand Down
1 change: 1 addition & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Test that exceptions are translated/rendered correctly """

import pytest
from ledgerblue.commException import CommException

Expand Down
5 changes: 3 additions & 2 deletions tests/test_message_signing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
Test higher level message signing functionality
"""

from eth_account import Account
from eth_account.messages import encode_defunct, encode_structured_data
from eth_account.messages import encode_defunct, encode_typed_data
from eth_utils import decode_hex, encode_hex

from ledgereth.accounts import get_accounts
Expand Down Expand Up @@ -54,7 +55,7 @@ def test_sign_large_message(yield_dongle):

def test_sign_typed_data(yield_dongle):
"""Test signing an EIP-712 typed data"""
signable = encode_structured_data(eip712_dict)
signable = encode_typed_data(full_message=eip712_dict)

# header/body is eth_account naming, presumably to be generic
domain_hash = signable.header
Expand Down
1 change: 1 addition & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Test objects and serialization
"""

from eth_utils import decode_hex, is_checksum_address

from ledgereth.constants import DEFAULT_CHAIN_ID, DEFAULTS
Expand Down
1 change: 1 addition & 0 deletions tests/test_transactions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Test higher level transaction functionality
"""

from eth_account import Account
from eth_utils import decode_hex

Expand Down
4 changes: 2 additions & 2 deletions tests/test_web3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from eth_account import Account
from eth_account.messages import encode_defunct, encode_structured_data
from eth_account.messages import encode_defunct, encode_typed_data
from eth_utils import encode_hex
from web3 import Web3
from web3.datastructures import AttributeDict
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_web3_middleware_sign_text(yield_dongle):
def test_web3_middleware_sign_typed_data(yield_dongle):
"""Test LedgerSignerMiddleware EIP-712 typed data signing"""

signable = encode_structured_data(eip712_dict)
signable = encode_typed_data(full_message=eip712_dict)
provider = EthereumTesterProvider()
web3 = Web3(provider)
clean_web3 = Web3(provider)
Expand Down
Loading