diff --git a/actions/approve_csr_handler.go b/actions/approve_csr_handler.go index b8e3368..3365e4b 100644 --- a/actions/approve_csr_handler.go +++ b/actions/approve_csr_handler.go @@ -19,6 +19,7 @@ func newApproveCSRHandler(log logrus.FieldLogger, clientset kubernetes.Interface log: log, clientset: clientset, initialCSRFetchTimeout: 5 * time.Minute, + csrFetchInterval: 5 * time.Second, } } @@ -26,6 +27,7 @@ type approveCSRHandler struct { log logrus.FieldLogger clientset kubernetes.Interface initialCSRFetchTimeout time.Duration + csrFetchInterval time.Duration } func (h *approveCSRHandler) Handle(ctx context.Context, data interface{}) error { @@ -41,6 +43,11 @@ func (h *approveCSRHandler) Handle(ctx context.Context, data interface{}) error return fmt.Errorf("getting initial csr: %w", err) } + if cert.Approved() { + log.Debug("csr is already approved") + return nil + } + b := backoff.WithContext( newApproveCSRExponentialBackoff(), ctx, @@ -54,12 +61,7 @@ func (h *approveCSRHandler) Handle(ctx context.Context, data interface{}) error }) } -func (h *approveCSRHandler) handle(ctx context.Context, log logrus.FieldLogger, cert *csr.Certificate) error { - if cert.Approved() { - log.Debug("initial csr is already approved") - return nil - } - +func (h *approveCSRHandler) handle(ctx context.Context, log logrus.FieldLogger, cert *csr.Certificate) (reterr error) { // Since this new csr may be denied we need to delete it. log.Debug("deleting old csr") if err := csr.DeleteCertificate(ctx, h.clientset, cert); err != nil { @@ -68,7 +70,7 @@ func (h *approveCSRHandler) handle(ctx context.Context, log logrus.FieldLogger, // Create new csr with the same request data as original csr. log.Debug("requesting new csr") - cert, err := csr.RequestCertificate( + newCert, err := csr.RequestCertificate( ctx, h.clientset, cert, @@ -79,7 +81,7 @@ func (h *approveCSRHandler) handle(ctx context.Context, log logrus.FieldLogger, // Approve new csr. log.Debug("approving new csr") - resp, err := csr.ApproveCertificate(ctx, h.clientset, cert) + resp, err := csr.ApproveCertificate(ctx, h.clientset, newCert) if err != nil { return fmt.Errorf("approving csr: %w", err) } @@ -89,28 +91,45 @@ func (h *approveCSRHandler) handle(ctx context.Context, log logrus.FieldLogger, return errors.New("certificate signing request was not approved") } -func (h *approveCSRHandler) getInitialNodeCSR(ctx context.Context, log *logrus.Entry, nodeName string) (*csr.Certificate, error) { +func (h *approveCSRHandler) getInitialNodeCSR(ctx context.Context, log logrus.FieldLogger, nodeName string) (*csr.Certificate, error) { log.Debug("getting initial csr") ctx, cancel := context.WithTimeout(ctx, h.initialCSRFetchTimeout) defer cancel() + poll := func() (*csr.Certificate, error) { + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(h.csrFetchInterval): + cert, err := csr.GetCertificateByNodeName(ctx, h.clientset, nodeName) + if err != nil && !errors.Is(err, csr.ErrNodeCertificateNotFound) { + return nil, err + } + if cert != nil { + return cert, nil + } + } + } + } + var cert *csr.Certificate var err error - b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 3), ctx) - err = backoff.Retry(func() error { - cert, err = csr.GetCertificateByNodeName(ctx, h.clientset, nodeName) + logRetry := func(err error, _ time.Duration) { + log.Warnf("getting initial csr, will retry: %v", err) + } + b := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 3) + err = backoff.RetryNotify(func() error { + cert, err = poll() if errors.Is(err, context.DeadlineExceeded) { return backoff.Permanent(err) } return err - }, b) - if err != nil { - return nil, err - } + }, b, logRetry) - return cert, nil + return cert, err } func newApproveCSRExponentialBackoff() *backoff.ExponentialBackOff { diff --git a/actions/approve_csr_handler_test.go b/actions/approve_csr_handler_test.go index 49baad5..06f4cb2 100644 --- a/actions/approve_csr_handler_test.go +++ b/actions/approve_csr_handler_test.go @@ -34,8 +34,6 @@ func TestApproveCSRHandler(t *testing.T) { csrRes := getCSR() client := fake.NewSimpleClientset(csrRes) - watcher := watch.NewFake() - client.PrependWatchReactor("certificatesigningrequests", ktest.DefaultWatchReactor(watcher, nil)) var approveCalls int32 client.PrependReactor("update", "certificatesigningrequests", func(action ktest.Action) (handled bool, ret runtime.Object, err error) { @@ -61,12 +59,37 @@ func TestApproveCSRHandler(t *testing.T) { h := &approveCSRHandler{ log: log, clientset: client, + csrFetchInterval: 1 * time.Millisecond, + initialCSRFetchTimeout: 10 * time.Millisecond, + } + + ctx := context.Background() + err := h.Handle(ctx, &castai.ActionApproveCSR{NodeName: "gke-am-gcp-cast-5dc4f4ec"}) + r.NoError(err) + }) + + t.Run("return if csr is already approved", func(t *testing.T) { + r := require.New(t) + + csrRes := getCSR() + csrRes.Status.Conditions = []certv1.CertificateSigningRequestCondition{ + { + Type: certv1.CertificateApproved, + Reason: csr.ReasonApproved, + Message: "approved", + LastUpdateTime: metav1.Now(), + Status: v1.ConditionTrue, + }, + } + client := fake.NewSimpleClientset(csrRes) + + h := &approveCSRHandler{ + log: log, + clientset: client, + csrFetchInterval: 1 * time.Millisecond, initialCSRFetchTimeout: 10 * time.Millisecond, } - go func() { - watcher.Add(csrRes) - }() ctx := context.Background() err := h.Handle(ctx, &castai.ActionApproveCSR{NodeName: "gke-am-gcp-cast-5dc4f4ec"}) r.NoError(err) @@ -76,32 +99,28 @@ func TestApproveCSRHandler(t *testing.T) { r := require.New(t) csrRes := getCSR() - watcher := watch.NewFake() count := 0 - fn := ktest.WatchReactionFunc(func(action ktest.Action) (handled bool, ret watch.Interface, err error) { + fn := ktest.ReactionFunc(func(action ktest.Action) (handled bool, ret runtime.Object, err error) { if count == 0 { count++ - return true, watcher, errors.New("api server timeout") + return true, nil, errors.New("api server timeout") } - return true, watcher, err + out := certv1.CertificateSigningRequestList{Items: []certv1.CertificateSigningRequest{*csrRes}} + return true, &out, err }) - client := fake.NewSimpleClientset(csrRes) - client.PrependWatchReactor("certificatesigningrequests", fn) + client.PrependReactor("list", "certificatesigningrequests", fn) h := &approveCSRHandler{ log: log, clientset: client, + csrFetchInterval: 100 * time.Millisecond, initialCSRFetchTimeout: 1000 * time.Millisecond, } - go func() { - watcher.Add(csrRes) - }() ctx := context.Background() err := h.Handle(ctx, &castai.ActionApproveCSR{NodeName: "gke-am-gcp-cast-5dc4f4ec"}) r.NoError(err) - }) t.Run("approve v1beta1 csr successfully", func(t *testing.T) { @@ -124,9 +143,14 @@ AiAHVYZXHxxspoV0hcfn2Pdsl89fIPCOFy/K1PqSUR6QNAIgYdt51ZbQt9rgM2BD }, } client := fake.NewSimpleClientset(csrRes) - notFoundErr := apierrors.NewNotFound(schema.GroupResource{}, "csr") - client.PrependWatchReactor("certificatesigningrequests", ktest.DefaultWatchReactor(nil, notFoundErr)) - + // Return NotFound for all v1 resources. + client.PrependReactor("*", "*", func(action ktest.Action) (handled bool, ret runtime.Object, err error) { + if action.GetResource().Version == "v1" { + err = apierrors.NewNotFound(schema.GroupResource{}, action.GetResource().String()) + return true, nil, err + } + return + }) client.PrependReactor("update", "certificatesigningrequests", func(action ktest.Action) (handled bool, ret runtime.Object, err error) { approved := csrRes.DeepCopy() approved.Status.Conditions = []certv1beta1.CertificateSigningRequestCondition{ @@ -144,6 +168,7 @@ AiAHVYZXHxxspoV0hcfn2Pdsl89fIPCOFy/K1PqSUR6QNAIgYdt51ZbQt9rgM2BD h := &approveCSRHandler{ log: log, clientset: client, + csrFetchInterval: 1 * time.Millisecond, initialCSRFetchTimeout: 10 * time.Millisecond, } @@ -159,6 +184,7 @@ AiAHVYZXHxxspoV0hcfn2Pdsl89fIPCOFy/K1PqSUR6QNAIgYdt51ZbQt9rgM2BD h := &approveCSRHandler{ log: log, clientset: client, + csrFetchInterval: 1 * time.Millisecond, initialCSRFetchTimeout: 10 * time.Millisecond, } diff --git a/csr/csr.go b/csr/csr.go index b62f56d..df22828 100644 --- a/csr/csr.go +++ b/csr/csr.go @@ -212,29 +212,24 @@ func getNodeCSRV1(ctx context.Context, client kubernetes.Interface, nodeName str }).String(), } - watch, err := client.CertificatesV1().CertificateSigningRequests().Watch(ctx, options) + csrList, err := client.CertificatesV1().CertificateSigningRequests().List(ctx, options) if err != nil { return nil, err } - defer watch.Stop() - for r := range watch.ResultChan() { - csr, ok := r.Object.(*certv1.CertificateSigningRequest) - if !ok { - continue - } - - if len(csr.Status.Certificate) != 0 { - // If certificate is present - CSR is already approved. - continue - } + // Sort by newest first soo we don't need to parse old items. + sort.Slice(csrList.Items, func(i, j int) bool { + return csrList.Items[i].CreationTimestamp.After(csrList.Items[j].CreationTimestamp.Time) + }) + for _, csr := range csrList.Items { + csr := csr found, err := isNodeCSR(csr.Name, csr.Spec.Request, nodeName) if err != nil { return nil, err } if found { - return &Certificate{V1: csr}, nil + return &Certificate{V1: &csr}, nil } } @@ -259,11 +254,6 @@ func getNodeCSRV1Beta1(ctx context.Context, client kubernetes.Interface, nodeNam for _, item := range csrList.Items { item := item - if len(item.Status.Certificate) != 0 { - // If certificate is present - CSR is already approved. - continue - } - ok, err := isNodeCSR(item.Name, item.Spec.Request, nodeName) if err != nil { return nil, err