From f60921ad4b4f9eceda11474c8f08ff81b9f51cb1 Mon Sep 17 00:00:00 2001 From: Christopher Boumenot Date: Wed, 11 Jul 2018 14:17:56 -0700 Subject: [PATCH] azure: upgrade Azure/go-autorest to v10.12.0 --- .../Azure/go-autorest/autorest/adal/config.go | 8 +- .../Azure/go-autorest/autorest/adal/msi.go | 20 - .../go-autorest/autorest/adal/msi_windows.go | 25 - .../Azure/go-autorest/autorest/adal/token.go | 541 +++++++--- .../go-autorest/autorest/authorization.go | 22 +- .../Azure/go-autorest/autorest/autorest.go | 18 + .../Azure/go-autorest/autorest/azure/async.go | 945 +++++++++++++----- .../Azure/go-autorest/autorest/azure/azure.go | 27 +- .../Azure/go-autorest/autorest/azure/rp.go | 12 +- .../Azure/go-autorest/autorest/date/date.go | 14 + .../Azure/go-autorest/autorest/date/time.go | 14 + .../go-autorest/autorest/date/timerfc1123.go | 14 + .../go-autorest/autorest/date/unixtime.go | 14 + .../go-autorest/autorest/date/utility.go | 14 + .../Azure/go-autorest/autorest/sender.go | 22 +- .../Azure/go-autorest/autorest/to/convert.go | 14 + .../Azure/go-autorest/autorest/utility.go | 10 + .../autorest/validation/validation.go | 23 +- .../Azure/go-autorest/autorest/version.go | 2 +- vendor/vendor.json | 60 +- 20 files changed, 1306 insertions(+), 513 deletions(-) delete mode 100644 vendor/github.com/Azure/go-autorest/autorest/adal/msi.go delete mode 100644 vendor/github.com/Azure/go-autorest/autorest/adal/msi_windows.go diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/config.go b/vendor/github.com/Azure/go-autorest/autorest/adal/config.go index f570d540a..bee5e61dd 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/config.go +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/config.go @@ -26,10 +26,10 @@ const ( // OAuthConfig represents the endpoints needed // in OAuth operations type OAuthConfig struct { - AuthorityEndpoint url.URL - AuthorizeEndpoint url.URL - TokenEndpoint url.URL - DeviceCodeEndpoint url.URL + AuthorityEndpoint url.URL `json:"authorityEndpoint"` + AuthorizeEndpoint url.URL `json:"authorizeEndpoint"` + TokenEndpoint url.URL `json:"tokenEndpoint"` + DeviceCodeEndpoint url.URL `json:"deviceCodeEndpoint"` } // IsZero returns true if the OAuthConfig object is zero-initialized. diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/msi.go b/vendor/github.com/Azure/go-autorest/autorest/adal/msi.go deleted file mode 100644 index 5e02d52ac..000000000 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/msi.go +++ /dev/null @@ -1,20 +0,0 @@ -// +build !windows - -package adal - -// Copyright 2017 Microsoft Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// msiPath is the path to the MSI Extension settings file (to discover the endpoint) -var msiPath = "/var/lib/waagent/ManagedIdentity-Settings" diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/msi_windows.go b/vendor/github.com/Azure/go-autorest/autorest/adal/msi_windows.go deleted file mode 100644 index 261b56882..000000000 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/msi_windows.go +++ /dev/null @@ -1,25 +0,0 @@ -// +build windows - -package adal - -// Copyright 2017 Microsoft Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import ( - "os" - "strings" -) - -// msiPath is the path to the MSI Extension settings file (to discover the endpoint) -var msiPath = strings.Join([]string{os.Getenv("SystemDrive"), "WindowsAzure/Config/ManagedIdentity-Settings"}, "/") diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/token.go b/vendor/github.com/Azure/go-autorest/autorest/adal/token.go index d98504122..eec4dced7 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/token.go +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/token.go @@ -15,14 +15,18 @@ package adal // limitations under the License. import ( + "context" "crypto/rand" "crypto/rsa" "crypto/sha1" "crypto/x509" "encoding/base64" "encoding/json" + "errors" "fmt" "io/ioutil" + "math" + "net" "net/http" "net/url" "strconv" @@ -54,6 +58,12 @@ const ( // metadataHeader is the header required by MSI extension metadataHeader = "Metadata" + + // msiEndpoint is the well known endpoint for getting MSI authentications tokens + msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" + + // the default number of attempts to refresh an MSI authentication token + defaultMaxMSIRefreshAttempts = 5 ) // OAuthTokenProvider is an interface which should be implemented by an access token retriever @@ -74,6 +84,13 @@ type Refresher interface { EnsureFresh() error } +// RefresherWithContext is an interface for token refresh functionality +type RefresherWithContext interface { + RefreshWithContext(ctx context.Context) error + RefreshExchangeWithContext(ctx context.Context, resource string) error + EnsureFreshWithContext(ctx context.Context) error +} + // TokenRefreshCallback is the type representing callbacks that will be called after // a successful token refresh type TokenRefreshCallback func(Token) error @@ -124,6 +141,12 @@ func (t *Token) OAuthToken() string { return t.AccessToken } +// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form +// that is submitted when acquiring an oAuth token. +type ServicePrincipalSecret interface { + SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error +} + // ServicePrincipalNoSecret represents a secret type that contains no secret // meaning it is not valid for fetching a fresh token. This is used by Manual type ServicePrincipalNoSecret struct { @@ -135,15 +158,19 @@ func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePr return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token") } -// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form -// that is submitted when acquiring an oAuth token. -type ServicePrincipalSecret interface { - SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error +// MarshalJSON implements the json.Marshaler interface. +func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalNoSecret", + }) } // ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization. type ServicePrincipalTokenSecret struct { - ClientSecret string + ClientSecret string `json:"value"` } // SetAuthenticationValues is a method of the interface ServicePrincipalSecret. @@ -153,49 +180,24 @@ func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *Ser return nil } +// MarshalJSON implements the json.Marshaler interface. +func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + Value string `json:"value"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalTokenSecret", + Value: tokenSecret.ClientSecret, + }) +} + // ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs. type ServicePrincipalCertificateSecret struct { Certificate *x509.Certificate PrivateKey *rsa.PrivateKey } -// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension. -type ServicePrincipalMSISecret struct { -} - -// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth. -type ServicePrincipalUsernamePasswordSecret struct { - Username string - Password string -} - -// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth. -type ServicePrincipalAuthorizationCodeSecret struct { - ClientSecret string - AuthorizationCode string - RedirectURI string -} - -// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. -func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { - v.Set("code", secret.AuthorizationCode) - v.Set("client_secret", secret.ClientSecret) - v.Set("redirect_uri", secret.RedirectURI) - return nil -} - -// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. -func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { - v.Set("username", secret.Username) - v.Set("password", secret.Password) - return nil -} - -// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. -func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { - return nil -} - // SignJwt returns the JWT signed with the certificate's private key. func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) { hasher := sha1.New() @@ -216,9 +218,9 @@ func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalTo token := jwt.New(jwt.SigningMethodRS256) token.Header["x5t"] = thumbprint token.Claims = jwt.MapClaims{ - "aud": spt.oauthConfig.TokenEndpoint.String(), - "iss": spt.clientID, - "sub": spt.clientID, + "aud": spt.inner.OauthConfig.TokenEndpoint.String(), + "iss": spt.inner.ClientID, + "sub": spt.inner.ClientID, "jti": base64.URLEncoding.EncodeToString(jti), "nbf": time.Now().Unix(), "exp": time.Now().Add(time.Hour * 24).Unix(), @@ -241,19 +243,151 @@ func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *Se return nil } +// MarshalJSON implements the json.Marshaler interface. +func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) { + return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported") +} + +// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension. +type ServicePrincipalMSISecret struct { +} + +// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. +func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) { + return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported") +} + +// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth. +type ServicePrincipalUsernamePasswordSecret struct { + Username string `json:"username"` + Password string `json:"password"` +} + +// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. +func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { + v.Set("username", secret.Username) + v.Set("password", secret.Password) + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + Username string `json:"username"` + Password string `json:"password"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalUsernamePasswordSecret", + Username: secret.Username, + Password: secret.Password, + }) +} + +// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth. +type ServicePrincipalAuthorizationCodeSecret struct { + ClientSecret string `json:"value"` + AuthorizationCode string `json:"authCode"` + RedirectURI string `json:"redirect"` +} + +// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. +func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { + v.Set("code", secret.AuthorizationCode) + v.Set("client_secret", secret.ClientSecret) + v.Set("redirect_uri", secret.RedirectURI) + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + Value string `json:"value"` + AuthCode string `json:"authCode"` + Redirect string `json:"redirect"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalAuthorizationCodeSecret", + Value: secret.ClientSecret, + AuthCode: secret.AuthorizationCode, + Redirect: secret.RedirectURI, + }) +} + // ServicePrincipalToken encapsulates a Token created for a Service Principal. type ServicePrincipalToken struct { - token Token - secret ServicePrincipalSecret - oauthConfig OAuthConfig - clientID string - resource string - autoRefresh bool - refreshLock *sync.RWMutex - refreshWithin time.Duration - sender Sender - + inner servicePrincipalToken + refreshLock *sync.RWMutex + sender Sender refreshCallbacks []TokenRefreshCallback + // MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token. + MaxMSIRefreshAttempts int +} + +// MarshalTokenJSON returns the marshalled inner token. +func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) { + return json.Marshal(spt.inner.Token) +} + +// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks. +func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) { + spt.refreshCallbacks = callbacks +} + +// MarshalJSON implements the json.Marshaler interface. +func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) { + return json.Marshal(spt.inner) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error { + // need to determine the token type + raw := map[string]interface{}{} + err := json.Unmarshal(data, &raw) + if err != nil { + return err + } + secret := raw["secret"].(map[string]interface{}) + switch secret["type"] { + case "ServicePrincipalNoSecret": + spt.inner.Secret = &ServicePrincipalNoSecret{} + case "ServicePrincipalTokenSecret": + spt.inner.Secret = &ServicePrincipalTokenSecret{} + case "ServicePrincipalCertificateSecret": + return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported") + case "ServicePrincipalMSISecret": + return errors.New("unmarshalling ServicePrincipalMSISecret is not supported") + case "ServicePrincipalUsernamePasswordSecret": + spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{} + case "ServicePrincipalAuthorizationCodeSecret": + spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{} + default: + return fmt.Errorf("unrecognized token type '%s'", secret["type"]) + } + err = json.Unmarshal(data, &spt.inner) + if err != nil { + return err + } + spt.refreshLock = &sync.RWMutex{} + spt.sender = &http.Client{} + return nil +} + +// internal type used for marshalling/unmarshalling +type servicePrincipalToken struct { + Token Token `json:"token"` + Secret ServicePrincipalSecret `json:"secret"` + OauthConfig OAuthConfig `json:"oauth"` + ClientID string `json:"clientID"` + Resource string `json:"resource"` + AutoRefresh bool `json:"autoRefresh"` + RefreshWithin time.Duration `json:"refreshWithin"` } func validateOAuthConfig(oac OAuthConfig) error { @@ -278,13 +412,15 @@ func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, reso return nil, fmt.Errorf("parameter 'secret' cannot be nil") } spt := &ServicePrincipalToken{ - oauthConfig: oauthConfig, - secret: secret, - clientID: id, - resource: resource, - autoRefresh: true, + inner: servicePrincipalToken{ + OauthConfig: oauthConfig, + Secret: secret, + ClientID: id, + Resource: resource, + AutoRefresh: true, + RefreshWithin: defaultRefresh, + }, refreshLock: &sync.RWMutex{}, - refreshWithin: defaultRefresh, sender: &http.Client{}, refreshCallbacks: callbacks, } @@ -315,7 +451,39 @@ func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID s return nil, err } - spt.token = token + spt.inner.Token = token + + return spt, nil +} + +// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret +func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { + if err := validateOAuthConfig(oauthConfig); err != nil { + return nil, err + } + if err := validateStringParam(clientID, "clientID"); err != nil { + return nil, err + } + if err := validateStringParam(resource, "resource"); err != nil { + return nil, err + } + if secret == nil { + return nil, fmt.Errorf("parameter 'secret' cannot be nil") + } + if token.IsZero() { + return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized") + } + spt, err := NewServicePrincipalTokenWithSecret( + oauthConfig, + clientID, + resource, + secret, + callbacks...) + if err != nil { + return nil, err + } + + spt.inner.Token = token return spt, nil } @@ -441,24 +609,7 @@ func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clie // GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines. func GetMSIVMEndpoint() (string, error) { - return getMSIVMEndpoint(msiPath) -} - -func getMSIVMEndpoint(path string) (string, error) { - // Read MSI settings - bytes, err := ioutil.ReadFile(path) - if err != nil { - return "", err - } - msiSettings := struct { - URL string `json:"url"` - }{} - err = json.Unmarshal(bytes, &msiSettings) - if err != nil { - return "", err - } - - return msiSettings.URL, nil + return msiEndpoint, nil } // NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension. @@ -491,24 +642,32 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI return nil, err } - oauthConfig, err := NewOAuthConfig(msiEndpointURL.String(), "") - if err != nil { - return nil, err + v := url.Values{} + v.Set("resource", resource) + v.Set("api-version", "2018-02-01") + if userAssignedID != nil { + v.Set("client_id", *userAssignedID) } + msiEndpointURL.RawQuery = v.Encode() spt := &ServicePrincipalToken{ - oauthConfig: *oauthConfig, - secret: &ServicePrincipalMSISecret{}, - resource: resource, - autoRefresh: true, - refreshLock: &sync.RWMutex{}, - refreshWithin: defaultRefresh, - sender: &http.Client{}, - refreshCallbacks: callbacks, + inner: servicePrincipalToken{ + OauthConfig: OAuthConfig{ + TokenEndpoint: *msiEndpointURL, + }, + Secret: &ServicePrincipalMSISecret{}, + Resource: resource, + AutoRefresh: true, + RefreshWithin: defaultRefresh, + }, + refreshLock: &sync.RWMutex{}, + sender: &http.Client{}, + refreshCallbacks: callbacks, + MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts, } if userAssignedID != nil { - spt.clientID = *userAssignedID + spt.inner.ClientID = *userAssignedID } return spt, nil @@ -537,12 +696,18 @@ func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError // EnsureFresh will refresh the token if it will expire within the refresh window (as set by // RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. func (spt *ServicePrincipalToken) EnsureFresh() error { - if spt.autoRefresh && spt.token.WillExpireIn(spt.refreshWithin) { + return spt.EnsureFreshWithContext(context.Background()) +} + +// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by +// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. +func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error { + if spt.inner.AutoRefresh && spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) { // take the write lock then check to see if the token was already refreshed spt.refreshLock.Lock() defer spt.refreshLock.Unlock() - if spt.token.WillExpireIn(spt.refreshWithin) { - return spt.refreshInternal(spt.resource) + if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) { + return spt.refreshInternal(ctx, spt.inner.Resource) } } return nil @@ -552,7 +717,7 @@ func (spt *ServicePrincipalToken) EnsureFresh() error { func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { if spt.refreshCallbacks != nil { for _, callback := range spt.refreshCallbacks { - err := callback(spt.token) + err := callback(spt.inner.Token) if err != nil { return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err) } @@ -564,21 +729,33 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { // Refresh obtains a fresh token for the Service Principal. // This method is not safe for concurrent use and should be syncrhonized. func (spt *ServicePrincipalToken) Refresh() error { + return spt.RefreshWithContext(context.Background()) +} + +// RefreshWithContext obtains a fresh token for the Service Principal. +// This method is not safe for concurrent use and should be syncrhonized. +func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error { spt.refreshLock.Lock() defer spt.refreshLock.Unlock() - return spt.refreshInternal(spt.resource) + return spt.refreshInternal(ctx, spt.inner.Resource) } // RefreshExchange refreshes the token, but for a different resource. // This method is not safe for concurrent use and should be syncrhonized. func (spt *ServicePrincipalToken) RefreshExchange(resource string) error { + return spt.RefreshExchangeWithContext(context.Background(), resource) +} + +// RefreshExchangeWithContext refreshes the token, but for a different resource. +// This method is not safe for concurrent use and should be syncrhonized. +func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { spt.refreshLock.Lock() defer spt.refreshLock.Unlock() - return spt.refreshInternal(resource) + return spt.refreshInternal(ctx, resource) } func (spt *ServicePrincipalToken) getGrantType() string { - switch spt.secret.(type) { + switch spt.inner.Secret.(type) { case *ServicePrincipalUsernamePasswordSecret: return OAuthGrantTypeUserPass case *ServicePrincipalAuthorizationCodeSecret: @@ -588,37 +765,64 @@ func (spt *ServicePrincipalToken) getGrantType() string { } } -func (spt *ServicePrincipalToken) refreshInternal(resource string) error { - v := url.Values{} - v.Set("client_id", spt.clientID) - v.Set("resource", resource) - - if spt.token.RefreshToken != "" { - v.Set("grant_type", OAuthGrantTypeRefreshToken) - v.Set("refresh_token", spt.token.RefreshToken) - } else { - v.Set("grant_type", spt.getGrantType()) - err := spt.secret.SetAuthenticationValues(spt, &v) - if err != nil { - return err - } +func isIMDS(u url.URL) bool { + imds, err := url.Parse(msiEndpoint) + if err != nil { + return false } + return u.Host == imds.Host && u.Path == imds.Path +} - s := v.Encode() - body := ioutil.NopCloser(strings.NewReader(s)) - req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), body) +func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error { + req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil) if err != nil { return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err) } + req = req.WithContext(ctx) + if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) { + v := url.Values{} + v.Set("client_id", spt.inner.ClientID) + v.Set("resource", resource) + + if spt.inner.Token.RefreshToken != "" { + v.Set("grant_type", OAuthGrantTypeRefreshToken) + v.Set("refresh_token", spt.inner.Token.RefreshToken) + // web apps must specify client_secret when refreshing tokens + // see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens + if spt.getGrantType() == OAuthGrantTypeAuthorizationCode { + err := spt.inner.Secret.SetAuthenticationValues(spt, &v) + if err != nil { + return err + } + } + } else { + v.Set("grant_type", spt.getGrantType()) + err := spt.inner.Secret.SetAuthenticationValues(spt, &v) + if err != nil { + return err + } + } + + s := v.Encode() + body := ioutil.NopCloser(strings.NewReader(s)) + req.ContentLength = int64(len(s)) + req.Header.Set(contentType, mimeTypeFormPost) + req.Body = body + } - req.ContentLength = int64(len(s)) - req.Header.Set(contentType, mimeTypeFormPost) - if _, ok := spt.secret.(*ServicePrincipalMSISecret); ok { + if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok { + req.Method = http.MethodGet req.Header.Set(metadataHeader, "true") } - resp, err := spt.sender.Do(req) + + var resp *http.Response + if isIMDS(spt.inner.OauthConfig.TokenEndpoint) { + resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts) + } else { + resp, err = spt.sender.Do(req) + } if err != nil { - return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err) + return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil) } defer resp.Body.Close() @@ -626,11 +830,15 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error { if resp.StatusCode != http.StatusOK { if err != nil { - return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body", resp.StatusCode), resp) + return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp) } return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp) } + // for the following error cases don't return a TokenRefreshError. the operation succeeded + // but some transient failure happened during deserialization. by returning a generic error + // the retry logic will kick in (we don't retry on TokenRefreshError). + if err != nil { return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err) } @@ -643,20 +851,99 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error { return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)) } - spt.token = token + spt.inner.Token = token return spt.InvokeRefreshCallbacks(token) } +// retry logic specific to retrieving a token from the IMDS endpoint +func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) { + // copied from client.go due to circular dependency + retries := []int{ + http.StatusRequestTimeout, // 408 + http.StatusTooManyRequests, // 429 + http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 + } + // extra retry status codes specific to IMDS + retries = append(retries, + http.StatusNotFound, + http.StatusGone, + // all remaining 5xx + http.StatusNotImplemented, + http.StatusHTTPVersionNotSupported, + http.StatusVariantAlsoNegotiates, + http.StatusInsufficientStorage, + http.StatusLoopDetected, + http.StatusNotExtended, + http.StatusNetworkAuthenticationRequired) + + // see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance + + const maxDelay time.Duration = 60 * time.Second + + attempt := 0 + delay := time.Duration(0) + + for attempt < maxAttempts { + resp, err = sender.Do(req) + // retry on temporary network errors, e.g. transient network failures. + // if we don't receive a response then assume we can't connect to the + // endpoint so we're likely not running on an Azure VM so don't retry. + if (err != nil && !isTemporaryNetworkError(err)) || resp == nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) { + return + } + + // perform exponential backoff with a cap. + // must increment attempt before calculating delay. + attempt++ + // the base value of 2 is the "delta backoff" as specified in the guidance doc + delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second) + if delay > maxDelay { + delay = maxDelay + } + + select { + case <-time.After(delay): + // intentionally left blank + case <-req.Context().Done(): + err = req.Context().Err() + return + } + } + return +} + +// returns true if the specified error is a temporary network error or false if it's not. +// if the error doesn't implement the net.Error interface the return value is true. +func isTemporaryNetworkError(err error) bool { + if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) { + return true + } + return false +} + +// returns true if slice ints contains the value n +func containsInt(ints []int, n int) bool { + for _, i := range ints { + if i == n { + return true + } + } + return false +} + // SetAutoRefresh enables or disables automatic refreshing of stale tokens. func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) { - spt.autoRefresh = autoRefresh + spt.inner.AutoRefresh = autoRefresh } // SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will // refresh the token. func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) { - spt.refreshWithin = d + spt.inner.RefreshWithin = d return } @@ -668,12 +955,12 @@ func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s } func (spt *ServicePrincipalToken) OAuthToken() string { spt.refreshLock.RLock() defer spt.refreshLock.RUnlock() - return spt.token.OAuthToken() + return spt.inner.Token.OAuthToken() } // Token returns a copy of the current token. func (spt *ServicePrincipalToken) Token() Token { spt.refreshLock.RLock() defer spt.refreshLock.RUnlock() - return spt.token + return spt.inner.Token } diff --git a/vendor/github.com/Azure/go-autorest/autorest/authorization.go b/vendor/github.com/Azure/go-autorest/autorest/authorization.go index c51eac0a7..77eff45bd 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/authorization.go +++ b/vendor/github.com/Azure/go-autorest/autorest/authorization.go @@ -113,17 +113,19 @@ func (ba *BearerAuthorizer) WithAuthorization() PrepareDecorator { return PreparerFunc(func(r *http.Request) (*http.Request, error) { r, err := p.Prepare(r) if err == nil { - refresher, ok := ba.tokenProvider.(adal.Refresher) - if ok { - err := refresher.EnsureFresh() - if err != nil { - var resp *http.Response - if tokError, ok := err.(adal.TokenRefreshError); ok { - resp = tokError.Response() - } - return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp, - "Failed to refresh the Token for request to %s", r.URL) + // the ordering is important here, prefer RefresherWithContext if available + if refresher, ok := ba.tokenProvider.(adal.RefresherWithContext); ok { + err = refresher.EnsureFreshWithContext(r.Context()) + } else if refresher, ok := ba.tokenProvider.(adal.Refresher); ok { + err = refresher.EnsureFresh() + } + if err != nil { + var resp *http.Response + if tokError, ok := err.(adal.TokenRefreshError); ok { + resp = tokError.Response() } + return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp, + "Failed to refresh the Token for request to %s", r.URL) } return Prepare(r, WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", ba.tokenProvider.OAuthToken()))) } diff --git a/vendor/github.com/Azure/go-autorest/autorest/autorest.go b/vendor/github.com/Azure/go-autorest/autorest/autorest.go index f86b66a41..aafdf021f 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/autorest.go +++ b/vendor/github.com/Azure/go-autorest/autorest/autorest.go @@ -72,6 +72,7 @@ package autorest // limitations under the License. import ( + "context" "net/http" "time" ) @@ -130,3 +131,20 @@ func NewPollingRequest(resp *http.Response, cancel <-chan struct{}) (*http.Reque return req, nil } + +// NewPollingRequestWithContext allocates and returns a new http.Request with the specified context to poll for the passed response. +func NewPollingRequestWithContext(ctx context.Context, resp *http.Response) (*http.Request, error) { + location := GetLocation(resp) + if location == "" { + return nil, NewErrorWithResponse("autorest", "NewPollingRequestWithContext", resp, "Location header missing from response that requires polling") + } + + req, err := Prepare((&http.Request{}).WithContext(ctx), + AsGet(), + WithBaseURL(location)) + if err != nil { + return nil, NewErrorWithError(err, "autorest", "NewPollingRequestWithContext", nil, "Failure creating poll request to %s", location) + } + + return req, nil +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/azure/async.go b/vendor/github.com/Azure/go-autorest/autorest/azure/async.go index a18b61041..cda1e180a 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/azure/async.go +++ b/vendor/github.com/Azure/go-autorest/autorest/azure/async.go @@ -21,11 +21,11 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "strings" "time" "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/date" ) const ( @@ -44,84 +44,85 @@ var pollingCodes = [...]int{http.StatusNoContent, http.StatusAccepted, http.Stat // Future provides a mechanism to access the status and results of an asynchronous request. // Since futures are stateful they should be passed by value to avoid race conditions. type Future struct { - req *http.Request - resp *http.Response - ps pollingState + req *http.Request // legacy + pt pollingTracker } // NewFuture returns a new Future object initialized with the specified request. +// Deprecated: Please use NewFutureFromResponse instead. func NewFuture(req *http.Request) Future { return Future{req: req} } -// Response returns the last HTTP response or nil if there isn't one. +// NewFutureFromResponse returns a new Future object initialized +// with the initial response from an asynchronous operation. +func NewFutureFromResponse(resp *http.Response) (Future, error) { + pt, err := createPollingTracker(resp) + if err != nil { + return Future{}, err + } + return Future{pt: pt}, nil +} + +// Response returns the last HTTP response. func (f Future) Response() *http.Response { - return f.resp + if f.pt == nil { + return nil + } + return f.pt.latestResponse() } // Status returns the last status message of the operation. func (f Future) Status() string { - if f.ps.State == "" { - return "Unknown" + if f.pt == nil { + return "" } - return f.ps.State + return f.pt.pollingStatus() } // PollingMethod returns the method used to monitor the status of the asynchronous operation. func (f Future) PollingMethod() PollingMethodType { - return f.ps.PollingMethod + if f.pt == nil { + return PollingUnknown + } + return f.pt.pollingMethod() } // Done queries the service to see if the operation has completed. func (f *Future) Done(sender autorest.Sender) (bool, error) { - // exit early if this future has terminated - if f.ps.hasTerminated() { - return true, f.errorInfo() + // support for legacy Future implementation + if f.req != nil { + resp, err := sender.Do(f.req) + if err != nil { + return false, err + } + pt, err := createPollingTracker(resp) + if err != nil { + return false, err + } + f.pt = pt + f.req = nil } - resp, err := sender.Do(f.req) - f.resp = resp - if err != nil { + // end legacy + if f.pt == nil { + return false, autorest.NewError("Future", "Done", "future is not initialized") + } + if f.pt.hasTerminated() { + return true, f.pt.pollingError() + } + if err := f.pt.pollForStatus(sender); err != nil { return false, err } - - if !autorest.ResponseHasStatusCode(resp, pollingCodes[:]...) { - // check response body for error content - if resp.Body != nil { - type respErr struct { - ServiceError ServiceError `json:"error"` - } - re := respErr{} - - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return false, err - } - err = json.Unmarshal(b, &re) - if err != nil { - return false, err - } - return false, re.ServiceError - } - - // try to return something meaningful - return false, ServiceError{ - Code: fmt.Sprintf("%v", resp.StatusCode), - Message: resp.Status, - } + if err := f.pt.checkForErrors(); err != nil { + return f.pt.hasTerminated(), err } - - err = updatePollingState(resp, &f.ps) - if err != nil { + if err := f.pt.updatePollingState(f.pt.provisioningStateApplicable()); err != nil { return false, err } - - if f.ps.hasTerminated() { - return true, f.errorInfo() + if err := f.pt.updateHeaders(); err != nil { + return false, err } - - f.req, err = newPollingRequest(f.ps) - return false, err + return f.pt.hasTerminated(), f.pt.pollingError() } // GetPollingDelay returns a duration the application should wait before checking @@ -129,11 +130,15 @@ func (f *Future) Done(sender autorest.Sender) (bool, error) { // the service via the Retry-After response header. If the header wasn't returned // then the function returns the zero-value time.Duration and false. func (f Future) GetPollingDelay() (time.Duration, bool) { - if f.resp == nil { + if f.pt == nil { + return 0, false + } + resp := f.pt.latestResponse() + if resp == nil { return 0, false } - retry := f.resp.Header.Get(autorest.HeaderRetryAfter) + retry := resp.Header.Get(autorest.HeaderRetryAfter) if retry == "" { return 0, false } @@ -150,14 +155,22 @@ func (f Future) GetPollingDelay() (time.Duration, bool) { // running operation has completed, the provided context is cancelled, or the client's // polling duration has been exceeded. It will retry failed polling attempts based on // the retry value defined in the client up to the maximum retry attempts. +// Deprecated: Please use WaitForCompletionRef() instead. func (f Future) WaitForCompletion(ctx context.Context, client autorest.Client) error { + return f.WaitForCompletionRef(ctx, client) +} + +// WaitForCompletionRef will return when one of the following conditions is met: the long +// running operation has completed, the provided context is cancelled, or the client's +// polling duration has been exceeded. It will retry failed polling attempts based on +// the retry value defined in the client up to the maximum retry attempts. +func (f *Future) WaitForCompletionRef(ctx context.Context, client autorest.Client) error { ctx, cancel := context.WithTimeout(ctx, client.PollingDuration) defer cancel() - done, err := f.Done(client) for attempts := 0; !done; done, err = f.Done(client) { if attempts >= client.RetryAttempts { - return autorest.NewErrorWithError(err, "azure", "WaitForCompletion", f.resp, "the number of retries has been exceeded") + return autorest.NewErrorWithError(err, "Future", "WaitForCompletion", f.pt.latestResponse(), "the number of retries has been exceeded") } // we want delayAttempt to be zero in the non-error case so // that DelayForBackoff doesn't perform exponential back-off @@ -181,317 +194,709 @@ func (f Future) WaitForCompletion(ctx context.Context, client autorest.Client) e // wait until the delay elapses or the context is cancelled delayElapsed := autorest.DelayForBackoff(delay, delayAttempt, ctx.Done()) if !delayElapsed { - return autorest.NewErrorWithError(ctx.Err(), "azure", "WaitForCompletion", f.resp, "context has been cancelled") + return autorest.NewErrorWithError(ctx.Err(), "Future", "WaitForCompletion", f.pt.latestResponse(), "context has been cancelled") } } return err } -// if the operation failed the polling state will contain -// error information and implements the error interface -func (f *Future) errorInfo() error { - if !f.ps.hasSucceeded() { - return f.ps - } - return nil -} - // MarshalJSON implements the json.Marshaler interface. func (f Future) MarshalJSON() ([]byte, error) { - return json.Marshal(&f.ps) + return json.Marshal(f.pt) } // UnmarshalJSON implements the json.Unmarshaler interface. func (f *Future) UnmarshalJSON(data []byte) error { - err := json.Unmarshal(data, &f.ps) + // unmarshal into JSON object to determine the tracker type + obj := map[string]interface{}{} + err := json.Unmarshal(data, &obj) if err != nil { return err } - f.req, err = newPollingRequest(f.ps) - return err + if obj["method"] == nil { + return autorest.NewError("Future", "UnmarshalJSON", "missing 'method' property") + } + method := obj["method"].(string) + switch strings.ToUpper(method) { + case http.MethodDelete: + f.pt = &pollingTrackerDelete{} + case http.MethodPatch: + f.pt = &pollingTrackerPatch{} + case http.MethodPost: + f.pt = &pollingTrackerPost{} + case http.MethodPut: + f.pt = &pollingTrackerPut{} + default: + return autorest.NewError("Future", "UnmarshalJSON", "unsupoorted method '%s'", method) + } + // now unmarshal into the tracker + return json.Unmarshal(data, &f.pt) } // PollingURL returns the URL used for retrieving the status of the long-running operation. -// For LROs that use the Location header the final URL value is used to retrieve the result. func (f Future) PollingURL() string { - return f.ps.URI + if f.pt == nil { + return "" + } + return f.pt.pollingURL() +} + +// GetResult should be called once polling has completed successfully. +// It makes the final GET call to retrieve the resultant payload. +func (f Future) GetResult(sender autorest.Sender) (*http.Response, error) { + if f.pt.finalGetURL() == "" { + // we can end up in this situation if the async operation returns a 200 + // with no polling URLs. in that case return the response which should + // contain the JSON payload (only do this for successful terminal cases). + if lr := f.pt.latestResponse(); lr != nil && f.pt.hasSucceeded() { + return lr, nil + } + return nil, autorest.NewError("Future", "GetResult", "missing URL for retrieving result") + } + req, err := http.NewRequest(http.MethodGet, f.pt.finalGetURL(), nil) + if err != nil { + return nil, err + } + return sender.Do(req) } -// DoPollForAsynchronous returns a SendDecorator that polls if the http.Response is for an Azure -// long-running operation. It will delay between requests for the duration specified in the -// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled by -// closing the optional channel on the http.Request. -func DoPollForAsynchronous(delay time.Duration) autorest.SendDecorator { - return func(s autorest.Sender) autorest.Sender { - return autorest.SenderFunc(func(r *http.Request) (resp *http.Response, err error) { - resp, err = s.Do(r) - if err != nil { - return resp, err - } - if !autorest.ResponseHasStatusCode(resp, pollingCodes[:]...) { - return resp, nil - } +type pollingTracker interface { + // these methods can differ per tracker - ps := pollingState{} - for err == nil { - err = updatePollingState(resp, &ps) - if err != nil { - break - } - if ps.hasTerminated() { - if !ps.hasSucceeded() { - err = ps - } - break - } + // checks the response headers and status code to determine the polling mechanism + updateHeaders() error - r, err = newPollingRequest(ps) - if err != nil { - return resp, err - } - r.Cancel = resp.Request.Cancel + // checks the response for tracker-specific error conditions + checkForErrors() error - delay = autorest.GetRetryAfter(resp, delay) - resp, err = autorest.SendWithSender(s, r, - autorest.AfterDelay(delay)) - } + // returns true if provisioning state should be checked + provisioningStateApplicable() bool - return resp, err - }) - } + // methods common to all trackers + + // initializes the tracker's internal state, call this when the tracker is created + initializeState() error + + // makes an HTTP request to check the status of the LRO + pollForStatus(sender autorest.Sender) error + + // updates internal tracker state, call this after each call to pollForStatus + updatePollingState(provStateApl bool) error + + // returns the error response from the service, can be nil + pollingError() error + + // returns the polling method being used + pollingMethod() PollingMethodType + + // returns the state of the LRO as returned from the service + pollingStatus() string + + // returns the URL used for polling status + pollingURL() string + + // returns the URL used for the final GET to retrieve the resource + finalGetURL() string + + // returns true if the LRO is in a terminal state + hasTerminated() bool + + // returns true if the LRO is in a failed terminal state + hasFailed() bool + + // returns true if the LRO is in a successful terminal state + hasSucceeded() bool + + // returns the cached HTTP response after a call to pollForStatus(), can be nil + latestResponse() *http.Response } -func getAsyncOperation(resp *http.Response) string { - return resp.Header.Get(http.CanonicalHeaderKey(headerAsyncOperation)) +type pollingTrackerBase struct { + // resp is the last response, either from the submission of the LRO or from polling + resp *http.Response + + // method is the HTTP verb, this is needed for deserialization + Method string `json:"method"` + + // rawBody is the raw JSON response body + rawBody map[string]interface{} + + // denotes if polling is using async-operation or location header + Pm PollingMethodType `json:"pollingMethod"` + + // the URL to poll for status + URI string `json:"pollingURI"` + + // the state of the LRO as returned from the service + State string `json:"lroState"` + + // the URL to GET for the final result + FinalGetURI string `json:"resultURI"` + + // used to hold an error object returned from the service + Err *ServiceError `json:"error,omitempty"` } -func hasSucceeded(state string) bool { - return strings.EqualFold(state, operationSucceeded) +func (pt *pollingTrackerBase) initializeState() error { + // determine the initial polling state based on response body and/or HTTP status + // code. this is applicable to the initial LRO response, not polling responses! + pt.Method = pt.resp.Request.Method + if err := pt.updateRawBody(); err != nil { + return err + } + switch pt.resp.StatusCode { + case http.StatusOK: + if ps := pt.getProvisioningState(); ps != nil { + pt.State = *ps + } else { + pt.State = operationSucceeded + } + case http.StatusCreated: + if ps := pt.getProvisioningState(); ps != nil { + pt.State = *ps + } else { + pt.State = operationInProgress + } + case http.StatusAccepted: + pt.State = operationInProgress + case http.StatusNoContent: + pt.State = operationSucceeded + default: + pt.State = operationFailed + pt.updateErrorFromResponse() + } + return nil } -func hasTerminated(state string) bool { - return strings.EqualFold(state, operationCanceled) || strings.EqualFold(state, operationFailed) || strings.EqualFold(state, operationSucceeded) +func (pt pollingTrackerBase) getProvisioningState() *string { + if pt.rawBody != nil && pt.rawBody["properties"] != nil { + p := pt.rawBody["properties"].(map[string]interface{}) + if ps := p["provisioningState"]; ps != nil { + s := ps.(string) + return &s + } + } + return nil } -func hasFailed(state string) bool { - return strings.EqualFold(state, operationFailed) +func (pt *pollingTrackerBase) updateRawBody() error { + pt.rawBody = map[string]interface{}{} + if pt.resp.ContentLength != 0 { + defer pt.resp.Body.Close() + b, err := ioutil.ReadAll(pt.resp.Body) + if err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "updateRawBody", nil, "failed to read response body") + } + // put the body back so it's available to other callers + pt.resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + if err = json.Unmarshal(b, &pt.rawBody); err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "updateRawBody", nil, "failed to unmarshal response body") + } + } + return nil } -type provisioningTracker interface { - state() string - hasSucceeded() bool - hasTerminated() bool +func (pt *pollingTrackerBase) pollForStatus(sender autorest.Sender) error { + req, err := http.NewRequest(http.MethodGet, pt.URI, nil) + if err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "pollForStatus", nil, "failed to create HTTP request") + } + // attach the context from the original request if available (it will be absent for deserialized futures) + if pt.resp != nil { + req = req.WithContext(pt.resp.Request.Context()) + } + pt.resp, err = sender.Do(req) + if err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "pollForStatus", nil, "failed to send HTTP request") + } + if autorest.ResponseHasStatusCode(pt.resp, pollingCodes[:]...) { + // reset the service error on success case + pt.Err = nil + err = pt.updateRawBody() + } else { + // check response body for error content + pt.updateErrorFromResponse() + } + return err } -type operationResource struct { - // Note: - // The specification states services should return the "id" field. However some return it as - // "operationId". - ID string `json:"id"` - OperationID string `json:"operationId"` - Name string `json:"name"` - Status string `json:"status"` - Properties map[string]interface{} `json:"properties"` - OperationError ServiceError `json:"error"` - StartTime date.Time `json:"startTime"` - EndTime date.Time `json:"endTime"` - PercentComplete float64 `json:"percentComplete"` +// attempts to unmarshal a ServiceError type from the response body. +// if that fails then make a best attempt at creating something meaningful. +func (pt *pollingTrackerBase) updateErrorFromResponse() { + var err error + if pt.resp.ContentLength != 0 { + type respErr struct { + ServiceError *ServiceError `json:"error"` + } + re := respErr{} + defer pt.resp.Body.Close() + var b []byte + b, err = ioutil.ReadAll(pt.resp.Body) + if err != nil { + goto Default + } + if err = json.Unmarshal(b, &re); err != nil { + goto Default + } + // unmarshalling the error didn't yield anything, try unwrapped error + if re.ServiceError == nil { + err = json.Unmarshal(b, &re.ServiceError) + if err != nil { + goto Default + } + } + if re.ServiceError != nil { + pt.Err = re.ServiceError + return + } + } +Default: + se := &ServiceError{ + Code: fmt.Sprintf("HTTP status code %v", pt.resp.StatusCode), + Message: pt.resp.Status, + } + if err != nil { + se.InnerError = make(map[string]interface{}) + se.InnerError["unmarshalError"] = err.Error() + } + pt.Err = se } -func (or operationResource) state() string { - return or.Status +func (pt *pollingTrackerBase) updatePollingState(provStateApl bool) error { + if pt.Pm == PollingAsyncOperation && pt.rawBody["status"] != nil { + pt.State = pt.rawBody["status"].(string) + } else { + if pt.resp.StatusCode == http.StatusAccepted { + pt.State = operationInProgress + } else if provStateApl { + if ps := pt.getProvisioningState(); ps != nil { + pt.State = *ps + } else { + pt.State = operationSucceeded + } + } else { + return autorest.NewError("pollingTrackerBase", "updatePollingState", "the response from the async operation has an invalid status code") + } + } + // if the operation has failed update the error state + if pt.hasFailed() { + pt.updateErrorFromResponse() + } + return nil } -func (or operationResource) hasSucceeded() bool { - return hasSucceeded(or.state()) +func (pt pollingTrackerBase) pollingError() error { + if pt.Err == nil { + return nil + } + return pt.Err } -func (or operationResource) hasTerminated() bool { - return hasTerminated(or.state()) +func (pt pollingTrackerBase) pollingMethod() PollingMethodType { + return pt.Pm } -type provisioningProperties struct { - ProvisioningState string `json:"provisioningState"` +func (pt pollingTrackerBase) pollingStatus() string { + return pt.State } -type provisioningStatus struct { - Properties provisioningProperties `json:"properties,omitempty"` - ProvisioningError ServiceError `json:"error,omitempty"` +func (pt pollingTrackerBase) pollingURL() string { + return pt.URI } -func (ps provisioningStatus) state() string { - return ps.Properties.ProvisioningState +func (pt pollingTrackerBase) finalGetURL() string { + return pt.FinalGetURI } -func (ps provisioningStatus) hasSucceeded() bool { - return hasSucceeded(ps.state()) +func (pt pollingTrackerBase) hasTerminated() bool { + return strings.EqualFold(pt.State, operationCanceled) || strings.EqualFold(pt.State, operationFailed) || strings.EqualFold(pt.State, operationSucceeded) } -func (ps provisioningStatus) hasTerminated() bool { - return hasTerminated(ps.state()) +func (pt pollingTrackerBase) hasFailed() bool { + return strings.EqualFold(pt.State, operationCanceled) || strings.EqualFold(pt.State, operationFailed) } -func (ps provisioningStatus) hasProvisioningError() bool { - // code and message are required fields so only check them - return len(ps.ProvisioningError.Code) > 0 || - len(ps.ProvisioningError.Message) > 0 +func (pt pollingTrackerBase) hasSucceeded() bool { + return strings.EqualFold(pt.State, operationSucceeded) } -// PollingMethodType defines a type used for enumerating polling mechanisms. -type PollingMethodType string +func (pt pollingTrackerBase) latestResponse() *http.Response { + return pt.resp +} -const ( - // PollingAsyncOperation indicates the polling method uses the Azure-AsyncOperation header. - PollingAsyncOperation PollingMethodType = "AsyncOperation" +// error checking common to all trackers +func (pt pollingTrackerBase) baseCheckForErrors() error { + // for Azure-AsyncOperations the response body cannot be nil or empty + if pt.Pm == PollingAsyncOperation { + if pt.resp.Body == nil || pt.resp.ContentLength == 0 { + return autorest.NewError("pollingTrackerBase", "baseCheckForErrors", "for Azure-AsyncOperation response body cannot be nil") + } + if pt.rawBody["status"] == nil { + return autorest.NewError("pollingTrackerBase", "baseCheckForErrors", "missing status property in Azure-AsyncOperation response body") + } + } + return nil +} - // PollingLocation indicates the polling method uses the Location header. - PollingLocation PollingMethodType = "Location" +// DELETE - // PollingUnknown indicates an unknown polling method and is the default value. - PollingUnknown PollingMethodType = "" -) +type pollingTrackerDelete struct { + pollingTrackerBase +} -type pollingState struct { - PollingMethod PollingMethodType `json:"pollingMethod"` - URI string `json:"uri"` - State string `json:"state"` - ServiceError *ServiceError `json:"error,omitempty"` +func (pt *pollingTrackerDelete) updateHeaders() error { + // for 201 the Location header is required + if pt.resp.StatusCode == http.StatusCreated { + if lh, err := getURLFromLocationHeader(pt.resp); err != nil { + return err + } else if lh == "" { + return autorest.NewError("pollingTrackerDelete", "updateHeaders", "missing Location header in 201 response") + } else { + pt.URI = lh + } + pt.Pm = PollingLocation + pt.FinalGetURI = pt.URI + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + // if the Location header is invalid and we already have a polling URL + // then we don't care if the Location header URL is malformed. + if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { + return err + } else if lh != "" { + if ao == "" { + pt.URI = lh + pt.Pm = PollingLocation + } + // when both headers are returned we use the value in the Location header for the final GET + pt.FinalGetURI = lh + } + // make sure a polling URL was found + if pt.URI == "" { + return autorest.NewError("pollingTrackerPost", "updateHeaders", "didn't get any suitable polling URLs in 202 response") + } + } + return nil } -func (ps pollingState) hasSucceeded() bool { - return hasSucceeded(ps.State) +func (pt pollingTrackerDelete) checkForErrors() error { + return pt.baseCheckForErrors() } -func (ps pollingState) hasTerminated() bool { - return hasTerminated(ps.State) +func (pt pollingTrackerDelete) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusNoContent } -func (ps pollingState) hasFailed() bool { - return hasFailed(ps.State) +// PATCH + +type pollingTrackerPatch struct { + pollingTrackerBase } -func (ps pollingState) Error() string { - s := fmt.Sprintf("Long running operation terminated with status '%s'", ps.State) - if ps.ServiceError != nil { - s = fmt.Sprintf("%s: %+v", s, *ps.ServiceError) +func (pt *pollingTrackerPatch) updateHeaders() error { + // by default we can use the original URL for polling and final GET + if pt.URI == "" { + pt.URI = pt.resp.Request.URL.String() + } + if pt.FinalGetURI == "" { + pt.FinalGetURI = pt.resp.Request.URL.String() } - return s + if pt.Pm == PollingUnknown { + pt.Pm = PollingRequestURI + } + // for 201 it's permissible for no headers to be returned + if pt.resp.StatusCode == http.StatusCreated { + if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + // note the absense of the "final GET" mechanism for PATCH + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + if ao == "" { + if lh, err := getURLFromLocationHeader(pt.resp); err != nil { + return err + } else if lh == "" { + return autorest.NewError("pollingTrackerPatch", "updateHeaders", "didn't get any suitable polling URLs in 202 response") + } else { + pt.URI = lh + pt.Pm = PollingLocation + } + } + } + return nil } -// updatePollingState maps the operation status -- retrieved from either a provisioningState -// field, the status field of an OperationResource, or inferred from the HTTP status code -- -// into a well-known states. Since the process begins from the initial request, the state -// always comes from either a the provisioningState returned or is inferred from the HTTP -// status code. Subsequent requests will read an Azure OperationResource object if the -// service initially returned the Azure-AsyncOperation header. The responseFormat field notes -// the expected response format. -func updatePollingState(resp *http.Response, ps *pollingState) error { - // Determine the response shape - // -- The first response will always be a provisioningStatus response; only the polling requests, - // depending on the header returned, may be something otherwise. - var pt provisioningTracker - if ps.PollingMethod == PollingAsyncOperation { - pt = &operationResource{} - } else { - pt = &provisioningStatus{} - } +func (pt pollingTrackerPatch) checkForErrors() error { + return pt.baseCheckForErrors() +} - // If this is the first request (that is, the polling response shape is unknown), determine how - // to poll and what to expect - if ps.PollingMethod == PollingUnknown { - req := resp.Request - if req == nil { - return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Original HTTP request is missing") - } +func (pt pollingTrackerPatch) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusCreated +} + +// POST + +type pollingTrackerPost struct { + pollingTrackerBase +} - // Prefer the Azure-AsyncOperation header - ps.URI = getAsyncOperation(resp) - if ps.URI != "" { - ps.PollingMethod = PollingAsyncOperation +func (pt *pollingTrackerPost) updateHeaders() error { + // 201 requires Location header + if pt.resp.StatusCode == http.StatusCreated { + if lh, err := getURLFromLocationHeader(pt.resp); err != nil { + return err + } else if lh == "" { + return autorest.NewError("pollingTrackerPost", "updateHeaders", "missing Location header in 201 response") } else { - ps.PollingMethod = PollingLocation + pt.URI = lh + pt.FinalGetURI = lh + pt.Pm = PollingLocation } - - // Else, use the Location header - if ps.URI == "" { - ps.URI = autorest.GetLocation(resp) + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation } - - // Lastly, requests against an existing resource, use the last request URI - if ps.URI == "" { - m := strings.ToUpper(req.Method) - if m == http.MethodPatch || m == http.MethodPut || m == http.MethodGet { - ps.URI = req.URL.String() + // if the Location header is invalid and we already have a polling URL + // then we don't care if the Location header URL is malformed. + if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { + return err + } else if lh != "" { + if ao == "" { + pt.URI = lh + pt.Pm = PollingLocation } + // when both headers are returned we use the value in the Location header for the final GET + pt.FinalGetURI = lh + } + // make sure a polling URL was found + if pt.URI == "" { + return autorest.NewError("pollingTrackerPost", "updateHeaders", "didn't get any suitable polling URLs in 202 response") } } + return nil +} - // Read and interpret the response (saving the Body in case no polling is necessary) - b := &bytes.Buffer{} - err := autorest.Respond(resp, - autorest.ByCopying(b), - autorest.ByUnmarshallingJSON(pt), - autorest.ByClosing()) - resp.Body = ioutil.NopCloser(b) - if err != nil { - return err - } +func (pt pollingTrackerPost) checkForErrors() error { + return pt.baseCheckForErrors() +} - // Interpret the results - // -- Terminal states apply regardless - // -- Unknown states are per-service inprogress states - // -- Otherwise, infer state from HTTP status code - if pt.hasTerminated() { - ps.State = pt.state() - } else if pt.state() != "" { - ps.State = operationInProgress - } else { - switch resp.StatusCode { - case http.StatusAccepted: - ps.State = operationInProgress +func (pt pollingTrackerPost) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusNoContent +} - case http.StatusNoContent, http.StatusCreated, http.StatusOK: - ps.State = operationSucceeded +// PUT - default: - ps.State = operationFailed - } - } +type pollingTrackerPut struct { + pollingTrackerBase +} - if strings.EqualFold(ps.State, operationInProgress) && ps.URI == "" { - return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Unable to obtain polling URI for %s %s", resp.Request.Method, resp.Request.URL) +func (pt *pollingTrackerPut) updateHeaders() error { + // by default we can use the original URL for polling and final GET + if pt.URI == "" { + pt.URI = pt.resp.Request.URL.String() } - - // For failed operation, check for error code and message in - // -- Operation resource - // -- Response - // -- Otherwise, Unknown - if ps.hasFailed() { - if or, ok := pt.(*operationResource); ok { - ps.ServiceError = &or.OperationError - } else if p, ok := pt.(*provisioningStatus); ok && p.hasProvisioningError() { - ps.ServiceError = &p.ProvisioningError - } else { - ps.ServiceError = &ServiceError{ - Code: "Unknown", - Message: "None", + if pt.FinalGetURI == "" { + pt.FinalGetURI = pt.resp.Request.URL.String() + } + if pt.Pm == PollingUnknown { + pt.Pm = PollingRequestURI + } + // for 201 it's permissible for no headers to be returned + if pt.resp.StatusCode == http.StatusCreated { + if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + // if the Location header is invalid and we already have a polling URL + // then we don't care if the Location header URL is malformed. + if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { + return err + } else if lh != "" { + if ao == "" { + pt.URI = lh + pt.Pm = PollingLocation } + // when both headers are returned we use the value in the Location header for the final GET + pt.FinalGetURI = lh + } + // make sure a polling URL was found + if pt.URI == "" { + return autorest.NewError("pollingTrackerPut", "updateHeaders", "didn't get any suitable polling URLs in 202 response") } } return nil } -func newPollingRequest(ps pollingState) (*http.Request, error) { - reqPoll, err := autorest.Prepare(&http.Request{}, - autorest.AsGet(), - autorest.WithBaseURL(ps.URI)) +func (pt pollingTrackerPut) checkForErrors() error { + err := pt.baseCheckForErrors() if err != nil { - return nil, autorest.NewErrorWithError(err, "azure", "newPollingRequest", nil, "Failure creating poll request to %s", ps.URI) + return err + } + // if there are no LRO headers then the body cannot be empty + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } + lh, err := getURLFromLocationHeader(pt.resp) + if err != nil { + return err + } + if ao == "" && lh == "" && len(pt.rawBody) == 0 { + return autorest.NewError("pollingTrackerPut", "checkForErrors", "the response did not contain a body") } + return nil +} + +func (pt pollingTrackerPut) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusCreated +} + +// creates a polling tracker based on the verb of the original request +func createPollingTracker(resp *http.Response) (pollingTracker, error) { + var pt pollingTracker + switch strings.ToUpper(resp.Request.Method) { + case http.MethodDelete: + pt = &pollingTrackerDelete{pollingTrackerBase: pollingTrackerBase{resp: resp}} + case http.MethodPatch: + pt = &pollingTrackerPatch{pollingTrackerBase: pollingTrackerBase{resp: resp}} + case http.MethodPost: + pt = &pollingTrackerPost{pollingTrackerBase: pollingTrackerBase{resp: resp}} + case http.MethodPut: + pt = &pollingTrackerPut{pollingTrackerBase: pollingTrackerBase{resp: resp}} + default: + return nil, autorest.NewError("azure", "createPollingTracker", "unsupported HTTP method %s", resp.Request.Method) + } + if err := pt.initializeState(); err != nil { + return pt, err + } + // this initializes the polling header values, we do this during creation in case the + // initial response send us invalid values; this way the API call will return a non-nil + // error (not doing this means the error shows up in Future.Done) + return pt, pt.updateHeaders() +} + +// gets the polling URL from the Azure-AsyncOperation header. +// ensures the URL is well-formed and absolute. +func getURLFromAsyncOpHeader(resp *http.Response) (string, error) { + s := resp.Header.Get(http.CanonicalHeaderKey(headerAsyncOperation)) + if s == "" { + return "", nil + } + if !isValidURL(s) { + return "", autorest.NewError("azure", "getURLFromAsyncOpHeader", "invalid polling URL '%s'", s) + } + return s, nil +} - return reqPoll, nil +// gets the polling URL from the Location header. +// ensures the URL is well-formed and absolute. +func getURLFromLocationHeader(resp *http.Response) (string, error) { + s := resp.Header.Get(http.CanonicalHeaderKey(autorest.HeaderLocation)) + if s == "" { + return "", nil + } + if !isValidURL(s) { + return "", autorest.NewError("azure", "getURLFromLocationHeader", "invalid polling URL '%s'", s) + } + return s, nil +} + +// verify that the URL is valid and absolute +func isValidURL(s string) bool { + u, err := url.Parse(s) + return err == nil && u.IsAbs() } +// DoPollForAsynchronous returns a SendDecorator that polls if the http.Response is for an Azure +// long-running operation. It will delay between requests for the duration specified in the +// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled via +// the context associated with the http.Request. +// Deprecated: Prefer using Futures to allow for non-blocking async operations. +func DoPollForAsynchronous(delay time.Duration) autorest.SendDecorator { + return func(s autorest.Sender) autorest.Sender { + return autorest.SenderFunc(func(r *http.Request) (*http.Response, error) { + resp, err := s.Do(r) + if err != nil { + return resp, err + } + if !autorest.ResponseHasStatusCode(resp, pollingCodes[:]...) { + return resp, nil + } + future, err := NewFutureFromResponse(resp) + if err != nil { + return resp, err + } + // retry until either the LRO completes or we receive an error + var done bool + for done, err = future.Done(s); !done && err == nil; done, err = future.Done(s) { + // check for Retry-After delay, if not present use the specified polling delay + if pd, ok := future.GetPollingDelay(); ok { + delay = pd + } + // wait until the delay elapses or the context is cancelled + if delayElapsed := autorest.DelayForBackoff(delay, 0, r.Context().Done()); !delayElapsed { + return future.Response(), + autorest.NewErrorWithError(r.Context().Err(), "azure", "DoPollForAsynchronous", future.Response(), "context has been cancelled") + } + } + return future.Response(), err + }) + } +} + +// PollingMethodType defines a type used for enumerating polling mechanisms. +type PollingMethodType string + +const ( + // PollingAsyncOperation indicates the polling method uses the Azure-AsyncOperation header. + PollingAsyncOperation PollingMethodType = "AsyncOperation" + + // PollingLocation indicates the polling method uses the Location header. + PollingLocation PollingMethodType = "Location" + + // PollingRequestURI indicates the polling method uses the original request URI. + PollingRequestURI PollingMethodType = "RequestURI" + + // PollingUnknown indicates an unknown polling method and is the default value. + PollingUnknown PollingMethodType = "" +) + // AsyncOpIncompleteError is the type that's returned from a future that has not completed. type AsyncOpIncompleteError struct { // FutureType is the name of the type composed of a azure.Future. diff --git a/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go b/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go index 18d029526..a702ffe75 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go +++ b/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go @@ -279,16 +279,29 @@ func WithErrorUnlessStatusCode(codes ...int) autorest.RespondDecorator { resp.Body = ioutil.NopCloser(&b) if decodeErr != nil { return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), decodeErr) - } else if e.ServiceError == nil { + } + if e.ServiceError == nil { // Check if error is unwrapped ServiceError - if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil || e.ServiceError.Message == "" { - e.ServiceError = &ServiceError{ - Code: "Unknown", - Message: "Unknown service error", - } + if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil { + return err } } - + if e.ServiceError.Message == "" { + // if we're here it means the returned error wasn't OData v4 compliant. + // try to unmarshal the body as raw JSON in hopes of getting something. + rawBody := map[string]interface{}{} + if err := json.Unmarshal(b.Bytes(), &rawBody); err != nil { + return err + } + e.ServiceError = &ServiceError{ + Code: "Unknown", + Message: "Unknown service error", + } + if len(rawBody) > 0 { + e.ServiceError.Details = []map[string]interface{}{rawBody} + } + } + e.Response = resp e.RequestID = ExtractRequestID(resp) if e.StatusCode == nil { e.StatusCode = resp.StatusCode diff --git a/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go b/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go index 2c88d60ba..bd34f0ed5 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go +++ b/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go @@ -64,7 +64,7 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator { } } } - return resp, fmt.Errorf("failed request: %s", err) + return resp, err }) } } @@ -115,7 +115,7 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError if err != nil { return err } - req.Cancel = originalReq.Cancel + req = req.WithContext(originalReq.Context()) resp, err := autorest.SendWithSender(client, req, autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...), @@ -154,7 +154,7 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError if err != nil { return err } - req.Cancel = originalReq.Cancel + req = req.WithContext(originalReq.Context()) resp, err := autorest.SendWithSender(client, req, autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...), @@ -178,9 +178,9 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError break } - delayed := autorest.DelayWithRetryAfter(resp, originalReq.Cancel) - if !delayed { - autorest.DelayForBackoff(client.PollingDelay, 0, originalReq.Cancel) + delayed := autorest.DelayWithRetryAfter(resp, originalReq.Context().Done()) + if !delayed && !autorest.DelayForBackoff(client.PollingDelay, 0, originalReq.Context().Done()) { + return originalReq.Context().Err() } } if !(time.Since(now) < client.PollingDuration) { diff --git a/vendor/github.com/Azure/go-autorest/autorest/date/date.go b/vendor/github.com/Azure/go-autorest/autorest/date/date.go index 80ca60e9b..c45710656 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/date/date.go +++ b/vendor/github.com/Azure/go-autorest/autorest/date/date.go @@ -5,6 +5,20 @@ time.Time types. And both convert to time.Time through a ToTime method. */ package date +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + import ( "fmt" "time" diff --git a/vendor/github.com/Azure/go-autorest/autorest/date/time.go b/vendor/github.com/Azure/go-autorest/autorest/date/time.go index c1af62963..b453fad04 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/date/time.go +++ b/vendor/github.com/Azure/go-autorest/autorest/date/time.go @@ -1,5 +1,19 @@ package date +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + import ( "regexp" "time" diff --git a/vendor/github.com/Azure/go-autorest/autorest/date/timerfc1123.go b/vendor/github.com/Azure/go-autorest/autorest/date/timerfc1123.go index 11995fb9f..48fb39ba9 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/date/timerfc1123.go +++ b/vendor/github.com/Azure/go-autorest/autorest/date/timerfc1123.go @@ -1,5 +1,19 @@ package date +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + import ( "errors" "time" diff --git a/vendor/github.com/Azure/go-autorest/autorest/date/unixtime.go b/vendor/github.com/Azure/go-autorest/autorest/date/unixtime.go index e085c77ee..7073959b2 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/date/unixtime.go +++ b/vendor/github.com/Azure/go-autorest/autorest/date/unixtime.go @@ -1,5 +1,19 @@ package date +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + import ( "bytes" "encoding/binary" diff --git a/vendor/github.com/Azure/go-autorest/autorest/date/utility.go b/vendor/github.com/Azure/go-autorest/autorest/date/utility.go index 207b1a240..12addf0eb 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/date/utility.go +++ b/vendor/github.com/Azure/go-autorest/autorest/date/utility.go @@ -1,5 +1,19 @@ package date +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + import ( "strings" "time" diff --git a/vendor/github.com/Azure/go-autorest/autorest/sender.go b/vendor/github.com/Azure/go-autorest/autorest/sender.go index c5efd59a2..cacbd8157 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/sender.go +++ b/vendor/github.com/Azure/go-autorest/autorest/sender.go @@ -86,7 +86,7 @@ func SendWithSender(s Sender, r *http.Request, decorators ...SendDecorator) (*ht func AfterDelay(d time.Duration) SendDecorator { return func(s Sender) Sender { return SenderFunc(func(r *http.Request) (*http.Response, error) { - if !DelayForBackoff(d, 0, r.Cancel) { + if !DelayForBackoff(d, 0, r.Context().Done()) { return nil, fmt.Errorf("autorest: AfterDelay canceled before full delay") } return s.Do(r) @@ -165,7 +165,7 @@ func DoPollForStatusCodes(duration time.Duration, delay time.Duration, codes ... resp, err = s.Do(r) if err == nil && ResponseHasStatusCode(resp, codes...) { - r, err = NewPollingRequest(resp, r.Cancel) + r, err = NewPollingRequestWithContext(r.Context(), resp) for err == nil && ResponseHasStatusCode(resp, codes...) { Respond(resp, @@ -198,7 +198,9 @@ func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator { if err == nil { return resp, err } - DelayForBackoff(backoff, attempt, r.Cancel) + if !DelayForBackoff(backoff, attempt, r.Context().Done()) { + return nil, r.Context().Err() + } } return resp, err }) @@ -221,14 +223,18 @@ func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) Se return resp, err } resp, err = s.Do(rr.Request()) + // if the error isn't temporary don't bother retrying + if err != nil && !IsTemporaryNetworkError(err) { + return nil, err + } // we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication // resp and err will both have a value, so in this case we don't want to retry as it will never succeed. if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) { return resp, err } - delayed := DelayWithRetryAfter(resp, r.Cancel) - if !delayed { - DelayForBackoff(backoff, attempt, r.Cancel) + delayed := DelayWithRetryAfter(resp, r.Context().Done()) + if !delayed && !DelayForBackoff(backoff, attempt, r.Context().Done()) { + return nil, r.Context().Err() } // don't count a 429 against the number of attempts // so that we continue to retry until it succeeds @@ -277,7 +283,9 @@ func DoRetryForDuration(d time.Duration, backoff time.Duration) SendDecorator { if err == nil { return resp, err } - DelayForBackoff(backoff, attempt, r.Cancel) + if !DelayForBackoff(backoff, attempt, r.Context().Done()) { + return nil, r.Context().Err() + } } return resp, err }) diff --git a/vendor/github.com/Azure/go-autorest/autorest/to/convert.go b/vendor/github.com/Azure/go-autorest/autorest/to/convert.go index 7b180b866..fdda2ce1a 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/to/convert.go +++ b/vendor/github.com/Azure/go-autorest/autorest/to/convert.go @@ -3,6 +3,20 @@ Package to provides helpers to ease working with pointer values of marshalled st */ package to +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // String returns a string value for the passed string pointer. It returns the empty string if the // pointer is nil. func String(s *string) string { diff --git a/vendor/github.com/Azure/go-autorest/autorest/utility.go b/vendor/github.com/Azure/go-autorest/autorest/utility.go index afb3e4e16..bfddd90b5 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/utility.go +++ b/vendor/github.com/Azure/go-autorest/autorest/utility.go @@ -20,6 +20,7 @@ import ( "encoding/xml" "fmt" "io" + "net" "net/http" "net/url" "reflect" @@ -216,3 +217,12 @@ func IsTokenRefreshError(err error) bool { } return false } + +// IsTemporaryNetworkError returns true if the specified error is a temporary network error or false +// if it's not. If the error doesn't implement the net.Error interface the return value is true. +func IsTemporaryNetworkError(err error) bool { + if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) { + return true + } + return false +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/validation/validation.go b/vendor/github.com/Azure/go-autorest/autorest/validation/validation.go index d886e0b3f..ae987f8fa 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/validation/validation.go +++ b/vendor/github.com/Azure/go-autorest/autorest/validation/validation.go @@ -136,29 +136,29 @@ func validatePtr(x reflect.Value, v Constraint) error { func validateInt(x reflect.Value, v Constraint) error { i := x.Int() - r, ok := v.Rule.(int) + r, ok := toInt64(v.Rule) if !ok { return createError(x, v, fmt.Sprintf("rule must be integer value for %v constraint; got: %v", v.Name, v.Rule)) } switch v.Name { case MultipleOf: - if i%int64(r) != 0 { + if i%r != 0 { return createError(x, v, fmt.Sprintf("value must be a multiple of %v", r)) } case ExclusiveMinimum: - if i <= int64(r) { + if i <= r { return createError(x, v, fmt.Sprintf("value must be greater than %v", r)) } case ExclusiveMaximum: - if i >= int64(r) { + if i >= r { return createError(x, v, fmt.Sprintf("value must be less than %v", r)) } case InclusiveMinimum: - if i < int64(r) { + if i < r { return createError(x, v, fmt.Sprintf("value must be greater than or equal to %v", r)) } case InclusiveMaximum: - if i > int64(r) { + if i > r { return createError(x, v, fmt.Sprintf("value must be less than or equal to %v", r)) } default: @@ -388,6 +388,17 @@ func createError(x reflect.Value, v Constraint, err string) error { v.Target, v.Name, getInterfaceValue(x), err) } +func toInt64(v interface{}) (int64, bool) { + if i64, ok := v.(int64); ok { + return i64, true + } + // older generators emit max constants as int, so if int64 fails fall back to int + if i32, ok := v.(int); ok { + return int64(i32), true + } + return 0, false +} + // NewErrorWithValidationError appends package type and method name in // validation error. // diff --git a/vendor/github.com/Azure/go-autorest/autorest/version.go b/vendor/github.com/Azure/go-autorest/autorest/version.go index f900ef8b5..e32cd68fe 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/version.go +++ b/vendor/github.com/Azure/go-autorest/autorest/version.go @@ -16,5 +16,5 @@ package autorest // Version returns the semantic version (see http://semver.org). func Version() string { - return "v10.3.0" + return "v10.12.0" } diff --git a/vendor/vendor.json b/vendor/vendor.json index c519a1e0e..203593780 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -65,56 +65,56 @@ "versionExact": "v17.3.1" }, { - "checksumSHA1": "+P6HOINDh/n2z4GqEkluzuGP5p0=", + "checksumSHA1": "q3bhLdUVAz5dmDFvfKJJSTd/BaE=", "comment": "v7.0.7", "path": "github.com/Azure/go-autorest/autorest", - "revision": "ed4b7f5bf1ec0c9ede1fda2681d96771282f2862", - "revisionTime": "2018-03-26T17:06:54Z", - "version": "v10.4.0", - "versionExact": "v10.4.0" + "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318", + "revisionTime": "2018-06-28T21:22:21Z", + "version": "v10.12.0", + "versionExact": "v10.12.0" }, { - "checksumSHA1": "HzA52MbMWnsR31CFrub5biN90/Q=", + "checksumSHA1": "Ygkay0Aq7zeg8A9DSRWI+flRcFk=", "path": "github.com/Azure/go-autorest/autorest/adal", - "revision": "ed4b7f5bf1ec0c9ede1fda2681d96771282f2862", - "revisionTime": "2018-03-26T17:06:54Z", - "version": "v10.4.0", - "versionExact": "v10.4.0" + "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318", + "revisionTime": "2018-06-28T21:22:21Z", + "version": "v10.12.0", + "versionExact": "v10.12.0" }, { - "checksumSHA1": "bDFbLGwpCT8TRmqEKtPY/U1DAY8=", + "checksumSHA1": "O8hvZ5SZ6o/mGOTivIagc1uPdl4=", "comment": "v7.0.7", "path": "github.com/Azure/go-autorest/autorest/azure", - "revision": "ed4b7f5bf1ec0c9ede1fda2681d96771282f2862", - "revisionTime": "2018-03-26T17:06:54Z", - "version": "v10.4.0", - "versionExact": "v10.4.0" + "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318", + "revisionTime": "2018-06-28T21:22:21Z", + "version": "v10.12.0", + "versionExact": "v10.12.0" }, { - "checksumSHA1": "LSF/pNrjhIxl6jiS6bKooBFCOxI=", + "checksumSHA1": "KcxXWvXhVIkfqdGR+TQqWyCbgZk=", "comment": "v7.0.7", "path": "github.com/Azure/go-autorest/autorest/date", - "revision": "58f6f26e200fa5dfb40c9cd1c83f3e2c860d779d", - "revisionTime": "2017-04-28T17:52:31Z", - "version": "=v8.0.0", - "versionExact": "v8.0.0" + "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318", + "revisionTime": "2018-06-28T21:22:21Z", + "version": "v10.12.0", + "versionExact": "v10.12.0" }, { - "checksumSHA1": "Ev8qCsbFjDlMlX0N2tYAhYQFpUc=", + "checksumSHA1": "8XiRKZWE6XGsb0eJfG4P98JwaXw=", "comment": "v7.0.7", "path": "github.com/Azure/go-autorest/autorest/to", - "revision": "58f6f26e200fa5dfb40c9cd1c83f3e2c860d779d", - "revisionTime": "2017-04-28T17:52:31Z", - "version": "=v8.0.0", - "versionExact": "v8.0.0" + "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318", + "revisionTime": "2018-06-28T21:22:21Z", + "version": "v10.12.0", + "versionExact": "v10.12.0" }, { - "checksumSHA1": "CdDkG+J8wqXQVQ0f0xal+eolB1w=", + "checksumSHA1": "SVU/fTZ9osBm+WPdd5y71RabAzE=", "path": "github.com/Azure/go-autorest/autorest/validation", - "revision": "ed4b7f5bf1ec0c9ede1fda2681d96771282f2862", - "revisionTime": "2018-03-26T17:06:54Z", - "version": "v10.4.0", - "versionExact": "v10.4.0" + "revision": "1f7cd6cfe0adea687ad44a512dfe76140f804318", + "revisionTime": "2018-06-28T21:22:21Z", + "version": "v10.12.0", + "versionExact": "v10.12.0" }, { "checksumSHA1": "TgrN0l/E16deTlLYNt8wf66urSU=",