mirror of https://github.com/hashicorp/packer
parent
9ddec470a0
commit
7e34579b7e
@ -0,0 +1,36 @@
|
||||
package arm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
packerAzureCommon "github.com/hashicorp/packer/builder/azure/common"
|
||||
)
|
||||
|
||||
func NewDeviceFlowOAuthTokenProvider(env azure.Environment, say func(string), tenantID string) oAuthTokenProvider {
|
||||
return &deviceflowOauthTokenProvider{}
|
||||
}
|
||||
|
||||
type deviceflowOauthTokenProvider struct {
|
||||
env azure.Environment
|
||||
say func(string)
|
||||
tenantID string
|
||||
}
|
||||
|
||||
func (tp *deviceflowOauthTokenProvider) getServicePrincipalToken() (*adal.ServicePrincipalToken, error) {
|
||||
return tp.getServicePrincipalTokenWithResource(tp.env.ResourceManagerEndpoint)
|
||||
}
|
||||
|
||||
func (tp *deviceflowOauthTokenProvider) getServicePrincipalTokenWithResource(resource string) (*adal.ServicePrincipalToken, error) {
|
||||
if resource == tp.env.ServiceManagementEndpoint {
|
||||
tp.say("Getting auth token for Service management endpoint")
|
||||
} else if resource == strings.TrimRight(tp.env.KeyVaultEndpoint, "/") {
|
||||
tp.say("Getting token for Vault resource")
|
||||
} else {
|
||||
tp.say(fmt.Sprintf("Getting token for %s", resource))
|
||||
}
|
||||
|
||||
return packerAzureCommon.Authenticate(tp.env, tp.tenantID, tp.say, resource)
|
||||
}
|
||||
@ -0,0 +1,23 @@
|
||||
package arm
|
||||
|
||||
import (
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
)
|
||||
|
||||
// for managed identity auth
|
||||
type msiOAuthTokenProvider struct {
|
||||
env azure.Environment
|
||||
}
|
||||
|
||||
func NewMSIOAuthTokenProvider(env azure.Environment) oAuthTokenProvider {
|
||||
return &msiOAuthTokenProvider{env}
|
||||
}
|
||||
|
||||
func (tp *msiOAuthTokenProvider) getServicePrincipalToken() (*adal.ServicePrincipalToken, error) {
|
||||
return tp.getServicePrincipalTokenWithResource(tp.env.ResourceManagerEndpoint)
|
||||
}
|
||||
|
||||
func (tp *msiOAuthTokenProvider) getServicePrincipalTokenWithResource(resource string) (*adal.ServicePrincipalToken, error) {
|
||||
return adal.NewServicePrincipalTokenFromMSI("http://169.254.169.254/metadata/identity/oauth2/token", resource)
|
||||
}
|
||||
@ -0,0 +1,35 @@
|
||||
package arm
|
||||
|
||||
import (
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
)
|
||||
|
||||
// for clientID/secret auth
|
||||
type secretOAuthTokenProvider struct {
|
||||
env azure.Environment
|
||||
clientID, clientSecret, tenantID string
|
||||
}
|
||||
|
||||
func NewSecretOAuthTokenProvider(env azure.Environment, clientID, clientSecret, tenantID string) oAuthTokenProvider {
|
||||
return &secretOAuthTokenProvider{env, clientID, clientSecret, tenantID}
|
||||
}
|
||||
|
||||
func (tp *secretOAuthTokenProvider) getServicePrincipalToken() (*adal.ServicePrincipalToken, error) {
|
||||
return tp.getServicePrincipalTokenWithResource(tp.env.ResourceManagerEndpoint)
|
||||
}
|
||||
|
||||
func (tp *secretOAuthTokenProvider) getServicePrincipalTokenWithResource(resource string) (*adal.ServicePrincipalToken, error) {
|
||||
oauthConfig, err := adal.NewOAuthConfig(tp.env.ActiveDirectoryEndpoint, tp.tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
spt, err := adal.NewServicePrincipalToken(
|
||||
*oauthConfig,
|
||||
tp.clientID,
|
||||
tp.clientSecret,
|
||||
resource)
|
||||
|
||||
return spt, err
|
||||
}
|
||||
@ -0,0 +1,165 @@
|
||||
package arm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/hashicorp/packer/packer"
|
||||
)
|
||||
|
||||
// ClientConfig allows for various ways to authenticate Azure clients
|
||||
type ClientConfig struct {
|
||||
// Describes where API's are
|
||||
|
||||
CloudEnvironmentName string `mapstructure:"cloud_environment_name"`
|
||||
cloudEnvironment *azure.Environment
|
||||
|
||||
// Authentication fields
|
||||
|
||||
// Client ID
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
// Client secret/password
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
ObjectID string `mapstructure:"object_id"`
|
||||
TenantID string `mapstructure:"tenant_id"`
|
||||
SubscriptionID string `mapstructure:"subscription_id"`
|
||||
}
|
||||
|
||||
const DefaultCloudEnvironmentName = "Public"
|
||||
|
||||
func (c *ClientConfig) provideDefaultValues() {
|
||||
if c.CloudEnvironmentName == "" {
|
||||
c.CloudEnvironmentName = DefaultCloudEnvironmentName
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConfig) setCloudEnvironment() error {
|
||||
lookup := map[string]string{
|
||||
"CHINA": "AzureChinaCloud",
|
||||
"CHINACLOUD": "AzureChinaCloud",
|
||||
"AZURECHINACLOUD": "AzureChinaCloud",
|
||||
|
||||
"GERMAN": "AzureGermanCloud",
|
||||
"GERMANCLOUD": "AzureGermanCloud",
|
||||
"AZUREGERMANCLOUD": "AzureGermanCloud",
|
||||
|
||||
"GERMANY": "AzureGermanCloud",
|
||||
"GERMANYCLOUD": "AzureGermanCloud",
|
||||
"AZUREGERMANYCLOUD": "AzureGermanCloud",
|
||||
|
||||
"PUBLIC": "AzurePublicCloud",
|
||||
"PUBLICCLOUD": "AzurePublicCloud",
|
||||
"AZUREPUBLICCLOUD": "AzurePublicCloud",
|
||||
|
||||
"USGOVERNMENT": "AzureUSGovernmentCloud",
|
||||
"USGOVERNMENTCLOUD": "AzureUSGovernmentCloud",
|
||||
"AZUREUSGOVERNMENTCLOUD": "AzureUSGovernmentCloud",
|
||||
}
|
||||
|
||||
name := strings.ToUpper(c.CloudEnvironmentName)
|
||||
envName, ok := lookup[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("There is no cloud environment matching the name '%s'!", c.CloudEnvironmentName)
|
||||
}
|
||||
|
||||
env, err := azure.EnvironmentFromName(envName)
|
||||
c.cloudEnvironment = &env
|
||||
return err
|
||||
}
|
||||
|
||||
func (c ClientConfig) assertRequiredParametersSet(errs *packer.MultiError) {
|
||||
/////////////////////////////////////////////
|
||||
// Authentication via OAUTH
|
||||
|
||||
// Check if device login is being asked for, and is allowed.
|
||||
//
|
||||
// Device login is enabled if the user only defines SubscriptionID and not
|
||||
// ClientID, ClientSecret, and TenantID.
|
||||
//
|
||||
// Device login is not enabled for Windows because the WinRM certificate is
|
||||
// readable by the ObjectID of the App. There may be another way to handle
|
||||
// this case, but I am not currently aware of it - send feedback.
|
||||
|
||||
if c.useMSI() {
|
||||
return
|
||||
}
|
||||
|
||||
if c.SubscriptionID == "" {
|
||||
errs = packer.MultiErrorAppend(errs, fmt.Errorf("A subscription_id must be specified"))
|
||||
}
|
||||
|
||||
if c.useDeviceLogin() {
|
||||
return
|
||||
}
|
||||
|
||||
if c.SubscriptionID != "" && c.ClientID != "" && c.ClientSecret != "" {
|
||||
// Service principal using secret
|
||||
return
|
||||
}
|
||||
|
||||
errs = packer.MultiErrorAppend(errs, fmt.Errorf("No valid set of authentication values specified:\n"+
|
||||
"* to use the Managed Identity of teh current machine, do not specify any of the fields below\n"+
|
||||
"* to use interactive user authentication, specify only subscription_id\n"+
|
||||
"* to use an Azure Active Directory service principal, specify subscription_id, client_id and client_secret."))
|
||||
}
|
||||
|
||||
func (c ClientConfig) useDeviceLogin() bool {
|
||||
return c.SubscriptionID != "" &&
|
||||
c.ClientID == "" &&
|
||||
c.ClientSecret == "" &&
|
||||
c.TenantID == ""
|
||||
}
|
||||
|
||||
func (c ClientConfig) useMSI() bool {
|
||||
return c.SubscriptionID == "" &&
|
||||
c.ClientID == "" &&
|
||||
c.ClientSecret == "" &&
|
||||
c.TenantID == ""
|
||||
}
|
||||
|
||||
func (c ClientConfig) getServicePrincipalTokens(
|
||||
say func(string)) (
|
||||
servicePrincipalToken *adal.ServicePrincipalToken,
|
||||
servicePrincipalTokenVault *adal.ServicePrincipalToken,
|
||||
err error) {
|
||||
|
||||
tenantID := c.TenantID
|
||||
|
||||
var auth oAuthTokenProvider
|
||||
|
||||
if c.useDeviceLogin() {
|
||||
say("Getting tokens using device flow")
|
||||
auth = NewDeviceFlowOAuthTokenProvider(*c.cloudEnvironment, say, tenantID)
|
||||
} else if c.useMSI() {
|
||||
say("Getting tokens using Managed Identity for Azure")
|
||||
auth = NewMSIOAuthTokenProvider(*c.cloudEnvironment)
|
||||
} else {
|
||||
say("Getting tokens using client secret")
|
||||
auth = NewSecretOAuthTokenProvider(*c.cloudEnvironment, c.ClientID, c.ClientSecret, tenantID)
|
||||
}
|
||||
|
||||
servicePrincipalToken, err = auth.getServicePrincipalToken()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = servicePrincipalToken.EnsureFresh()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
servicePrincipalTokenVault, err = auth.getServicePrincipalTokenWithResource(
|
||||
strings.TrimRight(c.cloudEnvironment.KeyVaultEndpoint, "/"))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = servicePrincipalTokenVault.EnsureFresh()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return servicePrincipalToken, servicePrincipalTokenVault, nil
|
||||
}
|
||||
@ -0,0 +1,199 @@
|
||||
package arm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/hashicorp/packer/packer"
|
||||
)
|
||||
|
||||
func Test_ClientConfig_RequiredParametersSet(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config ClientConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no client_id, client_secret or subscription_id should enable MSI auth",
|
||||
config: ClientConfig{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "subscription_id is set will trigger device flow",
|
||||
config: ClientConfig{
|
||||
SubscriptionID: "error",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "client_id without client_secret should error",
|
||||
config: ClientConfig{
|
||||
ClientID: "error",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "client_secret without client_id should error",
|
||||
config: ClientConfig{
|
||||
ClientSecret: "error",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing subscription_id when using secret",
|
||||
config: ClientConfig{
|
||||
ClientID: "ok",
|
||||
ClientSecret: "ok",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "tenant_id alone should fail",
|
||||
config: ClientConfig{
|
||||
TenantID: "ok",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
errs := &packer.MultiError{}
|
||||
tt.config.assertRequiredParametersSet(errs)
|
||||
if (len(errs.Errors) != 0) != tt.wantErr {
|
||||
t.Errorf("newConfig() error = %v, wantErr %v", errs, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ClientConfig_DeviceLogin(t *testing.T) {
|
||||
getEnvOrSkip(t, "AZURE_DEVICE_LOGIN")
|
||||
cfg := ClientConfig{
|
||||
SubscriptionID: getEnvOrSkip(t, "AZURE_SUBSCRIPTION"),
|
||||
cloudEnvironment: getCloud(),
|
||||
}
|
||||
assertValid(t, cfg)
|
||||
|
||||
spt, sptkv, err := cfg.getServicePrincipalTokens(
|
||||
func(s string) { fmt.Printf("SAY: %s\n", s) })
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil err, but got: %v", err)
|
||||
}
|
||||
token := spt.Token()
|
||||
if token.AccessToken == "" {
|
||||
t.Fatal("Expected management token to have non-nil access token")
|
||||
}
|
||||
if token.RefreshToken == "" {
|
||||
t.Fatal("Expected management token to have non-nil refresh token")
|
||||
}
|
||||
kvtoken := sptkv.Token()
|
||||
if kvtoken.AccessToken == "" {
|
||||
t.Fatal("Expected keyvault token to have non-nil access token")
|
||||
}
|
||||
if kvtoken.RefreshToken == "" {
|
||||
t.Fatal("Expected keyvault token to have non-nil refresh token")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ClientConfig_ClientPassword(t *testing.T) {
|
||||
cfg := ClientConfig{
|
||||
SubscriptionID: getEnvOrSkip(t, "AZURE_SUBSCRIPTION"),
|
||||
ClientID: getEnvOrSkip(t, "AZURE_CLIENTID"),
|
||||
ClientSecret: getEnvOrSkip(t, "AZURE_CLIENTSECRET"),
|
||||
TenantID: getEnvOrSkip(t, "AZURE_TENANTID"),
|
||||
cloudEnvironment: getCloud(),
|
||||
}
|
||||
assertValid(t, cfg)
|
||||
|
||||
spt, sptkv, err := cfg.getServicePrincipalTokens(func(s string) { fmt.Printf("SAY: %s\n", s) })
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil err, but got: %v", err)
|
||||
}
|
||||
token := spt.Token()
|
||||
if token.AccessToken == "" {
|
||||
t.Fatal("Expected management token to have non-nil access token")
|
||||
}
|
||||
if token.RefreshToken != "" {
|
||||
t.Fatal("Expected management token to have no refresh token")
|
||||
}
|
||||
kvtoken := sptkv.Token()
|
||||
if kvtoken.AccessToken == "" {
|
||||
t.Fatal("Expected keyvault token to have non-nil access token")
|
||||
}
|
||||
if kvtoken.RefreshToken != "" {
|
||||
t.Fatal("Expected keyvault token to have no refresh token")
|
||||
}
|
||||
}
|
||||
|
||||
func getEnvOrSkip(t *testing.T, envVar string) string {
|
||||
v := os.Getenv(envVar)
|
||||
if v == "" {
|
||||
t.Skipf("%s is empty, skipping", envVar)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func getCloud() *azure.Environment {
|
||||
cloudName := os.Getenv("AZURE_CLOUD")
|
||||
if cloudName == "" {
|
||||
cloudName = "AZUREPUBLICCLOUD"
|
||||
}
|
||||
c, _ := azure.EnvironmentFromName(cloudName)
|
||||
return &c
|
||||
}
|
||||
|
||||
// tests for assertRequiredParametersSet
|
||||
|
||||
func Test_ClientConfig_CanUseDeviceCode(t *testing.T) {
|
||||
cfg := emptyClientConfig()
|
||||
cfg.SubscriptionID = "12345"
|
||||
// TenantID is optional
|
||||
|
||||
assertValid(t, cfg)
|
||||
}
|
||||
|
||||
func assertValid(t *testing.T, cfg ClientConfig) {
|
||||
errs := &packer.MultiError{}
|
||||
cfg.assertRequiredParametersSet(errs)
|
||||
if len(errs.Errors) != 0 {
|
||||
t.Fatal("Expected errs to be empty: ", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func assertInvalid(t *testing.T, cfg ClientConfig) {
|
||||
errs := &packer.MultiError{}
|
||||
cfg.assertRequiredParametersSet(errs)
|
||||
if len(errs.Errors) == 0 {
|
||||
t.Fatal("Expected errs to be non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ClientConfig_CanUseClientSecret(t *testing.T) {
|
||||
cfg := emptyClientConfig()
|
||||
cfg.SubscriptionID = "12345"
|
||||
cfg.ClientID = "12345"
|
||||
cfg.ClientSecret = "12345"
|
||||
|
||||
assertValid(t, cfg)
|
||||
}
|
||||
|
||||
func Test_ClientConfig_CanUseClientSecretWithTenantID(t *testing.T) {
|
||||
cfg := emptyClientConfig()
|
||||
cfg.SubscriptionID = "12345"
|
||||
cfg.ClientID = "12345"
|
||||
cfg.ClientSecret = "12345"
|
||||
cfg.TenantID = "12345"
|
||||
|
||||
assertValid(t, cfg)
|
||||
}
|
||||
|
||||
func emptyClientConfig() ClientConfig {
|
||||
cfg := ClientConfig{}
|
||||
_ = cfg.setCloudEnvironment()
|
||||
return cfg
|
||||
}
|
||||
Loading…
Reference in new issue