Skip to content

Commit

Permalink
Increase coverage (#722)
Browse files Browse the repository at this point in the history
* Add tests

* pre-commit

* improve remove blacklist app

* update setup for test_backends
  • Loading branch information
kiraware authored Dec 4, 2023
1 parent 8ae34dd commit c3cfd3e
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import builtins
import uuid
from datetime import datetime, timedelta
from importlib import reload
from json import JSONEncoder
from unittest import mock
from unittest.mock import patch
Expand Down Expand Up @@ -45,6 +47,7 @@ def default(self, obj):

class TestTokenBackend(TestCase):
def setUp(self):
self.realimport = builtins.__import__
self.hmac_token_backend = TokenBackend("HS256", SECRET)
self.hmac_leeway_token_backend = TokenBackend("HS256", SECRET, leeway=LEEWAY)
self.rsa_token_backend = TokenBackend("RS256", PRIVATE_KEY, PUBLIC_KEY)
Expand Down Expand Up @@ -76,6 +79,28 @@ def test_init_fails_for_rs_algorithms_when_crypto_not_installed(self):
):
TokenBackend(algo, "not_secret")

def test_jwk_client_not_available(self):
from rest_framework_simplejwt import backends

def myimport(name, globals=None, locals=None, fromlist=(), level=0):
if name == "jwt" and fromlist == ("PyJWKClient", "PyJWKClientError"):
raise ImportError
return self.realimport(name, globals, locals, fromlist, level)

builtins.__import__ = myimport

# Reload backends, mock jwk client is not available
reload(backends)

self.assertEqual(backends.JWK_CLIENT_AVAILABLE, False)
self.assertEqual(backends.TokenBackend("HS256").jwks_client, None)

builtins.__import__ = self.realimport

@patch("jwt.encode", mock.Mock(return_value=b"test"))
def test_token_encode_should_return_str_for_old_PyJWT(self):
self.assertIsInstance(TokenBackend("HS256").encode({}), str)

def test_encode_hmac(self):
# Should return a JSON web token for the given payload
payload = {"exp": make_utc(datetime(year=2000, month=1, day=1))}
Expand Down
19 changes: 19 additions & 0 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from importlib import reload
from unittest.mock import Mock, patch

from django.test import SimpleTestCase
from pkg_resources import DistributionNotFound


class TestInit(SimpleTestCase):
def test_package_is_not_installed(self):
with patch(
"pkg_resources.get_distribution", Mock(side_effect=DistributionNotFound)
):
# Import package mock package is not installed
import rest_framework_simplejwt.__init__

self.assertEqual(rest_framework_simplejwt.__init__.__version__, None)

# Restore origin package without mock
reload(rest_framework_simplejwt.__init__)
24 changes: 24 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from importlib import reload
from unittest.mock import patch

from django.test import TestCase

from rest_framework_simplejwt.models import TokenUser
Expand All @@ -15,6 +18,18 @@ def setUp(self):

self.user = TokenUser(self.token)

def test_type_checking(self):
from rest_framework_simplejwt import models

with patch("typing.TYPE_CHECKING", True):
# Reload models, mock type checking
reload(models)

self.assertEqual(models.TYPE_CHECKING, True)

# Restore origin module without mock
reload(models)

def test_username(self):
self.assertEqual(self.user.username, "deep-thought")

Expand Down Expand Up @@ -60,6 +75,12 @@ def test_eq(self):
self.assertNotEqual(user1, user2)
self.assertEqual(user1, user3)

def test_eq_not_implemented(self):
user1 = TokenUser({api_settings.USER_ID_CLAIM: 1})
user2 = "user2"

self.assertFalse(user1 == user2)

def test_hash(self):
self.assertEqual(hash(self.user), hash(self.user.id))

Expand Down Expand Up @@ -105,3 +126,6 @@ def test_is_authenticated(self):

def test_get_username(self):
self.assertEqual(self.user.get_username(), "deep-thought")

def test_get_custom_claims_through_backup_getattr(self):
self.assertEqual(self.user.some_other_stuff, "arstarst")
61 changes: 61 additions & 0 deletions tests/test_serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import timedelta
from importlib import reload
from unittest.mock import MagicMock, patch

from django.conf import settings
from django.contrib.auth import get_user_model
from django.test import TestCase
from rest_framework import exceptions as drf_exceptions
Expand Down Expand Up @@ -76,6 +78,17 @@ def test_it_should_not_validate_if_user_not_found(self):
with self.assertRaises(drf_exceptions.AuthenticationFailed):
s.is_valid()

def test_it_should_pass_validate_if_request_not_in_context(self):
s = TokenObtainSerializer(
context={},
data={
"username": self.username,
"password": self.password,
},
)

s.is_valid()

def test_it_should_raise_if_user_not_active(self):
self.user.is_active = False
self.user.save()
Expand Down Expand Up @@ -372,6 +385,32 @@ def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_black
# Assert old refresh token is blacklisted
self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti)

@override_api_settings(
ROTATE_REFRESH_TOKENS=True,
BLACKLIST_AFTER_ROTATION=True,
)
def test_blacklist_app_not_installed_should_pass(self):
from rest_framework_simplejwt import serializers, tokens

