Put session ID in ALPN (#1966)

This allows us to use SNI for actual host (e.g. routing) information. We
don't have the client cert yet so we can't look at that, so this
provides a way for us to convey the information needed to look up that
session's TLS stack.

Using SNI for hosts means we also run into the fact that we don't have
automatic agreement in terms of SANs. So when generating the certs we
now also pass worker address information to the function to be encoded
in the cert.

Finally, there is a change in how the websocket dialing happens, because
http.RoundTripper tries to be too clever for its own good and overwrites
NextProtos on a whim.
pull/1968/head
Jeff Mitchell 4 years ago committed by GitHub
parent 62f0d18885
commit 58a448fc6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,6 +5,7 @@ import "time"
const (
TcpProxyV1 = "boundary-tcp-proxy-v1"
ServiceTokenV1 = "s1"
SessionPrefix = "s_"
)
type (

@ -70,10 +70,10 @@ require (
github.com/stretchr/testify v1.7.0
github.com/zalando/go-keyring v0.2.1
go.uber.org/atomic v1.9.0
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000
golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29
golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
golang.org/x/tools v0.1.8-0.20211102182255-bb4add04ddef
golang.org/x/tools v0.1.10
google.golang.org/genproto v0.0.0-20220208230804-65c12eb4c068
google.golang.org/grpc v1.44.0
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0
@ -189,7 +189,7 @@ require (
go.opencensus.io v0.23.0 // indirect
go.uber.org/multierr v1.7.0 // indirect
go.uber.org/zap v1.19.0 // indirect
golang.org/x/mod v0.5.1 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect

@ -963,8 +963,9 @@ golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 h1:SL+8VVnkqyshUSz5iNnXtrBQzvFF2SkROm6t5RczFAE=
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 h1:tkVvjkPTB7pnW3jnid7kNyAMPVWllTNOf/qKDze4p9o=
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -1001,8 +1002,9 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.5.1 h1:OJxoQ/rynoF0dcCdI7cLPktw/hR2cueqYfjm43oqK38=
golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180530234432-1e491301e022/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -1166,8 +1168,9 @@ golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a h1:ppl5mZgokTT8uPkmYOyEUmPTr3ypaKkg5eFOGrAmxxE=
golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f h1:rlezHXNlxYWvBCzNses9Dlc7nGFaNMJeqLolcmQSSZY=
golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@ -1257,8 +1260,9 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.8-0.20211102182255-bb4add04ddef h1:/DaKawnTFFxdq/mJT3pM+OkeJlq5gc3ZhkbGVYbqOCw=
golang.org/x/tools v0.1.8-0.20211102182255-bb4add04ddef/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU=
golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

@ -427,6 +427,15 @@ func (c *Command) Run(args []string) (retCode int) {
c.connectionsLeft.Store(c.sessionAuthzData.ConnectionLimit)
workerAddr := c.sessionAuthzData.GetWorkerInfo()[0].GetAddress()
workerHost, _, err := net.SplitHostPort(workerAddr)
if err != nil {
if strings.Contains(err.Error(), "missing port") {
workerHost = workerAddr
} else {
c.PrintCliError(fmt.Errorf("Error splitting worker adddress host/port: %w", err))
return base.CommandUserError
}
}
parsedCert, err := x509.ParseCertificate(c.sessionAuthzData.Certificate)
if err != nil {
@ -434,11 +443,6 @@ func (c *Command) Run(args []string) (retCode int) {
return base.CommandUserError
}
if len(parsedCert.DNSNames) != 1 {
c.PrintCliError(fmt.Errorf("mTLS certificate has invalid parameters: %w", err))
return base.CommandUserError
}
c.expiration = parsedCert.NotAfter
// We don't _rely_ on client-side timeout verification but this prevents us
@ -459,16 +463,20 @@ func (c *Command) Run(args []string) (retCode int) {
},
},
RootCAs: certPool,
ServerName: parsedCert.DNSNames[0],
ServerName: workerHost,
MinVersion: tls.VersionTLS13,
NextProtos: []string{"http/1.1", c.sessionAuthzData.SessionId},
}
transport := cleanhttp.DefaultTransport()
transport.DisableKeepAlives = false
transport.TLSClientConfig = tlsConf
// This isn't/shouldn't used anyways really because the connection is
// hijacked, just setting for completeness
transport.IdleConnTimeout = 0
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &tls.Dialer{Config: tlsConf}
return dialer.DialContext(ctx, network, addr)
}
c.listener, err = net.ListenTCP("tcp", &net.TCPAddr{
IP: listenAddr,
@ -691,7 +699,7 @@ func (c *Command) getWsConn(
) (*websocket.Conn, error) {
conn, resp, err := websocket.Dial(
ctx,
fmt.Sprintf("wss://%s/v1/proxy", workerAddr),
fmt.Sprintf("ws://%s/v1/proxy", workerAddr),
&websocket.DialOptions{
HTTPClient: &http.Client{
Transport: transport,

@ -100,7 +100,8 @@ func NewService(
sessionRepoFn common.SessionRepoFactory,
pluginHostRepoFn common.PluginHostRepoFactory,
staticHostRepoFn common.StaticRepoFactory,
vaultCredRepoFn common.VaultCredentialRepoFactory) (Service, error) {
vaultCredRepoFn common.VaultCredentialRepoFactory,
) (Service, error) {
const op = "targets.NewService"
if repoFn == nil {
return Service{}, errors.New(ctx, errors.InvalidParameter, op, "missing target repository")
@ -939,6 +940,10 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession
codes.FailedPrecondition,
"No workers are available to handle this session, or all have been filtered.")
}
workerAddresses := make([]string, 0, len(workers))
for _, worker := range workers {
workerAddresses = append(workerAddresses, worker.GetAddress())
}
requestedId := req.GetHostId()
staticHostRepo, err := s.staticHostRepoFn()
@ -1046,7 +1051,7 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession
if err != nil {
return nil, err
}
sess, privKey, err := sessionRepo.CreateSession(ctx, wrapper, sess)
sess, privKey, err := sessionRepo.CreateSession(ctx, wrapper, sess, workerAddresses)
if err != nil {
return nil, err
}

@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"strconv"
"strings"
"github.com/hashicorp/boundary/globals"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
@ -53,11 +54,26 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig) (http.Han
return func(wr http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if r.TLS == nil {
event.WriteError(ctx, op, errors.New("no request TLS information found"))
event.WriteError(ctx, op, errors.New("no request tls information found"))
wr.WriteHeader(http.StatusInternalServerError)
return
}
var sessionId string
outerCertLoop:
for _, cert := range r.TLS.PeerCertificates {
for _, name := range cert.DNSNames {
if strings.HasPrefix(name, globals.SessionPrefix) {
sessionId = name
break outerCertLoop
}
}
}
if sessionId == "" {
event.WriteError(ctx, op, errors.New("no session id could be found in peer certificates"))
wr.WriteHeader(http.StatusInternalServerError)
return
}
sessionId := r.TLS.ServerName
clientIp, clientPort, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
@ -78,7 +94,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig) (http.Han
userClientIp, err := common.ClientIpFromRequest(ctx, listenerCfg, r)
if err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("unable to determine user IP"))
event.WriteError(ctx, op, err, event.WithInfoMsg("unable to determine user ip"))
wr.WriteHeader(http.StatusInternalServerError)
}

@ -14,6 +14,7 @@ import (
"sync/atomic"
"time"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
@ -267,11 +268,20 @@ func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error)
ctx := w.baseContext
var sessionId string
switch {
case strings.HasPrefix(hello.ServerName, "s_"):
case strings.HasPrefix(hello.ServerName, globals.SessionPrefix):
sessionId = hello.ServerName
default:
event.WriteSysEvent(ctx, op, "invalid session in SNI", "session_id", hello.ServerName)
return nil, fmt.Errorf("could not find session ID in SNI")
for _, proto := range hello.SupportedProtos {
if strings.HasPrefix(proto, globals.SessionPrefix) {
sessionId = proto
break
}
}
}
if sessionId == "" {
event.WriteSysEvent(ctx, op, "session_id not found in either SNI or ALPN protos", "server_name", hello.ServerName, "alpn_protos", hello.SupportedProtos)
return nil, fmt.Errorf("could not find session ID in SNI or ALPN protos")
}
conn, err := w.ControllerSessionConn()
@ -299,10 +309,6 @@ func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error)
return nil, fmt.Errorf("error parsing session certificate: %w", err)
}
if len(parsedCert.DNSNames) != 1 {
return nil, fmt.Errorf("invalid length of DNS names (%d) in parsed certificate", len(parsedCert.DNSNames))
}
certPool := x509.NewCertPool()
certPool.AddCert(parsedCert)
@ -314,7 +320,8 @@ func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error)
Leaf: parsedCert,
},
},
ServerName: parsedCert.DNSNames[0],
NextProtos: []string{"http/1.1"},
ServerName: sessionId,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: certPool,
MinVersion: tls.VersionTLS13,

@ -19,7 +19,7 @@ import (
// its State of "Pending". The following fields must be empty when creating a
// session: ServerId, ServerType, and PublicId. No options are
// currently supported.
func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.Wrapper, newSession *Session, _ ...Option) (*Session, ed25519.PrivateKey, error) {
func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.Wrapper, newSession *Session, workerAddresses []string, _ ...Option) (*Session, ed25519.PrivateKey, error) {
const op = "session.(Repository).CreateSession"
if newSession == nil {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing session")
@ -63,13 +63,16 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.
if newSession.ExpirationTime == nil || newSession.ExpirationTime.Timestamp.AsTime().IsZero() {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing expiration time")
}
if len(workerAddresses) == 0 {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing addresses")
}
id, err := newId()
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
privKey, certBytes, err := newCert(ctx, sessionWrapper, newSession.UserId, id, newSession.ExpirationTime.Timestamp.AsTime())
privKey, certBytes, err := newCert(ctx, sessionWrapper, newSession.UserId, id, workerAddresses, newSession.ExpirationTime.Timestamp.AsTime())
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}

@ -274,8 +274,10 @@ func TestRepository_CreateSession(t *testing.T) {
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
workerAddresses := []string{"1.2.3.4"}
type args struct {
composedOf ComposedOf
composedOf ComposedOf
workerAddresses []string
}
tests := []struct {
name string
@ -286,14 +288,16 @@ func TestRepository_CreateSession(t *testing.T) {
{
name: "valid",
args: args{
composedOf: TestSessionParams(t, conn, wrapper, iamRepo),
composedOf: TestSessionParams(t, conn, wrapper, iamRepo),
workerAddresses: workerAddresses,
},
wantErr: false,
},
{
name: "valid-with-credentials",
args: args{
composedOf: testSessionCredentialParams(t, conn, wrapper, iamRepo),
composedOf: testSessionCredentialParams(t, conn, wrapper, iamRepo),
workerAddresses: workerAddresses,
},
wantErr: false,
},
@ -305,6 +309,7 @@ func TestRepository_CreateSession(t *testing.T) {
c.UserId = ""
return c
}(),
workerAddresses: workerAddresses,
},
wantErr: true,
wantIsError: errors.InvalidParameter,
@ -317,6 +322,7 @@ func TestRepository_CreateSession(t *testing.T) {
c.HostId = ""
return c
}(),
workerAddresses: workerAddresses,
},
wantErr: true,
wantIsError: errors.InvalidParameter,
@ -329,6 +335,7 @@ func TestRepository_CreateSession(t *testing.T) {
c.TargetId = ""
return c
}(),
workerAddresses: workerAddresses,
},
wantErr: true,
wantIsError: errors.InvalidParameter,
@ -341,6 +348,7 @@ func TestRepository_CreateSession(t *testing.T) {
c.HostSetId = ""
return c
}(),
workerAddresses: workerAddresses,
},
wantErr: true,
wantIsError: errors.InvalidParameter,
@ -353,6 +361,7 @@ func TestRepository_CreateSession(t *testing.T) {
c.AuthTokenId = ""
return c
}(),
workerAddresses: workerAddresses,
},
wantErr: true,
wantIsError: errors.InvalidParameter,
@ -365,6 +374,18 @@ func TestRepository_CreateSession(t *testing.T) {
c.ScopeId = ""
return c
}(),
workerAddresses: workerAddresses,
},
wantErr: true,
wantIsError: errors.InvalidParameter,
},
{
name: "empty-worker-addresses",
args: args{
composedOf: func() ComposedOf {
c := TestSessionParams(t, conn, wrapper, iamRepo)
return c
}(),
},
wantErr: true,
wantIsError: errors.InvalidParameter,
@ -385,7 +406,7 @@ func TestRepository_CreateSession(t *testing.T) {
ConnectionLimit: tt.args.composedOf.ConnectionLimit,
DynamicCredentials: tt.args.composedOf.DynamicCredentials,
}
ses, privKey, err := repo.CreateSession(context.Background(), wrapper, s)
ses, privKey, err := repo.CreateSession(context.Background(), wrapper, s, tt.args.workerAddresses)
if tt.wantErr {
assert.Error(err)
assert.Nil(ses)

@ -7,6 +7,7 @@ import (
"crypto/x509"
"math/big"
mathrand "math/rand"
"net"
"strings"
"time"
@ -351,7 +352,8 @@ func contains(ss []string, t string) bool {
return false
}
func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string, exp time.Time) (ed25519.PrivateKey, []byte, error) {
// newCert creates a new session certificate. If addresses are supplied, they will be parsed and added to IP or DNS SANs as appropriate.
func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string, addresses []string, exp time.Time) (ed25519.PrivateKey, []byte, error) {
const op = "session.newCert"
if wrapper == nil {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing wrapper")
@ -362,6 +364,9 @@ func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string
if jobId == "" {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing job id")
}
if len(addresses) == 0 {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing addresses")
}
pubKey, privKey, err := DeriveED25519Key(ctx, wrapper, userId, jobId)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
@ -380,6 +385,27 @@ func newCert(ctx context.Context, wrapper wrapping.Wrapper, userId, jobId string
IsCA: true,
}
for _, addr := range addresses {
// First ensure we aren't looking at ports, regardless of IP or not
host, _, err := net.SplitHostPort(addr)
if err != nil {
if strings.Contains(err.Error(), "missing port") {
host = addr
} else {
return nil, nil, errors.Wrap(ctx, err, op)
}
}
// Now figure out if it's an IP address or not. If ParseIP likes it, add
// to IP SANs. Otherwise DNS SANs.
ip := net.ParseIP(host)
switch ip {
case nil:
template.DNSNames = append(template.DNSNames, host)
default:
template.IPAddresses = append(template.IPAddresses, ip)
}
}
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, pubKey, privKey)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.GenCert))

@ -2,6 +2,7 @@ package session
import (
"context"
"crypto/x509"
"testing"
"time"
@ -23,8 +24,10 @@ func TestSession_Create(t *testing.T) {
composedOf := testSessionCredentialParams(t, conn, wrapper, iamRepo)
exp := &timestamp.Timestamp{Timestamp: timestamppb.New(time.Now().Add(time.Hour))}
defaultAddresses := []string{"1.2.3.4", "a.b.c.d"}
type args struct {
composedOf ComposedOf
addresses []string
opt []Option
}
tests := []struct {
@ -32,6 +35,7 @@ func TestSession_Create(t *testing.T) {
args args
want *Session
wantErr bool
wantAddrErr bool
wantIsErr errors.Code
create bool
wantCreateErr bool
@ -41,6 +45,7 @@ func TestSession_Create(t *testing.T) {
args: args{
composedOf: composedOf,
opt: []Option{WithExpirationTime(exp)},
addresses: defaultAddresses,
},
want: &Session{
UserId: composedOf.UserId,
@ -64,6 +69,7 @@ func TestSession_Create(t *testing.T) {
c.UserId = ""
return c
}(),
addresses: defaultAddresses,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
@ -76,6 +82,7 @@ func TestSession_Create(t *testing.T) {
c.HostId = ""
return c
}(),
addresses: defaultAddresses,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
@ -88,6 +95,7 @@ func TestSession_Create(t *testing.T) {
c.TargetId = ""
return c
}(),
addresses: defaultAddresses,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
@ -100,6 +108,7 @@ func TestSession_Create(t *testing.T) {
c.HostSetId = ""
return c
}(),
addresses: defaultAddresses,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
@ -112,6 +121,7 @@ func TestSession_Create(t *testing.T) {
c.AuthTokenId = ""
return c
}(),
addresses: defaultAddresses,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
@ -124,10 +134,34 @@ func TestSession_Create(t *testing.T) {
c.ScopeId = ""
return c
}(),
addresses: defaultAddresses,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
},
{
name: "empty-addresses",
args: args{
composedOf: func() ComposedOf {
c := composedOf
return c
}(),
},
want: &Session{
UserId: composedOf.UserId,
HostId: composedOf.HostId,
TargetId: composedOf.TargetId,
HostSetId: composedOf.HostSetId,
AuthTokenId: composedOf.AuthTokenId,
ScopeId: composedOf.ScopeId,
Endpoint: "tcp://127.0.0.1:22",
ExpirationTime: composedOf.ExpirationTime,
ConnectionLimit: composedOf.ConnectionLimit,
DynamicCredentials: composedOf.DynamicCredentials,
},
wantAddrErr: true,
wantIsErr: errors.InvalidParameter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -145,7 +179,12 @@ func TestSession_Create(t *testing.T) {
id, err := db.NewPublicId(SessionPrefix)
require.NoError(err)
got.PublicId = id
_, certBytes, err := newCert(ctx, wrapper, got.UserId, id, composedOf.ExpirationTime.Timestamp.AsTime())
_, certBytes, err := newCert(ctx, wrapper, got.UserId, id, tt.args.addresses, composedOf.ExpirationTime.Timestamp.AsTime())
if tt.wantAddrErr {
require.Error(err)
assert.True(errors.Match(errors.T(tt.wantIsErr), err))
return
}
require.NoError(err)
got.Certificate = certBytes
err = db.New(conn).Create(ctx, got)
@ -155,6 +194,13 @@ func TestSession_Create(t *testing.T) {
} else {
assert.NoError(err)
}
if len(tt.args.addresses) > 0 {
cert, err := x509.ParseCertificate(certBytes)
require.NoError(err)
// Session ID is always encoded, hence the +1
assert.Equal(len(tt.args.addresses)+1, len(cert.DNSNames)+len(cert.IPAddresses))
}
}
})
}

@ -84,7 +84,7 @@ func TestSession(t *testing.T, conn *db.DB, wrapper wrapping.Wrapper, c Composed
id, err := newId()
require.NoError(err)
s.PublicId = id
_, certBytes, err := newCert(ctx, wrapper, c.UserId, id, c.ExpirationTime.Timestamp.AsTime())
_, certBytes, err := newCert(ctx, wrapper, c.UserId, id, []string{"127.0.0.1", "localhost"}, c.ExpirationTime.Timestamp.AsTime())
require.NoError(err)
s.Certificate = certBytes
s.ServerId = opts.withServerId
@ -205,5 +205,5 @@ func TestWorker(t *testing.T, conn *db.DB, wrapper wrapping.Wrapper, opt ...Opti
// as a parameter. It's currently used in controller.jobTestingHandler() and
// should be deprecated once that function is refactored to use sessions properly.
func TestCert(wrapper wrapping.Wrapper, userId, jobId string) (ed25519.PrivateKey, []byte, error) {
return newCert(context.Background(), wrapper, userId, jobId, time.Now().Add(5*time.Minute))
return newCert(context.Background(), wrapper, userId, jobId, []string{"127.0.0.1", "localhost"}, time.Now().Add(5*time.Minute))
}

Loading…
Cancel
Save