From 7d1a989ea7c2bb2b90a9198708253754c36fb3ee Mon Sep 17 00:00:00 2001 From: Louis Ruch Date: Wed, 20 Jul 2022 15:17:15 -0700 Subject: [PATCH] feat(cli): Support using brokered private key in ssh subcommand (#2267) * feat(cli): Support using brokered private key in ssh subcommand --- internal/cmd/commands/connect/connect.go | 89 +-- internal/cmd/commands/connect/credentials.go | 115 +++- .../cmd/commands/connect/credentials_test.go | 518 +++++++++++++----- internal/cmd/commands/connect/funcs.go | 7 +- internal/cmd/commands/connect/postgres.go | 35 +- internal/cmd/commands/connect/ssh.go | 89 ++- 6 files changed, 615 insertions(+), 238 deletions(-) diff --git a/internal/cmd/commands/connect/connect.go b/internal/cmd/commands/connect/connect.go index 90f99977c2..3cb7e04c37 100644 --- a/internal/cmd/commands/connect/connect.go +++ b/internal/cmd/commands/connect/connect.go @@ -483,28 +483,16 @@ func (c *Command) Run(args []string) (retCode int) { c.listenerAddr = c.listener.Addr().(*net.TCPAddr) - var creds []*targets.SessionCredential - if c.sessionAuthz != nil && len(c.sessionAuthz.Credentials) > 0 { - creds = c.sessionAuthz.Credentials - } - switch c.Func { - case "postgres": - // Credentials are brokered when connecting to the postgres db. - // TODO: Figure out how to handle cases where we don't automatically know how to - // broker the credentials like unrecognized or multiple credentials. - case "ssh": - if c.flagExec == "sshpass" { - // If we can broker ssh credentials using sshpass, skip displaying creds on client side. - // TODO: Figure out how to handle multiple credentials - break - } - - // We cannot broker credentials print to user - fallthrough - case "connect": + if c.Func == "connect" { // "connect" indicates there is no subcommand to the connect function. // The only way a user will be able to connect to the session is by // connecting directly to the port and address we report to them here. + + var creds []*targets.SessionCredential + if c.sessionAuthz != nil && len(c.sessionAuthz.Credentials) > 0 { + creds = c.sessionAuthz.Credentials + } + sessInfo := SessionInfo{ Protocol: c.sessionAuthzData.GetType(), Address: c.listenerAddr.IP.String(), @@ -525,25 +513,6 @@ func (c *Command) Run(args []string) (retCode int) { } c.UI.Output(string(out)) } - default: - if len(creds) == 0 { - break - } - switch base.Format(c.UI) { - case "table": - c.UI.Output(generateCredentialTableOutput(creds)) - case "json": - out, err := json.Marshal(&struct { - Credentials []*targets.SessionCredential `json:"credentials"` - }{ - Credentials: c.sessionAuthz.Credentials, - }) - if err != nil { - c.PrintCliError(fmt.Errorf("error marshaling session information: %w", err)) - return base.CommandCliError - } - c.UI.Output(string(out)) - } } c.connWg = new(sync.WaitGroup) @@ -681,6 +650,27 @@ func (c *Command) Run(args []string) (retCode int) { return } +func (c *Command) printCredentials(creds []*targets.SessionCredential) error { + if len(creds) == 0 { + return nil + } + switch base.Format(c.UI) { + case "table": + c.UI.Output(generateCredentialTableOutput(creds)) + case "json": + out, err := json.Marshal(&struct { + Credentials []*targets.SessionCredential `json:"credentials"` + }{ + Credentials: creds, + }) + if err != nil { + return fmt.Errorf("error marshaling credential information: %w", err) + } + c.UI.Output(string(out)) + } + return nil +} + func (c *Command) getWsConn( ctx context.Context, workerAddr string, @@ -824,6 +814,17 @@ func (c *Command) handleExec(passthroughArgs []string) { var envs []string var argsErr error + var creds credentials + if c.sessionAuthz != nil { + var err error + creds, err = parseCredentials(c.sessionAuthz.Credentials) + if err != nil { + c.PrintCliError(fmt.Errorf("Error interpreting secret: %w", err)) + c.execCmdReturnValue.Store(int32(3)) + return + } + } + switch c.Func { case "http": httpArgs, err := c.httpFlags.buildArgs(c, port, ip, addr) @@ -835,25 +836,27 @@ func (c *Command) handleExec(passthroughArgs []string) { args = append(args, httpArgs...) case "postgres": - pgArgs, pgEnvs, pgErr := c.postgresFlags.buildArgs(c, port, ip, addr) + pgArgs, pgEnvs, pgCreds, pgErr := c.postgresFlags.buildArgs(c, port, ip, addr, creds) if pgErr != nil { argsErr = pgErr break } args = append(args, pgArgs...) envs = append(envs, pgEnvs...) + creds = pgCreds case "rdp": args = append(args, c.rdpFlags.buildArgs(c, port, ip, addr)...) case "ssh": - sshArgs, sshEnvs, sshErr := c.sshFlags.buildArgs(c, port, ip, addr) + sshArgs, sshEnvs, sshCreds, sshErr := c.sshFlags.buildArgs(c, port, ip, addr, creds) if sshErr != nil { argsErr = sshErr break } args = append(args, sshArgs...) envs = append(envs, sshEnvs...) + creds = sshCreds case "kube": kubeArgs, err := c.kubeFlags.buildArgs(c, port, ip, addr) @@ -871,6 +874,12 @@ func (c *Command) handleExec(passthroughArgs []string) { return } + if err := c.printCredentials(creds.unconsumedSessionCredentials()); err != nil { + c.PrintCliError(fmt.Errorf("Failed to print credentials: %w", err)) + c.execCmdReturnValue.Store(int32(2)) + return + } + args = append(passthroughArgs, args...) stringReplacer := func(in, typ, replacer string) string { diff --git a/internal/cmd/commands/connect/credentials.go b/internal/cmd/commands/connect/credentials.go index 7ca65cf30f..db64b31cf7 100644 --- a/internal/cmd/commands/connect/credentials.go +++ b/internal/cmd/commands/connect/credentials.go @@ -8,54 +8,115 @@ import ( "github.com/mitchellh/mapstructure" ) -type usernamePasswordCredentials struct { +type usernamePassword struct { Username string `mapstructure:"username"` Password string `mapstructure:"password"` + + raw *targets.SessionCredential + consumed bool +} + +type sshPrivateKey struct { + Username string `mapstructure:"username"` + PrivateKey string `mapstructure:"private_key"` + + raw *targets.SessionCredential + consumed bool +} + +type credentials struct { + usernamePassword []usernamePassword + sshPrivateKey []sshPrivateKey + unspecified []*targets.SessionCredential } -// TODO: this should support multiple types of credentials and return an interface instead -// of only usernamepassword. Each subcommand can use the returned credentials as it needs. -func parseCredentials(creds []*targets.SessionCredential) ([]usernamePasswordCredentials, error) { +func (c credentials) unconsumedSessionCredentials() []*targets.SessionCredential { + out := make([]*targets.SessionCredential, 0, len(c.sshPrivateKey)+len(c.usernamePassword)+len(c.unspecified)) + + // Unspecified credentials cannot be consumed + out = append(out, c.unspecified...) + + for _, c := range c.sshPrivateKey { + if !c.consumed { + out = append(out, c.raw) + } + } + for _, c := range c.usernamePassword { + if !c.consumed { + out = append(out, c.raw) + } + } + return out +} + +func parseCredentials(creds []*targets.SessionCredential) (credentials, error) { if creds == nil { - return nil, nil + return credentials{}, nil } - var out []usernamePasswordCredentials + var out credentials for _, cred := range creds { if cred.CredentialSource == nil { - return nil, errors.New("missing credential source") + return credentials{}, errors.New("missing credential source") } - var c usernamePasswordCredentials - if cred.CredentialSource.CredentialType == string(credential.UsernamePasswordType) { + var upCred usernamePassword + var spkCred sshPrivateKey + switch credential.Type(cred.CredentialSource.CredentialType) { + case credential.UsernamePasswordType: // Decode attributes from credential struct - if err := mapstructure.Decode(cred.Credential, &c); err != nil { - return nil, err + if err := mapstructure.Decode(cred.Credential, &upCred); err != nil { + return credentials{}, err } - if c.Username != "" && c.Password != "" { - out = append(out, c) + if upCred.Username != "" && upCred.Password != "" { + upCred.raw = cred + out.usernamePassword = append(out.usernamePassword, upCred) continue } - } - // Credential type is unspecified, make a best effort attempt to parse both username - // and password from Decoded field if it exists - if cred.Secret == nil || cred.Secret.Decoded == nil { - // No secret data continue to next credential - continue - } + case credential.SshPrivateKeyType: + // Decode attributes from credential struct + if err := mapstructure.Decode(cred.Credential, &spkCred); err != nil { + return credentials{}, err + } - switch cred.CredentialSource.Type { - case "vault", "static": - // Attempt unmarshaling into creds - if err := mapstructure.Decode(cred.Secret.Decoded, &c); err != nil { - return nil, err + if spkCred.Username != "" && spkCred.PrivateKey != "" { + spkCred.raw = cred + out.sshPrivateKey = append(out.sshPrivateKey, spkCred) + continue } } - if c.Username != "" && c.Password != "" { - out = append(out, c) + // Credential type is unspecified, make a best effort attempt to parse + // a username_password credential from the Decoded field if it exists + if cred.Secret != nil && cred.Secret.Decoded != nil { + switch cred.CredentialSource.Type { + case "vault", "static": + // Attempt unmarshaling into username password creds + if err := mapstructure.Decode(cred.Secret.Decoded, &upCred); err != nil { + return credentials{}, err + } + if upCred.Username != "" && upCred.Password != "" { + upCred.raw = cred + out.usernamePassword = append(out.usernamePassword, upCred) + continue + } + + // Attempt unmarshaling into ssh private key creds + if err := mapstructure.Decode(cred.Secret.Decoded, &spkCred); err != nil { + return credentials{}, err + } + if spkCred.Username != "" && spkCred.PrivateKey != "" { + spkCred.raw = cred + out.sshPrivateKey = append(out.sshPrivateKey, spkCred) + continue + } + } } + + // We could not parse the credential + out.unspecified = append(out.unspecified, cred) } + return out, nil } diff --git a/internal/cmd/commands/connect/credentials_test.go b/internal/cmd/commands/connect/credentials_test.go index 5d1f8615f4..8b63433a5a 100644 --- a/internal/cmd/commands/connect/credentials_test.go +++ b/internal/cmd/commands/connect/credentials_test.go @@ -9,222 +9,472 @@ import ( "github.com/stretchr/testify/require" ) -func Test_parseUsernamePasswordCredentials(t *testing.T) { +var ( + typedUsernamePassword = &targets.SessionCredential{ + CredentialSource: &targets.CredentialSource{ + CredentialType: string(credential.UsernamePasswordType), + }, + Credential: map[string]interface{}{ + "username": "user", + "password": "pass", + }, + } + + typedSshPrivateKey = &targets.SessionCredential{ + CredentialSource: &targets.CredentialSource{ + CredentialType: string(credential.SshPrivateKeyType), + }, + Credential: map[string]interface{}{ + "username": "user", + "private_key": "my-pk", + }, + } + + vaultUsernamePassword = &targets.SessionCredential{ + CredentialSource: &targets.CredentialSource{ + Type: "vault", + }, + Secret: &targets.SessionSecret{ + Decoded: map[string]interface{}{ + "username": "vault-decoded-user", + "password": "vault-decoded-pass", + }, + }, + } + + vaultSshPrivateKey = &targets.SessionCredential{ + CredentialSource: &targets.CredentialSource{ + Type: "vault", + }, + Secret: &targets.SessionSecret{ + Decoded: map[string]interface{}{ + "username": "vault-decoded-user", + "private_key": "vault-decoded-pk", + }, + }, + } + + staticUsernamePassword = &targets.SessionCredential{ + CredentialSource: &targets.CredentialSource{ + Type: "static", + }, + Secret: &targets.SessionSecret{ + Decoded: map[string]interface{}{ + "username": "static-decoded-user", + "password": "static-decoded-pass", + }, + }, + } + + staticSshPrivateKey = &targets.SessionCredential{ + CredentialSource: &targets.CredentialSource{ + Type: "static", + }, + Secret: &targets.SessionSecret{ + Decoded: map[string]interface{}{ + "username": "static-decoded-user", + "private_key": "static-decoded-pk", + }, + }, + } + + unspecifiedCred = &targets.SessionCredential{ + CredentialSource: &targets.CredentialSource{ + Type: "static", + }, + Secret: &targets.SessionSecret{ + Decoded: map[string]interface{}{ + "username": "decoded-user", + "some-value": "decoded-some-value", + }, + }, + } + + unspecifiedCred1 = &targets.SessionCredential{ + CredentialSource: &targets.CredentialSource{ + Type: "static", + }, + Secret: &targets.SessionSecret{ + Decoded: map[string]interface{}{ + "username": "decoded-user", + "some-value1": "decoded-some-value1", + }, + }, + } +) + +func Test_parseCredentials(t *testing.T) { tests := []struct { name string creds []*targets.SessionCredential - wantCreds []usernamePasswordCredentials + wantCreds credentials wantErr bool }{ { - name: "no-creds", - wantCreds: nil, - wantErr: false, + name: "no-creds", + wantErr: false, }, { name: "no-credential-source", creds: []*targets.SessionCredential{ { - Credential: map[string]interface{}{ - "username": "user", - "password": "pass", + Secret: &targets.SessionSecret{ + Decoded: map[string]interface{}{ + "username": "decoded-user", + "private_key": "decoded-pk", + }, }, }, }, - wantCreds: nil, - wantErr: true, + wantErr: true, }, { - name: "valid-typed", + name: "username-password-typed", creds: []*targets.SessionCredential{ - { - CredentialSource: &targets.CredentialSource{ - CredentialType: string(credential.UsernamePasswordType), - }, - Credential: map[string]interface{}{ - "username": "user", - "password": "pass", - }, - }, + typedUsernamePassword, }, - wantCreds: []usernamePasswordCredentials{ - { - Username: "user", - Password: "pass", + wantCreds: credentials{ + usernamePassword: []usernamePassword{ + { + Username: "user", + Password: "pass", + raw: typedUsernamePassword, + }, }, }, wantErr: false, }, { - name: "valid-typed-instead-of-decoded", + name: "ssh-private-key-typed", creds: []*targets.SessionCredential{ - { - CredentialSource: &targets.CredentialSource{ - CredentialType: string(credential.UsernamePasswordType), - }, - Credential: map[string]interface{}{ - "username": "user", - "password": "pass", - }, - Secret: &targets.SessionSecret{ - Decoded: map[string]interface{}{ - "username": "secret-user", - "password": "secret-pass", - }, + typedSshPrivateKey, + }, + wantCreds: credentials{ + sshPrivateKey: []sshPrivateKey{ + { + Username: "user", + PrivateKey: "my-pk", + raw: typedSshPrivateKey, }, }, }, - wantCreds: []usernamePasswordCredentials{ - { - Username: "user", - Password: "pass", + wantErr: false, + }, + { + name: "vault-username-password-decoded", + creds: []*targets.SessionCredential{ + vaultUsernamePassword, + }, + wantCreds: credentials{ + usernamePassword: []usernamePassword{ + { + Username: "vault-decoded-user", + Password: "vault-decoded-pass", + raw: vaultUsernamePassword, + }, }, }, wantErr: false, }, { - name: "valid-vault-not-typed", + name: "vault-private-key-decoded", creds: []*targets.SessionCredential{ - { - CredentialSource: &targets.CredentialSource{ - Type: "vault", - }, - Secret: &targets.SessionSecret{ - Decoded: map[string]interface{}{ - "username": "user", - "password": "pass", - }, + vaultSshPrivateKey, + }, + wantCreds: credentials{ + sshPrivateKey: []sshPrivateKey{ + { + Username: "vault-decoded-user", + PrivateKey: "vault-decoded-pk", + raw: vaultSshPrivateKey, }, }, }, - wantCreds: []usernamePasswordCredentials{ - { - Username: "user", - Password: "pass", + wantErr: false, + }, + { + name: "static-username-password-decoded", + creds: []*targets.SessionCredential{ + staticUsernamePassword, + }, + wantCreds: credentials{ + usernamePassword: []usernamePassword{ + { + Username: "static-decoded-user", + Password: "static-decoded-pass", + raw: staticUsernamePassword, + }, }, }, wantErr: false, }, { - name: "valid-static-not-typed", + name: "static-private-key-decoded", creds: []*targets.SessionCredential{ - { - CredentialSource: &targets.CredentialSource{ - Type: "static", - }, - Secret: &targets.SessionSecret{ - Decoded: map[string]interface{}{ - "username": "user", - "password": "pass", - }, + staticSshPrivateKey, + }, + wantCreds: credentials{ + sshPrivateKey: []sshPrivateKey{ + { + Username: "static-decoded-user", + PrivateKey: "static-decoded-pk", + raw: staticSshPrivateKey, }, }, }, - wantCreds: []usernamePasswordCredentials{ - { - Username: "user", - Password: "pass", + wantErr: false, + }, + { + name: "unspecified", + creds: []*targets.SessionCredential{ + unspecifiedCred, + }, + wantCreds: credentials{ + unspecified: []*targets.SessionCredential{ + unspecifiedCred, }, }, wantErr: false, }, { - name: "valid-multiple", + name: "mixed", creds: []*targets.SessionCredential{ - { - CredentialSource: &targets.CredentialSource{ - CredentialType: string(credential.UsernamePasswordType), + staticSshPrivateKey, unspecifiedCred1, vaultSshPrivateKey, typedUsernamePassword, + unspecifiedCred, vaultUsernamePassword, typedSshPrivateKey, staticUsernamePassword, + }, + wantCreds: credentials{ + sshPrivateKey: []sshPrivateKey{ + { + Username: "static-decoded-user", + PrivateKey: "static-decoded-pk", + raw: staticSshPrivateKey, + }, + { + Username: "vault-decoded-user", + PrivateKey: "vault-decoded-pk", + raw: vaultSshPrivateKey, }, - Credential: map[string]interface{}{ - "username": "user", - "password": "pass", + { + Username: "user", + PrivateKey: "my-pk", + raw: typedSshPrivateKey, }, }, - { - CredentialSource: &targets.CredentialSource{ - CredentialType: string(credential.UsernamePasswordType), + usernamePassword: []usernamePassword{ + { + Username: "static-decoded-user", + Password: "static-decoded-pass", + raw: staticUsernamePassword, + }, + { + Username: "vault-decoded-user", + Password: "vault-decoded-pass", + raw: vaultUsernamePassword, }, - Credential: map[string]interface{}{ - "username": "user1", - "password": "pass1", + { + Username: "user", + Password: "pass", + raw: typedUsernamePassword, }, }, + unspecified: []*targets.SessionCredential{ + unspecifiedCred, unspecifiedCred1, + }, }, - wantCreds: []usernamePasswordCredentials{ - { - Username: "user", - Password: "pass", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + creds, err := parseCredentials(tt.creds) + if tt.wantErr { + require.Error(err) + assert.Empty(creds) + return + } + require.NoError(err) + + assert.ElementsMatch(tt.wantCreds.usernamePassword, creds.usernamePassword) + assert.ElementsMatch(tt.wantCreds.sshPrivateKey, creds.sshPrivateKey) + assert.ElementsMatch(tt.wantCreds.unspecified, creds.unspecified) + }) + } +} + +func Test_unconsumedSessionCredentials(t *testing.T) { + tests := []struct { + name string + creds credentials + wantCreds []*targets.SessionCredential + }{ + { + name: "no-creds", + wantCreds: nil, + }, + { + name: "spk-consumed", + creds: credentials{ + sshPrivateKey: []sshPrivateKey{ + { + raw: staticSshPrivateKey, + consumed: true, + }, }, - { - Username: "user1", - Password: "pass1", + }, + wantCreds: nil, + }, + { + name: "spk", + creds: credentials{ + sshPrivateKey: []sshPrivateKey{ + { + raw: staticSshPrivateKey, + }, }, }, - wantErr: false, + wantCreds: []*targets.SessionCredential{staticSshPrivateKey}, }, { - name: "valid-multiple-mixed", - creds: []*targets.SessionCredential{ - { - CredentialSource: &targets.CredentialSource{ - CredentialType: string(credential.UsernamePasswordType), + name: "up", + creds: credentials{ + usernamePassword: []usernamePassword{ + { + raw: vaultUsernamePassword, }, - Credential: map[string]interface{}{ - "username": "user", - "password": "pass", + }, + }, + wantCreds: []*targets.SessionCredential{vaultUsernamePassword}, + }, + { + name: "up-consumed", + creds: credentials{ + usernamePassword: []usernamePassword{ + { + raw: vaultUsernamePassword, + consumed: true, }, }, - { - CredentialSource: &targets.CredentialSource{ - Type: "vault", + }, + wantCreds: nil, + }, + { + name: "unspecified", + creds: credentials{ + unspecified: []*targets.SessionCredential{unspecifiedCred}, + }, + wantCreds: []*targets.SessionCredential{unspecifiedCred}, + }, + { + name: "mixed", + creds: credentials{ + sshPrivateKey: []sshPrivateKey{ + { + raw: staticSshPrivateKey, + consumed: true, }, - Secret: &targets.SessionSecret{ - Decoded: map[string]interface{}{ - "username": "user1", - "password": "pass1", - }, + { + raw: vaultSshPrivateKey, + }, + { + raw: typedSshPrivateKey, }, }, - { - CredentialSource: &targets.CredentialSource{ - Type: "static", + usernamePassword: []usernamePassword{ + { + raw: staticUsernamePassword, + consumed: true, }, - Secret: &targets.SessionSecret{ - Decoded: map[string]interface{}{ - "username": "user2", - "password": "pass2", - }, + { + raw: vaultUsernamePassword, + }, + { + raw: typedUsernamePassword, + consumed: true, }, }, + unspecified: []*targets.SessionCredential{unspecifiedCred, unspecifiedCred1}, }, - wantCreds: []usernamePasswordCredentials{ - { - Username: "user", - Password: "pass", + wantCreds: []*targets.SessionCredential{ + vaultSshPrivateKey, typedSshPrivateKey, vaultUsernamePassword, unspecifiedCred, unspecifiedCred1, + }, + }, + { + name: "mixed-all-consumed", + creds: credentials{ + sshPrivateKey: []sshPrivateKey{ + { + raw: staticSshPrivateKey, + consumed: true, + }, + { + raw: vaultSshPrivateKey, + consumed: true, + }, + { + raw: typedSshPrivateKey, + consumed: true, + }, }, - { - Username: "user1", - Password: "pass1", + usernamePassword: []usernamePassword{ + { + raw: staticUsernamePassword, + consumed: true, + }, + { + raw: vaultUsernamePassword, + consumed: true, + }, + { + raw: typedUsernamePassword, + consumed: true, + }, }, - { - Username: "user2", - Password: "pass2", + unspecified: []*targets.SessionCredential{unspecifiedCred, unspecifiedCred1}, + }, + wantCreds: []*targets.SessionCredential{ + unspecifiedCred1, unspecifiedCred, + }, + }, + { + name: "mixed-all-unconsumed", + creds: credentials{ + sshPrivateKey: []sshPrivateKey{ + { + raw: staticSshPrivateKey, + }, + { + raw: vaultSshPrivateKey, + }, + { + raw: typedSshPrivateKey, + }, }, + usernamePassword: []usernamePassword{ + { + raw: staticUsernamePassword, + }, + { + raw: vaultUsernamePassword, + }, + { + raw: typedUsernamePassword, + }, + }, + unspecified: []*targets.SessionCredential{unspecifiedCred, unspecifiedCred1}, + }, + wantCreds: []*targets.SessionCredential{ + staticSshPrivateKey, unspecifiedCred1, vaultSshPrivateKey, typedUsernamePassword, + unspecifiedCred, vaultUsernamePassword, typedSshPrivateKey, staticUsernamePassword, }, - wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) + assert := assert.New(t) - creds, err := parseCredentials(tt.creds) - if tt.wantErr { - require.Error(err) - assert.Nil(creds) - return - } - require.NoError(err) + creds := tt.creds.unconsumedSessionCredentials() assert.ElementsMatch(tt.wantCreds, creds) }) } diff --git a/internal/cmd/commands/connect/funcs.go b/internal/cmd/commands/connect/funcs.go index f44c789a52..555d1bf03b 100644 --- a/internal/cmd/commands/connect/funcs.go +++ b/internal/cmd/commands/connect/funcs.go @@ -46,6 +46,11 @@ func generateCredentialTableOutput(creds []*targets.SessionCredential) string { func generateCredentialTableOutputSlice(prefixIndent int, creds []*targets.SessionCredential) []string { var ret []string prefixString := strings.Repeat(" ", prefixIndent) + + if len(creds) > 0 { + // Add credential header + ret = append(ret, fmt.Sprintf("%sCredentials:", prefixString)) + } for _, crd := range creds { credMap := map[string]interface{}{ "Credential Store ID": crd.CredentialSource.CredentialStoreId, @@ -63,12 +68,12 @@ func generateCredentialTableOutputSlice(prefixIndent int, creds []*targets.Sessi } maxLength := base.MaxAttributesLength(credMap, nil, nil) ret = append(ret, - fmt.Sprintf("%sCredentials:", prefixString), base.WrapMap(2+prefixIndent, maxLength, credMap), fmt.Sprintf("%s Secret:", prefixString)) ret = append(ret, fmtSecretForTable(2+prefixIndent, crd)..., ) + ret = append(ret, "") } return ret diff --git a/internal/cmd/commands/connect/postgres.go b/internal/cmd/commands/connect/postgres.go index 9ab0abe2af..884dc8f701 100644 --- a/internal/cmd/commands/connect/postgres.go +++ b/internal/cmd/commands/connect/postgres.go @@ -51,18 +51,17 @@ func (p *postgresFlags) defaultExec() string { return strings.ToLower(p.flagPostgresStyle) } -func (p *postgresFlags) buildArgs(c *Command, port, ip, addr string) (args, envs []string, retErr error) { - var cred usernamePasswordCredentials - if c.sessionAuthz != nil { - creds, err := parseCredentials(c.sessionAuthz.Credentials) - if err != nil { - return nil, nil, fmt.Errorf("Error interpreting secret: %w", err) - } - if len(creds) > 0 { - // Just use first credentials returned - cred = creds[0] - } +func (p *postgresFlags) buildArgs(c *Command, port, ip, addr string, creds credentials) (args, envs []string, retCreds credentials, retErr error) { + var username, password string + + retCreds = creds + if len(retCreds.usernamePassword) > 0 { + // Mark credential as consumed so it is not printed to user + retCreds.usernamePassword[0].consumed = true + // For now just grab the first username password credential brokered + username = retCreds.usernamePassword[0].Username + password = retCreds.usernamePassword[0].Password } switch p.flagPostgresStyle { @@ -74,16 +73,16 @@ func (p *postgresFlags) buildArgs(c *Command, port, ip, addr string) (args, envs } switch { - case cred.Username != "": - args = append(args, "-U", cred.Username) + case username != "": + args = append(args, "-U", username) case c.flagUsername != "": args = append(args, "-U", c.flagUsername) } - if cred.Password != "" { + if password != "" { passfile, err := ioutil.TempFile("", "*") if err != nil { - return nil, nil, fmt.Errorf("Error saving postgres password to tmp file: %w", err) + return nil, nil, credentials{}, fmt.Errorf("Error saving postgres password to tmp file: %w", err) } c.cleanupFuncs = append(c.cleanupFuncs, func() error { if err := os.Remove(passfile.Name()); err != nil { @@ -91,12 +90,12 @@ func (p *postgresFlags) buildArgs(c *Command, port, ip, addr string) (args, envs } return nil }) - _, err = passfile.WriteString(fmt.Sprintf("*:*:*:*:%s", cred.Password)) + _, err = passfile.WriteString(fmt.Sprintf("*:*:*:*:%s", password)) if err != nil { - return nil, nil, fmt.Errorf("Error writing password file to %s: %w", passfile.Name(), err) + return nil, nil, credentials{}, fmt.Errorf("Error writing password file to %s: %w", passfile.Name(), err) } if err := passfile.Close(); err != nil { - return nil, nil, fmt.Errorf("Error closing password file after writing to %s: %w", passfile.Name(), err) + return nil, nil, credentials{}, fmt.Errorf("Error closing password file after writing to %s: %w", passfile.Name(), err) } envs = append(envs, fmt.Sprintf("PGPASSFILE=%s", passfile.Name())) diff --git a/internal/cmd/commands/connect/ssh.go b/internal/cmd/commands/connect/ssh.go index 38ae464965..bf7e309d4f 100644 --- a/internal/cmd/commands/connect/ssh.go +++ b/internal/cmd/commands/connect/ssh.go @@ -3,6 +3,8 @@ package connect import ( "errors" "fmt" + "io/ioutil" + "os" "strings" "github.com/hashicorp/boundary/internal/cmd/base" @@ -42,18 +44,9 @@ func (s *sshFlags) defaultExec() string { return strings.ToLower(s.flagSshStyle) } -func (s *sshFlags) buildArgs(c *Command, port, ip, addr string) (args, envs []string, retErr error) { - var cred usernamePasswordCredentials - if c.sessionAuthz != nil { - creds, err := parseCredentials(c.sessionAuthz.Credentials) - if err != nil { - return nil, nil, fmt.Errorf("Error interpreting secret: %w", err) - } - if len(creds) > 0 { - // Just use first credentials returned - cred = creds[0] - } - } +func (s *sshFlags) buildArgs(c *Command, port, ip, addr string, creds credentials) (args, envs []string, retCreds credentials, retErr error) { + var username string + retCreds = creds switch strings.ToLower(s.flagSshStyle) { case "ssh": @@ -65,12 +58,23 @@ func (s *sshFlags) buildArgs(c *Command, port, ip, addr string) (args, envs []st args = append(args, "-o", "NoHostAuthenticationForLocalhost=yes") case "sshpass": - if cred.Password == "" { - return nil, nil, errors.New("Password is required when using sshpass") + var password string + if len(retCreds.usernamePassword) > 0 { + // For now just grab the first username password credential brokered + // Mark credential as consumed so it is not printed to user + retCreds.usernamePassword[0].consumed = true + + username = retCreds.usernamePassword[0].Username + password = retCreds.usernamePassword[0].Password + } + + if password == "" { + return nil, nil, credentials{}, errors.New("Password is required when using sshpass") } + // sshpass requires that the password is passed as env-var "SSHPASS" // when the using env-vars. - envs = append(envs, fmt.Sprintf("SSHPASS=%s", cred.Password)) + envs = append(envs, fmt.Sprintf("SSHPASS=%s", password)) args = append(args, "-e", "ssh") args = append(args, "-p", port, ip) @@ -82,12 +86,61 @@ func (s *sshFlags) buildArgs(c *Command, port, ip, addr string) (args, envs []st args = append(args, "-P", port, ip) } + // Check if we got credentials to attempt to use for ssh or putty, + // sshpass style has already be handled above as username password. + switch strings.ToLower(s.flagSshStyle) { + case "putty", "ssh": + + switch { + // First check if we have a private key available + case len(creds.sshPrivateKey) > 0: + // For now just grab the first ssh private key credential brokered + // Mark credential as consumed so it is not printed to user + retCreds.sshPrivateKey[0].consumed = true + + username = retCreds.sshPrivateKey[0].Username + privateKey := retCreds.sshPrivateKey[0].PrivateKey + + pkFile, err := ioutil.TempFile("", "*") + if err != nil { + return nil, nil, credentials{}, fmt.Errorf("Error saving ssh private key to tmp file: %w", err) + } + c.cleanupFuncs = append(c.cleanupFuncs, func() error { + if err := os.Remove(pkFile.Name()); err != nil { + return fmt.Errorf("Error removing temporary ssh private key file; consider removing %s manually: %w", pkFile.Name(), err) + } + return nil + }) + _, err = pkFile.WriteString(privateKey) + if err != nil { + return nil, nil, credentials{}, fmt.Errorf("Error writing private key file to %s: %w", pkFile.Name(), err) + } + if err := pkFile.Close(); err != nil { + return nil, nil, credentials{}, fmt.Errorf("Error closing private key file after writing to %s: %w", pkFile.Name(), err) + } + args = append(args, "-i", pkFile.Name()) + + // Next check if we have a username password credential + case len(creds.usernamePassword) > 0: + // We cannot use the password of the credential outside of sshpass but we + // can use the username. + // N.B. Do not mark credential as consumed, as user will still need enter + // the password when prompted. + + if c.flagUsername == "" { + // If the user did not actively provide a username flag set the + // username to that of the first credential we got. + username = retCreds.usernamePassword[0].Username + } + } + } + switch { - case cred.Username != "": - args = append(args, "-l", cred.Username) + case username != "": + args = append(args, "-l", username) case c.flagUsername != "": args = append(args, "-l", c.flagUsername) } - return args, envs, nil + return args, envs, retCreds, nil }