tprotocol: Do task cleanup when peer disconnects. - obelisk - Electrum server using libbitcoin as its backend
git clone https://git.parazyd.org/obelisk
Log
Files
Refs
README
LICENSE
---
commit f6449ea78a20d6ef3d62d7dc00de34ec05bbfc10
parent 0c8ef25aea1b5e0bab605b950f77d93279861c9b
Author: parazyd 
Date:   Mon, 19 Apr 2021 17:59:10 +0200

protocol: Do task cleanup when peer disconnects.

Diffstat:
  M obelisk/protocol.py                 |     135 ++++++++++++++++++-------------
  M obelisk/zeromq.py                   |      10 +++++-----
  M tests/test_electrum_protocol.py     |      12 ++++++++++++

3 files changed, 94 insertions(+), 63 deletions(-)
---
diff --git a/obelisk/protocol.py b/obelisk/protocol.py
t@@ -65,13 +65,9 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
         self.endpoints = endpoints
         self.server_cfg = server_cfg
         self.loop = asyncio.get_event_loop()
-        self.chain_tip = 0
-        # Consider renaming bx to something else
         self.bx = Client(log, endpoints, self.loop)
         self.block_queue = None
-        # TODO: Clean up on client disconnect
-        self.tasks = []
-        self.sh_subscriptions = {}
+        self.peers = {}
 
         if chain == "mainnet":  # pragma: no cover
             self.genesis = "000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f"
t@@ -112,28 +108,32 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
         self.log.debug("ElectrumProtocol.stop()")
         self.stopped = True
         if self.bx:
-            # unsub_pool = []
-            # for i in self.sh_subscriptions:  # pragma: no cover
-            # self.log.debug("bx.unsubscribe %s", i)
-            # unsub_pool.append(self.bx.unsubscribe_scripthash(i))
-            # await asyncio.gather(*unsub_pool, return_exceptions=True)
+            for i in self.peers:
+                await self._peer_cleanup(i)
             await self.bx.stop()
 
-        # idxs = []
-        # for task in self.tasks:
-        # idxs.append(self.tasks.index(task))
-        # task.cancel()
-        # for i in idxs:
-        # del self.tasks[i]
+    async def _peer_cleanup(self, peer):
+        """Cleanup tasks and data for peer"""
+        self.log.debug("Cleaning up data for %s", peer)
+        for i in self.peers[peer]["tasks"]:
+            i.cancel()
+        for i in self.peers[peer]["sh"]:
+            self.peers[peer]["sh"][i]["task"].cancel()
+
+    @staticmethod
+    def _get_peer(writer):
+        peer_t = writer._transport.get_extra_info("peername")  # pylint: disable=W0212
+        return f"{peer_t[0]}:{peer_t[1]}"
 
     async def recv(self, reader, writer):
         """Loop ran upon a connection which acts as a JSON-RPC handler"""
         recv_buf = bytearray()
+        self.peers[self._get_peer(writer)] = {"tasks": [], "sh": {}}
+
         while not self.stopped:
             data = await reader.read(4096)
             if not data or len(data) == 0:
-                self.log.debug("Received EOF, disconnect")
-                # TODO: cancel asyncio tasks for this client here?
+                await self._peer_cleanup(self._get_peer(writer))
                 return
             recv_buf.extend(data)
             lb = recv_buf.find(b"\n")
t@@ -181,12 +181,7 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
 
     async def handle_query(self, writer, query):  # pylint: disable=R0915,R0912,R0911
         """Electrum protocol method handler mapper"""
-        if "method" not in query:
-            self.log.debug("No 'method' in query: %s", query)
-            return await self._send_reply(writer, JsonRPCError.invalidrequest(),
-                                          None)
-        if "id" not in query:
-            self.log.debug("No 'id' in query: %s", query)
+        if "method" not in query or "id" not in query:
             return await self._send_reply(writer, JsonRPCError.invalidrequest(),
                                           None)
 
t@@ -304,13 +299,11 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
         self.block_queue = asyncio.Queue()
         await self.bx.subscribe_to_blocks(self.block_queue)
         while True:
