package session import ( "context" "errors" "testing" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/iam" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestConnectionState_Create(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres") wrapper := db.TestWrapper(t) iamRepo := iam.TestRepo(t, conn, wrapper) session := TestDefaultSession(t, conn, wrapper, iamRepo) connection := TestConnection(t, conn, session.PublicId, "127.0.0.1", 443, "127.0.0.1", 4443) type args struct { connectionId string status ConnectionStatus } tests := []struct { name string args args want *ConnectionState wantErr bool wantIsErr error create bool wantCreateErr bool }{ { name: "valid", args: args{ connectionId: connection.PublicId, status: StatusClosed, }, want: &ConnectionState{ ConnectionId: connection.PublicId, Status: StatusClosed, }, create: true, }, { name: "empty-connectionId", args: args{ status: StatusClosed, }, wantErr: true, wantIsErr: db.ErrInvalidParameter, }, { name: "empty-status", args: args{ connectionId: connection.PublicId, }, wantErr: true, wantIsErr: db.ErrInvalidParameter, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) got, err := NewConnectionState(tt.args.connectionId, tt.args.status) if tt.wantErr { require.Error(err) assert.True(errors.Is(err, tt.wantIsErr)) return } require.NoError(err) assert.Equal(tt.want, got) if tt.create { err = db.New(conn).Create(context.Background(), got) if tt.wantCreateErr { assert.Error(err) return } else { assert.NoError(err) } } }) } } func TestConnectionState_Delete(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) wrapper := db.TestWrapper(t) iamRepo := iam.TestRepo(t, conn, wrapper) s := TestDefaultSession(t, conn, wrapper, iamRepo) c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) c2 := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) tests := []struct { name string state *ConnectionState deleteConnectionStateId string wantRowsDeleted int wantErr bool wantErrMsg string }{ { name: "valid", state: TestConnectionState(t, conn, c.PublicId, StatusClosed), wantErr: false, wantRowsDeleted: 1, }, { name: "bad-id", state: TestConnectionState(t, conn, c2.PublicId, StatusClosed), deleteConnectionStateId: func() string { id, err := db.NewPublicId(ConnectionStatePrefix) require.NoError(t, err) return id }(), wantErr: false, wantRowsDeleted: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) var initialState ConnectionState err := rw.LookupWhere(context.Background(), &initialState, "connection_id = ? and state = ?", tt.state.ConnectionId, tt.state.Status) require.NoError(err) deleteState := allocConnectionState() if tt.deleteConnectionStateId != "" { deleteState.ConnectionId = tt.deleteConnectionStateId } else { deleteState.ConnectionId = tt.state.ConnectionId } deleteState.StartTime = initialState.StartTime deletedRows, err := rw.Delete(context.Background(), &deleteState) if tt.wantErr { require.Error(err) return } require.NoError(err) if tt.wantRowsDeleted == 0 { assert.Equal(tt.wantRowsDeleted, deletedRows) return } assert.Equal(tt.wantRowsDeleted, deletedRows) foundState := allocConnectionState() err = rw.LookupWhere(context.Background(), &foundState, "connection_id = ? and start_time = ?", tt.state.ConnectionId, initialState.StartTime) require.Error(err) assert.True(errors.Is(db.ErrRecordNotFound, err)) }) } } func TestConnectionState_Clone(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres") wrapper := db.TestWrapper(t) iamRepo := iam.TestRepo(t, conn, wrapper) t.Run("valid", func(t *testing.T) { assert := assert.New(t) s := TestDefaultSession(t, conn, wrapper, iamRepo) c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) state := TestConnectionState(t, conn, c.PublicId, StatusConnected) cp := state.Clone() assert.Equal(cp.(*ConnectionState), state) }) t.Run("not-equal", func(t *testing.T) { assert := assert.New(t) s := TestDefaultSession(t, conn, wrapper, iamRepo) c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) state := TestConnectionState(t, conn, c.PublicId, StatusConnected) state2 := TestConnectionState(t, conn, c.PublicId, StatusConnected) cp := state.Clone() assert.NotEqual(cp.(*ConnectionState), state2) }) } func TestConnectionState_SetTableName(t *testing.T) { t.Parallel() defaultTableName := defaultConnectionStateTableName tests := []struct { name string setNameTo string want string }{ { name: "new-name", setNameTo: "new-name", want: "new-name", }, { name: "reset to default", setNameTo: "", want: defaultTableName, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) def := allocConnectionState() require.Equal(defaultTableName, def.TableName()) s := allocConnectionState() s.SetTableName(tt.setNameTo) assert.Equal(tt.want, s.TableName()) }) } }