tAdd address and header sync to protocol tests - electrum-personal-server - Maximally lightweight electrum server for a single user
git clone https://git.parazyd.org/electrum-personal-server
Log
Files
Refs
README
---
commit 76cbd8a2c618bc1b2fe72770fad76a2f324b5a8d
parent b38006736d604f06e3feb9e988bf20362e382c69
Author: chris-belcher 
Date:   Sat, 28 Dec 2019 14:51:16 +0000

Add address and header sync to protocol tests

Diffstat:
  M electrumpersonalserver/server/comm… |      18 +++++++++++++++---
  M electrumpersonalserver/server/elec… |      20 ++++++--------------
  M test/test_electrum_protocol.py      |     167 ++++++++++++++++++++++++-------

3 files changed, 152 insertions(+), 53 deletions(-)
---
diff --git a/electrumpersonalserver/server/common.py b/electrumpersonalserver/server/common.py
t@@ -7,6 +7,7 @@ import os.path
 import logging
 import tempfile
 import platform
+import json
 from configparser import RawConfigParser, NoSectionError, NoOptionError
 from ipaddress import ip_network, ip_address
 
t@@ -126,9 +127,14 @@ def run_electrum_server(rpc, txmonitor, config):
                 except (ConnectionRefusedError, ssl.SSLError):
                     sock.close()
                     sock = None
-
             logger.info('Electrum connected from ' + str(addr[0]))
-            protocol.set_send_line_fun(lambda l: sock.sendall(l + b'\n'))
+
+            def send_reply_fun(reply):
+                line = json.dumps(reply)
+                sock.sendall(line.encode('utf-8') + b'\n')
+                logger.debug('<= ' + line)
+            protocol.set_send_reply_fun(send_reply_fun)
+
             sock.settimeout(poll_interval_connected)
             recv_buffer = bytearray()
             while True:
t@@ -144,7 +150,13 @@ def run_electrum_server(rpc, txmonitor, config):
                         line = recv_buffer[:lb].rstrip()
                         recv_buffer = recv_buffer[lb + 1:]
                         lb = recv_buffer.find(b'\n')
-                        protocol.handle_query(line.decode("utf-8"))
+                        line = line.decode("utf-8")
+                        logger.debug("=> " + line)
+                        try:
+                            query = json.loads(line)
+                        except json.decoder.JSONDecodeError as e:
+                            raise IOError(repr(e))
+                        protocol.handle_query(query)
                 except socket.timeout:
                     on_heartbeat_connected(rpc, txmonitor, protocol)
         except (IOError, EOFError) as e:
diff --git a/electrumpersonalserver/server/electrumprotocol.py b/electrumpersonalserver/server/electrumprotocol.py
t@@ -123,8 +123,8 @@ class ElectrumProtocol(object):
         self.are_headers_raw = False
         self.txid_blockhash_map = {}
 
-    def set_send_line_fun(self, send_line_fun):
-        self.send_line_fun = send_line_fun
+    def set_send_reply_fun(self, send_reply_fun):
+        self.send_reply_fun = send_reply_fun
 
     def on_blockchain_tip_updated(self, header):
         if self.subscribed_to_headers:
t@@ -145,25 +145,17 @@ class ElectrumProtocol(object):
 
     def _send_response(self, query, result):
         response = {"jsonrpc": "2.0", "result": result, "id": query["id"]}
-        self.send_line_fun(json.dumps(response).encode('utf-8'))
-        self.logger.debug('<= ' + json.dumps(response))
+        self.send_reply_fun(response)
 
     def _send_update(self, update):
         update["jsonrpc"] = "2.0"
-        self.send_line_fun(json.dumps(update).encode('utf-8'))
-        self.logger.debug('<= ' + json.dumps(update))
+        self.send_reply_fun(update)
 
     def _send_error(self, nid, error):
         payload = {"error": error, "jsonrpc": "2.0", "id": nid}
-        self.send_line_fun(json.dumps(payload).encode('utf-8'))
-        self.logger.debug('<= ' + json.dumps(payload))
+        self.send_reply_fun(payload)
 
-    def handle_query(self, line):
-        self.logger.debug("=> " + line)
-        try:
-            query = json.loads(line)
-        except json.decoder.JSONDecodeError as e:
-            raise IOError(e)
+    def handle_query(self, query):
         if "method" not in query:
             raise IOError("Bad client query, no \"method\"")
         method = query["method"]
diff --git a/test/test_electrum_protocol.py b/test/test_electrum_protocol.py
t@@ -10,22 +10,29 @@ from electrumpersonalserver.server import (
     get_block_header,
     get_current_header,
     get_block_headers_hex,
-    JsonRpcError
+    JsonRpcError,
+    get_status_electrum
 )
 
 logger = logging.getLogger('ELECTRUMPERSONALSERVER-TEST')
 logger.setLevel(logging.DEBUG)
 
