summaryrefslogtreecommitdiff
path: root/nslcd_systemd
diff options
context:
space:
mode:
Diffstat (limited to 'nslcd_systemd')
-rw-r--r--nslcd_systemd/disable_nss_module.go12
-rw-r--r--nslcd_systemd/nslcd_systemd.go44
2 files changed, 37 insertions, 19 deletions
diff --git a/nslcd_systemd/disable_nss_module.go b/nslcd_systemd/disable_nss_module.go
index df22360..32b105a 100644
--- a/nslcd_systemd/disable_nss_module.go
+++ b/nslcd_systemd/disable_nss_module.go
@@ -23,7 +23,7 @@ import (
"fmt"
"git.lukeshu.com/go/libgnulinux/dl"
- "git.lukeshu.com/go/libsystemd/sd_daemon"
+ "git.lukeshu.com/go/libnslcd/nslcd_server"
)
//static char *strary(char **ary, unsigned int n) { return ary[n]; }
@@ -35,27 +35,27 @@ const (
nss_module_sym_enablelookups = "_nss_ldap_enablelookups"
)
-func disable_nss_module() {
+func disable_nss_module(log nslcd_server.Logger) {
handle, err := dl.Open(nss_module_soname, dl.RTLD_LAZY|dl.RTLD_NODELETE)
if err == nil {
defer handle.Close()
} else {
- sd_daemon.Log.Warning(fmt.Sprintf("NSS module %s not loaded: %v", nss_module_soname, err))
+ log.Warning(fmt.Sprintf("NSS module %s not loaded: %v", nss_module_soname, err))
return
}
c_version_info, err := handle.Sym(nss_module_sym_version)
if err == nil {
g_version_info := (**C.char)(c_version_info)
- sd_daemon.Log.Debug(fmt.Sprintf("NSS module %s version %s %s", nss_module_soname,
+ log.Debug(fmt.Sprintf("NSS module %s version %s %s", nss_module_soname,
C.GoString(C.strary(g_version_info, 0)),
C.GoString(C.strary(g_version_info, 1))))
} else {
- sd_daemon.Log.Warning(fmt.Sprintf("NSS module %s version missing: %v", nss_module_soname, err))
+ log.Warning(fmt.Sprintf("NSS module %s version missing: %v", nss_module_soname, err))
}
c_enable_flag, err := handle.Sym(nss_module_sym_enablelookups)
if err != nil {
- sd_daemon.Log.Warning(fmt.Sprintf("Unable to disable NSS ldap module for nslcd process: %v", err))
+ log.Warning(fmt.Sprintf("Unable to disable NSS ldap module for nslcd process: %v", err))
return
}
g_enable_flag := (*C.int)(c_enable_flag)
diff --git a/nslcd_systemd/nslcd_systemd.go b/nslcd_systemd/nslcd_systemd.go
index b2f8e28..29d49d6 100644
--- a/nslcd_systemd/nslcd_systemd.go
+++ b/nslcd_systemd/nslcd_systemd.go
@@ -61,6 +61,15 @@ type Backend interface {
Close()
}
+type contextKey struct {
+ name string
+}
+
+var (
+ // ConnectionIDKey is a context key.
+ ConnectionIDKey = &contextKey{"connection-id"}
+)
+
func get_socket() (socket net.Listener, err error) {
fds := sd_daemon.ListenFds(true)
if fds == nil {
@@ -102,17 +111,20 @@ func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net
ctx, cancel := context.WithCancel(ctx)
defer cancel()
+ // TODO: override nslcd_server.LoggerKey with a logger that includes ConnectionIDKey
+ log := nslcd_server.LoggerFromContext(ctx)
+
cred, err := getpeercred(conn)
if err != nil {
- sd_daemon.Log.Debug("Connection from unknown client")
+ log.Debug("Connection from unknown client")
} else {
- sd_daemon.Log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v",
+ log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v",
cred.Pid, cred.Uid, cred.Gid))
ctx = context.WithValue(ctx, nslcd_server.PeerCredKey, cred)
}
err = nslcd_server.HandleRequest(backend, limits, conn, ctx)
if err != nil {
- sd_daemon.Log.Notice(fmt.Sprintf("Error while handling request: %v", err))
+ log.Notice(fmt.Sprintf("Error while handling request: %v", err))
}
}
@@ -123,11 +135,13 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
sigs := make(chan os.Signal)
signal.Notify(sigs, unix.SIGTERM, unix.SIGHUP)
- disable_nss_module()
+ log := nslcd_server.LoggerFromContext(ctx)
+
+ disable_nss_module(log)
err = backend.Init()
if err != nil {
- sd_daemon.Log.Err(fmt.Sprintf("Could not initialize backend: %v", err))
+ log.Err(fmt.Sprintf("Could not initialize backend: %v", err))
sd_daemon.Notification{State: "STOPPING=1"}.Send(false)
return sd_daemon.EXIT_FAILURE
}
@@ -138,7 +152,7 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
socket, err := get_socket()
if err != nil {
- sd_daemon.Log.Err(fmt.Sprintf("%v", err))
+ log.Err(fmt.Sprintf("%v", err))
sd_daemon.Notification{State: "STOPPING=1"}.Send(false)
return sd_daemon.EXIT_NOTRUNNING
}
@@ -153,6 +167,8 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
defer sd_daemon.Recover()
defer wg.Done()
+ id := 0
+
var tempDelay time.Duration
last := false
for !last {
@@ -161,7 +177,7 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
if ne, ok := err.(net.Error); ok && ne.Timeout() {
last = true
} else if ne, ok := err.(net.Error); ok && ne.Temporary() {
- sd_daemon.Log.Notice(fmt.Sprintf("temporary error %v", err))
+ log.Notice(fmt.Sprintf("temporary error %v", err))
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
@@ -178,10 +194,12 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
}
if conn != nil {
wg.Add(1)
+ id++
+ hctx := context.WithValue(ctx, ConnectionIDKey, id)
go func() {
defer sd_daemon.Recover()
defer wg.Done()
- handler(backend, limits, conn.(*net.UnixConn), ctx)
+ handler(backend, limits, conn.(*net.UnixConn), hctx)
}()
}
}
@@ -194,23 +212,23 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
case sig := <-sigs:
switch sig {
case unix.SIGTERM:
- sd_daemon.Log.Notice("Received SIGTERM, shutting down")
+ log.Notice("Received SIGTERM, shutting down")
return sd_daemon.EXIT_SUCCESS
case unix.SIGHUP:
- sd_daemon.Log.Notice("Received SIGHUP, reloading")
+ log.Notice("Received SIGHUP, reloading")
sd_daemon.Notification{State: "RELOADING=1"}.Send(false)
err := backend.Reload()
if err != nil {
- sd_daemon.Log.Notice(fmt.Sprintf("Could not reload backend: %v", err))
+ log.Notice(fmt.Sprintf("Could not reload backend: %v", err))
return sd_daemon.EXIT_NOTRUNNING
}
sd_daemon.Notification{State: "READY=1"}.Send(false)
}
case <-ctx.Done():
- sd_daemon.Log.Err(fmt.Sprintf("Context was canceled, shutting down: %v", ctx.Err()))
+ log.Err(fmt.Sprintf("Context was canceled, shutting down: %v", ctx.Err()))
return sd_daemon.EXIT_FAILURE
case err = <-socket_error:
- sd_daemon.Log.Err(fmt.Sprintf("%v", err))
+ log.Err(fmt.Sprintf("%v", err))
return sd_daemon.EXIT_NETWORK
}
}