Skip to content
This repository has been archived by the owner on Nov 20, 2023. It is now read-only.

Commit

Permalink
extendend post_save_hook for related fk
Browse files Browse the repository at this point in the history
  • Loading branch information
Yacobolo committed Oct 16, 2023
1 parent 9bde747 commit a1b1c63
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 19 deletions.
4 changes: 2 additions & 2 deletions backend/api/schemas/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import Field

from api.models import Job
from api.schemas.base import DependencyBaseOut, JobBaseOut
from api.schemas.base import DependencyBaseOut, JobBaseOut, TaskBaseOut


class JobIn(ModelSchema):
Expand All @@ -20,5 +20,5 @@ class Config:


class JobOut(JobBaseOut):
# tasks: List["TaskBase"]
tasks: List[TaskBaseOut]
dependencies: List[DependencyBaseOut]
51 changes: 42 additions & 9 deletions backend/api/utils/crud_hooks/post_save_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,56 @@ def set_m2m_relations_from_ids(
Set values of a ManyToManyField using given IDs.
"""
if hasattr(instance, many_to_many_field_name) and hasattr(instance, ids_field_name):
many_to_many_field = getattr(instance, many_to_many_field_name)
many_to_many_relation = getattr(instance, many_to_many_field_name)
ids_field = getattr(instance, ids_field_name)

# Assuming you want to set the ManyToManyField using the IDs
many_to_many_field.set(ids_field)
many_to_many_relation.set(ids_field)


def post_save_hook(*field_pairs: Tuple[str, str]):
def set_reverse_fk_relations_from_ids(
instance: Model, reverse_relation_name: str, ids_field_name: str
) -> None:
"""
Set values of a reverse ForeignKey using given IDs in a generic way.
"""
if hasattr(instance, reverse_relation_name) and hasattr(instance, ids_field_name):
# This is the manager for the related model (e.g., Task.objects for a Job instance)
related_model_manager = getattr(instance, reverse_relation_name).model.objects
ids_field = getattr(instance, ids_field_name)

# Get the name of the ForeignKey field in the related model (e.g., "job" for Task model)
fk_field_name = getattr(instance, reverse_relation_name).field.name

# Using the manager to fetch tasks regardless of their current ForeignKey relationship
related_model_manager.filter(id__in=ids_field).update(
**{fk_field_name: instance}
)


def post_save_hook(*relationship_configs: Tuple[str, str, str]):
"""
Return a post-save function to set multiple ManyToMany fields.
Each pair consists of (many_to_many_field_name, ids_field_name).
Return a post-save function to set relations based on relationship type.
Each config consists of (relationship_type, relation_field_name, ids_field_name).
"""

def post_save(request, instance) -> None:
for many_to_many_field_name, ids_field_name in field_pairs:
set_m2m_relations_from_ids(
instance, many_to_many_field_name, ids_field_name
)
for (
relationship_type,
relation_field_name,
ids_field_name,
) in relationship_configs:
if relationship_type == "m2m":
set_m2m_relations_from_ids(
instance, relation_field_name, ids_field_name
)
elif relationship_type == "reverse_fk":
print("reverse_fk")
set_reverse_fk_relations_from_ids(
instance, relation_field_name, ids_field_name
)
else:
# Handle potential unsupported relationship types
raise ValueError(f"Unsupported relationship type: {relationship_type}")

return post_save
4 changes: 2 additions & 2 deletions backend/api/views/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ class DependencyViewSet(ModelViewSet):
output_schema=DependencyOut,
pre_save=pre_save_hook(),
post_save=post_save_hook(
("jobs", "job_ids"),
("tasks", "task_ids"),
("m2m", "jobs", "job_ids"),
("m2m", "tasks", "task_ids"),
),
)
retrieve = RetrieveModelView(output_schema=DependencyOut)
Expand Down
4 changes: 2 additions & 2 deletions backend/api/views/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class JobViewSet(ModelViewSet):
output_schema=JobOut,
pre_save=pre_save_hook(),
post_save=post_save_hook(
("tasks", "task_ids"),
("dependencies", "dependency_ids"),
("m2m", "dependencies", "dependency_ids"),
("reverse_fk", "tasks", "task_ids"),
),
)
retrieve = RetrieveModelView(output_schema=JobOut)
Expand Down
2 changes: 1 addition & 1 deletion backend/api/views/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ResourceViewSet(ModelViewSet):
input_schema=ResourceIn,
output_schema=ResourceOut,
pre_save=pre_save_hook(),
post_save=post_save_hook(("resource_groups", "resource_group_ids")),
post_save=post_save_hook(("m2m", "resource_groups", "resource_group_ids")),
)
retrieve = RetrieveModelView(output_schema=ResourceOut)
update = UpdateModelView(
Expand Down
2 changes: 1 addition & 1 deletion backend/api/views/resource_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ResourceGroupsViewSet(ModelViewSet):
input_schema=ResourceGroupIn,
output_schema=ResourceGroupOut,
pre_save=pre_save_hook(),
post_save=post_save_hook(("resources", "resource_ids")),
post_save=post_save_hook(("m2m", "resources", "resource_ids")),
)
retrieve = RetrieveModelView(output_schema=ResourceGroupOut)
update = UpdateModelView(
Expand Down
4 changes: 2 additions & 2 deletions backend/api/views/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class TaskViewSet(ModelViewSet):
output_schema=TaskOut,
pre_save=pre_save_hook(),
post_save=post_save_hook(
("predecessors", "predecessor_ids"),
("dependencies", "dependency_ids"),
("m2m", "predecessors", "predecessor_ids"),
("m2m", "dependencies", "dependency_ids"),
),
)
retrieve = RetrieveModelView(output_schema=TaskOut)
Expand Down

0 comments on commit a1b1c63

Please sign in to comment.