diff --git a/builder/amazon/common/ssm_driver.go b/builder/amazon/common/ssm_driver.go index 82ad8887b..8fc69b90f 100644 --- a/builder/amazon/common/ssm_driver.go +++ b/builder/amazon/common/ssm_driver.go @@ -3,52 +3,50 @@ package common import ( "bytes" "context" + "encoding/json" "fmt" "log" "os/exec" - "github.com/hashicorp/packer/packer" + "github.com/aws/aws-sdk-go/service/ssm" ) -const SessionManagerPluginName string = "session-manager-plugin" +const sessionManagerPluginName string = "session-manager-plugin" + +//sessionCommand is the AWS-SDK equivalent to the command you would specify to `aws ssm ...` +const sessionCommand string = "StartSession" type SSMDriver struct { - Ui packer.Ui - Ctx context.Context - // Provided for testing purposes; if not specified it defaults to SessionManagerPluginName + Region string + ProfileName string + Session *ssm.StartSessionOutput + SessionParams ssm.StartSessionInput + SessionEndpoint string + // Provided for testing purposes; if not specified it defaults to sessionManagerPluginName PluginName string } // StartSession starts an interactive Systems Manager session with a remote instance via the AWS session-manager-plugin -func (s *SSMDriver) StartSession(sessionData, region, profile, params, endpoint string) error { +func (sd *SSMDriver) StartSession(ctx context.Context) error { var stdout bytes.Buffer var stderr bytes.Buffer - args := []string{ - sessionData, - region, - "StartSession", - profile, - params, - endpoint, - } - - if s.PluginName == "" { - s.PluginName = SessionManagerPluginName + if sd.PluginName == "" { + sd.PluginName = sessionManagerPluginName } - if _, err := exec.LookPath(s.PluginName); err != nil { + args, err := sd.Args() + if err != nil { + err = fmt.Errorf("error encountered validating session details: %s", err) return err } - log.Printf("Attempting to start session with the following args: %v", args) - cmd := exec.CommandContext(s.Ctx, s.PluginName, args...) + cmd := exec.CommandContext(ctx, sd.PluginName, args...) cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Start(); err != nil { - err = fmt.Errorf("error encountered when calling %s: %s\nStderr: %s", s.PluginName, err, stderr.String()) - s.Ui.Error(err.Error()) + err = fmt.Errorf("error encountered when calling %s: %s\nStderr: %s", sd.PluginName, err, stderr.String()) return err } // TODO capture logging for testing @@ -56,3 +54,31 @@ func (s *SSMDriver) StartSession(sessionData, region, profile, params, endpoint return nil } +func (sd *SSMDriver) Args() ([]string, error) { + if sd.Session == nil { + return nil, fmt.Errorf("an active Amazon SSM Session is required before trying to open a session tunnel") + } + + // AWS session-manager-plugin requires a valid session be passed in JSON. + sessionDetails, err := json.Marshal(sd.Session) + if err != nil { + return nil, fmt.Errorf("error encountered in reading session details %s", err) + } + + // AWS session-manager-plugin requires the parameters used in the session to be passed in JSON as well. + sessionParameters, err := json.Marshal(sd.SessionParams) + if err != nil { + return nil, fmt.Errorf("error encountered in reading session parameter details %s", err) + } + + args := []string{ + string(sessionDetails), + sd.Region, + sessionCommand, + sd.ProfileName, + string(sessionParameters), + sd.SessionEndpoint, + } + + return args, nil +} diff --git a/builder/amazon/common/ssm_driver_test.go b/builder/amazon/common/ssm_driver_test.go index b7680187f..426adf251 100644 --- a/builder/amazon/common/ssm_driver_test.go +++ b/builder/amazon/common/ssm_driver_test.go @@ -1,10 +1,15 @@ package common import ( + "context" + "fmt" + "strings" "testing" + + "github.com/aws/aws-sdk-go/service/ssm" ) -func TestStartSession(t *testing.T) { +func TestSSMDriver_StartSession(t *testing.T) { tt := []struct { Name string PluginName string @@ -17,13 +22,75 @@ func TestStartSession(t *testing.T) { for _, tc := range tt { tc := tc t.Run(tc.Name, func(t *testing.T) { - driver := SSMDriver{PluginName: "someboguspluginname"} + driver := SSMDriver{ + Region: "region", + Session: new(ssm.StartSessionOutput), + SessionParams: ssm.StartSessionInput{}, + SessionEndpoint: "endpoint", + PluginName: tc.PluginName} - err := driver.StartSession("sessionData", "region", "profile", "params", "bogus-endpoint") + ctx := context.TODO() + err := driver.StartSession(ctx) if tc.ErrorExpected && err == nil { t.Fatalf("Executing %q should have failed but instead no error was returned", tc.PluginName) } + + }) + } +} + +func TestSSMDriver_Args(t *testing.T) { + tt := []struct { + Name string + Session *ssm.StartSessionOutput + ProfileName string + ErrorExpected bool + }{ + { + Name: "NilSession", + ErrorExpected: true, + }, + { + Name: "NonNilSession", + Session: new(ssm.StartSessionOutput), + ErrorExpected: false, + }, + { + Name: "SessionWithProfileName", + Session: new(ssm.StartSessionOutput), + ProfileName: "default", + ErrorExpected: false, + }, + } + + for _, tc := range tt { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + driver := SSMDriver{ + Region: "region", + ProfileName: tc.ProfileName, + Session: tc.Session, + SessionParams: ssm.StartSessionInput{}, + SessionEndpoint: "amazon.com/sessions", + } + + args, err := driver.Args() + if tc.ErrorExpected && err == nil { + t.Fatalf("SSMDriver.Args with a %q should have failed but instead no error was returned", tc.Name) + } + + if tc.ErrorExpected { + return + } + + // validate launch script + expectedArgString := fmt.Sprintf(`{"SessionId":null,"StreamUrl":null,"TokenValue":null} %s StartSession %s {"DocumentName":null,"Parameters":null,"Target":null} %s`, driver.Region, driver.ProfileName, driver.SessionEndpoint) + argString := strings.Join(args, " ") + if argString != expectedArgString { + t.Errorf("Expected launch script to be %q but got %q", expectedArgString, argString) + } + }) } } diff --git a/builder/amazon/common/step_create_ssm_tunnel.go b/builder/amazon/common/step_create_ssm_tunnel.go index d06feca14..06af19700 100644 --- a/builder/amazon/common/step_create_ssm_tunnel.go +++ b/builder/amazon/common/step_create_ssm_tunnel.go @@ -2,8 +2,8 @@ package common import ( "context" - "encoding/json" "fmt" + "log" "strconv" "time" @@ -18,61 +18,46 @@ import ( ) type StepCreateSSMTunnel struct { - AWSSession *session.Session - DstPort int - SSMAgentEnabled bool - instanceId string - ssmSession *ssm.StartSessionOutput + AWSSession *session.Session + Region string + LocalPortNumber int + RemotePortNumber int + SSMAgentEnabled bool + instanceId string + session *ssm.StartSessionOutput } +// Run executes the Packer build step that creates a session tunnel. func (s *StepCreateSSMTunnel) Run(ctx context.Context, state multistep.StateBag) multistep.StepAction { if !s.SSMAgentEnabled { return multistep.ActionContinue } ui := state.Get("ui").(packer.Ui) - // Find an available TCP port for our HTTP server - l, err := net.ListenRangeConfig{ - Min: 8000, - Max: 9000, - Addr: "0.0.0.0", - Network: "tcp", - }.Listen(ctx) - if err != nil { + if err := s.ConfigureLocalHostPort(ctx); err != nil { err := fmt.Errorf("error finding an available port to initiate a session tunnel: %s", err) state.Put("error", err) ui.Error(err.Error()) return multistep.ActionHalt } - dst, src := strconv.Itoa(s.DstPort), strconv.Itoa(l.Port) - params := map[string][]*string{ - "portNumber": []*string{aws.String(dst)}, - "localPortNumber": []*string{aws.String(src)}, - } - instance, ok := state.Get("instance").(*ec2.Instance) if !ok { - err := fmt.Errorf("error encountered in obtaining target instance id for SSM tunnel") + err := fmt.Errorf("error encountered in obtaining target instance id for session tunnel") ui.Error(err.Error()) state.Put("error", err) return multistep.ActionHalt } - s.instanceId = aws.StringValue(instance.InstanceId) - ssmconn := ssm.New(s.AWSSession) - input := ssm.StartSessionInput{ - DocumentName: aws.String("AWS-StartPortForwardingSession"), - Parameters: params, - Target: aws.String(s.instanceId), - } - ui.Message(fmt.Sprintf("Starting PortForwarding session to instance %q on local port %q to remote port %q", s.instanceId, src, dst)) + log.Printf("Starting PortForwarding session to instance %q on local port %q to remote port %q", s.instanceId, s.LocalPortNumber, s.RemotePortNumber) + input := s.BuildTunnelInputForInstance(s.instanceId) + ssmconn := ssm.New(s.AWSSession) var output *ssm.StartSessionOutput - err = retry.Config{ + err := retry.Config{ ShouldRetry: func(err error) bool { return isAWSErr(err, "TargetNotConnected", "") }, RetryDelay: (&retry.Backoff{InitialBackoff: 200 * time.Millisecond, MaxBackoff: 60 * time.Second, Multiplier: 2}).Linear, - }.Run(ctx, func(ctx context.Context) error { + }.Run(ctx, func(ctx context.Context) (err error) { output, err = ssmconn.StartSessionWithContext(ctx, &input) return err }) @@ -83,56 +68,82 @@ func (s *StepCreateSSMTunnel) Run(ctx context.Context, state multistep.StateBag) state.Put("error", err) return multistep.ActionHalt } - s.ssmSession = output - - // AWS session-manager-plugin requires a valid session be passed in JSON - sessionDetails, err := json.Marshal(s.ssmSession) - if err != nil { - ui.Error(err.Error()) - state.Put("error encountered in reading session details", err) - return multistep.ActionHalt - } - sessionParameters, err := json.Marshal(input) - if err != nil { - ui.Error(err.Error()) - state.Put("error encountered in reading session parameter details", err) - return multistep.ActionHalt + driver := SSMDriver{ + Region: s.Region, + Session: output, + SessionParams: input, + SessionEndpoint: ssmconn.Endpoint, } - // Stop listening on selected port so that the AWS session-manager-plugin can use it. - // The port is closed right before we start the session to avoid two Packer builds from getting the same port - fingers-crossed - l.Close() - - driver := SSMDriver{Ui: ui, Ctx: ctx} - // sessionDetails, region, "StartSession", profile, paramJson, endpoint - region := aws.StringValue(s.AWSSession.Config.Region) - // how to best get Profile name - if err := driver.StartSession(string(sessionDetails), region, "", string(sessionParameters), ssmconn.Endpoint); err != nil { - err = fmt.Errorf("error encountered in establishing a tunnel with the session-manager-plugin: %s", err) + if err := driver.StartSession(ctx); err != nil { + err = fmt.Errorf("error encountered in establishing a tunnel %s", err) ui.Error(err.Error()) state.Put("error", err) return multistep.ActionHalt } - ui.Message(fmt.Sprintf("PortForwarding session to instance %q established!", s.instanceId)) - state.Put("sessionPort", l.Port) + ui.Message(fmt.Sprintf("PortForwarding session tunnel to instance %q established!", s.instanceId)) + state.Put("sessionPort", s.LocalPortNumber) return multistep.ActionContinue } +// Cleanup terminates an active session on AWS, which in turn terminates the associated tunnel process running on the local machine. func (s *StepCreateSSMTunnel) Cleanup(state multistep.StateBag) { - if s.ssmSession == nil { + if s.session == nil { return } ui := state.Get("ui").(packer.Ui) ssmconn := ssm.New(s.AWSSession) - _, err := ssmconn.TerminateSession(&ssm.TerminateSessionInput{SessionId: s.ssmSession.SessionId}) + _, err := ssmconn.TerminateSession(&ssm.TerminateSessionInput{SessionId: s.session.SessionId}) if err != nil { msg := fmt.Sprintf("Error terminating SSM Session %q. Please terminate the session manually: %s", - aws.StringValue(s.ssmSession.SessionId), err) + aws.StringValue(s.session.SessionId), err) ui.Error(msg) } } + +// ConfigureLocalHostPort finds an available port on the localhost that can be used for the remote tunnel. +// Defaults to using s.LocalPortNumber if it is set. +func (s *StepCreateSSMTunnel) ConfigureLocalHostPort(ctx context.Context) error { + if s.LocalPortNumber != 0 { + return nil + } + // Find an available TCP port for our HTTP server + l, err := net.ListenRangeConfig{ + Min: 8000, + Max: 9000, + Addr: "0.0.0.0", + Network: "tcp", + }.Listen(ctx) + if err != nil { + return err + } + + s.LocalPortNumber = l.Port + // Stop listening on selected port so that the AWS session-manager-plugin can use it. + // The port is closed right before we start the session to avoid two Packer builds from getting the same port - fingers-crossed + l.Close() + + return nil + +} + +func (s *StepCreateSSMTunnel) BuildTunnelInputForInstance(instance string) ssm.StartSessionInput { + dst, src := strconv.Itoa(s.RemotePortNumber), strconv.Itoa(s.LocalPortNumber) + params := map[string][]*string{ + "portNumber": []*string{aws.String(dst)}, + "localPortNumber": []*string{aws.String(src)}, + } + + input := ssm.StartSessionInput{ + DocumentName: aws.String("AWS-StartPortForwardingSession"), + Parameters: params, + Target: aws.String(instance), + } + + return input +} diff --git a/builder/amazon/common/step_create_ssm_tunnel_test.go b/builder/amazon/common/step_create_ssm_tunnel_test.go new file mode 100644 index 000000000..5f027c662 --- /dev/null +++ b/builder/amazon/common/step_create_ssm_tunnel_test.go @@ -0,0 +1,48 @@ +package common + +import ( + "context" + "reflect" + "testing" + + "github.com/aws/aws-sdk-go/aws" +) + +func TestStepCreateSSMTunnel_BuildTunnelInputForInstance(t *testing.T) { + step := StepCreateSSMTunnel{ + Region: "region", + LocalPortNumber: 8001, + RemotePortNumber: 22, + SSMAgentEnabled: true, + } + + input := step.BuildTunnelInputForInstance("i-something") + + target := aws.StringValue(input.Target) + if target != "i-something" { + t.Errorf("input should contain instance id as target but it got %q", target) + } + + params := map[string][]*string{ + "portNumber": []*string{aws.String("22")}, + "localPortNumber": []*string{aws.String("8001")}, + } + if !reflect.DeepEqual(input.Parameters, params) { + t.Errorf("input should contain the expected port parameters but it got %v", input.Parameters) + } + +} + +func TestStepCreateSSMTunnel_ConfigureLocalHostPort(t *testing.T) { + tun := StepCreateSSMTunnel{} + + ctx := context.TODO() + if err := tun.ConfigureLocalHostPort(ctx); err != nil { + t.Errorf("failed to configure a port on localhost") + } + + if tun.LocalPortNumber == 0 { + t.Errorf("failed to configure a port on localhost") + } + +}