mirror of https://github.com/hashicorp/boundary
Add API Error Handler (#74)
* Add error handler that converts from rpc.Status errors to Watchtower API errors. * Better testing for ErrorHandler. * Rename testcase and remove empty default switch case. * Change the error status back to int64. * Converting everything to int32s. jsonpb wraps int64s as a string which we dont like. We'll figure out how to use values larger than int64s when it comes up. * Remove special casing for TraceId which isn't needed anymore. * Removing wrappers from error details since we never need to know when they are set or unset from the end user. * Using the helper error functions inside the project service. * Correct usage of hclog, replace panic with Error log. * Adding periods to all API returned errors, correct spelling, fix missed invalid argument error not using helper function. * Change our logged errors to Error instead of Warn. * Add TODO for defining and using our own defined API error codes. * Add TODO to remove internal error messages.pull/80/head
parent
b718091419
commit
44152ae63e
@ -0,0 +1,88 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/grpc-ecosystem/grpc-gateway/runtime"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
pb "github.com/hashicorp/watchtower/internal/gen/controller/api"
|
||||
"google.golang.org/genproto/googleapis/rpc/errdetails"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func NotFoundErrorf(msg string, a ...interface{}) error {
|
||||
return status.Errorf(codes.NotFound, msg, a...)
|
||||
}
|
||||
|
||||
func InvalidArgumentErrorf(msg string, fields []string) error {
|
||||
st := status.New(codes.InvalidArgument, msg)
|
||||
br := &errdetails.BadRequest{}
|
||||
for _, f := range fields {
|
||||
br.FieldViolations = append(br.FieldViolations, &errdetails.BadRequest_FieldViolation{Field: f})
|
||||
}
|
||||
st, err := st.WithDetails(br)
|
||||
if err != nil {
|
||||
hclog.Default().Error("failure building status with details", "details", br, "error", err)
|
||||
return status.Error(codes.Internal, "Failed to build InvalidArgument error.")
|
||||
}
|
||||
return st.Err()
|
||||
}
|
||||
|
||||
func statusErrorToApiError(s *status.Status) *pb.Error {
|
||||
apiErr := &pb.Error{}
|
||||
apiErr.Status = int32(runtime.HTTPStatusFromCode(s.Code()))
|
||||
apiErr.Message = s.Message()
|
||||
// TODO(ICU-193): Decouple from the status codes and instead use codes defined specifically for our API.
|
||||
apiErr.Code = s.Code().String()
|
||||
|
||||
for _, ed := range s.Details() {
|
||||
switch ed.(type) {
|
||||
case *errdetails.BadRequest:
|
||||
br := ed.(*errdetails.BadRequest)
|
||||
for _, fv := range br.GetFieldViolations() {
|
||||
if apiErr.Details == nil {
|
||||
apiErr.Details = &pb.ErrorDetails{}
|
||||
}
|
||||
apiErr.Details.RequestFields = append(apiErr.Details.RequestFields, fv.GetField())
|
||||
}
|
||||
}
|
||||
}
|
||||
return apiErr
|
||||
}
|
||||
|
||||
// TODO(ICU-194): Remove all information from internal errors.
|
||||
func ErrorHandler(logger hclog.Logger) runtime.ProtoErrorHandlerFunc {
|
||||
const errorFallback = `{"error": "failed to marshal error message"}`
|
||||
return func(ctx context.Context, _ *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, r *http.Request, inErr error) {
|
||||
if inErr == runtime.ErrUnknownURI {
|
||||
// grpc gateway uses this error when the path was not matched, but the error uses codes.Unimplemented which doesn't match the intention.
|
||||
// Overwrite the error to match our expected behavior.
|
||||
inErr = status.Error(codes.NotFound, http.StatusText(http.StatusNotFound))
|
||||
}
|
||||
s, ok := status.FromError(inErr)
|
||||
if !ok {
|
||||
s = status.New(codes.Unknown, inErr.Error())
|
||||
}
|
||||
apiErr := statusErrorToApiError(s)
|
||||
buf, merr := marshaler.Marshal(apiErr)
|
||||
if merr != nil {
|
||||
logger.Error("failed to marshal error response", "response", fmt.Sprintf("%#v", apiErr), "error", merr)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
if _, err := io.WriteString(w, errorFallback); err != nil {
|
||||
logger.Error("failed to write response", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", marshaler.ContentType())
|
||||
w.WriteHeader(int(apiErr.GetStatus()))
|
||||
if _, err := w.Write(buf); err != nil {
|
||||
logger.Error("failed to send response chunk", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,101 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/grpc-ecosystem/grpc-gateway/runtime"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
pb "github.com/hashicorp/watchtower/internal/gen/controller/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
sdpb "google.golang.org/genproto/googleapis/rpc/errdetails"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestApiErrorHandler(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
req, err := http.NewRequest("GET", "madeup/for/the/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't create test request")
|
||||
}
|
||||
mux := runtime.NewServeMux()
|
||||
_, outMarsh := runtime.MarshalerForRequest(mux, req)
|
||||
|
||||
tested := ErrorHandler(hclog.L())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
err error
|
||||
statusDetails []proto.Message
|
||||
expected *pb.Error
|
||||
}{
|
||||
{
|
||||
name: "Not Found",
|
||||
err: status.Error(codes.NotFound, "test"),
|
||||
expected: &pb.Error{
|
||||
Status: 404,
|
||||
Code: codes.NotFound.String(),
|
||||
Message: "test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "GrpcGateway Routing Error",
|
||||
err: runtime.ErrUnknownURI,
|
||||
expected: &pb.Error{
|
||||
Status: 404,
|
||||
Code: codes.NotFound.String(),
|
||||
Message: http.StatusText(http.StatusNotFound),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid Fields",
|
||||
err: status.Error(codes.InvalidArgument, "test"),
|
||||
statusDetails: []proto.Message{
|
||||
&sdpb.BadRequest{
|
||||
FieldViolations: []*sdpb.BadRequest_FieldViolation{
|
||||
{Field: "first"},
|
||||
{Field: "second"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &pb.Error{
|
||||
Status: 400,
|
||||
Code: codes.InvalidArgument.String(),
|
||||
Message: "test",
|
||||
Details: &pb.ErrorDetails{
|
||||
RequestFields: []string{"first", "second"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if tc.statusDetails != nil {
|
||||
s, ok := status.FromError(tc.err)
|
||||
assert.True(ok)
|
||||
s, err := s.WithDetails(tc.statusDetails...)
|
||||
assert.NoError(err)
|
||||
tc.err = s.Err()
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tested(ctx, mux, outMarsh, w, req, tc.err)
|
||||
resp := w.Result()
|
||||
assert.EqualValues(tc.expected.Status, resp.StatusCode)
|
||||
|
||||
got, err := ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(err)
|
||||
want, err := outMarsh.Marshal(tc.expected)
|
||||
t.Logf("Got marshaled error: %q", want)
|
||||
assert.NoError(err)
|
||||
assert.JSONEq(string(want), string(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in new issue