Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-1081 | Use stores in Combiner + ModelPredict #718

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions fedn/cli/status_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
@main.group("status")
@click.pass_context
def status_cmd(ctx):
""":param ctx:
"""
""":param ctx:"""
pass


Expand Down
15 changes: 9 additions & 6 deletions fedn/network/api/v1/inference_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,31 @@

from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.shared import control
from fedn.network.api.v1.shared import api_version
from fedn.network.api.v1.shared import api_version, mdb
from fedn.network.storage.statestore.stores.inference_store import InferenceStore

bp = Blueprint("inference", __name__, url_prefix=f"/api/{api_version}/infer")

inference_store = InferenceStore(mdb, "control.inferences")


@bp.route("/start", methods=["POST"])
@jwt_auth_required(role="admin")
def start_session():
"""Start a new inference session.
param: session_id: The session id to start.
type: session_id: str
param: inference_id: The session id to start.
type: inference_id: str
param: rounds: The number of rounds to run.
type: rounds: int
"""
try:
data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict()
session_id: str = data.get("session_id")
inference_id: str = data.get("inference_id")

if not session_id or session_id == "":
if not inference_id or inference_id == "":
return jsonify({"message": "Session ID is required"}), 400

session_config = {"session_id": session_id}
session_config = {"inference_id": inference_id}

threading.Thread(target=control.inference_session, kwargs={"config": session_config}).start()

Expand Down
4 changes: 4 additions & 0 deletions fedn/network/api/v1/status_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,12 @@ def get_statuses():
limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers)
kwargs = request.args.to_dict()

# print all the typed headers
print(f"limit: {limit}, skip: {skip}, sort_key: {sort_key}, sort_order: {sort_order}, use_typing: {use_typing}")
print(f"kwargs: {kwargs}")
statuses = status_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs)

print(f"statuses: {statuses}")
result = [status.__dict__ for status in statuses["result"]] if use_typing else statuses["result"]

response = {"count": statuses["count"], "result": result}
Expand Down
28 changes: 27 additions & 1 deletion fedn/network/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def connect(self, combiner_config):
logger.debug("Client using metadata: {}.".format(self.metadata))
port = combiner_config["port"]
secure = False
if combiner_config["fqdn"] is not None:
if "fqdn" in combiner_config.keys() and combiner_config["fqdn"] is not None:
host = combiner_config["fqdn"]
# assuming https if fqdn is used
port = 443
Expand Down Expand Up @@ -721,7 +721,33 @@ def process_request(self):
self.state = ClientState.idle
continue
presigned_url = presigned_url["presigned_url"]
# Obs that session_id in request is the inference_id
_ = self._process_inference_request(request.model_id, request.session_id, presigned_url)
inference = fedn.ModelInference()
inference.sender.name = self.name
inference.sender.role = fedn.WORKER
inference.receiver.name = request.sender.name
inference.receiver.name = request.sender.name
inference.receiver.role = request.sender.role
inference.model_id = str(request.model_id)
# TODO: Add inference data
inference.data = ""
inference.timestamp.GetCurrentTime()
inference.correlation_id = request.correlation_id
# Obs that session_id in request is the inference_id
inference.inference_id = request.session_id

try:
_ = self.combinerStub.SendModelInference(inference, metadata=self.metadata)
status_type = fedn.StatusType.INFERENCE
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can change StatusType.INFERENCE to StatusType.MODEL_PREDICTION in fedn.proto

self.send_status(
"Model inference completed.", log_level=fedn.Status.AUDIT, type=status_type, request=inference, sesssion_id=request.session_id
)
except grpc.RpcError as e:
status_code = e.code()
logger.error("GRPC error, {}.".format(status_code.name))
logger.debug(e)

self.state = ClientState.idle
except queue.Empty:
pass
Expand Down
106 changes: 89 additions & 17 deletions fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from enum import Enum
from typing import TypedDict

from google.protobuf.json_format import MessageToDict

import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.certificate.certificate import Certificate
from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream
from fedn.network.combiner.roundhandler import RoundConfig, RoundHandler
from fedn.network.combiner.shared import repository, statestore
from fedn.network.combiner.shared import client_store, combiner_store, inference_store, repository, statestore, status_store, validation_store
from fedn.network.grpc.server import Server, ServerConfig

VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$"
Expand Down Expand Up @@ -82,6 +84,12 @@ def __init__(self, config):
set_log_stream(config.get("logfile", None))

# Client queues
# Each client in the dict is stored with its client_id as key, and the value is a dict with keys:
# name: str
# status: str
# last_seen: str
# fedn.Queue.TASK_QUEUE: queue.Queue
# Obs that fedn.Queue.TASK_QUEUE is just str(1)
self.clients = {}

# Validate combiner name
Expand All @@ -106,19 +114,29 @@ def __init__(self, config):
"address": config["host"],
"parent": "localhost",
"ip": "",
"updated_at": str(datetime.now()),
}
self.statestore.set_combiner(interface_config)
combiner_store.add(interface_config)

