You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/db/schema/postgres/postgres.go

477 lines
15 KiB

// The MIT License (MIT)
//
// Original Work
// Copyright (c) 2016 Matthias Kadenbach
// https://github.com/mattes/migrate
//
// Modified Work
// Copyright (c) 2018 Dale Hui
// https://github.com/golang-migrate/migrate
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package postgres
import (
"context"
"database/sql"
"fmt"
"io"
"io/ioutil"
"strconv"
"strings"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/go-multierror"
"github.com/lib/pq"
)
// schemaAccessLockId is a Lock key used to ensure a single boundary binary is operating
// on a postgres server at a time. The value has no meaning and was picked randomly.
const (
schemaAccessLockId int64 = 3865661975
nilVersion = -1
)
var defaultMigrationsTable = "boundary_schema_version"
// Postgres is a driver usable by a boundary schema manager.
// This struct is not thread safe.
type Postgres struct {
// Locking and unlocking need to use the same connection
conn *sql.Conn
db *sql.DB
tx *sql.Tx
}
// New returns a postgres pointer with the provided db verified as
// connectable and a version table being initialized.
func New(ctx context.Context, instance *sql.DB) (*Postgres, error) {
const op = "postgres.New"
if err := instance.PingContext(ctx); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
conn, err := instance.Conn(ctx)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
px := &Postgres{
conn: conn,
db: instance,
}
return px, nil
}
// TrySharedLock attempts to capture a shared lock. If it is not successful it returns an error.
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
func (p *Postgres) TrySharedLock(ctx context.Context) error {
const op = "postgres.(Postgres).TrySharedLock"
const query = "select pg_try_advisory_lock_shared($1)"
r := p.conn.QueryRowContext(ctx, query, schemaAccessLockId)
if r.Err() != nil {
return errors.Wrap(ctx, r.Err(), op)
}
var gotLock bool
if err := r.Scan(&gotLock); err != nil {
return errors.Wrap(ctx, err, op)
}
if !gotLock {
return errors.New(ctx, errors.MigrationLock, op, "Lock failed")
}
return nil
}
// TryLock attempts to capture an exclusive lock. If it is not successful it returns an error.
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
func (p *Postgres) TryLock(ctx context.Context) error {
const op = "postgres.(Postgres).TryLock"
const query = "select pg_try_advisory_lock($1)"
r := p.conn.QueryRowContext(ctx, query, schemaAccessLockId)
if r.Err() != nil {
return errors.Wrap(ctx, r.Err(), op)
}
var gotLock bool
if err := r.Scan(&gotLock); err != nil {
return errors.Wrap(ctx, err, op)
}
if !gotLock {
return errors.New(ctx, errors.MigrationLock, op, "Lock failed")
}
return nil
}
// Lock calls pg_advisory_lock with the provided context and returns an error
// if we were unable to get the lock before the context cancels.
func (p *Postgres) Lock(ctx context.Context) error {
const op = "postgres.(Postgres).Lock"
const query = "select pg_advisory_lock($1)"
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
return errors.Wrap(ctx, err, op)
}
return nil
}
// Unlock calls pg_advisory_unlock and returns an error if we were unable to
// release the lock before the context cancels.
func (p *Postgres) Unlock(ctx context.Context) error {
const op = "postgres.(Postgres).Unlock"
const query = `select pg_advisory_unlock($1)`
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
return errors.Wrap(ctx, err, op)
}
return nil
}
// UnlockShared calls pg_advisory_unlock_shared and returns an error if we were unable to
// release the lock before the context cancels.
func (p *Postgres) UnlockShared(ctx context.Context) error {
const op = "postgres.(Postgres).UnlockShared"
query := `select pg_advisory_unlock_shared($1)`
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
return errors.Wrap(ctx, err, op)
}
return nil
}
// Rollback rolls back the outstanding transaction.
// Calling Rollback when there is not an outstanding transaction is an error.
func (p *Postgres) Rollback() error {
const op = "postgres.(Postgres).Rollback"
defer func() {
p.tx = nil
}()
if p.tx == nil {
return errors.NewDeprecated(errors.MigrationIntegrity, op, "no pending transaction")
}
if err := p.tx.Rollback(); err != nil {
return errors.WrapDeprecated(err, op)
}
return nil
}
// StartRun starts a transaction that all subsequent calls to Run will use.
func (p *Postgres) StartRun(ctx context.Context) error {
tx, err := p.conn.BeginTx(ctx, nil)
if err != nil {
return err
}
p.tx = tx
return nil
}
// CommitRun commits the pending transaction if there is one
func (p *Postgres) CommitRun() error {
const op = "postgres.(Postgres).CommitRun"
defer func() {
p.tx = nil
}()
if p.tx == nil {
return errors.NewDeprecated(errors.MigrationIntegrity, op, "no pending transaction")
}
if err := p.tx.Commit(); err != nil {
if errRollback := p.tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return errors.WrapDeprecated(err, op)
}
return nil
}
type execContexter interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
// Run executes the sql provided in the passed in io.Reader. The contents of the reader must
// fit in memory as the full content is read into a string before being passed to the
// backing database. EnsureVersionTable should be ran prior to this call.
func (p *Postgres) Run(ctx context.Context, migration io.Reader, version int) error {
const op = "postgres.(Postgres).Run"
migr, err := ioutil.ReadAll(migration)
if err != nil {
return errors.Wrap(ctx, err, op)
}
// Run migration
query := string(migr)
var extr execContexter = p.conn
rollback := func() error { return nil }
if p.tx != nil {
extr = p.tx
rollback = func() error {
defer func() { p.tx = nil }()
return p.tx.Rollback()
}
}
// set the version first, so logs will be associated with this new version.
// if there's an error, it will get rollback
if err := p.setVersion(ctx, version, false); err != nil {
return errors.Wrap(ctx, err, op)
}
if _, err := extr.ExecContext(ctx, query); err != nil {
if rollbackErr := rollback(); rollbackErr != nil {
err = multierror.Append(err, rollbackErr)
}
if pgErr, ok := err.(*pq.Error); ok {
var line uint
var col uint
var lineColOK bool
if pgErr.Position != "" {
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
line, col, lineColOK = computeLineFromPos(query, int(pos))
}
}
message := "migration failed"
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
message = fmt.Sprintf("%s, on line %v: %s", message, line, migr)
return errors.Wrap(ctx, err, op, errors.WithMsg(message))
}
return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("migration failed: %s", migr)))
}
return nil
}
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
// replace crlf with lf
s = strings.Replace(s, "\r\n", "\n", -1)
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
runes := []rune(s)
if pos > len(runes) {
return 0, 0, false
}
sel := runes[:pos]
line = uint(runesCount(sel, newLine) + 1)
col = uint(pos - 1 - runesLastIndex(sel, newLine))
return line, col, true
}
const newLine = '\n'
func runesCount(input []rune, target rune) int {
var count int
for _, r := range input {
if r == target {
count++
}
}
return count
}
func runesLastIndex(input []rune, target rune) int {
for i := len(input) - 1; i >= 0; i-- {
if input[i] == target {
return i
}
}
return -1
}
// setVersion sets the version number, and whether the database is in a dirty state.
// A version value of -1 indicates no version is set.
func (p *Postgres) setVersion(ctx context.Context, version int, dirty bool) error {
const op = "postgres.(Postgres).SetVersion"
tx := p.tx
var err error
if tx == nil {
tx, err = p.conn.BeginTx(ctx, nil)
if err != nil {
return err
}
}
rollback := func() error {
defer func() { p.tx = nil }()
return tx.Rollback()
}
query := `truncate ` + pq.QuoteIdentifier(defaultMigrationsTable)
if _, err := tx.ExecContext(ctx, query); err != nil {
if errRollback := rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return errors.Wrap(ctx, err, op)
}
// Also re-write the schema Version for nil dirty versions to prevent
// empty schema Version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == nilVersion && dirty) {
query = `insert into ` + pq.QuoteIdentifier(defaultMigrationsTable) +
` (version, dirty) values ($1, $2)`
if _, err := tx.ExecContext(ctx, query, version, dirty); err != nil {
if errRollback := rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return errors.Wrap(ctx, err, op)
}
}
if p.tx == nil {
if err := tx.Commit(); err != nil {
return errors.Wrap(ctx, err, op)
}
}
return nil
}
// CurrentState returns the version, if the database was ever initialized
// previously, if it is currently in a dirty state, and any error. A version
// value of -1 indicates no version is set.
func (p *Postgres) CurrentState(ctx context.Context) (version int, previouslyRan, dirty bool, err error) {
const op = "postgres.(Postgres).CurrentState"
version = nilVersion
previouslyRan, dirty = false, false
tableQuery := `select table_name from information_schema.tables where table_schema=(select current_schema()) and table_name in ('schema_migrations', '` + defaultMigrationsTable + `')`
tableResult, err := p.conn.QueryContext(ctx, tableQuery)
if err != nil {
return nilVersion, previouslyRan, dirty, errors.Wrap(ctx, err, op)
}
defer tableResult.Close()
if !tableResult.Next() {
// No version table found
return nilVersion, previouslyRan, dirty, nil
}
tableName := defaultMigrationsTable
if err := tableResult.Scan(&tableName); err != nil {
return nilVersion, previouslyRan, dirty, errors.Wrap(ctx, err, op)
}
previouslyRan = true
if tableResult.Next() {
return nilVersion, previouslyRan, dirty, errors.New(ctx, errors.MigrationIntegrity, op, "both old and new migration tables exist")
}
query := `select version, dirty from ` + pq.QuoteIdentifier(tableName)
results, err := p.conn.QueryContext(ctx, query)
if err != nil {
return nilVersion, previouslyRan, dirty, errors.Wrap(ctx, err, op)
}
defer results.Close()
if !results.Next() {
// no version recorded
return nilVersion, previouslyRan, dirty, nil
}
if err := results.Scan(&version, &dirty); err != nil {
return nilVersion, previouslyRan, dirty, errors.Wrap(ctx, err, op)
}
if results.Next() {
return nilVersion, previouslyRan, dirty, errors.New(ctx, errors.MigrationIntegrity, op, "to many versions in version table")
}
return version, previouslyRan, dirty, nil
}
func (p *Postgres) drop(ctx context.Context) (err error) {
const op = "postgres.(Postgres).drop"
// select all tables in current schema
query := `select table_name from information_schema.tables where table_schema=(select current_schema()) and table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(ctx, query)
if err != nil {
return errors.Wrap(ctx, err, op)
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
err = errors.Wrap(ctx, err, op)
}
}()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return errors.Wrap(ctx, err, op)
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if err := tables.Err(); err != nil {
return errors.Wrap(ctx, err, op)
}
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `drop table if exists ` + pq.QuoteIdentifier(t) + ` cascade`
if _, err := p.conn.ExecContext(ctx, query); err != nil {
return errors.Wrap(ctx, err, op)
}
}
}
return nil
}
// EnsureVersionTable checks if versions table exists and, if not, creates it.
func (p *Postgres) EnsureVersionTable(ctx context.Context) (err error) {
const op = "postgres.(Postgres).EnsureVersionTable"
var extr execContexter = p.conn
rollback := func() error { return nil }
if p.tx != nil {
extr = p.tx
rollback = func() error {
defer func() { p.tx = nil }()
return p.tx.Rollback()
}
}
query := `select exists (select 1 from information_schema.tables where table_schema=(select current_schema()) and table_name = '` + defaultMigrationsTable + `');`
exists := false
if err := extr.QueryRowContext(ctx, query).Scan(&exists); err != nil {
if wpErr := rollback(); wpErr != nil {
err = multierror.Append(err, wpErr)
}
return errors.Wrap(ctx, err, op)
}
if exists {
return nil
}
updateQuery := `alter table if exists schema_migrations rename to ` + defaultMigrationsTable + `;`
if _, err = extr.ExecContext(ctx, updateQuery); err != nil {
if wpErr := rollback(); wpErr != nil {
err = multierror.Append(err, wpErr)
}
return errors.Wrap(ctx, err, op)
}
createStmt := `create table if not exists ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` (version bigint primary key, dirty boolean not null)`
if _, err = extr.ExecContext(ctx, createStmt); err != nil {
if wpErr := rollback(); wpErr != nil {
err = multierror.Append(err, wpErr)
}
return errors.Wrap(ctx, err, op)
}
return nil
}