From 58a448fc6a944a1e28ea5a6010a3383cc0f6613a Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 1 Apr 2022 16:32:20 -0400 Subject: [PATCH] Put session ID in ALPN (#1966) This allows us to use SNI for actual host (e.g. routing) information. We don't have the client cert yet so we can't look at that, so this provides a way for us to convey the information needed to look up that session's TLS stack. Using SNI for hosts means we also run into the fact that we don't have automatic agreement in terms of SANs. So when generating the certs we now also pass worker address information to the function to be encoded in the cert. Finally, there is a change in how the websocket dialing happens, because http.RoundTripper tries to be too clever for its own good and overwrites NextProtos on a whim. --- globals/globals.go | 1 + go.mod | 8 ++-- go.sum | 12 +++-- internal/cmd/commands/connect/connect.go | 24 ++++++---- .../handlers/targets/target_service.go | 9 +++- internal/servers/worker/handler.go | 22 +++++++-- internal/servers/worker/worker.go | 23 +++++---- internal/session/repository_session.go | 7 ++- internal/session/repository_session_test.go | 29 +++++++++-- internal/session/session.go | 28 ++++++++++- internal/session/session_test.go | 48 ++++++++++++++++++- internal/session/testing.go | 4 +- 12 files changed, 176 insertions(+), 39 deletions(-) diff --git a/globals/globals.go b/globals/globals.go index b5b9acbe61..daba120ccf 100644 --- a/globals/globals.go +++ b/globals/globals.go @@ -5,6 +5,7 @@ import "time" const ( TcpProxyV1 = "boundary-tcp-proxy-v1" ServiceTokenV1 = "s1" + SessionPrefix = "s_" ) type ( diff --git a/go.mod b/go.mod index 441d88cfcb..a521af6f6b 100644 --- a/go.mod +++ b/go.mod @@ -70,10 +70,10 @@ require ( github.com/stretchr/testify v1.7.0 github.com/zalando/go-keyring v0.2.1 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 - golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a + golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 + golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 - golang.org/x/tools v0.1.8-0.20211102182255-bb4add04ddef + golang.org/x/tools v0.1.10 google.golang.org/genproto v0.0.0-20220208230804-65c12eb4c068 google.golang.org/grpc v1.44.0 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0 @@ -189,7 +189,7 @@ require ( go.opencensus.io v0.23.0 // indirect go.uber.org/multierr v1.7.0 // indirect go.uber.org/zap v1.19.0 // indirect - golang.org/x/mod v0.5.1 // indirect + golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect diff --git a/go.sum b/go.sum index d1d920014b..6c19c0cd52 100644 --- a/go.sum +++ b/go.sum @@ -963,8 +963,9 @@ golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5 golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 h1:SL+8VVnkqyshUSz5iNnXtrBQzvFF2SkROm6t5RczFAE= golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 h1:tkVvjkPTB7pnW3jnid7kNyAMPVWllTNOf/qKDze4p9o= +golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -1001,8 +1002,9 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.5.1 h1:OJxoQ/rynoF0dcCdI7cLPktw/hR2cueqYfjm43oqK38= golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180530234432-1e491301e022/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1166,8 +1168,9 @@ golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a h1:ppl5mZgokTT8uPkmYOyEUmPTr3ypaKkg5eFOGrAmxxE= golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f h1:rlezHXNlxYWvBCzNses9Dlc7nGFaNMJeqLolcmQSSZY= +golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1257,8 +1260,9 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.8-0.20211102182255-bb4add04ddef h1:/DaKawnTFFxdq/mJT3pM+OkeJlq5gc3ZhkbGVYbqOCw= golang.org/x/tools v0.1.8-0.20211102182255-bb4add04ddef/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= +golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20= +golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/cmd/commands/connect/connect.go b/internal/cmd/commands/connect/connect.go index 2d8f2d7837..5def593f07 100644 --- a/internal/cmd/commands/connect/connect.go +++ b/internal/cmd/commands/connect/connect.go @@ -427,6 +427,15 @@ func (c *Command) Run(args []string) (retCode int) { c.connectionsLeft.Store(c.sessionAuthzData.ConnectionLimit) workerAddr := c.sessionAuthzData.GetWorkerInfo()[0].GetAddress() + workerHost, _, err := net.SplitHostPort(workerAddr) + if err != nil { + if strings.Contains(err.Error(), "missing port") { + workerHost = workerAddr + } else { + c.PrintCliError(fmt.Errorf("Error splitting worker adddress host/port: %w", err)) + return base.CommandUserError + } + } parsedCert, err := x509.ParseCertificate(c.sessionAuthzData.Certificate) if err != nil { @@ -434,11 +443,6 @@ func (c *Command) Run(args []string) (retCode int) { return base.CommandUserError } - if len(parsedCert.DNSNames) != 1 { - c.PrintCliError(fmt.Errorf("mTLS certificate has invalid parameters: %w", err)) - return base.CommandUserError - } - c.expiration = parsedCert.NotAfter // We don't _rely_ on client-side timeout verification but this prevents us @@ -459,16 +463,20 @@ func (c *Command) Run(args []string) (retCode int) { }, }, RootCAs: certPool, - ServerName: parsedCert.DNSNames[0], + ServerName: workerHost, MinVersion: tls.VersionTLS13, + NextProtos: []string{"http/1.1", c.sessionAuthzData.SessionId}, } transport := cleanhttp.DefaultTransport() transport.DisableKeepAlives = false - transport.TLSClientConfig = tlsConf // This isn't/shouldn't used anyways really because the connection is // hijacked, just setting for completeness transport.IdleConnTimeout = 0 + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &tls.Dialer{Config: tlsConf} + return dialer.DialContext(ctx, network, addr) + } c.listener, err = net.ListenTCP("tcp", &net.TCPAddr{ IP: listenAddr, @@ -691,7 +699,7 @@ func (c *Command) getWsConn( ) (*websocket.Conn, error) { conn, resp, err := websocket.Dial( ctx, - fmt.Sprintf("wss://%s/v1/proxy", workerAddr), + fmt.Sprintf("ws://%s/v1/proxy", workerAddr), &websocket.DialOptions{ HTTPClient: &http.Client{ Transport: transport, diff --git a/internal/servers/controller/handlers/targets/target_service.go b/internal/servers/controller/handlers/targets/target_service.go index 2b1278d034..3a09ed6c1b 100644 --- a/internal/servers/controller/handlers/targets/target_service.go +++ b/internal/servers/controller/handlers/targets/target_service.go @@ -100,7 +100,8 @@ func NewService( sessionRepoFn common.SessionRepoFactory, pluginHostRepoFn common.PluginHostRepoFactory, staticHostRepoFn common.StaticRepoFactory, - vaultCredRepoFn common.VaultCredentialRepoFactory) (Service, error) { + vaultCredRepoFn common.VaultCredentialRepoFactory, +) (Service, error) { const op = "targets.NewService" if repoFn == nil { return Service{}, errors.New(ctx, errors.InvalidParameter, op, "missing target repository") @@ -939,6 +940,10 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession codes.FailedPrecondition, "No workers are available to handle this session, or all have been filtered.") } + workerAddresses := make([]string, 0, len(workers)) + for _, worker := range workers { + workerAddresses = append(workerAddresses, worker.GetAddress()) + } requestedId := req.GetHostId() staticHostRepo, err := s.staticHostRepoFn() @@ -1046,7 +1051,7 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession if err != nil { return nil, err } - sess, privKey, err := sessionRepo.CreateSession(ctx, wrapper, sess) + sess, privKey, err := sessionRepo.CreateSession(ctx, wrapper, sess, workerAddresses) if err != nil { return nil, err } diff --git a/internal/servers/worker/handler.go b/internal/servers/worker/handler.go index ae26217242..48e7ea6de0 100644 --- a/internal/servers/worker/handler.go +++ b/internal/servers/worker/handler.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "strconv" + "strings" "github.com/hashicorp/boundary/globals" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" @@ -53,11 +54,26 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig) (http.Han return func(wr http.ResponseWriter, r *http.Request) { ctx := r.Context() if r.TLS == nil { - event.WriteError(ctx, op, errors.New("no request TLS information found")) + event.WriteError(ctx, op, errors.New("no request tls information found")) + wr.WriteHeader(http.StatusInternalServerError) + return + } + + var sessionId string + outerCertLoop: + for _, cert := range r.TLS.PeerCertificates { + for _, name := range cert.DNSNames { + if strings.HasPrefix(name, globals.SessionPrefix) { + sessionId = name + break outerCertLoop + } + } + } + if sessionId == "" { + event.WriteError(ctx, op, errors.New("no session id could be found in peer certificates")) wr.WriteHeader(http.StatusInternalServerError) return } - sessionId := r.TLS.ServerName clientIp, clientPort, err := net.SplitHostPort(r.RemoteAddr) if err != nil { @@ -78,7 +94,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig) (http.Han userClientIp, err := common.ClientIpFromRequest(ctx, listenerCfg, r) if err != nil { - event.WriteError(ctx, op, err, event.WithInfoMsg("unable to determine user IP")) + event.WriteError(ctx, op, err, event.WithInfoMsg("unable to determine user ip")) wr.WriteHeader(http.StatusInternalServerError) } diff --git a/internal/servers/worker/worker.go b/internal/servers/worker/worker.go index 332b017aba..4ca8cd7cb4 100644 --- a/internal/servers/worker/worker.go +++ b/internal/servers/worker/worker.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/cmd/config" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" @@ -267,11 +268,20 @@ func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error) ctx := w.baseContext var sessionId string switch { - case strings.HasPrefix(hello.ServerName, "s_"): + case strings.HasPrefix(hello.ServerName, globals.SessionPrefix): sessionId = hello.ServerName default: - event.WriteSysEvent(ctx, op, "invalid session in SNI", "session_id", hello.ServerName) - return nil, fmt.Errorf("could not find session ID in SNI") + for _, proto := range hello.SupportedProtos { + if strings.HasPrefix(proto, globals.SessionPrefix) { + sessionId = proto + break + } + } + } + + if sessionId == "" { + event.WriteSysEvent(ctx, op, "session_id not found in either SNI or ALPN protos", "server_name", hello.ServerName, "alpn_protos", hello.SupportedProtos) + return nil, fmt.Errorf("could not find session ID in SNI or ALPN protos") } conn, err := w.ControllerSessionConn() @@ -299,10 +309,6 @@ func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error) return nil, fmt.Errorf("error parsing session certificate: %w", err) } - if len(parsedCert.DNSNames) != 1 { - return nil, fmt.Errorf("invalid length of DNS names (%d) in parsed certificate", len(parsedCert.DNSNames)) - } - certPool := x509.NewCertPool() certPool.AddCert(parsedCert) @@ -314,7 +320,8 @@ func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error) Leaf: parsedCert, }, }, - ServerName: parsedCert.DNSNames[0], + NextProtos: []string{"http/1.1"}, + ServerName: sessionId, ClientAuth: tls.RequireAndVerifyClientCert, ClientCAs: certPool, MinVersion: tls.VersionTLS13, diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index 47657cca4d..6e28a42ab6 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -19,7 +19,7 @@ import ( // its State of "Pending". The following fields must be empty when creating a // session: ServerId, ServerType, and PublicId. No options are // currently supported. -func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.Wrapper, newSession *Session, _ ...Option) (*Session, ed25519.PrivateKey, error) { +func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.Wrapper, newSession *Session, workerAddresses []string, _ ...Option) (*Session, ed25519.PrivateKey, error) { const op = "session.(Repository).CreateSession" if newSession == nil { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing session") @@ -63,13 +63,16 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping. if newSession.ExpirationTime == nil || newSession.ExpirationTime.Timestamp.AsTime().IsZero() { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing expiration time") } + if len(workerAddresses) == 0 { + return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing addresses") + } id, err := newId() if err != nil { return nil, nil, errors.Wrap(ctx, err, op) } - privKey, certBytes, err := newCert(ctx, sessionWrapper, newSession.UserId, id, newSession.ExpirationTime.Timestamp.AsTime()) + privKey, certBytes, err := newCert(ctx, sessionWrapper, newSession.UserId, id, workerAddresses, newSession.ExpirationTime.Timestamp.AsTime()) if err != nil { return nil, nil, errors.Wrap(ctx, err, op) } diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index e667cd6b55..8f7d1b7384 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -274,8 +274,10 @@ func TestRepository_CreateSession(t *testing.T) { repo, err := NewRepository(rw, rw, kms) require.NoError(t, err) + workerAddresses := []string{"1.2.3.4"} type args struct { - composedOf ComposedOf + composedOf ComposedOf + workerAddresses []string } tests := []struct { name string @@ -286,14 +288,16 @@ func TestRepository_CreateSession(t *testing.T) { { name: "valid", args: args{ - composedOf: TestSessionParams(t, conn, wrapper, iamRepo), + composedOf: TestSessionParams(t, conn, wrapper, iamRepo), + workerAddresses: workerAddresses, }, wantErr: false, }, { name: "valid-with-credentials", args: args{ - composedOf: testSessionCredentialParams(t, conn, wrapper, iamRepo), + composedOf: testSessionCredentialParams(t, conn, wrapper, iamRepo), + workerAddresses: workerAddresses, }, wantErr: false, }, @@ -305,6 +309,7 @@ func TestRepository_CreateSession(t *testing.T) { c.UserId = "" return c }(), + workerAddresses: workerAddresses, }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -317,6 +322,7 @@ func TestRepository_CreateSession(t *testing.T) { c.HostId = "" return c }(), + workerAddresses: workerAddresses, }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -329,6 +335,7 @@ func TestRepository_CreateSession(t *testing.T) { c.TargetId = "" return c }(), + workerAddresses: workerAddresses, }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -341,6 +348,7 @@ func TestRepository_CreateSession(t *testing.T) { c.HostSetId = "" return c }(), + workerAddresses: workerAddresses, }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -353,6 +361,7 @@ func TestRepository_CreateSession(t *testing.T) { c.AuthTokenId = "" return c }(), + workerAddresses: workerAddresses, }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -365,6 +374,18 @@ func TestRepository_CreateSession(t *testing.T) { c.ScopeId = "" return c }(), + workerAddresses: workerAddresses, + }, + wantErr: true, + wantIsError: errors.InvalidParameter, + }, + { + name: "empty-worker-addresses", + args: args{ + composedOf: func() ComposedOf { + c := TestSessionParams(t, conn, wrapper, iamRepo) + return c + }(), }, wantErr: true, wantIsError: errors.InvalidParameter, @@ -385,7 +406,7 @@ func TestRepository_CreateSession(t *testing.T) { ConnectionLimit: tt.args.composedOf.ConnectionLimit, DynamicCredentials: tt.args.composedOf.DynamicCredentials, } - ses, privKey, err := repo.CreateSession(context.Background(), wrapper, s) + ses, privKey, err := repo.CreateSession(context.Background(), wrapper, s, tt.args.workerAddresses) if tt.wantErr { assert.Error(err) assert.Nil(ses) diff --git a/internal/session/session.go b/internal/session/session.go index 6367a8f0e3..baaa5331f0 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "math/big" mathrand "math/rand" + "net" "strings" "time" @@ -351,7 +352,8 @@ func contains(ss []string, t string) bool { return false } -func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string, exp time.Time) (ed25519.PrivateKey, []byte, error) { +// newCert creates a new session certificate. If addresses are supplied, they will be parsed and added to IP or DNS SANs as appropriate. +func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string, addresses []string, exp time.Time) (ed25519.PrivateKey, []byte, error) { const op = "session.newCert" if wrapper == nil { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing wrapper") @@ -362,6 +364,9 @@ func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string if jobId == "" { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing job id") } + if len(addresses) == 0 { + return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing addresses") + } pubKey, privKey, err := DeriveED25519Key(ctx, wrapper, userId, jobId) if err != nil { return nil, nil, errors.Wrap(ctx, err, op) @@ -380,6 +385,27 @@ func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string IsCA: true, } + for _, addr := range addresses { + // First ensure we aren't looking at ports, regardless of IP or not + host, _, err := net.SplitHostPort(addr) + if err != nil { + if strings.Contains(err.Error(), "missing port") { + host = addr + } else { + return nil, nil, errors.Wrap(ctx, err, op) + } + } + // Now figure out if it's an IP address or not. If ParseIP likes it, add + // to IP SANs. Otherwise DNS SANs. + ip := net.ParseIP(host) + switch ip { + case nil: + template.DNSNames = append(template.DNSNames, host) + default: + template.IPAddresses = append(template.IPAddresses, ip) + } + } + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, pubKey, privKey) if err != nil { return nil, nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.GenCert)) diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 334f48aba4..a5699f14df 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -2,6 +2,7 @@ package session import ( "context" + "crypto/x509" "testing" "time" @@ -23,8 +24,10 @@ func TestSession_Create(t *testing.T) { composedOf := testSessionCredentialParams(t, conn, wrapper, iamRepo) exp := ×tamp.Timestamp{Timestamp: timestamppb.New(time.Now().Add(time.Hour))} + defaultAddresses := []string{"1.2.3.4", "a.b.c.d"} type args struct { composedOf ComposedOf + addresses []string opt []Option } tests := []struct { @@ -32,6 +35,7 @@ func TestSession_Create(t *testing.T) { args args want *Session wantErr bool + wantAddrErr bool wantIsErr errors.Code create bool wantCreateErr bool @@ -41,6 +45,7 @@ func TestSession_Create(t *testing.T) { args: args{ composedOf: composedOf, opt: []Option{WithExpirationTime(exp)}, + addresses: defaultAddresses, }, want: &Session{ UserId: composedOf.UserId, @@ -64,6 +69,7 @@ func TestSession_Create(t *testing.T) { c.UserId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -76,6 +82,7 @@ func TestSession_Create(t *testing.T) { c.HostId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -88,6 +95,7 @@ func TestSession_Create(t *testing.T) { c.TargetId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -100,6 +108,7 @@ func TestSession_Create(t *testing.T) { c.HostSetId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -112,6 +121,7 @@ func TestSession_Create(t *testing.T) { c.AuthTokenId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -124,10 +134,34 @@ func TestSession_Create(t *testing.T) { c.ScopeId = "" return c }(), + addresses: defaultAddresses, }, wantErr: true, wantIsErr: errors.InvalidParameter, }, + { + name: "empty-addresses", + args: args{ + composedOf: func() ComposedOf { + c := composedOf + return c + }(), + }, + want: &Session{ + UserId: composedOf.UserId, + HostId: composedOf.HostId, + TargetId: composedOf.TargetId, + HostSetId: composedOf.HostSetId, + AuthTokenId: composedOf.AuthTokenId, + ScopeId: composedOf.ScopeId, + Endpoint: "tcp://127.0.0.1:22", + ExpirationTime: composedOf.ExpirationTime, + ConnectionLimit: composedOf.ConnectionLimit, + DynamicCredentials: composedOf.DynamicCredentials, + }, + wantAddrErr: true, + wantIsErr: errors.InvalidParameter, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -145,7 +179,12 @@ func TestSession_Create(t *testing.T) { id, err := db.NewPublicId(SessionPrefix) require.NoError(err) got.PublicId = id - _, certBytes, err := newCert(ctx, wrapper, got.UserId, id, composedOf.ExpirationTime.Timestamp.AsTime()) + _, certBytes, err := newCert(ctx, wrapper, got.UserId, id, tt.args.addresses, composedOf.ExpirationTime.Timestamp.AsTime()) + if tt.wantAddrErr { + require.Error(err) + assert.True(errors.Match(errors.T(tt.wantIsErr), err)) + return + } require.NoError(err) got.Certificate = certBytes err = db.New(conn).Create(ctx, got) @@ -155,6 +194,13 @@ func TestSession_Create(t *testing.T) { } else { assert.NoError(err) } + + if len(tt.args.addresses) > 0 { + cert, err := x509.ParseCertificate(certBytes) + require.NoError(err) + // Session ID is always encoded, hence the +1 + assert.Equal(len(tt.args.addresses)+1, len(cert.DNSNames)+len(cert.IPAddresses)) + } } }) } diff --git a/internal/session/testing.go b/internal/session/testing.go index 7d2197b88b..c306c1d819 100644 --- a/internal/session/testing.go +++ b/internal/session/testing.go @@ -84,7 +84,7 @@ func TestSession(t *testing.T, conn *db.DB, wrapper wrapping.Wrapper, c Composed id, err := newId() require.NoError(err) s.PublicId = id - _, certBytes, err := newCert(ctx, wrapper, c.UserId, id, c.ExpirationTime.Timestamp.AsTime()) + _, certBytes, err := newCert(ctx, wrapper, c.UserId, id, []string{"127.0.0.1", "localhost"}, c.ExpirationTime.Timestamp.AsTime()) require.NoError(err) s.Certificate = certBytes s.ServerId = opts.withServerId @@ -205,5 +205,5 @@ func TestWorker(t *testing.T, conn *db.DB, wrapper wrapping.Wrapper, opt ...Opti // as a parameter. It's currently used in controller.jobTestingHandler() and // should be deprecated once that function is refactored to use sessions properly. func TestCert(wrapper wrapping.Wrapper, userId, jobId string) (ed25519.PrivateKey, []byte, error) { - return newCert(context.Background(), wrapper, userId, jobId, time.Now().Add(5*time.Minute)) + return newCert(context.Background(), wrapper, userId, jobId, []string{"127.0.0.1", "localhost"}, time.Now().Add(5*time.Minute)) }