# Fetch all clients previously connected to the combiner
# If a client and a combiner goes down at the same time,
# the client will be stuck listed as "online" in the statestore.
# Set the status to offline for previous clients.
previous_clients = self.statestore.clients.find({"combiner": config["name"]})
for client in previous_clients:
previous_clients = client_store.list(limit=0, skip=0, sort_key=None, kwargs={"combiner": self.id})
count = previous_clients["count"]
result = previous_clients["result"]
logger.info(f"Found {count} previous clients")
logger.info("Updating previous clients status to offline")
for client in result:
try:
self.statestore.set_client({"name": client["name"], "status": "offline", "client_id": client["client_id"]})
except KeyError:
self.statestore.set_client({"name": client["name"], "status": "offline"})
if "client_id" in client.keys():
client_store.update("client_id", client["client_id"], {"name": client["name"], "status": "offline"})
else:
# Old clients might not have a client_id
client_store.update("name", client["name"], {"name": client["name"], "status": "offline"})

except Exception as e:
logger.error("Failed to update previous client status: {}".format(str(e)))

# Set up gRPC server configuration
if config["secure"]:
Expand Down Expand Up @@ -191,7 +209,7 @@ def request_model_validation(self, session_id, model_id, clients=[]):
else:
logger.info("Sent model validation request for model {} to {} clients".format(model_id, len(clients)))

def request_model_inference(self, session_id: str, model_id: str, clients: list = []) -> None:
def request_model_inference(self, inference_id: str, model_id: str, clients: list = []) -> None:
"""Ask clients to perform inference on the model.

:param model_id: the model id to perform inference on
Expand All @@ -202,7 +220,7 @@ def request_model_inference(self, session_id: str, model_id: str, clients: list
:type clients: list

"""
clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients)
clients = self._send_request_type(fedn.StatusType.INFERENCE, inference_id, model_id, clients)

