// Copyright (C) 2015, 2017 Luke Shumaker <lukeshu@sbcglobal.net>
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
// 02110-1301 USA

package nslcd_proto

import (
	"encoding/binary"
	"fmt"
	"io"
	"net"
	"reflect"

	"golang.org/x/sys/unix"
)

// NslcdError represents a normal, expected error when dealing with
// the nslcd protocol.  Passing invalid data to a Write operation is
// *not* an NslcdError, nor is passing a non-pointer to a Read
// operation; those are programming errors, and result in a panic().
type NslcdError string

func (o NslcdError) Error() string {
	return string(o)
}

func npanic(err NslcdError) {
	panic(err)
}

// An nslcdObject is an object with a different network representation
// than a naive structure.
type nslcdObject interface {
	// May panic(interface{}) if given invalid data.
	//
	// May panic(NslcdError) if encountering a network error.
	nslcdWrite(fd io.Writer)
}

// An nslcdObjectPtr is a pointer to an object with a different
// network representation than a naive structure.
type nslcdObjectPtr interface {
	// May panic(NslcdError) if encountering a network error.
	nslcdRead(fd io.Reader)
}

// Write an object to a stream.  Any errors returned are of type
// NslcdError.
func Write(fd io.Writer, data interface{}) (err error) {
	defer func() {
		if r := recover(); r != nil {
			switch r := r.(type) {
			case NslcdError:
				err = r
			default:
				panic(r)
			}
		}
	}()
	write(fd, data)

	return err
}

// Read an object from a stream.  Any errors returned are of type
// NslcdError.  If the type assertion succeeds, then
// fd.(*io.LimitedReader).N is used to prevent an overly-large buffer
// from being allocated.
func Read(fd io.Reader, data interface{}) (err error) {
	defer func() {
		if r := recover(); r != nil {
			switch r := r.(type) {
			case NslcdError:
				err = r
			default:
				panic(r)
			}
		}
	}()
	read(fd, data)

	return err
}

// Write an object to a stream.  In the event of an error, this
// function may panic(NslcdError)!  Handle it!
func write(fd io.Writer, data interface{}) {
	switch data := data.(type) {
	// basic data types
	case nslcdObject:
		data.nslcdWrite(fd)
	case []byte:
		if len(data) > 0 {
			_, err := fd.Write(data)
			if err != nil {
				npanic(NslcdError(err.Error()))
			}
		}
	case int32:
		err := binary.Write(fd, binary.BigEndian, data)
		if err != nil {
			npanic(NslcdError(err.Error()))
		}
	// composite datatypes
	case string:
		write(fd, int32(len(data)))
		write(fd, []byte(data))
	case []string:
		write(fd, int32(len(data)))
		for _, item := range data {
			write(fd, item)
		}
	case net.IP:
		var af int32 = -1
		switch len(data) {
		case net.IPv4len:
			af = unix.AF_INET
		case net.IPv6len:
			af = unix.AF_INET6
		}
		var bytes []byte
		if af < 0 {
			bytes = make([]byte, 0)
		} else {
			bytes = data
		}
		write(fd, af)
		write(fd, int32(len(bytes)))
		write(fd, bytes)
	case []net.IP:
		write(fd, int32(len(data)))
		for _, item := range data {
			write(fd, item)
		}
	default:
		v := reflect.ValueOf(data)
		switch v.Kind() {
		case reflect.Struct:
			for i, n := 0, v.NumField(); i < n; i++ {
				write(fd, v.Field(i).Interface())
			}
		default:
			panic(fmt.Sprintf("Invalid structure to write NSLCD protocol data from: %T ( %#v )", data, data))
		}
	}
}

// Assert that we *will* read n bytes.  If we know now that < n bytes
// will be available, then this will let us avoid an allocation.
func willread(fd io.Reader, n int64) {
	if lfd, ok := fd.(*io.LimitedReader); ok {
		if n > lfd.N {
			npanic(NslcdError(io.EOF.Error()))
		}
	}
}

// Read an object from a stream.  In the event of an error, this
// function may panic(NslcdError)!  Handle it!
func read(fd io.Reader, data interface{}) {
	switch data := data.(type) {
	// basic data types
	case nslcdObjectPtr:
		data.nslcdRead(fd)
	case *[]byte:
		if len(*data) > 0 {
			_, err := io.ReadFull(fd, *data)
			if err != nil {
				npanic(NslcdError(err.Error()))
			}
		}
	case *int32:
		err := binary.Read(fd, binary.BigEndian, data)
		if err != nil {
			npanic(NslcdError(err.Error()))
		}
	// composite datatypes
	case *string:
		var len int32
		read(fd, &len)
		willread(fd, int64(len))
		buf := make([]byte, len)
		read(fd, &buf)
		*data = string(buf)
	case *[]string:
		var num int32
		read(fd, &num)
		willread(fd, int64(num * /* min size of a string is: */4))
		*data = make([]string, num)
		for i := 0; i < int(num); i++ {
			read(fd, &((*data)[i]))
		}
	case *net.IP:
		var af int32
		read(fd, &af)
		var _len int32
		switch af {
		case unix.AF_INET:
			_len = net.IPv4len
		case unix.AF_INET6:
			_len = net.IPv6len
		default:
			npanic(NslcdError(fmt.Sprintf("incorrect address family specified: %d", af)))
		}
		var len int32
		read(fd, &len)
		if len != _len {
			npanic(NslcdError(fmt.Sprintf("address length incorrect: %d", len)))
		}
		buf := make([]byte, len)
		read(fd, &buf)
		*data = buf
	case *[]net.IP:
		var num int32
		read(fd, &num)
		willread(fd, int64(num * /* min size of an IP is: */net.IPv4len))
		*data = make([]net.IP, num)
		for i := 0; i < int(num); i++ {
			read(fd, &((*data)[i]))
		}
	default:
		p := reflect.ValueOf(data)
		v := reflect.Indirect(p)
		if p == v || v.Kind() != reflect.Struct {
			panic(fmt.Sprintf("The argument to nslcd_proto.Read() must be a pointer: %T ( %#v )", data, data))
		}
		for i, n := 0, v.NumField(); i < n; i++ {
			read(fd, v.Field(i).Addr().Interface())
		}
	}
}