diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index bc44d0c..fed0c57 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -1,5 +1,5 @@ -# This workflow will install Python dependencies and run tests with PyTest using Python 3.8 -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions +# This workflow will install Python dependencies and run tests with PyTest using Python 3.11 +# For more information see: https://docs.github.com/en/actions/about-github-actions name: Run tests @@ -16,23 +16,29 @@ jobs: steps: # Checkout repository - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 + with: + submodules: recursive # Set Python version - - name: Set up Python 3.8 - uses: actions/setup-python@v4 + - name: Set up Python 3.11 + uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.11 + + # Upgrade pip + - name: Upgrade pip + run: | + python -m pip install --upgrade pip # Set up submodules - name: Set up all submodules and project dependencies run: | - git submodule update --init --recursive --remote + git submodule foreach --recursive "pip install -r requirements.txt" # Install project dependencies - name: Install project dependencies run: | - python -m pip install --upgrade pip pip install -r requirements.txt # Install zbar library to resolve pyzbar import error @@ -46,6 +52,6 @@ jobs: flake8 . pylint . - # Install dependencies and run tests with PyTest - - name: Run PyTest + # Run unit tests with PyTest + - name: Run unit tests run: pytest -vv diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..a7b24fe --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "mavlink/dronekit"] + path = mavlink/dronekit + url = https://github.com/UWARG/dronekit.git + branch = WARG-minimal diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/camera/test_camera.py b/camera/test_camera.py index b651d0c..f4de6fd 100644 --- a/camera/test_camera.py +++ b/camera/test_camera.py @@ -6,7 +6,7 @@ import cv2 -from camera.modules.camera_device import CameraDevice +from .modules.camera_device import CameraDevice IMAGE_LOG_PREFIX = pathlib.Path("logs", "log_image") @@ -16,7 +16,9 @@ def main() -> int: """ Main function. """ - device = CameraDevice(0, 100, IMAGE_LOG_PREFIX) + device = CameraDevice(0, 100, str(IMAGE_LOG_PREFIX)) + + IMAGE_LOG_PREFIX.parent.mkdir(parents=True, exist_ok=True) while True: result, image = device.get_image() diff --git a/camera_qr_example.py b/camera_qr_example.py index dbfff71..f035c3e 100644 --- a/camera_qr_example.py +++ b/camera_qr_example.py @@ -4,8 +4,8 @@ import cv2 -from camera.modules import camera_device -from qr.modules import qr_scanner +from .camera.modules import camera_device +from .qr.modules import qr_scanner def main() -> int: diff --git a/image_encoding/__init__.py b/image_encoding/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_encoding/modules/__init__.py b/image_encoding/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_encoding/modules/decoder.py b/image_encoding/modules/decoder.py index a617be2..6aaaf21 100644 --- a/image_encoding/modules/decoder.py +++ b/image_encoding/modules/decoder.py @@ -2,15 +2,13 @@ Decodes images from JPEG bytes to numpy array. """ -# Used in type annotation of flight interface output -# pylint: disable-next=unused-import import io from PIL import Image import numpy as np -def decode(data: "io.BytesIO | bytes") -> np.ndarray: +def decode(data: "bytes") -> np.ndarray: """ Decodes a JPEG encoded image and returns it as a numpy array. @@ -19,6 +17,6 @@ def decode(data: "io.BytesIO | bytes") -> np.ndarray: Returns: NDArray with in RGB format. Shape is (Height, Width, 3) """ - image = Image.open(data, formats=["JPEG"]) + image = Image.open(io.BytesIO(data), formats=["JPEG"]) return np.asarray(image) diff --git a/image_encoding/modules/encoder.py b/image_encoding/modules/encoder.py index d40f392..1a87216 100644 --- a/image_encoding/modules/encoder.py +++ b/image_encoding/modules/encoder.py @@ -11,7 +11,7 @@ QUALITY = 80 # Quality of JPEG encoding to use (0-100) -def encode(image_array: np.ndarray) -> "io.BytesIO | bytes": +def encode(image_array: np.ndarray) -> "bytes": """ Encodes an image in numpy array form into bytes of a JPEG. @@ -25,4 +25,4 @@ def encode(image_array: np.ndarray) -> "io.BytesIO | bytes": buffer = io.BytesIO() img.save(buffer, format="JPEG", quality=QUALITY) - return buffer + return buffer.getvalue() diff --git a/image_encoding/test_image_encode_decode.py b/image_encoding/test_image_encode_decode.py index d2c374f..f515f14 100644 --- a/image_encoding/test_image_encode_decode.py +++ b/image_encoding/test_image_encode_decode.py @@ -7,8 +7,8 @@ from PIL import Image import numpy as np -from image_encoding.modules import decoder -from image_encoding.modules import encoder +from .modules import decoder +from .modules import encoder ROOT_DIR = "image_encoding" @@ -16,9 +16,10 @@ RESULT_IMG = "result.jpg" -def main() -> int: +def test_image_encode_decode() -> int: """ Main testing sequence of encoding and decoding an image. + Note that JPEG is a lossy compression algorithm, so data cannot be recovered. """ # Get test image in numpy form im = Image.open(pathlib.Path(ROOT_DIR, TEST_IMG)) @@ -36,17 +37,3 @@ def main() -> int: # Check output shape assert img_array.shape == raw_data.shape - - # Note: the following fail since JPEG encoding is lossy - # assert (raw_data == img_array).all() - - return 0 - - -if __name__ == "__main__": - result_main = main() - - if result_main < 0: - print(f"ERROR: Status code: {result_main}") - - print("Done!") diff --git a/kml/test_ground_locations_to_kml.py b/kml/test_ground_locations_to_kml.py index f69f075..9c3cd09 100644 --- a/kml/test_ground_locations_to_kml.py +++ b/kml/test_ground_locations_to_kml.py @@ -6,8 +6,8 @@ import pytest -from kml.modules import ground_locations_to_kml -from kml.modules import location_ground +from .modules import ground_locations_to_kml +from .modules import location_ground PARENT_DIRECTORY = "kml" diff --git a/logger/__init__.py b/logger/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/logger/modules/__init__.py b/logger/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/logger/modules/config_logger.yaml b/logger/modules/config_logger.yaml new file mode 100644 index 0000000..f226ac3 --- /dev/null +++ b/logger/modules/config_logger.yaml @@ -0,0 +1,7 @@ +# Displaying datetime: https://docs.python.org/3/howto/logging.html#displaying-the-date-time-in-messages +# Changing the format of log messages https://docs.python.org/3/howto/logging.html#changing-the-format-of-displayed-messages +logger: + directory_path: "logs" + file_datetime_format: "%Y-%m-%d_%H-%M-%S" + format: "%(asctime)s: [%(levelname)s] %(message)s" + log_datetime_format: "%I:%M:%S" diff --git a/logger/modules/logger.py b/logger/modules/logger.py new file mode 100644 index 0000000..fe07c56 --- /dev/null +++ b/logger/modules/logger.py @@ -0,0 +1,174 @@ +""" +Logs debug messages. +""" + +import datetime +import inspect +import logging +import os +import pathlib +import sys + +# Used in type annotation of logger parameters +# pylint: disable-next=unused-import +import types + +from ..read_yaml.modules import read_yaml + + +CONFIG_FILE_PATH = pathlib.Path(os.path.dirname(__file__), "config_logger.yaml") + + +class Logger: + """ + Instantiates Logger objects. + """ + + __create_key = object() + + @classmethod + def create(cls, name: str, enable_log_to_file: bool) -> "tuple[bool, Logger | None]": + """ + Create and configure a logger. + """ + # Configuration settings + result, config = read_yaml.open_config(CONFIG_FILE_PATH) + if not result: + print("ERROR: Failed to load configuration file") + return False, None + + # Get Pylance to stop complaining + assert config is not None + + try: + log_directory_path = config["logger"]["directory_path"] + file_datetime_format = config["logger"]["file_datetime_format"] + logger_format = config["logger"]["format"] + logger_datetime_format = config["logger"]["log_datetime_format"] + except KeyError as exception: + print(f"Config key(s) not found: {exception}") + return False, None + + # Create a unique logger instance + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter( + fmt=logger_format, + datefmt=logger_datetime_format, + ) + + # Handles logging to terminal + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + + # Handles logging to file + if enable_log_to_file: + # Get the path to the logs directory. + entries = os.listdir(log_directory_path) + + if len(entries) == 0: + print("ERROR: The directory for this log session was not found.") + return False, None + + log_names = [ + entry for entry in entries if os.path.isdir(os.path.join(log_directory_path, entry)) + ] + + # Find the log directory for the current run, which is the most recent timestamp. + log_path = max( + log_names, + key=lambda datetime_string: datetime.datetime.strptime( + datetime_string, file_datetime_format + ), + ) + + filepath = pathlib.Path(log_directory_path, log_path, f"{name}.log") + try: + file = os.open(filepath, os.O_RDWR | os.O_EXCL | os.O_CREAT) + os.close(file) + except OSError: + print("ERROR: Log file already exists.") + return False, None + + file_handler = logging.FileHandler(filename=filepath, mode="w") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return True, Logger(cls.__create_key, logger) + + def __init__(self, class_create_private_key: object, logger: logging.Logger) -> None: + """ + Private constructor, use create() method. + """ + assert class_create_private_key is Logger.__create_key, "Use create() method." + + self.logger = logger + + @staticmethod + def message_and_metadata(message: str, frame: "types.FrameType | None") -> str: + """ + Extracts metadata from frame and appends it to the message. + """ + if frame is None: + return message + + # Get Pylance to stop complaining + assert frame is not None + + function_name = frame.f_code.co_name + filename = frame.f_code.co_filename + line_number = inspect.getframeinfo(frame).lineno + + return f"[{filename} | {function_name} | {line_number}] {message}" + + def debug(self, message: str, log_with_frame_info: bool = True) -> None: + """ + Logs a debug level message. + """ + if log_with_frame_info: + logger_frame = inspect.currentframe() + caller_frame = logger_frame.f_back + message = self.message_and_metadata(message, caller_frame) + self.logger.debug(message) + + def info(self, message: str, log_with_frame_info: bool = True) -> None: + """ + Logs an info level message. + """ + if log_with_frame_info: + logger_frame = inspect.currentframe() + caller_frame = logger_frame.f_back + message = self.message_and_metadata(message, caller_frame) + self.logger.info(message) + + def warning(self, message: str, log_with_frame_info: bool = True) -> None: + """ + Logs a warning level message. + """ + if log_with_frame_info: + logger_frame = inspect.currentframe() + caller_frame = logger_frame.f_back + message = self.message_and_metadata(message, caller_frame) + self.logger.warning(message) + + def error(self, message: str, log_with_frame_info: bool = True) -> None: + """ + Logs an error level message. + """ + if log_with_frame_info: + logger_frame = inspect.currentframe() + caller_frame = logger_frame.f_back + message = self.message_and_metadata(message, caller_frame) + self.logger.error(message) + + def critical(self, message: str, log_with_frame_info: bool = True) -> None: + """ + Logs a critical level message. + """ + if log_with_frame_info: + logger_frame = inspect.currentframe() + caller_frame = logger_frame.f_back + message = self.message_and_metadata(message, caller_frame) + self.logger.critical(message) diff --git a/logger/modules/logger_setup_main.py b/logger/modules/logger_setup_main.py new file mode 100644 index 0000000..3a08932 --- /dev/null +++ b/logger/modules/logger_setup_main.py @@ -0,0 +1,49 @@ +""" +Logger setup for `main()` . +""" + +import datetime +import pathlib + +from . import logger + + +MAIN_LOGGER_NAME = "main" + + +def setup_main_logger( + config: "dict", main_logger_name: str = MAIN_LOGGER_NAME, enable_log_to_file: bool = True +) -> "tuple[bool, logger.Logger | None, pathlib.Path | None]": + """ + Setup prerequisites for logging in `main()` . + + config: The configuration. + + Returns: Success, logger, logger path. + """ + # Get settings + try: + log_directory_path = config["logger"]["directory_path"] + log_path_format = config["logger"]["file_datetime_format"] + start_time = datetime.datetime.now().strftime(log_path_format) + except KeyError as exception: + print(f"ERROR: Config key(s) not found: {exception}") + return False, None, None + + logging_path = pathlib.Path(log_directory_path, start_time) + + # Create logging directory + logging_path.mkdir(exist_ok=True, parents=True) + + # Setup logger + result, main_logger = logger.Logger.create(main_logger_name, enable_log_to_file) + if not result: + print("ERROR: Failed to create main logger") + return False, None, None + + # Get Pylance to stop complaining + assert main_logger is not None + + main_logger.info(f"{main_logger_name} logger initialized", True) + + return True, main_logger, logging_path diff --git a/logger/read_yaml/__init__.py b/logger/read_yaml/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/logger/read_yaml/config_test_files/config_no_error.yaml b/logger/read_yaml/config_test_files/config_no_error.yaml new file mode 100644 index 0000000..dc345ad --- /dev/null +++ b/logger/read_yaml/config_test_files/config_no_error.yaml @@ -0,0 +1 @@ +"config": "no_error" diff --git a/logger/read_yaml/modules/__init__.py b/logger/read_yaml/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/logger/read_yaml/modules/read_yaml.py b/logger/read_yaml/modules/read_yaml.py new file mode 100644 index 0000000..a4361a8 --- /dev/null +++ b/logger/read_yaml/modules/read_yaml.py @@ -0,0 +1,26 @@ +""" +For YAML files. +""" + +import pathlib + +import yaml + + +def open_config(file_path: pathlib.Path) -> "tuple[bool, dict | None]": + """ + Open and decode YAML file. + """ + try: + with file_path.open("r", encoding="utf8") as file: + try: + config = yaml.safe_load(file) + return True, config + except yaml.YAMLError as exception: + print(f"ERROR: Could not parse YAML file: {exception}") + except FileNotFoundError as exception: + print(f"ERROR: YAML file not found: {exception}") + except IOError as exception: + print(f"ERROR: Could not open file: {exception}") + + return False, None diff --git a/logger/read_yaml/test_read_yaml.py b/logger/read_yaml/test_read_yaml.py new file mode 100644 index 0000000..eb09af4 --- /dev/null +++ b/logger/read_yaml/test_read_yaml.py @@ -0,0 +1,43 @@ +""" +Test if read_yaml function correctly reads yaml files +""" + +import os +import pathlib + +from .modules import read_yaml + + +CURRENT_DIRECTORY = os.path.dirname(__file__) + + +class TestOpenConfig: + """ + Test the open_config function + """ + + def test_open_config(self) -> None: + """ + Test if the function correctly reads the yaml file + """ + expected = {"config": "no_error"} + + result, actual = read_yaml.open_config( + pathlib.Path(CURRENT_DIRECTORY, "config_test_files/config_no_error.yaml") + ) + + assert result + assert actual == expected + + def test_open_config_file_not_found(self) -> None: + """ + Test if the function handles file not found + """ + expected = None + + result, actual = read_yaml.open_config( + pathlib.Path(CURRENT_DIRECTORY, "config_test_files/config_nonexistant_file.yaml") + ) + + assert not result + assert actual == expected diff --git a/logger/test_logger.py b/logger/test_logger.py new file mode 100644 index 0000000..c5745e4 --- /dev/null +++ b/logger/test_logger.py @@ -0,0 +1,220 @@ +""" +Logger unit tests. +""" + +import inspect +import pathlib +import re + +import pytest + +from .modules import logger, logger_setup_main +from .read_yaml.modules import read_yaml + + +@pytest.fixture +def main_logger_instance_and_log_file_path() -> logger.Logger: # type: ignore + """ + Returns the main logger with logging to file enabled and sets up logging directory. + """ + result, config = read_yaml.open_config(logger.CONFIG_FILE_PATH) + assert result + + result, instance, log_file_path = logger_setup_main.setup_main_logger(config=config) + assert result + yield instance, log_file_path + + +@pytest.fixture +def logger_instance_to_file_enabled() -> logger.Logger: # type: ignore + """ + Returns a logger with logging to file enabled. + """ + result, instance = logger.Logger.create("test_logger_to_file_enabled", True) + assert result + yield instance + + +@pytest.fixture +def logger_instance_to_file_disabled() -> logger.Logger: # type: ignore + """ + Returns a logger with logging to file disabled. + """ + result, instance = logger.Logger.create("test_logger_to_file_disabled", False) + assert result + yield instance + + +class TestMessageAndMetadata: + """ + Test if message_and_metadata function correctly extracts information from the frame. + """ + + def test_message_and_metadata_with_frame(self) -> None: + """ + Test by passing in a frame + """ + frame = inspect.currentframe() + message = "Test message" + actual = logger.Logger.message_and_metadata(message, frame) + + expected = f"[{__file__} | test_message_and_metadata_with_frame | 59] Test message" + + assert actual == expected + + def test_message_and_metadata_without_frame(self) -> None: + """ + Test with frame is None + """ + frame = None + message = "Test message" + actual = logger.Logger.message_and_metadata(message, frame) + + expected = "Test message" + + assert actual == expected + + +# Fixtures are used to setup and teardown resources for tests +# pylint: disable=redefined-outer-name +class TestLogger: + """ + Test if logger logs the correct messages to file and stdout + """ + + def test_log_with_frame_info( + self, caplog: pytest.LogCaptureFixture, logger_instance_to_file_disabled: logger.Logger + ) -> None: + """ + Test if frame information is logged + """ + test_message = "test message" + + logger_instance_to_file_disabled.debug(test_message, True) + actual = caplog.text + + expected_pattern = re.compile( + r"DEBUG.*\[" + + re.escape(__file__) + + r" | test_log_with_frame_info | 93\]" + + re.escape(test_message) + ) + + assert re.search(expected_pattern, actual) is not None + + def test_log_to_file( + self, + main_logger_instance_and_log_file_path: "tuple[logger.Logger | None, pathlib.Path | None]", + logger_instance_to_file_enabled: logger.Logger, + ) -> None: + """ + Test if messages are logged to file + All levels are done in one test since they will all be logged to the same file + """ + main_logger_instance, log_file_path = main_logger_instance_and_log_file_path + + main_message = "main message" + main_logger_instance.debug(main_message, False) + main_logger_instance.info(main_message, False) + main_logger_instance.warning(main_message, False) + main_logger_instance.error(main_message, False) + main_logger_instance.critical(main_message, False) + + test_message = "test message" + logger_instance_to_file_enabled.debug(test_message, False) + logger_instance_to_file_enabled.info(test_message, False) + logger_instance_to_file_enabled.warning(test_message, False) + logger_instance_to_file_enabled.error(test_message, False) + logger_instance_to_file_enabled.critical(test_message, False) + + main_log_file_path = pathlib.Path(log_file_path, "main.log") + test_log_file_path = pathlib.Path(log_file_path, "test_logger_to_file_enabled.log") + + with open(main_log_file_path, "r", encoding="utf8") as log_file: + actual_main = log_file.read() + + with open(test_log_file_path, "r", encoding="utf8") as log_file: + actual_test = log_file.read() + + for level in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: + expected = f"[{level}] {main_message}\n" + assert expected in actual_main # don't know timestamps, so check existance of message + + for level in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: + expected = f"[{level}] {test_message}\n" + assert expected in actual_test # don't know timestamps, so check existance of message + + def test_debug_log_debug_to_stdout( + self, caplog: pytest.LogCaptureFixture, logger_instance_to_file_disabled: logger.Logger + ) -> None: + """ + Test if debug level message is logged to stdout + """ + test_message = "test message" + + logger_instance_to_file_disabled.debug(test_message, False) + actual = caplog.text + + expected_pattern = re.compile(r"DEBUG.*" + re.escape(test_message)) + + assert re.search(expected_pattern, actual) is not None + + def test_debug_log_info_to_stdout( + self, caplog: pytest.LogCaptureFixture, logger_instance_to_file_disabled: logger.Logger + ) -> None: + """ + Test if info level message is logged to stdout + """ + test_message = "test message" + + logger_instance_to_file_disabled.info(test_message, False) + actual = caplog.text + + expected_pattern = re.compile(r"INFO.*" + re.escape(test_message)) + + assert re.search(expected_pattern, actual) is not None + + def test_debug_log_warning_to_stdout( + self, caplog: pytest.LogCaptureFixture, logger_instance_to_file_disabled: logger.Logger + ) -> None: + """ + Test if warning level message is logged to stdout + """ + test_message = "test message" + + logger_instance_to_file_disabled.warning(test_message, False) + actual = caplog.text + + expected_pattern = re.compile(r"WARNING.*" + re.escape(test_message)) + + assert re.search(expected_pattern, actual) is not None + + def test_debug_log_error_to_stdout( + self, caplog: pytest.LogCaptureFixture, logger_instance_to_file_disabled: logger.Logger + ) -> None: + """ + Test if error level message is logged to stdout + """ + test_message = "test message" + + logger_instance_to_file_disabled.error(test_message, False) + actual = caplog.text + + expected_pattern = re.compile(r"ERROR.*" + re.escape(test_message)) + + assert re.search(expected_pattern, actual) is not None + + def test_debug_log_critical_to_stdout( + self, caplog: pytest.LogCaptureFixture, logger_instance_to_file_disabled: logger.Logger + ) -> None: + """ + Test if critical level message is logged to stdout + """ + test_message = "test message" + + logger_instance_to_file_disabled.critical(test_message, False) + actual = caplog.text + + expected_pattern = re.compile(r"CRITICAL.*" + re.escape(test_message)) + + assert re.search(expected_pattern, actual) is not None diff --git a/mavlink/dronekit b/mavlink/dronekit new file mode 160000 index 0000000..3192482 --- /dev/null +++ b/mavlink/dronekit @@ -0,0 +1 @@ +Subproject commit 319248273f8a9c55785e26addaa8d7f513f14aeb diff --git a/mavlink/modules/drone_odometry.py b/mavlink/modules/drone_odometry.py index d602b8a..7c56c41 100644 --- a/mavlink/modules/drone_odometry.py +++ b/mavlink/modules/drone_odometry.py @@ -2,6 +2,7 @@ Position and orientation of drone. """ +import enum import math @@ -109,16 +110,26 @@ def __str__(self) -> str: return f"{self.__class__}, yaw: {self.yaw}, pitch: {self.pitch}, roll: {self.roll}" +class FlightMode(enum.Enum): + """ + Possible drone flight modes. + """ + + STOPPED = 0 + MOVING = 1 + MANUAL = 2 + + class DroneOdometry: """ - Wrapper for DronePosition and DroneOrientation. + Wrapper for DronePosition, DroneOrientation, and FlightMode. """ __create_key = object() @classmethod def create( - cls, position: DronePosition, orientation: DroneOrientation + cls, position: DronePosition, orientation: DroneOrientation, flight_mode: FlightMode ) -> "tuple[bool, DroneOdometry | None]": """ Position and orientation in one class. @@ -129,13 +140,17 @@ def create( if orientation is None: return False, None - return True, DroneOdometry(cls.__create_key, position, orientation) + if flight_mode is None: + return False, None + + return True, DroneOdometry(cls.__create_key, position, orientation, flight_mode) def __init__( self, class_private_create_key: object, position: DronePosition, orientation: DroneOrientation, + flight_mode: FlightMode, ) -> None: """ Private constructor, use create() method. @@ -144,9 +159,10 @@ def __init__( self.position = position self.orientation = orientation + self.flight_mode = flight_mode def __str__(self) -> str: """ To string. """ - return f"{self.__class__}, position: {self.position}, orientation: {self.orientation}" + return f"{self.__class__}, position: {self.position}, orientation: {self.orientation}, flight mode: {self.flight_mode}" diff --git a/mavlink/modules/flight_controller.py b/mavlink/modules/flight_controller.py index 97b92fa..ca3716e 100644 --- a/mavlink/modules/flight_controller.py +++ b/mavlink/modules/flight_controller.py @@ -4,8 +4,9 @@ import time -import dronekit +from pymavlink import mavutil +from .. import dronekit from . import drone_odometry @@ -16,8 +17,9 @@ class FlightController: __create_key = object() - __MAVLINK_LANDING_FRAME = dronekit.mavutil.mavlink.MAV_FRAME_GLOBAL - __MAVLINK_LANDING_COMMAND = dronekit.mavutil.mavlink.MAV_CMD_NAV_LAND + __MAVLINK_LANDING_FRAME = mavutil.mavlink.MAV_FRAME_GLOBAL + __MAVLINK_LANDING_COMMAND = mavutil.mavlink.MAV_CMD_NAV_LAND + __MAVLINK_WAYPOINT_COMMAND = mavutil.mavlink.MAV_CMD_NAV_WAYPOINT @classmethod def create(cls, address: str, baud: int = 57600) -> "tuple[bool, FlightController | None]": @@ -69,13 +71,17 @@ def get_odometry(self) -> "tuple[bool, drone_odometry.DroneOdometry | None]": if not result: return False, None + result, flight_mode = self.get_flight_mode() + if not result: + return False, None + # Get Pylance to stop complaining assert position_data is not None assert orientation_data is not None + assert flight_mode is not None result, odometry_data = drone_odometry.DroneOdometry.create( - position_data, - orientation_data, + position_data, orientation_data, flight_mode ) if not result: return False, None @@ -249,3 +255,94 @@ def get_current_position(self) -> drone_odometry.DronePosition: if not result or position is None: raise RuntimeError("Failed to get current position") return position + + def get_flight_mode(self) -> "tuple[bool, drone_odometry.FlightMode | None]": + """ + Gets the current flight mode of the drone. + """ + flight_mode = self.drone.mode.name + + if flight_mode is None: + return False, None + if flight_mode == "LOITER": + return True, drone_odometry.FlightMode.STOPPED + if flight_mode == "AUTO": + return True, drone_odometry.FlightMode.MOVING + return True, drone_odometry.FlightMode.MANUAL + + def download_commands(self) -> "tuple[bool, list[dronekit.Command]]": + """ + Downloads the current list of commands from the drone. + + Returns + ------- + tuple[bool, list[dronekit.Command]] + A tuple where the first element is a boolean indicating success or failure, + and the second element is the list of commands currently held by the drone. + """ + try: + command_sequence = self.drone.commands + command_sequence.download() + command_sequence.wait_ready() + commands = list(command_sequence) + return True, commands + except dronekit.TimeoutError: + print("ERROR: Download timeout, commands are not being received.") + return False, [] + except ConnectionResetError: + print("ERROR: Connection with drone reset. Unable to download commands.") + return False, [] + + def get_next_waypoint(self) -> "tuple[bool, drone_odometry.DronePosition | None]": + """ + Gets the next waypoint. + + Returns + ------- + tuple[bool, drone_odometry.DronePosition | None] + A tuple where the first element is a boolean indicating success or failure, + and the second element is the next waypoint currently held by the drone. + """ + result, commands = self.download_commands() + if not result: + return False, None + + next_command_index = self.drone.commands.next + if next_command_index >= len(commands): + return False, None + + for command in commands[next_command_index:]: + if command.command == self.__MAVLINK_WAYPOINT_COMMAND: + return drone_odometry.DronePosition.create(command.x, command.y, command.z) + return False, None + + def insert_waypoint( + self, index: int, latitude: float, longitude: float, altitude: float + ) -> bool: + """ + Insert a waypoint into the current list of commands at a certain index and reupload the list to the drone. + """ + result, commands = self.download_commands() + if not result: + return False + + new_waypoint = dronekit.Command( + 0, + 0, + 0, + self.__MAVLINK_LANDING_FRAME, + self.__MAVLINK_WAYPOINT_COMMAND, + 0, + 0, + 0, # param1 + 0, + 0, + 0, + latitude, + longitude, + altitude, + ) + + commands.insert(index, new_waypoint) + + return self.upload_commands(commands) diff --git a/mavlink/test_flight_controller.py b/mavlink/test_flight_controller.py index bf2aee9..9b7c88e 100644 --- a/mavlink/test_flight_controller.py +++ b/mavlink/test_flight_controller.py @@ -4,7 +4,7 @@ import time -from mavlink.modules import flight_controller +from .modules import flight_controller DELAY_TIME = 0.5 # seconds @@ -47,6 +47,23 @@ def main() -> int: time.sleep(DELAY_TIME) + # Download and print commands + success, commands = controller.download_commands() + if success: + print("Downloaded commands:") + for command in commands: + print(command) + else: + print("Failed to download commands.") + + result, next_waypoint = controller.get_next_waypoint() + if result: + print("next waypoint lat: " + str(next_waypoint.latitude)) + print("next waypoint lon: " + str(next_waypoint.longitude)) + print("next waypoint alt: " + str(next_waypoint.altitude)) + else: + print("Failed to get next waypoint.") + result, home = controller.get_home_location(TIMEOUT) if not result: print("Failed to get home location") diff --git a/mavlink/test_mission_ended.py b/mavlink/test_mission_ended.py index 9589553..b06986c 100644 --- a/mavlink/test_mission_ended.py +++ b/mavlink/test_mission_ended.py @@ -4,19 +4,20 @@ import time -import dronekit +from pymavlink import mavutil -from mavlink.modules import flight_controller +from . import dronekit +from .modules import flight_controller DELAY_TIME = 1.0 # seconds MISSION_PLANNER_ADDRESS = "tcp:127.0.0.1:14550" TIMEOUT = 1.0 # seconds -MAVLINK_TAKEOFF_FRAME = dronekit.mavutil.mavlink.MAV_FRAME_GLOBAL_RELATIVE_ALT -MAVLINK_TAKEOFF_COMMAND = dronekit.mavutil.mavlink.MAV_CMD_NAV_TAKEOFF -MAVLINK_FRAME = dronekit.mavutil.mavlink.MAV_FRAME_GLOBAL_RELATIVE_ALT -MAVLINK_COMMAND = dronekit.mavutil.mavlink.MAV_CMD_NAV_WAYPOINT +MAVLINK_TAKEOFF_FRAME = mavutil.mavlink.MAV_FRAME_GLOBAL_RELATIVE_ALT +MAVLINK_TAKEOFF_COMMAND = mavutil.mavlink.MAV_CMD_NAV_TAKEOFF +MAVLINK_FRAME = mavutil.mavlink.MAV_FRAME_GLOBAL_RELATIVE_ALT +MAVLINK_COMMAND = mavutil.mavlink.MAV_CMD_NAV_WAYPOINT ALTITUDE = 10 # metres ACCEPT_RADIUS = 10 # metres diff --git a/network/README.md b/network/README.md new file mode 100644 index 0000000..4b8ca37 --- /dev/null +++ b/network/README.md @@ -0,0 +1,31 @@ +# Network +This module facilitates communication over TCP and UDP. + +## Testing +Instructions on how to use the unit tests. + +### TCP +To test `TcpClientSocket` and `TcpServerSocket`, run the test scripts in the following order (could be on 2 different machines, but edit network addresses accordingly): + +1. `test_tcp_receiver.py` +2. `test_tcp_sender.py` + +`start_tcp_receiver.py` will start a server socket listening for connections on `localhost:8080`. +`start_tcp_sender.py` will then start a client socket and connect to the server. +Then, the client will send an integer (4 bytes) representing the message length, followed by the actual test message. +The server will receive the message and send it back to the client. +This process repeats until all test messages are sent. + +### UDP +To test `UdpClientSocket` and `UdpServerSocket`, run the test scripts in the following order (could be on 2 different machines, but edit network addresses accordingly): + +1. `test_udp_receiver.py` +2. `test_udp_sender.py` + +`start_udp_receiver.py` will start a server socket listening for data on `localhost:8080`. +`start_udp_sender.py` will then start a client socket and send data to the server. +Then, the client will send an integer (4 bytes) representing the message length, followed by the actual test message. +This process repeats until all test messages are sent. + +*Note: UDP does not guarantee that data is sent or is not corrupted. +It is a connectionless protocol and thus the server cannot send any messages. diff --git a/network/__init__.py b/network/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/network/modules/__init__.py b/network/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/network/modules/tcp/__init__.py b/network/modules/tcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/network/modules/tcp/client_socket.py b/network/modules/tcp/client_socket.py new file mode 100644 index 0000000..18a9234 --- /dev/null +++ b/network/modules/tcp/client_socket.py @@ -0,0 +1,78 @@ +""" +Wrapper for TCP client socket operations. +""" + +import socket + +from .socket_wrapper import TcpSocket + + +class TcpClientSocket(TcpSocket): + """ + Wrapper for TCP client socket operations. + """ + + __create_key = object() + + def __init__(self, class_private_create_key: object, socket_instance: socket.socket) -> None: + """ + Private constructor, use create() method. + """ + + assert class_private_create_key is TcpClientSocket.__create_key, "Use create() method" + + super().__init__(socket_instance=socket_instance) + + @classmethod + def create( + cls, + instance: socket.socket = None, + host: str = "localhost", + port: int = 5000, + connection_timeout: float = 60.0, + ) -> "tuple[bool, TcpClientSocket | None]": + """ + Establishes socket connection through provided host and port. + + Parameters + ---------- + instance: socket.socket (default None) + For initializing Socket with an existing socket object. + host: str (default "localhost") + port: int (default 5000) + The host combined with the port will form an address (e.g. localhost:5000) + connection_timeout: float (default 60.0) + Timeout for establishing connection, in seconds + + Returns + ------- + tuple[bool, TcpClientSocket | None] + The first parameter represents if the socket creation is successful. + - If it is not successful, the second parameter will be None. + - If it is successful, the second parameter will be the created TcpClientSocket object. + """ + + # Reassign instance before check or Pylance will complain + socket_instance = instance + if socket_instance is not None: + return True, TcpClientSocket(cls.__create_key, socket_instance) + + if connection_timeout <= 0: + # Zero puts it on non-blocking mode, which complicates things + print("Must be a positive non-zero value") + return False, None + + try: + socket_instance = socket.create_connection((host, port), connection_timeout) + return True, TcpClientSocket(cls.__create_key, socket_instance) + except TimeoutError: + print("Connection timed out.") + except socket.gaierror as e: + print( + f"Could not connect to socket, address related error: {e}. " + "Make sure the host and port are correct." + ) + except socket.error as e: + print(f"Could not connect to socket, connection error: {e}.") + + return False, None diff --git a/network/modules/tcp/server_socket.py b/network/modules/tcp/server_socket.py new file mode 100644 index 0000000..e6ae80a --- /dev/null +++ b/network/modules/tcp/server_socket.py @@ -0,0 +1,114 @@ +""" +Wrapper for TCP server socket operations. +""" + +import socket + +from .socket_wrapper import TcpSocket + + +class TcpServerSocket(TcpSocket): + """ + Wrapper for TCP server socket operations. + """ + + __create_key = object() + + def __init__(self, class_private_create_key: object, socket_instance: socket.socket) -> None: + """ + Private constructor, use create() method. + """ + + assert class_private_create_key is TcpServerSocket.__create_key, "Use create() method" + + super().__init__(socket_instance=socket_instance) + + @classmethod + def create( + cls, + instance: socket.socket = None, + host: str = "", + port: int = 5000, + connection_timeout: float = 60.0, + ) -> "tuple[bool, TcpServerSocket | None]": + """ + Establishes socket connection through provided host and port. + Note: Although in practice a TCP 'server' simply connects 2 clients together, + this newly created client is called 'server' for simplicity and differentiation. + + Parameters + ---------- + instance: socket.socket (default None) + For initializing Socket with an existing socket object. + host: str (default "") + Empty string is interpreted as '0.0.0.0' (IPv4) or '::' (IPv6), which is all addresses. + Could also use socket.gethostname(). (needed to enable other machines to connect) + port: int (default 5000) + The host combined with the port will form an address (e.g. localhost:5000) + connection_timeout: float (default 60.0) + Timeout for operations such as recieve + + Returns + ------- + tuple[bool, TcpServerSocket | None] + The first parameter represents if the socket creation is successful. + - If it is not successful, the second parameter will be None. + - If it is successful, the second parameter will be the created + TcpServerSocket object. + """ + + # Reassign instance before check or Pylance will complain + socket_instance = instance + if socket_instance is not None: + return True, TcpServerSocket(cls.__create_key, socket_instance) + + if socket.has_dualstack_ipv6(): + # Create server which can accept both IPv6 and IPv4 if possible + try: + server = socket.create_server( + (host, port), + family=socket.AF_INET6, + dualstack_ipv6=True, + ) + except socket.gaierror as e: + print( + f"Could not connect to socket, address related error: {e}.", + "Make sure the host and port are correct.", + ) + return False, None + except socket.error as e: + print(f"Could not connect to socket, connection error: {e}.") + return False, None + else: + # Otherwise, server can only accept IPv4 + try: + server = socket.create_server((host, port)) + except socket.gaierror as e: + print( + f"Could not connect to socket, address related error: {e}. " + "Make sure the host and port are correct." + ) + return False, None + except socket.error as e: + print(f"Could not connect to socket, connection error: {e}.") + return False, None + + # Currently listening, waiting for a connection + if host == "": + print(f"Listening for external connections on port {port}") + else: + print(f"Listening for internal connections on {host}:{port}") + + server.settimeout(connection_timeout) + # This is in blocking mode, nothing can happen until this finishes, even keyboard interrupt + socket_instance, addr = server.accept() + + # Now, a connection as been accepted (created a new 'client' socket) + print(f"Accepted a connection from {addr[0]}:{addr[1]}") + + socket_instance.settimeout(connection_timeout) + + server.close() + print("No longer accepting new connections.") + + return True, TcpServerSocket(cls.__create_key, socket_instance) diff --git a/network/modules/tcp/socket_wrapper.py b/network/modules/tcp/socket_wrapper.py new file mode 100644 index 0000000..0c4fd58 --- /dev/null +++ b/network/modules/tcp/socket_wrapper.py @@ -0,0 +1,111 @@ +""" +Wrapper for a TCP socket. +""" + +import socket + + +class TcpSocket: + """ + Wrapper for a TCP socket. + """ + + def __init__(self, socket_instance: socket.socket) -> None: + """ + Parameters + ---------- + instance: socket.socket + For initializing Socket with an existing socket object. + """ + + self.__socket = socket_instance + + def send(self, data: bytes) -> bool: + """ + Sends all data at once over the socket. + + Parameters + ---------- + data: bytes + + Returns + ------- + bool: If the data was sent successfully. + """ + + try: + self.__socket.sendall(data) + except socket.error as e: + print(f"Could not send data: {e}.") + return False + + return True + + def recv(self, buf_size: int) -> "tuple[bool, bytes | None]": + """ + Reads buf_size bytes from the socket. + + Parameters + ---------- + buf_size: int + The number of bytes to receive. + + Returns + ------- + tuple[bool, bytes | None] + The first parameter represents if the read is successful. + - If it is not successful, the second parameter will be None. + - If it is successful, the second parameter will be the data that is read. + """ + + message = b"" + bytes_recd = 0 + while bytes_recd < buf_size: + # 4096 or other low powers of 2 is recommended + # Although while testing without a limit, it has been observed to reach above 100000 + chunk = self.__socket.recv(min(buf_size - bytes_recd, 4096)) + + if chunk == b"": + print("Socket connection broken") # When 0 is received, means error + return False, None + + message += chunk + bytes_recd += len(chunk) + + return True, message + + def close(self) -> bool: + """ + Closes the socket object. All future operations on the socket object will fail. + + Returns + ------- + bool: If the socket was closed successfully. + """ + + try: + self.__socket.close() + except socket.error as e: + print(f"Could not close socket: {e}.") + return False + + return True + + def address(self) -> "tuple[str, int]": + """ + Retrieves the address that the socket is listening on. + + Returns + ------- + tuple[str, int] + The address in the format (ip address, port). + """ + + return self.__socket.getsockname() + + def get_socket(self) -> socket.socket: + """ + Getter for the underlying socket objet. + """ + + return self.__socket diff --git a/network/modules/udp/__init__.py b/network/modules/udp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/network/modules/udp/client_socket.py b/network/modules/udp/client_socket.py new file mode 100644 index 0000000..9b3842c --- /dev/null +++ b/network/modules/udp/client_socket.py @@ -0,0 +1,118 @@ +""" +Wrapper for UDP client socket operations. +""" + +import socket + +from .socket_wrapper import UdpSocket + + +class UdpClientSocket(UdpSocket): + """ + Wrapper for UDP client socket operations. + """ + + __create_key = object() + + def __init__( + self, + class_private_create_key: object, + socket_instance: socket.socket, + server_address: tuple, + ) -> None: + """ + Private Constructor, use create() method. + """ + + assert class_private_create_key is UdpClientSocket.__create_key + + super().__init__(socket_instance=socket_instance) + self.__server_address = server_address + + @classmethod + def create( + cls, host: str = "localhost", port: int = 5000, connection_timeout: float = 60.0 + ) -> "tuple[bool, UdpClientSocket | None]": + """ + Initializes UDP client socket with the appropriate server address. + + Parameters + ---------- + host: str (default "localhost") + The hostname or IP address of the server. + port: int (default 5000) + The port number of the server. + connection_timeout: float (default 60.0) + Timeout for establishing connection, in seconds + + Returns + ------- + tuple[bool, UdpClientSocket | None] + The boolean value represents whether the initialization was successful or not. + - If it is not successful, the second parameter will be None. + - If it is successful, the method will return True and a UdpClientSocket object will be created. + """ + + if connection_timeout <= 0: + print("Must provide positive non-zero value.") + return False, None + + try: + socket_instance = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + socket_instance.settimeout(connection_timeout) + server_address = (host, port) + return True, UdpClientSocket(cls.__create_key, socket_instance, server_address) + except TimeoutError as e: + print(f"Connection timed out: {e}") + + except socket.gaierror as e: + print( + f"Could not connect to socket, address related error: {e}. Make sure the host and port are correct." + ) + + except socket.error as e: + print(f"Could not connect: {e}") + + return False, None + + def send(self, data: bytes) -> bool: + """ + Sends data to the specified server address during this socket's creation. + + Parameters + ---------- + data: bytes + Takes in raw data that we wish to send + + Returns + ------- + bool: True if data is sent successfully, and false if it fails to send + """ + + try: + host, port = self.__server_address + super().send_to(data, host, port) + except socket.error as e: + print(f"Could not send data: {e}") + return False + + return True + + def recv(self, buf_size: int) -> None: + """ + Receive data method override to prevent client sockets from receiving data. + + Parameters + ---------- + bufsize: int + The amount of data to be received. + + Raises + ------ + NotImplementedError + Always raised because client sockets should not receive data. + """ + + raise NotImplementedError( + "Client sockets cannot receive data as they are not bound by a port." + ) diff --git a/network/modules/udp/server_socket.py b/network/modules/udp/server_socket.py new file mode 100644 index 0000000..5c28328 --- /dev/null +++ b/network/modules/udp/server_socket.py @@ -0,0 +1,77 @@ +""" +Wrapper for server socket operations. +""" + +import socket + +from .socket_wrapper import UdpSocket + + +class UdpServerSocket(UdpSocket): + """ + Wrapper for server socket operations. + """ + + __create_key = object() + + def __init__( + self, + class_private_create_key: object, + socket_instance: socket.socket, + ) -> None: + """ + Private Constructor, use create() method. + """ + + assert class_private_create_key is UdpServerSocket.__create_key, "Use create() method" + + super().__init__(socket_instance=socket_instance) + + @classmethod + def create( + cls, host: str = "", port: int = 5000, connection_timeout: float = 60.0 + ) -> "tuple[bool, UdpServerSocket | None]": + """ + Creates a UDP server socket bound to the provided host and port. + + Parameters + ---------- + host: str (default "") + The hostname or IP address to bind the socket to. + port: int (default 5000) + The port number to bind the socket to. + connection_timeout: float (default 60.0) + Timeout for establishing connection, in seconds + + Returns + ------- + tuple[bool, UdpServerSocket | None] + The first parameter represents if the socket creation is successful. + - If it is not successful, the second parameter will be None. + - If it is successful, the second parameter will be the created + UdpServerSocket object. + """ + + if connection_timeout <= 0: + print("Must provide a positive non-zero value.") + return False, None + + try: + socket_instance = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + socket_instance.settimeout(connection_timeout) + server_address = (host, port) + socket_instance.bind(server_address) + + if host == "": + print(f"Listening for external data on port {port}") + else: + print(f"Listening for internal data on {host}:{port}") + + return True, UdpServerSocket(cls.__create_key, socket_instance) + + except TimeoutError: + print("Connection timed out.") + return False, None + except socket.error as e: + print(f"Could not create socket, error: {e}.") + return False, None diff --git a/network/modules/udp/socket_wrapper.py b/network/modules/udp/socket_wrapper.py new file mode 100644 index 0000000..7d12400 --- /dev/null +++ b/network/modules/udp/socket_wrapper.py @@ -0,0 +1,114 @@ +""" +Wrapper for a UDP socket. +""" + +import socket +import time + +CHUNK_SIZE = 2**15 # 32 kb, may need to be shrunk on pi becasue its buffer may not be as large +SEND_DELAY = 1e-4 # Delay in seconds in between sends to avoid filling socket buffer + + +class UdpSocket: + """ + Wrapper for a UDP socket. + """ + + def __init__(self, socket_instance: socket.socket = None) -> None: + """ + Parameters + ---------- + instance: socket.socket + For initializing Socket with an existing socket object. + """ + + self.__socket = socket_instance + + def send_to( + self, + data: bytes, + host: str = "", + port: int = 5000, + chunk_size: int = CHUNK_SIZE, + send_delay: float = SEND_DELAY, + ) -> bool: + """ + Sends data to specified address + + Parameters + ---------- + data: bytes + host: str (default "") + Empty string is interpreted as '0.0.0.0' (IPv4) or '::' (IPv6), which is an open address + port: int (default 5000) + The host, combined with the port, will form the address as a tuple + + Returns + ------- + bool: if data was transferred successfully + """ + + address = (host, port) + data_sent = 0 + data_size = len(data) + + while data_sent < data_size: + if data_sent + chunk_size > data_size: + chunk = data[data_sent:data_size] + else: + chunk = data[data_sent : data_sent + chunk_size] + + try: + self.__socket.sendto(chunk, address) + data_sent += len(chunk) + except socket.error as e: + print(f"Could not send data: {e}") + return False + + time.sleep(send_delay) + + return True + + def recv(self, buf_size: int) -> "tuple[bool, bytes | None]": + """ + Parameters + ---------- + buf_size: int + The number of bytes to receive + + Returns + ------- + tuple: + bool - True if data was received and unpacked successfully, False otherwise + bytes | None - The received data, or None if unsuccessful + """ + + data = b"" + addr = None + data_size = 0 + + while data_size < buf_size: + try: + packet, current_addr = self.__socket.recvfrom(buf_size) + if addr is None: + addr = current_addr + elif addr != current_addr: + print(f"Data received from multiple addresses: {addr} and {current_addr}") + packet = b"" + + # Add the received packet to the accumulated data and increment the size accordingly + data += packet + data_size += len(packet) + + except socket.error as e: + print(f"Could not receive data: {e}") + return False, None + + return True, data + + def get_socket(self) -> socket.socket: + """ + Getter for the underlying socket objet. + """ + + return self.__socket diff --git a/network/start_tcp_receiver.py b/network/start_tcp_receiver.py new file mode 100644 index 0000000..3601630 --- /dev/null +++ b/network/start_tcp_receiver.py @@ -0,0 +1,60 @@ +""" +Test TCP socket operations by receiving images over server sockets. +""" + +import struct +import sys + +from .modules.tcp.server_socket import TcpServerSocket + + +# Since the socket may be using either IPv4 or IPv6, do not specify 127.0.0.1 or ::1. +# Instead, use localhost if wanting to test on same the machine. +SOCKET_ADDRESS = "" +SOCKET_PORT = 8080 + + +# pylint: disable=R0801 +def start_server(host: str, port: int) -> int: + """ + Starts server listening on host:port that receives messages and sends them back to the client. + """ + result, server_socket = TcpServerSocket.create(host=host, port=port) + assert result, "Server creation failed." + + while True: + result, data_len = server_socket.recv(4) + if not result: + print("Client closed the connection.") + break + + print("Received data length from client.") + + data_len = struct.unpack("!I", data_len) + print(f"data length: {data_len}") + + result, data = server_socket.recv(data_len[0]) + assert result, "Could not receive data from client." + print("Received data from client.") + + result = server_socket.send(data) + assert result, "Failed to send data back to client." + print("Sent data back to client.") + + result = server_socket.close() + assert result, "Failed to close server connection" + + print("Connection to client closed.") + + return 0 + + +if __name__ == "__main__": + if len(sys.argv) > 1: + result_main = start_server(SOCKET_ADDRESS, int(sys.argv[1])) + else: + result_main = start_server(SOCKET_ADDRESS, SOCKET_PORT) + if result_main < 0: + print(f"ERROR: Status code: {result_main}") + + print("Done!") diff --git a/network/start_tcp_sender.py b/network/start_tcp_sender.py new file mode 100644 index 0000000..6c38e0b --- /dev/null +++ b/network/start_tcp_sender.py @@ -0,0 +1,63 @@ +""" +Test TCP socket operations by sending images over client sockets. +""" + +import struct + +import numpy as np + +from .modules.tcp.client_socket import TcpClientSocket + + +SOCKET_ADDRESS = "localhost" +SOCKET_PORT = 8080 + + +IMAGE_ENCODE_EXT = ".png" + + +def start_sender(host: str, port: int) -> int: + """ + Client will send messages to the server, and the server will send them back. + """ + + test_messages = [ + b"Hello world!", + np.random.bytes(4096), + np.random.bytes(10000000), + ] + + result, client_socket = TcpClientSocket.create(host=host, port=port) + assert result, "Failed to create ClientSocket." + print(f"Connected to: {host}:{port}.") + + for data in test_messages: + # Send data length, 4 byte message (unsigned int, network or big-endian format) + data_len = struct.pack("!I", len(data)) + result = client_socket.send(data_len) + assert result, "Failed to send data byte length." + print("Sent data byte length to server.") + + result = client_socket.send(data) + assert result, "Failed to send data." + print("Sent data to server.") + + result, recv_data = client_socket.recv(len(data)) + assert result, "Failed to receive returning data." + print("Received data from server.") + assert data == recv_data, "Sent data does not match received data" + print("Received data is same as sent data, no corruption has occured.") + + result = client_socket.close() + assert result, "Failed to close client connection." + + print("Connection to server closed.") + return 0 + + +if __name__ == "__main__": + result_main = start_sender(SOCKET_ADDRESS, SOCKET_PORT) + if result_main < 0: + print(f"ERROR: Status code: {result_main}") + + print("Done!") diff --git a/network/start_udp_receiver.py b/network/start_udp_receiver.py new file mode 100644 index 0000000..2f837d3 --- /dev/null +++ b/network/start_udp_receiver.py @@ -0,0 +1,52 @@ +""" +Test UDP socket operations by receiving images over server sockets. +""" + +import sys +import struct + +from .modules.udp.server_socket import UdpServerSocket + + +# Since the socket may be using either IPv4 or IPv6, do not specify 127.0.0.1 or ::1. +# Instead, use localhost if wanting to test on same the machine. +SOCKET_ADDRESS = "" +SOCKET_PORT = 8080 + + +# pylint: disable=R0801 +def start_server(host: str, port: int) -> int: + """ + Starts server listening on host:port that receives some messages. + """ + result, server_socket = UdpServerSocket.create(host=host, port=port) + assert result, "Server cration failed." + + while True: + result, data_len = server_socket.recv(4) + if not result: + print("Client closed the connection.") + break + + print("Received data length from client.") + + data_len = struct.unpack("!I", data_len) + print(f"data length: {data_len}") + + result, data = server_socket.recv(data_len[0]) + assert result, "Could not receive data from client." + assert len(data) == data_len[0], "Data lengths not matching" + print("Received data from client.") + + return 0 + + +if __name__ == "__main__": + if len(sys.argv) > 1: + result_main = start_server(SOCKET_ADDRESS, int(sys.argv[1])) + else: + result_main = start_server(SOCKET_ADDRESS, SOCKET_PORT) + if result_main < 0: + print(f"ERROR: Status code: {result_main}") + + print("Done!") diff --git a/network/start_udp_sender.py b/network/start_udp_sender.py new file mode 100644 index 0000000..2a77be9 --- /dev/null +++ b/network/start_udp_sender.py @@ -0,0 +1,53 @@ +""" +Test UDP socket operations by sending images over client sockets. +""" + +import struct + +import numpy as np + +from .modules.udp.client_socket import UdpClientSocket + + +SOCKET_ADDRESS = "localhost" +SOCKET_PORT = 8080 + + +IMAGE_ENCODE_EXT = ".png" + + +def start_sender(host: str, port: int) -> int: + """ + Client will send some test data (random bytes) to server. + """ + + test_messages = [ + b"Hello world!", + np.random.bytes(4096), + np.random.bytes(10000000), + ] + + result, client_socket = UdpClientSocket.create(host=host, port=port) + assert result, "Failed to create ClientSocket." + print(f"Connected to: {host}:{port}.") + + for data in test_messages: + # Send data length, 4 byte message (unsigned int, network or big-endian format) + data_len = struct.pack("!I", len(data)) + result = client_socket.send(data_len) + assert result, "Failed to send data byte length." + print("Sent data byte length to server.") + + result = client_socket.send(data) + assert result, "Failed to send data." + print("Sent data to server.") + + return 0 + + +if __name__ == "__main__": + result_main = start_sender(SOCKET_ADDRESS, SOCKET_PORT) + if result_main < 0: + print(f"ERROR: Status code: {result_main}") + + print("Done!") diff --git a/network/test_tcp.py b/network/test_tcp.py new file mode 100644 index 0000000..980a033 --- /dev/null +++ b/network/test_tcp.py @@ -0,0 +1,86 @@ +""" +Test TCP sockets by sending random data to an echo server (for Pytest). +""" + +import os +from pathlib import Path +import struct +from typing import Generator + +import numpy as np +import pytest +from xprocess import ProcessStarter, XProcess + +from .modules.tcp.client_socket import TcpClientSocket + + +# Since the socket may be using either IPv4 or IPv6, do not specify 127.0.0.1 or ::1. +# Instead, use localhost if wanting to test on same the machine +SERVER_PORT = 8080 +ROOT_DIR = Path(__file__).parent.parent + + +@pytest.fixture +def test_messages() -> "Generator[bytes]": + """ + Test messages to send to server. + """ + + yield [ + b"Hello world!", + np.random.bytes(4096), + np.random.bytes(10000000), + ] + + +# fmt: off +@pytest.fixture +def myserver(xprocess: XProcess) -> Generator: + """ + Starts echo server. + """ + + myenv = os.environ.copy() + myenv["PYTHONPATH"] = str(ROOT_DIR) + myenv["PYTHONUNBUFFERED"] = "1" + + class Starter(ProcessStarter): + """ + xprocess config to start the server as another process. + """ + + pattern = f"Listening for external connections on port {SERVER_PORT}" + timeout = 60 + args = ["python", "-m", "network.start_tcp_receiver", SERVER_PORT] + env = myenv + + xprocess.ensure("mysever", Starter) + + yield + + xprocess.getinfo("myserver").terminate() +# fmt: on + + +# pylint: disable=W0621,W0613 +def test_client(test_messages: "Generator[bytes]", myserver: Generator) -> None: + """ + Client will send messages to the server, and the server will send them back. + """ + + result, client = TcpClientSocket.create(port=SERVER_PORT) + assert result + + for data in test_messages: + # Send data length, 4 byte message (unsigned int, network or big-endian format) + data_len = struct.pack("!I", len(data)) + result = client.send(data_len) + assert result + + result = client.send(data) + assert result + + result, recv_data = client.recv(len(data)) + assert result + + assert data == recv_data diff --git a/network/test_udp.py b/network/test_udp.py new file mode 100644 index 0000000..8232e9b --- /dev/null +++ b/network/test_udp.py @@ -0,0 +1,82 @@ +""" +Test UDP sockets by sending random data to a server (for Pytest). +""" + +import os +from pathlib import Path +import struct +from typing import Generator + +import numpy as np +import pytest +from xprocess import ProcessStarter, XProcess + +from .modules.udp.client_socket import UdpClientSocket + + +# Since the socket may be using either IPv4 or IPv6, do not specify 127.0.0.1 or ::1. +# Instead, use localhost if wanting to test on same the machine. +SERVER_PORT = 8825 +ROOT_DIR = Path(__file__).parent.parent + + +@pytest.fixture +def test_messages() -> "Generator[bytes]": + """ + Test messages to send to server. + """ + + yield [ + b"Hello world!", + np.random.bytes(4096), + np.random.bytes(10000000), + ] + + +# fmt: off +@pytest.fixture +def myserver(xprocess: XProcess) -> Generator: + """ + Starts server. + """ + + myenv = os.environ.copy() + myenv["PYTHONPATH"] = str(ROOT_DIR) + myenv["PYTHONUNBUFFERED"] = "1" + + class Starter(ProcessStarter): + """ + xprocess config to start the server as another process. + """ + + pattern = f"Listening for external data on port {SERVER_PORT}" + timeout = 60 + args = ["python", "-m", "network.start_udp_receiver", SERVER_PORT] + env = myenv + + xprocess.ensure("mysever", Starter) + + yield + + xprocess.getinfo("myserver").terminate() +# fmt: on + + +# pylint: disable=W0621,W0613 +def test_client(test_messages: "Generator[bytes]", myserver: Generator) -> None: + """ + Client will send messages to the server. + We do not know whether they have been received successfully or not, since these are UDP packets + """ + + result, client = UdpClientSocket.create(port=SERVER_PORT) + assert result + + for data in test_messages: + # Send data length, 4 byte message (unsigned int, network or big-endian format) + data_len = struct.pack("!I", len(data)) + result = client.send(data_len) + assert result + + result = client.send(data) + assert result diff --git a/pyproject.toml b/pyproject.toml index bdde5db..141de84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,9 @@ ignore-paths = [ # Logging "logs/", + + # Submodule dronekit + "mavlink/dronekit", ] # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the @@ -24,7 +27,7 @@ jobs = 0 # Minimum Python version to use for version dependent checks. Will default to the # version used to run pylint. -py-version = "3.8" +py-version = "3.11" # Discover python modules and packages in the file system subtree. recursive = true @@ -97,6 +100,12 @@ min-similarity-lines = 10 [tool.pytest.ini_options] minversion = "6.0" +# Submodules +addopts = "--ignore=mavlink/dronekit/" + [tool.black] line-length = 100 -target-version = ["py38"] +target-version = ["py311"] +# Excludes files or directories in addition to the defaults +# Submodules +extend-exclude = "mavlink/dronekit/*" diff --git a/qr/test_qr.py b/qr/test_qr.py index e50c32c..71f6d58 100644 --- a/qr/test_qr.py +++ b/qr/test_qr.py @@ -6,7 +6,7 @@ import cv2 -from qr.modules.qr_scanner import QrScanner +from .modules.qr_scanner import QrScanner PARENT_DIRECTORY = "qr" diff --git a/requirements.txt b/requirements.txt index 42301b6..cdd7076 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,23 +3,27 @@ # Global pytest +pytest-xprocess # Module camera opencv-python numpy -# Module lte +# Module image_encoding Pillow +# Module kml +simplekml + +# Module logger +pyyaml + # Module mavlink -dronekit +pymavlink # Module qr pyzbar -# Module kml -simplekml - # Linters and formatters are explicitly versioned black==24.2.0 flake8-annotations==3.0.1 diff --git a/setup.cfg b/setup.cfg index 338a854..694fd37 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,3 +17,6 @@ extend-exclude= # Logging logs/, + + # Submodule dronekit + mavlink/dronekit, diff --git a/setup_project.ps1 b/setup_project.ps1 new file mode 100644 index 0000000..fb025d9 --- /dev/null +++ b/setup_project.ps1 @@ -0,0 +1,21 @@ +# Initialize the project for Windows + +# Activate venv to prevent accidentally installing into global space +./venv/Scripts/Activate.ps1 + +if($?) { + # If successfully activated venv + "Installing project dependencies..." + pip install -r requirements.txt + + "" + "Installing submodules and their dependencies..." + git submodule update --init --remote --recursive + git submodule foreach --recursive "pip install -r requirements.txt" + + deactivate + "" + "Seutp complete!" +} else { + "Please install a virtual environment in the directory 'venv', at the project root directory" +} diff --git a/setup_project.sh b/setup_project.sh new file mode 100644 index 0000000..92705ef --- /dev/null +++ b/setup_project.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# Initialize and update submodules script for Linux + +# Activate venv to prevent accidentally installing into global space +source ./venv/bin/activate + +if [ $? -eq 0 ]; then + echo "Installing project dependencies..." + pip install -r requirements.txt + + echo "" + echo "Installing submodules and their dependencies..." + git submodule update --init --remote --recursive + git submodule foreach --recursive "pip install -r requirements.txt" + + deactivate + echo "" + echo "Setup complete!" +else + echo "Please install a virtual environment in the directory 'venv', at the project root directory" +fi diff --git a/test_send_image.py b/test_send_image.py new file mode 100644 index 0000000..37f1d2f --- /dev/null +++ b/test_send_image.py @@ -0,0 +1,149 @@ +""" +Integration test for image_encoding and network modules. +Encodes images then sends them to server through network sockets. +""" + +import os +from pathlib import Path +import struct +from typing import Generator + +import numpy as np +from PIL import Image +import pytest +from xprocess import ProcessStarter, XProcess + +from .image_encoding.modules import decoder +from .image_encoding.modules import encoder +from .network.modules.tcp.client_socket import TcpClientSocket +from .network.modules.udp.client_socket import UdpClientSocket + +# Since the socket may be using either IPv4 or IPv6, do not specify 127.0.0.1 or ::1. +# Instead, use localhost if wanting to test on same the machine +SERVER_PORT = 9145 +ROOT_DIR = Path(__file__).parent + + +@pytest.fixture +def images() -> "Generator[np.ndarray]": + """ + Images to send to server. + """ + + image = Image.open(Path(ROOT_DIR, "image_encoding", "test.png")) + image_bytes = [np.asarray(image)] + + yield image_bytes + + +# fmt: off +@pytest.fixture +def tcp_server(xprocess: XProcess) -> Generator: + """ + Starts echo server. + """ + + myenv = os.environ.copy() + myenv["PYTHONPATH"] = str(ROOT_DIR) + myenv["PYTHONUNBUFFERED"] = "1" + + class Starter(ProcessStarter): + """ + xprocess config to start a tcp echo server as another process. + """ + + pattern = f"Listening for external connections on port {SERVER_PORT}" + timeout = 60 + args = ["python", "-m", "network.start_tcp_receiver", SERVER_PORT] + env = myenv + + xprocess.ensure("tcp_sever", Starter) + + yield + + xprocess.getinfo("tcp_server").terminate() +# fmt: on + + +# pylint: disable=W0621,W0613 +def test_tcp_client(images: "Generator[np.ndarray]", tcp_server: Generator) -> None: + """ + Client will send images to the server, and the server will send them back. + """ + + result, client = TcpClientSocket.create(port=SERVER_PORT) + assert result + + for image in images: + # Encode image (into jpeg) + data = encoder.encode(image) + + # Send image length, 4 byte message (unsigned int, network or big-endian format) + data_len = struct.pack("!I", len(data)) + result = client.send(data_len) + assert result + + # Send image to server + result = client.send(data) + assert result + + # Recive image back from echo server + result, recv_data = client.recv(len(data)) + assert result + + # Decode image + recv_image = decoder.decode(recv_data) + assert image.shape == recv_image.shape + + +# fmt: off +@pytest.fixture +def udp_server(xprocess: XProcess) -> Generator: + """ + Starts server. + """ + + myenv = os.environ.copy() + myenv["PYTHONPATH"] = str(ROOT_DIR) + myenv["PYTHONUNBUFFERED"] = "1" + + class Starter(ProcessStarter): + """ + xprocess config to start a udp server as another process. + """ + + pattern = f"Listening for external data on port {SERVER_PORT + 1}" + timeout = 60 + args = ["python", "-m", "network.start_udp_receiver", SERVER_PORT + 1] + env = myenv + + xprocess.ensure("udp_sever", Starter) + + yield + + xprocess.getinfo("udp_server").terminate() +# fmt: on + + +# pylint: disable=W0621,W0613 +def test_udp_client(images: "Generator[bytes]", udp_server: Generator) -> None: + """ + Client will send image to the server. + We do not know whether they have been received successfully or not, since these are UDP packets + """ + + result, client = UdpClientSocket.create(port=SERVER_PORT + 1) + assert result + + for image in images: + # Encode image + data = encoder.encode(image) + + # Send data length, 4 byte message (unsigned int, network or big-endian format) + data_len = struct.pack("!I", len(data)) + result = client.send(data_len) + assert result + + # Send image to server + result = client.send(data) + assert result