summaryrefslogtreecommitdiff
path: root/internal/base64.go
blob: 15adbf4a523264ae6faf4b5c26b78e553ad272b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// Copyright (C) 2022-2023  Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later

package internal

import (
	"encoding/base64"
	"io"
	"strings"
)

type base64Decoder struct {
	dst io.Writer

	err    error
	pos    int64
	buf    [4]byte
	bufLen int
}

func NewBase64Decoder(w io.Writer) io.WriteCloser {
	return &base64Decoder{
		dst: w,
	}
}

func (dec *base64Decoder) decodeByte(b byte) (byte, bool) {
	const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
	n := strings.IndexByte(alphabet, b)
	if n < 0 {
		return 0, false
	}
	dec.pos++
	return byte(n), true
}

func (dec *base64Decoder) decodeTuple(a, b, c, d byte) error {
	var decodedLen int
	var encoded [4]byte
	var ok bool

	if a != '=' {
		encoded[0], ok = dec.decodeByte(a)
		if !ok {
			return base64.CorruptInputError(dec.pos)
		}
		decodedLen++
	}
	if b != '=' {
		encoded[1], ok = dec.decodeByte(b)
		if !ok {
			return base64.CorruptInputError(dec.pos)
		}
		// do NOT increment decodedLen here
	}
	if c != '=' {
		encoded[2], ok = dec.decodeByte(c)
		if !ok {
			return base64.CorruptInputError(dec.pos)
		}
		decodedLen++
	}
	if d != '=' {
		encoded[3], ok = dec.decodeByte(d)
		if !ok {
			return base64.CorruptInputError(dec.pos)
		}
		decodedLen++
	}

	val := 0 |
		uint32(encoded[0])<<18 |
		uint32(encoded[1])<<12 |
		uint32(encoded[2])<<6 |
		uint32(encoded[3])<<0
	var decoded [3]byte
	decoded[0] = byte(val >> 16)
	decoded[1] = byte(val >> 8)
	decoded[2] = byte(val >> 0)

	_, err := dec.dst.Write(decoded[:decodedLen])
	return err
}

func (dec *base64Decoder) Write(dat []byte) (int, error) {
	if len(dat) == 0 {
		return 0, nil
	}
	if dec.err != nil {
		return 0, dec.err
	}
	var n int
	if dec.bufLen > 0 {
		n = copy(dec.buf[dec.bufLen:], dat)
		dec.bufLen += n
		if dec.bufLen < 4 {
			return len(dat), nil
		}
		if err := dec.decodeTuple(dec.buf[0], dec.buf[1], dec.buf[2], dec.buf[3]); err != nil {
			dec.err = err
			return 0, dec.err
		}
	}
	for ; n+3 < len(dat); n += 4 {
		if err := dec.decodeTuple(dat[n], dat[n+1], dat[n+2], dat[n+3]); err != nil {
			dec.err = err
			return n, dec.err
		}
	}
	dec.bufLen = copy(dec.buf[:], dat[n:])
	return len(dat), nil
}

func (dec *base64Decoder) Close() error {
	if dec.bufLen == 0 {
		return nil
	}
	copy(dec.buf[:], "====")
	return dec.decodeTuple(dec.buf[0], dec.buf[1], dec.buf[2], dec.buf[3])
}