1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
|
// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later
package internal
import (
"encoding/base64"
"io"
"strings"
)
type base64Decoder struct {
dst io.Writer
err error
pos int64
buf [4]byte
bufLen int
}
func NewBase64Decoder(w io.Writer) io.WriteCloser {
return &base64Decoder{
dst: w,
}
}
func (dec *base64Decoder) decodeByte(b byte) (byte, bool) {
const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
n := strings.IndexByte(alphabet, b)
if n < 0 {
return 0, false
}
dec.pos++
return byte(n), true
}
func (dec *base64Decoder) decodeTuple(a, b, c, d byte) error {
var decodedLen int
var encoded [4]byte
var ok bool
if a != '=' {
encoded[0], ok = dec.decodeByte(a)
if !ok {
return base64.CorruptInputError(dec.pos)
}
decodedLen++
}
if b != '=' {
encoded[1], ok = dec.decodeByte(b)
if !ok {
return base64.CorruptInputError(dec.pos)
}
// do NOT increment decodedLen here
}
if c != '=' {
encoded[2], ok = dec.decodeByte(c)
if !ok {
return base64.CorruptInputError(dec.pos)
}
decodedLen++
}
if d != '=' {
encoded[3], ok = dec.decodeByte(d)
if !ok {
return base64.CorruptInputError(dec.pos)
}
decodedLen++
}
val := 0 |
uint32(encoded[0])<<18 |
uint32(encoded[1])<<12 |
uint32(encoded[2])<<6 |
uint32(encoded[3])<<0
var decoded [3]byte
decoded[0] = byte(val >> 16)
decoded[1] = byte(val >> 8)
decoded[2] = byte(val >> 0)
_, err := dec.dst.Write(decoded[:decodedLen])
return err
}
func (dec *base64Decoder) Write(dat []byte) (int, error) {
if len(dat) == 0 {
return 0, nil
}
if dec.err != nil {
return 0, dec.err
}
var n int
if dec.bufLen > 0 {
n = copy(dec.buf[dec.bufLen:], dat)
dec.bufLen += n
if dec.bufLen < 4 {
return len(dat), nil
}
if err := dec.decodeTuple(dec.buf[0], dec.buf[1], dec.buf[2], dec.buf[3]); err != nil {
dec.err = err
return 0, dec.err
}
}
for ; n+3 < len(dat); n += 4 {
if err := dec.decodeTuple(dat[n], dat[n+1], dat[n+2], dat[n+3]); err != nil {
dec.err = err
return n, dec.err
}
}
dec.bufLen = copy(dec.buf[:], dat[n:])
return len(dat), nil
}
func (dec *base64Decoder) Close() error {
if dec.bufLen == 0 {
return nil
}
copy(dec.buf[:], "====")
return dec.decodeTuple(dec.buf[0], dec.buf[1], dec.buf[2], dec.buf[3])
}
|