import asyncio
import logging
import socket
import re
from ..abc.sink import Sink
from .stream import Stream, TLSStream
#
L = logging.getLogger(__name__)
#
[docs]class StreamClientSink(Sink):
"""
Description:
"""
ConfigDefaults = {
'address': '127.0.0.1 8888', # IPv4, IPv6 or unix socket path
'outbound_queue_max_size': 100, # Maximum size of the output queue before throttling
}
[docs] def __init__(self, app, pipeline, id=None, config=None):
"""
Description:
"""
super().__init__(app, pipeline, id=id, config=config)
self.OutboundQueue = asyncio.Queue()
# Throttle till we are connected
self.Pipeline.throttle(self, enable=True)
# Maximum size for the queue
self.OutboundQueueMaxSize = int(self.Config['outbound_queue_max_size'])
assert (self.OutboundQueueMaxSize >= 1)
self.Task = None
self.Pipeline.PubSub.subscribe("bspump.pipeline.start!", self._open_connection)
self.Pipeline.PubSub.subscribe("bspump.pipeline.stop!", self._close_connection)
app.PubSub.subscribe("Application.tick!", self._on_tick)
async def _open_connection(self, message, pipeline):
"""
Description:
"""
# Connection is established
if self.Task is not None:
await self._close_connection(message, pipeline)
self.Task = self.Pipeline.Loop.create_task(
self._client_connected_task()
)
async def _close_connection(self, message, pipeline):
"""
Description:
"""
self.Pipeline.throttle(self, enable=True)
if self.Task is not None:
self.Task.cancel()
self.Task = None
def _on_tick(self, event_name):
"""
Description:
"""
# Unthrottle the queue if needed
if self.OutboundQueue in self.Pipeline.get_throttles() and self.OutboundQueue.qsize() < self.OutboundQueueMaxSize:
print("Unthrottling")
self.Pipeline.throttle(self.OutboundQueue, False)
if self.Task is not None and self.Task.done():
# We should be connected but we are not
# Let's do a bit of clean-up and commence reconnection
try:
self.Task.result()
except Exception:
L.exception("Error when handling client socket")
self.Task = self.Pipeline.Loop.create_task(
self._client_connected_task()
)
async def _client_connected_task(self):
"""
Description:
"""
loop = self.Pipeline.Loop
addr = self.Config['address']
if ' ' in addr:
addr = re.split(r"\s+", addr)
else:
# This line allows the (obsolete) format of IPv4 with ':'
# such as "0.0.0.0:8001"
addr = re.split(r"[:\s]", addr, 1)
host = addr.pop(0).strip()
port = addr.pop(0).strip()
port = int(port)
addrinfo = await loop.getaddrinfo(host, port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM)
client_sock = None
connection_exception = ''
for family, socktype, proto, canonname, sockaddr in addrinfo:
client_sock = socket.socket(family, socktype, proto)
client_sock.setblocking(False)
try:
await loop.sock_connect(client_sock, sockaddr)
except Exception as e:
connection_exception = e
client_sock = None
continue
if client_sock is None:
L.warning("Connection to '{}' failed: {}".format(addr, connection_exception))
return
# TODO: Support also TLSStream ...
stream = Stream(self.Pipeline.Loop, client_sock, outbound_queue=self.OutboundQueue)
outbound = loop.create_task(stream.outbound())
inbound = loop.create_task(self._client_inbound_task(client_sock))
self.Pipeline.throttle(self, enable=False)
try:
done, active = await asyncio.wait(
{outbound, inbound},
return_when=asyncio.FIRST_COMPLETED
)
except asyncio.CancelledError:
active = {outbound, inbound}
done = {}
finally:
self.Pipeline.throttle(self, enable=True)
# Cancel remaining active tasks
for t in active:
# TODO: There could be outstanding data in the outbound queue
# Consider flushing them
t.cancel()
# Collect results from completed tasks
for t in done:
try:
await t
except asyncio.CancelledError:
pass
except Exception:
L.exception("Error when handling client socket")
# Close the stream
await stream.close()
async def _client_inbound_task(self, client_sock):
"""
Description:
"""
while True:
data = await self.Pipeline.Loop.sock_recv(client_sock, 4096)
if len(data) == 0:
# Client closed the connection
return
# Incoming data are discarted ...
[docs] def process(self, context, event):
"""
Description:
"""
self.OutboundQueue.put_nowait(event)
if self.OutboundQueue.qsize() == self.OutboundQueueMaxSize:
print("Throttling")
self.Pipeline.throttle(self.OutboundQueue, True)