diff --git a/builder/azure/common/client/config.go b/builder/azure/common/client/config.go index 9a4a736e8..230188eb5 100644 --- a/builder/azure/common/client/config.go +++ b/builder/azure/common/client/config.go @@ -198,60 +198,64 @@ func (c Config) UseMSI() bool { c.TenantID == "" } -func (c Config) GetServicePrincipalTokens( - say func(string)) ( +func (c Config) GetServicePrincipalTokens(say func(string)) ( servicePrincipalToken *adal.ServicePrincipalToken, servicePrincipalTokenVault *adal.ServicePrincipalToken, err error) { - tenantID := c.TenantID + servicePrincipalToken, err = c.GetServicePrincipalToken(say, + c.CloudEnvironment.ResourceManagerEndpoint) + if err != nil { + return nil, nil, err + } + servicePrincipalTokenVault, err = c.GetServicePrincipalToken(say, + strings.TrimRight(c.CloudEnvironment.KeyVaultEndpoint, "/")) + if err != nil { + return nil, nil, err + } + return servicePrincipalToken, servicePrincipalTokenVault, nil +} + +func (c Config) GetServicePrincipalToken( + say func(string), forResource string) ( + servicePrincipalToken *adal.ServicePrincipalToken, + err error) { var auth oAuthTokenProvider switch c.authType { case authTypeDeviceLogin: say("Getting tokens using device flow") - auth = NewDeviceFlowOAuthTokenProvider(*c.CloudEnvironment, say, tenantID) + auth = NewDeviceFlowOAuthTokenProvider(*c.CloudEnvironment, say, c.TenantID) case authTypeMSI: say("Getting tokens using Managed Identity for Azure") auth = NewMSIOAuthTokenProvider(*c.CloudEnvironment) case authTypeClientSecret: say("Getting tokens using client secret") - auth = NewSecretOAuthTokenProvider(*c.CloudEnvironment, c.ClientID, c.ClientSecret, tenantID) + auth = NewSecretOAuthTokenProvider(*c.CloudEnvironment, c.ClientID, c.ClientSecret, c.TenantID) case authTypeClientCert: say("Getting tokens using client certificate") - auth, err = NewCertOAuthTokenProvider(*c.CloudEnvironment, c.ClientID, c.ClientCertPath, tenantID) + auth, err = NewCertOAuthTokenProvider(*c.CloudEnvironment, c.ClientID, c.ClientCertPath, c.TenantID) if err != nil { - return nil, nil, err + return nil, err } case authTypeClientBearerJWT: say("Getting tokens using client bearer JWT") - auth = NewJWTOAuthTokenProvider(*c.CloudEnvironment, c.ClientID, c.ClientJWT, tenantID) + auth = NewJWTOAuthTokenProvider(*c.CloudEnvironment, c.ClientID, c.ClientJWT, c.TenantID) default: panic("authType not set, call FillParameters, or set explicitly") } - servicePrincipalToken, err = auth.getServicePrincipalToken() + servicePrincipalToken, err = auth.getServicePrincipalTokenWithResource(forResource) if err != nil { - return nil, nil, err + return nil, err } err = servicePrincipalToken.EnsureFresh() if err != nil { - return nil, nil, err - } - - servicePrincipalTokenVault, err = auth.getServicePrincipalTokenWithResource( - strings.TrimRight(c.CloudEnvironment.KeyVaultEndpoint, "/")) - if err != nil { - return nil, nil, err - } - - err = servicePrincipalTokenVault.EnsureFresh() - if err != nil { - return nil, nil, err + return nil, err } - return servicePrincipalToken, servicePrincipalTokenVault, nil + return servicePrincipalToken, nil } // FillParameters capture the user intent from the supplied parameter set in authType, retrieves the TenantID and CloudEnvironment if not specified.