diff options
author | Luke Shumaker <lukeshu@lukeshu.com> | 2023-01-28 13:39:42 -0700 |
---|---|---|
committer | Luke Shumaker <lukeshu@lukeshu.com> | 2023-01-29 02:14:51 -0700 |
commit | 636311bafdb18da9851a668317a8d792f38ead5b (patch) | |
tree | 0dbd3ac35fb4f54487600f3ea19964043ea011bf /internal | |
parent | 2824310168b9dbe24c2d47cfb71d4283b1733642 (diff) |
Move the base64 decode to the internal package
Diffstat (limited to 'internal')
-rw-r--r-- | internal/base64.go | 121 | ||||
-rw-r--r-- | internal/base64_test.go | 44 |
2 files changed, 165 insertions, 0 deletions
diff --git a/internal/base64.go b/internal/base64.go new file mode 100644 index 0000000..15adbf4 --- /dev/null +++ b/internal/base64.go @@ -0,0 +1,121 @@ +// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com> +// +// SPDX-License-Identifier: GPL-2.0-or-later + +package internal + +import ( + "encoding/base64" + "io" + "strings" +) + +type base64Decoder struct { + dst io.Writer + + err error + pos int64 + buf [4]byte + bufLen int +} + +func NewBase64Decoder(w io.Writer) io.WriteCloser { + return &base64Decoder{ + dst: w, + } +} + +func (dec *base64Decoder) decodeByte(b byte) (byte, bool) { + const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + n := strings.IndexByte(alphabet, b) + if n < 0 { + return 0, false + } + dec.pos++ + return byte(n), true +} + +func (dec *base64Decoder) decodeTuple(a, b, c, d byte) error { + var decodedLen int + var encoded [4]byte + var ok bool + + if a != '=' { + encoded[0], ok = dec.decodeByte(a) + if !ok { + return base64.CorruptInputError(dec.pos) + } + decodedLen++ + } + if b != '=' { + encoded[1], ok = dec.decodeByte(b) + if !ok { + return base64.CorruptInputError(dec.pos) + } + // do NOT increment decodedLen here + } + if c != '=' { + encoded[2], ok = dec.decodeByte(c) + if !ok { + return base64.CorruptInputError(dec.pos) + } + decodedLen++ + } + if d != '=' { + encoded[3], ok = dec.decodeByte(d) + if !ok { + return base64.CorruptInputError(dec.pos) + } + decodedLen++ + } + + val := 0 | + uint32(encoded[0])<<18 | + uint32(encoded[1])<<12 | + uint32(encoded[2])<<6 | + uint32(encoded[3])<<0 + var decoded [3]byte + decoded[0] = byte(val >> 16) + decoded[1] = byte(val >> 8) + decoded[2] = byte(val >> 0) + + _, err := dec.dst.Write(decoded[:decodedLen]) + return err +} + +func (dec *base64Decoder) Write(dat []byte) (int, error) { + if len(dat) == 0 { + return 0, nil + } + if dec.err != nil { + return 0, dec.err + } + var n int + if dec.bufLen > 0 { + n = copy(dec.buf[dec.bufLen:], dat) + dec.bufLen += n + if dec.bufLen < 4 { + return len(dat), nil + } + if err := dec.decodeTuple(dec.buf[0], dec.buf[1], dec.buf[2], dec.buf[3]); err != nil { + dec.err = err + return 0, dec.err + } + } + for ; n+3 < len(dat); n += 4 { + if err := dec.decodeTuple(dat[n], dat[n+1], dat[n+2], dat[n+3]); err != nil { + dec.err = err + return n, dec.err + } + } + dec.bufLen = copy(dec.buf[:], dat[n:]) + return len(dat), nil +} + +func (dec *base64Decoder) Close() error { + if dec.bufLen == 0 { + return nil + } + copy(dec.buf[:], "====") + return dec.decodeTuple(dec.buf[0], dec.buf[1], dec.buf[2], dec.buf[3]) +} diff --git a/internal/base64_test.go b/internal/base64_test.go new file mode 100644 index 0000000..f18bcd7 --- /dev/null +++ b/internal/base64_test.go @@ -0,0 +1,44 @@ +// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com> +// +// SPDX-License-Identifier: GPL-2.0-or-later + +package internal + +import ( + "bytes" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/require" +) + +func b64encode(t *testing.T, input []byte) []byte { + var encoded bytes.Buffer + enc := base64.NewEncoder(base64.StdEncoding, &encoded) + _, err := enc.Write(input) + require.NoError(t, err) + require.NoError(t, enc.Close()) + return encoded.Bytes() +} + +func b64decode(t *testing.T, input []byte) []byte { + var decoded bytes.Buffer + dec := NewBase64Decoder(&decoded) + _, err := dec.Write(input) + require.NoError(t, err) + require.NoError(t, dec.Close()) + return decoded.Bytes() +} + +func FuzzBase64Decoder(f *testing.F) { + f.Fuzz(func(t *testing.T, input []byte) { + encoded := b64encode(t, input) + decoded := b64decode(t, encoded) + t.Logf("input b64 = %q", encoded) + t.Logf("expected decoded = %#v", input) + t.Logf("actual decoded = %#v", decoded) + if !bytes.Equal(input, decoded) { + t.Fail() + } + }) +} |