Source code for bspump.ipc.stream_server_source

import ssl
import socket
import asyncio
import logging

from ..abc.source import Source

from .stream import Stream, TLSStream
from .protocol import LineSourceProtocol

#

L = logging.getLogger(__name__)

#


[docs]class StreamServerSource(Source): """ Description: """ ConfigDefaults = { 'address': '127.0.0.1 8888', # IPv4, IPv6 or unix socket path 'backlog': '', # Specify 'cert' or 'key' to enable SSL / TLS mode # An encoding a line is going to be decoded from # - Pass '' (empty string) to prevent decoding 'decode': 'utf-8', }
[docs] def __init__(self, app, pipeline, id=None, config=None, protocol_class=LineSourceProtocol): """ Description: """ super().__init__(app, pipeline, id=id, config=config) self.Address = self.Config['address'] if 'cert' in self.Config or 'key' in self.Config: import asab.net sslbuilder = asab.net.SSLContextBuilder('[none]', config=self.Config) self.SSL = sslbuilder.build(protocol=ssl.PROTOCOL_SSLv23) else: self.SSL = None self.AcceptingSockets = [] self.ConnectedClients = set() # Set of active _client_connected_task() self.Protocol = protocol_class(app, pipeline, config=self.Config) app.PubSub.subscribe("Application.tick!", self._on_tick)
[docs] def start(self, loop): """ Description: """ if self.Task is not None: return # Create all required sockets, bind them to specific ports and start listening for addrline in self.Address.split('\n'): addrline = addrline.strip() if addrline.count(":") == 1: host, port = self.Address.rsplit(":", maxsplit=1) addrinfo = socket.getaddrinfo(host, port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM, flags=socket.AI_PASSIVE) for family, socktype, proto, canonname, sockaddr in addrinfo: s = socket.socket(family, socktype, proto) try: s.bind(sockaddr) except OSError as e: L.warning("Failed to start listening at '{}': {}".format(addrline, e)) continue backlog = self.Config['backlog'] if backlog == '': s.listen() else: s.listen(int(backlog)) s.setblocking(False) self.AcceptingSockets.append(s) else: self.Socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) self.Socket.setblocking(False) self.Socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.Socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) self.Socket.bind(self.Address) L.error("Invalid address specification: '{}'".format(addrline)) super().start(loop)
[docs] async def stop(self): """ Description: """ # Close client connections for t in self.ConnectedClients: t.cancel() self.ConnectedClients = set() # The main() will be canceled in the parent class await super().stop()
[docs] async def main(self): """ Description: """ if len(self.AcceptingSockets) == 0: L.error("No listening socket configured") return await asyncio.gather( *[ self._handle_accept(sock) for sock in self.AcceptingSockets ], return_exceptions=True )
async def _handle_accept(self, sock): """ Description: """ loop = self.Pipeline.App.Loop server_addr = sock.getsockname() while True: client_sock, client_addr = await loop.sock_accept(sock) t = loop.create_task( self._client_connected_task(client_sock, client_addr, server_addr) ) self.ConnectedClients.add(t) async def _client_connected_task(self, client_sock, client_addr, server_addr): """ Description: """ client_sock.setblocking(False) if client_sock.family is socket.AF_INET: me = '{} {}'.format(server_addr[0], server_addr[1]) peer = '{} {}'.format(client_addr[0], client_addr[1]) elif client_sock.family is socket.AF_INET6: me = '{} {}'.format(server_addr[0], server_addr[1]) peer = '{} {}'.format(client_addr[0], client_addr[1]) else: me = server_addr peer = client_addr context = { 'stream_type': client_sock.family.name, 'stream_dir': 'in', 'stream_peer': peer, 'stream_me': me, } if self.SSL is not None: stream = TLSStream(self.Pipeline.App.Loop, self.SSL, client_sock, server_side=True) ok = await stream.handshake() if not ok: return else: stream = Stream(self.Pipeline.App.Loop, client_sock) # This allows to send a reply to a client context['stream'] = stream inbound = self.Protocol.handle(self, stream, context) outbound = stream.outbound() done, active = await asyncio.wait( {inbound, outbound}, return_when=asyncio.FIRST_COMPLETED ) for t in active: # TODO: There could be outstanding data in the outbound queue # Consider flushing them # Cancel remaining active tasks t.cancel() for t in done: try: await t except asyncio.CancelledError: pass except Exception: L.exception("Error when handling client socket") # Close the stream await stream.close() def _on_tick(self, event_name): # Remove clients that disconnected disconnected_client_tasks = [*filter( lambda task: task.done(), self.ConnectedClients )] for task in disconnected_client_tasks: if task.done(): self.ConnectedClients.remove(task) try: task.result() except asyncio.CancelledError: pass except Exception: L.exception("Exception when handling client socket")