Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rp): return oidc.Tokens on token refresh #423

Merged
merged 1 commit into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions pkg/client/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"context"
"io"
"io/ioutil"
"math/rand"
"net/http"
"net/http/cookiejar"
Expand Down Expand Up @@ -56,23 +55,25 @@ func TestRelyingPartySession(t *testing.T) {
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)

t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")

t.Log("------- refresh tokens ------")

newTokens, err := rp.RefreshAccessToken(CTX, provider, refreshToken, "", "")
newTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
require.NoError(t, err, "refresh token")
assert.NotNil(t, newTokens, "access token")
t.Logf("new access token %s", newTokens.AccessToken)
t.Logf("new refresh token %s", newTokens.RefreshToken)
t.Logf("new token type %s", newTokens.TokenType)
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
assert.NotEmpty(t, newTokens.Extra("id_token"), "new idToken")
assert.NotEmpty(t, newTokens.IDToken, "new idToken")
assert.NotNil(t, newTokens.IDTokenClaims)
assert.Equal(t, newTokens.IDTokenClaims.Subject, tokens.IDTokenClaims.Subject)

t.Log("------ end session (logout) ------")

newLoc, err := rp.EndSession(CTX, provider, idToken, "", "")
newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
require.NoError(t, err, "logout")
if newLoc != nil {
t.Logf("redirect to %s", newLoc)
Expand All @@ -81,12 +82,12 @@ func TestRelyingPartySession(t *testing.T) {
}

t.Log("------ attempt refresh again (should fail) ------")
t.Log("trying original refresh token", refreshToken)
_, err = rp.RefreshAccessToken(CTX, provider, refreshToken, "", "")
t.Log("trying original refresh token", tokens.RefreshToken)
_, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with original")
if newTokens.RefreshToken != "" {
t.Log("trying replacement refresh token", newTokens.RefreshToken)
_, err = rp.RefreshAccessToken(CTX, provider, newTokens.RefreshToken, "", "")
_, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, newTokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with replacement")
}
}
Expand All @@ -106,7 +107,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
clientSecret := "secret"

t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)

resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret)
require.NoError(t, err, "new resource server")
Expand All @@ -116,7 +117,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
tokenExchangeResponse, err := tokenexchange.ExchangeToken(
CTX,
resourceServer,
refreshToken,
tokens.RefreshToken,
oidc.RefreshTokenType,
"",
"",
Expand All @@ -134,7 +135,7 @@ func TestResourceServerTokenExchange(t *testing.T) {

t.Log("------ end session (logout) ------")

newLoc, err := rp.EndSession(CTX, provider, idToken, "", "")
newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
require.NoError(t, err, "logout")
if newLoc != nil {
t.Logf("redirect to %s", newLoc)
Expand All @@ -147,7 +148,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
tokenExchangeResponse, err = tokenexchange.ExchangeToken(
CTX,
resourceServer,
refreshToken,
tokens.RefreshToken,
oidc.RefreshTokenType,
"",
"",
Expand All @@ -161,7 +162,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
require.Nil(t, tokenExchangeResponse, "token exchange response")
}

func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) {
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) {
targetURL := "http://local-site"
localURL, err := url.Parse(targetURL + "/login?requestID=1234")
require.NoError(t, err, "local url")
Expand Down Expand Up @@ -258,17 +259,15 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
}

var email string
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
tokens = newTokens
require.NotNil(t, tokens, "tokens")
require.NotNil(t, info, "info")
t.Log("access token", tokens.AccessToken)
t.Log("refresh token", tokens.RefreshToken)
t.Log("id token", tokens.IDToken)
t.Log("email", info.Email)

accessToken = tokens.AccessToken
refreshToken = tokens.RefreshToken
idToken = tokens.IDToken
email = info.Email
http.Redirect(w, r, targetURL, 302)
}
Expand All @@ -290,12 +289,12 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
require.NoError(t, err, "get fully-authorizied redirect location")
require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location")

require.NotEmpty(t, idToken, "id token")
assert.NotEmpty(t, refreshToken, "refresh token")
assert.NotEmpty(t, accessToken, "access token")
require.NotEmpty(t, tokens.IDToken, "id token")
assert.NotEmpty(t, tokens.RefreshToken, "refresh token")
assert.NotEmpty(t, tokens.AccessToken, "access token")
assert.NotEmpty(t, email, "email")

return provider, accessToken, refreshToken, idToken
return provider, tokens
}

