diff --git a/src/main.py b/src/main.py index 3ed909f..fa6ed5e 100644 --- a/src/main.py +++ b/src/main.py @@ -30,9 +30,31 @@ def parse_arguments(): # TODO pokracovani - block_transactions = [COINBASE_TRANSACTION] + mempool.valid_transactions - - transaction_hashes = [calculate_txid(COINBASE_TRANSACTION)] + [calculate_txid(json_transaction) for json_transaction in block_transactions[1:]] + + #block_transactions = [COINBASE_TRANSACTION] + mempool.valid_transactions + + # Initialize an empty list for transactions + block_transactions = [COINBASE_TRANSACTION] + + # Initialize total weight and total fees + total_weight = 0 + total_fees = 0 + + # Set the maximum block weight + max_block_weight = 4000000 + + # Sort the transactions by the fee in descending order + transactions_sorted_by_fee = sorted(mempool.valid_transactions, key=lambda tx: tx.fee, reverse=True) + + for tx in transactions_sorted_by_fee: + tx_weight = tx.calculate_weight() + if total_weight + tx_weight > max_block_weight: + break + block_transactions.append(tx) + total_weight = total_weight + tx_weight + total_fees = total_fees + tx.fee + + transaction_hashes = [calculate_txid(COINBASE_TRANSACTION, True)] + [calculate_txid(transaction.json_transaction, transaction.has_witness) for transaction in block_transactions[1:]] block_hash = block_mining(transaction_hashes).hex() wtxids = ["0000000000000000000000000000000000000000000000000000000000000000"] + transaction_hashes[1:] diff --git a/src/mempool.py b/src/mempool.py index 1231f62..fa11785 100644 --- a/src/mempool.py +++ b/src/mempool.py @@ -7,4 +7,4 @@ def __init__(self, root_dir): self.root_dir = root_dir self.transaction_files = [os.path.join(self.root_dir, file) for file in os.listdir(self.root_dir) if file.endswith('.json')] self.transactions = [Transaction(file) for file in self.transaction_files] - self.valid_transactions = [transaction.json_transaction for transaction in self.transactions if transaction.is_valid()] + self.valid_transactions = [transaction for transaction in self.transactions if transaction.is_valid()] diff --git a/src/serialize.py b/src/serialize.py index db84e33..d3bac3c 100644 --- a/src/serialize.py +++ b/src/serialize.py @@ -82,6 +82,9 @@ def serialize_transaction(transaction, index=-1, sighash_type=1, segwit=False): # witness if segwit: for tx_in in inputs: + if "witness" not in tx_in: + break + out += [encode_varint(len(tx_in["witness"]))] for item in tx_in["witness"]: diff --git a/src/transaction.py b/src/transaction.py index 4f24bbf..f93a7b1 100644 --- a/src/transaction.py +++ b/src/transaction.py @@ -10,12 +10,9 @@ from src.verify import parse_der_signature_bytes, valid_transaction_syntax -def calculate_txid(transaction_content, coinbase=False): +def calculate_txid(transaction_content, segwit=False): # Serialize the transaction content - if coinbase: - serialized_transaction = serialize_transaction(transaction_content, segwit=True) #json.dumps(transaction_content, sort_keys=True).encode() - else: - serialized_transaction = serialize_transaction(transaction_content) #json.dumps(transaction_content, sort_keys=True).encode() + serialized_transaction = serialize_transaction(transaction_content, segwit=segwit) # Calculate double SHA-256 hash hash_result = hashlib.sha256(hashlib.sha256(serialized_transaction).digest()).digest() @@ -39,6 +36,8 @@ def __init__(self, transaction_json_file): self.vin = json_transaction['vin'] self.vout = json_transaction['vout'] self.json_transaction = json_transaction + self.fee = 0 + self.has_witness = False else: # TODO jestli nejakej error print('Invalid transaction syntax') @@ -90,12 +89,20 @@ def check_input_output_sum(self): for output in self.vout: output_sum = output_sum + output['value'] + self.fee = input_sum - output_sum + # Output sum can't be greater than the input sum. if input_sum < output_sum: return False return True + def calculate_weight(self): + base_size = len(serialize_transaction(self.json_transaction)) + total_size = len(serialize_transaction(self.json_transaction, segwit=self.has_witness)) + + return int(base_size * 3 + total_size) + def valid_input(self, vin_idx, vin): if vin.get("is_coinbase", False): return False @@ -109,12 +116,15 @@ def valid_input(self, vin_idx, vin): #return self.validate_p2sh(vin_idx, vin) pass elif scriptpubkey_type == "v0_p2wsh": + self.has_witness = True pass elif scriptpubkey_type == "v1_p2tr": pass elif scriptpubkey_type == "v0_p2wpkh": #return self.validate_p2wpkh(vin_idx, vin) - pass + #pass + self.has_witness = True + return True # Unknown script type. return False