+DUMMY_JSONRPC_BLOCKCHAIN_HEIGHT = 100000
+
 def get_dummy_hash_from_height(height):
+    if height == 0:
+        return "00"*32
     return str(height) + "a"*(64 - len(str(height)))
 
 def get_height_from_dummy_hash(hhash):
+    if hhash == "00"*32:
+        return 0
     return int(hhash[:hhash.index("a")])
 
 class DummyJsonRpc(object):
     def __init__(self):
         self.calls = {}
-        self.blockchain_height = 100000
+        self.blockchain_height = DUMMY_JSONRPC_BLOCKCHAIN_HEIGHT
 
     def call(self, method, params):
         if method not in self.calls:
t@@ -58,9 +65,15 @@ class DummyJsonRpc(object):
                     + "00000000000000000da",
                 "nTx": 1,
             }
-            if height > 0:
+            if height > 1:
                 header["previousblockhash"] = get_dummy_hash_from_height(
                     height - 1)
+            elif height == 1:
+                header["previousblockhash"] = "00"*32 #genesis block
+            elif height == 0:
+                pass #no prevblock for genesis
+            else:
+                assert 0
             if height < self.blockchain_height:
                 header["nextblockhash"] = get_dummy_hash_from_height(height + 1)
             return header
t@@ -102,26 +115,29 @@ def test_get_current_header():
             assert type(ret[1]) == dict
             assert len(ret[1]) == 7
 
-def test_get_block_headers_hex_out_of_bounds():
-    rpc = DummyJsonRpc()
-    ret = get_block_headers_hex(rpc, rpc.blockchain_height + 10, 5)
-    assert len(ret) == 2
-    assert ret[0] == ""
-    assert ret[1] == 0
-
-def test_get_block_headers_hex():
+@pytest.mark.parametrize(
+    "start_height, count",
+    [(100, 200),
+    (DUMMY_JSONRPC_BLOCKCHAIN_HEIGHT + 10, 5),
+    (DUMMY_JSONRPC_BLOCKCHAIN_HEIGHT - 10, 15),
+    (0, 250)
+    ]
+)
+def test_get_block_headers_hex(start_height, count):
     rpc = DummyJsonRpc()
-    count = 200
-    ret = get_block_headers_hex(rpc, 100, count)
+    ret = get_block_headers_hex(rpc, start_height, count)
+    print("start_height=" + str(start_height) + " count=" + str(count))
     assert len(ret) == 2
-    assert ret[1] == count
-    assert len(ret[0]) == count*80*2 #80 bytes per header, 2 chars per byte
+    available_blocks = -min(0, start_height - DUMMY_JSONRPC_BLOCKCHAIN_HEIGHT
+        - 1)
+    expected_count = min(available_blocks, count)
+    assert len(ret[0]) == expected_count*80*2 #80 bytes/header, 2 chars/byte
+    assert ret[1] == expected_count
 
 @pytest.mark.parametrize(
     "invalid_json_query",
     [
-        "{\"invalid-json\":}",
-        "{\"valid-json-no-method\": 5}"
+        {"valid-json-no-method": 5}
     ]
 ) 
 def test_invalid_json_query_line(invalid_json_query):
t@@ -129,43 +145,122 @@ def test_invalid_json_query_line(invalid_json_query):
     with pytest.raises(IOError) as e:
         protocol.handle_query(invalid_json_query)
 
+def create_electrum_protocol_instance(broadcast_method="own-node",
+        tor_hostport=("127.0.0.1", 9050),
+        disable_mempool_fee_histogram=False):
+    protocol = ElectrumProtocol(DummyJsonRpc(), DummyTransactionMonitor(),
+        logger, broadcast_method, tor_hostport, disable_mempool_fee_histogram)
+    sent_replies = []
+    protocol.set_send_reply_fun(lambda l: sent_replies.append(l))
+    assert len(sent_replies) == 0
+    return protocol, sent_replies
+
+def dummy_script_hash_to_history(scrhash):
+    index = int(scrhash[:scrhash.index("s")])
+    tx_count = (index+2) % 5
+    height = 500
+    return [(index_to_dummy_txid(i), height) for i in range(tx_count)]
+
+def index_to_dummy_script_hash(index):
+    return str(index) + "s"*(64 - len(str(index)))
+
+def index_to_dummy_txid(index):
+    return str(index) + "t"*(64 - len(str(index)))
+
+def dummy_txid_to_dummy_tx(txid):
+    return txid[::-1] * 6
+
 class DummyTransactionMonitor(object):
     def __init__(self):
         self.deterministic_wallets = list(range(5))
         self.address_history = list(range(5))
+        self.subscribed_addresses = []
+        self.history_hashes = {}
 
     def get_electrum_history_hash(self, scrhash):
