You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/db/db.go

188 lines
5.0 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package db
import (
"context"
"database/sql"
"fmt"
"sync/atomic"
"time"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/observability/event"
"github.com/hashicorp/go-dbw"
_ "github.com/jackc/pgx/v4"
"gorm.io/driver/postgres"
)
func init() {
dbw.InitNonCreatableFields([]string{"CreateTime", "UpdateTime"})
dbw.InitNonUpdatableFields([]string{"PublicId", "CreateTime", "UpdateTime"})
}
type DbType int
const (
UnknownDB DbType = 0
Postgres DbType = 1
)
func (db DbType) String() string {
return [...]string{
"unknown",
"postgres",
}[db]
}
func StringToDbType(dialect string) (DbType, error) {
switch dialect {
case "postgres":
return Postgres, nil
default:
return UnknownDB, fmt.Errorf("%s is an unknown dialect", dialect)
}
}
// DB is a wrapper around the ORM
type DB struct {
wrapped *atomic.Pointer[dbw.DB]
}
type closeDbFn func(context.Context)
var CloseSwappedDbDuration = 5 * time.Minute
// Swap replaces *DB's underlying database object with the one from `newDB`.
// It returns a function that calls Close() on the outgoing database object.
// Note: Swap does not verify the incoming *DB object is correctly set-up.
func (db *DB) Swap(ctx context.Context, newDB *DB) (closeDbFn, error) {
const op = "db.(DB).Swap"
if newDB == nil || newDB.wrapped == nil || newDB.wrapped.Load() == nil {
return nil, fmt.Errorf("no new db object present")
}
if db == nil || db.wrapped == nil || db.wrapped.Load() == nil {
// TBD: We could be helpful here and set the new DB instead of err?
return nil, fmt.Errorf("no current db is present to swap, aborting")
}
// Grab the old db to allow for cleanup after swap.
oldDbw := db.wrapped.Swap(newDB.wrapped.Load())
closeOldDbFn := func(ctx context.Context) {
go func() {
maxTime := time.Now().Add(CloseSwappedDbDuration)
t := time.NewTicker(time.Second)
var done bool
for {
select {
case <-t.C:
if time.Now().After(maxTime) || done {
t.Stop()
if err := oldDbw.Close(ctx); err != nil {
event.WriteError(ctx, op, errors.Wrap(ctx, err, op))
}
return
}
sqlDb, err := oldDbw.SqlDB(ctx)
if err != nil {
event.WriteError(ctx, op, fmt.Errorf("unable to load old sqldb to check stats"))
continue
}
stats := sqlDb.Stats()
if stats.InUse == 0 {
done = true
}
case <-ctx.Done():
t.Stop()
event.WriteError(ctx, op, fmt.Errorf("context canceled before old database connection was closed, aborting"))
return
}
}
}()
}
return closeOldDbFn, nil
}
// Debug will enable/disable debug info for the connection
func (db *DB) Debug(on bool) {
db.wrapped.Load().Debug(on)
}
// SqlDB returns the underlying sql.DB
//
// Note: This func is not named DB(), because our choice of ORM (gorm) which is
// embedded for various reasons, already exports a DB() func and it's base type
// is also an export DB type
func (d *DB) SqlDB(ctx context.Context) (*sql.DB, error) {
const op = "db.(DB).SqlDB"
if d.wrapped == nil {
return nil, errors.New(ctx, errors.Internal, op, "missing underlying database")
}
return d.wrapped.Load().SqlDB(ctx)
}
// Close the underlying sql.DB
func (d *DB) Close(ctx context.Context) error {
const op = "db.(DB).Close"
if d.wrapped == nil {
return errors.New(ctx, errors.Internal, op, "missing underlying database")
}
return d.wrapped.Load().Close(ctx)
}
// Open a database connection which is long-lived. The options of
// WithGormFormatter and WithMaxOpenConnections are supported.
//
// Note: Consider if you need to call Close() on the returned DB. Typically the
// answer is no, but there are occasions when it's necessary. See the sql.DB
// docs for more information.
func Open(ctx context.Context, dbType DbType, connectionUrl string, opt ...Option) (*DB, error) {
const op = "db.Open"
var dialect dbw.Dialector
switch dbType {
case Postgres:
dialect = postgres.New(postgres.Config{
DSN: connectionUrl,
},
)
default:
return nil, fmt.Errorf("unable to open %s database type", dbType)
}
opts := GetOpts(opt...)
var wrappedOpts []dbw.Option
if opts.withGormFormatter != nil {
wrappedOpts = append(wrappedOpts, dbw.WithLogger(opts.withGormFormatter))
}
if opts.withMaxOpenConnections > 0 {
if opts.withMaxOpenConnections < 5 {
return nil, fmt.Errorf("max_open_connections cannot be below 5")
}
wrappedOpts = append(wrappedOpts, dbw.WithMaxOpenConnections(opts.withMaxOpenConnections))
}
wrapped, err := dbw.OpenWith(dialect, wrappedOpts...)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
sdb, err := wrapped.SqlDB(ctx)
if err != nil {
return nil, err
}
if opts.withMaxIdleConnections != nil {
sdb.SetMaxIdleConns(*opts.withMaxIdleConnections)
}
if opts.withConnMaxIdleTimeDuration != nil {
sdb.SetConnMaxIdleTime(*opts.withConnMaxIdleTimeDuration)
}
ret := &DB{wrapped: new(atomic.Pointer[dbw.DB])}
ret.wrapped.Store(wrapped)
return ret, nil
}