Skip to content

Commit

Permalink
Merge pull request #12 from ImogenBits/pydanticv2
Browse files Browse the repository at this point in the history
Pydantic v2
  • Loading branch information
Benezivas authored Aug 23, 2023
2 parents 1964b3b + 76c5dc8 commit 74309dd
Show file tree
Hide file tree
Showing 21 changed files with 171 additions and 190 deletions.
13 changes: 6 additions & 7 deletions algobattle_problems/biclique/problem.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
"""The Biclique problem class."""
from algobattle.problem import Problem, UndirectedGraph, SolutionModel, ValidationError, maximize
from algobattle.util import u64, Role
from algobattle.problem import Problem, SolutionModel, maximize
from algobattle.util import Role, ValidationError
from algobattle.types import UndirectedGraph, Vertex


class Solution(SolutionModel[UndirectedGraph]):
"""A solution to a bipartite clique problem."""

s_1: set[u64]
s_2: set[u64]
s_1: set[Vertex]
s_2: set[Vertex]

def validate_solution(self, instance: UndirectedGraph, role: Role) -> None:
edge_set = set(instance.edges) | set(edge[::-1] for edge in instance.edges)
super().validate_solution(instance, role)
if any(i >= instance.num_vertices for i in self.s_1 | self.s_2):
raise ValidationError("Solution contains vertices that aren't in the instance.")
edge_set = set(instance.edges) | set(edge[::-1] for edge in instance.edges)
if len(self.s_1.intersection(self.s_2)) != 0:
raise ValidationError("Solution contains vertex sets that aren't disjoint.")
if any((u, v) not in edge_set for u in self.s_1 for v in self.s_2):
Expand Down
13 changes: 8 additions & 5 deletions algobattle_problems/biclique/tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Tests for the biclique problem."""
import unittest

from algobattle_problems.biclique.problem import UndirectedGraph, Solution, ValidationError, Role
from pydantic import ValidationError as PydanticValidationError
from algobattle.util import Role

from algobattle_problems.biclique.problem import UndirectedGraph, Solution, ValidationError


class Tests(unittest.TestCase):
Expand All @@ -10,22 +13,22 @@ class Tests(unittest.TestCase):
def test_vertices_exist(self):
"""Tests that only valid vertex indices are allowed."""
graph = UndirectedGraph(num_vertices=10, edges=[(i, j) for i in range(10) for j in range(i)])
sol = Solution(s_1=set(), s_2={20})
with self.assertRaises(ValidationError):
with self.assertRaises(PydanticValidationError):
sol = Solution.model_validate({"s_1": set(), "s_2": {20}}, context={"instance": graph})
sol.validate_solution(graph, Role.generator)

def test_edges_exist(self):
"""Tests that solutions that arent complete bicliques are not allowed."""
graph = UndirectedGraph(num_vertices=10, edges=[])
sol = Solution(s_1={1}, s_2={2})
with self.assertRaises(ValidationError):
sol = Solution.model_validate({"s_1": {1}, "s_2": {2}}, context={"instance": graph})
sol.validate_solution(graph, Role.generator)

def test_edges_missing(self):
"""Asserts that solutions that aren't bipartite are not allowed."""
graph = UndirectedGraph(num_vertices=10, edges=[(i, j) for i in range(10) for j in range(i)])
sol = Solution(s_1={1, 2}, s_2={3, 4})
with self.assertRaises(ValidationError):
sol = Solution.model_validate({"s_1": {1, 2}, "s_2": {3, 4}}, context={"instance": graph})
sol.validate_solution(graph, Role.generator)


Expand Down
77 changes: 33 additions & 44 deletions algobattle_problems/c4subgraphiso/problem.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,46 @@
"""The C4subgraphiso problem class."""

from algobattle.problem import Problem, UndirectedGraph, SolutionModel, ValidationError, maximize
from algobattle.util import u64, Role
from itertools import combinations, cycle, islice
from typing import Annotated, Iterator
from pydantic import field_validator

from algobattle.problem import Problem, SolutionModel, maximize
from algobattle.util import Role, ValidationError
from algobattle.types import UndirectedGraph, Vertex, UniqueItems

