Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove oauth provider #15588

Open
wants to merge 4 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions awx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,6 @@ def version_file():
from django.db import connection


def oauth2_getattribute(self, attr):
# Custom method to override
# oauth2_provider.settings.OAuth2ProviderSettings.__getattribute__
from django.conf import settings
from oauth2_provider.settings import DEFAULTS

val = None
if (isinstance(attr, str)) and (attr in DEFAULTS) and (not attr.startswith('_')):
# certain Django OAuth Toolkit migrations actually reference
# setting lookups for references to model classes (e.g.,
# oauth2_settings.REFRESH_TOKEN_MODEL)
# If we're doing an OAuth2 setting lookup *while running* a migration,
# don't do our usual database settings lookup
val = settings.OAUTH2_PROVIDER.get(attr)
if val is None:
val = object.__getattribute__(self, attr)
return val


def prepare_env():
# Update the default settings environment variable based on current mode.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'awx.settings.%s' % MODE)
Expand All @@ -89,12 +70,6 @@ def prepare_env():
if not settings.DEBUG: # pragma: no cover
warnings.simplefilter('ignore', DeprecationWarning)

# Monkeypatch Oauth2 toolkit settings class to check for settings
# in django.conf settings each time, not just once during import
import oauth2_provider.settings

oauth2_provider.settings.OAuth2ProviderSettings.__getattribute__ = oauth2_getattribute


def manage():
# Prepare the AWX environment.
Expand Down
16 changes: 0 additions & 16 deletions awx/api/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
# Django REST Framework
from rest_framework import authentication

# Django-OAuth-Toolkit
from oauth2_provider.contrib.rest_framework import OAuth2Authentication

logger = logging.getLogger('awx.api.authentication')


Expand All @@ -36,16 +33,3 @@ def authenticate_header(self, request):
class SessionAuthentication(authentication.SessionAuthentication):
def authenticate_header(self, request):
return 'Session'


class LoggedOAuth2Authentication(OAuth2Authentication):
def authenticate(self, request):
ret = super(LoggedOAuth2Authentication, self).authenticate(request)
if ret:
user, token = ret
username = user.username if user else '<none>'
logger.info(
smart_str(u"User {} performed a {} to {} through the API using OAuth 2 token {}.".format(username, request.method, request.path, token.pk))
)
setattr(user, 'oauth_scopes', [x for x in token.scope.split() if x])
return ret
23 changes: 0 additions & 23 deletions awx/api/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

# AWX
from awx.conf import fields, register, register_validate
from awx.api.fields import OAuth2ProviderField
from oauth2_provider.settings import oauth2_settings


register(
Expand Down Expand Up @@ -46,27 +44,6 @@
category=_('Authentication'),
category_slug='authentication',
)
register(
'OAUTH2_PROVIDER',
field_class=OAuth2ProviderField,
default={
'ACCESS_TOKEN_EXPIRE_SECONDS': oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS,
'AUTHORIZATION_CODE_EXPIRE_SECONDS': oauth2_settings.AUTHORIZATION_CODE_EXPIRE_SECONDS,
'REFRESH_TOKEN_EXPIRE_SECONDS': oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS,
},
label=_('OAuth 2 Timeout Settings'),
help_text=_(
'Dictionary for customizing OAuth 2 timeouts, available items are '
'`ACCESS_TOKEN_EXPIRE_SECONDS`, the duration of access tokens in the number '
'of seconds, `AUTHORIZATION_CODE_EXPIRE_SECONDS`, the duration of '
'authorization codes in the number of seconds, and `REFRESH_TOKEN_EXPIRE_SECONDS`, '
'the duration of refresh tokens, after expired access tokens, '
'in the number of seconds.'
),
category=_('Authentication'),
category_slug='authentication',
unit=_('seconds'),
)
register(
'LOGIN_REDIRECT_OVERRIDE',
field_class=fields.CharField,
Expand Down
14 changes: 0 additions & 14 deletions awx/api/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from rest_framework import serializers

# AWX
from awx.conf import fields
from awx.main.models import Credential

__all__ = ['BooleanNullField', 'CharNullField', 'ChoiceNullField', 'VerbatimField']
Expand Down Expand Up @@ -79,19 +78,6 @@ def to_representation(self, value):
return value


class OAuth2ProviderField(fields.DictField):
default_error_messages = {'invalid_key_names': _('Invalid key names: {invalid_key_names}')}
valid_key_names = {'ACCESS_TOKEN_EXPIRE_SECONDS', 'AUTHORIZATION_CODE_EXPIRE_SECONDS', 'REFRESH_TOKEN_EXPIRE_SECONDS'}
child = fields.IntegerField(min_value=1)

def to_internal_value(self, data):
data = super(OAuth2ProviderField, self).to_internal_value(data)
invalid_flags = set(data.keys()) - self.valid_key_names
if invalid_flags:
self.fail('invalid_key_names', invalid_key_names=', '.join(list(invalid_flags)))
return data


class DeprecatedCredentialField(serializers.IntegerField):
def __init__(self, **kwargs):
kwargs['allow_null'] = True
Expand Down
208 changes: 1 addition & 207 deletions awx/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
from datetime import timedelta
from uuid import uuid4

# OAuth2
from oauthlib import oauth2
from oauthlib.common import generate_token

# Jinja
from jinja2 import sandbox, StrictUndefined
from jinja2.exceptions import TemplateSyntaxError, UndefinedError, SecurityError
Expand Down Expand Up @@ -50,7 +46,7 @@

# AWX
from awx.main.access import get_user_capabilities
from awx.main.constants import ACTIVE_STATES, CENSOR_VALUE, org_role_to_permission
from awx.main.constants import ACTIVE_STATES, org_role_to_permission
from awx.main.models import (
ActivityStream,
AdHocCommand,
Expand Down Expand Up @@ -79,14 +75,11 @@
Label,
Notification,
NotificationTemplate,
OAuth2AccessToken,
OAuth2Application,
Organization,
Project,
ProjectUpdate,
ProjectUpdateEvent,
ReceptorAddress,
RefreshToken,
Role,
Schedule,
SystemJob,
Expand Down Expand Up @@ -1059,9 +1052,6 @@ def get_related(self, obj):
roles=self.reverse('api:user_roles_list', kwargs={'pk': obj.pk}),
activity_stream=self.reverse('api:user_activity_stream_list', kwargs={'pk': obj.pk}),
access_list=self.reverse('api:user_access_list', kwargs={'pk': obj.pk}),
tokens=self.reverse('api:o_auth2_token_list', kwargs={'pk': obj.pk}),
authorized_tokens=self.reverse('api:user_authorized_token_list', kwargs={'pk': obj.pk}),
personal_tokens=self.reverse('api:user_personal_token_list', kwargs={'pk': obj.pk}),
)
)
return res
Expand All @@ -1078,199 +1068,6 @@ class Meta:
fields = ('*', '-is_system_auditor')


