| t@@ -65,13 +65,9 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
self.endpoints = endpoints
self.server_cfg = server_cfg
self.loop = asyncio.get_event_loop()
- self.chain_tip = 0
- # Consider renaming bx to something else
self.bx = Client(log, endpoints, self.loop)
self.block_queue = None
- # TODO: Clean up on client disconnect
- self.tasks = []
- self.sh_subscriptions = {}
+ self.peers = {}
if chain == "mainnet": # pragma: no cover
self.genesis = "000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f"
t@@ -112,28 +108,32 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
self.log.debug("ElectrumProtocol.stop()")
self.stopped = True
if self.bx:
- # unsub_pool = []
- # for i in self.sh_subscriptions: # pragma: no cover
- # self.log.debug("bx.unsubscribe %s", i)
- # unsub_pool.append(self.bx.unsubscribe_scripthash(i))
- # await asyncio.gather(*unsub_pool, return_exceptions=True)
+ for i in self.peers:
+ await self._peer_cleanup(i)
await self.bx.stop()
- # idxs = []
- # for task in self.tasks:
- # idxs.append(self.tasks.index(task))
- # task.cancel()
- # for i in idxs:
- # del self.tasks[i]
+ async def _peer_cleanup(self, peer):
+ """Cleanup tasks and data for peer"""
+ self.log.debug("Cleaning up data for %s", peer)
+ for i in self.peers[peer]["tasks"]:
+ i.cancel()
+ for i in self.peers[peer]["sh"]:
+ self.peers[peer]["sh"][i]["task"].cancel()
+
+ @staticmethod
+ def _get_peer(writer):
+ peer_t = writer._transport.get_extra_info("peername") # pylint: disable=W0212
+ return f"{peer_t[0]}:{peer_t[1]}"
async def recv(self, reader, writer):
"""Loop ran upon a connection which acts as a JSON-RPC handler"""
recv_buf = bytearray()
+ self.peers[self._get_peer(writer)] = {"tasks": [], "sh": {}}
+
while not self.stopped:
data = await reader.read(4096)
if not data or len(data) == 0:
- self.log.debug("Received EOF, disconnect")
- # TODO: cancel asyncio tasks for this client here?
+ await self._peer_cleanup(self._get_peer(writer))
return
recv_buf.extend(data)
lb = recv_buf.find(b"\n")
t@@ -181,12 +181,7 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
async def handle_query(self, writer, query): # pylint: disable=R0915,R0912,R0911
"""Electrum protocol method handler mapper"""
- if "method" not in query:
- self.log.debug("No 'method' in query: %s", query)
- return await self._send_reply(writer, JsonRPCError.invalidrequest(),
- None)
- if "id" not in query:
- self.log.debug("No 'id' in query: %s", query)
+ if "method" not in query or "id" not in query:
return await self._send_reply(writer, JsonRPCError.invalidrequest(),
None)
t@@ -304,13 +299,11 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
self.block_queue = asyncio.Queue()
await self.bx.subscribe_to_blocks(self.block_queue)
while True:
- # item = (seq, height, block_data)
item = await self.block_queue.get()
if len(item) != 3:
self.log.debug("error: item from block queue len != 3")
continue
- self.chain_tip = item[1]
header = block_to_header(item[2])
params = [{"height": item[1], "hex": safe_hexlify(header)}]
await self._send_notification(writer,
t@@ -331,8 +324,8 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
self.log.debug("Got error: %s", repr(_ec))
return JsonRPCError.internalerror()
- self.chain_tip = height
- self.tasks.append(asyncio.create_task(self.header_notifier(writer)))
+ self.peers[self._get_peer(writer)]["tasks"].append(
+ asyncio.create_task(self.header_notifier(writer)))
ret = {"height": height, "hex": safe_hexlify(tip_header)}
return {"result": ret}
t@@ -428,32 +421,56 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
return {"result": ret}
+ async def scripthash_renewer(self, scripthash, queue):
+ while True:
+ try:
+ self.log.debug("scriphash renewer: %s", scripthash)
+ _ec = await self.bx.subscribe_scripthash(scripthash, queue)
+ if _ec and _ec != 0:
+ self.log.error("bx.subscribe_scripthash failed: %s",
+ repr(_ec))
+ await asyncio.sleep(60)
+ except asyncio.CancelledError:
+ self.log.debug("%s renewer cancelled", scripthash)
+ break
+
async def scripthash_notifier(self, writer, scripthash):
# TODO: Mempool
+ # TODO: This is still flaky and not always notified. Investigate.
+ self.log.debug("notifier")
method = "blockchain.scripthash.subscribe"
- while True:
- _ec, sh_queue = await self.bx.subscribe_scripthash(scripthash)
- if _ec and _ec != 0:
- self.log.error("bx.subscribe_scripthash failed: %s", repr(_ec))
- return
-
- item = await sh_queue.get()
- _ec, height, txid = struct.unpack(" |
| t@@ -266,11 +266,11 @@ class Client:
socket.connect(self._endpoints["query"])
return socket
- async def _subscription_request(self, command, data):
+ async def _subscription_request(self, command, data, queue):
request = await self._request(command, data)
- request.queue = asyncio.Queue()
+ request.queue = queue
error_code, _ = await self._wait_for_response(request)
- return error_code, request.queue
+ return error_code
async def _simple_request(self, command, data):
return await self._wait_for_response(await self._request(command, data))
t@@ -345,11 +345,11 @@ class Client:
return error_code, None
return error_code, data
- async def subscribe_scripthash(self, scripthash):
+ async def subscribe_scripthash(self, scripthash, queue):
"""Subscribe to scripthash"""
command = b"subscribe.key"
decoded_address = unhexlify(scripthash)
- return await self._subscription_request(command, decoded_address)
+ return await self._subscription_request(command, decoded_address, queue)
async def unsubscribe_scripthash(self, scripthash):
"""Unsubscribe scripthash""" |
| t@@ -399,11 +399,21 @@ async def test_send_reply(protocol, writer, method):
assert_equal(writer.mock, expect)
+class MockTransport:
+
+ def __init__(self):
+ self.peername = ("foo", 42)
+
+ def get_extra_info(self, param):
+ return self.peername
+
+
class MockWriter(asyncio.StreamWriter): # pragma: no cover
"""Mock class for StreamWriter"""
def __init__(self):
self.mock = None
+ self._transport = MockTransport()
def write(self, data):
self.mock = data
t@@ -455,6 +465,8 @@ async def main():
protocol = ElectrumProtocol(log, "testnet", libbitcoin, {})
writer = MockWriter()
+ protocol.peers[protocol._get_peer(writer)] = {"tasks": [], "sh": {}}
+
for func in orchestration:
try:
await orchestration[func](protocol, writer, func) |