From 7b8aefea056f995ee2d00a79c22277c09cda5363 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Mon, 4 Sep 2017 19:42:01 -0400 Subject: add tests --- nslcd_systemd/misc_test.go | 251 +++++++++++++++++++++++++++++++++++++++++++++ nslcd_systemd/util_test.go | 193 ++++++++++++++++++++++++++++++++++ 2 files changed, 444 insertions(+) create mode 100644 nslcd_systemd/misc_test.go create mode 100644 nslcd_systemd/util_test.go (limited to 'nslcd_systemd') diff --git a/nslcd_systemd/misc_test.go b/nslcd_systemd/misc_test.go new file mode 100644 index 0000000..d2b2a7e --- /dev/null +++ b/nslcd_systemd/misc_test.go @@ -0,0 +1,251 @@ +// Copyright (C) 2017 Luke Shumaker +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 2.1 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +// 02110-1301 USA + +package nslcd_systemd_test + +import ( + "fmt" + "io/ioutil" + "net" + "os" + "strings" + "sync" + "syscall" + "testing" + "time" + + "git.lukeshu.com/go/libnslcd/nslcd_proto" + "git.lukeshu.com/go/libnslcd/nslcd_server" + "git.lukeshu.com/go/libnslcd/nslcd_systemd" + "golang.org/x/sys/unix" +) + +type testContext struct { + *testing.T + tmpdir string + status chan<- string +} + +func testBadClient(t *testContext, backend nslcd_systemd.Backend, toclose bool) { + t.status <- "setting up" + + errfatal := func(err error) { + if err != nil { + t.Fatal(err) + } + } + + evExitSupervisor := make(chan error) + evExitServer := make(chan uint8) + evReload := make(chan bool) + + // supervisor ////////////////////////////////////////////////////////// + notify_sock, err := sdNotifyListen(t.tmpdir + "/notify.sock") + errfatal(err) + go func() { + reloading := false + evExitSupervisor <- sdNotifyHandle(notify_sock, func(dat []byte, oob []byte) error { + for _, line := range strings.Split(string(dat), "\n") { + switch line { + case "RELOADING=1": + reloading = true + case "READY=1": + if reloading { + evReload <- true + } + reloading = false + } + } + return nil + }) + }() + + // server ////////////////////////////////////////////////////////////// + errfatal(sdActivatedStream(t.tmpdir + "/nslcd.sock")) + go func() { + evExitServer <- nslcd_systemd.Main(backend) + }() + + // client/driver /////////////////////////////////////////////////////// + + t.status <- "talking with server" + conn, err := net.Dial("unix", t.tmpdir+"/nslcd.sock") + errfatal(err) + errfatal(nslcd_proto.Write(conn, nslcd_proto.NSLCD_VERSION)) + errfatal(nslcd_proto.Write(conn, nslcd_proto.NSLCD_ACTION_PASSWD_ALL)) + // Wait for NSLCD_RESULT_*, to make sure the server has made + // it in to backend code. + var n int32 + errfatal(nslcd_proto.Read(conn, &n)) + if n != nslcd_proto.NSLCD_VERSION { + t.Fatal("server version wrong") + } + errfatal(nslcd_proto.Read(conn, &n)) + if n != nslcd_proto.NSLCD_ACTION_PASSWD_ALL { + t.Fatal("server action wrong") + } + errfatal(nslcd_proto.Read(conn, &n)) + if n != nslcd_proto.NSLCD_RESULT_BEGIN && n != nslcd_proto.NSLCD_RESULT_END { + t.Fatal("server result malformed") + } + if toclose { + errfatal(conn.Close()) + } + + t.status <- "waiting for server reload" + errfatal(unix.Kill(unix.Getpid(), unix.SIGHUP)) + <-evReload + + // A limitation of Unix sockets is that some may get dropped + // if they arrive close together. So give it some (a half + // second is probably generous by a couple orders of + // magnitude) time to handle SIGHUP before sending SIGTERM, so + // that we are sure it gets both. + time.Sleep(time.Second / 2) + + t.status <- "waiting for server exit" + errfatal(unix.Kill(unix.Getpid(), unix.SIGTERM)) + status := <-evExitServer + if status != 0 { + t.Fatalf("Main() exited with %d", status) + } + + t.status <- "waiting for supervisor exit" + errfatal(notify_sock.SetReadDeadline(time.Now())) + err = <-evExitSupervisor + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + err = nil + } + errfatal(err) + errfatal(notify_sock.Close()) +} + +type NonLockingBackend struct { + nslcd_server.NilBackend +} + +func (o *NonLockingBackend) Init() error { return nil } +func (o *NonLockingBackend) Reload() error { return nil } +func (o *NonLockingBackend) Close() {} + +func (o *NonLockingBackend) Passwd_All(cred unix.Ucred, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd { + ret := make(chan nslcd_proto.Passwd) + go func() { + defer close(ret) + + for i := 0; i < 500; i++ { + ret <- nslcd_proto.Passwd{ + Name: fmt.Sprintf("user%d", i), + PwHash: "x", + UID: int32(1000 + i), + GID: 1000, + GECOS: fmt.Sprintf("User %d", i), + HomeDir: fmt.Sprintf("/home/user%d", i), + Shell: "/bin/sh", + } + } + }() + return ret +} + +type LockingBackend struct { + NonLockingBackend + lock sync.RWMutex +} + +func (o *LockingBackend) Reload() error { + o.lock.Lock() + defer o.lock.Unlock() + return o.NonLockingBackend.Reload() +} + +func (o *LockingBackend) Close() { + o.lock.Lock() + defer o.lock.Unlock() + o.NonLockingBackend.Close() +} + +func (o *LockingBackend) Passwd_All(cred unix.Ucred, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd { + o.lock.RLock() + ret := make(chan nslcd_proto.Passwd) + go func() { + defer o.lock.RUnlock() + defer close(ret) + + for i := 0; i < 500; i++ { + ret <- nslcd_proto.Passwd{ + Name: fmt.Sprintf("user%d", i), + PwHash: "x", + UID: int32(1000 + i), + GID: 1000, + GECOS: fmt.Sprintf("User %d", i), + HomeDir: fmt.Sprintf("/home/user%d", i), + Shell: "/bin/sh", + } + } + }() + return ret +} + +func init() { + if fdIsDevNull(3) == nil { + return + } + + devnull, err := os.OpenFile("/dev/null", os.O_RDWR, 0666) + if err != nil { + panic(err) + } + if devnull.Fd() == 3 { + return + } + + fmt.Fprintln(os.Stderr, "Could not open /dev/null on FD 3; calling dup2 and re-exec()ing") + // shell out to do the FD manipulation--If we made it here, + // there's a good chance that FD3 was managed by the go + // runtime, and would be changed before we execve(2). + panic(syscall.Exec("/bin/sh", append([]string{"sh", "-c", "exec -- \"$@\" 3<>/dev/null"}, os.Args...), os.Environ())) +} + +func TestBadClient(t *testing.T) { + testcases := []struct { + name string + backend nslcd_systemd.Backend + toclose bool + }{ + {"NoLocks-ClientOpen", &NonLockingBackend{}, false}, + {"NoLocks-ClientClose", &NonLockingBackend{}, true}, + {"Locking-ClientOpen", &LockingBackend{}, false}, + {"Locking-ClientClose", &LockingBackend{}, true}, + } + for _, testcase := range testcases { + func() { + tmpdir, err := ioutil.TempDir("", "go-test-libnslcd-bad-client.") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpdir) + defer sdActivatedReset() + t.Run(testcase.name, func(t *testing.T) { + testWithTimeout(t, 2*time.Second, func(t *testing.T, s chan<- string) { + ctx := &testContext{T: t, tmpdir: tmpdir, status: s} + testBadClient(ctx, testcase.backend, testcase.toclose) + }) + }) + }() + } +} diff --git a/nslcd_systemd/util_test.go b/nslcd_systemd/util_test.go new file mode 100644 index 0000000..15147b7 --- /dev/null +++ b/nslcd_systemd/util_test.go @@ -0,0 +1,193 @@ +// Copyright (C) 2017 Luke Shumaker +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 2.1 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +// 02110-1301 USA + +package nslcd_systemd_test + +import ( + "fmt" + "net" + "os" + "strings" + "testing" + "time" + + "github.com/pkg/errors" + "golang.org/x/sys/unix" +) + +func testWithTimeout(t *testing.T, timeout time.Duration, fn func(t *testing.T, s chan<- string)) { + finished := make(chan bool) + status := make(chan string) + cur_status := "" + go func() { + t.Run("timed", func(t *testing.T) { fn(t, status) }) + finished <- true + }() + for { + select { + case cur_status = <-status: + case <-finished: + close(status) + return + case <-time.After(timeout): + close(status) + if cur_status != "" { + t.Fatal("timed out: " + cur_status) + } else { + t.Fatal("timed out") + } + return + } + } +} + +var sdListenFds = uintptr(0) +var sdListenFdNames = []string{} + +type filer interface { + Close() error + File() (*os.File, error) +} + +func sdOpenStream(streampath string) (filer, error) { + // I should have this change based on type of stream + listener, err := net.ListenUnix("unix", &net.UnixAddr{Net: "unix", Name: streampath}) + if err != nil { + return nil, errors.Wrap(err, "net.ListenUnix()") + } + listener.SetUnlinkOnClose(false) + return listener, nil +} + +func fdIsDevNull(fd uintptr) error { + file := os.NewFile(fd, fmt.Sprintf("/dev/fd/%d", fd)) + if file == nil { + return errors.Errorf("not a valid file descriptor: %d", fd) + } + + statFd, err := file.Stat() + if err != nil { + return err + } + + statNull, err := os.Stat("/dev/null") + if err != nil { + return err + } + + if !os.SameFile(statFd, statNull) { + return errors.Errorf("FD %d is not /dev/null", fd) + } + return nil +} + +func sdActivatedStream(streampath string) error { + // Set up the file descriptor + err := func(streampath string) error { + file, err := func(streampath string) (*os.File, error) { + listener, err := sdOpenStream(streampath) + if err != nil { + return nil, err + } + defer listener.Close() + file, err := listener.File() + if err != nil { + return nil, errors.Wrap(err, "listener.File()") + } + return file, nil + }(streampath) + if err != nil { + return err + } + defer file.Close() + + fd := sdListenFds + 3 + err = fdIsDevNull(fd) + if err != nil { + return err + } + err = unix.Dup2(int(file.Fd()), int(fd)) + if err != nil { + return errors.Wrap(err, "Dup2()") + } + return nil + }(streampath) + if err != nil { + return err + } + + sdListenFds++ + sdListenFdNames = append(sdListenFdNames, streampath) + + err = os.Setenv("LISTEN_PID", fmt.Sprintf("%d", os.Getpid())) + if err != nil { + return errors.Wrap(err, "os.Setenv()") + } + err = os.Setenv("LISTEN_FDS", fmt.Sprintf("%d", sdListenFds)) + if err != nil { + return errors.Wrap(err, "os.Setenv()") + } + err = os.Setenv("LISTEN_FDNAMES", strings.Join(sdListenFdNames, ":")) + if err != nil { + return errors.Wrap(err, "os.Setenv()") + } + return nil +} + +func sdActivatedReset() error { + devnull, err := os.OpenFile("/dev/null", os.O_RDWR, 0666) + if err != nil { + return err + } + defer devnull.Close() + for i := uintptr(0); i < sdListenFds; i++ { + err = unix.Dup2(int(devnull.Fd()), int(3+i)) + if err != nil { + return err + } + } + sdListenFds = 0 + sdListenFdNames = []string{} + return nil +} + +func sdNotifyListen(sockname string) (*net.UnixConn, error) { + err := os.Setenv("NOTIFY_SOCKET", sockname) + if err != nil { + return nil, err + } + return net.ListenUnixgram("unixgram", &net.UnixAddr{Net: "unixgram", Name: sockname}) +} + +func sdNotifyHandle(sock *net.UnixConn, fn func(dat []byte, oob []byte) error) error { + var dat [4096]byte + oob := make([]byte, unix.CmsgSpace(unix.SizeofUcred)+unix.CmsgSpace(8*768)) + for { + n, oobn, flags, _, err := sock.ReadMsgUnix(dat[:], oob[:]) + if err != nil { + return err + } + if flags&unix.MSG_TRUNC != 0 { + // Received notify message exceeded maximum size. Ignoring." + continue + } + err = fn(dat[:n], oob[:oobn]) + if err != nil { + return err + } + } +} -- cgit v1.2.3-2-g168b