// Copyright (C) 2015-2017 Luke Shumaker <lukeshu@sbcglobal.net>
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
// 02110-1301 USA

// Package nslcd_systemd does the legwork for implementing a systemd
// socket-activated nslcd server.
//
// You just need to implement the Backend interface, then pass it to
// Main, which will return the exit code for the process.  Everything
// but the backend is taken care of for you!
//
// 	package main
//
// 	import "nslcd/systemd"
//
// 	func main() {
// 		backend := ...
// 		os.Exit(int(nslcd_systemd.Main(backend)))
// 	}
package nslcd_systemd // import "git.lukeshu.com/go/libnslcd/nslcd_systemd"

import (
	"fmt"
	"net"
	"os"
	"os/signal"
	"sync"

	"git.lukeshu.com/go/libnslcd/nslcd_server"
	"git.lukeshu.com/go/libsystemd/sd_daemon"
	"golang.org/x/sys/unix"
)

type Backend interface {
	nslcd_server.Backend
	Init() error
	Reload() error
	Close()
}

func get_socket() (socket net.Listener, err error) {
	fds := sd_daemon.ListenFds(true)
	if fds == nil {
		err = fmt.Errorf("Failed to aquire sockets from systemd")
		return
	}
	if len(fds) != 1 {
		err = fmt.Errorf("Wrong number of sockets from systemd: expected %d but got %d", 1, len(fds))
		return
	}
	socket, err = net.FileListener(fds[0])
	fds[0].Close()
	return
}

func getpeercred(conn *net.UnixConn) (cred unix.Ucred, err error) {
	file, err := conn.File()
	if err != nil {
		return
	}
	defer file.Close()
	_cred, err := unix.GetsockoptUcred(int(file.Fd()), unix.SOL_SOCKET, unix.SO_PEERCRED)
	cred = *_cred
	return
}

func handler(conn *net.UnixConn, backend nslcd_server.Backend) {
	defer conn.Close()
	cred, err := getpeercred(conn)
	if err != nil {
		sd_daemon.Log.Debug("Connection from unknown client")
	} else {
		sd_daemon.Log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v",
			cred.Pid, cred.Uid, cred.Gid))
	}
	err = nslcd_server.HandleRequest(backend, conn, conn, cred)
	if err != nil {
		sd_daemon.Log.Notice(fmt.Sprintf("Error while handling request: %v", err))
	}
}

func Main(backend Backend) uint8 {
	var err error = nil

	sigs := make(chan os.Signal)
	signal.Notify(sigs, unix.SIGTERM, unix.SIGHUP)

	disable_nss_module()

	err = backend.Init()
	if err != nil {
		sd_daemon.Log.Err(fmt.Sprintf("Could not initialize backend: %v", err))
		sd_daemon.Notification{State: "STOPPING=1"}.Send(false)
		return sd_daemon.EXIT_FAILURE
	}
	defer backend.Close()

	socket, err := get_socket()
	if err != nil {
		sd_daemon.Log.Err(fmt.Sprintf("%v", err))
		sd_daemon.Notification{State: "STOPPING=1"}.Send(false)
		return sd_daemon.EXIT_NOTRUNNING
	}
	defer socket.Close()
	sock := make(chan *net.UnixConn)
	go func() {
		defer sd_daemon.Recover()
		for {
			conn, err := socket.Accept()
			if err != nil {
				sd_daemon.Log.Notice(fmt.Sprintf("%v", err))
			}
			if conn != nil {
				sock <- conn.(*net.UnixConn)
			}
		}
	}()

	var wg sync.WaitGroup
	defer wg.Wait()
	defer sd_daemon.Notification{State: "STOPPING=1"}.Send(false)
	sd_daemon.Notification{State: "READY=1"}.Send(false)
	for {
		select {
		case sig := <-sigs:
			switch sig {
			case unix.SIGTERM:
				sd_daemon.Log.Notice("Received SIGTERM, shutting down")
				return sd_daemon.EXIT_SUCCESS
			case unix.SIGHUP:
				sd_daemon.Notification{State: "RELOADING=1"}.Send(false)
				err := backend.Reload()
				if err != nil {
					sd_daemon.Log.Notice(fmt.Sprintf("Could not reload backend: %s", err.Error()))
					return sd_daemon.EXIT_NOTRUNNING
				}
				sd_daemon.Notification{State: "READY=1"}.Send(false)
			}
		case conn := <-sock:
			wg.Add(1)
			go func() {
				defer sd_daemon.Recover()
				defer wg.Done()
				handler(conn, backend)
			}()
		}
	}
}