diff options
Diffstat (limited to 'nslcd_proto')
-rw-r--r-- | nslcd_proto/io.go | 100 | ||||
-rw-r--r-- | nslcd_proto/nslcd_h.go | 24 |
2 files changed, 82 insertions, 42 deletions
diff --git a/nslcd_proto/io.go b/nslcd_proto/io.go index a2adade..daced37 100644 --- a/nslcd_proto/io.go +++ b/nslcd_proto/io.go @@ -37,6 +37,10 @@ func (o NslcdError) Error() string { return string(o) } +func npanic(err NslcdError) { + panic(err) +} + // An nslcdObject is an object with a different network representation // than a naive structure. type nslcdObject interface { @@ -53,9 +57,45 @@ type nslcdObjectPtr interface { nslcdRead(fd io.Reader) } +// Write an object to a stream. Any errors returned are of type +// NslcdError. +func Write(fd io.Writer, data interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + switch r := r.(type) { + case NslcdError: + err = r + default: + panic(r) + } + } + }() + write(fd, data) + + return err +} + +// Read an object from a stream. Any errors returned are of type +// NslcdError. +func Read(fd io.Reader, data interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + switch r := r.(type) { + case NslcdError: + err = r + default: + panic(r) + } + } + }() + read(fd, data) + + return err +} + // Write an object to a stream. In the event of an error, this -// function may panic! Handle it! -func Write(fd io.Writer, data interface{}) { +// function may panic(NslcdError)! Handle it! +func write(fd io.Writer, data interface{}) { switch data := data.(type) { // basic data types case nslcdObject: @@ -64,22 +104,22 @@ func Write(fd io.Writer, data interface{}) { if len(data) > 0 { _, err := fd.Write(data) if err != nil { - panic(err) + npanic(NslcdError(err.Error())) } } case int32: err := binary.Write(fd, binary.BigEndian, data) if err != nil { - panic(err) + npanic(NslcdError(err.Error())) } // composite datatypes case string: - Write(fd, int32(len(data))) - Write(fd, []byte(data)) + write(fd, int32(len(data))) + write(fd, []byte(data)) case []string: - Write(fd, int32(len(data))) + write(fd, int32(len(data))) for _, item := range data { - Write(fd, item) + write(fd, item) } case net.IP: var af int32 = -1 @@ -95,20 +135,20 @@ func Write(fd io.Writer, data interface{}) { } else { bytes = data } - Write(fd, af) - Write(fd, int32(len(bytes))) - Write(fd, bytes) + write(fd, af) + write(fd, int32(len(bytes))) + write(fd, bytes) case []net.IP: - Write(fd, int32(len(data))) + write(fd, int32(len(data))) for _, item := range data { - Write(fd, item) + write(fd, item) } default: v := reflect.ValueOf(data) switch v.Kind() { case reflect.Struct: for i, n := 0, v.NumField(); i < n; i++ { - Write(fd, v.Field(i).Interface()) + write(fd, v.Field(i).Interface()) } default: panic(fmt.Sprintf("Invalid structure to write NSLCD protocol data from: %T ( %#v )", data, data)) @@ -117,8 +157,8 @@ func Write(fd io.Writer, data interface{}) { } // Read an object from a stream. In the event of an error, this -// function may panic! Handle it! -func Read(fd io.Reader, data interface{}) { +// function may panic(NslcdError)! Handle it! +func read(fd io.Reader, data interface{}) { switch data := data.(type) { // basic data types case nslcdObjectPtr: @@ -127,31 +167,31 @@ func Read(fd io.Reader, data interface{}) { if len(*data) > 0 { _, err := io.ReadFull(fd, *data) if err != nil { - panic(err) + npanic(NslcdError(err.Error())) } } case *int32: err := binary.Read(fd, binary.BigEndian, data) if err != nil { - panic(err) + npanic(NslcdError(err.Error())) } // composite datatypes case *string: var len int32 - Read(fd, &len) + read(fd, &len) buf := make([]byte, len) // BUG(lukeshu): Read: `string` length needs sanity checked - Read(fd, &buf) + read(fd, &buf) *data = string(buf) case *[]string: var num int32 - Read(fd, &num) + read(fd, &num) *data = make([]string, num) // BUG(lukeshu): Read: `[]string` length needs sanity checked for i := 0; i < int(num); i++ { - Read(fd, &((*data)[i])) + read(fd, &((*data)[i])) } case *net.IP: var af int32 - Read(fd, &af) + read(fd, &af) var _len int32 switch af { case unix.AF_INET: @@ -159,22 +199,22 @@ func Read(fd io.Reader, data interface{}) { case unix.AF_INET6: _len = net.IPv6len default: - panic(NslcdError(fmt.Sprintf("incorrect address family specified: %d", af))) + npanic(NslcdError(fmt.Sprintf("incorrect address family specified: %d", af))) } var len int32 - Read(fd, &len) + read(fd, &len) if len != _len { - panic(NslcdError(fmt.Sprintf("address length incorrect: %d", len))) + npanic(NslcdError(fmt.Sprintf("address length incorrect: %d", len))) } buf := make([]byte, len) - Read(fd, &buf) + read(fd, &buf) *data = buf case *[]net.IP: var num int32 - Read(fd, &num) + read(fd, &num) *data = make([]net.IP, num) // BUG(lukeshu): Read: `[]net.IP` length needs sanity checked for i := 0; i < int(num); i++ { - Read(fd, &((*data)[i])) + read(fd, &((*data)[i])) } default: p := reflect.ValueOf(data) @@ -183,7 +223,7 @@ func Read(fd io.Reader, data interface{}) { panic(fmt.Sprintf("The argument to nslcd_proto.Read() must be a pointer: %T ( %#v )", data, data)) } for i, n := 0, v.NumField(); i < n; i++ { - Read(fd, v.Field(i).Addr().Interface()) + read(fd, v.Field(i).Addr().Interface()) } } } diff --git a/nslcd_proto/nslcd_h.go b/nslcd_proto/nslcd_h.go index cb210cd..d8eee9f 100644 --- a/nslcd_proto/nslcd_h.go +++ b/nslcd_proto/nslcd_h.go @@ -1,5 +1,5 @@ // This file is based heavily on nslcd.h from nss-pam-ldapd -// Copyright (C) 2015 Luke Shumaker +// Copyright (C) 2015, 2017 Luke Shumaker /* nslcd.h - file describing client/server protocol @@ -171,19 +171,19 @@ func (data Netgroup_PartList) nslcdWrite(fd io.Writer) { t = NSLCD_NETGROUP_TYPE_TRIPLE } if t < 0 { - panic("unrecognized netgroup type") + panic(fmt.Sprintf("unrecognized netgroup type: %#08x", t)) } - Write(fd, t) - Write(fd, part) + write(fd, t) + write(fd, part) } - Write(fd, NSLCD_NETGROUP_TYPE_END) + write(fd, NSLCD_NETGROUP_TYPE_END) } func (data *Netgroup_PartList) nslcdRead(fd io.Reader) { *data = make([]interface{}, 0) for { var t int32 var v interface{} - Read(fd, &t) + read(fd, &t) switch t { case NSLCD_NETGROUP_TYPE_NETGROUP: v = Netgroup_Netgroup{} @@ -192,9 +192,9 @@ func (data *Netgroup_PartList) nslcdRead(fd io.Reader) { case NSLCD_NETGROUP_TYPE_END: return default: - panic(NslcdError(fmt.Sprintf("unrecognized netgroup type: %#08x", t))) + npanic(NslcdError(fmt.Sprintf("unrecognized netgroup type: %#08x", t))) } - Read(fd, &v) + read(fd, &v) *data = append(*data, v) } } @@ -384,20 +384,20 @@ type UserMod_Item struct { type UserMod_ItemList []UserMod_Item func (data UserMod_ItemList) nslcdWrite(fd io.Writer) { for _, item := range data { - Write(fd, item) + write(fd, item) } - Write(fd, NSLCD_USERMOD_END) + write(fd, NSLCD_USERMOD_END) } func (data *UserMod_ItemList) nslcdRead(fd io.Reader) { *data = make([]UserMod_Item, 0) for { var t int32 - Read(fd, &t) + read(fd, &t) if t == NSLCD_USERMOD_END { return } var v UserMod_Item - Read(fd, &v) + read(fd, &v) *data = append(*data, v) } } |