-            # item = (seq, height, block_data)
             item = await self.block_queue.get()
             if len(item) != 3:
                 self.log.debug("error: item from block queue len != 3")
                 continue
 
-            self.chain_tip = item[1]
             header = block_to_header(item[2])
             params = [{"height": item[1], "hex": safe_hexlify(header)}]
             await self._send_notification(writer,
t@@ -331,8 +324,8 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
             self.log.debug("Got error: %s", repr(_ec))
             return JsonRPCError.internalerror()
 
-        self.chain_tip = height
-        self.tasks.append(asyncio.create_task(self.header_notifier(writer)))
+        self.peers[self._get_peer(writer)]["tasks"].append(
+            asyncio.create_task(self.header_notifier(writer)))
         ret = {"height": height, "hex": safe_hexlify(tip_header)}
         return {"result": ret}
 
t@@ -428,32 +421,56 @@ class ElectrumProtocol(asyncio.Protocol):  # pylint: disable=R0904,R0902
 
         return {"result": ret}
 
+    async def scripthash_renewer(self, scripthash, queue):
+        while True:
+            try:
+                self.log.debug("scriphash renewer: %s", scripthash)
+                _ec = await self.bx.subscribe_scripthash(scripthash, queue)
+                if _ec and _ec != 0:
+                    self.log.error("bx.subscribe_scripthash failed: %s",
+                                   repr(_ec))
+                await asyncio.sleep(60)
+            except asyncio.CancelledError:
+                self.log.debug("%s renewer cancelled", scripthash)
+                break
+
     async def scripthash_notifier(self, writer, scripthash):
         # TODO: Mempool
+        # TODO: This is still flaky and not always notified. Investigate.
+        self.log.debug("notifier")
         method = "blockchain.scripthash.subscribe"
-        while True:
-            _ec, sh_queue = await self.bx.subscribe_scripthash(scripthash)
-            if _ec and _ec != 0:
-                self.log.error("bx.subscribe_scripthash failed: %s", repr(_ec))
-                return
-
-            item = await sh_queue.get()
-            _ec, height, txid = struct.unpack("
diff --git a/obelisk/zeromq.py b/obelisk/zeromq.py
t@@ -266,11 +266,11 @@ class Client:
         socket.connect(self._endpoints["query"])
         return socket
 
-    async def _subscription_request(self, command, data):
+    async def _subscription_request(self, command, data, queue):
         request = await self._request(command, data)
-        request.queue = asyncio.Queue()
+        request.queue = queue
         error_code, _ = await self._wait_for_response(request)
-        return error_code, request.queue
+        return error_code
 
     async def _simple_request(self, command, data):
         return await self._wait_for_response(await self._request(command, data))
t@@ -345,11 +345,11 @@ class Client:
             return error_code, None
         return error_code, data
 
-    async def subscribe_scripthash(self, scripthash):
+    async def subscribe_scripthash(self, scripthash, queue):
         """Subscribe to scripthash"""
         command = b"subscribe.key"
         decoded_address = unhexlify(scripthash)
-        return await self._subscription_request(command, decoded_address)
+        return await self._subscription_request(command, decoded_address, queue)
 
     async def unsubscribe_scripthash(self, scripthash):
         """Unsubscribe scripthash"""
diff --git a/tests/test_electrum_protocol.py b/tests/test_electrum_protocol.py
t@@ -399,11 +399,21 @@ async def test_send_reply(protocol, writer, method):
     assert_equal(writer.mock, expect)
 
 
+class MockTransport:
+
+    def __init__(self):
+        self.peername = ("foo", 42)
+
+    def get_extra_info(self, param):
+        return self.peername
+
+
 class MockWriter(asyncio.StreamWriter):  # pragma: no cover
     """Mock class for StreamWriter"""
 
     def __init__(self):
         self.mock = None
+        self._transport = MockTransport()
 
     def write(self, data):
         self.mock = data
t@@ -455,6 +465,8 @@ async def main():
     protocol = ElectrumProtocol(log, "testnet", libbitcoin, {})
     writer = MockWriter()
 
+    protocol.peers[protocol._get_peer(writer)] = {"tasks": [], "sh": {}}
+
     for func in orchestration:
         try:
             await orchestration[func](protocol, writer, func)