mirror of
https://github.com/FreeOpcUa/opcua-asyncio
synced 2025-10-29 17:07:18 +08:00
* implement oroulet changes * closing_task must be a task not onl run once * do not catch CancelleError in closing_Task * make sure we use the correct loop in asyncio binary_server Co-authored-by: Joey Faulkner <joeymfaulkner@gmail.com>
171 lines
6.3 KiB
Python
171 lines
6.3 KiB
Python
"""
|
|
Socket server forwarding request to internal server
|
|
"""
|
|
import logging
|
|
import asyncio
|
|
from typing import Optional
|
|
|
|
from ..ua.ua_binary import header_from_binary
|
|
from ..common.utils import Buffer, NotEnoughData
|
|
from .uaprocessor import UaProcessor
|
|
from .internal_server import InternalServer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OPCUAProtocol(asyncio.Protocol):
|
|
"""
|
|
Instantiated for every connection.
|
|
"""
|
|
|
|
def __init__(self, iserver: InternalServer, policies, clients, closing_tasks):
|
|
self.peer_name = None
|
|
self.transport = None
|
|
self.processor = None
|
|
self._buffer = b''
|
|
self.iserver: InternalServer = iserver
|
|
self.policies = policies
|
|
self.clients = clients
|
|
self.closing_tasks = closing_tasks
|
|
self.messages = asyncio.Queue()
|
|
self._task = None
|
|
|
|
def __str__(self):
|
|
return f'OPCUAProtocol({self.peer_name}, {self.processor.session})'
|
|
|
|
__repr__ = __str__
|
|
|
|
def connection_made(self, transport):
|
|
self.peer_name = transport.get_extra_info('peername')
|
|
logger.info('New connection from %s', self.peer_name)
|
|
self.transport = transport
|
|
self.processor = UaProcessor(self.iserver, self.transport)
|
|
self.processor.set_policies(self.policies)
|
|
self.iserver.asyncio_transports.append(transport)
|
|
self.clients.append(self)
|
|
self._task = self.iserver.loop.create_task(self._process_received_message_loop())
|
|
|
|
def connection_lost(self, ex):
|
|
logger.info('Lost connection from %s, %s', self.peer_name, ex)
|
|
self.transport.close()
|
|
self.iserver.asyncio_transports.remove(self.transport)
|
|
closing_task = self.iserver.loop.create_task(self.processor.close())
|
|
self.closing_tasks.append(closing_task)
|
|
if self in self.clients:
|
|
self.clients.remove(self)
|
|
self.messages.put_nowait((None, None))
|
|
self._task.cancel()
|
|
|
|
def data_received(self, data):
|
|
self._buffer += data
|
|
# try to parse the incoming data
|
|
while self._buffer:
|
|
try:
|
|
buf = Buffer(self._buffer)
|
|
try:
|
|
header = header_from_binary(buf)
|
|
except NotEnoughData:
|
|
logger.debug('Not enough data while parsing header from client, waiting for more')
|
|
return
|
|
if len(buf) < header.body_size:
|
|
logger.debug('We did not receive enough data from client. Need %s got %s', header.body_size,
|
|
len(buf))
|
|
return
|
|
# we have a complete message
|
|
self.messages.put_nowait((header, buf))
|
|
self._buffer = self._buffer[(header.header_size + header.body_size):]
|
|
except Exception:
|
|
logger.exception('Exception raised while parsing message from client')
|
|
return
|
|
|
|
async def _process_received_message_loop(self):
|
|
"""
|
|
Take message from the queue and try to process it.
|
|
"""
|
|
while True:
|
|
header, buf = await self.messages.get()
|
|
if header is None and buf is None:
|
|
# Connection was closed, end task
|
|
break
|
|
try:
|
|
await self._process_one_msg(header, buf)
|
|
except Exception:
|
|
logger.exception('Exception raised while processing message from client')
|
|
|
|
async def _process_one_msg(self, header, buf):
|
|
logger.debug('_process_received_message %s %s', header.body_size, len(buf))
|
|
ret = await self.processor.process(header, buf)
|
|
if not ret:
|
|
logger.info('processor returned False, we close connection from %s', self.peer_name)
|
|
self.transport.close()
|
|
return
|
|
|
|
|
|
class BinaryServer:
|
|
def __init__(self, internal_server: InternalServer, hostname, port):
|
|
self.logger = logging.getLogger(__name__)
|
|
self.hostname = hostname
|
|
self.port = port
|
|
self.iserver: InternalServer = internal_server
|
|
self._server: Optional[asyncio.AbstractServer] = None
|
|
self._policies = []
|
|
self.clients = []
|
|
self.closing_tasks = []
|
|
self.cleanup_task = None
|
|
|
|
def set_policies(self, policies):
|
|
self._policies = policies
|
|
|
|
def _make_protocol(self):
|
|
"""Protocol Factory"""
|
|
return OPCUAProtocol(
|
|
iserver=self.iserver,
|
|
policies=self._policies,
|
|
clients=self.clients,
|
|
closing_tasks=self.closing_tasks,
|
|
)
|
|
|
|
async def start(self):
|
|
self._server = await self.iserver.loop.create_server(self._make_protocol, self.hostname, self.port)
|
|
# get the port and the hostname from the created server socket
|
|
# only relevant for dynamic port asignment (when self.port == 0)
|
|
if self.port == 0 and len(self._server.sockets) == 1:
|
|
# will work for AF_INET and AF_INET6 socket names
|
|
# these are to only families supported by the create_server call
|
|
sockname = self._server.sockets[0].getsockname()
|
|
self.hostname = sockname[0]
|
|
self.port = sockname[1]
|
|
self.logger.info('Listening on %s:%s', self.hostname, self.port)
|
|
self.cleanup_task = self.iserver.loop.create_task(self._await_closing_tasks())
|
|
|
|
async def stop(self):
|
|
self.logger.info('Closing asyncio socket server')
|
|
for transport in self.iserver.asyncio_transports:
|
|
transport.close()
|
|
|
|
# stop cleanup process and run it a last time
|
|
self.cleanup_task.cancel()
|
|
try:
|
|
await self.cleanup_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
await self._await_closing_tasks(recursive=False)
|
|
|
|
if self._server:
|
|
self.iserver.loop.call_soon(self._server.close)
|
|
await self._server.wait_closed()
|
|
|
|
async def _await_closing_tasks(self, recursive=True):
|
|
while self.closing_tasks:
|
|
task = self.closing_tasks.pop()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
# this means a stop request has been sent, it should not be catched
|
|
raise
|
|
except Exception:
|
|
logger.exception("Unexpected crash in BinaryServer._await_closing_tasks")
|
|
if recursive:
|
|
await asyncio.sleep(10)
|
|
await self._await_closing_tasks()
|