Skip to content

Commit

Permalink
Merge pull request #714 from akshayjshah/ajs/details
Browse files Browse the repository at this point in the history
Include error details in protovalidate responses
  • Loading branch information
johanbrandhorst authored May 25, 2024
2 parents 7da22cf + aec8785 commit 8036513
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 18 deletions.
23 changes: 20 additions & 3 deletions interceptors/protovalidate/protovalidate.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option)
break
}
if err = validator.Validate(msg); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
return nil, validationErrToStatus(err).Err()
}
default:
return nil, errors.New("unsupported message type")
Expand Down Expand Up @@ -63,12 +63,15 @@ func (w *wrappedServerStream) RecvMsg(m interface{}) error {
return err
}

msg := m.(proto.Message)
msg, ok := m.(proto.Message)
if !ok {
return errors.New("unsupported message type")
}
if w.options.shouldIgnoreMessage(msg.ProtoReflect().Type()) {
return nil
}
if err := w.validator.Validate(msg); err != nil {
return status.Error(codes.InvalidArgument, err.Error())
return validationErrToStatus(err).Err()
}

return nil
Expand All @@ -93,3 +96,17 @@ func (w *wrappedServerStream) Context() context.Context {
func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream {
return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()}
}

func validationErrToStatus(err error) *status.Status {
// Message is invalid.
if valErr := new(protovalidate.ValidationError); errors.As(err, &valErr) {
st := status.New(codes.InvalidArgument, err.Error())
ds, detErr := st.WithDetails(valErr.ToProto())
if detErr != nil {
return st
}
return ds
}
// CEL expression doesn't compile or type-check.
return status.New(codes.Unknown, err.Error())
}
50 changes: 35 additions & 15 deletions interceptors/protovalidate/protovalidate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ import (
"net"
"testing"

"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"github.com/bufbuild/protovalidate-go"
protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate"
testvalidatev1 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)

Expand All @@ -31,36 +34,34 @@ func TestUnaryServerInterceptor(t *testing.T) {
handler := func(ctx context.Context, req any) (any, error) {
return "good", nil
}
info := &grpc.UnaryServerInfo{FullMethod: "FakeMethod"}

t.Run("valid_email", func(t *testing.T) {
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}

resp, err := interceptor(context.TODO(), testvalidate.GoodUnaryRequest, info, handler)
assert.Nil(t, err)
assert.Equal(t, resp, "good")
})

t.Run("invalid_email", func(t *testing.T) {
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}

_, err = interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler)
assertEqualViolation(t, &validate.Violation{
FieldPath: "message",
ConstraintId: "string.email",
Message: "value must be a valid email address",
}, err)
})

t.Run("not_protobuf", func(t *testing.T) {
_, err = interceptor(context.Background(), "not protobuf", info, handler)
assert.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Equal(t, codes.Unknown, status.Code(err))
})

interceptor = protovalidate_middleware.UnaryServerInterceptor(validator,
protovalidate_middleware.WithIgnoreMessages(testvalidate.BadUnaryRequest.ProtoReflect().Type()),
)

t.Run("invalid_email_ignored", func(t *testing.T) {
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}

resp, err := interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler)
assert.Nil(t, err)
assert.Equal(t, resp, "good")
Expand Down Expand Up @@ -145,8 +146,11 @@ func TestStreamServerInterceptor(t *testing.T) {
assert.Nil(t, err)

_, err = out.Recv()
assert.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assertEqualViolation(t, &validate.Violation{
FieldPath: "message",
ConstraintId: "string.email",
Message: "value must be a valid email address",
}, err)
})

t.Run("invalid_email_ignored", func(t *testing.T) {
Expand All @@ -161,3 +165,19 @@ func TestStreamServerInterceptor(t *testing.T) {
assert.Nil(t, err)
})
}

func assertEqualViolation(tb testing.TB, want *validate.Violation, got error) bool {
require.Error(tb, got)
st := status.Convert(got)
assert.Equal(tb, codes.InvalidArgument, st.Code())
details := st.Proto().GetDetails()
require.Len(tb, details, 1)
gotpb, unwrapErr := details[0].UnmarshalNew()
require.Nil(tb, unwrapErr)
violations := &validate.Violations{
Violations: []*validate.Violation{want},
}
tb.Logf("got: %v", gotpb)
tb.Logf("want: %v", violations)
return assert.True(tb, proto.Equal(gotpb, violations))
}

0 comments on commit 8036513

Please sign in to comment.