#!/usr/bin/env bash
# -*- Mode: Go -*-
# Copyright (C) 2015-2017 Luke Shumaker <lukeshu@sbcglobal.net>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301 USA

requests=$1
printf '//'
printf ' %q' "$0" "$@"
printf '\n// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT\n\n'
cat <<EOF | gofmt
package nslcd_server

import (
	"context"
	"fmt"
	"io"
	"time"

	p "git.lukeshu.com/go/libnslcd/nslcd_proto"
)

const sensitive = "<omitted-from-log>"

func maybePanic(err error) {
	if err != nil {
		panic(err)
	}
}

type Limits struct {
	// What is the maximum total amount of time that we spend
	// handling a single request.  This includes both the time
	// reading the request and the time creating and writing the
	// response.
	Timeout time.Duration

	// How long can we spend reading a request?
	ReadTimeout time.Duration

	// How long can we spend writing a response?
	WriteTimeout time.Duration

	// What is the maximum request length in bytes that we are
	// willing to handle?
	RequestMaxSize int64
}

type Conn interface {
	// This is a subset of net.Conn; semantics are the same.

	Read(b []byte) (n int, err error)
	Write(b []byte) (n int, err error)
	SetDeadline(t time.Time) error
	SetReadDeadline(t time.Time) error
	SetWriteDeadline(t time.Time) error
}

// Handle a request to nslcd.  The caller is responsible for
// initializing the context with PeerCredKey.
func HandleRequest(backend Backend, limits Limits, conn Conn, ctx context.Context) (err error) {
	defer func() {
		if r := recover(); r != nil {
			switch r := r.(type) {
			case p.NslcdError:
				err = r
			default:
				panic(r)
			}
		}
	}()

	now := time.Now()
	deadlineAll := time.Time{}
	deadlineRead := time.Time{}
	if limits.Timeout != 0 {
		deadlineAll = now.Add(limits.Timeout)
	}
	if deadline, ok := ctx.Deadline(); ok {
		if deadlineAll.IsZero() || deadline.Before(deadlineAll) {
			deadlineAll = deadline
		}
	}
	if limits.ReadTimeout != 0 {
		deadlineRead = now.Add(limits.ReadTimeout)
		if !deadlineAll.IsZero() && deadlineAll.Before(deadlineRead) {
			deadlineRead = deadlineAll
		}
	}
	deadlineWrite := deadlineAll
	if !deadlineRead.IsZero() {
		err = conn.SetReadDeadline(deadlineRead)
		if err != nil {
			return err
		}
	}

	log := LoggerFromContext(ctx)

	var in io.Reader = conn
	if limits.RequestMaxSize > 0 {
		in = &io.LimitedReader{R: in, N: limits.RequestMaxSize}
	}
	out := conn

	var version int32
	maybePanic(p.Read(in, &version))
	if version != p.NSLCD_VERSION {
		return p.NslcdError(fmt.Sprintf("Version mismatch: server=%#08x client=%#08x", p.NSLCD_VERSION, version))
	}
	var action int32
	maybePanic(p.Read(in, &action))

	switch action {
$(
while read -r request; do
	cat <<EOT
	case p.NSLCD_ACTION_${request^^}:
		var req p.Request_${request}
		maybePanic(p.Read(in, &req))
		$(
		case "$request" in
			PAM_Authentication)
				echo '_req := req'
				echo '_req.Password = sensitive'
				echo 'log.Info(fmt.Sprintf("Request: %#v\n", _req))'
				;;
			PAM_PwMod)
				echo '_req := req'
				echo 'if len(_req.OldPassword) > 0 {'
				echo '	_req.OldPassword = sensitive'
				echo '}'
				echo '_req.NewPassword = sensitive'
				echo 'log.Info(fmt.Sprintf("Request: %#v", _req))'
				;;
			PAM_UserMod)
				echo '_req := req'
				echo '_req.Password = sensitive'
				echo 'log.Info(fmt.Sprintf("Request: %#v", _req))'
				;;
			*)
				echo 'log.Info(fmt.Sprintf("Request: %#v", req))'
				;;
		esac
		)

		if limits.WriteTimeout != 0 {
			deadlineWrite = time.Now().Add(limits.WriteTimeout)
			if !deadlineAll.IsZero() && deadlineAll.Before(deadlineWrite) {
				deadlineWrite = deadlineAll
			}
		}
		if !deadlineWrite.IsZero() {
			err = out.SetWriteDeadline(deadlineWrite)
			if err != nil {
				return err
			}
		}

		var cancel context.CancelFunc
		if deadline, ok := ctx.Deadline(); !ok || (!deadlineWrite.IsZero() && deadline.After(deadlineWrite)) {
			ctx, cancel = context.WithDeadline(ctx, deadlineWrite)
		} else {
			ctx, cancel = context.WithCancel(ctx)
		}
		defer cancel()

		maybePanic(p.Write(out, p.NSLCD_VERSION))
		maybePanic(p.Write(out, action))
		ch := backend.${request}(ctx, req)
		n := 0
		for result := range ch {
			if err == nil {
				err = p.Write(out, p.NSLCD_RESULT_BEGIN)
			}
			if err == nil {
				err = p.Write(out, result)
			}
			n++
			log.Info(fmt.Sprintf("Wrote %d results / err = %v", n, err))
		}
		maybePanic(err)
		maybePanic(p.Write(out, p.NSLCD_RESULT_END))
		return ctx.Err() // probably nil
EOT
done < "$requests"
)
	default:
		return p.NslcdError(fmt.Sprintf("Unknown request action: %#08x", action))
	}
}
EOF