// Copyright (C) 2017 Luke Shumaker <lukeshu@sbcglobal.net>
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
// 02110-1301 USA

package nslcd_systemd_test

import (
	"fmt"
	"net"
	"os"
	"strings"
	"testing"
	"time"

	"github.com/pkg/errors"
	"golang.org/x/sys/unix"
)

func testWithTimeout(t *testing.T, timeout time.Duration, fn func(t *testing.T, s chan<- string)) {
	finished := make(chan bool)
	status := make(chan string)
	cur_status := ""
	go func() {
		t.Run("timed", func(t *testing.T) { fn(t, status) })
		finished <- true
	}()
	for {
		select {
		case cur_status = <-status:
		case <-finished:
			close(status)
			return
		case <-time.After(timeout):
			close(status)
			if cur_status != "" {
				t.Fatal("timed out: " + cur_status)
			} else {
				t.Fatal("timed out")
			}
			return
		}
	}
}

var sdListenFds = uintptr(0)
var sdListenFdNames = []string{}

type filer interface {
	Close() error
	File() (*os.File, error)
}

func sdOpenStream(streampath string) (filer, error) {
	// I should have this change based on type of stream
	listener, err := net.ListenUnix("unix", &net.UnixAddr{Net: "unix", Name: streampath})
	if err != nil {
		return nil, errors.Wrap(err, "net.ListenUnix()")
	}
	listener.SetUnlinkOnClose(false)
	return listener, nil
}

func fdIsDevNull(fd uintptr) error {
	file := os.NewFile(fd, fmt.Sprintf("/dev/fd/%d", fd))
	if file == nil {
		return errors.Errorf("not a valid file descriptor: %d", fd)
	}

	statFd, err := file.Stat()
	if err != nil {
		return err
	}

	statNull, err := os.Stat("/dev/null")
	if err != nil {
		return err
	}

	if !os.SameFile(statFd, statNull) {
		return errors.Errorf("FD %d is not /dev/null", fd)
	}
	return nil
}

func sdActivatedStream(streampath string) error {
	// Set up the file descriptor
	err := func(streampath string) error {
		file, err := func(streampath string) (*os.File, error) {
			listener, err := sdOpenStream(streampath)
			if err != nil {
				return nil, err
			}
			defer listener.Close()
			file, err := listener.File()
			if err != nil {
				return nil, errors.Wrap(err, "listener.File()")
			}
			return file, nil
		}(streampath)
		if err != nil {
			return err
		}
		defer file.Close()

		fd := sdListenFds + 3
		err = fdIsDevNull(fd)
		if err != nil {
			return err
		}
		err = unix.Dup2(int(file.Fd()), int(fd))
		if err != nil {
			return errors.Wrap(err, "Dup2()")
		}
		return nil
	}(streampath)
	if err != nil {
		return err
	}

	sdListenFds++
	sdListenFdNames = append(sdListenFdNames, streampath)

	err = os.Setenv("LISTEN_PID", fmt.Sprintf("%d", os.Getpid()))
	if err != nil {
		return errors.Wrap(err, "os.Setenv()")
	}
	err = os.Setenv("LISTEN_FDS", fmt.Sprintf("%d", sdListenFds))
	if err != nil {
		return errors.Wrap(err, "os.Setenv()")
	}
	err = os.Setenv("LISTEN_FDNAMES", strings.Join(sdListenFdNames, ":"))
	if err != nil {
		return errors.Wrap(err, "os.Setenv()")
	}
	return nil
}

func sdActivatedReset() error {
	devnull, err := os.OpenFile("/dev/null", os.O_RDWR, 0666)
	if err != nil {
		return err
	}
	defer devnull.Close()
	for i := uintptr(0); i < sdListenFds; i++ {
		err = unix.Dup2(int(devnull.Fd()), int(3+i))
		if err != nil {
			return err
		}
	}
	sdListenFds = 0
	sdListenFdNames = []string{}
	return nil
}

func sdNotifyListen(sockname string) (*net.UnixConn, error) {
	err := os.Setenv("NOTIFY_SOCKET", sockname)
	if err != nil {
		return nil, err
	}
	return net.ListenUnixgram("unixgram", &net.UnixAddr{Net: "unixgram", Name: sockname})
}

func sdNotifyHandle(sock *net.UnixConn, fn func(dat []byte, oob []byte) error) error {
	var dat [4096]byte
	oob := make([]byte, unix.CmsgSpace(unix.SizeofUcred)+unix.CmsgSpace(8*768))
	for {
		n, oobn, flags, _, err := sock.ReadMsgUnix(dat[:], oob[:])
		if err != nil {
			return err
		}
		if flags&unix.MSG_TRUNC != 0 {
			// Received notify message exceeded maximum size. Ignoring."
			continue
		}
		err = fn(dat[:n], oob[:oobn])
		if err != nil {
			return err
		}
	}
}

func humanizeU64(n uint64) string {
	str := fmt.Sprintf("%d", n)
	bts := make([]byte, len(str)+(len(str)-1)/3)

	s := 0
	b := 0
	for s < len(str) && b < len(bts) {
		if (s % 3 == 0 && s > 0) {
			bts[len(bts)-1-b] = ','
			b++
		}
		bts[len(bts)-1-b] = str[len(str)-1-s]
		b++
		s++
	}

	return string(bts)
}