From 01f99486c842ffa3e1caf4d234ae0e52c7d933f6 Mon Sep 17 00:00:00 2001 From: Akshay Shah Date: Fri, 24 May 2024 20:46:02 -0700 Subject: [PATCH 1/3] protovalidate: refactor repetitive subtests The info struct is shared between all these subtests - there's no need to clutter up each test by redefining it. Signed-off-by: Akshay Shah --- interceptors/protovalidate/protovalidate_test.go | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/interceptors/protovalidate/protovalidate_test.go b/interceptors/protovalidate/protovalidate_test.go index f626f395c..0532e9d76 100644 --- a/interceptors/protovalidate/protovalidate_test.go +++ b/interceptors/protovalidate/protovalidate_test.go @@ -31,22 +31,15 @@ func TestUnaryServerInterceptor(t *testing.T) { handler := func(ctx context.Context, req any) (any, error) { return "good", nil } + info := &grpc.UnaryServerInfo{FullMethod: "FakeMethod"} t.Run("valid_email", func(t *testing.T) { - info := &grpc.UnaryServerInfo{ - FullMethod: "FakeMethod", - } - resp, err := interceptor(context.TODO(), testvalidate.GoodUnaryRequest, info, handler) assert.Nil(t, err) assert.Equal(t, resp, "good") }) t.Run("invalid_email", func(t *testing.T) { - info := &grpc.UnaryServerInfo{ - FullMethod: "FakeMethod", - } - _, err = interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler) assert.Error(t, err) assert.Equal(t, codes.InvalidArgument, status.Code(err)) @@ -57,10 +50,6 @@ func TestUnaryServerInterceptor(t *testing.T) { ) t.Run("invalid_email_ignored", func(t *testing.T) { - info := &grpc.UnaryServerInfo{ - FullMethod: "FakeMethod", - } - resp, err := interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler) assert.Nil(t, err) assert.Equal(t, resp, "good") From 6e75075a0e9f5b171b33d84624153395c0681567 Mon Sep 17 00:00:00 2001 From: Akshay Shah Date: Fri, 24 May 2024 20:48:12 -0700 Subject: [PATCH 2/3] protovalidate: don't panic in streaming interceptor The streaming interceptor should match the behavior of the unary interceptor and gracefully handle non-protobuf messages. Signed-off-by: Akshay Shah --- interceptors/protovalidate/protovalidate.go | 5 ++++- interceptors/protovalidate/protovalidate_test.go | 6 ++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/interceptors/protovalidate/protovalidate.go b/interceptors/protovalidate/protovalidate.go index cf337db11..86c14fa9e 100644 --- a/interceptors/protovalidate/protovalidate.go +++ b/interceptors/protovalidate/protovalidate.go @@ -63,7 +63,10 @@ func (w *wrappedServerStream) RecvMsg(m interface{}) error { return err } - msg := m.(proto.Message) + msg, ok := m.(proto.Message) + if !ok { + return errors.New("unsupported message type") + } if w.options.shouldIgnoreMessage(msg.ProtoReflect().Type()) { return nil } diff --git a/interceptors/protovalidate/protovalidate_test.go b/interceptors/protovalidate/protovalidate_test.go index 0532e9d76..5ed6c6812 100644 --- a/interceptors/protovalidate/protovalidate_test.go +++ b/interceptors/protovalidate/protovalidate_test.go @@ -45,6 +45,12 @@ func TestUnaryServerInterceptor(t *testing.T) { assert.Equal(t, codes.InvalidArgument, status.Code(err)) }) + t.Run("not_protobuf", func(t *testing.T) { + _, err = interceptor(context.Background(), "not protobuf", info, handler) + assert.Error(t, err) + assert.Equal(t, codes.Unknown, status.Code(err)) + }) + interceptor = protovalidate_middleware.UnaryServerInterceptor(validator, protovalidate_middleware.WithIgnoreMessages(testvalidate.BadUnaryRequest.ProtoReflect().Type()), ) From aec878525ba2576b998a4fb08330d3d9dca52fcc Mon Sep 17 00:00:00 2001 From: Akshay Shah Date: Fri, 24 May 2024 21:17:34 -0700 Subject: [PATCH 3/3] protovalidate: send violations as error details Amend the unary and streaming interceptors to send validation errors to the client as an error detail. This allows client code to easily parse and work with the structured validation information: for example, a UI might want to display validation errors next to the relevant fields in a form. Signed-off-by: Akshay Shah --- interceptors/protovalidate/protovalidate.go | 18 ++++++++-- .../protovalidate/protovalidate_test.go | 33 ++++++++++++++++--- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/interceptors/protovalidate/protovalidate.go b/interceptors/protovalidate/protovalidate.go index 86c14fa9e..fac113ba2 100644 --- a/interceptors/protovalidate/protovalidate.go +++ b/interceptors/protovalidate/protovalidate.go @@ -29,7 +29,7 @@ func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option) break } if err = validator.Validate(msg); err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, validationErrToStatus(err).Err() } default: return nil, errors.New("unsupported message type") @@ -71,7 +71,7 @@ func (w *wrappedServerStream) RecvMsg(m interface{}) error { return nil } if err := w.validator.Validate(msg); err != nil { - return status.Error(codes.InvalidArgument, err.Error()) + return validationErrToStatus(err).Err() } return nil @@ -96,3 +96,17 @@ func (w *wrappedServerStream) Context() context.Context { func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream { return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()} } + +func validationErrToStatus(err error) *status.Status { + // Message is invalid. + if valErr := new(protovalidate.ValidationError); errors.As(err, &valErr) { + st := status.New(codes.InvalidArgument, err.Error()) + ds, detErr := st.WithDetails(valErr.ToProto()) + if detErr != nil { + return st + } + return ds + } + // CEL expression doesn't compile or type-check. + return status.New(codes.Unknown, err.Error()) +} diff --git a/interceptors/protovalidate/protovalidate_test.go b/interceptors/protovalidate/protovalidate_test.go index 5ed6c6812..052d2395b 100644 --- a/interceptors/protovalidate/protovalidate_test.go +++ b/interceptors/protovalidate/protovalidate_test.go @@ -9,16 +9,19 @@ import ( "net" "testing" + "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" "github.com/bufbuild/protovalidate-go" protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate" "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate" testvalidatev1 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate/v1" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -41,8 +44,11 @@ func TestUnaryServerInterceptor(t *testing.T) { t.Run("invalid_email", func(t *testing.T) { _, err = interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler) - assert.Error(t, err) - assert.Equal(t, codes.InvalidArgument, status.Code(err)) + assertEqualViolation(t, &validate.Violation{ + FieldPath: "message", + ConstraintId: "string.email", + Message: "value must be a valid email address", + }, err) }) t.Run("not_protobuf", func(t *testing.T) { @@ -140,8 +146,11 @@ func TestStreamServerInterceptor(t *testing.T) { assert.Nil(t, err) _, err = out.Recv() - assert.Error(t, err) - assert.Equal(t, codes.InvalidArgument, status.Code(err)) + assertEqualViolation(t, &validate.Violation{ + FieldPath: "message", + ConstraintId: "string.email", + Message: "value must be a valid email address", + }, err) }) t.Run("invalid_email_ignored", func(t *testing.T) { @@ -156,3 +165,19 @@ func TestStreamServerInterceptor(t *testing.T) { assert.Nil(t, err) }) } + +func assertEqualViolation(tb testing.TB, want *validate.Violation, got error) bool { + require.Error(tb, got) + st := status.Convert(got) + assert.Equal(tb, codes.InvalidArgument, st.Code()) + details := st.Proto().GetDetails() + require.Len(tb, details, 1) + gotpb, unwrapErr := details[0].UnmarshalNew() + require.Nil(tb, unwrapErr) + violations := &validate.Violations{ + Violations: []*validate.Violation{want}, + } + tb.Logf("got: %v", gotpb) + tb.Logf("want: %v", violations) + return assert.True(tb, proto.Equal(gotpb, violations)) +}