diff options
-rw-r--r-- | nslcd_proto/io.go | 23 | ||||
-rwxr-xr-x | nslcd_server/func_handlerequest.go.gen | 29 | ||||
-rw-r--r-- | nslcd_systemd/misc_test.go | 3 |
3 files changed, 42 insertions, 13 deletions
diff --git a/nslcd_proto/io.go b/nslcd_proto/io.go index daced37..bf59282 100644 --- a/nslcd_proto/io.go +++ b/nslcd_proto/io.go @@ -76,7 +76,9 @@ func Write(fd io.Writer, data interface{}) (err error) { } // Read an object from a stream. Any errors returned are of type -// NslcdError. +// NslcdError. If the type assertion succeeds, then +// fd.(*io.LimitedReader).N is used to prevent an overly-large buffer +// from being allocated. func Read(fd io.Reader, data interface{}) (err error) { defer func() { if r := recover(); r != nil { @@ -156,6 +158,16 @@ func write(fd io.Writer, data interface{}) { } } +// Assert that we *will* read n bytes. If we know now that < n bytes +// will be available, then this will let us avoid an allocation. +func willread(fd io.Reader, n int64) { + if lfd, ok := fd.(*io.LimitedReader); ok { + if n > lfd.N { + npanic(NslcdError(io.EOF.Error())) + } + } +} + // Read an object from a stream. In the event of an error, this // function may panic(NslcdError)! Handle it! func read(fd io.Reader, data interface{}) { @@ -179,13 +191,15 @@ func read(fd io.Reader, data interface{}) { case *string: var len int32 read(fd, &len) - buf := make([]byte, len) // BUG(lukeshu): Read: `string` length needs sanity checked + willread(fd, int64(len)) + buf := make([]byte, len) read(fd, &buf) *data = string(buf) case *[]string: var num int32 read(fd, &num) - *data = make([]string, num) // BUG(lukeshu): Read: `[]string` length needs sanity checked + willread(fd, int64(num * /* min size of a string is: */4)) + *data = make([]string, num) for i := 0; i < int(num); i++ { read(fd, &((*data)[i])) } @@ -212,7 +226,8 @@ func read(fd io.Reader, data interface{}) { case *[]net.IP: var num int32 read(fd, &num) - *data = make([]net.IP, num) // BUG(lukeshu): Read: `[]net.IP` length needs sanity checked + willread(fd, int64(num * /* min size of an IP is: */net.IPv4len)) + *data = make([]net.IP, num) for i := 0; i < int(num); i++ { read(fd, &((*data)[i])) } diff --git a/nslcd_server/func_handlerequest.go.gen b/nslcd_server/func_handlerequest.go.gen index af36e84..7c28e7c 100755 --- a/nslcd_server/func_handlerequest.go.gen +++ b/nslcd_server/func_handlerequest.go.gen @@ -26,6 +26,7 @@ package nslcd_server import ( "fmt" + "io" "os" "time" @@ -53,6 +54,10 @@ type Limits struct { // How long can we spend writing a response? WriteTimeout time.Duration + + // What is the maximum request length in bytes that we are + // willing to handle? + RequestMaxSize int64 } type Conn interface { @@ -92,13 +97,19 @@ func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) ( } } + var in io.Reader = conn + if limits.RequestMaxSize > 0 { + in = &io.LimitedReader{R: in, N: limits.RequestMaxSize} + } + out := conn + var version int32 - maybePanic(p.Read(conn, &version)) + maybePanic(p.Read(in, &version)) if version != p.NSLCD_VERSION { return p.NslcdError(fmt.Sprintf("Version mismatch: server=%#08x client=%#08x", p.NSLCD_VERSION, version)) } var action int32 - maybePanic(p.Read(conn, &action)) + maybePanic(p.Read(in, &action)) switch action { $( @@ -106,7 +117,7 @@ while read -r request; do cat <<EOT case p.NSLCD_ACTION_${request^^}: var req p.Request_${request} - maybePanic(p.Read(conn, &req)) + maybePanic(p.Read(in, &req)) $( case "$request" in PAM_Authentication) @@ -133,24 +144,24 @@ while read -r request; do esac ) if limits.WriteTimeout != 0 { - err = conn.SetWriteDeadline(time.Now().Add(limits.WriteTimeout)) + err = out.SetWriteDeadline(time.Now().Add(limits.WriteTimeout)) if err != nil { return err } } - maybePanic(p.Write(conn, p.NSLCD_VERSION)) - maybePanic(p.Write(conn, action)) + maybePanic(p.Write(out, p.NSLCD_VERSION)) + maybePanic(p.Write(out, action)) ch := backend.${request}(cred, req) for result := range ch { if err == nil { - err = p.Write(conn, p.NSLCD_RESULT_BEGIN) + err = p.Write(out, p.NSLCD_RESULT_BEGIN) } if err == nil { - err = p.Write(conn, result) + err = p.Write(out, result) } } maybePanic(err) - maybePanic(p.Write(conn, p.NSLCD_RESULT_END)) + maybePanic(p.Write(out, p.NSLCD_RESULT_END)) return nil EOT done < "$requests" diff --git a/nslcd_systemd/misc_test.go b/nslcd_systemd/misc_test.go index 0e6f2e9..be6d40c 100644 --- a/nslcd_systemd/misc_test.go +++ b/nslcd_systemd/misc_test.go @@ -293,13 +293,16 @@ func TestLargeRequest(t *testing.T) { defer sdActivatedReset() t.Run("large-request", func(t *testing.T) { testWithTimeout(t, 2*time.Second, func(t *testing.T, s chan<- string) { + KiB := 1024 GiB := 1024*1024*1024 + ctx := &testContext{T: t, tmpdir: tmpdir, status: s} backend := &LockingBackend{} limits := nslcd_server.Limits{ Timeout: 1 * time.Second, + RequestMaxSize: int64(1*KiB), } notifyHandler := func(dat []byte, oob []byte) error { return nil } |