summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2023-01-28 13:39:42 -0700
committerLuke Shumaker <lukeshu@lukeshu.com>2023-01-29 02:14:51 -0700
commit636311bafdb18da9851a668317a8d792f38ead5b (patch)
tree0dbd3ac35fb4f54487600f3ea19964043ea011bf /internal
parent2824310168b9dbe24c2d47cfb71d4283b1733642 (diff)
Move the base64 decode to the internal package
Diffstat (limited to 'internal')
-rw-r--r--internal/base64.go121
-rw-r--r--internal/base64_test.go44
2 files changed, 165 insertions, 0 deletions
diff --git a/internal/base64.go b/internal/base64.go
new file mode 100644
index 0000000..15adbf4
--- /dev/null
+++ b/internal/base64.go
@@ -0,0 +1,121 @@
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package internal
+
+import (
+ "encoding/base64"
+ "io"
+ "strings"
+)
+
+type base64Decoder struct {
+ dst io.Writer
+
+ err error
+ pos int64
+ buf [4]byte
+ bufLen int
+}
+
+func NewBase64Decoder(w io.Writer) io.WriteCloser {
+ return &base64Decoder{
+ dst: w,
+ }
+}
+
+func (dec *base64Decoder) decodeByte(b byte) (byte, bool) {
+ const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
+ n := strings.IndexByte(alphabet, b)
+ if n < 0 {
+ return 0, false
+ }
+ dec.pos++
+ return byte(n), true
+}
+
+func (dec *base64Decoder) decodeTuple(a, b, c, d byte) error {
+ var decodedLen int
+ var encoded [4]byte
+ var ok bool
+
+ if a != '=' {
+ encoded[0], ok = dec.decodeByte(a)
+ if !ok {
+ return base64.CorruptInputError(dec.pos)
+ }
+ decodedLen++
+ }
+ if b != '=' {
+ encoded[1], ok = dec.decodeByte(b)
+ if !ok {
+ return base64.CorruptInputError(dec.pos)
+ }
+ // do NOT increment decodedLen here
+ }
+ if c != '=' {
+ encoded[2], ok = dec.decodeByte(c)
+ if !ok {
+ return base64.CorruptInputError(dec.pos)
+ }
+ decodedLen++
+ }
+ if d != '=' {
+ encoded[3], ok = dec.decodeByte(d)
+ if !ok {
+ return base64.CorruptInputError(dec.pos)
+ }
+ decodedLen++
+ }
+
+ val := 0 |
+ uint32(encoded[0])<<18 |
+ uint32(encoded[1])<<12 |
+ uint32(encoded[2])<<6 |
+ uint32(encoded[3])<<0
+ var decoded [3]byte
+ decoded[0] = byte(val >> 16)
+ decoded[1] = byte(val >> 8)
+ decoded[2] = byte(val >> 0)
+
+ _, err := dec.dst.Write(decoded[:decodedLen])
+ return err
+}
+
+func (dec *base64Decoder) Write(dat []byte) (int, error) {
+ if len(dat) == 0 {
+ return 0, nil
+ }
+ if dec.err != nil {
+ return 0, dec.err
+ }
+ var n int
+ if dec.bufLen > 0 {
+ n = copy(dec.buf[dec.bufLen:], dat)
+ dec.bufLen += n
+ if dec.bufLen < 4 {
+ return len(dat), nil
+ }
+ if err := dec.decodeTuple(dec.buf[0], dec.buf[1], dec.buf[2], dec.buf[3]); err != nil {
+ dec.err = err
+ return 0, dec.err
+ }
+ }
+ for ; n+3 < len(dat); n += 4 {
+ if err := dec.decodeTuple(dat[n], dat[n+1], dat[n+2], dat[n+3]); err != nil {
+ dec.err = err
+ return n, dec.err
+ }
+ }
+ dec.bufLen = copy(dec.buf[:], dat[n:])
+ return len(dat), nil
+}
+
+func (dec *base64Decoder) Close() error {
+ if dec.bufLen == 0 {
+ return nil
+ }
+ copy(dec.buf[:], "====")
+ return dec.decodeTuple(dec.buf[0], dec.buf[1], dec.buf[2], dec.buf[3])
+}
diff --git a/internal/base64_test.go b/internal/base64_test.go
new file mode 100644
index 0000000..f18bcd7
--- /dev/null
+++ b/internal/base64_test.go
@@ -0,0 +1,44 @@
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package internal
+
+import (
+ "bytes"
+ "encoding/base64"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func b64encode(t *testing.T, input []byte) []byte {
+ var encoded bytes.Buffer
+ enc := base64.NewEncoder(base64.StdEncoding, &encoded)
+ _, err := enc.Write(input)
+ require.NoError(t, err)
+ require.NoError(t, enc.Close())
+ return encoded.Bytes()
+}
+
+func b64decode(t *testing.T, input []byte) []byte {
+ var decoded bytes.Buffer
+ dec := NewBase64Decoder(&decoded)
+ _, err := dec.Write(input)
+ require.NoError(t, err)
+ require.NoError(t, dec.Close())
+ return decoded.Bytes()
+}
+
+func FuzzBase64Decoder(f *testing.F) {
+ f.Fuzz(func(t *testing.T, input []byte) {
+ encoded := b64encode(t, input)
+ decoded := b64decode(t, encoded)
+ t.Logf("input b64 = %q", encoded)
+ t.Logf("expected decoded = %#v", input)
+ t.Logf("actual decoded = %#v", decoded)
+ if !bytes.Equal(input, decoded) {
+ t.Fail()
+ }
+ })
+}