From a1b1c63ae8e0d549d60061aecd99269847263bbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacob=20=C3=98stergaard=20Nielsen?= Date: Mon, 16 Oct 2023 22:28:48 +0200 Subject: [PATCH] extendend post_save_hook for related fk --- backend/api/schemas/job.py | 4 +- .../api/utils/crud_hooks/post_save_hook.py | 51 +++++++++++++++---- backend/api/views/dependency.py | 4 +- backend/api/views/job.py | 4 +- backend/api/views/resource.py | 2 +- backend/api/views/resource_group.py | 2 +- backend/api/views/task.py | 4 +- 7 files changed, 52 insertions(+), 19 deletions(-) diff --git a/backend/api/schemas/job.py b/backend/api/schemas/job.py index 672df1f..533c481 100644 --- a/backend/api/schemas/job.py +++ b/backend/api/schemas/job.py @@ -5,7 +5,7 @@ from pydantic import Field from api.models import Job -from api.schemas.base import DependencyBaseOut, JobBaseOut +from api.schemas.base import DependencyBaseOut, JobBaseOut, TaskBaseOut class JobIn(ModelSchema): @@ -20,5 +20,5 @@ class Config: class JobOut(JobBaseOut): - # tasks: List["TaskBase"] + tasks: List[TaskBaseOut] dependencies: List[DependencyBaseOut] diff --git a/backend/api/utils/crud_hooks/post_save_hook.py b/backend/api/utils/crud_hooks/post_save_hook.py index 6ad90b7..23ab51b 100644 --- a/backend/api/utils/crud_hooks/post_save_hook.py +++ b/backend/api/utils/crud_hooks/post_save_hook.py @@ -10,23 +10,56 @@ def set_m2m_relations_from_ids( Set values of a ManyToManyField using given IDs. """ if hasattr(instance, many_to_many_field_name) and hasattr(instance, ids_field_name): - many_to_many_field = getattr(instance, many_to_many_field_name) + many_to_many_relation = getattr(instance, many_to_many_field_name) ids_field = getattr(instance, ids_field_name) # Assuming you want to set the ManyToManyField using the IDs - many_to_many_field.set(ids_field) + many_to_many_relation.set(ids_field) -def post_save_hook(*field_pairs: Tuple[str, str]): +def set_reverse_fk_relations_from_ids( + instance: Model, reverse_relation_name: str, ids_field_name: str +) -> None: + """ + Set values of a reverse ForeignKey using given IDs in a generic way. + """ + if hasattr(instance, reverse_relation_name) and hasattr(instance, ids_field_name): + # This is the manager for the related model (e.g., Task.objects for a Job instance) + related_model_manager = getattr(instance, reverse_relation_name).model.objects + ids_field = getattr(instance, ids_field_name) + + # Get the name of the ForeignKey field in the related model (e.g., "job" for Task model) + fk_field_name = getattr(instance, reverse_relation_name).field.name + + # Using the manager to fetch tasks regardless of their current ForeignKey relationship + related_model_manager.filter(id__in=ids_field).update( + **{fk_field_name: instance} + ) + + +def post_save_hook(*relationship_configs: Tuple[str, str, str]): """ - Return a post-save function to set multiple ManyToMany fields. - Each pair consists of (many_to_many_field_name, ids_field_name). + Return a post-save function to set relations based on relationship type. + Each config consists of (relationship_type, relation_field_name, ids_field_name). """ def post_save(request, instance) -> None: - for many_to_many_field_name, ids_field_name in field_pairs: - set_m2m_relations_from_ids( - instance, many_to_many_field_name, ids_field_name - ) + for ( + relationship_type, + relation_field_name, + ids_field_name, + ) in relationship_configs: + if relationship_type == "m2m": + set_m2m_relations_from_ids( + instance, relation_field_name, ids_field_name + ) + elif relationship_type == "reverse_fk": + print("reverse_fk") + set_reverse_fk_relations_from_ids( + instance, relation_field_name, ids_field_name + ) + else: + # Handle potential unsupported relationship types + raise ValueError(f"Unsupported relationship type: {relationship_type}") return post_save diff --git a/backend/api/views/dependency.py b/backend/api/views/dependency.py index 1ee71a4..c828ab0 100644 --- a/backend/api/views/dependency.py +++ b/backend/api/views/dependency.py @@ -75,8 +75,8 @@ class DependencyViewSet(ModelViewSet): output_schema=DependencyOut, pre_save=pre_save_hook(), post_save=post_save_hook( - ("jobs", "job_ids"), - ("tasks", "task_ids"), + ("m2m", "jobs", "job_ids"), + ("m2m", "tasks", "task_ids"), ), ) retrieve = RetrieveModelView(output_schema=DependencyOut) diff --git a/backend/api/views/job.py b/backend/api/views/job.py index 63e31a8..40ef252 100644 --- a/backend/api/views/job.py +++ b/backend/api/views/job.py @@ -51,8 +51,8 @@ class JobViewSet(ModelViewSet): output_schema=JobOut, pre_save=pre_save_hook(), post_save=post_save_hook( - ("tasks", "task_ids"), - ("dependencies", "dependency_ids"), + ("m2m", "dependencies", "dependency_ids"), + ("reverse_fk", "tasks", "task_ids"), ), ) retrieve = RetrieveModelView(output_schema=JobOut) diff --git a/backend/api/views/resource.py b/backend/api/views/resource.py index 80662e8..b9133a2 100644 --- a/backend/api/views/resource.py +++ b/backend/api/views/resource.py @@ -24,7 +24,7 @@ class ResourceViewSet(ModelViewSet): input_schema=ResourceIn, output_schema=ResourceOut, pre_save=pre_save_hook(), - post_save=post_save_hook(("resource_groups", "resource_group_ids")), + post_save=post_save_hook(("m2m", "resource_groups", "resource_group_ids")), ) retrieve = RetrieveModelView(output_schema=ResourceOut) update = UpdateModelView( diff --git a/backend/api/views/resource_group.py b/backend/api/views/resource_group.py index 38a0f68..5cb8100 100644 --- a/backend/api/views/resource_group.py +++ b/backend/api/views/resource_group.py @@ -25,7 +25,7 @@ class ResourceGroupsViewSet(ModelViewSet): input_schema=ResourceGroupIn, output_schema=ResourceGroupOut, pre_save=pre_save_hook(), - post_save=post_save_hook(("resources", "resource_ids")), + post_save=post_save_hook(("m2m", "resources", "resource_ids")), ) retrieve = RetrieveModelView(output_schema=ResourceGroupOut) update = UpdateModelView( diff --git a/backend/api/views/task.py b/backend/api/views/task.py index 65779a4..2833254 100644 --- a/backend/api/views/task.py +++ b/backend/api/views/task.py @@ -51,8 +51,8 @@ class TaskViewSet(ModelViewSet): output_schema=TaskOut, pre_save=pre_save_hook(), post_save=post_save_hook( - ("predecessors", "predecessor_ids"), - ("dependencies", "dependency_ids"), + ("m2m", "predecessors", "predecessor_ids"), + ("m2m", "dependencies", "dependency_ids"), ), ) retrieve = RetrieveModelView(output_schema=TaskOut)