diff options
-rw-r--r-- | nslcd_server/ctx.go | 46 | ||||
-rwxr-xr-x | nslcd_server/func_handlerequest.go.gen | 26 | ||||
-rwxr-xr-x | nslcd_server/interface_backend.go.gen | 9 | ||||
-rwxr-xr-x | nslcd_server/type_nilbackend.go.gen | 7 | ||||
-rw-r--r-- | nslcd_systemd/misc_test.go | 7 | ||||
-rw-r--r-- | nslcd_systemd/nslcd_systemd.go | 23 |
6 files changed, 98 insertions, 20 deletions
diff --git a/nslcd_server/ctx.go b/nslcd_server/ctx.go new file mode 100644 index 0000000..5214adc --- /dev/null +++ b/nslcd_server/ctx.go @@ -0,0 +1,46 @@ +// Copyright (C) 2017 Luke Shumaker <lukeshu@sbcglobal.net> +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 2.1 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +// 02110-1301 USA + +package nslcd_server + +import ( + "context" + + "golang.org/x/sys/unix" +) + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +var ( + // PeerCredKey is a context key. It can be used in backend + // methods to access the credentials of the client process. + // The associated value will be of type + // "golang.org/x/sys/unix".Ucred + PeerCredKey = &contextKey{"peercred"} +) + +// PeerCredFromContext is a convenience function for +// +// cred, ok := ctx.Value(nslcd_server.PeerCredKey).(unix.Ucred) +func PeerCredFromContext(ctx context.Context) (unix.Ucred, bool) { + cred, ok := ctx.Value(PeerCredKey).(unix.Ucred) + return cred, ok +} diff --git a/nslcd_server/func_handlerequest.go.gen b/nslcd_server/func_handlerequest.go.gen index 00e9663..750a7b0 100755 --- a/nslcd_server/func_handlerequest.go.gen +++ b/nslcd_server/func_handlerequest.go.gen @@ -25,12 +25,12 @@ cat <<EOF | gofmt package nslcd_server import ( + "context" "fmt" "io" "os" "time" - "golang.org/x/sys/unix" p "git.lukeshu.com/go/libnslcd/nslcd_proto" ) @@ -70,8 +70,9 @@ type Conn interface { SetWriteDeadline(t time.Time) error } -// Handle a request to nslcd -func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) (err error) { +// Handle a request to nslcd. The caller is responsible for +// initializing the context with PeerCredKey. +func HandleRequest(backend Backend, limits Limits, conn Conn, ctx context.Context) (err error) { defer func() { if r := recover(); r != nil { switch r := r.(type) { @@ -89,6 +90,11 @@ func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) ( if limits.Timeout != 0 { deadlineAll = now.Add(limits.Timeout) } + if deadline, ok := ctx.Deadline(); ok { + if deadlineAll.IsZero() || deadline.Before(deadlineAll) { + deadlineAll = deadline + } + } if limits.ReadTimeout != 0 { deadlineRead = now.Add(limits.ReadTimeout) if !deadlineAll.IsZero() && deadlineAll.Before(deadlineRead) { @@ -149,6 +155,7 @@ while read -r request; do ;; esac ) + if limits.WriteTimeout != 0 { deadlineWrite = time.Now().Add(limits.WriteTimeout) if !deadlineAll.IsZero() && deadlineAll.Before(deadlineWrite) { @@ -161,9 +168,18 @@ while read -r request; do return err } } + + var cancel context.CancelFunc + if deadline, ok := ctx.Deadline(); !ok || (!deadlineWrite.IsZero() && deadline.After(deadlineWrite)) { + ctx, cancel = context.WithDeadline(ctx, deadlineWrite) + } else { + ctx, cancel = context.WithCancel(ctx) + } + defer cancel() + maybePanic(p.Write(out, p.NSLCD_VERSION)) maybePanic(p.Write(out, action)) - ch := backend.${request}(cred, req) + ch := backend.${request}(ctx, req) for result := range ch { if err == nil { err = p.Write(out, p.NSLCD_RESULT_BEGIN) @@ -174,7 +190,7 @@ while read -r request; do } maybePanic(err) maybePanic(p.Write(out, p.NSLCD_RESULT_END)) - return nil + return ctx.Err() // probably nil EOT done < "$requests" ) diff --git a/nslcd_server/interface_backend.go.gen b/nslcd_server/interface_backend.go.gen index 4749d0c..4ced5e7 100755 --- a/nslcd_server/interface_backend.go.gen +++ b/nslcd_server/interface_backend.go.gen @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# Copyright (C) 2015 Luke Shumaker <lukeshu@sbcglobal.net> +# Copyright (C) 2015, 2017 Luke Shumaker <lukeshu@sbcglobal.net> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -24,7 +24,8 @@ cat <<EOF | gofmt package nslcd_server import ( - "golang.org/x/sys/unix" + "context" + p "git.lukeshu.com/go/libnslcd/nslcd_proto" ) @@ -33,7 +34,7 @@ import ( // that the nslcd server may reply to is implemented simply as a // method that returns a channel of the resulting values. type Backend interface { - $(sed -rn 's/([^_]+)(.*)/\1\2(unix.Ucred, p.Request_\1\2) <-chan p.\1/p' "$requests" | grep -v PAM) - $(sed -rn 's/(PAM)(.*)/\1\2(unix.Ucred, p.Request_\1\2) <-chan p.\1\2/p' "$requests") + $(sed -rn 's/([^_]+)(.*)/\1\2(context.Context, p.Request_\1\2) <-chan p.\1/p' "$requests" | grep -v PAM) + $(sed -rn 's/(PAM)(.*)/\1\2(context.Context, p.Request_\1\2) <-chan p.\1\2/p' "$requests") } EOF diff --git a/nslcd_server/type_nilbackend.go.gen b/nslcd_server/type_nilbackend.go.gen index 0c6f4b5..b7ea372 100755 --- a/nslcd_server/type_nilbackend.go.gen +++ b/nslcd_server/type_nilbackend.go.gen @@ -24,7 +24,8 @@ cat <<EOF | gofmt package nslcd_server import ( - "golang.org/x/sys/unix" + "context" + p "git.lukeshu.com/go/libnslcd/nslcd_proto" ) @@ -35,8 +36,8 @@ import ( type NilBackend struct{} $( - re_in='^\t([^(]+)\(unix\.Ucred, ([^)]+)\) <-chan (\S+)$' - re_out='func (o NilBackend) \1(unix.Ucred, \2) <-chan \3 { r := make(chan \3); close(r); return r }' + re_in='^\t([^(]+)\(context\.Context, ([^)]+)\) <-chan (\S+)$' + re_out='func (o NilBackend) \1(context.Context, \2) <-chan \3 { r := make(chan \3); close(r); return r }' < "$interface" sed -rn "s/$re_in/$re_out/p" ) diff --git a/nslcd_systemd/misc_test.go b/nslcd_systemd/misc_test.go index a75083d..bc7ace6 100644 --- a/nslcd_systemd/misc_test.go +++ b/nslcd_systemd/misc_test.go @@ -18,6 +18,7 @@ package nslcd_systemd_test import ( + "context" "fmt" "io/ioutil" "net" @@ -89,7 +90,7 @@ func testDriver( // server ////////////////////////////////////////////////////////////// errfatal(sdActivatedStream(t.tmpdir + "/nslcd.sock")) go func() { - evExitServer <- nslcd_systemd.Main(backend, limits) + evExitServer <- nslcd_systemd.Main(backend, limits, context.Background()) }() // client/driver /////////////////////////////////////////////////////// @@ -133,7 +134,7 @@ func (o *NonLockingBackend) Init() error { return nil } func (o *NonLockingBackend) Reload() error { return nil } func (o *NonLockingBackend) Close() {} -func (o *NonLockingBackend) Passwd_All(cred unix.Ucred, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd { +func (o *NonLockingBackend) Passwd_All(ctx context.Context, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd { ret := make(chan nslcd_proto.Passwd) go func() { defer close(ret) @@ -170,7 +171,7 @@ func (o *LockingBackend) Close() { o.NonLockingBackend.Close() } -func (o *LockingBackend) Passwd_All(cred unix.Ucred, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd { +func (o *LockingBackend) Passwd_All(ctx context.Context, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd { o.lock.RLock() ret := make(chan nslcd_proto.Passwd) go func() { diff --git a/nslcd_systemd/nslcd_systemd.go b/nslcd_systemd/nslcd_systemd.go index 0999106..b2f8e28 100644 --- a/nslcd_systemd/nslcd_systemd.go +++ b/nslcd_systemd/nslcd_systemd.go @@ -25,6 +25,7 @@ // package main // // import ( +// "context" // "os" // // "git.lukeshu.com/go/libnslcd/nslcd_server" @@ -34,11 +35,13 @@ // func main() { // backend := ... // limits := nslcd_server.Limits{ ... } -// os.Exit(int(nslcd_systemd.Main(backend, limits))) +// ctx := context.Background() +// os.Exit(int(nslcd_systemd.Main(backend, limits, ctx))) // } package nslcd_systemd // import "git.lukeshu.com/go/libnslcd/nslcd_systemd" import ( + "context" "fmt" "net" "os" @@ -94,22 +97,26 @@ func getpeercred(conn *net.UnixConn) (cred unix.Ucred, err error) { return } -func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net.UnixConn) { +func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net.UnixConn, ctx context.Context) { defer conn.Close() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + cred, err := getpeercred(conn) if err != nil { sd_daemon.Log.Debug("Connection from unknown client") } else { sd_daemon.Log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v", cred.Pid, cred.Uid, cred.Gid)) + ctx = context.WithValue(ctx, nslcd_server.PeerCredKey, cred) } - err = nslcd_server.HandleRequest(backend, limits, conn, cred) + err = nslcd_server.HandleRequest(backend, limits, conn, ctx) if err != nil { sd_daemon.Log.Notice(fmt.Sprintf("Error while handling request: %v", err)) } } -func Main(backend Backend, limits nslcd_server.Limits) uint8 { +func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint8 { defer sd_daemon.Recover() var err error = nil @@ -137,6 +144,9 @@ func Main(backend Backend, limits nslcd_server.Limits) uint8 { } defer func() { socket.(*net.UnixListener).SetDeadline(time.Now()) }() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + socket_error := make(chan error) wg.Add(1) go func() { @@ -171,7 +181,7 @@ func Main(backend Backend, limits nslcd_server.Limits) uint8 { go func() { defer sd_daemon.Recover() defer wg.Done() - handler(backend, limits, conn.(*net.UnixConn)) + handler(backend, limits, conn.(*net.UnixConn), ctx) }() } } @@ -196,6 +206,9 @@ func Main(backend Backend, limits nslcd_server.Limits) uint8 { } sd_daemon.Notification{State: "READY=1"}.Send(false) } + case <-ctx.Done(): + sd_daemon.Log.Err(fmt.Sprintf("Context was canceled, shutting down: %v", ctx.Err())) + return sd_daemon.EXIT_FAILURE case err = <-socket_error: sd_daemon.Log.Err(fmt.Sprintf("%v", err)) return sd_daemon.EXIT_NETWORK |