feat(cli): Support using brokered private key in ssh subcommand (#2267)

* feat(cli): Support using brokered private key in ssh subcommand
pull/2238/head^2
Louis Ruch 4 years ago committed by GitHub
parent 6d5abba2e5
commit 7d1a989ea7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -483,28 +483,16 @@ func (c *Command) Run(args []string) (retCode int) {
c.listenerAddr = c.listener.Addr().(*net.TCPAddr)
var creds []*targets.SessionCredential
if c.sessionAuthz != nil && len(c.sessionAuthz.Credentials) > 0 {
creds = c.sessionAuthz.Credentials
}
switch c.Func {
case "postgres":
// Credentials are brokered when connecting to the postgres db.
// TODO: Figure out how to handle cases where we don't automatically know how to
// broker the credentials like unrecognized or multiple credentials.
case "ssh":
if c.flagExec == "sshpass" {
// If we can broker ssh credentials using sshpass, skip displaying creds on client side.
// TODO: Figure out how to handle multiple credentials
break
}
// We cannot broker credentials print to user
fallthrough
case "connect":
if c.Func == "connect" {
// "connect" indicates there is no subcommand to the connect function.
// The only way a user will be able to connect to the session is by
// connecting directly to the port and address we report to them here.
var creds []*targets.SessionCredential
if c.sessionAuthz != nil && len(c.sessionAuthz.Credentials) > 0 {
creds = c.sessionAuthz.Credentials
}
sessInfo := SessionInfo{
Protocol: c.sessionAuthzData.GetType(),
Address: c.listenerAddr.IP.String(),
@ -525,25 +513,6 @@ func (c *Command) Run(args []string) (retCode int) {
}
c.UI.Output(string(out))
}
default:
if len(creds) == 0 {
break
}
switch base.Format(c.UI) {
case "table":
c.UI.Output(generateCredentialTableOutput(creds))
case "json":
out, err := json.Marshal(&struct {
Credentials []*targets.SessionCredential `json:"credentials"`
}{
Credentials: c.sessionAuthz.Credentials,
})
if err != nil {
c.PrintCliError(fmt.Errorf("error marshaling session information: %w", err))
return base.CommandCliError
}
c.UI.Output(string(out))
}
}
c.connWg = new(sync.WaitGroup)
@ -681,6 +650,27 @@ func (c *Command) Run(args []string) (retCode int) {
return
}
func (c *Command) printCredentials(creds []*targets.SessionCredential) error {
if len(creds) == 0 {
return nil
}
switch base.Format(c.UI) {
case "table":
c.UI.Output(generateCredentialTableOutput(creds))
case "json":
out, err := json.Marshal(&struct {
Credentials []*targets.SessionCredential `json:"credentials"`
}{
Credentials: creds,
})
if err != nil {
return fmt.Errorf("error marshaling credential information: %w", err)
}
c.UI.Output(string(out))
}
return nil
}
func (c *Command) getWsConn(
ctx context.Context,
workerAddr string,
@ -824,6 +814,17 @@ func (c *Command) handleExec(passthroughArgs []string) {
var envs []string
var argsErr error
var creds credentials
if c.sessionAuthz != nil {
var err error
creds, err = parseCredentials(c.sessionAuthz.Credentials)
if err != nil {
c.PrintCliError(fmt.Errorf("Error interpreting secret: %w", err))
c.execCmdReturnValue.Store(int32(3))
return
}
}
switch c.Func {
case "http":
httpArgs, err := c.httpFlags.buildArgs(c, port, ip, addr)
@ -835,25 +836,27 @@ func (c *Command) handleExec(passthroughArgs []string) {
args = append(args, httpArgs...)
case "postgres":
pgArgs, pgEnvs, pgErr := c.postgresFlags.buildArgs(c, port, ip, addr)
pgArgs, pgEnvs, pgCreds, pgErr := c.postgresFlags.buildArgs(c, port, ip, addr, creds)
if pgErr != nil {
argsErr = pgErr
break
}
args = append(args, pgArgs...)
envs = append(envs, pgEnvs...)
creds = pgCreds
case "rdp":
args = append(args, c.rdpFlags.buildArgs(c, port, ip, addr)...)
case "ssh":
sshArgs, sshEnvs, sshErr := c.sshFlags.buildArgs(c, port, ip, addr)
sshArgs, sshEnvs, sshCreds, sshErr := c.sshFlags.buildArgs(c, port, ip, addr, creds)
if sshErr != nil {
argsErr = sshErr
break
}
args = append(args, sshArgs...)
envs = append(envs, sshEnvs...)
creds = sshCreds
case "kube":
kubeArgs, err := c.kubeFlags.buildArgs(c, port, ip, addr)
@ -871,6 +874,12 @@ func (c *Command) handleExec(passthroughArgs []string) {
return
}
if err := c.printCredentials(creds.unconsumedSessionCredentials()); err != nil {
c.PrintCliError(fmt.Errorf("Failed to print credentials: %w", err))
c.execCmdReturnValue.Store(int32(2))
return
}
args = append(passthroughArgs, args...)
stringReplacer := func(in, typ, replacer string) string {

@ -8,54 +8,115 @@ import (
"github.com/mitchellh/mapstructure"
)
type usernamePasswordCredentials struct {
type usernamePassword struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
raw *targets.SessionCredential
consumed bool
}
type sshPrivateKey struct {
Username string `mapstructure:"username"`
PrivateKey string `mapstructure:"private_key"`
raw *targets.SessionCredential
consumed bool
}
type credentials struct {
usernamePassword []usernamePassword
sshPrivateKey []sshPrivateKey
unspecified []*targets.SessionCredential
}
// TODO: this should support multiple types of credentials and return an interface instead
// of only usernamepassword. Each subcommand can use the returned credentials as it needs.
func parseCredentials(creds []*targets.SessionCredential) ([]usernamePasswordCredentials, error) {
func (c credentials) unconsumedSessionCredentials() []*targets.SessionCredential {
out := make([]*targets.SessionCredential, 0, len(c.sshPrivateKey)+len(c.usernamePassword)+len(c.unspecified))
// Unspecified credentials cannot be consumed
out = append(out, c.unspecified...)
for _, c := range c.sshPrivateKey {
if !c.consumed {
out = append(out, c.raw)
}
}
for _, c := range c.usernamePassword {
if !c.consumed {
out = append(out, c.raw)
}
}
return out
}
func parseCredentials(creds []*targets.SessionCredential) (credentials, error) {
if creds == nil {
return nil, nil
return credentials{}, nil
}
var out []usernamePasswordCredentials
var out credentials
for _, cred := range creds {
if cred.CredentialSource == nil {
return nil, errors.New("missing credential source")
return credentials{}, errors.New("missing credential source")
}
var c usernamePasswordCredentials
if cred.CredentialSource.CredentialType == string(credential.UsernamePasswordType) {
var upCred usernamePassword
var spkCred sshPrivateKey
switch credential.Type(cred.CredentialSource.CredentialType) {
case credential.UsernamePasswordType:
// Decode attributes from credential struct
if err := mapstructure.Decode(cred.Credential, &c); err != nil {
return nil, err
if err := mapstructure.Decode(cred.Credential, &upCred); err != nil {
return credentials{}, err
}
if c.Username != "" && c.Password != "" {
out = append(out, c)
if upCred.Username != "" && upCred.Password != "" {
upCred.raw = cred
out.usernamePassword = append(out.usernamePassword, upCred)
continue
}
}
// Credential type is unspecified, make a best effort attempt to parse both username
// and password from Decoded field if it exists
if cred.Secret == nil || cred.Secret.Decoded == nil {
// No secret data continue to next credential
continue
}
case credential.SshPrivateKeyType:
// Decode attributes from credential struct
if err := mapstructure.Decode(cred.Credential, &spkCred); err != nil {
return credentials{}, err
}
switch cred.CredentialSource.Type {
case "vault", "static":
// Attempt unmarshaling into creds
if err := mapstructure.Decode(cred.Secret.Decoded, &c); err != nil {
return nil, err
if spkCred.Username != "" && spkCred.PrivateKey != "" {
spkCred.raw = cred
out.sshPrivateKey = append(out.sshPrivateKey, spkCred)
continue
}
}
if c.Username != "" && c.Password != "" {
out = append(out, c)
// Credential type is unspecified, make a best effort attempt to parse
// a username_password credential from the Decoded field if it exists
if cred.Secret != nil && cred.Secret.Decoded != nil {
switch cred.CredentialSource.Type {
case "vault", "static":
// Attempt unmarshaling into username password creds
if err := mapstructure.Decode(cred.Secret.Decoded, &upCred); err != nil {
return credentials{}, err
}
if upCred.Username != "" && upCred.Password != "" {
upCred.raw = cred
out.usernamePassword = append(out.usernamePassword, upCred)
continue
}
// Attempt unmarshaling into ssh private key creds
if err := mapstructure.Decode(cred.Secret.Decoded, &spkCred); err != nil {
return credentials{}, err
}
if spkCred.Username != "" && spkCred.PrivateKey != "" {
spkCred.raw = cred
out.sshPrivateKey = append(out.sshPrivateKey, spkCred)
continue
}
}
}
// We could not parse the credential
out.unspecified = append(out.unspecified, cred)
}
return out, nil
}

@ -9,222 +9,472 @@ import (
"github.com/stretchr/testify/require"
)
func Test_parseUsernamePasswordCredentials(t *testing.T) {
var (
typedUsernamePassword = &targets.SessionCredential{
CredentialSource: &targets.CredentialSource{
CredentialType: string(credential.UsernamePasswordType),
},
Credential: map[string]interface{}{
"username": "user",
"password": "pass",
},
}
typedSshPrivateKey = &targets.SessionCredential{
CredentialSource: &targets.CredentialSource{
CredentialType: string(credential.SshPrivateKeyType),
},
Credential: map[string]interface{}{
"username": "user",
"private_key": "my-pk",
},
}
vaultUsernamePassword = &targets.SessionCredential{
CredentialSource: &targets.CredentialSource{
Type: "vault",
},
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "vault-decoded-user",
"password": "vault-decoded-pass",
},
},
}
vaultSshPrivateKey = &targets.SessionCredential{
CredentialSource: &targets.CredentialSource{
Type: "vault",
},
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "vault-decoded-user",
"private_key": "vault-decoded-pk",
},
},
}
staticUsernamePassword = &targets.SessionCredential{
CredentialSource: &targets.CredentialSource{
Type: "static",
},
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "static-decoded-user",
"password": "static-decoded-pass",
},
},
}
staticSshPrivateKey = &targets.SessionCredential{
CredentialSource: &targets.CredentialSource{
Type: "static",
},
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "static-decoded-user",
"private_key": "static-decoded-pk",
},
},
}
unspecifiedCred = &targets.SessionCredential{
CredentialSource: &targets.CredentialSource{
Type: "static",
},
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "decoded-user",
"some-value": "decoded-some-value",
},
},
}
unspecifiedCred1 = &targets.SessionCredential{
CredentialSource: &targets.CredentialSource{
Type: "static",
},
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "decoded-user",
"some-value1": "decoded-some-value1",
},
},
}
)
func Test_parseCredentials(t *testing.T) {
tests := []struct {
name string
creds []*targets.SessionCredential
wantCreds []usernamePasswordCredentials
wantCreds credentials
wantErr bool
}{
{
name: "no-creds",
wantCreds: nil,
wantErr: false,
name: "no-creds",
wantErr: false,
},
{
name: "no-credential-source",
creds: []*targets.SessionCredential{
{
Credential: map[string]interface{}{
"username": "user",
"password": "pass",
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "decoded-user",
"private_key": "decoded-pk",
},
},
},
},
wantCreds: nil,
wantErr: true,
wantErr: true,
},
{
name: "valid-typed",
name: "username-password-typed",
creds: []*targets.SessionCredential{
{
CredentialSource: &targets.CredentialSource{
CredentialType: string(credential.UsernamePasswordType),
},
Credential: map[string]interface{}{
"username": "user",
"password": "pass",
},
},
typedUsernamePassword,
},
wantCreds: []usernamePasswordCredentials{
{
Username: "user",
Password: "pass",
wantCreds: credentials{
usernamePassword: []usernamePassword{
{
Username: "user",
Password: "pass",
raw: typedUsernamePassword,
},
},
},
wantErr: false,
},
{
name: "valid-typed-instead-of-decoded",
name: "ssh-private-key-typed",
creds: []*targets.SessionCredential{
{
CredentialSource: &targets.CredentialSource{
CredentialType: string(credential.UsernamePasswordType),
},
Credential: map[string]interface{}{
"username": "user",
"password": "pass",
},
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "secret-user",
"password": "secret-pass",
},
typedSshPrivateKey,
},
wantCreds: credentials{
sshPrivateKey: []sshPrivateKey{
{
Username: "user",
PrivateKey: "my-pk",
raw: typedSshPrivateKey,
},
},
},
wantCreds: []usernamePasswordCredentials{
{
Username: "user",
Password: "pass",
wantErr: false,
},
{
name: "vault-username-password-decoded",
creds: []*targets.SessionCredential{
vaultUsernamePassword,
},
wantCreds: credentials{
usernamePassword: []usernamePassword{
{
Username: "vault-decoded-user",
Password: "vault-decoded-pass",
raw: vaultUsernamePassword,
},
},
},
wantErr: false,
},
{
name: "valid-vault-not-typed",
name: "vault-private-key-decoded",
creds: []*targets.SessionCredential{
{
CredentialSource: &targets.CredentialSource{
Type: "vault",
},
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "user",
"password": "pass",
},
vaultSshPrivateKey,
},
wantCreds: credentials{
sshPrivateKey: []sshPrivateKey{
{
Username: "vault-decoded-user",
PrivateKey: "vault-decoded-pk",
raw: vaultSshPrivateKey,
},
},
},
wantCreds: []usernamePasswordCredentials{
{
Username: "user",
Password: "pass",
wantErr: false,
},
{
name: "static-username-password-decoded",
creds: []*targets.SessionCredential{
staticUsernamePassword,
},
wantCreds: credentials{
usernamePassword: []usernamePassword{
{
Username: "static-decoded-user",
Password: "static-decoded-pass",
raw: staticUsernamePassword,
},
},
},
wantErr: false,
},
{
name: "valid-static-not-typed",
name: "static-private-key-decoded",
creds: []*targets.SessionCredential{
{
CredentialSource: &targets.CredentialSource{
Type: "static",
},
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "user",
"password": "pass",
},
staticSshPrivateKey,
},
wantCreds: credentials{
sshPrivateKey: []sshPrivateKey{
{
Username: "static-decoded-user",
PrivateKey: "static-decoded-pk",
raw: staticSshPrivateKey,
},
},
},
wantCreds: []usernamePasswordCredentials{
{
Username: "user",
Password: "pass",
wantErr: false,
},
{
name: "unspecified",
creds: []*targets.SessionCredential{
unspecifiedCred,
},
wantCreds: credentials{
unspecified: []*targets.SessionCredential{
unspecifiedCred,
},
},
wantErr: false,
},
{
name: "valid-multiple",
name: "mixed",
creds: []*targets.SessionCredential{
{
CredentialSource: &targets.CredentialSource{
CredentialType: string(credential.UsernamePasswordType),
staticSshPrivateKey, unspecifiedCred1, vaultSshPrivateKey, typedUsernamePassword,
unspecifiedCred, vaultUsernamePassword, typedSshPrivateKey, staticUsernamePassword,
},
wantCreds: credentials{
sshPrivateKey: []sshPrivateKey{
{
Username: "static-decoded-user",
PrivateKey: "static-decoded-pk",
raw: staticSshPrivateKey,
},
{
Username: "vault-decoded-user",
PrivateKey: "vault-decoded-pk",
raw: vaultSshPrivateKey,
},
Credential: map[string]interface{}{
"username": "user",
"password": "pass",
{
Username: "user",
PrivateKey: "my-pk",
raw: typedSshPrivateKey,
},
},
{
CredentialSource: &targets.CredentialSource{
CredentialType: string(credential.UsernamePasswordType),
usernamePassword: []usernamePassword{
{
Username: "static-decoded-user",
Password: "static-decoded-pass",
raw: staticUsernamePassword,
},
{
Username: "vault-decoded-user",
Password: "vault-decoded-pass",
raw: vaultUsernamePassword,
},
Credential: map[string]interface{}{
"username": "user1",
"password": "pass1",
{
Username: "user",
Password: "pass",
raw: typedUsernamePassword,
},
},
unspecified: []*targets.SessionCredential{
unspecifiedCred, unspecifiedCred1,
},
},
wantCreds: []usernamePasswordCredentials{
{
Username: "user",
Password: "pass",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
creds, err := parseCredentials(tt.creds)
if tt.wantErr {
require.Error(err)
assert.Empty(creds)
return
}
require.NoError(err)
assert.ElementsMatch(tt.wantCreds.usernamePassword, creds.usernamePassword)
assert.ElementsMatch(tt.wantCreds.sshPrivateKey, creds.sshPrivateKey)
assert.ElementsMatch(tt.wantCreds.unspecified, creds.unspecified)
})
}
}
func Test_unconsumedSessionCredentials(t *testing.T) {
tests := []struct {
name string
creds credentials
wantCreds []*targets.SessionCredential
}{
{
name: "no-creds",
wantCreds: nil,
},
{
name: "spk-consumed",
creds: credentials{
sshPrivateKey: []sshPrivateKey{
{
raw: staticSshPrivateKey,
consumed: true,
},
},
{
Username: "user1",
Password: "pass1",
},
wantCreds: nil,
},
{
name: "spk",
creds: credentials{
sshPrivateKey: []sshPrivateKey{
{
raw: staticSshPrivateKey,
},
},
},
wantErr: false,
wantCreds: []*targets.SessionCredential{staticSshPrivateKey},
},
{
name: "valid-multiple-mixed",
creds: []*targets.SessionCredential{
{
CredentialSource: &targets.CredentialSource{
CredentialType: string(credential.UsernamePasswordType),
name: "up",
creds: credentials{
usernamePassword: []usernamePassword{
{
raw: vaultUsernamePassword,
},
Credential: map[string]interface{}{
"username": "user",
"password": "pass",
},
},
wantCreds: []*targets.SessionCredential{vaultUsernamePassword},
},
{
name: "up-consumed",
creds: credentials{
usernamePassword: []usernamePassword{
{
raw: vaultUsernamePassword,
consumed: true,
},
},
{
CredentialSource: &targets.CredentialSource{
Type: "vault",
},
wantCreds: nil,
},
{
name: "unspecified",
creds: credentials{
unspecified: []*targets.SessionCredential{unspecifiedCred},
},
wantCreds: []*targets.SessionCredential{unspecifiedCred},
},
{
name: "mixed",
creds: credentials{
sshPrivateKey: []sshPrivateKey{
{
raw: staticSshPrivateKey,
consumed: true,
},
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "user1",
"password": "pass1",
},
{
raw: vaultSshPrivateKey,
},
{
raw: typedSshPrivateKey,
},
},
{
CredentialSource: &targets.CredentialSource{
Type: "static",
usernamePassword: []usernamePassword{
{
raw: staticUsernamePassword,
consumed: true,
},
Secret: &targets.SessionSecret{
Decoded: map[string]interface{}{
"username": "user2",
"password": "pass2",
},
{
raw: vaultUsernamePassword,
},
{
raw: typedUsernamePassword,
consumed: true,
},
},
unspecified: []*targets.SessionCredential{unspecifiedCred, unspecifiedCred1},
},
wantCreds: []usernamePasswordCredentials{
{
Username: "user",
Password: "pass",
wantCreds: []*targets.SessionCredential{
vaultSshPrivateKey, typedSshPrivateKey, vaultUsernamePassword, unspecifiedCred, unspecifiedCred1,
},
},
{
name: "mixed-all-consumed",
creds: credentials{
sshPrivateKey: []sshPrivateKey{
{
raw: staticSshPrivateKey,
consumed: true,
},
{
raw: vaultSshPrivateKey,
consumed: true,
},
{
raw: typedSshPrivateKey,
consumed: true,
},
},
{
Username: "user1",
Password: "pass1",
usernamePassword: []usernamePassword{
{
raw: staticUsernamePassword,
consumed: true,
},
{
raw: vaultUsernamePassword,
consumed: true,
},
{
raw: typedUsernamePassword,
consumed: true,
},
},
{
Username: "user2",
Password: "pass2",
unspecified: []*targets.SessionCredential{unspecifiedCred, unspecifiedCred1},
},
wantCreds: []*targets.SessionCredential{
unspecifiedCred1, unspecifiedCred,
},
},
{
name: "mixed-all-unconsumed",
creds: credentials{
sshPrivateKey: []sshPrivateKey{
{
raw: staticSshPrivateKey,
},
{
raw: vaultSshPrivateKey,
},
{
raw: typedSshPrivateKey,
},
},
usernamePassword: []usernamePassword{
{
raw: staticUsernamePassword,
},
{
raw: vaultUsernamePassword,
},
{
raw: typedUsernamePassword,
},
},
unspecified: []*targets.SessionCredential{unspecifiedCred, unspecifiedCred1},
},
wantCreds: []*targets.SessionCredential{
staticSshPrivateKey, unspecifiedCred1, vaultSshPrivateKey, typedUsernamePassword,
unspecifiedCred, vaultUsernamePassword, typedSshPrivateKey, staticUsernamePassword,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
assert := assert.New(t)
creds, err := parseCredentials(tt.creds)
if tt.wantErr {
require.Error(err)
assert.Nil(creds)
return
}
require.NoError(err)
creds := tt.creds.unconsumedSessionCredentials()
assert.ElementsMatch(tt.wantCreds, creds)
})
}

@ -46,6 +46,11 @@ func generateCredentialTableOutput(creds []*targets.SessionCredential) string {
func generateCredentialTableOutputSlice(prefixIndent int, creds []*targets.SessionCredential) []string {
var ret []string
prefixString := strings.Repeat(" ", prefixIndent)
if len(creds) > 0 {
// Add credential header
ret = append(ret, fmt.Sprintf("%sCredentials:", prefixString))
}
for _, crd := range creds {
credMap := map[string]interface{}{
"Credential Store ID": crd.CredentialSource.CredentialStoreId,
@ -63,12 +68,12 @@ func generateCredentialTableOutputSlice(prefixIndent int, creds []*targets.Sessi
}
maxLength := base.MaxAttributesLength(credMap, nil, nil)
ret = append(ret,
fmt.Sprintf("%sCredentials:", prefixString),
base.WrapMap(2+prefixIndent, maxLength, credMap),
fmt.Sprintf("%s Secret:", prefixString))
ret = append(ret,
fmtSecretForTable(2+prefixIndent, crd)...,
)
ret = append(ret, "")
}
return ret

@ -51,18 +51,17 @@ func (p *postgresFlags) defaultExec() string {
return strings.ToLower(p.flagPostgresStyle)
}
func (p *postgresFlags) buildArgs(c *Command, port, ip, addr string) (args, envs []string, retErr error) {
var cred usernamePasswordCredentials
if c.sessionAuthz != nil {
creds, err := parseCredentials(c.sessionAuthz.Credentials)
if err != nil {
return nil, nil, fmt.Errorf("Error interpreting secret: %w", err)
}
if len(creds) > 0 {
// Just use first credentials returned
cred = creds[0]
}
func (p *postgresFlags) buildArgs(c *Command, port, ip, addr string, creds credentials) (args, envs []string, retCreds credentials, retErr error) {
var username, password string
retCreds = creds
if len(retCreds.usernamePassword) > 0 {
// Mark credential as consumed so it is not printed to user
retCreds.usernamePassword[0].consumed = true
// For now just grab the first username password credential brokered
username = retCreds.usernamePassword[0].Username
password = retCreds.usernamePassword[0].Password
}
switch p.flagPostgresStyle {
@ -74,16 +73,16 @@ func (p *postgresFlags) buildArgs(c *Command, port, ip, addr string) (args, envs
}
switch {
case cred.Username != "":
args = append(args, "-U", cred.Username)
case username != "":
args = append(args, "-U", username)
case c.flagUsername != "":
args = append(args, "-U", c.flagUsername)
}
if cred.Password != "" {
if password != "" {
passfile, err := ioutil.TempFile("", "*")
if err != nil {
return nil, nil, fmt.Errorf("Error saving postgres password to tmp file: %w", err)
return nil, nil, credentials{}, fmt.Errorf("Error saving postgres password to tmp file: %w", err)
}
c.cleanupFuncs = append(c.cleanupFuncs, func() error {
if err := os.Remove(passfile.Name()); err != nil {
@ -91,12 +90,12 @@ func (p *postgresFlags) buildArgs(c *Command, port, ip, addr string) (args, envs
}
return nil
})
_, err = passfile.WriteString(fmt.Sprintf("*:*:*:*:%s", cred.Password))
_, err = passfile.WriteString(fmt.Sprintf("*:*:*:*:%s", password))
if err != nil {
return nil, nil, fmt.Errorf("Error writing password file to %s: %w", passfile.Name(), err)
return nil, nil, credentials{}, fmt.Errorf("Error writing password file to %s: %w", passfile.Name(), err)
}
if err := passfile.Close(); err != nil {
return nil, nil, fmt.Errorf("Error closing password file after writing to %s: %w", passfile.Name(), err)
return nil, nil, credentials{}, fmt.Errorf("Error closing password file after writing to %s: %w", passfile.Name(), err)
}
envs = append(envs, fmt.Sprintf("PGPASSFILE=%s", passfile.Name()))

@ -3,6 +3,8 @@ package connect
import (
"errors"
"fmt"
"io/ioutil"
"os"
"strings"
"github.com/hashicorp/boundary/internal/cmd/base"
@ -42,18 +44,9 @@ func (s *sshFlags) defaultExec() string {
return strings.ToLower(s.flagSshStyle)
}
func (s *sshFlags) buildArgs(c *Command, port, ip, addr string) (args, envs []string, retErr error) {
var cred usernamePasswordCredentials
if c.sessionAuthz != nil {
creds, err := parseCredentials(c.sessionAuthz.Credentials)
if err != nil {
return nil, nil, fmt.Errorf("Error interpreting secret: %w", err)
}
if len(creds) > 0 {
// Just use first credentials returned
cred = creds[0]
}
}
func (s *sshFlags) buildArgs(c *Command, port, ip, addr string, creds credentials) (args, envs []string, retCreds credentials, retErr error) {
var username string
retCreds = creds
switch strings.ToLower(s.flagSshStyle) {
case "ssh":
@ -65,12 +58,23 @@ func (s *sshFlags) buildArgs(c *Command, port, ip, addr string) (args, envs []st
args = append(args, "-o", "NoHostAuthenticationForLocalhost=yes")
case "sshpass":
if cred.Password == "" {
return nil, nil, errors.New("Password is required when using sshpass")
var password string
if len(retCreds.usernamePassword) > 0 {
// For now just grab the first username password credential brokered
// Mark credential as consumed so it is not printed to user
retCreds.usernamePassword[0].consumed = true
username = retCreds.usernamePassword[0].Username
password = retCreds.usernamePassword[0].Password
}
if password == "" {
return nil, nil, credentials{}, errors.New("Password is required when using sshpass")
}
// sshpass requires that the password is passed as env-var "SSHPASS"
// when the using env-vars.
envs = append(envs, fmt.Sprintf("SSHPASS=%s", cred.Password))
envs = append(envs, fmt.Sprintf("SSHPASS=%s", password))
args = append(args, "-e", "ssh")
args = append(args, "-p", port, ip)
@ -82,12 +86,61 @@ func (s *sshFlags) buildArgs(c *Command, port, ip, addr string) (args, envs []st
args = append(args, "-P", port, ip)
}
// Check if we got credentials to attempt to use for ssh or putty,
// sshpass style has already be handled above as username password.
switch strings.ToLower(s.flagSshStyle) {
case "putty", "ssh":
switch {
// First check if we have a private key available
case len(creds.sshPrivateKey) > 0:
// For now just grab the first ssh private key credential brokered
// Mark credential as consumed so it is not printed to user
retCreds.sshPrivateKey[0].consumed = true
username = retCreds.sshPrivateKey[0].Username
privateKey := retCreds.sshPrivateKey[0].PrivateKey
pkFile, err := ioutil.TempFile("", "*")
if err != nil {
return nil, nil, credentials{}, fmt.Errorf("Error saving ssh private key to tmp file: %w", err)
}
c.cleanupFuncs = append(c.cleanupFuncs, func() error {
if err := os.Remove(pkFile.Name()); err != nil {
return fmt.Errorf("Error removing temporary ssh private key file; consider removing %s manually: %w", pkFile.Name(), err)
}
return nil
})
_, err = pkFile.WriteString(privateKey)
if err != nil {
return nil, nil, credentials{}, fmt.Errorf("Error writing private key file to %s: %w", pkFile.Name(), err)
}
if err := pkFile.Close(); err != nil {
return nil, nil, credentials{}, fmt.Errorf("Error closing private key file after writing to %s: %w", pkFile.Name(), err)
}
args = append(args, "-i", pkFile.Name())
// Next check if we have a username password credential
case len(creds.usernamePassword) > 0:
// We cannot use the password of the credential outside of sshpass but we
// can use the username.
// N.B. Do not mark credential as consumed, as user will still need enter
// the password when prompted.
if c.flagUsername == "" {
// If the user did not actively provide a username flag set the
// username to that of the first credential we got.
username = retCreds.usernamePassword[0].Username
}
}
}
switch {
case cred.Username != "":
args = append(args, "-l", cred.Username)
case username != "":
args = append(args, "-l", username)
case c.flagUsername != "":
args = append(args, "-l", c.flagUsername)
}
return args, envs, nil
return args, envs, retCreds, nil
}

Loading…
Cancel
Save