// Copyright (C) 2015-2017 Luke Shumaker <lukeshu@sbcglobal.net>
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
// 02110-1301 USA

// Package nslcd_systemd does the legwork for implementing a systemd
// socket-activated nslcd server.
//
// You just need to implement the Backend interface, then pass it to
// Main, which will return the exit code for the process.  Everything
// but the backend is taken care of for you!
//
// 	package main
//
// 	import (
// 		"context"
// 		"os"
//
// 		"git.lukeshu.com/go/libnslcd/nslcd_server"
// 		"git.lukeshu.com/go/libnslcd/nslcd_systemd"
// 	)
//
// 	func main() {
// 		backend := ...
// 		limits := nslcd_server.Limits{ ... }
// 		ctx := context.Background()
// 		os.Exit(int(nslcd_systemd.Main(backend, limits, ctx)))
// 	}
package nslcd_systemd // import "git.lukeshu.com/go/libnslcd/nslcd_systemd"

import (
	"context"
	"fmt"
	"net"
	"os"
	"os/signal"
	"sync"
	"time"

	"git.lukeshu.com/go/libnslcd/nslcd_server"
	"git.lukeshu.com/go/libsystemd/sd_daemon"
	"golang.org/x/sys/unix"
)

type Backend interface {
	nslcd_server.Backend
	Init() error
	Reload() error
	Close()
}

func get_socket() (socket net.Listener, err error) {
	fds := sd_daemon.ListenFds(true)
	if fds == nil {
		err = fmt.Errorf("Failed to aquire sockets from systemd")
		return
	}
	if len(fds) != 1 {
		err = fmt.Errorf("Wrong number of sockets from systemd: expected %d but got %d", 1, len(fds))
		return
	}
	socket, err = net.FileListener(fds[0])
	fds[0].Close()
	return
}

func getpeercred(conn *net.UnixConn) (cred unix.Ucred, err error) {
	rawconn, err := conn.SyscallConn()
	if err != nil {
		return
	}
	var _cred *unix.Ucred
	var _err error
	err = rawconn.Control(func(fd uintptr) {
		_cred, _err = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
	})
	if err != nil {
		return
	}
	if _err != nil {
		err = _err
		return
	}
	cred = *_cred
	return
}

func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net.UnixConn, cid int, ctx context.Context) {
	defer conn.Close()
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	log := PrefixLogger{
		Prefix: fmt.Sprintf("[%v] ", cid),
		Logger: nslcd_server.LoggerFromContext(ctx),
	}
	ctx = context.WithValue(ctx, nslcd_server.LoggerKey, log)
	defer log.Info("Connection closed")

	cred, err := getpeercred(conn)
	if err != nil {
		log.Info("Connection from unknown client")
	} else {
		log.Info(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v",
			cred.Pid, cred.Uid, cred.Gid))
		ctx = context.WithValue(ctx, nslcd_server.PeerCredKey, cred)
	}
	err = nslcd_server.HandleRequest(backend, limits, conn, ctx)
	if err != nil {
		log.Notice(fmt.Sprintf("Error while handling request: %v", err))
	}
}

func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint8 {
	defer sd_daemon.Recover()
	var err error = nil

	sigs := make(chan os.Signal)
	signal.Notify(sigs, unix.SIGTERM, unix.SIGHUP)

	log := nslcd_server.LoggerFromContext(ctx)

	disable_nss_module(log)

	err = backend.Init()
	if err != nil {
		log.Err(fmt.Sprintf("Could not initialize backend: %v", err))
		sd_daemon.Notification{State: "STOPPING=1"}.Send(false)
		return sd_daemon.EXIT_FAILURE
	}
	defer backend.Close()

	var wg sync.WaitGroup
	defer wg.Wait()

	socket, err := get_socket()
	if err != nil {
		log.Err(fmt.Sprintf("%v", err))
		sd_daemon.Notification{State: "STOPPING=1"}.Send(false)
		return sd_daemon.EXIT_NOTRUNNING
	}
	defer func() { socket.(*net.UnixListener).SetDeadline(time.Now()) }()

	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	socket_error := make(chan error)
	wg.Add(1)
	go func() {
		defer sd_daemon.Recover()
		defer wg.Done()

		id := 0

		var tempDelay time.Duration
		last := false
		for !last {
			conn, err := socket.Accept()
			if err != nil {
				if ne, ok := err.(net.Error); ok && ne.Timeout() {
					last = true
				} else if ne, ok := err.(net.Error); ok && ne.Temporary() {
					log.Notice(fmt.Sprintf("temporary error %v", err))
					if tempDelay == 0 {
						tempDelay = 5 * time.Millisecond
					} else {
						tempDelay *= 2
					}
					if max := 1 * time.Second; tempDelay > max {
						tempDelay = max
					}
					time.Sleep(tempDelay)
				} else {
					socket_error <- err
					last = true
				}
			}
			if conn != nil {
				wg.Add(1)
				id++
				go func(conn *net.UnixConn, id int) {
					defer sd_daemon.Recover()
					defer wg.Done()
					handler(backend, limits, conn, id, ctx)
				}(conn.(*net.UnixConn), id)
			}
		}
	}()

	defer sd_daemon.Notification{State: "STOPPING=1"}.Send(false)
	sd_daemon.Notification{State: "READY=1"}.Send(false)
	for {
		select {
		case sig := <-sigs:
			switch sig {
			case unix.SIGTERM:
				log.Notice("Received SIGTERM, shutting down")
				return sd_daemon.EXIT_SUCCESS
			case unix.SIGHUP:
				log.Notice("Received SIGHUP, reloading")
				sd_daemon.Notification{State: "RELOADING=1"}.Send(false)
				err := backend.Reload()
				if err != nil {
					log.Notice(fmt.Sprintf("Could not reload backend: %v", err))
					return sd_daemon.EXIT_NOTRUNNING
				}
				sd_daemon.Notification{State: "READY=1"}.Send(false)
			}
		case <-ctx.Done():
			log.Err(fmt.Sprintf("Context was canceled, shutting down: %v", ctx.Err()))
			return sd_daemon.EXIT_FAILURE
		case err = <-socket_error:
			log.Err(fmt.Sprintf("%v", err))
			return sd_daemon.EXIT_NETWORK
		}
	}
}