class Solution(SolutionModel[UndirectedGraph]):
"""A solution to a Square Subgraph Isomorphism problem."""
Square = Annotated[tuple[Vertex, Vertex, Vertex, Vertex], UniqueItems()]

squares: set[tuple[u64, u64, u64, u64]]

def validate_solution(self, instance: UndirectedGraph, role: Role) -> None:
super().validate_solution(instance, role)
self._all_entries_bounded_in_size(instance)
self._all_squares_in_instance(instance)
self._all_squares_node_disjoint()
self._all_squares_induced(instance)
def edges(square: Square) -> Iterator[tuple[Vertex, Vertex]]:
"""Returns all edges of a square."""
return zip(square, islice(cycle(square), 1, None))

def _all_entries_bounded_in_size(self, instance: UndirectedGraph) -> None:
for square in self.squares:
if any(node >= instance.num_vertices for node in square):
raise ValidationError("An element of the solution doesn't index an instance vertex.")

def _all_squares_node_disjoint(self) -> None:
used_nodes = set()
for square in self.squares:
for node in square:
if node in used_nodes:
raise ValidationError("A square in the solution is not node-disjoint to at least one other square.")
used_nodes.add(node)
def diagonals(square: Square) -> Iterator[tuple[Vertex, Vertex]]:
"""Returns the diagonals of a square."""
yield square[0], square[2]
yield square[1], square[3]

def _all_squares_induced(self, instance: UndirectedGraph) -> None:
edge_set = set(instance.edges)
for square in self.squares:
# Edges between opposing nodes of a square would mean the square is not induced by its nodes
unwanted_edges = [
(square[0], square[2]),
(square[2], square[0]),
(square[1], square[3]),
(square[3], square[1]),
]
if any(edge in edge_set for edge in unwanted_edges):
raise ValidationError("A square in the solution is not induced in the instance.")

def _all_squares_in_instance(self, instance: UndirectedGraph) -> None:
edge_set = set(instance.edges)
for square in self.squares:
if (
not ((square[0], square[1]) in edge_set or (square[1], square[0]) in edge_set)
or not ((square[1], square[2]) in edge_set or (square[2], square[1]) in edge_set)
or not ((square[2], square[3]) in edge_set or (square[3], square[2]) in edge_set)
or not ((square[3], square[0]) in edge_set or (square[0], square[3]) in edge_set)
):
raise ValidationError("A square is not part of the instance.")
class Solution(SolutionModel[UndirectedGraph]):
"""A solution to a Square Subgraph Isomorphism problem."""

squares: set[Square]

@field_validator("squares", mode="after")
@classmethod
def check_squares(cls, value: set[Square]) -> set[Square]:
if any(set(a) & set(b) for a, b in combinations(value, 2)):
raise ValueError("A square in the solution is not node-disjoint to at least one other square")
return value

def validate_solution(self, instance: UndirectedGraph, role: Role) -> None:
super().validate_solution(instance, role)
edge_set = set(instance.edges) | {(v, u) for u, v in instance.edges}
if any(edge not in edge_set for square in self.squares for edge in edges(square)):
raise ValidationError("A square is not part of the instance.")
if any(edge in edge_set for square in self.squares for edge in diagonals(square)):
raise ValidationError("A square in the solution is not induced in the instance.")

@maximize
def score(self, instance: UndirectedGraph, role: Role) -> float:
Expand Down
6 changes: 3 additions & 3 deletions algobattle_problems/c4subgraphiso/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setUpClass(cls) -> None:

