Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PYD003/4/5 #4

Merged
merged 2 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v4
Expand Down
62 changes: 62 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,68 @@ class Model(BaseModel):
foo = 1 # Will error at runtime
```

### `PYD003` - *Unecessary Field call to specify a default value*

Raise an error if the [`Field`](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) function
is used only to specify a default value.

```python
class Model(BaseModel):
foo: int = Field(default=1)
```

Instead, consider specifying the default value directly:

```python
class Model(BaseModel):
foo: int = 1
```

### `PYD004` - *Default argument specified in annotated*

Raise an error if the `default` argument of the [`Field`](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) function is used together with [`Annotated`](https://docs.python.org/3/library/typing.html#typing.Annotated).

```python
class Model(BaseModel):
foo: Annotated[int, Field(default=1, description="desc")]
```

To make type checkers aware of the default value, consider specifying the default value directly:

```python
class Model(BaseModel):
foo: Annotated[int, Field(description="desc")] = 1
```

### `PYD005` - *Field name overrides annotation*

Raise an error if the field name clashes with the annotation.

```python
from datetime import date

class Model(BaseModel):
date: date | None = None
```

Because of how Python [evaluates](https://docs.python.org/3/reference/simple_stmts.html#annassign)
annotated assignments, unexpected issues can happen when using an annotation name that clashes with a field
name. Pydantic will try its best to warn you about such issues, but can fail in complex scenarios (and the
issue may even be silently ignored).

Instead, consider, using an [alias](https://docs.pydantic.dev/latest/concepts/fields/#field-aliases) or referencing your type under a different name:

```python
from datetime import date

date_ = date

