diff --git a/api/client.go b/api/client.go index f23edbc556..5654256ccb 100644 --- a/api/client.go +++ b/api/client.go @@ -228,35 +228,41 @@ func (c *Config) ConfigureTLS() error { // // This also removes any trailing "/v1"; we'll use that in our commands so we // don't require it from users. -func (c *Config) setAddress(addr string) { - defer func() { - c.Address = strings.TrimSuffix(c.Address, "/") - c.Address = strings.TrimSuffix(c.Address, "/v1") - }() - - split := strings.Split(strings.TrimSuffix("/", addr), "/") - if len(split) < 5 { - return - } - - // ..../org//project/ - if split[len(split)-4] == "org" { - c.Org = split[len(split)-3] - if split[len(split)-2] == "project" { - c.Project = split[len(split)-1] - } - c.Address = strings.Join(split[0:len(split)-4], "/") - return +func (c *Config) setAddress(addr string) error { + u, err := url.Parse(addr) + if err != nil { + return fmt.Errorf("error parsing address: %w", err) } + c.Address = fmt.Sprintf("%s://%s", u.Scheme, u.Host) - // ..../org//project/ - if split[len(split)-2] == "org" { - c.Org = split[len(split)-1] - c.Address = strings.Join(split[0:len(split)-2], "/") - return + path := strings.TrimPrefix(u.Path, "/v1") + path = strings.TrimPrefix(path, "/") + if path == "" { + return nil } - return + split := strings.Split(path, "/") + switch len(split) { + case 0: + case 2: + if split[0] != "org" { + return fmt.Errorf("expected org segment in address, found %q", split[0]) + } + c.Org = split[1] + case 4: + if split[0] != "org" { + return fmt.Errorf("expected org segment in address, found %q", split[0]) + } + c.Org = split[1] + if split[2] != "project" { + return fmt.Errorf("expected project segment in address, found %q", split[2]) + } + c.Project = split[3] + default: + return fmt.Errorf("unexpected number of segments in address") + } + + return nil } // ReadEnvironment reads configuration information from the environment. If @@ -425,7 +431,9 @@ func NewClient(c *Config) (*Client, error) { } } - c.setAddress(c.Address) + if err := c.setAddress(c.Address); err != nil { + return nil, err + } return &Client{ config: c, @@ -435,11 +443,11 @@ func NewClient(c *Config) (*Client, error) { // Sets the address of Watchtower in the client. The format of address should // be "://:". Setting this on a client will override the // value of the WATCHTOWER_ADDR environment variable. -func (c *Client) SetAddress(addr string) { +func (c *Client) SetAddress(addr string) error { c.modifyLock.Lock() defer c.modifyLock.Unlock() - c.config.setAddress(addr) + return c.config.setAddress(addr) } // SetLimiter will set the rate limiter for this client. This method is @@ -638,7 +646,7 @@ func (c *Client) NewRequest(ctx context.Context, method, requestPath string, bod User: u.User, Scheme: u.Scheme, Host: host, - Path: path.Join(u.Path, requestPath, "v1", "org", org, "project", project), + Path: path.Join(u.Path, "v1", "org", org, "project", project, requestPath), }, Host: u.Host, } diff --git a/api/client_test.go b/api/client_test.go new file mode 100644 index 0000000000..2d451c85b7 --- /dev/null +++ b/api/client_test.go @@ -0,0 +1,90 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfigSetAddress(t *testing.T) { + type test struct { + name string + input string + address string + err string + org string + project string + } + + tests := []test{ + { + "bare", + "http://127.0.0.1:9200", + "http://127.0.0.1:9200", + "", + "", + "", + }, + { + "bare with version", + "http://127.0.0.1:9200/v1", + "http://127.0.0.1:9200", + "", + "", + "", + }, + { + "bare with version and trailing slash", + "http://127.0.0.1:9200/v1/", + "http://127.0.0.1:9200", + "", + "", + "", + }, + { + "invalid org", + "http://127.0.0.1:9200/v1/org", + "http://127.0.0.1:9200", + "unexpected number of segments in address", + "", + "", + }, + { + "valid org", + "http://127.0.0.1:9200/v1/org/orgid", + "http://127.0.0.1:9200", + "", + "orgid", + "", + }, + { + "invalid project", + "http://127.0.0.1:9200/v1/org/orgid/project", + "http://127.0.0.1:9200", + "unexpected number of segments in address", + "", + "", + }, + { + "valid project", + "http://127.0.0.1:9200/v1/org/orgid/project/projid", + "http://127.0.0.1:9200", + "", + "orgid", + "projid", + }, + } + + for _, v := range tests { + t.Run(v.name, func(t *testing.T) { + var c Config + err := c.setAddress(v.input) + if err != nil { + assert.Equal(t, v.err, err.Error()) + } + assert.Equal(t, v.address, c.Address) + assert.Equal(t, v.org, c.Org) + assert.Equal(t, v.project, c.Project) + }) + } +}