You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/bsr/decode.go

215 lines
6.0 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package bsr
import (
"bytes"
"compress/gzip"
"context"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io"
"log"
"github.com/hashicorp/boundary/internal/bsr/internal/is"
"github.com/hashicorp/boundary/internal/bsr/kms"
)
// DecodeChunkFunc is a function that given a BaseChunk and the data portion
// of a chunk, will decode the data into a Chunk.
type DecodeChunkFunc func(ctx context.Context, bc *BaseChunk, data []byte) (Chunk, error)
type decodeFuncsRegistry map[Protocol]map[ChunkType]DecodeChunkFunc
func (r decodeFuncsRegistry) Get(p Protocol, t ChunkType) (DecodeChunkFunc, bool) {
switch t {
case ChunkHeader:
return DecodeHeader, true
case ChunkEnd:
return DecodeEnd, true
default:
protocol, ok := r[p]
if !ok {
return nil, false
}
df, ok := protocol[t]
return df, ok
}
}
var chunkTypes decodeFuncsRegistry
// RegisterChunkType registers a DecodeChunkFunc for the given Protocol and
// ChunkType. A given Protocol and ChunkType can only have one decode function
// registered.
func RegisterChunkType(p Protocol, t ChunkType, df DecodeChunkFunc) error {
const op = "bsr.RegisterChunkType"
if chunkTypes == nil {
chunkTypes = make(map[Protocol]map[ChunkType]DecodeChunkFunc)
}
protocol, ok := chunkTypes[p]
if !ok {
protocol = make(map[ChunkType]DecodeChunkFunc)
}
_, ok = protocol[t]
if ok {
return fmt.Errorf("%s: %s %s: %w", op, p, t, ErrAlreadyRegistered)
}
protocol[t] = df
chunkTypes[p] = protocol
return nil
}
// ChunkDecoder is used to decode the data read from an io.Reader into Chunks.
type ChunkDecoder struct {
r io.Reader
compression Compression
encryption Encryption
keys *kms.Keys
}
// NewChunkDecoder creates a ChunkDecoder that can decode the data read from
// the given io.Reader. Supports the WithKeys option which will be used when
// support for encrypted chunks is added.
func NewChunkDecoder(_ context.Context, r io.Reader, options ...Option) (*ChunkDecoder, error) {
const op = "bsr.NewChunkDecoder"
if is.Nil(r) {
return nil, fmt.Errorf("%s: reader cannot be nil: %w", op, ErrInvalidParameter)
}
opts := getOpts(options...)
return &ChunkDecoder{
r: r,
compression: NoCompression,
encryption: NoEncryption,
keys: opts.withKeys,
}, nil
}
// Decode will read from the io.Reader and return the next Chunk that it
// decodes. If the io.Reader reaches EOF, Decode will return an error of
// io.EOF. Note that this is not a wrapped error so it can be checked for with
// err == io.EOF like with most io.Readers. If there is an unexpected error
// while decoding, such as an unsupported chunk type or corrupted data, an
// ErrChunkDecode error will be returned. This will be a wrapped error and
// should be checked for with errors.Is.
func (d *ChunkDecoder) Decode(ctx context.Context) (Chunk, error) {
const op = "bsr.(ChunkDecoder).Decode"
var b *BaseChunk
buf := make([]byte, chunkBaseSize)
crcBuf := make([]byte, crcSize)
_, err := io.ReadAtLeast(d.r, buf, chunkBaseSize)
if err != nil {
if err == io.EOF || errors.Is(err, io.EOF) {
return nil, io.EOF
}
return nil, fmt.Errorf("%s: %w: %w", op, err, ErrChunkDecode)
}
var length uint32
var protocol Protocol
var chunkType ChunkType
var direction Direction
crc := crc32.NewIEEE()
length, buf = binary.BigEndian.Uint32(buf[:lengthSize]), buf[lengthSize:]
if length > MaxChunkDataLength {
return nil, fmt.Errorf("%s: chunk length %d exceeds max chunk length of %d: %w", op, length, MaxChunkDataLength, ErrChunkDecode)
}
databuf := make([]byte, length)
_, err = io.ReadAtLeast(d.r, databuf, int(length))
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return nil, fmt.Errorf("%s: %w: missing data: %w", op, err, ErrChunkDecode)
}
crc.Write(append(buf, databuf...))
_, err = io.ReadAtLeast(d.r, crcBuf, crcSize)
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return nil, fmt.Errorf("%s: %w: missing crc: %w", op, err, ErrChunkDecode)
}
if crc.Sum32() != binary.BigEndian.Uint32(crcBuf) {
log.Printf("%x", crc.Sum32())
return nil, fmt.Errorf("%s: chunk crc did not match: %w", op, ErrChunkDecode)
}
protocol, buf = Protocol(string(buf[:protocolSize])), buf[protocolSize:]
chunkType, buf = ChunkType(string(buf[:chunkTypeSize])), buf[chunkTypeSize:]
direction, buf = Direction(buf[:directionSize][0]), buf[directionSize:]
timestamp, err := decodeTimestamp(buf[:timestampSize])
if err != nil {
return nil, fmt.Errorf("%s: error parsing timestamp: %w: %w", op, err, ErrChunkDecode)
}
buf = buf[timestampSize:]
if len(buf) != 0 {
return nil, fmt.Errorf("%s: extra data in chunk: %w", op, ErrChunkDecode)
}
b, err = NewBaseChunk(ctx, protocol, direction, timestamp, chunkType)
if err != nil {
return nil, fmt.Errorf("%s: %w: %w", op, err, ErrChunkDecode)
}
decompressBuf := bytes.NewBuffer(databuf)
var decompressor io.ReadCloser
switch chunkType {
// HEAD and END are never compressed
case ChunkHeader, ChunkEnd:
decompressor = newNullCompressionReader(decompressBuf)
default:
switch d.compression {
case GzipCompression:
decompressor, err = gzip.NewReader(decompressBuf)
if err != nil {
return nil, fmt.Errorf("%s: %w: %w", op, err, ErrChunkDecode)
}
default:
decompressor = newNullCompressionReader(decompressBuf)
}
}
// Decompressed chunk data should not go beyond MaxChunkDataLength/ prevent allocations beyond this limit
limitedDecompressionReader := io.LimitReader(decompressor, MaxChunkDataLength)
decompressed, err := io.ReadAll(limitedDecompressionReader)
if err != nil {
return nil, fmt.Errorf("%s: %w: %w", op, err, ErrChunkDecode)
}
df, ok := chunkTypes.Get(b.Protocol, b.Type)
if !ok {
return nil, fmt.Errorf("%s: unsupported chunk type %s for protocol %s: %w", op, b.Type, b.Protocol, ErrChunkDecode)
}
c, err := df(ctx, b, decompressed)
if err != nil {
return nil, fmt.Errorf("%s: %w: %w", op, err, ErrChunkDecode)
}
switch cc := c.(type) {
case *HeaderChunk:
d.compression = cc.Compression
d.encryption = cc.Encryption
}
return c, nil
}