// Copyright (C) 2015 Luke Shumaker <lukeshu@sbcglobal.net>
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
// 02110-1301 USA

package nslcd_proto

import (
	"encoding/binary"
	"fmt"
	"io"
	"net"
	"reflect"
	"syscall"
)

type NslcdError string

func (o NslcdError) Error() string {
	return string(o)
}

type nslcdObject interface {
	nslcdWrite(fd io.Writer)
}

type nslcdObjectPtr interface {
	nslcdRead(fd io.Reader)
}

// Write an object to a stream.  In the event of an error, this
// function may panic!  Handle it!
func Write(fd io.Writer, data interface{}) {
	switch data := data.(type) {
	// basic data types
	case nslcdObject:
		data.nslcdWrite(fd)
	case []byte:
		if len(data) > 0 {
			_, err := fd.Write(data)
			if err != nil {
				panic(err)
			}
		}
	case int32:
		err := binary.Write(fd, binary.BigEndian, data)
		if err != nil {
			panic(err)
		}
	// composite datatypes
	case string:
		Write(fd, int32(len(data)))
		Write(fd, []byte(data))
	case []string:
		Write(fd, int32(len(data)))
		for _, item := range data {
			Write(fd, item)
		}
	case net.IP:
		var af int32 = -1
		switch len(data) {
		case net.IPv4len:
			af = syscall.AF_INET
		case net.IPv6len:
			af = syscall.AF_INET6
		}
		var bytes []byte
		if af < 0 {
			bytes = make([]byte, 0)
		} else {
			bytes = data
		}
		Write(fd, af)
		Write(fd, int32(len(bytes)))
		Write(fd, bytes)
	case []net.IP:
		Write(fd, int32(len(data)))
		for _, item := range data {
			Write(fd, item)
		}
	default:
		v := reflect.ValueOf(data)
		switch v.Kind() {
		case reflect.Struct:
			for i, n := 0, v.NumField(); i < n; i++ {
				Write(fd, v.Field(i).Interface())
			}
		default:
			panic(fmt.Sprintf("Invalid structure to write NSLCD protocol data from: %T ( %#v )", data, data))
		}
	}
}

// Read an object from a stream.  In the event of an error, this
// function may panic!  Handle it!
func Read(fd io.Reader, data interface{}) {
	switch data := data.(type) {
	// basic data types
	case nslcdObjectPtr:
		data.nslcdRead(fd)
	case *[]byte:
		if len(*data) > 0 {
			_, err := fd.Read(*data)
			if err != nil {
				panic(err)
			}
		}
	case *int32:
		err := binary.Read(fd, binary.BigEndian, data)
		if err != nil {
			panic(err)
		}
	// composite datatypes
	case *string:
		var len int32
		Read(fd, &len)
		buf := make([]byte, len)
		Read(fd, &buf)
		*data = string(buf)
	case *[]string:
		var num int32
		Read(fd, &num)
		*data = make([]string, num)
		for i := 0; i < int(num); i++ {
			Read(fd, &((*data)[i]))
		}
	case *net.IP:
		var af int32
		Read(fd, &af)
		var _len int32
		switch af {
		case syscall.AF_INET:
			_len = net.IPv4len
		case syscall.AF_INET6:
			_len = net.IPv6len
		default:
			panic(NslcdError(fmt.Sprintf("incorrect address family specified: %d", af)))
		}
		var len int32
		Read(fd, &len)
		if len != _len {
			panic(NslcdError(fmt.Sprintf("address length incorrect: %d", len)))
		}
		buf := make([]byte, len)
		Read(fd, &buf)
		*data = buf
	case *[]net.IP:
		var num int32
		Read(fd, &num)
		*data = make([]net.IP, num)
		for i := 0; i < int(num); i++ {
			Read(fd, &((*data)[i]))
		}
	default:
		p := reflect.ValueOf(data)
		v := reflect.Indirect(p)
		if p == v || v.Kind() != reflect.Struct {
			panic(fmt.Sprintf("The argument to nslcd_proto/internal.Read() must be a pointer: %T ( %#v )", data, data))
		}
		for i, n := 0, v.NumField(); i < n; i++ {
			Read(fd, v.Field(i).Addr().Interface())
		}
	}
}