Skip to content

Commit

Permalink
* Added reconnect functionality.
Browse files Browse the repository at this point in the history
* Added `keepLinked` and `keepSynced`.
* Added error messages for missing host, node and lane uri when opening downlinks.
* Added method for waiting until the client is closed manually.
  • Loading branch information
DobromirM committed Aug 1, 2024
1 parent 67213ed commit def5508
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 40 deletions.
4 changes: 2 additions & 2 deletions swimos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .client import SwimClient
from .client import SwimClient, IntervalStrategy

__all__ = [SwimClient]
__all__ = [SwimClient, IntervalStrategy]
3 changes: 2 additions & 1 deletion swimos/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
# limitations under the License.

from ._swim_client import SwimClient
from ._connections import IntervalStrategy

__all__ = [SwimClient]
__all__ = [SwimClient, IntervalStrategy]
136 changes: 116 additions & 20 deletions swimos/client/_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,102 @@
# limitations under the License.

import asyncio
from abc import ABC, abstractmethod
import websockets

from enum import Enum
from websockets import ConnectionClosed
from swimos.warp._warp import _Envelope
from typing import TYPE_CHECKING, Any
from ._utils import exception_warn

if TYPE_CHECKING:
from ._downlinks._downlinks import _DownlinkModel
from ._downlinks._downlinks import _DownlinkView


class RetryStrategy(ABC):
@abstractmethod
async def retry(self) -> bool:
"""
Wait for a period of time that is defined by the retry strategy.
"""
raise NotImplementedError

@abstractmethod
def reset(self):
"""
Reset the retry strategy to its original state.
"""
raise NotImplementedError


class IntervalStrategy(RetryStrategy):

def __init__(self, retries_limit=None, delay=3) -> None:
super().__init__()
self.retries_limit = retries_limit
self.delay = delay
self.retries = 0

async def retry(self) -> bool:
if self.retries_limit is None or self.retries_limit >= self.retries:
await asyncio.sleep(self.delay)
self.retries += 1
return True
else:
return False

def reset(self):
self.retries = 0


class ExponentialStrategy(RetryStrategy):

def __init__(self, retries_limit=None, max_interval=16) -> None:
super().__init__()
self.retries_limit = retries_limit
self.max_interval = max_interval
self.retries = 0

async def retry(self) -> bool:
if self.retries_limit is None or self.retries_limit >= self.retries:
await asyncio.sleep(min(2 ** self.retries, self.max_interval))
self.retries += 1
return True
else:
return False

def reset(self):
self.retries = 0


class _ConnectionPool:

def __init__(self) -> None:
def __init__(self, retry_strategy: RetryStrategy = None) -> None:
self.__connections = dict()
self.retry_strategy = retry_strategy

@property
def _size(self) -> int:
return len(self.__connections)

async def _get_connection(self, host_uri: str, scheme: str) -> '_WSConnection':
async def _get_connection(self, host_uri: str, scheme: str, keep_linked: bool,
keep_synced: bool) -> '_WSConnection':
"""
Return a WebSocket connection to the given Host URI. If it is a new
host or the existing connection is closing, create a new connection.
:param host_uri: - URI of the connection host.
:param scheme: - URI scheme.
:param keep_linked: - Whether the link should be automatically re-established after connection failures.
:param keep_synced: - Whether the link should synchronize its state with the remote lane.
:return: - WebSocket connection.
"""
connection = self.__connections.get(host_uri)

if connection is None or connection.status == _ConnectionStatus.CLOSED:
connection = _WSConnection(host_uri, scheme)
connection = _WSConnection(host_uri, scheme, keep_linked, keep_synced, self.retry_strategy)
self.__connections[host_uri] = connection