if len(clients) < 20:
logger.info("Sent model inference request for model {} to clients {}".format(model_id, clients))
Expand All @@ -214,6 +232,8 @@ def _send_request_type(self, request_type, session_id, model_id, config=None, cl

:param request_type: the type of request
:type request_type: :class:`fedn.network.grpc.fedn_pb2.StatusType`
:param session_id: the session id to send in the request. Obs that for inference, this is the inference id.
:type session_id: str
:param model_id: the model id to send in the request
:type model_id: str
:param config: the model configuration to send to clients
Expand Down Expand Up @@ -354,7 +374,7 @@ def _list_active_clients(self, channel):
then = self.clients[client]["last_seen"]
if (now - then) < timedelta(seconds=10):
clients["active_clients"].append(client)
# If client has changed status, update statestore
# If client has changed status, update client queue
if status != "online":
self.clients[client]["status"] = "online"
clients["update_active_clients"].append(client)
Expand All @@ -363,9 +383,11 @@ def _list_active_clients(self, channel):
clients["update_offline_clients"].append(client)
# Update statestore with client status
if len(clients["update_active_clients"]) > 0:
self.statestore.update_client_status(clients["update_active_clients"], "online")
for client in clients["update_active_clients"]:
client_store.update("client_id", client, {"status": "online"})
if len(clients["update_offline_clients"]) > 0:
self.statestore.update_client_status(clients["update_offline_clients"], "offline")
for client in clients["update_offline_clients"]:
client_store.update("client_id", client, {"status": "offline"})

return clients["active_clients"]

Expand Down Expand Up @@ -395,10 +417,11 @@ def _put_request_to_client_queue(self, request, queue_name):
def _send_status(self, status):
"""Report a status to backend db.

:param status: the status to report
:param status: the status message to report
:type status: :class:`fedn.network.grpc.fedn_pb2.Status`
"""
self.statestore.report_status(status)
data = MessageToDict(status, including_default_value_fields=True)
_ = status_store.add(data)

def _flush_model_update_queue(self):
"""Clear the model update queue (aggregator).
Expand Down Expand Up @@ -627,14 +650,38 @@ def TaskStream(self, response, context):

self.__whoami(status.sender, self)

# Subscribe client, this also adds the client to self.clients
self._subscribe_client_to_queue(client, fedn.Queue.TASK_QUEUE)
q = self.__get_queue(client, fedn.Queue.TASK_QUEUE)

self._send_status(status)

# Set client status to online
self.clients[client.client_id]["status"] = "online"
self.statestore.set_client({"name": client.name, "status": "online", "client_id": client.client_id, "last_seen": datetime.now()})
try:
# If the client is already in the client store, update the status
success, result = client_store.update(
"client_id",
client.client_id,
{"name": client.name, "status": "online", "client_id": client.client_id, "last_seen": datetime.now(), "combiner": self.id},
)
if not success and result == "Entity not found":
# If the client is not in the client store, add the client
success, result = client_store.add(
{
"name": client.name,
"status": "online",
"client_id": client.client_id,
"last_seen": datetime.now(),
"combiner": self.id,
"combiner_preferred": self.id,
"updated_at": datetime.now(),
}
)
elif not success:
logger.error(result)
except Exception as e:
logger.error(f"Failed to update client status: {str(e)}")

# Keep track of the time context has been active
start_time = time.time()
Expand Down Expand Up @@ -673,7 +720,12 @@ def register_model_validation(self, validation):
:param validation: the model validation
:type validation: :class:`fedn.network.grpc.fedn_pb2.ModelValidation`
"""
self.statestore.report_validation(validation)
data = MessageToDict(validation, including_default_value_fields=True)
success, result = validation_store.add(data)
if not success:
logger.error(result)
else:
logger.info("Model validation registered: {}".format(result))

def SendModelValidation(self, request, context):
"""Send a model validation response.
Expand All @@ -687,12 +739,32 @@ def SendModelValidation(self, request, context):
"""
logger.info("Recieved ModelValidation from {}".format(request.sender.name))

self.register_model_validation(request)
validation = MessageToDict(request, including_default_value_fields=True)
validation_store.add(validation)

response = fedn.Response()
response.response = "RECEIVED ModelValidation {} from client {}".format(response, response.sender.name)
return response

def SendModelInference(self, request, context):
"""Send a model inference response.

:param request: the request
:type request: :class:`fedn.network.grpc.fedn_pb2.ModelInference`
:param context: the context
:type context: :class:`grpc._server._Context`
:return: the response
:rtype: :class:`fedn.network.grpc.fedn_pb2.Response`
"""
logger.info("Recieved ModelInference from {}".format(request.sender.name))

result = MessageToDict(request, including_default_value_fields=True)
inference_store.add(result)

response = fedn.Response()
response.response = "RECEIVED ModelInference {} from client {}".format(response, response.sender.name)
return response

####################################################################################################################

def run(self):
Expand Down
20 changes: 13 additions & 7 deletions fedn/network/combiner/roundhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class RoundConfig(TypedDict):
:type model_metadata: dict
:param session_id: The session identifier. Set by (Controller?).
:type session_id: str
:param inference_id: The inference identifier. Only used for inference tasks.
:type inference_id: str
:param helper_type: The helper type.
:type helper_type: str
:param aggregator: The aggregator type.
Expand Down Expand Up @@ -250,7 +252,7 @@ def _validation_round(self, session_id, model_id, clients):
"""
self.server.request_model_validation(session_id, model_id, clients=clients)

def _inference_round(self, session_id: str, model_id: str, clients: list):
def _inference_round(self, inference_id: str, model_id: str, clients: list):
"""Send model inference requests to clients.

:param config: The round config object (passed to the client).
Expand All @@ -260,7 +262,7 @@ def _inference_round(self, session_id: str, model_id: str, clients: list):
:param model_id: The ID of the model to use for inference
:type model_id: str
"""
self.server.request_model_inference(session_id, model_id, clients=clients)
self.server.request_model_inference(inference_id, model_id, clients=clients)

def stage_model(self, model_id, timeout_retry=3, retry=2):
"""Download a model from persistent storage and set in modelservice.
Expand Down Expand Up @@ -348,7 +350,7 @@ def execute_validation_round(self, session_id, model_id):
validators = self._assign_round_clients(self.server.max_clients, type="validators")
self._validation_round(session_id, model_id, validators)

def execute_inference_round(self, session_id: str, model_id: str) -> None:
def execute_inference_round(self, inference_id: str, model_id: str) -> None:
"""Coordinate inference rounds as specified in config.

:param round_config: The round config object.
Expand All @@ -358,7 +360,7 @@ def execute_inference_round(self, session_id: str, model_id: str) -> None:
self.stage_model(model_id)
# TODO: Implement inference client type
clients = self._assign_round_clients(self.server.max_clients, type="validators")
self._inference_round(session_id, model_id, clients)
self._inference_round(inference_id, model_id, clients)

def execute_training_round(self, config):
"""Coordinates clients to execute training tasks.
Expand Down Expand Up @@ -407,25 +409,29 @@ def run(self, polling_interval=1.0):
while True:
try:
round_config = self.round_configs.get(block=False)
session_id = round_config["session_id"]
model_id = round_config["model_id"]

# Check that the minimum allowed number of clients are connected
ready = self._check_nr_round_clients(round_config)
round_meta = {}

if ready:
if round_config["task"] == "training":
session_id = round_config["session_id"]
model_id = round_config["model_id"]
tic = time.time()
round_meta = self.execute_training_round(round_config)
round_meta["time_exec_training"] = time.time() - tic
round_meta["status"] = "Success"
round_meta["name"] = self.server.id
self.server.statestore.set_round_combiner_data(round_meta)
elif round_config["task"] == "validation":
session_id = round_config["session_id"]
model_id = round_config["model_id"]
self.execute_validation_round(session_id, model_id)
elif round_config["task"] == "inference":
self.execute_inference_round(session_id, model_id)
inference_id = round_config["inference_id"]
model_id = round_config["model_id"]
self.execute_inference_round(inference_id, model_id)
else:
logger.warning("config contains unkown task type.")
else:
Expand Down
Loading
Loading