type deferredHandler struct {
Expand Down Expand Up @@ -343,7 +342,7 @@ func getForm(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) [

func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL {
// TODO: switch to io.NopCloser when go1.15 support is dropped
req := gosubmit.ParseWithURL(ioutil.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest(
req := gosubmit.ParseWithURL(io.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest(
append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)...,
)
if req.URL.Scheme == "" {
Expand Down
59 changes: 38 additions & 21 deletions pkg/client/rp/relying_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,25 @@
return oidc.NewSHACodeChallenge(codeVerifier), nil
}

// ErrMissingIDToken is returned when an id_token was expected,
// but not received in the token response.
var ErrMissingIDToken = errors.New("id_token missing")

func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) {
if rp.IsOAuth2Only() {
return &oidc.Tokens[C]{Token: token}, nil
}
idTokenString, ok := token.Extra(idTokenKey).(string)
if !ok {
return &oidc.Tokens[C]{Token: token}, ErrMissingIDToken
}
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
if err != nil {
return nil, err
}
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
}

// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
// returning it parsed together with the oauth2 tokens (access, refresh)
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
Expand All @@ -369,22 +388,7 @@
if err != nil {
return nil, err
}

if rp.IsOAuth2Only() {
return &oidc.Tokens[C]{Token: token}, nil
}

idTokenString, ok := token.Extra(idTokenKey).(string)
if !ok {
return nil, errors.New("id_token missing")
}

idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
if err != nil {
return nil, err
}

return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
return verifyTokenResponse[C](ctx, token, rp)
}

type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
Expand Down Expand Up @@ -609,11 +613,14 @@
GrantType oidc.GrantType `schema:"grant_type"`
}

// RefreshAccessToken performs a token refresh. If it doesn't error, it will always
// RefreshTokens performs a token refresh. If it doesn't error, it will always
// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then
// the old one should be considered invalid. It may also provide a new IDToken. The
// new IDToken can be retrieved with token.Extra("id_token").
func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) {
// the old one should be considered invalid.
//
// In case the RP is not OAuth2 only and an IDToken was part of the response,
// the IDToken and AccessToken will be verfied
// and the IDToken and IDTokenClaims fields will be populated in the returned object.
func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) {
request := RefreshTokenRequest{
RefreshToken: refreshToken,
Scopes: rp.OAuthConfig().Scopes,
Expand All @@ -623,7 +630,17 @@
ClientAssertionType: clientAssertionType,
GrantType: oidc.GrantTypeRefreshToken,
}
return client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp})
newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp})
if err != nil {
return nil, err
}
tokens, err := verifyTokenResponse[C](ctx, newToken, rp)
if err == nil || errors.Is(err, ErrMissingIDToken) {
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse
// ...except that it might not contain an id_token.
return tokens, nil
}
return nil, err

Check warning on line 643 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L643

Added line #L643 was not covered by tests
}

func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
Expand Down
107 changes: 107 additions & 0 deletions pkg/client/rp/relying_party_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package rp

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
)

func Test_verifyTokenResponse(t *testing.T) {
verifier := &IDTokenVerifier{
Issuer: tu.ValidIssuer,
MaxAgeIAT: 2 * time.Minute,
ClientID: tu.ValidClientID,
Offset: time.Second,
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
KeySet: tu.KeySet{},
MaxAge: 2 * time.Minute,
ACR: tu.ACRVerify,
Nonce: func(context.Context) string { return tu.ValidNonce },
}
tests := []struct {
name string
oauth2Only bool
tokens func() (token *oauth2.Token, want *oidc.Tokens[*oidc.IDTokenClaims])
wantErr error
}{
{
name: "succes, oauth2 only",
oauth2Only: true,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
}
},
},
{
name: "id_token missing error",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
}
},
wantErr: ErrMissingIDToken,
},
{
name: "verify tokens error",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
token = token.WithExtra(map[string]any{
"id_token": "foobar",
})
return token, nil
},
wantErr: oidc.ErrParse,
},
{
name: "success, with id_token",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
idToken, claims := tu.ValidIDToken()
token = token.WithExtra(map[string]any{
"id_token": idToken,
})
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
IDTokenClaims: claims,
IDToken: idToken,
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &relyingParty{
oauth2Only: tt.oauth2Only,
idTokenVerifier: verifier,
}
token, want := tt.tokens()
got, err := verifyTokenResponse[*oidc.IDTokenClaims](context.Background(), token, rp)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, want, got)
})
}
}
Loading