diff --git a/globals/globals.go b/globals/globals.go index b5b9acbe61..daba120ccf 100644 --- a/globals/globals.go +++ b/globals/globals.go @@ -5,6 +5,7 @@ import "time" const ( TcpProxyV1 = "boundary-tcp-proxy-v1" ServiceTokenV1 = "s1" + SessionPrefix = "s_" ) type ( diff --git a/go.mod b/go.mod index 441d88cfcb..a521af6f6b 100644 --- a/go.mod +++ b/go.mod @@ -70,10 +70,10 @@ require ( github.com/stretchr/testify v1.7.0 github.com/zalando/go-keyring v0.2.1 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 - golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a + golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 + golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 - golang.org/x/tools v0.1.8-0.20211102182255-bb4add04ddef + golang.org/x/tools v0.1.10 google.golang.org/genproto v0.0.0-20220208230804-65c12eb4c068 google.golang.org/grpc v1.44.0 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0 @@ -189,7 +189,7 @@ require ( go.opencensus.io v0.23.0 // indirect go.uber.org/multierr v1.7.0 // indirect go.uber.org/zap v1.19.0 // indirect - golang.org/x/mod v0.5.1 // indirect + golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect diff --git a/go.sum b/go.sum index d1d920014b..6c19c0cd52 100644 --- a/go.sum +++ b/go.sum @@ -963,8 +963,9 @@ golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5 golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 h1:SL+8VVnkqyshUSz5iNnXtrBQzvFF2SkROm6t5RczFAE= golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 h1:tkVvjkPTB7pnW3jnid7kNyAMPVWllTNOf/qKDze4p9o= +golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -1001,8 +1002,9 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.5.1 h1:OJxoQ/rynoF0dcCdI7cLPktw/hR2cueqYfjm43oqK38= golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180530234432-1e491301e022/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1166,8 +1168,9 @@ golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a h1:ppl5mZgokTT8uPkmYOyEUmPTr3ypaKkg5eFOGrAmxxE= golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f h1:rlezHXNlxYWvBCzNses9Dlc7nGFaNMJeqLolcmQSSZY= +golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1257,8 +1260,9 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.8-0.20211102182255-bb4add04ddef h1:/DaKawnTFFxdq/mJT3pM+OkeJlq5gc3ZhkbGVYbqOCw= golang.org/x/tools v0.1.8-0.20211102182255-bb4add04ddef/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= +golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20= +golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/cmd/commands/connect/connect.go b/internal/cmd/commands/connect/connect.go index 2d8f2d7837..5def593f07 100644 --- a/internal/cmd/commands/connect/connect.go +++ b/internal/cmd/commands/connect/connect.go @@ -427,6 +427,15 @@ func (c *Command) Run(args []string) (retCode int) { c.connectionsLeft.Store(c.sessionAuthzData.ConnectionLimit) workerAddr := c.sessionAuthzData.GetWorkerInfo()[0].GetAddress() + workerHost, _, err := net.SplitHostPort(workerAddr) + if err != nil { + if strings.Contains(err.Error(), "missing port") { + workerHost = workerAddr + } else { + c.PrintCliError(fmt.Errorf("Error splitting worker adddress host/port: %w", err)) + return base.CommandUserError + } + } parsedCert, err := x509.ParseCertificate(c.sessionAuthzData.Certificate) if err != nil { @@ -434,11 +443,6 @@ func (c *Command) Run(args []string) (retCode int) { return base.CommandUserError } - if len(parsedCert.DNSNames) != 1 { - c.PrintCliError(fmt.Errorf("mTLS certificate has invalid parameters: %w", err)) - return base.CommandUserError - } - c.expiration = parsedCert.NotAfter // We don't _rely_ on client-side timeout verification but this prevents us @@ -459,16 +463,20 @@ func (c *Command) Run(args []string) (retCode int) { }, }, RootCAs: certPool, - ServerName: parsedCert.DNSNames[0], + ServerName: workerHost, MinVersion: tls.VersionTLS13, + NextProtos: []string{"http/1.1", c.sessionAuthzData.SessionId}, } transport := cleanhttp.DefaultTransport() transport.DisableKeepAlives = false - transport.TLSClientConfig = tlsConf // This isn't/shouldn't used anyways really because the connection is // hijacked, just setting for completeness transport.IdleConnTimeout = 0 + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &tls.Dialer{Config: tlsConf} + return dialer.DialContext(ctx, network, addr) + } c.listener, err = net.ListenTCP("tcp", &net.TCPAddr{ IP: listenAddr, @@ -691,7 +699,7 @@ func (c *Command) getWsConn( ) (*websocket.Conn, error) { conn, resp, err := websocket.Dial( ctx, - fmt.Sprintf("wss://%s/v1/proxy", workerAddr), + fmt.Sprintf("ws://%s/v1/proxy", workerAddr), &websocket.DialOptions{ HTTPClient: &http.Client{ Transport: transport, diff --git a/internal/servers/controller/handlers/targets/target_service.go b/internal/servers/controller/handlers/targets/target_service.go index 2b1278d034..3a09ed6c1b 100644 --- a/internal/servers/controller/handlers/targets/target_service.go +++ b/internal/servers/controller/handlers/targets/target_service.go @@ -100,7 +100,8 @@ func NewService( sessionRepoFn common.SessionRepoFactory, pluginHostRepoFn common.PluginHostRepoFactory, staticHostRepoFn common.StaticRepoFactory, - vaultCredRepoFn common.VaultCredentialRepoFactory) (Service, error) { + vaultCredRepoFn common.VaultCredentialRepoFactory, +) (Service, error) { const op = "targets.NewService" if repoFn == nil { return Service{}, errors.New(ctx, errors.InvalidParameter, op, "missing target repository") @@ -939,6 +940,10 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession codes.FailedPrecondition, "No workers are available to handle this session, or all have been filtered.") } + workerAddresses := make([]string, 0, len(workers)) + for _, worker := range workers { + workerAddresses = append(workerAddresses, worker.GetAddress()) + } requestedId := req.GetHostId() staticHostRepo, err := s.staticHostRepoFn() @@ -1046,7 +1051,7 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession if err != nil { return nil, err } - sess, privKey, err := sessionRepo.CreateSession(ctx, wrapper, sess) + sess, privKey, err := sessionRepo.CreateSession(ctx, wrapper, sess, workerAddresses) if err != nil { return nil, err } diff --git a/internal/servers/worker/handler.go b/internal/servers/worker/handler.go index ae26217242..48e7ea6de0 100644 --- a/internal/servers/worker/handler.go +++ b/internal/servers/worker/handler.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "strconv" + "strings" "github.com/hashicorp/boundary/globals" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" @@ -53,11 +54,26 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig) (http.Han return func(wr http.ResponseWriter, r *http.Request) { ctx := r.Context() if r.TLS == nil { - event.WriteError(ctx, op, errors.New("no request TLS information found")) + event.WriteError(ctx, op, errors.New("no request tls information found")) + wr.WriteHeader(http.StatusInternalServerError) + return + } + + var sessionId string + outerCertLoop: + for _, cert := range r.TLS.PeerCertificates { + for _, name := range cert.DNSNames { + if strings.HasPrefix(name, globals.SessionPrefix) { + sessionId = name + break outerCertLoop + } + } + } + if sessionId == "" { + event.WriteError(ctx, op, errors.New("no session id could be found in peer certificates")) wr.WriteHeader(http.StatusInternalServerError) return } - sessionId := r.TLS.ServerName clientIp, clientPort, err := net.SplitHostPort(r.RemoteAddr) if err != nil { @@ -78,7 +94,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig) (http.Han userClientIp, err := common.ClientIpFromRequest(ctx, listenerCfg, r) if err != nil { - event.WriteError(ctx, op, err, event.WithInfoMsg("unable to determine user IP")) + event.WriteError(ctx, op, err, event.WithInfoMsg("unable to determine user ip")) wr.WriteHeader(http.StatusInternalServerError) } diff --git a/internal/servers/worker/worker.go b/internal/servers/worker/worker.go index 332b017aba..4ca8cd7cb4 100644 --- a/internal/servers/worker/worker.go +++ b/internal/servers/worker/worker.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/cmd/config" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" @@ -267,11 +268,20 @@ func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error) ctx := w.baseContext var sessionId string switch { - case strings.HasPrefix(hello.ServerName, "s_"): + case strings.HasPrefix(hello.ServerName, globals.SessionPrefix): sessionId = hello.ServerName default: - event.WriteSysEvent(ctx, op, "invalid session in SNI", "session_id", hello.ServerName) - return nil, fmt.Errorf("could not find session ID in SNI") + for _, proto := range hello.SupportedProtos { + if strings.HasPrefix(proto, globals.SessionPrefix) { + sessionId = proto + break + } + } + } + + if sessionId == "" { + event.WriteSysEvent(ctx, op, "session_id not found in either SNI or ALPN protos", "server_name", hello.ServerName, "alpn_protos", hello.SupportedProtos) + return nil, fmt.Errorf("could not find session ID in SNI or ALPN protos") } conn, err := w.ControllerSessionConn() @@ -299,10 +309,6 @@ func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error) return nil, fmt.Errorf("error parsing session certificate: %w", err) } - if len(parsedCert.DNSNames) != 1 { - return nil, fmt.Errorf("invalid length of DNS names (%d) in parsed certificate", len(parsedCert.DNSNames)) - } - certPool := x509.NewCertPool() certPool.AddCert(parsedCert) @@ -314,7 +320,8 @@ func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error) Leaf: parsedCert, }, }, - ServerName: parsedCert.DNSNames[0], + NextProtos: []string{"http/1.1"}, + ServerName: sessionId, ClientAuth: tls.RequireAndVerifyClientCert, ClientCAs: certPool, MinVersion: tls.VersionTLS13, diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index 47657cca4d..6e28a42ab6 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -19,7 +19,7 @@ import ( // its State of "Pending". The following fields must be empty when creating a // session: ServerId, ServerType, and PublicId. No options are // currently supported. -func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.Wrapper, newSession *Session, _ ...Option) (*Session, ed25519.PrivateKey, error) { +func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.Wrapper, newSession *Session, workerAddresses []string, _ ...Option) (*Session, ed25519.PrivateKey, error) { const op = "session.(Repository).CreateSession" if newSession == nil { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing session") @@ -63,13 +63,16 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping. if newSession.ExpirationTime == nil || newSession.ExpirationTime.Timestamp.AsTime().IsZero() { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing expiration time") } + if len(workerAddresses) == 0 { + return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing addresses") + } id, err := newId() if err != nil { return nil, nil, errors.Wrap(ctx, err, op) } - privKey, certBytes, err := newCert(ctx, sessionWrapper, newSession.UserId, id, newSession.ExpirationTime.Timestamp.AsTime()) + privKey, certBytes, err := newCert(ctx, sessionWrapper, newSession.UserId, id, workerAddresses, newSession.ExpirationTime.Timestamp.AsTime()) if err != nil { return nil, nil, errors.Wrap(ctx, err, op) } diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index e667cd6b55..8f7d1b7384 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -274,8 +274,10 @@ func TestRepository_CreateSession(t *testing.T) { repo, err := NewRepository(rw, rw, kms) require.NoError(t, err) + workerAddresses := []string{"1.2.3.4"} type args struct { - composedOf ComposedOf + composedOf ComposedOf + workerAddresses []string } tests := []struct { name string @@ -286,14 +288,16 @@ func TestRepository_CreateSession(t *testing.T) { { name: "valid", args: args{ - composedOf: TestSessionParams(t, conn, wrapper, iamRepo), + composedOf: TestSessionParams(t, conn, wrapper, iamRepo), + workerAddresses: workerAddresses, }, wantErr: false, }, { name: "valid-with-credentials", args: args{ - composedOf: testSessionCredentialParams(t, conn, wrapper, iamRepo), + composedOf: testSessionCredentialParams(t, conn, wrapper, iamRepo), + workerAddresses: workerAddresses, }, wantErr: false, }, @@ -305,6 +309,7 @@ func TestRepository_CreateSession(t *testing.T) { c.UserId = "" return c }(), + workerAddresses: workerAddresses, }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -317,6 +322,7 @@ func TestRepository_CreateSession(t *testing.T) { c.HostId = "" return c }(), + workerAddresses: workerAddresses, }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -329,6 +335,7 @@ func TestRepository_CreateSession(t *testing.T) { c.TargetId = "" return c }(), + workerAddresses: workerAddresses, }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -341,6 +348,7 @@ func TestRepository_CreateSession(t *testing.T) { c.HostSetId = "" return c }(), + workerAddresses: workerAddresses, }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -353,6 +361,7 @@ func TestRepository_CreateSession(t *testing.T) { c.AuthTokenId = "" return c }(), + workerAddresses: workerAddresses, }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -365,6 +374,18 @@ func TestRepository_CreateSession(t *testing.T) { c.ScopeId = "" return c }(), + workerAddresses: workerAddresses, + }, + wantErr: true, + wantIsError: errors.InvalidParameter, + }, + { + name: "empty-worker-addresses", + args: args{ + composedOf: func() ComposedOf { + c := TestSessionParams(t, conn, wrapper, iamRepo) + return c + }(), }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -385,7 +406,7 @@ func TestRepository_CreateSession(t *testing.T) { ConnectionLimit: tt.args.composedOf.ConnectionLimit, DynamicCredentials: tt.args.composedOf.DynamicCredentials, } - ses, privKey, err := repo.CreateSession(context.Background(), wrapper, s) + ses, privKey, err := repo.CreateSession(context.Background(), wrapper, s, tt.args.workerAddresses) if tt.wantErr { assert.Error(err) assert.Nil(ses) diff --git a/internal/session/session.go b/internal/session/session.go index 6367a8f0e3..baaa5331f0 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "math/big" mathrand "math/rand" + "net" "strings" "time" @@ -351,7 +352,8 @@ func contains(ss []string, t string) bool { return false } -func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string, exp time.Time) (ed25519.PrivateKey, []byte, error) { +// newCert creates a new session certificate. If addresses are supplied, they will be parsed and added to IP or DNS SANs as appropriate. +func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string, addresses []string, exp time.Time) (ed25519.PrivateKey, []byte, error) { const op = "session.newCert" if wrapper == nil { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing wrapper") @@ -362,6 +364,9 @@ func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string if jobId == "" { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing job id") } + if len(addresses) == 0 { + return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing addresses") + } pubKey, privKey, err := DeriveED25519Key(ctx, wrapper, userId, jobId) if err != nil { return nil, nil, errors.Wrap(ctx, err, op) @@ -380,6 +385,27 @@ func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string IsCA: true, } + for _, addr := range addresses { + // First ensure we aren't looking at ports, regardless of IP or not + host, _, err := net.SplitHostPort(addr) + if err != nil { + if strings.Contains(err.Error(), "missing port") { + host = addr + } else { + return nil, nil, errors.Wrap(ctx, err, op) + } + } + // Now figure out if it's an IP address or not. If ParseIP likes it, add + // to IP SANs. Otherwise DNS SANs. + ip := net.ParseIP(host) + switch ip { + case nil: + template.DNSNames = append(template.DNSNames, host) + default: + template.IPAddresses = append(template.IPAddresses, ip) + } + } + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, pubKey, privKey) if err != nil { return nil, nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.GenCert)) diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 334f48aba4..a5699f14df 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -2,6 +2,7 @@ package session import ( "context" + "crypto/x509" "testing" "time" @@ -23,8 +24,10 @@ func TestSession_Create(t *testing.T) { composedOf := testSessionCredentialParams(t, conn, wrapper, iamRepo) exp := ×tamp.Timestamp{Timestamp: timestamppb.New(time.Now().Add(time.Hour))} + defaultAddresses := []string{"1.2.3.4", "a.b.c.d"} type args struct { composedOf ComposedOf + addresses []string opt []Option } tests := []struct { @@ -32,6 +35,7 @@ func TestSession_Create(t *testing.T) { args args want *Session wantErr bool + wantAddrErr bool wantIsErr errors.Code create bool wantCreateErr bool @@ -41,6 +45,7 @@ func TestSession_Create(t *testing.T) { args: args{ composedOf: composedOf, opt: []Option{WithExpirationTime(exp)}, + addresses: defaultAddresses, }, want: &Session{ UserId: composedOf.UserId, @@ -64,6 +69,7 @@ func TestSession_Create(t *testing.T) { c.UserId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -76,6 +82,7 @@ func TestSession_Create(t *testing.T) { c.HostId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -88,6 +95,7 @@ func TestSession_Create(t *testing.T) { c.TargetId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -100,6 +108,7 @@ func TestSession_Create(t *testing.T) { c.HostSetId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -112,6 +121,7 @@ func TestSession_Create(t *testing.T) { c.AuthTokenId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -124,10 +134,34 @@ func TestSession_Create(t *testing.T) { c.ScopeId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, }, + { + name: "empty-addresses", + args: args{ + composedOf: func() ComposedOf { + c := composedOf + return c + }(), + }, + want: &Session{ + UserId: composedOf.UserId, + HostId: composedOf.HostId, + TargetId: composedOf.TargetId, + HostSetId: composedOf.HostSetId, + AuthTokenId: composedOf.AuthTokenId, + ScopeId: composedOf.ScopeId, + Endpoint: "tcp://127.0.0.1:22", + ExpirationTime: composedOf.ExpirationTime, + ConnectionLimit: composedOf.ConnectionLimit, + DynamicCredentials: composedOf.DynamicCredentials, + }, + wantAddrErr: true, + wantIsErr: errors.InvalidParameter, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -145,7 +179,12 @@ func TestSession_Create(t *testing.T) { id, err := db.NewPublicId(SessionPrefix) require.NoError(err) got.PublicId = id - _, certBytes, err := newCert(ctx, wrapper, got.UserId, id, composedOf.ExpirationTime.Timestamp.AsTime()) + _, certBytes, err := newCert(ctx, wrapper, got.UserId, id, tt.args.addresses, composedOf.ExpirationTime.Timestamp.AsTime()) + if tt.wantAddrErr { + require.Error(err) + assert.True(errors.Match(errors.T(tt.wantIsErr), err)) + return + } require.NoError(err) got.Certificate = certBytes err = db.New(conn).Create(ctx, got) @@ -155,6 +194,13 @@ func TestSession_Create(t *testing.T) { } else { assert.NoError(err) } + + if len(tt.args.addresses) > 0 { + cert, err := x509.ParseCertificate(certBytes) + require.NoError(err) + // Session ID is always encoded, hence the +1 + assert.Equal(len(tt.args.addresses)+1, len(cert.DNSNames)+len(cert.IPAddresses)) + } } }) } diff --git a/internal/session/testing.go b/internal/session/testing.go index 7d2197b88b..c306c1d819 100644 --- a/internal/session/testing.go +++ b/internal/session/testing.go @@ -84,7 +84,7 @@ func TestSession(t *testing.T, conn *db.DB, wrapper wrapping.Wrapper, c Composed id, err := newId() require.NoError(err) s.PublicId = id - _, certBytes, err := newCert(ctx, wrapper, c.UserId, id, c.ExpirationTime.Timestamp.AsTime()) + _, certBytes, err := newCert(ctx, wrapper, c.UserId, id, []string{"127.0.0.1", "localhost"}, c.ExpirationTime.Timestamp.AsTime()) require.NoError(err) s.Certificate = certBytes s.ServerId = opts.withServerId @@ -205,5 +205,5 @@ func TestWorker(t *testing.T, conn *db.DB, wrapper wrapping.Wrapper, opt ...Opti // as a parameter. It's currently used in controller.jobTestingHandler() and // should be deprecated once that function is refactored to use sessions properly. func TestCert(wrapper wrapping.Wrapper, userId, jobId string) (ed25519.PrivateKey, []byte, error) { - return newCert(context.Background(), wrapper, userId, jobId, time.Now().Add(5*time.Minute)) + return newCert(context.Background(), wrapper, userId, jobId, []string{"127.0.0.1", "localhost"}, time.Now().Add(5*time.Minute)) }