class BaseOAuth2TokenSerializer(BaseSerializer):
refresh_token = serializers.SerializerMethodField()
token = serializers.SerializerMethodField()
ALLOWED_SCOPES = ['read', 'write']

class Meta:
model = OAuth2AccessToken
fields = ('*', '-name', 'description', 'user', 'token', 'refresh_token', 'application', 'expires', 'scope')
read_only_fields = ('user', 'token', 'expires', 'refresh_token')
extra_kwargs = {'scope': {'allow_null': False, 'required': False}, 'user': {'allow_null': False, 'required': True}}

def get_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return obj.token
else:
return CENSOR_VALUE
except ObjectDoesNotExist:
return ''

def get_refresh_token(self, obj):
request = self.context.get('request', None)
try:
if not obj.refresh_token:
return None
elif request.method == 'POST':
return getattr(obj.refresh_token, 'token', '')
else:
return CENSOR_VALUE
except ObjectDoesNotExist:
return None

def get_related(self, obj):
ret = super(BaseOAuth2TokenSerializer, self).get_related(obj)
if obj.user:
ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk})
if obj.application:
ret['application'] = self.reverse('api:o_auth2_application_detail', kwargs={'pk': obj.application.pk})
ret['activity_stream'] = self.reverse('api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk})
return ret

def _is_valid_scope(self, value):
if not value or (not isinstance(value, str)):
return False
words = value.split()
for word in words:
if words.count(word) > 1:
return False # do not allow duplicates
if word not in self.ALLOWED_SCOPES:
return False
return True

def validate_scope(self, value):
if not self._is_valid_scope(value):
raise serializers.ValidationError(_('Must be a simple space-separated string with allowed scopes {}.').format(self.ALLOWED_SCOPES))
return value

def create(self, validated_data):
validated_data['user'] = self.context['request'].user
try:
return super(BaseOAuth2TokenSerializer, self).create(validated_data)
except oauth2.AccessDeniedError as e:
raise PermissionDenied(str(e))


class UserAuthorizedTokenSerializer(BaseOAuth2TokenSerializer):
class Meta:
extra_kwargs = {
'scope': {'allow_null': False, 'required': False},
'user': {'allow_null': False, 'required': True},
'application': {'allow_null': False, 'required': True},
}

def create(self, validated_data):
current_user = self.context['request'].user
validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta(seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'])
obj = super(UserAuthorizedTokenSerializer, self).create(validated_data)
obj.save()
if obj.application:
RefreshToken.objects.create(user=current_user, token=generate_token(), application=obj.application, access_token=obj)
return obj


class OAuth2TokenSerializer(BaseOAuth2TokenSerializer):
def create(self, validated_data):
current_user = self.context['request'].user
validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta(seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'])
obj = super(OAuth2TokenSerializer, self).create(validated_data)
if obj.application and obj.application.user:
obj.user = obj.application.user
obj.save()
if obj.application:
RefreshToken.objects.create(user=current_user, token=generate_token(), application=obj.application, access_token=obj)
return obj


class OAuth2TokenDetailSerializer(OAuth2TokenSerializer):
class Meta:
read_only_fields = ('*', 'user', 'application')


class UserPersonalTokenSerializer(BaseOAuth2TokenSerializer):
class Meta:
read_only_fields = ('user', 'token', 'expires', 'application')

def create(self, validated_data):
validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta(seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'])
validated_data['application'] = None
obj = super(UserPersonalTokenSerializer, self).create(validated_data)
obj.save()
return obj


class OAuth2ApplicationSerializer(BaseSerializer):
show_capabilities = ['edit', 'delete']

class Meta:
model = OAuth2Application
fields = (
'*',
'description',
'-user',
'client_id',
'client_secret',
'client_type',
'redirect_uris',
'authorization_grant_type',
'skip_authorization',
'organization',
)
read_only_fields = ('client_id', 'client_secret')
read_only_on_update_fields = ('user', 'authorization_grant_type')
extra_kwargs = {
'user': {'allow_null': True, 'required': False},
'organization': {'allow_null': False},
'authorization_grant_type': {'allow_null': False, 'label': _('Authorization Grant Type')},
'client_secret': {'label': _('Client Secret')},
'client_type': {'label': _('Client Type')},
'redirect_uris': {'label': _('Redirect URIs')},
'skip_authorization': {'label': _('Skip Authorization')},
}

def to_representation(self, obj):
ret = super(OAuth2ApplicationSerializer, self).to_representation(obj)
request = self.context.get('request', None)
if request.method != 'POST' and obj.client_type == 'confidential':
ret['client_secret'] = CENSOR_VALUE
if obj.client_type == 'public':
ret.pop('client_secret', None)
return ret

def get_related(self, obj):
res = super(OAuth2ApplicationSerializer, self).get_related(obj)
res.update(
dict(
tokens=self.reverse('api:o_auth2_application_token_list', kwargs={'pk': obj.pk}),
activity_stream=self.reverse('api:o_auth2_application_activity_stream_list', kwargs={'pk': obj.pk}),
)
)
if obj.organization_id:
res.update(
dict(
organization=self.reverse('api:organization_detail', kwargs={'pk': obj.organization_id}),
)
)
return res

def get_modified(self, obj):
if obj is None:
return None
return obj.updated

def _summary_field_tokens(self, obj):
token_list = [{'id': x.pk, 'token': CENSOR_VALUE, 'scope': x.scope} for x in obj.oauth2accesstoken_set.all()[:10]]
if has_model_field_prefetched(obj, 'oauth2accesstoken_set'):
token_count = len(obj.oauth2accesstoken_set.all())
else:
if len(token_list) < 10:
token_count = len(token_list)
else:
token_count = obj.oauth2accesstoken_set.count()
return {'count': token_count, 'results': token_list}

def get_summary_fields(self, obj):
ret = super(OAuth2ApplicationSerializer, self).get_summary_fields(obj)
ret['tokens'] = self._summary_field_tokens(obj)
return ret


class OrganizationSerializer(BaseSerializer):
show_capabilities = ['edit', 'delete']

Expand All @@ -1291,7 +1088,6 @@ def get_related(self, obj):
admins=self.reverse('api:organization_admins_list', kwargs={'pk': obj.pk}),
teams=self.reverse('api:organization_teams_list', kwargs={'pk': obj.pk}),
credentials=self.reverse('api:organization_credential_list', kwargs={'pk': obj.pk}),
applications=self.reverse('api:organization_applications_list', kwargs={'pk': obj.pk}),
activity_stream=self.reverse('api:organization_activity_stream_list', kwargs={'pk': obj.pk}),
notification_templates=self.reverse('api:organization_notification_templates_list', kwargs={'pk': obj.pk}),
notification_templates_started=self.reverse('api:organization_notification_templates_started_list', kwargs={'pk': obj.pk}),
Expand Down Expand Up @@ -6074,8 +5870,6 @@ def _local_summarizable_fk_fields(self, obj):
('workflow_job_template_node', ('id', 'unified_job_template_id')),
('label', ('id', 'name', 'organization_id')),
('notification', ('id', 'status', 'notification_type', 'notification_template_id')),
('o_auth2_access_token', ('id', 'user_id', 'description', 'application_id', 'scope')),
('o_auth2_application', ('id', 'name', 'description')),
('credential_type', ('id', 'name', 'description', 'kind', 'managed')),
('ad_hoc_command', ('id', 'name', 'status', 'limit')),
('workflow_approval', ('id', 'name', 'unified_job_id')),
Expand Down
Loading
Loading