Add test for setAddress

apigen
Jeff Mitchell 6 years ago
parent a2e550729f
commit e8a35e2e11

@ -228,35 +228,41 @@ func (c *Config) ConfigureTLS() error {
//
// This also removes any trailing "/v1"; we'll use that in our commands so we
// don't require it from users.
func (c *Config) setAddress(addr string) {
defer func() {
c.Address = strings.TrimSuffix(c.Address, "/")
c.Address = strings.TrimSuffix(c.Address, "/v1")
}()
split := strings.Split(strings.TrimSuffix("/", addr), "/")
if len(split) < 5 {
return
}
// ..../org/<org id>/project/<project id>
if split[len(split)-4] == "org" {
c.Org = split[len(split)-3]
if split[len(split)-2] == "project" {
c.Project = split[len(split)-1]
}
c.Address = strings.Join(split[0:len(split)-4], "/")
return
func (c *Config) setAddress(addr string) error {
u, err := url.Parse(addr)
if err != nil {
return fmt.Errorf("error parsing address: %w", err)
}
c.Address = fmt.Sprintf("%s://%s", u.Scheme, u.Host)
// ..../org/<org id>/project/<project id>
if split[len(split)-2] == "org" {
c.Org = split[len(split)-1]
c.Address = strings.Join(split[0:len(split)-2], "/")
return
path := strings.TrimPrefix(u.Path, "/v1")
path = strings.TrimPrefix(path, "/")
if path == "" {
return nil
}
return
split := strings.Split(path, "/")
switch len(split) {
case 0:
case 2:
if split[0] != "org" {
return fmt.Errorf("expected org segment in address, found %q", split[0])
}
c.Org = split[1]
case 4:
if split[0] != "org" {
return fmt.Errorf("expected org segment in address, found %q", split[0])
}
c.Org = split[1]
if split[2] != "project" {
return fmt.Errorf("expected project segment in address, found %q", split[2])
}
c.Project = split[3]
default:
return fmt.Errorf("unexpected number of segments in address")
}
return nil
}
// ReadEnvironment reads configuration information from the environment. If
@ -425,7 +431,9 @@ func NewClient(c *Config) (*Client, error) {
}
}
c.setAddress(c.Address)
if err := c.setAddress(c.Address); err != nil {
return nil, err
}
return &Client{
config: c,
@ -435,11 +443,11 @@ func NewClient(c *Config) (*Client, error) {
// Sets the address of Watchtower in the client. The format of address should
// be "<Scheme>://<Host>:<Port>". Setting this on a client will override the
// value of the WATCHTOWER_ADDR environment variable.
func (c *Client) SetAddress(addr string) {
func (c *Client) SetAddress(addr string) error {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.config.setAddress(addr)
return c.config.setAddress(addr)
}
// SetLimiter will set the rate limiter for this client. This method is
@ -638,7 +646,7 @@ func (c *Client) NewRequest(ctx context.Context, method, requestPath string, bod
User: u.User,
Scheme: u.Scheme,
Host: host,
Path: path.Join(u.Path, requestPath, "v1", "org", org, "project", project),
Path: path.Join(u.Path, "v1", "org", org, "project", project, requestPath),
},
Host: u.Host,
}

@ -0,0 +1,90 @@
package api
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfigSetAddress(t *testing.T) {
type test struct {
name string
input string
address string
err string
org string
project string
}
tests := []test{
{
"bare",
"http://127.0.0.1:9200",
"http://127.0.0.1:9200",
"",
"",
"",
},
{
"bare with version",
"http://127.0.0.1:9200/v1",
"http://127.0.0.1:9200",
"",
"",
"",
},
{
"bare with version and trailing slash",
"http://127.0.0.1:9200/v1/",
"http://127.0.0.1:9200",
"",
"",
"",
},
{
"invalid org",
"http://127.0.0.1:9200/v1/org",
"http://127.0.0.1:9200",
"unexpected number of segments in address",
"",
"",
},
{
"valid org",
"http://127.0.0.1:9200/v1/org/orgid",
"http://127.0.0.1:9200",
"",
"orgid",
"",
},
{
"invalid project",
"http://127.0.0.1:9200/v1/org/orgid/project",
"http://127.0.0.1:9200",
"unexpected number of segments in address",
"",
"",
},
{
"valid project",
"http://127.0.0.1:9200/v1/org/orgid/project/projid",
"http://127.0.0.1:9200",
"",
"orgid",
"projid",
},
}
for _, v := range tests {
t.Run(v.name, func(t *testing.T) {
var c Config
err := c.setAddress(v.input)
if err != nil {
assert.Equal(t, v.err, err.Error())
}
assert.Equal(t, v.address, c.Address)
assert.Equal(t, v.org, c.Org)
assert.Equal(t, v.project, c.Project)
})
}
}
Loading…
Cancel
Save