return connection
Expand All @@ -70,7 +133,9 @@ async def _add_downlink_view(self, downlink_view: '_DownlinkView') -> None:
"""
host_uri = downlink_view._host_uri
scheme = downlink_view._scheme
connection = await self._get_connection(host_uri, scheme)
keep_linked = downlink_view._keep_linked
keep_synced = downlink_view._keep_synced
connection = await self._get_connection(host_uri, scheme, keep_linked, keep_synced)
downlink_view._connection = connection

await connection._subscribe(downlink_view)
Expand All @@ -95,29 +160,44 @@ async def _remove_downlink_view(self, downlink_view: '_DownlinkView') -> None:

class _WSConnection:

def __init__(self, host_uri: str, scheme: str) -> None:
def __init__(self, host_uri: str, scheme: str, keep_linked, keep_synced,
retry_strategy: RetryStrategy = None) -> None:
self.host_uri = host_uri
self.scheme = scheme
self.retry_strategy = retry_strategy

self.connected = asyncio.Event()
self.websocket = None
self.status = _ConnectionStatus.CLOSED
self.init_message = None

self.keep_linked = keep_linked
self.keep_synced = keep_synced

self.__subscribers = _DownlinkManagerPool()

async def _open(self) -> None:
if self.status == _ConnectionStatus.CLOSED:
self.status = _ConnectionStatus.CONNECTING

try:
if self.scheme == "wss":
self.websocket = await websockets.connect(self.host_uri, ssl=True)
else:
self.websocket = await websockets.connect(self.host_uri)
except Exception as error:
self.status = _ConnectionStatus.CLOSED
raise error

self.status = _ConnectionStatus.IDLE
while self.status == _ConnectionStatus.CONNECTING:
try:
if self.scheme == "wss":
self.websocket = await websockets.connect(self.host_uri, ssl=True)
self.retry_strategy.reset()
self.status = _ConnectionStatus.IDLE
else:
self.websocket = await websockets.connect(self.host_uri)
self.retry_strategy.reset()
self.status = _ConnectionStatus.IDLE
except Exception as error:
if self.keep_linked and await self.retry_strategy.retry():
exception_warn(error)
continue
else:
self.status = _ConnectionStatus.CLOSED
raise error

self.connected.set()

async def _close(self) -> None:
Expand All @@ -129,6 +209,20 @@ async def _close(self) -> None:
await self.websocket.close()
self.connected.clear()

def _set_init_message(self, message: str) -> None:
"""
Set the initial message that gets sent when the underlying downlink is established.
"""

self.init_message = message

async def _send_init_message(self) -> None:
"""
Send the initial message for the underlying downlink if it is set.
"""
if self.init_message is not None:
await self._send_message(self.init_message)

def _has_subscribers(self) -> bool:
"""
Check if the connection has any subscribers.
Expand Down Expand Up @@ -181,18 +275,20 @@ async def _wait_for_messages(self) -> None:
Wait for messages from the remote agent and propagate them
to all subscribers.
"""

if self.status == _ConnectionStatus.IDLE:
while self.status == _ConnectionStatus.IDLE:
self.status = _ConnectionStatus.RUNNING
try:
while self.status == _ConnectionStatus.RUNNING:
message = await self.websocket.recv()
response = _Envelope._parse_recon(message)
await self.__subscribers._receive_message(response)
# except:
# pass
finally:
except ConnectionClosed as error:
exception_warn(error)
await self._close()
if self.keep_linked and await self.retry_strategy.retry():
await self._open()
await self._send_init_message()
continue


class _ConnectionStatus(Enum):
Expand Down
51 changes: 44 additions & 7 deletions swimos/client/_downlinks/_downlinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, client: 'SwimClient') -> None:
self.host_uri = None
self.node_uri = None
self.lane_uri = None
self.keep_linked = None
self.keep_synced = None

self.task = None
self.connection = None
self.linked = asyncio.Event()
Expand Down Expand Up @@ -143,6 +146,9 @@ def __init__(self, client: 'SwimClient') -> None:
self._will_unlink_callback = None
self._did_unlink_callback = None

self._keep_linked = True
self._keep_synced = True

self.__registered_classes = dict()
self.__deregistered_classes = set()
self.__clear_classes = False
Expand Down Expand Up @@ -181,6 +187,13 @@ def registered_classes(self) -> dict:
return self._downlink_manager.registered_classes

def open(self) -> '_DownlinkView':
if self._host_uri is None:
raise Exception(f'Downlink cannot be opened without first setting the host URI!')
if self._node_uri is None:
raise Exception(f'Downlink cannot be opened without first setting the node URI!')
if self._lane_uri is None:
raise Exception(f'Downlink cannot be opened without first setting the lane URI!')

if not self._is_open:
task = self._client._schedule_task(self._client._add_downlink_view, self)
if task is not None:
Expand Down Expand Up @@ -210,6 +223,16 @@ def set_lane_uri(self, lane_uri: str) -> '_DownlinkView':
self._lane_uri = lane_uri
return self

@before_open
def keep_linked(self, keep_linked: bool) -> '_DownlinkView':
self._keep_linked = keep_linked
return self

@before_open
def keep_synced(self, keep_synced: bool) -> '_DownlinkView':
self._keep_synced = keep_synced
return self

def did_open(self, function: Callable) -> '_DownlinkView':
"""
Set the `did_open` callback of the current downlink view to a given function.
Expand Down Expand Up @@ -426,6 +449,8 @@ async def _initalise_model(self, manager: '_DownlinkManager', model: '_DownlinkM
model.host_uri = self._host_uri
model.node_uri = self._node_uri
model.lane_uri = self._lane_uri
model.keep_linked = self.keep_linked
model.keep_synced = self.keep_synced

async def _assign_manager(self, manager: '_DownlinkManager') -> None:
"""
Expand Down Expand Up @@ -463,8 +488,10 @@ def __register_class(self, custom_class: Any) -> None:
class _EventDownlinkModel(_DownlinkModel):

async def _establish_downlink(self) -> None:
link_request = _LinkRequest(self.node_uri, self.lane_uri)
await self.connection._send_message(link_request._to_recon())
request = _LinkRequest(self.node_uri, self.lane_uri)

self.connection._set_init_message(request._to_recon())
await self.connection._send_init_message()

async def _receive_event(self, message: _Envelope) -> None:
converter = RecordConverter.get_converter()
Expand Down Expand Up @@ -520,8 +547,13 @@ def __init__(self, client: 'SwimClient') -> None:
self._synced = asyncio.Event()

async def _establish_downlink(self) -> None:
sync_request = _SyncRequest(self.node_uri, self.lane_uri)
await self.connection._send_message(sync_request._to_recon())
if self.keep_synced:
request = _SyncRequest(self.node_uri, self.lane_uri)
else:
request = _LinkRequest(self.node_uri, self.lane_uri)

self.connection._set_init_message(request._to_recon())
await self.connection._send_init_message()

async def _receive_event(self, message: '_Envelope') -> None:
await self.__set_value(message)
Expand Down Expand Up @@ -550,7 +582,7 @@ async def _get_value(self) -> Any:

async def __set_value(self, message: '_Envelope') -> None:
"""
Set the value of the the downlink and trigger the `did_set` callback of the downlink subscribers.
Set the value of the downlink and trigger the `did_set` callback of the downlink subscribers.
:param message: - The message from the remote agent.
:return:
Expand Down Expand Up @@ -702,8 +734,13 @@ def __init__(self, client: 'SwimClient') -> None:
self._synced = asyncio.Event()

async def _establish_downlink(self) -> None:
sync_request = _SyncRequest(self.node_uri, self.lane_uri)
await self.connection._send_message(sync_request._to_recon())
if self.keep_synced:
request = _SyncRequest(self.node_uri, self.lane_uri)
else:
request = _LinkRequest(self.node_uri, self.lane_uri)

self.connection._set_init_message(request._to_recon())
await self.connection._send_init_message()

async def _receive_event(self, message: '_Envelope') -> None:
if message._body._tag == 'update':
Expand Down
3 changes: 2 additions & 1 deletion swimos/client/_downlinks/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def wrapper(*args, **kwargs):
return function(*args, **kwargs)
else:
try:
raise Exception(f'Cannot execute "{function.__name__}" before the downlink has been opened!')
raise Exception(
f'Cannot execute "{function.__name__}" before the downlink has been opened or after it has closed!')
except Exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
args[0]._client._handle_exception(exc_value, exc_traceback)
Expand Down
Loading

0 comments on commit def5508

Please sign in to comment.