diff options
author | Luke Shumaker <lukeshu@lukeshu.com> | 2017-09-07 23:28:47 -0400 |
---|---|---|
committer | Luke Shumaker <lukeshu@lukeshu.com> | 2017-09-08 16:55:55 -0400 |
commit | b58ea042394c66eabe67c3f58906c5d76b1e119d (patch) | |
tree | db1f55fb187504c7866b81c33ce0dc1489135da5 | |
parent | e7b6b3a7ae2e53d807e14697708c4110c038303b (diff) |
nslcd_{server,systemd}: FIX, BREAKING CHANGE: add limits
Added types:
nslcd_server: type Limits struct { ...}
nslcd_server: type Conn interface{ ... } // a subset of net.Conn
nslcd_server.HandleRequest() signature change:
-func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred) (err error) {
+func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) (err error) {
The `limits Limits` argument is added, and `conn Conn` replaces `in io.Reader` and `out io.Writer`.
nslcd_systemd.Main() signature change:
-func Main(backend Backend) uint8 {
+func Main(backend Backend, limits nslcd_server.Limits) uint8 {
The `limits Limits` argument is added.
-rwxr-xr-x | nslcd_server/func_handlerequest.go.gen | 64 | ||||
-rw-r--r-- | nslcd_systemd/misc_test.go | 4 | ||||
-rw-r--r-- | nslcd_systemd/nslcd_systemd.go | 16 |
3 files changed, 67 insertions, 17 deletions
diff --git a/nslcd_server/func_handlerequest.go.gen b/nslcd_server/func_handlerequest.go.gen index d34db88..af36e84 100755 --- a/nslcd_server/func_handlerequest.go.gen +++ b/nslcd_server/func_handlerequest.go.gen @@ -26,8 +26,8 @@ package nslcd_server import ( "fmt" - "io" "os" + "time" "golang.org/x/sys/unix" p "git.lukeshu.com/go/libnslcd/nslcd_proto" @@ -41,8 +41,32 @@ func maybePanic(err error) { } } +type Limits struct { + // What is the maximum total amount of time that we spend + // handling a single request. This includes both the time + // reading the request and the time creating and writing the + // response. + Timeout time.Duration + + // How long can we spend reading a request? + ReadTimeout time.Duration + + // How long can we spend writing a response? + WriteTimeout time.Duration +} + +type Conn interface { + // This is a subset of net.Conn; semantics are the same. + + Read(b []byte) (n int, err error) + Write(b []byte) (n int, err error) + SetDeadline(t time.Time) error + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error +} + // Handle a request to nslcd -func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred) (err error) { +func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) (err error) { defer func() { if r := recover(); r != nil { switch r := r.(type) { @@ -54,13 +78,27 @@ func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred } }() + now := time.Now() + if limits.Timeout != 0 { + err = conn.SetDeadline(now.Add(limits.Timeout)) + if err != nil { + return err + } + } + if limits.ReadTimeout != 0 { + err = conn.SetReadDeadline(now.Add(limits.ReadTimeout)) + if err != nil { + return err + } + } + var version int32 - maybePanic(p.Read(in, &version)) + maybePanic(p.Read(conn, &version)) if version != p.NSLCD_VERSION { return p.NslcdError(fmt.Sprintf("Version mismatch: server=%#08x client=%#08x", p.NSLCD_VERSION, version)) } var action int32 - maybePanic(p.Read(in, &action)) + maybePanic(p.Read(conn, &action)) switch action { $( @@ -68,7 +106,7 @@ while read -r request; do cat <<EOT case p.NSLCD_ACTION_${request^^}: var req p.Request_${request} - maybePanic(p.Read(in, &req)) + maybePanic(p.Read(conn, &req)) $( case "$request" in PAM_Authentication) @@ -94,19 +132,25 @@ while read -r request; do ;; esac ) - maybePanic(p.Write(out, p.NSLCD_VERSION)) - maybePanic(p.Write(out, action)) + if limits.WriteTimeout != 0 { + err = conn.SetWriteDeadline(time.Now().Add(limits.WriteTimeout)) + if err != nil { + return err + } + } + maybePanic(p.Write(conn, p.NSLCD_VERSION)) + maybePanic(p.Write(conn, action)) ch := backend.${request}(cred, req) for result := range ch { if err == nil { - err = p.Write(out, p.NSLCD_RESULT_BEGIN) + err = p.Write(conn, p.NSLCD_RESULT_BEGIN) } if err == nil { - err = p.Write(out, result) + err = p.Write(conn, result) } } maybePanic(err) - maybePanic(p.Write(out, p.NSLCD_RESULT_END)) + maybePanic(p.Write(conn, p.NSLCD_RESULT_END)) return nil EOT done < "$requests" diff --git a/nslcd_systemd/misc_test.go b/nslcd_systemd/misc_test.go index d2b2a7e..a910cd9 100644 --- a/nslcd_systemd/misc_test.go +++ b/nslcd_systemd/misc_test.go @@ -77,7 +77,9 @@ func testBadClient(t *testContext, backend nslcd_systemd.Backend, toclose bool) // server ////////////////////////////////////////////////////////////// errfatal(sdActivatedStream(t.tmpdir + "/nslcd.sock")) go func() { - evExitServer <- nslcd_systemd.Main(backend) + evExitServer <- nslcd_systemd.Main(backend, nslcd_server.Limits{ + Timeout: 1 * time.Second, + }) }() // client/driver /////////////////////////////////////////////////////// diff --git a/nslcd_systemd/nslcd_systemd.go b/nslcd_systemd/nslcd_systemd.go index 97991c8..8bae046 100644 --- a/nslcd_systemd/nslcd_systemd.go +++ b/nslcd_systemd/nslcd_systemd.go @@ -24,11 +24,15 @@ // // package main // -// import "nslcd/systemd" +// import ( +// "git.lukeshu.com/go/libnslcd/nslcd_server" +// "git.lukeshu.com/go/libnslcd/nslcd_systemd" +// ) // // func main() { // backend := ... -// os.Exit(int(nslcd_systemd.Main(backend))) +// limits := nslcd_server.Limits{ ... } +// os.Exit(int(nslcd_systemd.Main(backend, limits))) // } package nslcd_systemd // import "git.lukeshu.com/go/libnslcd/nslcd_systemd" @@ -88,7 +92,7 @@ func getpeercred(conn *net.UnixConn) (cred unix.Ucred, err error) { return } -func handler(conn *net.UnixConn, backend nslcd_server.Backend) { +func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net.UnixConn) { defer conn.Close() cred, err := getpeercred(conn) if err != nil { @@ -97,13 +101,13 @@ func handler(conn *net.UnixConn, backend nslcd_server.Backend) { sd_daemon.Log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v", cred.Pid, cred.Uid, cred.Gid)) } - err = nslcd_server.HandleRequest(backend, conn, conn, cred) + err = nslcd_server.HandleRequest(backend, limits, conn, cred) if err != nil { sd_daemon.Log.Notice(fmt.Sprintf("Error while handling request: %v", err)) } } -func Main(backend Backend) uint8 { +func Main(backend Backend, limits nslcd_server.Limits) uint8 { defer sd_daemon.Recover() var err error = nil @@ -165,7 +169,7 @@ func Main(backend Backend) uint8 { go func() { defer sd_daemon.Recover() defer wg.Done() - handler(conn.(*net.UnixConn), backend) + handler(backend, limits, conn.(*net.UnixConn)) }() } } |