Skip to content

Commit

Permalink
fix: use polling while handling csr (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
varnastadeus authored Oct 10, 2022
1 parent 9d2ebb5 commit b498a8f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 53 deletions.
53 changes: 36 additions & 17 deletions actions/approve_csr_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ func newApproveCSRHandler(log logrus.FieldLogger, clientset kubernetes.Interface
log: log,
clientset: clientset,
initialCSRFetchTimeout: 5 * time.Minute,
csrFetchInterval: 5 * time.Second,
}
}

type approveCSRHandler struct {
log logrus.FieldLogger
clientset kubernetes.Interface
initialCSRFetchTimeout time.Duration
csrFetchInterval time.Duration
}

func (h *approveCSRHandler) Handle(ctx context.Context, data interface{}) error {
Expand All @@ -41,6 +43,11 @@ func (h *approveCSRHandler) Handle(ctx context.Context, data interface{}) error
return fmt.Errorf("getting initial csr: %w", err)
}

if cert.Approved() {
log.Debug("csr is already approved")
return nil
}

b := backoff.WithContext(
newApproveCSRExponentialBackoff(),
ctx,
Expand All @@ -54,12 +61,7 @@ func (h *approveCSRHandler) Handle(ctx context.Context, data interface{}) error
})
}

func (h *approveCSRHandler) handle(ctx context.Context, log logrus.FieldLogger, cert *csr.Certificate) error {
if cert.Approved() {
log.Debug("initial csr is already approved")
return nil
}

func (h *approveCSRHandler) handle(ctx context.Context, log logrus.FieldLogger, cert *csr.Certificate) (reterr error) {
// Since this new csr may be denied we need to delete it.
log.Debug("deleting old csr")
if err := csr.DeleteCertificate(ctx, h.clientset, cert); err != nil {
Expand All @@ -68,7 +70,7 @@ func (h *approveCSRHandler) handle(ctx context.Context, log logrus.FieldLogger,

// Create new csr with the same request data as original csr.
log.Debug("requesting new csr")
cert, err := csr.RequestCertificate(
newCert, err := csr.RequestCertificate(
ctx,
h.clientset,
cert,
Expand All @@ -79,7 +81,7 @@ func (h *approveCSRHandler) handle(ctx context.Context, log logrus.FieldLogger,

// Approve new csr.
log.Debug("approving new csr")
resp, err := csr.ApproveCertificate(ctx, h.clientset, cert)
resp, err := csr.ApproveCertificate(ctx, h.clientset, newCert)
if err != nil {
return fmt.Errorf("approving csr: %w", err)
}
Expand All @@ -89,28 +91,45 @@ func (h *approveCSRHandler) handle(ctx context.Context, log logrus.FieldLogger,
return errors.New("certificate signing request was not approved")
}

func (h *approveCSRHandler) getInitialNodeCSR(ctx context.Context, log *logrus.Entry, nodeName string) (*csr.Certificate, error) {
func (h *approveCSRHandler) getInitialNodeCSR(ctx context.Context, log logrus.FieldLogger, nodeName string) (*csr.Certificate, error) {
log.Debug("getting initial csr")

ctx, cancel := context.WithTimeout(ctx, h.initialCSRFetchTimeout)
defer cancel()

poll := func() (*csr.Certificate, error) {
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(h.csrFetchInterval):
cert, err := csr.GetCertificateByNodeName(ctx, h.clientset, nodeName)
if err != nil && !errors.Is(err, csr.ErrNodeCertificateNotFound) {
return nil, err
}
if cert != nil {
return cert, nil
}
}
}
}

var cert *csr.Certificate
var err error

b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 3), ctx)
err = backoff.Retry(func() error {
cert, err = csr.GetCertificateByNodeName(ctx, h.clientset, nodeName)
logRetry := func(err error, _ time.Duration) {
log.Warnf("getting initial csr, will retry: %v", err)
}
b := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 3)
err = backoff.RetryNotify(func() error {
cert, err = poll()
if errors.Is(err, context.DeadlineExceeded) {
return backoff.Permanent(err)
}
return err
}, b)
if err != nil {
return nil, err
}
}, b, logRetry)

return cert, nil
return cert, err
}

func newApproveCSRExponentialBackoff() *backoff.ExponentialBackOff {
Expand Down
62 changes: 44 additions & 18 deletions actions/approve_csr_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ func TestApproveCSRHandler(t *testing.T) {

csrRes := getCSR()
client := fake.NewSimpleClientset(csrRes)
watcher := watch.NewFake()
client.PrependWatchReactor("certificatesigningrequests", ktest.DefaultWatchReactor(watcher, nil))

var approveCalls int32
client.PrependReactor("update", "certificatesigningrequests", func(action ktest.Action) (handled bool, ret runtime.Object, err error) {
Expand All @@ -61,12 +59,37 @@ func TestApproveCSRHandler(t *testing.T) {
h := &approveCSRHandler{
log: log,
clientset: client,
csrFetchInterval: 1 * time.Millisecond,
initialCSRFetchTimeout: 10 * time.Millisecond,
}

ctx := context.Background()
err := h.Handle(ctx, &castai.ActionApproveCSR{NodeName: "gke-am-gcp-cast-5dc4f4ec"})
r.NoError(err)
})

t.Run("return if csr is already approved", func(t *testing.T) {
r := require.New(t)

csrRes := getCSR()
csrRes.Status.Conditions = []certv1.CertificateSigningRequestCondition{
{
Type: certv1.CertificateApproved,
Reason: csr.ReasonApproved,
Message: "approved",
LastUpdateTime: metav1.Now(),
Status: v1.ConditionTrue,
},
}
client := fake.NewSimpleClientset(csrRes)

h := &approveCSRHandler{
log: log,
clientset: client,
csrFetchInterval: 1 * time.Millisecond,
initialCSRFetchTimeout: 10 * time.Millisecond,
}

go func() {
watcher.Add(csrRes)
}()
ctx := context.Background()
err := h.Handle(ctx, &castai.ActionApproveCSR{NodeName: "gke-am-gcp-cast-5dc4f4ec"})
r.NoError(err)
Expand All @@ -76,32 +99,28 @@ func TestApproveCSRHandler(t *testing.T) {
r := require.New(t)

csrRes := getCSR()
watcher := watch.NewFake()
count := 0
fn := ktest.WatchReactionFunc(func(action ktest.Action) (handled bool, ret watch.Interface, err error) {
fn := ktest.ReactionFunc(func(action ktest.Action) (handled bool, ret runtime.Object, err error) {
if count == 0 {
count++
return true, watcher, errors.New("api server timeout")
return true, nil, errors.New("api server timeout")
}
return true, watcher, err
out := certv1.CertificateSigningRequestList{Items: []certv1.CertificateSigningRequest{*csrRes}}
return true, &out, err
})

client := fake.NewSimpleClientset(csrRes)
client.PrependWatchReactor("certificatesigningrequests", fn)
client.PrependReactor("list", "certificatesigningrequests", fn)

h := &approveCSRHandler{
log: log,
clientset: client,
csrFetchInterval: 100 * time.Millisecond,
initialCSRFetchTimeout: 1000 * time.Millisecond,
}

go func() {
watcher.Add(csrRes)
}()
ctx := context.Background()
err := h.Handle(ctx, &castai.ActionApproveCSR{NodeName: "gke-am-gcp-cast-5dc4f4ec"})
r.NoError(err)

})

t.Run("approve v1beta1 csr successfully", func(t *testing.T) {
Expand All @@ -124,9 +143,14 @@ AiAHVYZXHxxspoV0hcfn2Pdsl89fIPCOFy/K1PqSUR6QNAIgYdt51ZbQt9rgM2BD
},
}
client := fake.NewSimpleClientset(csrRes)
notFoundErr := apierrors.NewNotFound(schema.GroupResource{}, "csr")
client.PrependWatchReactor("certificatesigningrequests", ktest.DefaultWatchReactor(nil, notFoundErr))

// Return NotFound for all v1 resources.
client.PrependReactor("*", "*", func(action ktest.Action) (handled bool, ret runtime.Object, err error) {
if action.GetResource().Version == "v1" {
err = apierrors.NewNotFound(schema.GroupResource{}, action.GetResource().String())
return true, nil, err
}
return
})
client.PrependReactor("update", "certificatesigningrequests", func(action ktest.Action) (handled bool, ret runtime.Object, err error) {
approved := csrRes.DeepCopy()
approved.Status.Conditions = []certv1beta1.CertificateSigningRequestCondition{
Expand All @@ -144,6 +168,7 @@ AiAHVYZXHxxspoV0hcfn2Pdsl89fIPCOFy/K1PqSUR6QNAIgYdt51ZbQt9rgM2BD
h := &approveCSRHandler{
log: log,
clientset: client,
csrFetchInterval: 1 * time.Millisecond,
initialCSRFetchTimeout: 10 * time.Millisecond,
}

Expand All @@ -159,6 +184,7 @@ AiAHVYZXHxxspoV0hcfn2Pdsl89fIPCOFy/K1PqSUR6QNAIgYdt51ZbQt9rgM2BD
h := &approveCSRHandler{
log: log,
clientset: client,
csrFetchInterval: 1 * time.Millisecond,
initialCSRFetchTimeout: 10 * time.Millisecond,
}

Expand Down
26 changes: 8 additions & 18 deletions csr/csr.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,29 +212,24 @@ func getNodeCSRV1(ctx context.Context, client kubernetes.Interface, nodeName str
}).String(),
}

watch, err := client.CertificatesV1().CertificateSigningRequests().Watch(ctx, options)
csrList, err := client.CertificatesV1().CertificateSigningRequests().List(ctx, options)
if err != nil {
return nil, err
}
defer watch.Stop()

for r := range watch.ResultChan() {
csr, ok := r.Object.(*certv1.CertificateSigningRequest)
if !ok {
continue
}

if len(csr.Status.Certificate) != 0 {
// If certificate is present - CSR is already approved.
continue
}
// Sort by newest first soo we don't need to parse old items.
sort.Slice(csrList.Items, func(i, j int) bool {
return csrList.Items[i].CreationTimestamp.After(csrList.Items[j].CreationTimestamp.Time)
})

for _, csr := range csrList.Items {
csr := csr
found, err := isNodeCSR(csr.Name, csr.Spec.Request, nodeName)
if err != nil {
return nil, err
}
if found {
return &Certificate{V1: csr}, nil
return &Certificate{V1: &csr}, nil
}
}

Expand All @@ -259,11 +254,6 @@ func getNodeCSRV1Beta1(ctx context.Context, client kubernetes.Interface, nodeNam

for _, item := range csrList.Items {
item := item
if len(item.Status.Certificate) != 0 {
// If certificate is present - CSR is already approved.
continue
}

ok, err := isNodeCSR(item.Name, item.Spec.Request, nodeName)
if err != nil {
return nil, err
Expand Down

0 comments on commit b498a8f

Please sign in to comment.