From 1f62249097ddc1a2820e36c060155c811a6fd0a7 Mon Sep 17 00:00:00 2001 From: sylviamoss Date: Tue, 29 Sep 2020 12:31:10 +0200 Subject: [PATCH] add retry terminated session chan --- builder/amazon/common/ssm_driver.go | 66 ++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 20 deletions(-) diff --git a/builder/amazon/common/ssm_driver.go b/builder/amazon/common/ssm_driver.go index 2627ce0a8..4eb62c931 100644 --- a/builder/amazon/common/ssm_driver.go +++ b/builder/amazon/common/ssm_driver.go @@ -37,8 +37,9 @@ type SSMDriver struct { sessionParams ssm.StartSessionInput pluginCmdFunc func(context.Context) error - retryConnection chan bool - wg sync.WaitGroup + retryConnection chan bool + retryAfterTermination chan bool + wg sync.WaitGroup } func NewSSMDriver(config SSMDriverConfig) *SSMDriver { @@ -55,33 +56,40 @@ func (d *SSMDriver) StartSession(ctx context.Context, input ssm.StartSessionInpu defer d.wg.Done() log.Printf("Starting PortForwarding session to instance %q", aws.StringValue(input.Target)) - var output *ssm.StartSessionOutput - err := retry.Config{ - ShouldRetry: func(err error) bool { return IsAWSErr(err, "TargetNotConnected", "") }, - RetryDelay: (&retry.Backoff{InitialBackoff: 200 * time.Millisecond, MaxBackoff: 60 * time.Second, Multiplier: 2}).Linear, - }.Run(ctx, func(ctx context.Context) (err error) { - output, err = d.SvcClient.StartSessionWithContext(ctx, &input) - return err - }) + output, err := d.StartSessionWithContext(ctx, input) if err != nil { return nil, fmt.Errorf("error encountered in starting session for instance %q: %s", aws.StringValue(input.Target), err) } d.retryConnection = make(chan bool, 1) + d.retryAfterTermination = make(chan bool, 1) // Starts go routine that will keep listening to a retry channel and retry the session creation when needed. - // The log loop will add data to the retry channel whenever a retryable error happens to session. + // The log polling process will add data to the retry channel whenever a retryable error happens to session. go func(ctx context.Context, driver *SSMDriver, input ssm.StartSessionInput) { for { select { case <-ctx.Done(): return - case <-driver.retryConnection: - d.wg.Wait() - log.Printf("[DEBUG] Retrying to establish SSM connection") - _, err := driver.StartSession(ctx, input) - if err != nil { + case r := <-driver.retryAfterTermination: + if r { + d.wg.Wait() + log.Printf("[DEBUG] Restablishing SSM connection") + _, _ = driver.StartSession(ctx, input) + // Close channels and end goroutine. Another goroutine will start + // and the channels wil be reopened. + close(driver.retryConnection) + close(driver.retryAfterTermination) return } + case r := <-driver.retryConnection: + if r { + d.wg.Wait() + log.Printf("[DEBUG] Retrying to establish SSM connection") + _, err := driver.StartSessionWithContext(ctx, input) + if err != nil { + return + } + } } } }(ctx, d, input) @@ -100,6 +108,20 @@ func (d *SSMDriver) StartSession(ctx context.Context, input ssm.StartSessionInpu return d.session, nil } +func (d *SSMDriver) StartSessionWithContext(ctx context.Context, input ssm.StartSessionInput) (*ssm.StartSessionOutput, error) { + d.wg.Add(1) + defer d.wg.Done() + var output *ssm.StartSessionOutput + err := retry.Config{ + ShouldRetry: func(err error) bool { return IsAWSErr(err, "TargetNotConnected", "") }, + RetryDelay: (&retry.Backoff{InitialBackoff: 200 * time.Millisecond, MaxBackoff: 60 * time.Second, Multiplier: 2}).Linear, + }.Run(ctx, func(ctx context.Context) (err error) { + output, err = d.SvcClient.StartSessionWithContext(ctx, &input) + return err + }) + return output, err +} + func (d *SSMDriver) openTunnelForSession(ctx context.Context) error { args, err := d.Args() if err != nil { @@ -163,6 +185,7 @@ func (d *SSMDriver) openTunnelForSession(ctx context.Context) error { if stdoutCh == nil && stderrCh == nil { log.Printf("[DEBUG] %s: %s", prefix, "active session has been terminated; stopping all log polling processes.") + d.retryAfterTermination <- true return } } @@ -198,10 +221,6 @@ func (d *SSMDriver) StopSession() error { d.wg.Add(1) defer d.wg.Done() - if d.retryConnection != nil { - close(d.retryConnection) - } - if d.session == nil || d.session.SessionId == nil { return fmt.Errorf("Unable to find a valid session to instance %q; skipping the termination step", aws.StringValue(d.sessionParams.Target)) @@ -211,6 +230,13 @@ func (d *SSMDriver) StopSession() error { if err != nil { err = fmt.Errorf("Error terminating SSM Session %q. Please terminate the session manually: %s", aws.StringValue(d.session.SessionId), err) } + + if d.retryConnection != nil { + close(d.retryConnection) + } + if d.retryAfterTermination != nil { + close(d.retryAfterTermination) + } return err }