def test_no_duplicate_squares(self):
with self.assertRaises(PydanticValidationError):
UndirectedGraph.parse_obj(
UndirectedGraph.model_validate(
{
"squares": {
(0, 1, 2, 3),
Expand Down Expand Up @@ -65,8 +65,8 @@ def test_score(self):
self.assertEqual(solution.score(self.instance, Role.solver), 2)

def test_squares_disjoin(self):
solution = Solution(squares={(0, 1, 2, 3), (0, 1, 8, 9)})
with self.assertRaises(ValidationError):
with self.assertRaises(PydanticValidationError):
solution = Solution(squares={(0, 1, 2, 3), (0, 1, 8, 9)})
solution.validate_solution(self.instance, Role.generator)


Expand Down
10 changes: 6 additions & 4 deletions algobattle_problems/clusterediting/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
from collections import defaultdict
from itertools import combinations

from algobattle.problem import Problem, UndirectedGraph, SolutionModel, ValidationError, minimize
from algobattle.util import u64, Role
from algobattle.problem import Problem, SolutionModel, minimize
from algobattle.util import Role, ValidationError
from algobattle.types import Vertex, UndirectedGraph


class Solution(SolutionModel[UndirectedGraph]):
"""A solution to a Cluster Editing problem."""

add: set[tuple[u64, u64]]
delete: set[tuple[u64, u64]]
add: set[tuple[Vertex, Vertex]]
delete: set[tuple[Vertex, Vertex]]

def validate_solution(self, instance: UndirectedGraph, role: Role) -> None:
super().validate_solution(instance, role)
edge_set = set(instance.edges)

# Apply modifications to graph
Expand Down
11 changes: 5 additions & 6 deletions algobattle_problems/domset/problem.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""The Clusterediting problem class."""
from algobattle.problem import Problem, UndirectedGraph, SolutionModel, ValidationError, minimize
from algobattle.util import u64, Role
from algobattle.problem import Problem, SolutionModel, minimize
from algobattle.util import Role, ValidationError
from algobattle.types import Vertex, UndirectedGraph


class Solution(SolutionModel[UndirectedGraph]):
"""A solution to a Dominating Set problem."""

domset: set[u64]
domset: set[Vertex]

def validate_solution(self, instance: UndirectedGraph, role: Role) -> None:
if any(u >= instance.num_vertices for u in self.domset):
raise ValidationError("A number in the domset is too large to be a vertex")

super().validate_solution(instance, role)
dominated = set(self.domset)
for u, v in instance.edges:
if u in self.domset:
Expand Down
22 changes: 12 additions & 10 deletions algobattle_problems/hikers/problem.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""The Hikers problem class."""
from algobattle.problem import Problem, InstanceModel, SolutionModel, ValidationError, maximize
from algobattle.util import u64, Role
from collections import Counter
from algobattle.problem import Problem, InstanceModel, SolutionModel, maximize
from algobattle.util import Role, ValidationError
from algobattle.types import u64, SizeIndex


Hiker = SizeIndex


class HikersInstance(InstanceModel):
"""The Tsptimewindows problem class."""
"""The Hikers instance class."""

hikers: list[tuple[u64, u64]]

Expand All @@ -14,22 +19,19 @@ def size(self) -> int:
return len(self.hikers)

def validate_instance(self) -> None:
super().validate_instance()
if any(min_size > max_size for min_size, max_size in self.hikers):
raise ValidationError("One hiker's minimum group size is larger than their maximum group size.")


class Solution(SolutionModel[HikersInstance]):
"""A solution to a Hikers problem."""

assignments: dict[u64, u64]
assignments: dict[Hiker, u64]

def validate_solution(self, instance: HikersInstance, role: Role) -> None:
if any(hiker >= len(instance.hikers) for hiker in self.assignments):
raise ValidationError("Solution contains hiker that is not in the instance.")

group_sizes: dict[int, int] = {}
for group in self.assignments.values():
group_sizes[group] = group_sizes.get(group, 0) + 1
super().validate_solution(instance, role)
group_sizes = Counter(self.assignments.values())

for hiker, group in self.assignments.items():
min_size, max_size = instance.hikers[hiker]
Expand Down
14 changes: 7 additions & 7 deletions algobattle_problems/hikers/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Tests for the hikers problem."""
import unittest

from pydantic import ValidationError as PydanticValidationError

from algobattle_problems.hikers.problem import HikersInstance, Solution, ValidationError, Role


Expand All @@ -20,8 +22,7 @@ def setUpClass(cls) -> None:
)

def test_solution_empty(self):
solution = Solution(assignments={})
solution.validate_solution(self.instance, Role.generator)
Solution.model_validate({"assignments": {}}, context={"instance": self.instance})

def test_solution_correct(self):
solution = Solution(
Expand All @@ -35,14 +36,13 @@ def test_solution_correct(self):
solution.validate_solution(self.instance, Role.generator)

def test_solution_wrong_hiker(self):
solution = Solution(assignments={10: 1})
with self.assertRaises(ValidationError):
solution.validate_solution(self.instance, Role.generator)
with self.assertRaises(PydanticValidationError):
Solution.model_validate({"assignments": {10: 1}}, context={"instance": self.instance})

def test_solution_hiker_unhappy(self):
solution = Solution(assignments={1: 1})
with self.assertRaises(ValidationError):
solution.validate_solution(self.instance, Role.generator)
sol = Solution.model_validate({"assignments": {1: 1}}, context={"instance": self.instance})
sol.validate_solution(self.instance, Role.generator)


if __name__ == "__main__":
Expand Down
10 changes: 6 additions & 4 deletions algobattle_problems/longestpathboundedfvs/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from networkx.algorithms.tree.recognition import is_forest
from networkx.classes.function import is_empty

from algobattle.problem import Problem, UndirectedGraph, SolutionModel, ValidationError, maximize
from algobattle.util import u64, Role
from algobattle.problem import Problem, SolutionModel, maximize
from algobattle.util import Role, ValidationError
from algobattle.types import Vertex, UndirectedGraph


class Instance(UndirectedGraph):
"""The Longestpathboundedfvs problem class."""

fvs: set[u64] = Field(hidden="solver")
fvs: set[Vertex] = Field(exclude=True)

def validate_instance(self) -> None:
super().validate_instance()
Expand All @@ -38,9 +39,10 @@ def valid_fvs_on_input(self) -> bool:
class Solution(SolutionModel[Instance]):
"""A solution to a Longest Path with Bounded Feedback Vertex Set problem."""

path: list[u64]
path: list[Vertex]

def validate_solution(self, instance: Instance, role: Role) -> None:
super().validate_solution(instance, role)
if not self._nodes_are_walk(instance):
raise ValidationError("The given path is not a walk in the instance graph.")
if not self._no_revisited_nodes():
Expand Down
30 changes: 10 additions & 20 deletions algobattle_problems/oscm3/problem.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,32 @@
"""The OSCM3 problem class."""
from algobattle.problem import Problem, InstanceModel, SolutionModel, ValidationError, minimize
from algobattle.util import u64, Role
from typing import Annotated
from algobattle.problem import Problem, InstanceModel, SolutionModel, minimize
from algobattle.util import Role
from algobattle.types import Vertex, MaxLen, UniqueItems, SizeLen


Neighbors = Annotated[set[Vertex], MaxLen(3)]


class Instance(InstanceModel):
"""The OSCM3 problem class."""

neighbors: dict[u64, set[u64]]
neighbors: dict[Vertex, Neighbors]

@property
def size(self) -> int:
return max(self.neighbors.keys()) + 1

def validate_instance(self) -> None:
super().validate_instance()
size = self.size
if any(not 0 <= v < size for v in self.neighbors):
raise ValidationError("Instance contains element of V_1 out of the permitted range.")
if any(not 0 <= v < size for neighbors in self.neighbors.values() for v in neighbors):
raise ValidationError("Instance contains element of V_2 out of the permitted range.")
if any(len(neighbors) > 3 for neighbors in self.neighbors.values()):
raise ValidationError("A vertex of V_1 has more than 3 neighbors.")
for u in range(size):
for u in range(self.size):
self.neighbors.setdefault(u, set())


class Solution(SolutionModel[Instance]):
"""A solution to a One-Sided Crossing Minimization-3 problem."""

vertex_order: list[u64]

def validate_solution(self, instance: Instance, role: Role) -> None:
if any(not 0 <= i < instance.size for i in self.vertex_order):
raise ValidationError("An element of the solution is not in the permitted range.")
if len(self.vertex_order) != len(set(self.vertex_order)):
raise ValidationError("The solution contains duplicate numbers.")
if len(self.vertex_order) != instance.size:
raise ValidationError("The solution does not order the whole instance.")
vertex_order: Annotated[list[Vertex], UniqueItems, SizeLen]

@minimize
def score(self, instance: Instance, role: Role) -> float:
Expand Down
Loading

0 comments on commit 74309dd

Please sign in to comment.