class Model(BaseModel):
date_aliased: date | None = Field(default=None, alias="date")
# or
date: date_ | None = None
```

### `PYD010` - *Usage of `__pydantic_config__`*

Raise an error if a Pydantic configuration is set with [`__pydantic_config__`](https://docs.pydantic.dev/dev/concepts/config/#configuration-with-dataclass-from-the-standard-library-or-typeddict).
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ readme = "README.md"
authors = [
{name = "Victorien", email = "contact@vctrn.dev"}
]
requires-python = ">=3.8"
requires-python = ">=3.9"
classifiers = [
"Development Status :: 4 - Beta",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down Expand Up @@ -54,7 +53,7 @@ where = ["src"]
[tool.ruff]
line-length = 120
src = ["src"]
target-version = "py38"
target-version = "py39"

[tool.ruff.lint]
preview = true
Expand Down
49 changes: 42 additions & 7 deletions src/flake8_pydantic/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _has_pydantic_decorator(node: ast.ClassDef) -> bool:
for stmt in node.body:
if isinstance(stmt, ast.FunctionDef):
decorator_names = get_decorator_names(stmt.decorator_list)
if PYDANTIC_DECORATORS.intersection(decorator_names):
if PYDANTIC_DECORATORS & decorator_names:
return True
return False

Expand All @@ -91,12 +91,6 @@ def _has_pydantic_method(node: ast.ClassDef) -> bool:
return False


def is_dataclass(node: ast.ClassDef) -> bool:
"""Determine if a class is a dataclass."""

return bool({"dataclass", "pydantic_dataclass"}.intersection(get_decorator_names(node.decorator_list)))


def is_pydantic_model(node: ast.ClassDef, include_root_model: bool = True) -> bool:
"""Determine if a class definition is a Pydantic model.

Expand All @@ -119,3 +113,44 @@ def is_pydantic_model(node: ast.ClassDef, include_root_model: bool = True) -> bo
or _has_pydantic_decorator(node)
or _has_pydantic_method(node)
)


def is_dataclass(node: ast.ClassDef) -> bool:
"""Determine if a class is a dataclass."""

return bool({"dataclass", "pydantic_dataclass"} & get_decorator_names(node.decorator_list))


def is_function(node: ast.Call, function_name: str) -> bool:
return (
isinstance(node.func, ast.Name)
and node.func.id == function_name
or isinstance(node.func, ast.Attribute)
and node.func.attr == function_name
)


def is_name(node: ast.expr, name: str) -> bool:
return isinstance(node, ast.Name) and node.id == name or isinstance(node, ast.Attribute) and node.attr == name


def extract_annotations(node: ast.expr) -> set[str]:
annotations: set[str] = set()

if isinstance(node, ast.Name):
# foo: date = ...
annotations.add(node.id)
if isinstance(node, ast.BinOp):
# foo: date | None = ...
annotations |= extract_annotations(node.left)
annotations |= extract_annotations(node.right)
if isinstance(node, ast.Subscript):
# foo: dict[str, date]
# foo: Annotated[list[date], ...]
if isinstance(node.slice, ast.Tuple):
for elt in node.slice.elts:
annotations |= extract_annotations(elt)
if isinstance(node.slice, ast.Name):
annotations.add(node.slice.id)

return annotations
15 changes: 15 additions & 0 deletions src/flake8_pydantic/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ class PYD002(Error):
message = "Non-annotated attribute inside Pydantic model"


class PYD003(Error):
error_code = "PYD003"
message = "Unecessary Field call to specify a default value"


class PYD004(Error):
error_code = "PYD004"
message = "Default argument specified in annotated"


class PYD005(Error):
error_code = "PYD005"
message = "Field name overrides annotation"


class PYD010(Error):
error_code = "PYD010"
message = "Usage of __pydantic_config__"
3 changes: 2 additions & 1 deletion src/flake8_pydantic/plugin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import ast
from collections.abc import Iterator
from importlib.metadata import version
from typing import Any, Iterator
from typing import Any

from .visitor import Visitor

Expand Down
55 changes: 49 additions & 6 deletions src/flake8_pydantic/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Literal

from ._compat import TypeAlias
from ._utils import is_dataclass, is_pydantic_model
from .errors import PYD001, PYD002, PYD010, Error
from ._utils import extract_annotations, is_dataclass, is_function, is_name, is_pydantic_model
from .errors import PYD001, PYD002, PYD003, PYD004, PYD005, PYD010, Error

ClassType: TypeAlias = Literal["pydantic_model", "dataclass", "other_class"]

Expand Down Expand Up @@ -35,10 +35,7 @@ def _check_pyd_001(self, node: ast.AnnAssign) -> None:
if (
self.current_class in {"pydantic_model", "dataclass"}
and isinstance(node.value, ast.Call)
and (
(isinstance(node.value.func, ast.Name) and node.value.func.id == "Field")
or (isinstance(node.value.func, ast.Attribute) and node.value.func.attr == "Field")
)
and is_function(node.value, "Field")
and len(node.value.args) >= 1
):
self.errors.append(PYD001.from_node(node))
Expand All @@ -55,6 +52,49 @@ def _check_pyd_002(self, node: ast.ClassDef) -> None:
for assignment in invalid_assignments:
self.errors.append(PYD002.from_node(assignment))

def _check_pyd_003(self, node: ast.AnnAssign) -> None:
if (
self.current_class in {"pydantic_model", "dataclass"}
and isinstance(node.value, ast.Call)
and is_function(node.value, "Field")
and len(node.value.keywords) == 1
and node.value.keywords[0].arg == "default"
):
self.errors.append(PYD003.from_node(node))

def _check_pyd_004(self, node: ast.AnnAssign) -> None:
if (
self.current_class in {"pydantic_model", "dataclass"}
and isinstance(node.annotation, ast.Subscript)
and is_name(node.annotation.value, "Annotated")
and isinstance(node.annotation.slice, ast.Tuple)
):
field_call = next(
(
elt
for elt in node.annotation.slice.elts
if isinstance(elt, ast.Call)
and is_function(elt, "Field")
and any(k.arg == "default" for k in elt.keywords)
),
None,
)
if field_call is not None:
self.errors.append(PYD004.from_node(node))

def _check_pyd_005(self, node: ast.ClassDef) -> None:
if self.current_class in {"pydantic_model", "dataclass"}:
previous_targets: set[str] = set()

for stmt in node.body:
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
# TODO only add before if AnnAssign?
# the following seems to work:
# date: date
previous_targets.add(stmt.target.id)
if previous_targets & extract_annotations(stmt.annotation):
self.errors.append(PYD005.from_node(stmt))

def _check_pyd_010(self, node: ast.ClassDef) -> None:
if self.current_class == "other_class":
for stmt in node.body:
Expand All @@ -74,10 +114,13 @@ def _check_pyd_010(self, node: ast.ClassDef) -> None:
def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.enter_class(node)
self._check_pyd_002(node)
self._check_pyd_005(node)
self._check_pyd_010(node)
self.generic_visit(node)
self.leave_class()

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
self._check_pyd_001(node)
self._check_pyd_003(node)
self._check_pyd_004(node)
self.generic_visit(node)
2 changes: 1 addition & 1 deletion tests/rules/test_pyd001.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Model:

PYD001_OK = """
class Model(BaseModel):
a: int = Field(default=1)
a: int = Field(default=1, description="")
"""


Expand Down
33 changes: 33 additions & 0 deletions tests/rules/test_pyd003.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

import ast

import pytest

from flake8_pydantic.errors import PYD003, Error
from flake8_pydantic.visitor import Visitor

PYD003_NOT_OK = """
class Model(BaseModel):
a: int = Field(default=1)
"""

PYD003_OK = """
class Model(BaseModel):
a: int = Field(default=1, description="")
"""


@pytest.mark.parametrize(
["source", "expected"],
[
(PYD003_NOT_OK, [PYD003(3, 4)]),
(PYD003_OK, []),
],
)
def test_pyd003(source: str, expected: list[Error]) -> None:
module = ast.parse(source)
visitor = Visitor()
visitor.visit(module)

assert visitor.errors == expected
33 changes: 33 additions & 0 deletions tests/rules/test_pyd004.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

import ast

import pytest

from flake8_pydantic.errors import PYD004, Error
from flake8_pydantic.visitor import Visitor

PYD004_1 = """
class Model(BaseModel):
a: Annotated[int, Field(default=1, description="")]
"""

PYD004_2 = """
class Model(BaseModel):
a: Annotated[int, Unrelated(), Field(default=1)]
"""


@pytest.mark.parametrize(
["source", "expected"],
[
(PYD004_1, [PYD004(3, 4)]),
(PYD004_2, [PYD004(3, 4)]),
],
)
def test_pyd004(source: str, expected: list[Error]) -> None:
module = ast.parse(source)
visitor = Visitor()
visitor.visit(module)

assert visitor.errors == expected
Loading