@ -9,6 +9,9 @@ import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"sync"
"testing"
"time"
@ -33,11 +36,34 @@ import (
"github.com/hashicorp/nodeenrollment/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestWorkerNew ( t * testing . T ) {
knownHostsPath := t . TempDir ( ) + "/known_hosts"
nonexistantKnownHostsPath := t . TempDir ( ) + "/does_not_exist"
corruptedKnownHostsPath := t . TempDir ( ) + "/corrupted_known_hosts"
file , err := os . Create ( knownHostsPath )
require . NoError ( t , err )
defer file . Close ( )
signer , err := ssh . NewSignerFromKey ( ed25519 . NewKeyFromSeed ( [ ] byte ( "foobfoobfoobfoobfoobfoobfoobfoob" ) ) )
require . NoError ( t , err )
line := fmt . Sprintf ( "::1 %s" , string ( ssh . MarshalAuthorizedKey ( signer . PublicKey ( ) ) ) )
_ , err = file . WriteString ( line )
require . NoError ( t , err )
corruptFile , err := os . Create ( corruptedKnownHostsPath )
require . NoError ( t , err )
defer corruptFile . Close ( )
_ , err = corruptFile . WriteString ( "this is not valid known hosts content" )
require . NoError ( t , err )
tests := [ ] struct {
name string
in * Config
@ -190,6 +216,87 @@ func TestWorkerNew(t *testing.T) {
assert . Equal ( t , wpbs . UnimplementedHostServiceServer { } , w . HostServiceServer )
} ,
} ,
{
name : "valid with no known hosts path" ,
in : & Config {
Server : & base . Server {
Listeners : [ ] * base . ServerListener {
{ Config : & listenerutil . ListenerConfig { Purpose : [ ] string { "proxy" } } } ,
} ,
} ,
RawConfig : & config . Config {
SharedConfig : & configutil . SharedConfig {
DisableMlock : true ,
} ,
} ,
} ,
expErr : false ,
assertions : func ( t * testing . T , w * Worker ) {
assert . Nil ( t , w . SshKnownHostsCallback . Load ( ) )
} ,
} ,
{
name : "valid known hosts path" ,
in : & Config {
Server : & base . Server {
Listeners : [ ] * base . ServerListener {
{ Config : & listenerutil . ListenerConfig { Purpose : [ ] string { "proxy" } } } ,
} ,
} ,
RawConfig : & config . Config {
Worker : & config . Worker {
SshKnownHostsPath : knownHostsPath ,
} ,
SharedConfig : & configutil . SharedConfig {
DisableMlock : true ,
} ,
} ,
} ,
expErr : false ,
assertions : func ( t * testing . T , w * Worker ) {
assert . NotNil ( t , w . SshKnownHostsCallback . Load ( ) )
} ,
} ,
{
name : "invalid known hosts path" ,
in : & Config {
Server : & base . Server {
Listeners : [ ] * base . ServerListener {
{ Config : & listenerutil . ListenerConfig { Purpose : [ ] string { "proxy" } } } ,
} ,
} ,
RawConfig : & config . Config {
Worker : & config . Worker {
SshKnownHostsPath : nonexistantKnownHostsPath ,
} ,
SharedConfig : & configutil . SharedConfig {
DisableMlock : true ,
} ,
} ,
} ,
expErr : true ,
expErrMsg : "no such file or directory" ,
} ,
{
name : "corrupted known hosts file" ,
in : & Config {
Server : & base . Server {
Listeners : [ ] * base . ServerListener {
{ Config : & listenerutil . ListenerConfig { Purpose : [ ] string { "proxy" } } } ,
} ,
} ,
RawConfig : & config . Config {
Worker : & config . Worker {
SshKnownHostsPath : corruptedKnownHostsPath ,
} ,
SharedConfig : & configutil . SharedConfig {
DisableMlock : true ,
} ,
} ,
} ,
expErr : true ,
expErrMsg : "illegal base64 data at input byte" ,
} ,
}
for _ , tt := range tests {
@ -211,7 +318,7 @@ func TestWorkerNew(t *testing.T) {
w , err := New ( context . Background ( ) , tt . in )
if tt . expErr {
require . E qualE rror( t , err , tt . expErrMsg )
require . E rrorContains ( t , err , tt . expErrMsg )
require . Nil ( t , w )
return
}
@ -225,6 +332,19 @@ func TestWorkerNew(t *testing.T) {
}
func TestWorkerReload ( t * testing . T ) {
knownHostsPath := t . TempDir ( ) + "/known_hosts"
file , err := os . Create ( knownHostsPath )
require . NoError ( t , err )
defer file . Close ( )
signer , err := ssh . NewSignerFromKey ( ed25519 . NewKeyFromSeed ( [ ] byte ( "foobfoobfoobfoobfoobfoobfoobfoob" ) ) )
require . NoError ( t , err )
dummyAddr := & net . TCPAddr { IP : net . ParseIP ( "127.0.0.1" ) , Port : 22 }
line := fmt . Sprintf ( "github.com %s" , string ( ssh . MarshalAuthorizedKey ( signer . PublicKey ( ) ) ) )
_ , err = file . WriteString ( line )
require . NoError ( t , err )
t . Run ( "default config is the same as the reload config" , func ( t * testing . T ) {
require , assert := require . New ( t ) , assert . New ( t )
cfg := & Config {
@ -254,6 +374,7 @@ func TestWorkerReload(t *testing.T) {
assert . Equal ( int64 ( common . DefaultSessionInfoTimeout ) , w . sessionInfoCallTimeoutDuration . Load ( ) )
assert . Equal ( int64 ( server . DefaultLiveness ) , w . getDownstreamWorkersTimeoutDuration . Load ( ) )
assert . Nil ( w . SshKnownHostsCallback . Load ( ) )
w . Reload ( context . Background ( ) , cfg . RawConfig )
@ -266,6 +387,7 @@ func TestWorkerReload(t *testing.T) {
assert . Equal ( int64 ( common . DefaultSessionInfoTimeout ) , w . sessionInfoCallTimeoutDuration . Load ( ) )
assert . Equal ( int64 ( server . DefaultLiveness ) , w . getDownstreamWorkersTimeoutDuration . Load ( ) )
assert . Nil ( w . SshKnownHostsCallback . Load ( ) )
} )
t . Run ( "new config is the same as the reload config" , func ( t * testing . T ) {
@ -286,6 +408,7 @@ func TestWorkerReload(t *testing.T) {
SuccessfulControllerRPCGracePeriodDuration : 5 * time . Second ,
ControllerRPCCallTimeoutDuration : 10 * time . Second ,
GetDownstreamWorkersTimeoutDuration : 20 * time . Second ,
SshKnownHostsPath : knownHostsPath ,
} ,
} ,
}
@ -301,6 +424,10 @@ func TestWorkerReload(t *testing.T) {
assert . Equal ( int64 ( 10 * time . Second ) , w . sessionInfoCallTimeoutDuration . Load ( ) )
assert . Equal ( int64 ( 20 * time . Second ) , w . getDownstreamWorkersTimeoutDuration . Load ( ) )
cb := w . SshKnownHostsCallback . Load ( )
require . NotNil ( cb )
err = ( * cb ) ( "github.com:22" , dummyAddr , signer . PublicKey ( ) )
assert . NoError ( err )
w . Reload ( context . Background ( ) , cfg . RawConfig )
@ -313,6 +440,89 @@ func TestWorkerReload(t *testing.T) {
assert . Equal ( int64 ( 10 * time . Second ) , w . sessionInfoCallTimeoutDuration . Load ( ) )
assert . Equal ( int64 ( 20 * time . Second ) , w . getDownstreamWorkersTimeoutDuration . Load ( ) )
cb = w . SshKnownHostsCallback . Load ( )
require . NotNil ( cb )
err = ( * cb ) ( "github.com:22" , dummyAddr , signer . PublicKey ( ) )
assert . NoError ( err )
} )
t . Run ( "new config is different" , func ( t * testing . T ) {
require , assert := require . New ( t ) , assert . New ( t )
cfg := & Config {
Server : & base . Server {
Logger : hclog . Default ( ) ,
Eventer : & event . Eventer { } ,
Listeners : [ ] * base . ServerListener {
{ Config : & listenerutil . ListenerConfig { Purpose : [ ] string { "api" } } } ,
{ Config : & listenerutil . ListenerConfig { Purpose : [ ] string { "proxy" } } } ,
{ Config : & listenerutil . ListenerConfig { Purpose : [ ] string { "cluster" } } } ,
} ,
} ,
RawConfig : & config . Config {
SharedConfig : & configutil . SharedConfig { DisableMlock : true } ,
Worker : & config . Worker {
SuccessfulControllerRPCGracePeriodDuration : 5 * time . Second ,
ControllerRPCCallTimeoutDuration : 10 * time . Second ,
GetDownstreamWorkersTimeoutDuration : 20 * time . Second ,
SshKnownHostsPath : knownHostsPath ,
} ,
} ,
}
w , err := New ( context . Background ( ) , cfg )
require . NoError ( err )
assert . Equal ( int64 ( 5 * time . Second ) , w . successfulRoutingInfoGracePeriod . Load ( ) )
assert . Equal ( int64 ( 5 * time . Second ) , w . successfulSessionInfoGracePeriod . Load ( ) )
assert . Equal ( w . successfulRoutingInfoGracePeriod . Load ( ) , session . CloseCallTimeout . Load ( ) )
assert . Equal ( int64 ( 10 * time . Second ) , w . routingInfoCallTimeoutDuration . Load ( ) )
assert . Equal ( int64 ( 10 * time . Second ) , w . statisticsCallTimeoutDuration . Load ( ) )
assert . Equal ( int64 ( 10 * time . Second ) , w . sessionInfoCallTimeoutDuration . Load ( ) )
assert . Equal ( int64 ( 20 * time . Second ) , w . getDownstreamWorkersTimeoutDuration . Load ( ) )
cb := w . SshKnownHostsCallback . Load ( )
require . NotNil ( cb )
err = ( * cb ) ( "github.com:22" , dummyAddr , signer . PublicKey ( ) )
assert . NoError ( err )
// Update the config with new values
newKnownHostsFile := t . TempDir ( ) + "/new_known_hosts"
newFile , err := os . Create ( newKnownHostsFile )
require . NoError ( err )
defer newFile . Close ( )
newSigner , err := ssh . NewSignerFromKey ( ed25519 . NewKeyFromSeed ( [ ] byte ( "noobnoobnoobnoobnoobnoobnoobnoob" ) ) )
require . NoError ( err )
line := fmt . Sprintf ( "github.com %s" , string ( ssh . MarshalAuthorizedKey ( newSigner . PublicKey ( ) ) ) )
_ , err = newFile . WriteString ( line )
require . NoError ( err )
cfg . RawConfig . Worker . SuccessfulControllerRPCGracePeriodDuration = 30 * time . Second
cfg . RawConfig . Worker . ControllerRPCCallTimeoutDuration = 35 * time . Second
cfg . RawConfig . Worker . GetDownstreamWorkersTimeoutDuration = 40 * time . Second
cfg . RawConfig . Worker . SshKnownHostsPath = newKnownHostsFile
w . Reload ( context . Background ( ) , cfg . RawConfig )
assert . Equal ( int64 ( 30 * time . Second ) , w . successfulRoutingInfoGracePeriod . Load ( ) )
assert . Equal ( int64 ( 30 * time . Second ) , w . successfulSessionInfoGracePeriod . Load ( ) )
assert . Equal ( w . successfulRoutingInfoGracePeriod . Load ( ) , session . CloseCallTimeout . Load ( ) )
assert . Equal ( int64 ( 35 * time . Second ) , w . routingInfoCallTimeoutDuration . Load ( ) )
assert . Equal ( int64 ( 35 * time . Second ) , w . statisticsCallTimeoutDuration . Load ( ) )
assert . Equal ( int64 ( 35 * time . Second ) , w . sessionInfoCallTimeoutDuration . Load ( ) )
assert . Equal ( int64 ( 40 * time . Second ) , w . getDownstreamWorkersTimeoutDuration . Load ( ) )
cb = w . SshKnownHostsCallback . Load ( )
require . NotNil ( cb )
// Old signer should fail
err = ( * cb ) ( "github.com:22" , dummyAddr , signer . PublicKey ( ) )
assert . Error ( err )
// New signer should work
err = ( * cb ) ( "github.com:22" , dummyAddr , newSigner . PublicKey ( ) )
assert . NoError ( err )
} )
}