-        pass
+        history = dummy_script_hash_to_history(scrhash)
+        hhash = get_status_electrum(history)
+        self.history_hashes[scrhash] = history
+        return hhash
 
     def get_electrum_history(self, scrhash):
-        pass
+        return self.history_hashes[scrhash]
 
     def unsubscribe_all_addresses(self):
-        pass
+        self.subscribed_addresses = []
 
     def subscribe_address(self, scrhash):
-        pass
+        self.subscribed_addresses.append(scrhash)
+        return True
 
     def get_address_balance(self, scrhash):
         pass
 
-def create_electrum_protocol_instance(broadcast_method="own-node",
-        tor_hostport=("127.0.0.01", 9050),
-        disable_mempool_fee_histogram=False):
-    protocol = ElectrumProtocol(DummyJsonRpc(), DummyTransactionMonitor(),
-        logger, broadcast_method, tor_hostport, disable_mempool_fee_histogram)
-    sent_lines = []
-    protocol.set_send_line_fun(lambda l: sent_lines.append(json.loads(
-        l.decode())))
-    return protocol, sent_lines
+def test_script_hash_sync():
+    protocol, sent_replies = create_electrum_protocol_instance()
+    scrhash_index = 0
+    scrhash = index_to_dummy_script_hash(scrhash_index)
+    protocol.handle_query({"method": "blockchain.scripthash.subscribe",
+        "params": [scrhash], "id": 0})
+    assert len(sent_replies) == 1
+    assert len(protocol.txmonitor.subscribed_addresses) == 1
+    assert protocol.txmonitor.subscribed_addresses[0] == scrhash
+    assert len(sent_replies) == 1
+    assert len(sent_replies[0]["result"]) == 64
+    history_hash = sent_replies[0]["result"]
+
+    protocol.handle_query({"method": "blockchain.scripthash.get_history",
+        "params": [scrhash], "id": 0})
+    assert len(sent_replies) == 2
+    assert get_status_electrum(sent_replies[1]["result"]) == history_hash
+
+    #updated scripthash but actually nothing changed, history_hash unchanged
+    protocol.on_updated_scripthashes([scrhash])
+    assert len(sent_replies) == 3
+    assert sent_replies[2]["method"] == "blockchain.scripthash.subscribe"
+    assert sent_replies[2]["params"][0] == scrhash
+    assert sent_replies[2]["params"][1] == history_hash
+
+    protocol.on_disconnect()
+    assert len(protocol.txmonitor.subscribed_addresses) == 0
+
+def test_headers_subscribe():
+    protocol, sent_replies = create_electrum_protocol_instance()
+
+    protocol.handle_query({"method": "server.version", "params": ["test-code",
+        1.4], "id": 0}) #protocol version of 1.4 means only raw headers used
+    assert len(sent_replies) == 1
+
+    protocol.handle_query({"method": "blockchain.headers.subscribe", "params":
+        [], "id": 0})
+    assert len(sent_replies) == 2
+    assert "height" in sent_replies[1]["result"]
+    assert sent_replies[1]["result"]["height"] == protocol.rpc.blockchain_height
+    assert "hex" in sent_replies[1]["result"]
+    assert len(sent_replies[1]["result"]["hex"]) == 80*2 #80 b/header, 2 b/char
+
+    protocol.rpc.blockchain_height += 1
+    new_bestblockhash, header = get_current_header(protocol.rpc,
+        protocol.are_headers_raw)
+    protocol.on_blockchain_tip_updated(header)
+    assert len(sent_replies) == 3
+    assert "method" in sent_replies[2]
+    assert sent_replies[2]["method"] == "blockchain.headers.subscribe"
+    assert "params" in sent_replies[2]
+    assert "height" in sent_replies[2]["params"][0]
+    assert sent_replies[2]["params"][0]["height"]\
+        == protocol.rpc.blockchain_height
+    assert "hex" in sent_replies[2]["params"][0]
+    assert len(sent_replies[2]["params"][0]["hex"]) == 80*2 #80 b/header, 2 b/c
 
 def test_server_ping():
-    protocol, sent_lines = create_electrum_protocol_instance()
+    protocol, sent_replies = create_electrum_protocol_instance()
     idd = 1
-    protocol.handle_query(json.dumps({"method": "server.ping", "id": idd}))
-    assert len(sent_lines) == 1
-    assert sent_lines[0]["result"] == None
-    assert sent_lines[0]["id"] == idd
-
+    protocol.handle_query({"method": "server.ping", "id": idd})
+    assert len(sent_replies) == 1
+    assert sent_replies[0]["result"] == None
+    assert sent_replies[0]["id"] == idd
 
+#test scripthash.subscribe, scripthash.get_history transaction.get
+# transaction.get_merkle