Skip to content

Commit

Permalink
Merge pull request #98 from yjinjo/master
Browse files Browse the repository at this point in the history
Add SAML feature and optimize the code
  • Loading branch information
yjinjo authored Jun 7, 2024
2 parents 414cec1 + da877a9 commit 6c139c2
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 59 deletions.
66 changes: 10 additions & 56 deletions src/cloudforet/console_api_v2/interface/rest/extension/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@

from fastapi import Depends, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.responses import RedirectResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi_utils.cbv import cbv
from fastapi_utils.inferring_router import InferringRouter
from spaceone.core import config
from spaceone.core.error import ERROR_REQUIRED_PARAMETER
from spaceone.core.fastapi.api import BaseAPI, exception_handler

from cloudforet.console_api_v2.manager.cloudforet_manager import CloudforetManager
from cloudforet.console_api_v2.service.auth_service import AuthService
from cloudforet.console_api_v2.service.proxy_service import ProxyService

_LOGGER = logging.getLogger(__name__)
_AUTH_SCHEME = HTTPBasic()
Expand Down Expand Up @@ -40,57 +36,15 @@ async def basic(
@router.post("/saml/{domain_id}")
@exception_handler
async def saml(self, request: Request, domain_id: str):
saml_service: AuthService = AuthService()
form_data = await request.form()
credentials = self._extract_credentials(request, dict(form_data))
refresh_token = self._issue_token(credentials, domain_id)
domain_name = self._get_domain_name(domain_id)
return self._redirect_response(domain_name, refresh_token)
params = {"request": request, "form_data": form_data, "domain_id": domain_id}
response = await run_in_threadpool(saml_service.saml, params)
return response

@staticmethod
def _extract_credentials(request: Request, form_data: dict) -> dict:
return {
"http_host": request.client.host,
"server_port": request.url.port,
"script_name": request.url.path,
"post_data": form_data,
}

@staticmethod
def _issue_token(credentials: dict, domain_id: str) -> str:
dispatch_params = {
"grpc_method": "identity.Token.issue",
"credentials": credentials,
"auth_type": "EXTERNAL",
"domain_id": domain_id,
}

proxy_service = ProxyService()
response = proxy_service.dispatch_api(dispatch_params)

return response.get("refresh_token")

@staticmethod
def _get_domain_name(domain_id: str) -> str:
cloudforet_mgr = CloudforetManager()
grpc_method = "identity.Domain.get"
dispatch_params = {"domain_id": domain_id}
system_token = config.get_global("TOKEN")

response = cloudforet_mgr.dispatch_api(
grpc_method, dispatch_params, system_token
)

return response.get("name")

@staticmethod
def _redirect_response(domain_name: str, refresh_token: str) -> RedirectResponse:
console_domain: str = config.get_global("CONSOLE_DOMAIN").format(
domain_name=domain_name
)
if refresh_token:
return RedirectResponse(
f"{console_domain}/saml?refresh_token={refresh_token}",
status_code=302,
)

return RedirectResponse(f"{console_domain}", status_code=302)
@router.get("/saml/{domain_id}/metadata")
@exception_handler
async def saml_sp_metadata(self, domain_id: str):
saml_service: AuthService = AuthService()
response = await run_in_threadpool(saml_service.saml_sp_metadata, domain_id)
return response
107 changes: 104 additions & 3 deletions src/cloudforet/console_api_v2/service/auth_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import json
import logging

from spaceone.core import cache
from spaceone.core import config
from fastapi import Request, Response
from fastapi.responses import RedirectResponse
from spaceone.core import cache, config
from spaceone.core.auth.jwt import JWTAuthenticator, JWTUtil
from spaceone.core.error import ERROR_AUTHENTICATE_FAILURE
from spaceone.core.service import *
from spaceone.core.service import BaseService, event_handler, transaction

from cloudforet.console_api_v2.manager.cloudforet_manager import CloudforetManager
from cloudforet.console_api_v2.service.proxy_service import ProxyService

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,6 +53,23 @@ def basic(self, params: dict) -> None:
self._check_app(client_id, domain_id)
self._authenticate(token, domain_id)

def saml(self, params: dict) -> RedirectResponse:
request = params.get("request")
form_data = params.get("form_data")
domain_id = params.get("domain_id")

credentials = self._extract_credentials(request, dict(form_data))
refresh_token = self._issue_token(credentials, domain_id)
domain_name = self._get_domain_name(domain_id)
return self._redirect_response(domain_name, refresh_token)

def saml_sp_metadata(self, domain_id: str) -> Response:
sp_entity_id = domain_id
domain_name = self._get_domain_name(domain_id)
acs_url = self._get_acs_url(domain_name, domain_id)
metadata_xml = self._generate_sp_metadata(sp_entity_id, acs_url)
return Response(content=metadata_xml, media_type="application/xml")

def _authenticate(self, token: str, domain_id: str) -> dict:
public_key = self._get_public_key(domain_id)
return JWTAuthenticator(json.loads(public_key)).validate(token)
Expand Down Expand Up @@ -100,3 +119,85 @@ def _check_app(client_id: str, domain_id: str):
{"client_id": client_id, "domain_id": domain_id},
token=system_token,
)

@staticmethod
def _extract_credentials(request: Request, form_data: dict) -> dict:
return {
"http_host": request.client.host,
"server_port": str(request.url.port),
"script_name": request.url.path,
"post_data": form_data,
}

@staticmethod
def _issue_token(credentials: dict, domain_id: str) -> str:
dispatch_params = {
"grpc_method": "identity.Token.issue",
"credentials": credentials,
"auth_type": "EXTERNAL",
"domain_id": domain_id,
}

proxy_service = ProxyService()
response = proxy_service.dispatch_api(dispatch_params)

return response.get("refresh_token")

@staticmethod
def _get_domain_name(domain_id: str) -> str:
cloudforet_mgr = CloudforetManager()
grpc_method = "identity.Domain.get"
dispatch_params = {"domain_id": domain_id}
system_token = config.get_global("TOKEN")

response = cloudforet_mgr.dispatch_api(
grpc_method, dispatch_params, system_token
)

return response.get("name")

@staticmethod
def _redirect_response(domain_name: str, refresh_token: str) -> RedirectResponse:
console_domain: str = config.get_global("CONSOLE_DOMAIN").format(
domain_name=domain_name
)

if refresh_token:
return RedirectResponse(
f"{console_domain}/saml?refresh_token={refresh_token}",
status_code=302,
)

return RedirectResponse(f"{console_domain}", status_code=302)

@staticmethod
def _get_acs_url(domain_name: str, domain_id: str) -> str:
console_api_v2_endpoint = config.get_global("CONSOLE_API_V2_ENDPOINT")
acs_url = (
f"{console_api_v2_endpoint}/console-api/extension/auth/saml/{domain_id}"
)

return acs_url

@staticmethod
def _generate_sp_metadata(sp_entity_id: str, acs_url: str) -> str:
"""Generates SP metadata XML.
Args:
'sp_entity_id': 'str' (Service Provider Entity ID),
'acs_url': 'str' (Assertion Consumer Service URL),
'x509_cert': 'str' (X.509 certificate),
Returns:
'metadata_template': 'str' (SP metadata XML)
"""
metadata_template = f"""
<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="{sp_entity_id}">
<md:SPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="{acs_url}" index="1"/>
</md:SPSSODescriptor>
</md:EntityDescriptor>
"""

return metadata_template.strip()

0 comments on commit 6c139c2

Please sign in to comment.