diff options
Diffstat (limited to 'nslcd_systemd')
-rw-r--r-- | nslcd_systemd/disable_nss_module.go | 12 | ||||
-rw-r--r-- | nslcd_systemd/nslcd_systemd.go | 44 |
2 files changed, 37 insertions, 19 deletions
diff --git a/nslcd_systemd/disable_nss_module.go b/nslcd_systemd/disable_nss_module.go index df22360..32b105a 100644 --- a/nslcd_systemd/disable_nss_module.go +++ b/nslcd_systemd/disable_nss_module.go @@ -23,7 +23,7 @@ import ( "fmt" "git.lukeshu.com/go/libgnulinux/dl" - "git.lukeshu.com/go/libsystemd/sd_daemon" + "git.lukeshu.com/go/libnslcd/nslcd_server" ) //static char *strary(char **ary, unsigned int n) { return ary[n]; } @@ -35,27 +35,27 @@ const ( nss_module_sym_enablelookups = "_nss_ldap_enablelookups" ) -func disable_nss_module() { +func disable_nss_module(log nslcd_server.Logger) { handle, err := dl.Open(nss_module_soname, dl.RTLD_LAZY|dl.RTLD_NODELETE) if err == nil { defer handle.Close() } else { - sd_daemon.Log.Warning(fmt.Sprintf("NSS module %s not loaded: %v", nss_module_soname, err)) + log.Warning(fmt.Sprintf("NSS module %s not loaded: %v", nss_module_soname, err)) return } c_version_info, err := handle.Sym(nss_module_sym_version) if err == nil { g_version_info := (**C.char)(c_version_info) - sd_daemon.Log.Debug(fmt.Sprintf("NSS module %s version %s %s", nss_module_soname, + log.Debug(fmt.Sprintf("NSS module %s version %s %s", nss_module_soname, C.GoString(C.strary(g_version_info, 0)), C.GoString(C.strary(g_version_info, 1)))) } else { - sd_daemon.Log.Warning(fmt.Sprintf("NSS module %s version missing: %v", nss_module_soname, err)) + log.Warning(fmt.Sprintf("NSS module %s version missing: %v", nss_module_soname, err)) } c_enable_flag, err := handle.Sym(nss_module_sym_enablelookups) if err != nil { - sd_daemon.Log.Warning(fmt.Sprintf("Unable to disable NSS ldap module for nslcd process: %v", err)) + log.Warning(fmt.Sprintf("Unable to disable NSS ldap module for nslcd process: %v", err)) return } g_enable_flag := (*C.int)(c_enable_flag) diff --git a/nslcd_systemd/nslcd_systemd.go b/nslcd_systemd/nslcd_systemd.go index b2f8e28..29d49d6 100644 --- a/nslcd_systemd/nslcd_systemd.go +++ b/nslcd_systemd/nslcd_systemd.go @@ -61,6 +61,15 @@ type Backend interface { Close() } +type contextKey struct { + name string +} + +var ( + // ConnectionIDKey is a context key. + ConnectionIDKey = &contextKey{"connection-id"} +) + func get_socket() (socket net.Listener, err error) { fds := sd_daemon.ListenFds(true) if fds == nil { @@ -102,17 +111,20 @@ func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net ctx, cancel := context.WithCancel(ctx) defer cancel() + // TODO: override nslcd_server.LoggerKey with a logger that includes ConnectionIDKey + log := nslcd_server.LoggerFromContext(ctx) + cred, err := getpeercred(conn) if err != nil { - sd_daemon.Log.Debug("Connection from unknown client") + log.Debug("Connection from unknown client") } else { - sd_daemon.Log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v", + log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v", cred.Pid, cred.Uid, cred.Gid)) ctx = context.WithValue(ctx, nslcd_server.PeerCredKey, cred) } err = nslcd_server.HandleRequest(backend, limits, conn, ctx) if err != nil { - sd_daemon.Log.Notice(fmt.Sprintf("Error while handling request: %v", err)) + log.Notice(fmt.Sprintf("Error while handling request: %v", err)) } } @@ -123,11 +135,13 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint sigs := make(chan os.Signal) signal.Notify(sigs, unix.SIGTERM, unix.SIGHUP) - disable_nss_module() + log := nslcd_server.LoggerFromContext(ctx) + + disable_nss_module(log) err = backend.Init() if err != nil { - sd_daemon.Log.Err(fmt.Sprintf("Could not initialize backend: %v", err)) + log.Err(fmt.Sprintf("Could not initialize backend: %v", err)) sd_daemon.Notification{State: "STOPPING=1"}.Send(false) return sd_daemon.EXIT_FAILURE } @@ -138,7 +152,7 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint socket, err := get_socket() if err != nil { - sd_daemon.Log.Err(fmt.Sprintf("%v", err)) + log.Err(fmt.Sprintf("%v", err)) sd_daemon.Notification{State: "STOPPING=1"}.Send(false) return sd_daemon.EXIT_NOTRUNNING } @@ -153,6 +167,8 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint defer sd_daemon.Recover() defer wg.Done() + id := 0 + var tempDelay time.Duration last := false for !last { @@ -161,7 +177,7 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint if ne, ok := err.(net.Error); ok && ne.Timeout() { last = true } else if ne, ok := err.(net.Error); ok && ne.Temporary() { - sd_daemon.Log.Notice(fmt.Sprintf("temporary error %v", err)) + log.Notice(fmt.Sprintf("temporary error %v", err)) if tempDelay == 0 { tempDelay = 5 * time.Millisecond } else { @@ -178,10 +194,12 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint } if conn != nil { wg.Add(1) + id++ + hctx := context.WithValue(ctx, ConnectionIDKey, id) go func() { defer sd_daemon.Recover() defer wg.Done() - handler(backend, limits, conn.(*net.UnixConn), ctx) + handler(backend, limits, conn.(*net.UnixConn), hctx) }() } } @@ -194,23 +212,23 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint case sig := <-sigs: switch sig { case unix.SIGTERM: - sd_daemon.Log.Notice("Received SIGTERM, shutting down") + log.Notice("Received SIGTERM, shutting down") return sd_daemon.EXIT_SUCCESS case unix.SIGHUP: - sd_daemon.Log.Notice("Received SIGHUP, reloading") + log.Notice("Received SIGHUP, reloading") sd_daemon.Notification{State: "RELOADING=1"}.Send(false) err := backend.Reload() if err != nil { - sd_daemon.Log.Notice(fmt.Sprintf("Could not reload backend: %v", err)) + log.Notice(fmt.Sprintf("Could not reload backend: %v", err)) return sd_daemon.EXIT_NOTRUNNING } sd_daemon.Notification{State: "READY=1"}.Send(false) } case <-ctx.Done(): - sd_daemon.Log.Err(fmt.Sprintf("Context was canceled, shutting down: %v", ctx.Err())) + log.Err(fmt.Sprintf("Context was canceled, shutting down: %v", ctx.Err())) return sd_daemon.EXIT_FAILURE case err = <-socket_error: - sd_daemon.Log.Err(fmt.Sprintf("%v", err)) + log.Err(fmt.Sprintf("%v", err)) return sd_daemon.EXIT_NETWORK } } |