# Remove blacklist app
new_apps = list(settings.INSTALLED_APPS)
new_apps.remove("rest_framework_simplejwt.token_blacklist")

with self.settings(INSTALLED_APPS=tuple(new_apps)):
# Reload module that blacklist app not installed
reload(tokens)
reload(serializers)

refresh = tokens.RefreshToken()

# Serializer validates
ser = serializers.TokenRefreshSerializer(data={"refresh": str(refresh)})
ser.validate({"refresh": str(refresh)})

# Restore origin module without mock
reload(tokens)
reload(serializers)


class TestTokenVerifySerializer(TestCase):
def test_it_should_raise_token_error_if_token_invalid(self):
Expand Down Expand Up @@ -489,3 +528,25 @@ def test_it_should_blacklist_refresh_token_if_everything_ok(self):

# Assert old refresh token is blacklisted
self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti)

def test_blacklist_app_not_installed_should_pass(self):
from rest_framework_simplejwt import serializers, tokens

# Remove blacklist app
new_apps = list(settings.INSTALLED_APPS)
new_apps.remove("rest_framework_simplejwt.token_blacklist")

with self.settings(INSTALLED_APPS=tuple(new_apps)):
# Reload module that blacklist app not installed
reload(tokens)
reload(serializers)

refresh = tokens.RefreshToken()

# Serializer validates
ser = serializers.TokenBlacklistSerializer(data={"refresh": str(refresh)})
ser.validate({"refresh": str(refresh)})

# Restore origin module without mock
reload(tokens)
reload(serializers)
32 changes: 32 additions & 0 deletions tests/test_token_blacklist.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from importlib import reload
from unittest.mock import patch

from django.contrib.auth.models import User
from django.core.management import call_command
from django.db.models import BigAutoField
from django.test import TestCase
from django.utils import timezone

from rest_framework_simplejwt.exceptions import TokenError
from rest_framework_simplejwt.serializers import TokenVerifySerializer
Expand All @@ -25,6 +27,19 @@ def setUp(self):
password="test_password",
)

def test_token_blacklist_old_django(self):
with patch("django.VERSION", (3, 1)):
# Import package mock blacklist old django
import rest_framework_simplejwt.token_blacklist.__init__ as blacklist

self.assertEqual(
blacklist.default_app_config,
("rest_framework_simplejwt.token_blacklist.apps.TokenBlacklistConfig"),
)

# Restore origin module without mock
reload(blacklist)

def test_sliding_tokens_are_added_to_outstanding_list(self):
token = SlidingToken.for_user(self.user)

Expand Down Expand Up @@ -114,6 +129,23 @@ def test_tokens_can_be_manually_blacklisted(self):

self.assertEqual(OutstandingToken.objects.count(), 2)

def test_outstanding_token_and_blacklisted_token_expected_str(self):
outstanding = OutstandingToken.objects.create(
user=self.user,
jti="abc",
token="xyz",
expires_at=timezone.now(),
)
blacklisted = BlacklistedToken.objects.create(token=outstanding)

expected_outstanding_str = "Token for {} ({})".format(
outstanding.user, outstanding.jti
)
expected_blacklisted_str = f"Blacklisted token for {blacklisted.token.user}"

self.assertEqual(str(outstanding), expected_outstanding_str)
self.assertEqual(str(blacklisted), expected_blacklisted_str)


class TestTokenBlacklistFlushExpiredTokens(TestCase):
def setUp(self):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime, timedelta
from importlib import reload
from unittest.mock import patch

from django.contrib.auth import get_user_model
Expand Down Expand Up @@ -39,6 +40,18 @@ def setUpTestData(cls):
password="test_password",
)

def test_type_checking(self):
from rest_framework_simplejwt import tokens

with patch("typing.TYPE_CHECKING", True):
# Reload tokens, mock type checking
reload(tokens)

self.assertEqual(tokens.TYPE_CHECKING, True)

# Restore origin module without mock
reload(tokens)

def test_init_no_token_type_or_lifetime(self):
class MyTestToken(Token):
pass
Expand Down Expand Up @@ -377,6 +390,11 @@ def test_for_user_with_username(self):
token = MyToken.for_user(self.user)
self.assertEqual(token[api_settings.USER_ID_CLAIM], self.username)

@override_api_settings(CHECK_REVOKE_TOKEN=True)
def test_revoke_token_claim_included_in_authorization_token(self):
token = MyToken.for_user(self.user)
self.assertIn(api_settings.REVOKE_TOKEN_CLAIM, token)

def test_get_token_backend(self):
token = MyToken()

Expand Down
13 changes: 13 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,16 @@ class CustomTokenView(TokenViewBase):
request = factory.post("/", {}, format="json")
res = view(request)
self.assertEqual(res.status_code, 400)


class TestTokenViewBase(APIViewTestCase):
def test_serializer_class_not_set_in_settings_and_class_attribute_or_wrong_path(
self,
):
view = TokenViewBase()
msg = "Could not import serializer '%s'" % view._serializer_class

with self.assertRaises(ImportError) as e:
view.get_serializer_class()

self.assertEqual(e.exception.msg, msg)

0 comments on commit c3cfd3e

Please sign in to comment.