diff --git a/.ci/tests/examples/inference_test.py b/.ci/tests/examples/prediction_test.py similarity index 87% rename from .ci/tests/examples/inference_test.py rename to .ci/tests/examples/prediction_test.py index 6e27d2499..cd5076e4f 100644 --- a/.ci/tests/examples/inference_test.py +++ b/.ci/tests/examples/prediction_test.py @@ -15,7 +15,7 @@ def _eprint(*args, **kwargs): def _wait_n_rounds(collection): n = 0 for _ in range(RETRIES): - query = {'type': 'INFERENCE'} + query = {'type': 'MODEL_PREDICTION'} n = collection.count_documents(query) if n == N_CLIENTS: return n @@ -32,4 +32,4 @@ def _wait_n_rounds(collection): # Wait for successful rounds succeded = _wait_n_rounds(client['fedn-test-network']['control']['status']) assert(succeded == N_CLIENTS) # check that all rounds succeeded - _eprint(f'Succeded inference clients: {succeded}. Test passed.') + _eprint(f'Succeded prediction clients: {succeded}. Test passed.') diff --git a/.ci/tests/examples/run_inference.sh b/.ci/tests/examples/run_prediction.sh similarity index 57% rename from .ci/tests/examples/run_inference.sh rename to .ci/tests/examples/run_prediction.sh index d78771d70..eb676fce3 100755 --- a/.ci/tests/examples/run_inference.sh +++ b/.ci/tests/examples/run_prediction.sh @@ -8,12 +8,12 @@ if [ "$#" -lt 1 ]; then fi example="$1" ->&2 echo "Run inference" +>&2 echo "Run prediction" pushd "examples/$example" -curl -k -X POST https://localhost:8090/infer +curl -k -X POST https://localhost:8090/predict ->&2 echo "Checking inference success" -".$example/bin/python" ../../.ci/tests/examples/inference_test.py +>&2 echo "Checking prediction success" +".$example/bin/python" ../../.ci/tests/examples/prediction_test.py >&2 echo "Test completed successfully" popd \ No newline at end of file diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index ed354a78b..b89df6ca6 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -38,8 +38,8 @@ jobs: - name: run ${{ matrix.to_test }} run: .ci/tests/examples/run.sh ${{ matrix.to_test }} - # - name: run ${{ matrix.to_test }} inference - # run: .ci/tests/examples/run_inference.sh ${{ matrix.to_test }} + # - name: run ${{ matrix.to_test }} prediction + # run: .ci/tests/examples/run_prediction.sh ${{ matrix.to_test }} # if: ${{ matrix.os != 'macos-11' && matrix.to_test == 'mnist-keras keras' }} # example available for Keras - name: print logs diff --git a/examples/mnist-keras/client/predict.py b/examples/mnist-keras/client/predict.py index 9d502ed75..412bb74ae 100644 --- a/examples/mnist-keras/client/predict.py +++ b/examples/mnist-keras/client/predict.py @@ -11,13 +11,13 @@ def predict(in_model_path, out_json_path, data_path=None): - # Using test data for inference but another dataset could be loaded + # Using test data for prediction but another dataset could be loaded x_test, _ = load_data(data_path, is_train=False) # Load model model = load_parameters(in_model_path) - # Infer + # Predict y_pred = model.predict(x_test) y_pred = np.argmax(y_pred, axis=1) diff --git a/fedn/cli/status_cmd.py b/fedn/cli/status_cmd.py index 078acaf13..c879ca1ef 100644 --- a/fedn/cli/status_cmd.py +++ b/fedn/cli/status_cmd.py @@ -8,8 +8,7 @@ @main.group("status") @click.pass_context def status_cmd(ctx): - """:param ctx: - """ + """:param ctx:""" pass diff --git a/fedn/network/api/v1/__init__.py b/fedn/network/api/v1/__init__.py index e5542084f..83af60aad 100644 --- a/fedn/network/api/v1/__init__.py +++ b/fedn/network/api/v1/__init__.py @@ -1,12 +1,12 @@ from fedn.network.api.v1.client_routes import bp as client_bp from fedn.network.api.v1.combiner_routes import bp as combiner_bp from fedn.network.api.v1.helper_routes import bp as helper_bp -from fedn.network.api.v1.inference_routes import bp as inference_bp from fedn.network.api.v1.model_routes import bp as model_bp from fedn.network.api.v1.package_routes import bp as package_bp +from fedn.network.api.v1.prediction_routes import bp as prediction_bp from fedn.network.api.v1.round_routes import bp as round_bp from fedn.network.api.v1.session_routes import bp as session_bp from fedn.network.api.v1.status_routes import bp as status_bp from fedn.network.api.v1.validation_routes import bp as validation_bp -_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, inference_bp, helper_bp] +_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp, prediction_bp, helper_bp] diff --git a/fedn/network/api/v1/inference_routes.py b/fedn/network/api/v1/inference_routes.py deleted file mode 100644 index 6da2dc8b4..000000000 --- a/fedn/network/api/v1/inference_routes.py +++ /dev/null @@ -1,34 +0,0 @@ -import threading - -from flask import Blueprint, jsonify, request - -from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.shared import control -from fedn.network.api.v1.shared import api_version - -bp = Blueprint("inference", __name__, url_prefix=f"/api/{api_version}/infer") - - -@bp.route("/start", methods=["POST"]) -@jwt_auth_required(role="admin") -def start_session(): - """Start a new inference session. - param: session_id: The session id to start. - type: session_id: str - param: rounds: The number of rounds to run. - type: rounds: int - """ - try: - data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict() - session_id: str = data.get("session_id") - - if not session_id or session_id == "": - return jsonify({"message": "Session ID is required"}), 400 - - session_config = {"session_id": session_id} - - threading.Thread(target=control.inference_session, kwargs={"config": session_config}).start() - - return jsonify({"message": "Inference session started"}), 200 - except Exception: - return jsonify({"message": "Failed to start inference session"}), 500 diff --git a/fedn/network/api/v1/prediction_routes.py b/fedn/network/api/v1/prediction_routes.py new file mode 100644 index 000000000..d5dd804cc --- /dev/null +++ b/fedn/network/api/v1/prediction_routes.py @@ -0,0 +1,51 @@ +import threading + +from flask import Blueprint, jsonify, request + +from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import control +from fedn.network.api.v1.shared import api_version, mdb +from fedn.network.storage.statestore.stores.model_store import ModelStore +from fedn.network.storage.statestore.stores.prediction_store import PredictionStore +from fedn.network.storage.statestore.stores.shared import EntityNotFound + +bp = Blueprint("prediction", __name__, url_prefix=f"/api/{api_version}/predict") + +prediction_store = PredictionStore(mdb, "control.predictions") +model_store = ModelStore(mdb, "control.model") + + +@bp.route("/start", methods=["POST"]) +@jwt_auth_required(role="admin") +def start_session(): + """Start a new prediction session. + param: prediction_id: The session id to start. + type: prediction_id: str + param: rounds: The number of rounds to run. + type: rounds: int + """ + try: + data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict() + prediction_id: str = data.get("prediction_id") + + if not prediction_id or prediction_id == "": + return jsonify({"message": "prediction_id is required"}), 400 + + if data.get("model_id") is None: + count = model_store.count() + if count == 0: + return jsonify({"message": "No models available"}), 400 + else: + try: + model_id = data.get("model_id") + _ = model_store.get(model_id) + except EntityNotFound: + return jsonify({"message": f"Model {model_id} not found"}), 404 + + session_config = {"prediction_id": prediction_id} + + threading.Thread(target=control.prediction_session, kwargs={"config": session_config}).start() + + return jsonify({"message": "Prediction session started"}), 200 + except Exception: + return jsonify({"message": "Failed to start prediction session"}), 500 diff --git a/fedn/network/api/v1/status_routes.py b/fedn/network/api/v1/status_routes.py index cf3907bea..0716f965b 100644 --- a/fedn/network/api/v1/status_routes.py +++ b/fedn/network/api/v1/status_routes.py @@ -124,8 +124,12 @@ def get_statuses(): limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() + # print all the typed headers + print(f"limit: {limit}, skip: {skip}, sort_key: {sort_key}, sort_order: {sort_order}, use_typing: {use_typing}") + print(f"kwargs: {kwargs}") statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) + print(f"statuses: {statuses}") result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"] response = {"count": statuses["count"], "result": result} diff --git a/fedn/network/clients/client.py b/fedn/network/clients/client.py index fa7000a99..12f76e2a4 100644 --- a/fedn/network/clients/client.py +++ b/fedn/network/clients/client.py @@ -176,7 +176,7 @@ def connect(self, combiner_config): logger.debug("Client using metadata: {}.".format(self.metadata)) port = combiner_config["port"] secure = False - if combiner_config["fqdn"] is not None: + if "fqdn" in combiner_config.keys() and combiner_config["fqdn"] is not None: host = combiner_config["fqdn"] # assuming https if fqdn is used port = 443 @@ -417,12 +417,12 @@ def _listen_to_task_stream(self): self.inbox.put(("train", request)) elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config["validator"]: self.inbox.put(("validate", request)) - elif request.type == fedn.StatusType.INFERENCE and self.config["validator"]: - logger.info("Received inference request for model_id {}".format(request.model_id)) + elif request.type == fedn.StatusType.MODEL_PREDICTION and self.config["validator"]: + logger.info("Received prediction request for model_id {}".format(request.model_id)) presigned_url = json.loads(request.data) presigned_url = presigned_url["presigned_url"] - logger.info("Inference presigned URL: {}".format(presigned_url)) - self.inbox.put(("infer", request)) + logger.info("Prediction presigned URL: {}".format(presigned_url)) + self.inbox.put(("predict", request)) else: logger.error("Unknown request type: {}".format(request.type)) @@ -519,25 +519,17 @@ def _process_training_request(self, model_id: str, session_id: str = None): return updated_model_id, meta - def _process_validation_request(self, model_id: str, is_inference: bool, session_id: str = None): + def _process_validation_request(self, model_id: str, session_id: str = None): """Process a validation request. :param model_id: The model id of the model to be validated. :type model_id: str - :param is_inference: True if the validation is an inference request, False if it is a validation request. - :type is_inference: bool :param session_id: The id of the current session. :type session_id: str :return: The validation metrics, or None if validation failed. :rtype: dict """ - # Figure out cmd - if is_inference: - cmd = "infer" - else: - cmd = "validate" - - self.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id) + self.send_status(f"Processing validation request for model_id {model_id}", sesssion_id=session_id) self.state = ClientState.validating try: model = self.get_model_from_combiner(str(model_id)) @@ -550,7 +542,7 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session fh.write(model.getbuffer()) outpath = get_tmp_path() - self.dispatcher.run_cmd(f"{cmd} {inpath} {outpath}") + self.dispatcher.run_cmd(f"validate {inpath} {outpath}") with open(outpath, "r") as fh: validation = json.loads(fh.read()) @@ -566,22 +558,22 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session self.state = ClientState.idle return validation - def _process_inference_request(self, model_id: str, session_id: str, presigned_url: str): - """Process an inference request. + def _process_prediction_request(self, model_id: str, session_id: str, presigned_url: str): + """Process a prediction request. - :param model_id: The model id of the model to be used for inference. + :param model_id: The model id of the model to be used for prediction. :type model_id: str :param session_id: The id of the current session. :type session_id: str - :param presigned_url: The presigned URL for the data to be used for inference. + :param presigned_url: The presigned URL for the data to be used for prediction. :type presigned_url: str :return: None """ - self.send_status(f"Processing inference request for model_id {model_id}", sesssion_id=session_id) + self.send_status(f"Processing prediction request for model_id {model_id}", sesssion_id=session_id) try: model = self.get_model_from_combiner(str(model_id)) if model is None: - logger.error("Could not retrieve model from combiner. Aborting inference request.") + logger.error("Could not retrieve model from combiner. Aborting prediction request.") return inpath = self.helper.get_tmp_path() @@ -591,7 +583,7 @@ def _process_inference_request(self, model_id: str, session_id: str, presigned_u outpath = get_tmp_path() self.dispatcher.run_cmd(f"predict {inpath} {outpath}") - # Upload the inference result to the presigned URL + # Upload the prediction result to the presigned URL with open(outpath, "rb") as fh: response = requests.put(presigned_url, data=fh.read()) @@ -599,12 +591,12 @@ def _process_inference_request(self, model_id: str, session_id: str, presigned_u os.unlink(outpath) if response.status_code != 200: - logger.warning("Inference upload failed with status code {}".format(response.status_code)) + logger.warning("Prediction upload failed with status code {}".format(response.status_code)) self.state = ClientState.idle return except Exception as e: - logger.warning("Inference failed with exception {}".format(e)) + logger.warning("Prediction failed with exception {}".format(e)) self.state = ClientState.idle return @@ -668,7 +660,7 @@ def process_request(self): elif task_type == "validate": self.state = ClientState.validating - metrics = self._process_validation_request(request.model_id, False, request.session_id) + metrics = self._process_validation_request(request.model_id, request.session_id) if metrics is not None: # Send validation @@ -707,21 +699,47 @@ def process_request(self): self.state = ClientState.idle self.inbox.task_done() - elif task_type == "infer": - self.state = ClientState.inferencing + elif task_type == "predict": + self.state = ClientState.predicting try: presigned_url = json.loads(request.data) except json.JSONDecodeError as e: - logger.error(f"Failed to decode inference request data: {e}") + logger.error(f"Failed to decode prediction request data: {e}") self.state = ClientState.idle continue if "presigned_url" not in presigned_url: - logger.error("Inference request missing presigned_url.") + logger.error("Prediction request missing presigned_url.") self.state = ClientState.idle continue presigned_url = presigned_url["presigned_url"] - _ = self._process_inference_request(request.model_id, request.session_id, presigned_url) + # Obs that session_id in request is the prediction_id + _ = self._process_prediction_request(request.model_id, request.session_id, presigned_url) + prediction = fedn.ModelPrediction() + prediction.sender.name = self.name + prediction.sender.role = fedn.WORKER + prediction.receiver.name = request.sender.name + prediction.receiver.name = request.sender.name + prediction.receiver.role = request.sender.role + prediction.model_id = str(request.model_id) + # TODO: Add prediction data + prediction.data = "" + prediction.timestamp.GetCurrentTime() + prediction.correlation_id = request.correlation_id + # Obs that session_id in request is the prediction_id + prediction.prediction_id = request.session_id + + try: + _ = self.combinerStub.SendModelPrediction(prediction, metadata=self.metadata) + status_type = fedn.StatusType.MODEL_PREDICTION + self.send_status( + "Model prediction completed.", log_level=fedn.Status.AUDIT, type=status_type, request=prediction, sesssion_id=request.session_id + ) + except grpc.RpcError as e: + status_code = e.code() + logger.error("GRPC error, {}.".format(status_code.name)) + logger.debug(e) + self.state = ClientState.idle except queue.Empty: pass diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index 9b8550344..4d2b5d569 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -25,6 +25,7 @@ def __init__(self, key): def __call__(self, context, callback): callback((("authorization", f"{FEDN_AUTH_SCHEME} {self._key}"),), None) + def _get_ssl_certificate(domain, port=443): context = SSL.Context(SSL.TLSv1_2_METHOD) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -39,6 +40,7 @@ def _get_ssl_certificate(domain, port=443): cert = cert.to_cryptography().public_bytes(Encoding.PEM).decode() return cert + class GrpcHandler: def __init__(self, host: str, port: int, name: str, token: str, combiner_name: str): self.metadata = [ @@ -59,6 +61,11 @@ def _init_secure_channel(self, host: str, port: int, token: str): url = f"{host}:{port}" logger.info(f"Connecting (GRPC) to {url}") + # Keepalive settings: these help keep the connection open for long-lived clients + KEEPALIVE_TIME_MS = 60 * 1000 # send keepalive ping every 60 seconds + KEEPALIVE_TIMEOUT_MS = 20 * 1000 # wait 20 seconds for keepalive ping ack before considering connection dead + KEEPALIVE_PERMIT_WITHOUT_CALLS = True # allow keepalive pings even when there are no RPCs + if os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): logger.info("Using root certificate from environment variable for GRPC channel.") with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], "rb") as f: @@ -70,7 +77,16 @@ def _init_secure_channel(self, host: str, port: int, token: str): cert = _get_ssl_certificate(host, port) credentials = grpc.ssl_channel_credentials(cert.encode("utf-8")) auth_creds = grpc.metadata_call_credentials(GrpcAuth(token)) - self.channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) + self.channel = grpc.secure_channel( + "{}:{}".format(host, str(port)), + grpc.composite_channel_credentials(credentials, auth_creds), + options=[ + ("grpc.keepalive_time_ms", KEEPALIVE_TIME_MS), + ("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS), + ("grpc.keepalive_permit_without_calls", KEEPALIVE_PERMIT_WITHOUT_CALLS), + ("grpc.http2.max_pings_without_data", 0), # unlimited pings without data + ], + ) def _init_insecure_channel(self, host: str, port: int): url = f"{host}:{port}" @@ -115,7 +131,7 @@ def listen_to_task_stream(self, client_name: str, client_id: str, callback: Call type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request, sesssion_id=request.session_id, - sender_name=client_name + sender_name=client_name, ) logger.info(f"Received task request of type {request.type} for model_id {request.model_id}") @@ -234,7 +250,8 @@ def send_model_to_combiner(self, model: BytesIO, id: str): return result - def send_model_update(self, + def send_model_update( + self, sender_name: str, sender_role: fedn.Role, client_id: str, @@ -242,7 +259,7 @@ def send_model_update(self, model_update_id: str, receiver_name: str, receiver_role: fedn.Role, - meta: dict + meta: dict, ): update = fedn.ModelUpdate() update.sender.name = sender_name @@ -260,17 +277,7 @@ def send_model_update(self, _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) except grpc.RpcError as e: return self._handle_grpc_error( - e, - "SendModelUpdate", - lambda: self.send_model_update( - sender_name, - sender_role, - model_id, - model_update_id, - receiver_name, - receiver_role, - meta - ) + e, "SendModelUpdate", lambda: self.send_model_update(sender_name, sender_role, model_id, model_update_id, receiver_name, receiver_role, meta) ) except Exception as e: logger.error(f"GRPC (SendModelUpdate): An error occurred: {e}") @@ -278,14 +285,8 @@ def send_model_update(self, return True - def send_model_validation(self, - sender_name: str, - receiver_name: str, - receiver_role: fedn.Role, - model_id: str, - metrics: str, - correlation_id: str, - session_id: str + def send_model_validation( + self, sender_name: str, receiver_name: str, receiver_role: fedn.Role, model_id: str, metrics: str, correlation_id: str, session_id: str ) -> bool: validation = fedn.ModelValidation() validation.sender.name = sender_name @@ -298,7 +299,6 @@ def send_model_validation(self, validation.correlation_id = correlation_id validation.session_id = session_id - try: logger.info("Sending model validation to combiner.") _ = self.combinerStub.SendModelValidation(validation, metadata=self.metadata) @@ -306,15 +306,7 @@ def send_model_validation(self, return self._handle_grpc_error( e, "SendModelValidation", - lambda: self.send_model_validation( - sender_name, - receiver_name, - receiver_role, - model_id, - metrics, - correlation_id, - session_id - ) + lambda: self.send_model_validation(sender_name, receiver_name, receiver_role, model_id, metrics, correlation_id, session_id), ) except Exception as e: logger.error(f"GRPC (SendModelValidation): An error occurred: {e}") diff --git a/fedn/network/clients/state.py b/fedn/network/clients/state.py index d7f82a769..678bb578c 100644 --- a/fedn/network/clients/state.py +++ b/fedn/network/clients/state.py @@ -7,7 +7,7 @@ class ClientState(Enum): idle = 1 training = 2 validating = 3 - inferencing = 4 + predicting = 4 def ClientStateToString(state): @@ -24,5 +24,7 @@ def ClientStateToString(state): return "TRAINING" if state == ClientState.validating: return "VALIDATING" + if state == ClientState.predicting: + return "PREDICTING" return "UNKNOWN" diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index d336932c5..3f732ecd4 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -9,13 +9,16 @@ from enum import Enum from typing import TypedDict +from google.protobuf.json_format import MessageToDict + import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.certificate.certificate import Certificate from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream from fedn.network.combiner.roundhandler import RoundConfig, RoundHandler -from fedn.network.combiner.shared import repository, statestore +from fedn.network.combiner.shared import client_store, combiner_store, prediction_store, repository, statestore, status_store, validation_store from fedn.network.grpc.server import Server, ServerConfig +from fedn.network.storage.statestore.stores.shared import EntityNotFound VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" @@ -82,6 +85,12 @@ def __init__(self, config): set_log_stream(config.get("logfile", None)) # Client queues + # Each client in the dict is stored with its client_id as key, and the value is a dict with keys: + # name: str + # status: str + # last_seen: str + # fedn.Queue.TASK_QUEUE: queue.Queue + # Obs that fedn.Queue.TASK_QUEUE is just str(1) self.clients = {} # Validate combiner name @@ -106,19 +115,33 @@ def __init__(self, config): "address": config["host"], "parent": "localhost", "ip": "", + "updated_at": str(datetime.now()), } - self.statestore.set_combiner(interface_config) + # Check if combiner already exists in statestore + try: + _ = combiner_store.get(config["name"]) + except EntityNotFound: + combiner_store.add(interface_config) # Fetch all clients previously connected to the combiner # If a client and a combiner goes down at the same time, # the client will be stuck listed as "online" in the statestore. # Set the status to offline for previous clients. - previous_clients = self.statestore.clients.find({"combiner": config["name"]}) - for client in previous_clients: + previous_clients = client_store.list(limit=0, skip=0, sort_key=None, kwargs={"combiner": self.id}) + count = previous_clients["count"] + result = previous_clients["result"] + logger.info(f"Found {count} previous clients") + logger.info("Updating previous clients status to offline") + for client in result: try: - self.statestore.set_client({"name": client["name"], "status": "offline", "client_id": client["client_id"]}) - except KeyError: - self.statestore.set_client({"name": client["name"], "status": "offline"}) + if "client_id" in client.keys(): + client_store.update("client_id", client["client_id"], {"name": client["name"], "status": "offline"}) + else: + # Old clients might not have a client_id + client_store.update("name", client["name"], {"name": client["name"], "status": "offline"}) + + except Exception as e: + logger.error("Failed to update previous client status: {}".format(str(e))) # Set up gRPC server configuration if config["secure"]: @@ -191,10 +214,10 @@ def request_model_validation(self, session_id, model_id, clients=[]): else: logger.info("Sent model validation request for model {} to {} clients".format(model_id, len(clients))) - def request_model_inference(self, session_id: str, model_id: str, clients: list = []) -> None: - """Ask clients to perform inference on the model. + def request_model_prediction(self, prediction_id: str, model_id: str, clients: list = []) -> None: + """Ask clients to perform prediction on the model. - :param model_id: the model id to perform inference on + :param model_id: the model id to perform prediction on :type model_id: str :param config: the model configuration to send to clients :type config: dict @@ -202,18 +225,20 @@ def request_model_inference(self, session_id: str, model_id: str, clients: list :type clients: list """ - clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients) + clients = self._send_request_type(fedn.StatusType.MODEL_PREDICTION, prediction_id, model_id, clients) if len(clients) < 20: - logger.info("Sent model inference request for model {} to clients {}".format(model_id, clients)) + logger.info("Sent model prediction request for model {} to clients {}".format(model_id, clients)) else: - logger.info("Sent model inference request for model {} to {} clients".format(model_id, len(clients))) + logger.info("Sent model prediction request for model {} to {} clients".format(model_id, len(clients))) def _send_request_type(self, request_type, session_id, model_id, config=None, clients=[]): """Send a request of a specific type to clients. :param request_type: the type of request :type request_type: :class:`fedn.network.grpc.fedn_pb2.StatusType` + :param session_id: the session id to send in the request. Obs that for prediction, this is the prediction id. + :type session_id: str :param model_id: the model id to send in the request :type model_id: str :param config: the model configuration to send to clients @@ -228,8 +253,8 @@ def _send_request_type(self, request_type, session_id, model_id, config=None, cl clients = self.get_active_trainers() elif request_type == fedn.StatusType.MODEL_VALIDATION: clients = self.get_active_validators() - elif request_type == fedn.StatusType.INFERENCE: - # TODO: add inference clients type + elif request_type == fedn.StatusType.MODEL_PREDICTION: + # TODO: add prediction clients type clients = self.get_active_validators() for client in clients: request = fedn.TaskRequest() @@ -244,9 +269,9 @@ def _send_request_type(self, request_type, session_id, model_id, config=None, cl request.receiver.client_id = client request.receiver.role = fedn.WORKER # Set the request data, not used in validation - if request_type == fedn.StatusType.INFERENCE: - presigned_url = self.repository.presigned_put_url(self.repository.inference_bucket, f"{client}/{session_id}") - # TODO: in inference, request.data should also contain user-defined data/parameters + if request_type == fedn.StatusType.MODEL_PREDICTION: + presigned_url = self.repository.presigned_put_url(self.repository.prediction_bucket, f"{client}/{session_id}") + # TODO: in prediction, request.data should also contain user-defined data/parameters request.data = json.dumps({"presigned_url": presigned_url}) elif request_type == fedn.StatusType.MODEL_UPDATE: request.data = json.dumps(config) @@ -354,7 +379,7 @@ def _list_active_clients(self, channel): then = self.clients[client]["last_seen"] if (now - then) < timedelta(seconds=10): clients["active_clients"].append(client) - # If client has changed status, update statestore + # If client has changed status, update client queue if status != "online": self.clients[client]["status"] = "online" clients["update_active_clients"].append(client) @@ -363,9 +388,11 @@ def _list_active_clients(self, channel): clients["update_offline_clients"].append(client) # Update statestore with client status if len(clients["update_active_clients"]) > 0: - self.statestore.update_client_status(clients["update_active_clients"], "online") + for client in clients["update_active_clients"]: + client_store.update("client_id", client, {"status": "online"}) if len(clients["update_offline_clients"]) > 0: - self.statestore.update_client_status(clients["update_offline_clients"], "offline") + for client in clients["update_offline_clients"]: + client_store.update("client_id", client, {"status": "offline"}) return clients["active_clients"] @@ -395,10 +422,11 @@ def _put_request_to_client_queue(self, request, queue_name): def _send_status(self, status): """Report a status to backend db. - :param status: the status to report + :param status: the status message to report :type status: :class:`fedn.network.grpc.fedn_pb2.Status` """ - self.statestore.report_status(status) + data = MessageToDict(status, including_default_value_fields=True) + _ = status_store.add(data) def _flush_model_update_queue(self): """Clear the model update queue (aggregator). @@ -622,11 +650,13 @@ def TaskStream(self, response, context): logger.info("grpc.Combiner.TaskStream: Client connected: {}\n".format(metadata["client"])) status = fedn.Status(status="Client {} connecting to TaskStream.".format(client.name)) + logger.info("Client {} connecting to TaskStream.".format(client.name)) status.log_level = fedn.Status.INFO status.timestamp.GetCurrentTime() self.__whoami(status.sender, self) + # Subscribe client, this also adds the client to self.clients self._subscribe_client_to_queue(client, fedn.Queue.TASK_QUEUE) q = self.__get_queue(client, fedn.Queue.TASK_QUEUE) @@ -634,7 +664,30 @@ def TaskStream(self, response, context): # Set client status to online self.clients[client.client_id]["status"] = "online" - self.statestore.set_client({"name": client.name, "status": "online", "client_id": client.client_id, "last_seen": datetime.now()}) + try: + # If the client is already in the client store, update the status + success, result = client_store.update( + "client_id", + client.client_id, + {"name": client.name, "status": "online", "client_id": client.client_id, "last_seen": datetime.now(), "combiner": self.id}, + ) + if not success and result == "Entity not found": + # If the client is not in the client store, add the client + success, result = client_store.add( + { + "name": client.name, + "status": "online", + "client_id": client.client_id, + "last_seen": datetime.now(), + "combiner": self.id, + "combiner_preferred": self.id, + "updated_at": datetime.now(), + } + ) + elif not success: + logger.error(result) + except Exception as e: + logger.error(f"Failed to update client status: {str(e)}") # Keep track of the time context has been active start_time = time.time() @@ -650,6 +703,11 @@ def TaskStream(self, response, context): pass except Exception as e: logger.error("Error in ModelUpdateRequestStream: {}".format(e)) + logger.warning("Client {} disconnected from TaskStream".format(client.name)) + status = fedn.Status(status="Client {} disconnected from TaskStream.".format(client.name)) + status.log_level = fedn.Status.INFO + status.timestamp.GetCurrentTime() + self._send_status(status) def SendModelUpdate(self, request, context): """Send a model update response. @@ -673,7 +731,12 @@ def register_model_validation(self, validation): :param validation: the model validation :type validation: :class:`fedn.network.grpc.fedn_pb2.ModelValidation` """ - self.statestore.report_validation(validation) + data = MessageToDict(validation, including_default_value_fields=True) + success, result = validation_store.add(data) + if not success: + logger.error(result) + else: + logger.info("Model validation registered: {}".format(result)) def SendModelValidation(self, request, context): """Send a model validation response. @@ -687,12 +750,32 @@ def SendModelValidation(self, request, context): """ logger.info("Recieved ModelValidation from {}".format(request.sender.name)) - self.register_model_validation(request) + validation = MessageToDict(request, including_default_value_fields=True) + validation_store.add(validation) response = fedn.Response() response.response = "RECEIVED ModelValidation {} from client {}".format(response, response.sender.name) return response + def SendModelPrediction(self, request, context): + """Send a model prediction response. + + :param request: the request + :type request: :class:`fedn.network.grpc.fedn_pb2.ModelPrediction` + :param context: the context + :type context: :class:`grpc._server._Context` + :return: the response + :rtype: :class:`fedn.network.grpc.fedn_pb2.Response` + """ + logger.info("Recieved ModelPrediction from {}".format(request.sender.name)) + + result = MessageToDict(request, including_default_value_fields=True) + prediction_store.add(result) + + response = fedn.Response() + response.response = "RECEIVED ModelPrediction {} from client {}".format(response, response.sender.name) + return response + #################################################################################################################### def run(self): diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index 1f0025303..a4808f7a5 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -21,7 +21,7 @@ class RoundConfig(TypedDict): :type _job_id: str :param committed_at: The time the round was committed. Set by Controller. :type committed_at: str - :param task: The task to perform in the round. Set by Controller. Supported tasks are "training", "validation", and "inference". + :param task: The task to perform in the round. Set by Controller. Supported tasks are "training", "validation", and "prediction". :type task: str :param round_id: The round identifier as str(int) :type round_id: str @@ -42,6 +42,8 @@ class RoundConfig(TypedDict): :type model_metadata: dict :param session_id: The session identifier. Set by (Controller?). :type session_id: str + :param prediction_id: The prediction identifier. Only used for prediction tasks. + :type prediction_id: str :param helper_type: The helper type. :type helper_type: str :param aggregator: The aggregator type. @@ -250,17 +252,17 @@ def _validation_round(self, session_id, model_id, clients): """ self.server.request_model_validation(session_id, model_id, clients=clients) - def _inference_round(self, session_id: str, model_id: str, clients: list): - """Send model inference requests to clients. + def _prediction_round(self, prediction_id: str, model_id: str, clients: list): + """Send model prediction requests to clients. :param config: The round config object (passed to the client). :type config: dict - :param clients: clients to send inference requests to + :param clients: clients to send prediction requests to :type clients: list - :param model_id: The ID of the model to use for inference + :param model_id: The ID of the model to use for prediction :type model_id: str """ - self.server.request_model_inference(session_id, model_id, clients=clients) + self.server.request_model_prediction(prediction_id, model_id, clients=clients) def stage_model(self, model_id, timeout_retry=3, retry=2): """Download a model from persistent storage and set in modelservice. @@ -348,17 +350,17 @@ def execute_validation_round(self, session_id, model_id): validators = self._assign_round_clients(self.server.max_clients, type="validators") self._validation_round(session_id, model_id, validators) - def execute_inference_round(self, session_id: str, model_id: str) -> None: - """Coordinate inference rounds as specified in config. + def execute_prediction_round(self, prediction_id: str, model_id: str) -> None: + """Coordinate prediction rounds as specified in config. :param round_config: The round config object. :type round_config: dict """ - logger.info("COMBINER orchestrating inference using model {}".format(model_id)) + logger.info("COMBINER orchestrating prediction using model {}".format(model_id)) self.stage_model(model_id) - # TODO: Implement inference client type + # TODO: Implement prediction client type clients = self._assign_round_clients(self.server.max_clients, type="validators") - self._inference_round(session_id, model_id, clients) + self._prediction_round(prediction_id, model_id, clients) def execute_training_round(self, config): """Coordinates clients to execute training tasks. @@ -407,8 +409,6 @@ def run(self, polling_interval=1.0): while True: try: round_config = self.round_configs.get(block=False) - session_id = round_config["session_id"] - model_id = round_config["model_id"] # Check that the minimum allowed number of clients are connected ready = self._check_nr_round_clients(round_config) @@ -416,6 +416,8 @@ def run(self, polling_interval=1.0): if ready: if round_config["task"] == "training": + session_id = round_config["session_id"] + model_id = round_config["model_id"] tic = time.time() round_meta = self.execute_training_round(round_config) round_meta["time_exec_training"] = time.time() - tic @@ -423,9 +425,13 @@ def run(self, polling_interval=1.0): round_meta["name"] = self.server.id self.server.statestore.set_round_combiner_data(round_meta) elif round_config["task"] == "validation": + session_id = round_config["session_id"] + model_id = round_config["model_id"] self.execute_validation_round(session_id, model_id) - elif round_config["task"] == "inference": - self.execute_inference_round(session_id, model_id) + elif round_config["task"] == "prediction": + prediction_id = round_config["prediction_id"] + model_id = round_config["model_id"] + self.execute_prediction_round(prediction_id, model_id) else: logger.warning("config contains unkown task type.") else: diff --git a/fedn/network/combiner/shared.py b/fedn/network/combiner/shared.py index e1dd6854f..bf9a63032 100644 --- a/fedn/network/combiner/shared.py +++ b/fedn/network/combiner/shared.py @@ -1,13 +1,33 @@ +import pymongo +from pymongo.database import Database + from fedn.common.config import get_modelstorage_config, get_network_config, get_statestore_config from fedn.network.combiner.modelservice import ModelService from fedn.network.storage.s3.repository import Repository from fedn.network.storage.statestore.mongostatestore import MongoStateStore +from fedn.network.storage.statestore.stores.client_store import ClientStore +from fedn.network.storage.statestore.stores.combiner_store import CombinerStore +from fedn.network.storage.statestore.stores.prediction_store import PredictionStore +from fedn.network.storage.statestore.stores.status_store import StatusStore +from fedn.network.storage.statestore.stores.validation_store import ValidationStore statestore_config = get_statestore_config() modelstorage_config = get_modelstorage_config() network_id = get_network_config() statestore = MongoStateStore(network_id, statestore_config["mongo_config"]) + +if statestore_config["type"] == "MongoDB": + mc = pymongo.MongoClient(**statestore_config["mongo_config"]) + mc.server_info() + mdb: Database = mc[network_id] + +client_store = ClientStore(mdb, "network.clients") +validation_store = ValidationStore(mdb, "control.validations") +combiner_store = CombinerStore(mdb, "network.combiners") +status_store = StatusStore(mdb, "control.status") +prediction_store = PredictionStore(mdb, "control.predictions") + repository = Repository(modelstorage_config["storage_config"], init_buckets=False) modelservice = ModelService() diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index 38e775d0e..129dd1af0 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -82,7 +82,7 @@ def __init__(self, message): class Control(ControlBase): - """Controller, implementing the overall global training, validation and inference logic. + """Controller, implementing the overall global training, validation and prediction logic. :param statestore: A StateStorage instance. :type statestore: class: `fedn.network.statestorebase.StateStorageBase` @@ -211,11 +211,11 @@ def session(self, config: RoundConfig) -> None: self.set_session_status(config["session_id"], "Finished") self._state = ReducerState.idle - def inference_session(self, config: RoundConfig) -> None: - """Execute a new inference session. + def prediction_session(self, config: RoundConfig) -> None: + """Execute a new prediction session. :param config: The round config. - :type config: InferenceConfig + :type config: PredictionConfig :return: None """ if self._state == ReducerState.instructing: @@ -223,14 +223,14 @@ def inference_session(self, config: RoundConfig) -> None: return if len(self.network.get_combiners()) < 1: - logger.warning("Inference round cannot start, no combiners connected!") + logger.warning("Prediction round cannot start, no combiners connected!") return if "model_id" not in config.keys(): config["model_id"] = self.statestore.get_latest_model() config["committed_at"] = datetime.datetime.now() - config["task"] = "inference" + config["task"] = "prediction" config["rounds"] = str(1) config["clients_required"] = 1 @@ -240,10 +240,10 @@ def inference_session(self, config: RoundConfig) -> None: round_start = self.evaluate_round_start_policy(participating_combiners) if round_start: - logger.info("Inference round start policy met, {} participating combiners.".format(len(participating_combiners))) + logger.info("Prediction round start policy met, {} participating combiners.".format(len(participating_combiners))) for combiner, _ in participating_combiners: combiner.submit(config) - logger.info("Inference round submitted to combiner {}".format(combiner)) + logger.info("Prediction round submitted to combiner {}".format(combiner)) def round(self, session_config: RoundConfig, round_id: str): """Execute one global round. @@ -437,10 +437,10 @@ def reduce(self, combiners): return model, meta - def infer_instruct(self, config): - """Main entrypoint for executing the inference compute plan. + def predict_instruct(self, config): + """Main entrypoint for executing the prediction compute plan. - : param config: configuration for the inference round + : param config: configuration for the prediction round """ # Check/set instucting state if self.__state == ReducerState.instructing: @@ -455,19 +455,19 @@ def infer_instruct(self, config): # Set reducer in monitoring state self.__state = ReducerState.monitoring - # Start inference round + # Start prediction round try: - self.inference_round(config) + self.prediction_round(config) except TypeError: logger.error("Round failed.") # Set reducer in idle state self.__state = ReducerState.idle - def inference_round(self, config): - """Execute an inference round. + def prediction_round(self, config): + """Execute a prediction round. - : param config: configuration for the inference round + : param config: configuration for the prediction round """ # Init meta round_data = {} @@ -480,7 +480,7 @@ def inference_round(self, config): # Setup combiner configuration combiner_config = copy.deepcopy(config) combiner_config["model_id"] = self.statestore.get_latest_model() - combiner_config["task"] = "inference" + combiner_config["task"] = "prediction" combiner_config["helper_type"] = self.statestore.get_framework() # Select combiners @@ -494,12 +494,12 @@ def inference_round(self, config): logger.warning("Round start policy not met, skipping round!") return None - # Synch combiners with latest model and trigger inference + # Synch combiners with latest model and trigger prediction for combiner, combiner_config in validating_combiners: try: combiner.submit(combiner_config) except CombinerUnavailableError: - # It is OK if inference fails for a combiner + # It is OK if prediction fails for a combiner self._handle_unavailable_combiner(combiner) pass diff --git a/fedn/network/grpc/fedn.proto b/fedn/network/grpc/fedn.proto index 04a1e5175..558b7e67d 100644 --- a/fedn/network/grpc/fedn.proto +++ b/fedn/network/grpc/fedn.proto @@ -15,7 +15,7 @@ enum StatusType { MODEL_UPDATE = 2; MODEL_VALIDATION_REQUEST = 3; MODEL_VALIDATION = 4; - INFERENCE = 5; + MODEL_PREDICTION = 5; } message Status { @@ -80,6 +80,17 @@ message ModelValidation { string session_id = 8; } +message ModelPrediction { + Client sender = 1; + Client receiver = 2; + string model_id = 3; + string data = 4; + string correlation_id = 5; + google.protobuf.Timestamp timestamp = 6; + string meta = 7; + string prediction_id = 8; +} + enum ModelStatus { OK = 0; IN_PROGRESS = 1; @@ -241,6 +252,7 @@ service Combiner { rpc SendModelUpdate (ModelUpdate) returns (Response); rpc SendModelValidation (ModelValidation) returns (Response); + rpc SendModelPrediction (ModelPrediction) returns (Response); } diff --git a/fedn/network/grpc/fedn_pb2.py b/fedn/network/grpc/fedn_pb2.py index d47f08792..cb637baba 100644 --- a/fedn/network/grpc/fedn_pb2.py +++ b/fedn/network/grpc/fedn_pb2.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: network/grpc/fedn.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -14,26 +15,25 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17network/grpc/fedn.proto\x12\x04\x66\x65\x64n\x1a\x1fgoogle/protobuf/timestamp.proto\":\n\x08Response\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08response\x18\x02 \x01(\t\"\xbc\x02\n\x06Status\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06status\x18\x02 \x01(\t\x12(\n\tlog_level\x18\x03 \x01(\x0e\x32\x15.fedn.Status.LogLevel\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x1e\n\x04type\x18\x07 \x01(\x0e\x32\x10.fedn.StatusType\x12\r\n\x05\x65xtra\x18\x08 \x01(\t\x12\x12\n\nsession_id\x18\t \x01(\t\"B\n\x08LogLevel\x12\x08\n\x04INFO\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x0b\n\x07WARNING\x10\x02\x12\t\n\x05\x45RROR\x10\x03\x12\t\n\x05\x41UDIT\x10\x04\"\xd8\x01\n\x0bTaskRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\x12\x1e\n\x04type\x18\t \x01(\x0e\x32\x10.fedn.StatusType\"\xbf\x01\n\x0bModelUpdate\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x17\n\x0fmodel_update_id\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x08 \x01(\t\"\xd8\x01\n\x0fModelValidation\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\"\x89\x01\n\x0cModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\n\n\x02id\x18\x04 \x01(\t\x12!\n\x06status\x18\x05 \x01(\x0e\x32\x11.fedn.ModelStatus\"]\n\rModelResponse\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\n\n\x02id\x18\x02 \x01(\t\x12!\n\x06status\x18\x03 \x01(\x0e\x32\x11.fedn.ModelStatus\x12\x0f\n\x07message\x18\x04 \x01(\t\"U\n\x15GetGlobalModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\"h\n\x16GetGlobalModelResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\")\n\tHeartbeat\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\"W\n\x16\x43lientAvailableMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"P\n\x12ListClientsRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1c\n\x07\x63hannel\x18\x02 \x01(\x0e\x32\x0b.fedn.Queue\"*\n\nClientList\x12\x1c\n\x06\x63lient\x18\x01 \x03(\x0b\x32\x0c.fedn.Client\"C\n\x06\x43lient\x12\x18\n\x04role\x18\x01 \x01(\x0e\x32\n.fedn.Role\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tclient_id\x18\x03 \x01(\t\"m\n\x0fReassignRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06server\x18\x03 \x01(\t\x12\x0c\n\x04port\x18\x04 \x01(\r\"c\n\x10ReconnectRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x11\n\treconnect\x18\x03 \x01(\r\"\'\n\tParameter\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"T\n\x0e\x43ontrolRequest\x12\x1e\n\x07\x63ommand\x18\x01 \x01(\x0e\x32\r.fedn.Command\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\"F\n\x0f\x43ontrolResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\"\x13\n\x11\x43onnectionRequest\"<\n\x12\x43onnectionResponse\x12&\n\x06status\x18\x01 \x01(\x0e\x32\x16.fedn.ConnectionStatus*\x84\x01\n\nStatusType\x12\x07\n\x03LOG\x10\x00\x12\x18\n\x14MODEL_UPDATE_REQUEST\x10\x01\x12\x10\n\x0cMODEL_UPDATE\x10\x02\x12\x1c\n\x18MODEL_VALIDATION_REQUEST\x10\x03\x12\x14\n\x10MODEL_VALIDATION\x10\x04\x12\r\n\tINFERENCE\x10\x05*$\n\x05Queue\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x0e\n\nTASK_QUEUE\x10\x01*S\n\x0bModelStatus\x12\x06\n\x02OK\x10\x00\x12\x0f\n\x0bIN_PROGRESS\x10\x01\x12\x12\n\x0eIN_PROGRESS_OK\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\x0b\n\x07UNKNOWN\x10\x04*8\n\x04Role\x12\n\n\x06WORKER\x10\x00\x12\x0c\n\x08\x43OMBINER\x10\x01\x12\x0b\n\x07REDUCER\x10\x02\x12\t\n\x05OTHER\x10\x03*J\n\x07\x43ommand\x12\x08\n\x04IDLE\x10\x00\x12\t\n\x05START\x10\x01\x12\t\n\x05PAUSE\x10\x02\x12\x08\n\x04STOP\x10\x03\x12\t\n\x05RESET\x10\x04\x12\n\n\x06REPORT\x10\x05*I\n\x10\x43onnectionStatus\x12\x11\n\rNOT_ACCEPTING\x10\x00\x12\r\n\tACCEPTING\x10\x01\x12\x13\n\x0fTRY_AGAIN_LATER\x10\x02\x32z\n\x0cModelService\x12\x33\n\x06Upload\x12\x12.fedn.ModelRequest\x1a\x13.fedn.ModelResponse(\x01\x12\x35\n\x08\x44ownload\x12\x12.fedn.ModelRequest\x1a\x13.fedn.ModelResponse0\x01\x32\xf8\x01\n\x07\x43ontrol\x12\x34\n\x05Start\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x33\n\x04Stop\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x44\n\x15\x46lushAggregationQueue\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12<\n\rSetAggregator\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse2V\n\x07Reducer\x12K\n\x0eGetGlobalModel\x12\x1b.fedn.GetGlobalModelRequest\x1a\x1c.fedn.GetGlobalModelResponse2\xab\x03\n\tConnector\x12\x44\n\x14\x41llianceStatusStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x0c.fedn.Status0\x01\x12*\n\nSendStatus\x12\x0c.fedn.Status\x1a\x0e.fedn.Response\x12?\n\x11ListActiveClients\x12\x18.fedn.ListClientsRequest\x1a\x10.fedn.ClientList\x12\x45\n\x10\x41\x63\x63\x65ptingClients\x12\x17.fedn.ConnectionRequest\x1a\x18.fedn.ConnectionResponse\x12\x30\n\rSendHeartbeat\x12\x0f.fedn.Heartbeat\x1a\x0e.fedn.Response\x12\x37\n\x0eReassignClient\x12\x15.fedn.ReassignRequest\x1a\x0e.fedn.Response\x12\x39\n\x0fReconnectClient\x12\x16.fedn.ReconnectRequest\x1a\x0e.fedn.Response2\xbf\x01\n\x08\x43ombiner\x12?\n\nTaskStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x11.fedn.TaskRequest0\x01\x12\x34\n\x0fSendModelUpdate\x12\x11.fedn.ModelUpdate\x1a\x0e.fedn.Response\x12<\n\x13SendModelValidation\x12\x15.fedn.ModelValidation\x1a\x0e.fedn.Responseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17network/grpc/fedn.proto\x12\x04\x66\x65\x64n\x1a\x1fgoogle/protobuf/timestamp.proto\":\n\x08Response\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08response\x18\x02 \x01(\t\"\xbc\x02\n\x06Status\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06status\x18\x02 \x01(\t\x12(\n\tlog_level\x18\x03 \x01(\x0e\x32\x15.fedn.Status.LogLevel\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x1e\n\x04type\x18\x07 \x01(\x0e\x32\x10.fedn.StatusType\x12\r\n\x05\x65xtra\x18\x08 \x01(\t\x12\x12\n\nsession_id\x18\t \x01(\t\"B\n\x08LogLevel\x12\x08\n\x04INFO\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x0b\n\x07WARNING\x10\x02\x12\t\n\x05\x45RROR\x10\x03\x12\t\n\x05\x41UDIT\x10\x04\"\xd8\x01\n\x0bTaskRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\x12\x1e\n\x04type\x18\t \x01(\x0e\x32\x10.fedn.StatusType\"\xbf\x01\n\x0bModelUpdate\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x17\n\x0fmodel_update_id\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x08 \x01(\t\"\xd8\x01\n\x0fModelValidation\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\"\xdb\x01\n\x0fModelPrediction\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x15\n\rprediction_id\x18\x08 \x01(\t\"\x89\x01\n\x0cModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\n\n\x02id\x18\x04 \x01(\t\x12!\n\x06status\x18\x05 \x01(\x0e\x32\x11.fedn.ModelStatus\"]\n\rModelResponse\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\n\n\x02id\x18\x02 \x01(\t\x12!\n\x06status\x18\x03 \x01(\x0e\x32\x11.fedn.ModelStatus\x12\x0f\n\x07message\x18\x04 \x01(\t\"U\n\x15GetGlobalModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\"h\n\x16GetGlobalModelResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\")\n\tHeartbeat\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\"W\n\x16\x43lientAvailableMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"P\n\x12ListClientsRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1c\n\x07\x63hannel\x18\x02 \x01(\x0e\x32\x0b.fedn.Queue\"*\n\nClientList\x12\x1c\n\x06\x63lient\x18\x01 \x03(\x0b\x32\x0c.fedn.Client\"C\n\x06\x43lient\x12\x18\n\x04role\x18\x01 \x01(\x0e\x32\n.fedn.Role\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tclient_id\x18\x03 \x01(\t\"m\n\x0fReassignRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06server\x18\x03 \x01(\t\x12\x0c\n\x04port\x18\x04 \x01(\r\"c\n\x10ReconnectRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x11\n\treconnect\x18\x03 \x01(\r\"\'\n\tParameter\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"T\n\x0e\x43ontrolRequest\x12\x1e\n\x07\x63ommand\x18\x01 \x01(\x0e\x32\r.fedn.Command\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\"F\n\x0f\x43ontrolResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\"\x13\n\x11\x43onnectionRequest\"<\n\x12\x43onnectionResponse\x12&\n\x06status\x18\x01 \x01(\x0e\x32\x16.fedn.ConnectionStatus*\x8b\x01\n\nStatusType\x12\x07\n\x03LOG\x10\x00\x12\x18\n\x14MODEL_UPDATE_REQUEST\x10\x01\x12\x10\n\x0cMODEL_UPDATE\x10\x02\x12\x1c\n\x18MODEL_VALIDATION_REQUEST\x10\x03\x12\x14\n\x10MODEL_VALIDATION\x10\x04\x12\x14\n\x10MODEL_PREDICTION\x10\x05*$\n\x05Queue\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x0e\n\nTASK_QUEUE\x10\x01*S\n\x0bModelStatus\x12\x06\n\x02OK\x10\x00\x12\x0f\n\x0bIN_PROGRESS\x10\x01\x12\x12\n\x0eIN_PROGRESS_OK\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\x0b\n\x07UNKNOWN\x10\x04*8\n\x04Role\x12\n\n\x06WORKER\x10\x00\x12\x0c\n\x08\x43OMBINER\x10\x01\x12\x0b\n\x07REDUCER\x10\x02\x12\t\n\x05OTHER\x10\x03*J\n\x07\x43ommand\x12\x08\n\x04IDLE\x10\x00\x12\t\n\x05START\x10\x01\x12\t\n\x05PAUSE\x10\x02\x12\x08\n\x04STOP\x10\x03\x12\t\n\x05RESET\x10\x04\x12\n\n\x06REPORT\x10\x05*I\n\x10\x43onnectionStatus\x12\x11\n\rNOT_ACCEPTING\x10\x00\x12\r\n\tACCEPTING\x10\x01\x12\x13\n\x0fTRY_AGAIN_LATER\x10\x02\x32z\n\x0cModelService\x12\x33\n\x06Upload\x12\x12.fedn.ModelRequest\x1a\x13.fedn.ModelResponse(\x01\x12\x35\n\x08\x44ownload\x12\x12.fedn.ModelRequest\x1a\x13.fedn.ModelResponse0\x01\x32\xf8\x01\n\x07\x43ontrol\x12\x34\n\x05Start\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x33\n\x04Stop\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x44\n\x15\x46lushAggregationQueue\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12<\n\rSetAggregator\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse2V\n\x07Reducer\x12K\n\x0eGetGlobalModel\x12\x1b.fedn.GetGlobalModelRequest\x1a\x1c.fedn.GetGlobalModelResponse2\xab\x03\n\tConnector\x12\x44\n\x14\x41llianceStatusStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x0c.fedn.Status0\x01\x12*\n\nSendStatus\x12\x0c.fedn.Status\x1a\x0e.fedn.Response\x12?\n\x11ListActiveClients\x12\x18.fedn.ListClientsRequest\x1a\x10.fedn.ClientList\x12\x45\n\x10\x41\x63\x63\x65ptingClients\x12\x17.fedn.ConnectionRequest\x1a\x18.fedn.ConnectionResponse\x12\x30\n\rSendHeartbeat\x12\x0f.fedn.Heartbeat\x1a\x0e.fedn.Response\x12\x37\n\x0eReassignClient\x12\x15.fedn.ReassignRequest\x1a\x0e.fedn.Response\x12\x39\n\x0fReconnectClient\x12\x16.fedn.ReconnectRequest\x1a\x0e.fedn.Response2\xfd\x01\n\x08\x43ombiner\x12?\n\nTaskStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x11.fedn.TaskRequest0\x01\x12\x34\n\x0fSendModelUpdate\x12\x11.fedn.ModelUpdate\x1a\x0e.fedn.Response\x12<\n\x13SendModelValidation\x12\x15.fedn.ModelValidation\x1a\x0e.fedn.Response\x12<\n\x13SendModelPrediction\x12\x15.fedn.ModelPrediction\x1a\x0e.fedn.Responseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'network.grpc.fedn_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals['_STATUSTYPE']._serialized_start=2327 - _globals['_STATUSTYPE']._serialized_end=2459 - _globals['_QUEUE']._serialized_start=2461 - _globals['_QUEUE']._serialized_end=2497 - _globals['_MODELSTATUS']._serialized_start=2499 - _globals['_MODELSTATUS']._serialized_end=2582 - _globals['_ROLE']._serialized_start=2584 - _globals['_ROLE']._serialized_end=2640 - _globals['_COMMAND']._serialized_start=2642 - _globals['_COMMAND']._serialized_end=2716 - _globals['_CONNECTIONSTATUS']._serialized_start=2718 - _globals['_CONNECTIONSTATUS']._serialized_end=2791 + _globals['_STATUSTYPE']._serialized_start=2549 + _globals['_STATUSTYPE']._serialized_end=2688 + _globals['_QUEUE']._serialized_start=2690 + _globals['_QUEUE']._serialized_end=2726 + _globals['_MODELSTATUS']._serialized_start=2728 + _globals['_MODELSTATUS']._serialized_end=2811 + _globals['_ROLE']._serialized_start=2813 + _globals['_ROLE']._serialized_end=2869 + _globals['_COMMAND']._serialized_start=2871 + _globals['_COMMAND']._serialized_end=2945 + _globals['_CONNECTIONSTATUS']._serialized_start=2947 + _globals['_CONNECTIONSTATUS']._serialized_end=3020 _globals['_RESPONSE']._serialized_start=66 _globals['_RESPONSE']._serialized_end=124 _globals['_STATUS']._serialized_start=127 @@ -46,46 +46,48 @@ _globals['_MODELUPDATE']._serialized_end=856 _globals['_MODELVALIDATION']._serialized_start=859 _globals['_MODELVALIDATION']._serialized_end=1075 - _globals['_MODELREQUEST']._serialized_start=1078 - _globals['_MODELREQUEST']._serialized_end=1215 - _globals['_MODELRESPONSE']._serialized_start=1217 - _globals['_MODELRESPONSE']._serialized_end=1310 - _globals['_GETGLOBALMODELREQUEST']._serialized_start=1312 - _globals['_GETGLOBALMODELREQUEST']._serialized_end=1397 - _globals['_GETGLOBALMODELRESPONSE']._serialized_start=1399 - _globals['_GETGLOBALMODELRESPONSE']._serialized_end=1503 - _globals['_HEARTBEAT']._serialized_start=1505 - _globals['_HEARTBEAT']._serialized_end=1546 - _globals['_CLIENTAVAILABLEMESSAGE']._serialized_start=1548 - _globals['_CLIENTAVAILABLEMESSAGE']._serialized_end=1635 - _globals['_LISTCLIENTSREQUEST']._serialized_start=1637 - _globals['_LISTCLIENTSREQUEST']._serialized_end=1717 - _globals['_CLIENTLIST']._serialized_start=1719 - _globals['_CLIENTLIST']._serialized_end=1761 - _globals['_CLIENT']._serialized_start=1763 - _globals['_CLIENT']._serialized_end=1830 - _globals['_REASSIGNREQUEST']._serialized_start=1832 - _globals['_REASSIGNREQUEST']._serialized_end=1941 - _globals['_RECONNECTREQUEST']._serialized_start=1943 - _globals['_RECONNECTREQUEST']._serialized_end=2042 - _globals['_PARAMETER']._serialized_start=2044 - _globals['_PARAMETER']._serialized_end=2083 - _globals['_CONTROLREQUEST']._serialized_start=2085 - _globals['_CONTROLREQUEST']._serialized_end=2169 - _globals['_CONTROLRESPONSE']._serialized_start=2171 - _globals['_CONTROLRESPONSE']._serialized_end=2241 - _globals['_CONNECTIONREQUEST']._serialized_start=2243 - _globals['_CONNECTIONREQUEST']._serialized_end=2262 - _globals['_CONNECTIONRESPONSE']._serialized_start=2264 - _globals['_CONNECTIONRESPONSE']._serialized_end=2324 - _globals['_MODELSERVICE']._serialized_start=2793 - _globals['_MODELSERVICE']._serialized_end=2915 - _globals['_CONTROL']._serialized_start=2918 - _globals['_CONTROL']._serialized_end=3166 - _globals['_REDUCER']._serialized_start=3168 - _globals['_REDUCER']._serialized_end=3254 - _globals['_CONNECTOR']._serialized_start=3257 - _globals['_CONNECTOR']._serialized_end=3684 - _globals['_COMBINER']._serialized_start=3687 - _globals['_COMBINER']._serialized_end=3878 + _globals['_MODELPREDICTION']._serialized_start=1078 + _globals['_MODELPREDICTION']._serialized_end=1297 + _globals['_MODELREQUEST']._serialized_start=1300 + _globals['_MODELREQUEST']._serialized_end=1437 + _globals['_MODELRESPONSE']._serialized_start=1439 + _globals['_MODELRESPONSE']._serialized_end=1532 + _globals['_GETGLOBALMODELREQUEST']._serialized_start=1534 + _globals['_GETGLOBALMODELREQUEST']._serialized_end=1619 + _globals['_GETGLOBALMODELRESPONSE']._serialized_start=1621 + _globals['_GETGLOBALMODELRESPONSE']._serialized_end=1725 + _globals['_HEARTBEAT']._serialized_start=1727 + _globals['_HEARTBEAT']._serialized_end=1768 + _globals['_CLIENTAVAILABLEMESSAGE']._serialized_start=1770 + _globals['_CLIENTAVAILABLEMESSAGE']._serialized_end=1857 + _globals['_LISTCLIENTSREQUEST']._serialized_start=1859 + _globals['_LISTCLIENTSREQUEST']._serialized_end=1939 + _globals['_CLIENTLIST']._serialized_start=1941 + _globals['_CLIENTLIST']._serialized_end=1983 + _globals['_CLIENT']._serialized_start=1985 + _globals['_CLIENT']._serialized_end=2052 + _globals['_REASSIGNREQUEST']._serialized_start=2054 + _globals['_REASSIGNREQUEST']._serialized_end=2163 + _globals['_RECONNECTREQUEST']._serialized_start=2165 + _globals['_RECONNECTREQUEST']._serialized_end=2264 + _globals['_PARAMETER']._serialized_start=2266 + _globals['_PARAMETER']._serialized_end=2305 + _globals['_CONTROLREQUEST']._serialized_start=2307 + _globals['_CONTROLREQUEST']._serialized_end=2391 + _globals['_CONTROLRESPONSE']._serialized_start=2393 + _globals['_CONTROLRESPONSE']._serialized_end=2463 + _globals['_CONNECTIONREQUEST']._serialized_start=2465 + _globals['_CONNECTIONREQUEST']._serialized_end=2484 + _globals['_CONNECTIONRESPONSE']._serialized_start=2486 + _globals['_CONNECTIONRESPONSE']._serialized_end=2546 + _globals['_MODELSERVICE']._serialized_start=3022 + _globals['_MODELSERVICE']._serialized_end=3144 + _globals['_CONTROL']._serialized_start=3147 + _globals['_CONTROL']._serialized_end=3395 + _globals['_REDUCER']._serialized_start=3397 + _globals['_REDUCER']._serialized_end=3483 + _globals['_CONNECTOR']._serialized_start=3486 + _globals['_CONNECTOR']._serialized_end=3913 + _globals['_COMBINER']._serialized_start=3916 + _globals['_COMBINER']._serialized_end=4169 # @@protoc_insertion_point(module_scope) diff --git a/fedn/network/grpc/fedn_pb2_grpc.py b/fedn/network/grpc/fedn_pb2_grpc.py index 86ebea055..32ac134d7 100644 --- a/fedn/network/grpc/fedn_pb2_grpc.py +++ b/fedn/network/grpc/fedn_pb2_grpc.py @@ -608,6 +608,11 @@ def __init__(self, channel): request_serializer=network_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, response_deserializer=network_dot_grpc_dot_fedn__pb2.Response.FromString, ) + self.SendModelPrediction = channel.unary_unary( + '/fedn.Combiner/SendModelPrediction', + request_serializer=network_dot_grpc_dot_fedn__pb2.ModelPrediction.SerializeToString, + response_deserializer=network_dot_grpc_dot_fedn__pb2.Response.FromString, + ) class CombinerServicer(object): @@ -632,6 +637,12 @@ def SendModelValidation(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def SendModelPrediction(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_CombinerServicer_to_server(servicer, server): rpc_method_handlers = { @@ -650,6 +661,11 @@ def add_CombinerServicer_to_server(servicer, server): request_deserializer=network_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, response_serializer=network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, ), + 'SendModelPrediction': grpc.unary_unary_rpc_method_handler( + servicer.SendModelPrediction, + request_deserializer=network_dot_grpc_dot_fedn__pb2.ModelPrediction.FromString, + response_serializer=network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'fedn.Combiner', rpc_method_handlers) @@ -710,3 +726,20 @@ def SendModelValidation(request, network_dot_grpc_dot_fedn__pb2.Response.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def SendModelPrediction(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/fedn.Combiner/SendModelPrediction', + network_dot_grpc_dot_fedn__pb2.ModelPrediction.SerializeToString, + network_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/fedn/network/grpc/server.py b/fedn/network/grpc/server.py index edd2fd6d5..8523f4a64 100644 --- a/fedn/network/grpc/server.py +++ b/fedn/network/grpc/server.py @@ -26,7 +26,20 @@ def __init__(self, servicer, config: ServerConfig): set_log_level_from_string(config.get("verbosity", "INFO")) set_log_stream(config.get("logfile", None)) - self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=350), interceptors=[JWTInterceptor()]) + # Keepalive settings: these detect if the client is alive + KEEPALIVE_TIME_MS = 60 * 1000 # send keepalive ping every 60 seconds + KEEPALIVE_TIMEOUT_MS = 20 * 1000 # wait 20 seconds for keepalive ping ack before considering connection dead + MAX_CONNECTION_IDLE_MS = 5 * 60 * 1000 # max idle time before server terminates the connection (5 minutes) + + self.server = grpc.server( + futures.ThreadPoolExecutor(max_workers=350), + interceptors=[JWTInterceptor()], + options=[ + ("grpc.keepalive_time_ms", KEEPALIVE_TIME_MS), + ("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS), + ("grpc.max_connection_idle_ms", MAX_CONNECTION_IDLE_MS), + ], + ) self.certificate = None self.health_servicer = health.HealthServicer() diff --git a/fedn/network/storage/s3/repository.py b/fedn/network/storage/s3/repository.py index 0b08fe7b9..4ad11bc0b 100644 --- a/fedn/network/storage/s3/repository.py +++ b/fedn/network/storage/s3/repository.py @@ -12,9 +12,9 @@ def __init__(self, config, init_buckets=True): self.model_bucket = config["storage_bucket"] self.context_bucket = config["context_bucket"] try: - self.inference_bucket = config["inference_bucket"] + self.prediction_bucket = config["prediction_bucket"] except KeyError: - self.inference_bucket = "fedn-inference" + self.prediction_bucket = "fedn-prediction" # TODO: Make a plug-in solution self.client = MINIORepository(config) @@ -22,7 +22,7 @@ def __init__(self, config, init_buckets=True): if init_buckets: self.client.create_bucket(self.context_bucket) self.client.create_bucket(self.model_bucket) - self.client.create_bucket(self.inference_bucket) + self.client.create_bucket(self.prediction_bucket) def get_model(self, model_id): """Retrieve a model with id model_id. diff --git a/fedn/network/storage/statestore/stores/client_store.py b/fedn/network/storage/statestore/stores/client_store.py index c3c2e5225..dd521860f 100644 --- a/fedn/network/storage/statestore/stores/client_store.py +++ b/fedn/network/storage/statestore/stores/client_store.py @@ -7,7 +7,7 @@ from fedn.network.storage.statestore.stores.store import Store -from .shared import EntityNotFound +from .shared import EntityNotFound, from_document class Client: @@ -30,7 +30,7 @@ def from_dict(data: dict) -> "Client": ip=data["ip"] if "ip" in data else None, status=data["status"] if "status" in data else None, updated_at=data["updated_at"] if "updated_at" in data else None, - last_seen=data["last_seen"] if "last_seen" in data else None + last_seen=data["last_seen"] if "last_seen" in data else None, ) @@ -49,14 +49,34 @@ def get(self, id: str, use_typing: bool = False) -> Client: response = super().get(id, use_typing=use_typing) return Client.from_dict(response) if use_typing else response - def update(self, id: str, item: Client) -> bool: - raise NotImplementedError("Update not implemented for ClientStore") + def _get_client_by_client_id(self, client_id: str) -> Dict: + document = self.database[self.collection].find_one({"client_id": client_id}) + if document is None: + raise EntityNotFound(f"Entity with client_id {client_id} not found") + return document + + def _get_client_by_name(self, name: str) -> Dict: + document = self.database[self.collection].find_one({"name": name}) + if document is None: + raise EntityNotFound(f"Entity with name {name} not found") + return document - def add(self, item: Client)-> Tuple[bool, Any]: - raise NotImplementedError("Add not implemented for ClientStore") + def update(self, by_key: str, value: str, item: Client) -> bool: + try: + result = self.database[self.collection].update_one({by_key: value}, {"$set": item}) + if result.modified_count == 1: + document = self.database[self.collection].find_one({by_key: value}) + return True, from_document(document) + else: + return False, "Entity not found" + except Exception as e: + return False, str(e) + + def add(self, item: Client) -> Tuple[bool, Any]: + return super().add(item) def delete(self, id: str) -> bool: - kwargs = { "_id": ObjectId(id) } if ObjectId.is_valid(id) else { "client_id": id } + kwargs = {"_id": ObjectId(id)} if ObjectId.is_valid(id) else {"client_id": id} document = self.database[self.collection].find_one(kwargs) @@ -86,10 +106,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI result = [Client.from_dict(item) for item in response["result"]] if use_typing else response["result"] - return { - "count": response["count"], - "result": result - } + return {"count": response["count"], "result": result} def count(self, **kwargs) -> int: return super().count(**kwargs) @@ -111,13 +128,13 @@ def connected_client_count(self, combiners): [ {"$match": {"combiner": {"$in": combiners}, "status": "online"}}, {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, - {"$project": {"id": "$_id", "count": 1, "_id": 0}} + {"$project": {"id": "$_id", "count": 1, "_id": 0}}, ] if len(combiners) > 0 else [ - {"$match": { "status": "online"}}, + {"$match": {"status": "online"}}, {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, - {"$project": {"id": "$_id", "count": 1, "_id": 0}} + {"$project": {"id": "$_id", "count": 1, "_id": 0}}, ] ) diff --git a/fedn/network/storage/statestore/stores/combiner_store.py b/fedn/network/storage/statestore/stores/combiner_store.py index 5fceea1b7..2ad6437ea 100644 --- a/fedn/network/storage/statestore/stores/combiner_store.py +++ b/fedn/network/storage/statestore/stores/combiner_store.py @@ -11,19 +11,19 @@ class Combiner: def __init__( - self, - id: str, - name: str, - address: str, - certificate: str, - config: dict, - fqdn: str, - ip: str, - key: str, - parent: dict, - port: int, - status: str, - updated_at: str + self, + id: str, + name: str, + address: str, + certificate: str, + config: dict, + fqdn: str, + ip: str, + key: str, + parent: dict, + port: int, + status: str, + updated_at: str, ): self.id = id self.name = name @@ -51,7 +51,7 @@ def from_dict(data: dict) -> "Combiner": parent=data["parent"] if "parent" in data else None, port=data["port"] if "port" in data else None, status=data["status"] if "status" in data else None, - updated_at=data["updated_at"] if "updated_at" in data else None + updated_at=data["updated_at"] if "updated_at" in data else None, ) @@ -82,12 +82,12 @@ def get(self, id: str, use_typing: bool = False) -> Combiner: def update(self, id: str, item: Combiner) -> bool: raise NotImplementedError("Update not implemented for CombinerStore") - def add(self, item: Combiner)-> Tuple[bool, Any]: - raise NotImplementedError("Add not implemented for CombinerStore") + def add(self, item: Combiner) -> Tuple[bool, Any]: + return super().add(item) def delete(self, id: str) -> bool: - if(ObjectId.is_valid(id)): - kwargs = { "_id": ObjectId(id)} + if ObjectId.is_valid(id): + kwargs = {"_id": ObjectId(id)} else: return False @@ -119,10 +119,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI result = [Combiner.from_dict(item) for item in response["result"]] if use_typing else response["result"] - return { - "count": response["count"], - "result": result - } + return {"count": response["count"], "result": result} def count(self, **kwargs) -> int: return super().count(**kwargs) diff --git a/fedn/network/storage/statestore/stores/prediction_store.py b/fedn/network/storage/statestore/stores/prediction_store.py new file mode 100644 index 000000000..3a14ec8b9 --- /dev/null +++ b/fedn/network/storage/statestore/stores/prediction_store.py @@ -0,0 +1,84 @@ +from typing import Any, Dict, List, Tuple + +import pymongo +from pymongo.database import Database + +from fedn.network.storage.statestore.stores.store import Store + + +class Prediction: + def __init__( + self, id: str, model_id: str, data: str, correlation_id: str, timestamp: str, prediction_id: str, meta: str, sender: dict = None, receiver: dict = None + ): + self.id = id + self.model_id = model_id + self.data = data + self.correlation_id = correlation_id + self.timestamp = timestamp + self.prediction_id = prediction_id + self.meta = meta + self.sender = sender + self.receiver = receiver + + def from_dict(data: dict) -> "Prediction": + return Prediction( + id=str(data["_id"]), + model_id=data["modelId"] if "modelId" in data else None, + data=data["data"] if "data" in data else None, + correlation_id=data["correlationId"] if "correlationId" in data else None, + timestamp=data["timestamp"] if "timestamp" in data else None, + session_id=data["sessionId"] if "sessionId" in data else None, + meta=data["meta"] if "meta" in data else None, + sender=data["sender"] if "sender" in data else None, + receiver=data["receiver"] if "receiver" in data else None, + ) + + +class PredictionStore(Store[Prediction]): + def __init__(self, database: Database, collection: str): + super().__init__(database, collection) + + def get(self, id: str, use_typing: bool = False) -> Prediction: + """Get an entity by id + param id: The id of the entity + type: str + description: The id of the entity, can be either the id or the Prediction (property) + param use_typing: Whether to return the entity as a typed object or as a dict + type: bool + return: The entity + """ + response = super().get(id, use_typing=use_typing) + return Prediction.from_dict(response) if use_typing else response + + def update(self, id: str, item: Prediction) -> bool: + raise NotImplementedError("Update not implemented for PredictionStore") + + def add(self, item: Prediction) -> Tuple[bool, Any]: + return super().add(item) + + def delete(self, id: str) -> bool: + raise NotImplementedError("Delete not implemented for PredictionStore") + + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, use_typing: bool = False, **kwargs) -> Dict[int, List[Prediction]]: + """List entities + param limit: The maximum number of entities to return + type: int + description: The maximum number of entities to return + param skip: The number of entities to skip + type: int + description: The number of entities to skip + param sort_key: The key to sort by + type: str + description: The key to sort by + param sort_order: The order to sort by + type: pymongo.DESCENDING + description: The order to sort by + param use_typing: Whether to return the entities as typed objects or as dicts + type: bool + description: Whether to return the entities as typed objects or as dicts + return: A dictionary with the count and a list of entities + """ + response = super().list(limit, skip, sort_key or "timestamp", sort_order, use_typing=use_typing, **kwargs) + + result = [Prediction.from_dict(item) for item in response["result"]] if use_typing else response["result"] + return {"count": response["count"], "result": result} diff --git a/fedn/network/storage/statestore/stores/status_store.py b/fedn/network/storage/statestore/stores/status_store.py index 3c79d2b8d..9233d0b23 100644 --- a/fedn/network/storage/statestore/stores/status_store.py +++ b/fedn/network/storage/statestore/stores/status_store.py @@ -8,17 +8,7 @@ class Status: def __init__( - self, - id: str, - status: str, - timestamp: str, - log_level: str, - data: str, - correlation_id: str, - type: str, - extra: str, - session_id: str, - sender: dict = None + self, id: str, status: str, timestamp: str, log_level: str, data: str, correlation_id: str, type: str, extra: str, session_id: str, sender: dict = None ): self.id = id self.status = status @@ -42,7 +32,7 @@ def from_dict(data: dict) -> "Status": type=data["type"] if "type" in data else None, extra=data["extra"] if "extra" in data else None, session_id=data["sessionId"] if "sessionId" in data else None, - sender=data["sender"] if "sender" in data else None + sender=data["sender"] if "sender" in data else None, ) @@ -65,8 +55,8 @@ def get(self, id: str, use_typing: bool = False) -> Status: def update(self, id: str, item: Status) -> bool: raise NotImplementedError("Update not implemented for StatusStore") - def add(self, item: Status)-> Tuple[bool, Any]: - raise NotImplementedError("Add not implemented for StatusStore") + def add(self, item: Status) -> Tuple[bool, Any]: + return super().add(item) def delete(self, id: str) -> bool: raise NotImplementedError("Delete not implemented for StatusStore") diff --git a/fedn/network/storage/statestore/stores/store.py b/fedn/network/storage/statestore/stores/store.py index eb9d8b1bb..f6a8f67e0 100644 --- a/fedn/network/storage/statestore/stores/store.py +++ b/fedn/network/storage/statestore/stores/store.py @@ -77,10 +77,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI result = [document for document in cursor] if use_typing else [from_document(document) for document in cursor] - return { - "count": count, - "result": result - } + return {"count": count, "result": result} def count(self, **kwargs) -> int: """Count entities diff --git a/fedn/network/storage/statestore/stores/validation_store.py b/fedn/network/storage/statestore/stores/validation_store.py index 7e8548135..59b5a0730 100644 --- a/fedn/network/storage/statestore/stores/validation_store.py +++ b/fedn/network/storage/statestore/stores/validation_store.py @@ -8,16 +8,7 @@ class Validation: def __init__( - self, - id: str, - model_id: str, - data: str, - correlation_id: str, - timestamp: str, - session_id: str, - meta: str, - sender: dict = None, - receiver: dict = None + self, id: str, model_id: str, data: str, correlation_id: str, timestamp: str, session_id: str, meta: str, sender: dict = None, receiver: dict = None ): self.id = id self.model_id = model_id @@ -39,7 +30,7 @@ def from_dict(data: dict) -> "Validation": session_id=data["sessionId"] if "sessionId" in data else None, meta=data["meta"] if "meta" in data else None, sender=data["sender"] if "sender" in data else None, - receiver=data["receiver"] if "receiver" in data else None + receiver=data["receiver"] if "receiver" in data else None, ) @@ -62,8 +53,8 @@ def get(self, id: str, use_typing: bool = False) -> Validation: def update(self, id: str, item: Validation) -> bool: raise NotImplementedError("Update not implemented for ValidationStore") - def add(self, item: Validation)-> Tuple[bool, Any]: - raise NotImplementedError("Add not implemented for ValidationStore") + def add(self, item: Validation) -> Tuple[bool, Any]: + return super().add(item) def delete(self, id: str) -> bool: raise NotImplementedError("Delete not implemented for ValidationStore") @@ -90,7 +81,4 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI response = super().list(limit, skip, sort_key or "timestamp", sort_order, use_typing=use_typing, **kwargs) result = [Validation.from_dict(item) for item in response["result"]] if use_typing else response["result"] - return { - "count": response["count"], - "result": result - } + return {"count": response["count"], "result": result}