diff --git a/internal/cmd/commands/controller/dev_flags.go b/internal/cmd/commands/controller/dev_flags.go index 5c43524fb3..2723bb3a5e 100644 --- a/internal/cmd/commands/controller/dev_flags.go +++ b/internal/cmd/commands/controller/dev_flags.go @@ -13,6 +13,6 @@ func addDevOnlyControllerFlags(c *Command, f *base.FlagSet) { Name: "dev-passthrough-directory", Target: &c.flagDevPassthroughDirectory, EnvVar: "WATCHTOWER_DEV_PASSTHROUGH_DIRECTORY", - Usage: "Enables a passthrough directory in the webserver at /passthrough", + Usage: "Enables a passthrough directory in the webserver at /", }) } diff --git a/internal/cmd/commands/dev/dev_flags.go b/internal/cmd/commands/dev/dev_flags.go index 9cb3d30215..0cb86cb2fa 100644 --- a/internal/cmd/commands/dev/dev_flags.go +++ b/internal/cmd/commands/dev/dev_flags.go @@ -13,6 +13,6 @@ func addDevOnlyControllerFlags(c *Command, f *base.FlagSet) { Name: "dev-passthrough-directory", Target: &c.flagDevPassthroughDirectory, EnvVar: "WATCHTOWER_DEV_PASSTHROUGH_DIRECTORY", - Usage: "Enables a passthrough directory in the webserver at /passthrough", + Usage: "Enables a passthrough directory in the webserver at /", }) } diff --git a/internal/servers/controller/handler.go b/internal/servers/controller/handler.go index eedf9a9f3d..7a5e73a9bf 100644 --- a/internal/servers/controller/handler.go +++ b/internal/servers/controller/handler.go @@ -1,10 +1,13 @@ package controller import ( + "bytes" "context" "encoding/json" "fmt" + "io/ioutil" "net/http" + "os" "path" "path/filepath" "strings" @@ -37,32 +40,135 @@ func (c *Controller) handler(props HandlerProperties) (http.Handler, error) { // Create the muxer to handle the actual endpoints mux := http.NewServeMux() - if c.conf.RawConfig.PassthroughDirectory != "" { - // Panic may not be ideal but this is never a production call and it'll - // panic on startup. We could also just change the function to return - // an error. - abs, err := filepath.Abs(c.conf.RawConfig.PassthroughDirectory) - if err != nil { - panic(err) - } - c.logger.Warn("serving passthrough files at /", "path", abs) - fs := http.FileServer(http.Dir(abs)) - prefixHandler := http.StripPrefix("/", fs) - mux.Handle("/", prefixHandler) - } - h, err := handleGrpcGateway(c) if err != nil { return nil, err } mux.Handle("/v1/", h) + // TODO: enable when not in this mode, when we bundle the assets + if c.conf.RawConfig.PassthroughDirectory != "" { + mux.Handle("/", handleUi(c)) + } + corsWrappedHandler := wrapHandlerWithCors(mux, props) commonWrappedHandler := wrapHandlerWithCommonFuncs(corsWrappedHandler, c, props) return commonWrappedHandler, nil } +func handleUi(c *Controller) http.Handler { + // TODO: Do stuff with real UI data when it's bundled. We may also have to + // do a similar thing with fetching index.html in advance. + var nextHandler http.Handler + var indexBytes []byte + var modTime time.Time + if c.conf.RawConfig.PassthroughDirectory != "" { + nextHandler, indexBytes, modTime = devPassthroughHandler(c) + } + + returnIndexBytes := func(w http.ResponseWriter, r *http.Request) { + _, file := filepath.Split(r.URL.Path) + rw := newIndexResponseWriter() + http.ServeContent(rw, r, file, modTime, bytes.NewReader(indexBytes)) + for k, v := range rw.header { + for _, i := range v { + w.Header().Add(k, i) + } + } + w.Header().Set("content-type", "text/html; charset=utf-8") + w.WriteHeader(rw.statusCode) + w.Write(rw.body.Bytes()) + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + dotIndex := strings.LastIndex(r.URL.Path, ".") + switch dotIndex { + case -1: + // For all paths without an extension serve /index.html + returnIndexBytes(w, r) + return + + default: + switch r.URL.Path { + case "/index.html": + // Because of the special handling of http.FileServer this fails + // in dev passthrough mode so we handle it specifically + returnIndexBytes(w, r) + return + + case "/favicon.png", "/assets/styles.css": + // This is purely an optimization, it'd fall through below + // outside of this case + nextHandler.ServeHTTP(w, r) + return + + default: + for i := dotIndex + 1; i < len(r.URL.Path); i++ { + intVal := r.URL.Path[i] + // Current guidance from FE is if it's only alphanum after + // the last dot, treat it as an extension + if intVal < '0' || + (intVal > '9' && intVal < 'A') || + (intVal > 'Z' && intVal < 'a') || + intVal > 'z' { + // Not an extension. Serve the contents of index.html + returnIndexBytes(w, r) + return + } + } + } + } + + // Fall through to the next handler + nextHandler.ServeHTTP(w, r) + }) +} + +func devPassthroughHandler(c *Controller) (http.Handler, []byte, time.Time) { + // Panic may not be ideal but this is never a production call and it'll + // panic on startup. We could also just change the function to return + // an error. + abs, err := filepath.Abs(c.conf.RawConfig.PassthroughDirectory) + if err != nil { + panic(err) + } + c.logger.Warn("serving passthrough files at /", "path", abs) + fs := http.FileServer(http.Dir(abs)) + prefixHandler := http.StripPrefix("/", fs) + + // We need to read index.html because http.ServeFile has special handling + // for that file that we don't want + file, err := os.Open(filepath.Join(abs, "index.html")) + if err != nil { + c.logger.Warn("unable to open index.html in the dev passthrough directory, if it exists") + return prefixHandler, nil, time.Time{} + } + defer file.Close() + + fileInfo, err := file.Stat() + if err != nil { + c.logger.Warn("unable to stat index.html in the dev passthrough directory, if it exists") + return prefixHandler, nil, time.Time{} + } + modTime := fileInfo.ModTime() + + // Easier to just do an ioutil.ReadAll than deal with the lower level read + // methods, even though we're opening twice + indexBytes, err := ioutil.ReadFile(filepath.Join(abs, "index.html")) + if err != nil { + c.logger.Warn("unable to read index.html bytes in the dev passthrough directory, if it exists") + return prefixHandler, nil, time.Time{} + } + + return prefixHandler, indexBytes, modTime +} + func handleGrpcGateway(c *Controller) (http.Handler, error) { // Register*ServiceHandlerServer methods ignore the passed in ctx. Using the baseContext now just in case this changes // in the future, at which point we'll want to be using the baseContext. diff --git a/internal/servers/controller/handler_test.go b/internal/servers/controller/handler_test.go index 3fea983aa9..08fcedbc66 100644 --- a/internal/servers/controller/handler_test.go +++ b/internal/servers/controller/handler_test.go @@ -1,9 +1,16 @@ package controller import ( + "bytes" "fmt" + "io/ioutil" "net/http" + "os" + "path/filepath" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestHandleGrpcGateway(t *testing.T) { @@ -39,3 +46,114 @@ func TestHandleGrpcGateway(t *testing.T) { }) } } + +func TestHandleDevPassthrough(t *testing.T) { + // Create a temporary directory + tempDir, err := ioutil.TempDir("", "watchtower-test-") + require.NoError(t, err) + defer func() { + assert.NoError(t, os.RemoveAll(tempDir)) + }() + + nameContentsMap := map[string]string{ + "index.html": `index`, + "favicon.png": `favicon`, + "/assets/styles.css": `css`, + "index.htm": `badindex`, + } + + for k, v := range nameContentsMap { + dir := filepath.Dir(k) + if dir != "/" { + require.NoError(t, os.MkdirAll(filepath.Join(tempDir, dir), 0755)) + } + require.NoError(t, ioutil.WriteFile(filepath.Join(tempDir, k), []byte(v), 0644)) + } + + c := NewTestController(t, &TestControllerOpts{DisableAutoStart: true}) + + c.c.conf.RawConfig.PassthroughDirectory = tempDir + require.NoError(t, c.c.Start()) + defer c.Shutdown() + + cases := []struct { + name string + path string + contentsKey string + code int + mimeType string + }{ + { + "direct index", + "index.html", + "index.html", + http.StatusOK, + "text/html; charset=utf-8", + }, + { + "no extension", + "orgs", + "index.html", + http.StatusOK, + "text/html; charset=utf-8", + }, + { + "favicon", + "favicon.png", + "favicon.png", + http.StatusOK, + "image/png", + }, + { + "bad index", + "index.htm", + "index.htm", + http.StatusOK, + "text/html; charset=utf-8", + }, + { + "bad path", + "index.ht", + "index.ht", + http.StatusNotFound, + "text/plain; charset=utf-8", + }, + { + "css", + "assets/styles.css", + "assets/styles.css", + http.StatusOK, + "text/css; charset=utf-8", + }, + { + "invalid extension", + "foo.bāb", + "index.html", + http.StatusOK, + "text/html; charset=utf-8", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) + + url := fmt.Sprintf("%s/%s", c.ApiAddrs()[0], tc.path) + resp, err := http.Post(url, "", nil) + assert.NoError(err) + assert.Equal(http.StatusMethodNotAllowed, resp.StatusCode) + + resp, err = http.Get(url) + assert.NoError(err) + assert.Equal(tc.code, resp.StatusCode) + assert.Equal(tc.mimeType, resp.Header.Get("content-type")) + + contents, ok := nameContentsMap[tc.contentsKey] + if ok { + reader := new(bytes.Buffer) + _, err = reader.ReadFrom(resp.Body) + assert.NoError(err) + assert.Equal(contents, reader.String()) + } + }) + } +} diff --git a/internal/servers/controller/index_response_writer.go b/internal/servers/controller/index_response_writer.go new file mode 100644 index 0000000000..3aee140f3e --- /dev/null +++ b/internal/servers/controller/index_response_writer.go @@ -0,0 +1,32 @@ +package controller + +import ( + "bytes" + "net/http" +) + +type indexResponseWriter struct { + statusCode int + header http.Header + body *bytes.Buffer +} + +// newindexResponseWriter returns an initialized indexResponseWriter +func newIndexResponseWriter() *indexResponseWriter { + return &indexResponseWriter{ + header: make(http.Header), + body: new(bytes.Buffer), + } +} + +func (w *indexResponseWriter) Header() http.Header { + return w.header +} + +func (w *indexResponseWriter) Write(buf []byte) (int, error) { + return w.body.Write(buf) +} + +func (w *indexResponseWriter) WriteHeader(code int) { + w.statusCode = code +} diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index f4907d983c..194071caa1 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -117,6 +117,9 @@ type TestControllerOpts struct { // DisableDatabaseCreation can be set true to disable creating a dev // database DisableDatabaseCreation bool + + // If true, the controller will not be started + DisableAutoStart bool } func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { @@ -192,9 +195,11 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { tc.buildClient() - if err := tc.c.Start(); err != nil { - tc.Shutdown() - t.Fatal(err) + if !opts.DisableAutoStart { + if err := tc.c.Start(); err != nil { + tc.Shutdown() + t.Fatal(err) + } } return tc