From 26bd86fb07a2577c4e1bd2fd27e0a6a1e519b515 Mon Sep 17 00:00:00 2001 From: Timothy Messier Date: Wed, 15 Mar 2023 14:50:42 +0000 Subject: [PATCH] feat(bsr): Add chunk encoder --- internal/bsr/chunk.go | 14 +- internal/bsr/chunk_end_test.go | 8 +- internal/bsr/chunk_header.go | 7 +- internal/bsr/chunk_header_test.go | 12 +- internal/bsr/chunk_test.go | 10 +- internal/bsr/encode.go | 107 ++++++++ internal/bsr/encode_test.go | 394 ++++++++++++++++++++++++++++++ 7 files changed, 526 insertions(+), 26 deletions(-) create mode 100644 internal/bsr/encode.go create mode 100644 internal/bsr/encode_test.go diff --git a/internal/bsr/chunk.go b/internal/bsr/chunk.go index c881cf6005..8ad8cf19f7 100644 --- a/internal/bsr/chunk.go +++ b/internal/bsr/chunk.go @@ -5,8 +5,7 @@ package bsr import ( "context" - - "github.com/hashicorp/boundary/internal/errors" + "fmt" ) // sizes @@ -16,7 +15,8 @@ const ( chunkTypeSize = 4 directionSize = 1 - chunkBaseSize = lengthSize + protocolSize + chunkTypeSize + directionSize + timestampSize + crcDataSize = protocolSize + chunkTypeSize + directionSize + timestampSize + chunkBaseSize = lengthSize + crcDataSize crcSize = 4 ) @@ -65,16 +65,16 @@ type BaseChunk struct { func NewBaseChunk(ctx context.Context, p Protocol, d Direction, t *Timestamp, typ ChunkType) (*BaseChunk, error) { const op = "bsr.NewBaseChunk" if !ValidProtocol(p) { - return nil, errors.New(ctx, errors.InvalidParameter, op, "protocol name cannot be greater than 4 characters") + return nil, fmt.Errorf("%s: protocol name cannot be greater than 4 characters: %w", op, ErrInvalidParameter) } if !ValidDirection(d) { - return nil, errors.New(ctx, errors.InvalidParameter, op, "invalid direction") + return nil, fmt.Errorf("%s: invalid direction: %w", op, ErrInvalidParameter) } if t == nil { - return nil, errors.New(ctx, errors.InvalidParameter, op, "timestamp must not be nil") + return nil, fmt.Errorf("%s: timestamp must not be nil: %w", op, ErrInvalidParameter) } if !ValidChunkType(typ) { - return nil, errors.New(ctx, errors.InvalidParameter, op, "chunk type cannot be greater than 4 characters") + return nil, fmt.Errorf("%s: chunk type cannot be greater than 4 characters: %w", op, ErrInvalidParameter) } return &BaseChunk{ diff --git a/internal/bsr/chunk_end_test.go b/internal/bsr/chunk_end_test.go index 69b4582e75..6bbc19afc8 100644 --- a/internal/bsr/chunk_end_test.go +++ b/internal/bsr/chunk_end_test.go @@ -5,11 +5,11 @@ package bsr_test import ( "context" + "errors" "testing" "time" "github.com/hashicorp/boundary/internal/bsr" - "github.com/hashicorp/boundary/internal/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -62,7 +62,7 @@ func TestNewEndChunk(t *testing.T) { bsr.Inbound, bsr.NewTimestamp(now), nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewBaseChunk", "protocol name cannot be greater than 4 characters"), + errors.New("bsr.NewBaseChunk: protocol name cannot be greater than 4 characters: invalid parameter"), }, { "invalid-direction", @@ -70,7 +70,7 @@ func TestNewEndChunk(t *testing.T) { bsr.UnknownDirection, bsr.NewTimestamp(now), nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewBaseChunk", "invalid direction"), + errors.New("bsr.NewBaseChunk: invalid direction: invalid parameter"), }, { "invalid-timestamp", @@ -78,7 +78,7 @@ func TestNewEndChunk(t *testing.T) { bsr.Inbound, nil, nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewBaseChunk", "timestamp must not be nil"), + errors.New("bsr.NewBaseChunk: timestamp must not be nil: invalid parameter"), }, } diff --git a/internal/bsr/chunk_header.go b/internal/bsr/chunk_header.go index 0b4be332f6..0a2383bb47 100644 --- a/internal/bsr/chunk_header.go +++ b/internal/bsr/chunk_header.go @@ -5,8 +5,7 @@ package bsr import ( "context" - - "github.com/hashicorp/boundary/internal/errors" + "fmt" ) // HeaderChunk is the first chunk in a BSR data file. @@ -47,11 +46,11 @@ func NewHeader(ctx context.Context, p Protocol, d Direction, t *Timestamp, c Com } if !ValidCompression(c) { - return nil, errors.New(ctx, errors.InvalidParameter, op, "invalid compression") + return nil, fmt.Errorf("%s: invalid compression: %w", op, ErrInvalidParameter) } if !ValidEncryption(e) { - return nil, errors.New(ctx, errors.InvalidParameter, op, "invalid encryption") + return nil, fmt.Errorf("%s: invalid encryption: %w", op, ErrInvalidParameter) } return &HeaderChunk{ diff --git a/internal/bsr/chunk_header_test.go b/internal/bsr/chunk_header_test.go index 2fd865e56c..ff0489a2b5 100644 --- a/internal/bsr/chunk_header_test.go +++ b/internal/bsr/chunk_header_test.go @@ -5,11 +5,11 @@ package bsr_test import ( "context" + "errors" "testing" "time" "github.com/hashicorp/boundary/internal/bsr" - "github.com/hashicorp/boundary/internal/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -80,7 +80,7 @@ func TestNewHeaderChunk(t *testing.T) { bsr.NoEncryption, "sess_123456789", nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewBaseChunk", "protocol name cannot be greater than 4 characters"), + errors.New("bsr.NewBaseChunk: protocol name cannot be greater than 4 characters: invalid parameter"), }, { "invalid-direction", @@ -91,7 +91,7 @@ func TestNewHeaderChunk(t *testing.T) { bsr.NoEncryption, "sess_123456789", nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewBaseChunk", "invalid direction"), + errors.New("bsr.NewBaseChunk: invalid direction: invalid parameter"), }, { "invalid-timestamp", @@ -102,7 +102,7 @@ func TestNewHeaderChunk(t *testing.T) { bsr.NoEncryption, "sess_123456789", nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewBaseChunk", "timestamp must not be nil"), + errors.New("bsr.NewBaseChunk: timestamp must not be nil: invalid parameter"), }, { "invalid-compression", @@ -113,7 +113,7 @@ func TestNewHeaderChunk(t *testing.T) { bsr.NoEncryption, "sess_123456789", nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewHeader", "invalid compression"), + errors.New("bsr.NewHeader: invalid compression: invalid parameter"), }, { "invalid-encryption", @@ -124,7 +124,7 @@ func TestNewHeaderChunk(t *testing.T) { bsr.Encryption(255), "sess_123456789", nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewHeader", "invalid encryption"), + errors.New("bsr.NewHeader: invalid encryption: invalid parameter"), }, } diff --git a/internal/bsr/chunk_test.go b/internal/bsr/chunk_test.go index 4e80346ef6..d2b0644641 100644 --- a/internal/bsr/chunk_test.go +++ b/internal/bsr/chunk_test.go @@ -5,11 +5,11 @@ package bsr_test import ( "context" + "errors" "testing" "time" "github.com/hashicorp/boundary/internal/bsr" - "github.com/hashicorp/boundary/internal/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -48,7 +48,7 @@ func TestNewBaseChunk(t *testing.T) { bsr.NewTimestamp(now), bsr.ChunkType("TEST"), nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewBaseChunk", "protocol name cannot be greater than 4 characters"), + errors.New("bsr.NewBaseChunk: protocol name cannot be greater than 4 characters: invalid parameter"), }, { "invalid-direction", @@ -57,7 +57,7 @@ func TestNewBaseChunk(t *testing.T) { bsr.NewTimestamp(now), bsr.ChunkType("TEST"), nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewBaseChunk", "invalid direction"), + errors.New("bsr.NewBaseChunk: invalid direction: invalid parameter"), }, { "invalid-timestamp", @@ -66,7 +66,7 @@ func TestNewBaseChunk(t *testing.T) { nil, bsr.ChunkType("TEST"), nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewBaseChunk", "timestamp must not be nil"), + errors.New("bsr.NewBaseChunk: timestamp must not be nil: invalid parameter"), }, { "invalid-chunk-type", @@ -75,7 +75,7 @@ func TestNewBaseChunk(t *testing.T) { bsr.NewTimestamp(now), bsr.ChunkType("TEST_INVALID"), nil, - errors.New(ctx, errors.InvalidParameter, "bsr.NewBaseChunk", "chunk type cannot be greater than 4 characters"), + errors.New("bsr.NewBaseChunk: chunk type cannot be greater than 4 characters: invalid parameter"), }, } diff --git a/internal/bsr/encode.go b/internal/bsr/encode.go new file mode 100644 index 0000000000..a4435523ea --- /dev/null +++ b/internal/bsr/encode.go @@ -0,0 +1,107 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package bsr + +import ( + "bytes" + "compress/gzip" + "context" + "encoding/binary" + "fmt" + "hash/crc32" + "io" +) + +// ChunkEncoder will encode a chunk and write it to the writer. +// It will compress the chunk data based on the compression. +type ChunkEncoder struct { + w io.Writer + compression Compression + encryption Encryption +} + +// NewChunkEncoder creates a ChunkEncoder. +func NewChunkEncoder(ctx context.Context, w io.Writer, c Compression, e Encryption) (*ChunkEncoder, error) { + const op = "bsr.NewChunkEncoder" + + if w == nil { + return nil, fmt.Errorf("%s: writer cannot be nil: %w", op, ErrInvalidParameter) + } + + if !ValidCompression(c) { + return nil, fmt.Errorf("%s: invalid compression: %w", op, ErrInvalidParameter) + } + + if !ValidEncryption(e) { + return nil, fmt.Errorf("%s: invalid encryption: %w", op, ErrInvalidParameter) + } + + return &ChunkEncoder{ + w: w, + compression: c, + encryption: e, + }, nil +} + +// Encode serializes a Chunk and writes it with the encoder's writer. +func (e ChunkEncoder) Encode(ctx context.Context, c Chunk) (int, error) { + data, err := c.MarshalData(ctx) + if err != nil { + return 0, err + } + + var buf bytes.Buffer + var compressor io.WriteCloser + switch c.GetType() { + // Header should not be compressed since we need to read it prior to knowing + // what compression was used to check the compression bit. + // End should not be compressed since it has no data and compressing an empty + // byte slice just adds data in the form of the compression magic strings. + case ChunkHeader, ChunkEnd: + compressor = newNullCompressionWriter(&buf) + default: + switch e.compression { + case GzipCompression: + compressor = gzip.NewWriter(&buf) + default: + compressor = newNullCompressionWriter(&buf) + } + } + + if _, err := compressor.Write(data); err != nil { + return 0, err + } + compressor.Close() + length := buf.Len() + + t := c.GetTimestamp().marshal() + + // calculate CRC for protocol+type+dir+timestamp+data + crced := make([]byte, 0, chunkBaseSize+length) + crced = append(crced, c.GetProtocol()...) + crced = append(crced, c.GetType()...) + crced = append(crced, byte(c.GetDirection())) + crced = append(crced, t...) + crced = append(crced, buf.Bytes()...) + + crc := crc32.NewIEEE() + crc.Write(crced) + + d := make([]byte, 0, chunkBaseSize+length+crcSize) + d = binary.BigEndian.AppendUint32(d, uint32(length)) + d = append(d, crced...) + d = binary.BigEndian.AppendUint32(d, crc.Sum32()) + + return e.w.Write(d) +} + +// Close closes the encoder. +func (e *ChunkEncoder) Close() error { + var i interface{} = e.w + v, ok := i.(io.WriteCloser) + if ok { + return v.Close() + } + return nil +} diff --git a/internal/bsr/encode_test.go b/internal/bsr/encode_test.go new file mode 100644 index 0000000000..fd33c609ac --- /dev/null +++ b/internal/bsr/encode_test.go @@ -0,0 +1,394 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package bsr_test + +import ( + "bytes" + "compress/gzip" + "context" + "errors" + "fmt" + "io" + "testing" + "time" + + "github.com/hashicorp/boundary/internal/bsr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testChunk struct { + *bsr.BaseChunk + Data []byte + err error +} + +// MarshalData serializes the data portion of a chunk. +func (t *testChunk) MarshalData(_ context.Context) ([]byte, error) { + if t.err != nil { + return nil, t.err + } + return t.Data, nil +} + +func gziped(d string) string { + var buf bytes.Buffer + w := gzip.NewWriter(&buf) + w.Write([]byte(d)) + w.Close() + return buf.String() +} + +func TestChunkEncoder(t *testing.T) { + ctx := context.Background() + + ts := time.Date(2023, time.March, 16, 10, 47, 3, 14, time.UTC) + + cases := []struct { + name string + c bsr.Compression + e bsr.Encryption + chunks []bsr.Chunk + want []byte + }{ + { + "header-no-compression", + bsr.NoCompression, + bsr.NoEncryption, + []bsr.Chunk{ + &bsr.HeaderChunk{ + BaseChunk: &bsr.BaseChunk{ + Protocol: "TEST", + Direction: bsr.Inbound, + Timestamp: bsr.NewTimestamp(ts), + Type: bsr.ChunkHeader, + }, + Compression: bsr.NoCompression, + Encryption: bsr.NoEncryption, + SessionId: "sess_123456789", + }, + }, + []byte( + "" + // so everything else aligns better + "\x00\x00\x00\x10" + // length + "TEST" + // protocol + "HEAD" + // type + "\x01" + // direction + "\x00\x00\x00\x00\x64\x12\xf3\xa7" + // time seconds + "\x00\x00\x00\x0e" + // time nanoseconds + "\x00" + // compression method + "\x00" + // encryption method + "sess_123456789" + // data + "\xbe\x4c\x7c\x20" + // crc + "", + ), + }, + { + "header-end-no-compression", + bsr.NoCompression, + bsr.NoEncryption, + []bsr.Chunk{ + &bsr.HeaderChunk{ + BaseChunk: &bsr.BaseChunk{ + Protocol: "TEST", + Direction: bsr.Inbound, + Timestamp: bsr.NewTimestamp(ts), + Type: bsr.ChunkHeader, + }, + Compression: bsr.NoCompression, + Encryption: bsr.NoEncryption, + SessionId: "sess_123456789", + }, + &bsr.EndChunk{ + BaseChunk: &bsr.BaseChunk{ + Protocol: "TEST", + Direction: bsr.Inbound, + Timestamp: bsr.NewTimestamp(ts.Add(time.Nanosecond * 5)), + Type: bsr.ChunkEnd, + }, + }, + }, + []byte( + "" + // header + "\x00\x00\x00\x10" + // length + "TEST" + // protocol + "HEAD" + // type + "\x01" + // direction + "\x00\x00\x00\x00\x64\x12\xf3\xa7" + // time seconds + "\x00\x00\x00\x0e" + // time nanoseconds + "\x00" + // compression method + "\x00" + // encryption method + "sess_123456789" + // data + "\xbe\x4c\x7c\x20" + // crc + "" + // end + "\x00\x00\x00\x00" + // length + "TEST" + // protocol + "DONE" + // type + "\x01" + // direction + "\x00\x00\x00\x00\x64\x12\xf3\xa7" + // time seconds + "\x00\x00\x00\x13" + // time nanoseconds + "\x50\x91\xfe\x72" + // crc + "", + ), + }, + { + "header-test-end-no-compression", + bsr.NoCompression, + bsr.NoEncryption, + []bsr.Chunk{ + &bsr.HeaderChunk{ + BaseChunk: &bsr.BaseChunk{ + Protocol: "TEST", + Direction: bsr.Inbound, + Timestamp: bsr.NewTimestamp(ts), + Type: bsr.ChunkHeader, + }, + Compression: bsr.NoCompression, + Encryption: bsr.NoEncryption, + SessionId: "sess_123456789", + }, + &testChunk{ + BaseChunk: &bsr.BaseChunk{ + Protocol: "TEST", + Direction: bsr.Inbound, + Timestamp: bsr.NewTimestamp(ts), + Type: "TEST", + }, + Data: []byte("foo"), + }, + &bsr.EndChunk{ + BaseChunk: &bsr.BaseChunk{ + Protocol: "TEST", + Direction: bsr.Inbound, + Timestamp: bsr.NewTimestamp(ts.Add(time.Nanosecond * 5)), + Type: bsr.ChunkEnd, + }, + }, + }, + []byte( + "" + // header + "\x00\x00\x00\x10" + // length + "TEST" + // protocol + "HEAD" + // type + "\x01" + // direction + "\x00\x00\x00\x00\x64\x12\xf3\xa7" + // time seconds + "\x00\x00\x00\x0e" + // time nanoseconds + "\x00" + // compression method + "\x00" + // encryption method + "sess_123456789" + // data + "\xbe\x4c\x7c\x20" + // crc + "" + // test + "\x00\x00\x00\x03" + // length + "TEST" + // protocol + "TEST" + // type + "\x01" + // direction + "\x00\x00\x00\x00\x64\x12\xf3\xa7" + // time seconds + "\x00\x00\x00\x0e" + // time nanoseconds + "foo" + // data + "\xa4\x6e\x48\x70" + // crc + "" + // end + "\x00\x00\x00\x00" + // length + "TEST" + // protocol + "DONE" + // type + "\x01" + // direction + "\x00\x00\x00\x00\x64\x12\xf3\xa7" + // time seconds + "\x00\x00\x00\x13" + // time nanoseconds + "\x50\x91\xfe\x72" + // crc + "", + ), + }, + { + "header-test-end-gzip-compression", + bsr.GzipCompression, + bsr.NoEncryption, + []bsr.Chunk{ + &bsr.HeaderChunk{ + BaseChunk: &bsr.BaseChunk{ + Protocol: "TEST", + Direction: bsr.Inbound, + Timestamp: bsr.NewTimestamp(ts), + Type: bsr.ChunkHeader, + }, + Compression: bsr.GzipCompression, + Encryption: bsr.NoEncryption, + SessionId: "sess_123456789", + }, + &testChunk{ + BaseChunk: &bsr.BaseChunk{ + Protocol: "TEST", + Direction: bsr.Inbound, + Timestamp: bsr.NewTimestamp(ts), + Type: "TEST", + }, + Data: []byte("foo"), + }, + &bsr.EndChunk{ + BaseChunk: &bsr.BaseChunk{ + Protocol: "TEST", + Direction: bsr.Inbound, + Timestamp: bsr.NewTimestamp(ts.Add(time.Nanosecond * 5)), + Type: bsr.ChunkEnd, + }, + }, + }, + []byte( + "" + // header + "\x00\x00\x00\x10" + // length + "TEST" + // protocol + "HEAD" + // type + "\x01" + // direction + "\x00\x00\x00\x00\x64\x12\xf3\xa7" + // time seconds + "\x00\x00\x00\x0e" + // time nanoseconds + "\x01" + // compression method + "\x00" + // encryption method + "sess_123456789" + // data + "\x10\x24\xed\xb1" + // crc + "" + // test + "\x00\x00\x00\x1b" + // length + "TEST" + // protocol + "TEST" + // type + "\x01" + // direction + "\x00\x00\x00\x00\x64\x12\xf3\xa7" + // time seconds + "\x00\x00\x00\x0e" + // time nanoseconds + gziped("foo") + // data + "\x29\x8e\x22\x12" + // crc + "" + // end + "\x00\x00\x00\x00" + // length + "TEST" + // protocol + "DONE" + // type + "\x01" + // direction + "\x00\x00\x00\x00\x64\x12\xf3\xa7" + // time seconds + "\x00\x00\x00\x13" + // time nanoseconds + "\x50\x91\xfe\x72" + // crc + "", + ), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + enc, err := bsr.NewChunkEncoder(ctx, &buf, tc.c, tc.e) + require.NoError(t, err) + + var wrote int + for _, c := range tc.chunks { + w, err := enc.Encode(ctx, c) + require.NoError(t, err) + wrote += w + } + err = enc.Close() + require.NoError(t, err) + + got := buf.Bytes() + assert.Equal(t, len(tc.want), wrote) + assert.Equal(t, tc.want, got) + }) + } +} + +type errorWriter struct{} + +func (e errorWriter) Write(_ []byte) (int, error) { + return 0, fmt.Errorf("write error") +} + +func TestChunkEncoderEncodeError(t *testing.T) { + ctx := context.Background() + + ts := time.Date(2023, time.March, 16, 10, 47, 3, 14, time.UTC) + + cases := []struct { + name string + w io.Writer + c bsr.Compression + e bsr.Encryption + chunk bsr.Chunk + want error + }{ + { + "chunk-marshal-error", + func() io.Writer { var buf bytes.Buffer; return &buf }(), + bsr.NoCompression, + bsr.NoEncryption, + &testChunk{ + BaseChunk: &bsr.BaseChunk{ + Protocol: "TEST", + Direction: bsr.Inbound, + Timestamp: bsr.NewTimestamp(ts), + Type: "TEST", + }, + err: fmt.Errorf("marshal error"), + }, + fmt.Errorf("marshal error"), + }, + { + "writer-error", + func() io.Writer { return errorWriter{} }(), + bsr.NoCompression, + bsr.NoEncryption, + &testChunk{ + BaseChunk: &bsr.BaseChunk{ + Protocol: "TEST", + Direction: bsr.Inbound, + Timestamp: bsr.NewTimestamp(ts), + Type: "TEST", + }, + Data: []byte("foo"), + }, + fmt.Errorf("write error"), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + enc, err := bsr.NewChunkEncoder(ctx, tc.w, tc.c, tc.e) + require.NoError(t, err) + + _, err = enc.Encode(ctx, tc.chunk) + assert.EqualError(t, tc.want, err.Error()) + }) + } +} + +func TestChunkEncoderErrors(t *testing.T) { + ctx := context.Background() + + cases := []struct { + name string + w io.Writer + c bsr.Compression + e bsr.Encryption + want error + }{ + { + "invalid-compression", + func() io.Writer { var buf bytes.Buffer; return &buf }(), + bsr.Compression(255), + bsr.NoEncryption, + errors.New("bsr.NewChunkEncoder: invalid compression: invalid parameter"), + }, + { + "invalid-encryption", + func() io.Writer { var buf bytes.Buffer; return &buf }(), + bsr.NoCompression, + bsr.Encryption(255), + errors.New("bsr.NewChunkEncoder: invalid encryption: invalid parameter"), + }, + { + "nil-writer", + nil, + bsr.NoCompression, + bsr.NoEncryption, + errors.New("bsr.NewChunkEncoder: writer cannot be nil: invalid parameter"), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := bsr.NewChunkEncoder(ctx, tc.w, tc.c, tc.e) + require.EqualError(t, tc.want, err.Error()) + }) + } +}