diff options
-rw-r--r-- | nslcd_proto/io.go | 100 | ||||
-rw-r--r-- | nslcd_proto/nslcd_h.go | 24 | ||||
-rwxr-xr-x | nslcd_server/func_handlerequest.go.gen | 26 |
3 files changed, 98 insertions, 52 deletions
diff --git a/nslcd_proto/io.go b/nslcd_proto/io.go index a2adade..daced37 100644 --- a/nslcd_proto/io.go +++ b/nslcd_proto/io.go @@ -37,6 +37,10 @@ func (o NslcdError) Error() string { return string(o) } +func npanic(err NslcdError) { + panic(err) +} + // An nslcdObject is an object with a different network representation // than a naive structure. type nslcdObject interface { @@ -53,9 +57,45 @@ type nslcdObjectPtr interface { nslcdRead(fd io.Reader) } +// Write an object to a stream. Any errors returned are of type +// NslcdError. +func Write(fd io.Writer, data interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + switch r := r.(type) { + case NslcdError: + err = r + default: + panic(r) + } + } + }() + write(fd, data) + + return err +} + +// Read an object from a stream. Any errors returned are of type +// NslcdError. +func Read(fd io.Reader, data interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + switch r := r.(type) { + case NslcdError: + err = r + default: + panic(r) + } + } + }() + read(fd, data) + + return err +} + // Write an object to a stream. In the event of an error, this -// function may panic! Handle it! -func Write(fd io.Writer, data interface{}) { +// function may panic(NslcdError)! Handle it! +func write(fd io.Writer, data interface{}) { switch data := data.(type) { // basic data types case nslcdObject: @@ -64,22 +104,22 @@ func Write(fd io.Writer, data interface{}) { if len(data) > 0 { _, err := fd.Write(data) if err != nil { - panic(err) + npanic(NslcdError(err.Error())) } } case int32: err := binary.Write(fd, binary.BigEndian, data) if err != nil { - panic(err) + npanic(NslcdError(err.Error())) } // composite datatypes case string: - Write(fd, int32(len(data))) - Write(fd, []byte(data)) + write(fd, int32(len(data))) + write(fd, []byte(data)) case []string: - Write(fd, int32(len(data))) + write(fd, int32(len(data))) for _, item := range data { - Write(fd, item) + write(fd, item) } case net.IP: var af int32 = -1 @@ -95,20 +135,20 @@ func Write(fd io.Writer, data interface{}) { } else { bytes = data } - Write(fd, af) - Write(fd, int32(len(bytes))) - Write(fd, bytes) + write(fd, af) + write(fd, int32(len(bytes))) + write(fd, bytes) case []net.IP: - Write(fd, int32(len(data))) + write(fd, int32(len(data))) for _, item := range data { - Write(fd, item) + write(fd, item) } default: v := reflect.ValueOf(data) switch v.Kind() { case reflect.Struct: for i, n := 0, v.NumField(); i < n; i++ { - Write(fd, v.Field(i).Interface()) + write(fd, v.Field(i).Interface()) } default: panic(fmt.Sprintf("Invalid structure to write NSLCD protocol data from: %T ( %#v )", data, data)) @@ -117,8 +157,8 @@ func Write(fd io.Writer, data interface{}) { } // Read an object from a stream. In the event of an error, this -// function may panic! Handle it! -func Read(fd io.Reader, data interface{}) { +// function may panic(NslcdError)! Handle it! +func read(fd io.Reader, data interface{}) { switch data := data.(type) { // basic data types case nslcdObjectPtr: @@ -127,31 +167,31 @@ func Read(fd io.Reader, data interface{}) { if len(*data) > 0 { _, err := io.ReadFull(fd, *data) if err != nil { - panic(err) + npanic(NslcdError(err.Error())) } } case *int32: err := binary.Read(fd, binary.BigEndian, data) if err != nil { - panic(err) + npanic(NslcdError(err.Error())) } // composite datatypes case *string: var len int32 - Read(fd, &len) + read(fd, &len) buf := make([]byte, len) // BUG(lukeshu): Read: `string` length needs sanity checked - Read(fd, &buf) + read(fd, &buf) *data = string(buf) case *[]string: var num int32 - Read(fd, &num) + read(fd, &num) *data = make([]string, num) // BUG(lukeshu): Read: `[]string` length needs sanity checked for i := 0; i < int(num); i++ { - Read(fd, &((*data)[i])) + read(fd, &((*data)[i])) } case *net.IP: var af int32 - Read(fd, &af) + read(fd, &af) var _len int32 switch af { case unix.AF_INET: @@ -159,22 +199,22 @@ func Read(fd io.Reader, data interface{}) { case unix.AF_INET6: _len = net.IPv6len default: - panic(NslcdError(fmt.Sprintf("incorrect address family specified: %d", af))) + npanic(NslcdError(fmt.Sprintf("incorrect address family specified: %d", af))) } var len int32 - Read(fd, &len) + read(fd, &len) if len != _len { - panic(NslcdError(fmt.Sprintf("address length incorrect: %d", len))) + npanic(NslcdError(fmt.Sprintf("address length incorrect: %d", len))) } buf := make([]byte, len) - Read(fd, &buf) + read(fd, &buf) *data = buf case *[]net.IP: var num int32 - Read(fd, &num) + read(fd, &num) *data = make([]net.IP, num) // BUG(lukeshu): Read: `[]net.IP` length needs sanity checked for i := 0; i < int(num); i++ { - Read(fd, &((*data)[i])) + read(fd, &((*data)[i])) } default: p := reflect.ValueOf(data) @@ -183,7 +223,7 @@ func Read(fd io.Reader, data interface{}) { panic(fmt.Sprintf("The argument to nslcd_proto.Read() must be a pointer: %T ( %#v )", data, data)) } for i, n := 0, v.NumField(); i < n; i++ { - Read(fd, v.Field(i).Addr().Interface()) + read(fd, v.Field(i).Addr().Interface()) } } } diff --git a/nslcd_proto/nslcd_h.go b/nslcd_proto/nslcd_h.go index cb210cd..d8eee9f 100644 --- a/nslcd_proto/nslcd_h.go +++ b/nslcd_proto/nslcd_h.go @@ -1,5 +1,5 @@ // This file is based heavily on nslcd.h from nss-pam-ldapd -// Copyright (C) 2015 Luke Shumaker +// Copyright (C) 2015, 2017 Luke Shumaker /* nslcd.h - file describing client/server protocol @@ -171,19 +171,19 @@ func (data Netgroup_PartList) nslcdWrite(fd io.Writer) { t = NSLCD_NETGROUP_TYPE_TRIPLE } if t < 0 { - panic("unrecognized netgroup type") + panic(fmt.Sprintf("unrecognized netgroup type: %#08x", t)) } - Write(fd, t) - Write(fd, part) + write(fd, t) + write(fd, part) } - Write(fd, NSLCD_NETGROUP_TYPE_END) + write(fd, NSLCD_NETGROUP_TYPE_END) } func (data *Netgroup_PartList) nslcdRead(fd io.Reader) { *data = make([]interface{}, 0) for { var t int32 var v interface{} - Read(fd, &t) + read(fd, &t) switch t { case NSLCD_NETGROUP_TYPE_NETGROUP: v = Netgroup_Netgroup{} @@ -192,9 +192,9 @@ func (data *Netgroup_PartList) nslcdRead(fd io.Reader) { case NSLCD_NETGROUP_TYPE_END: return default: - panic(NslcdError(fmt.Sprintf("unrecognized netgroup type: %#08x", t))) + npanic(NslcdError(fmt.Sprintf("unrecognized netgroup type: %#08x", t))) } - Read(fd, &v) + read(fd, &v) *data = append(*data, v) } } @@ -384,20 +384,20 @@ type UserMod_Item struct { type UserMod_ItemList []UserMod_Item func (data UserMod_ItemList) nslcdWrite(fd io.Writer) { for _, item := range data { - Write(fd, item) + write(fd, item) } - Write(fd, NSLCD_USERMOD_END) + write(fd, NSLCD_USERMOD_END) } func (data *UserMod_ItemList) nslcdRead(fd io.Reader) { *data = make([]UserMod_Item, 0) for { var t int32 - Read(fd, &t) + read(fd, &t) if t == NSLCD_USERMOD_END { return } var v UserMod_Item - Read(fd, &v) + read(fd, &v) *data = append(*data, v) } } diff --git a/nslcd_server/func_handlerequest.go.gen b/nslcd_server/func_handlerequest.go.gen index e7e2dcc..40e00c0 100755 --- a/nslcd_server/func_handlerequest.go.gen +++ b/nslcd_server/func_handlerequest.go.gen @@ -1,6 +1,6 @@ #!/usr/bin/env bash # -*- Mode: Go -*- -# Copyright (C) 2015-2016 Luke Shumaker <lukeshu@sbcglobal.net> +# Copyright (C) 2015-2017 Luke Shumaker <lukeshu@sbcglobal.net> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -35,12 +35,18 @@ import ( const sensitive = "<omitted-from-log>" +func maybePanic(err error) { + if err != nil { + panic(err) + } +} + // Handle a request to nslcd func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred) (err error) { defer func() { if r := recover(); r != nil { switch r := r.(type) { - case error: + case p.NslcdError: err = r default: panic(r) @@ -53,12 +59,12 @@ func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred func handleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred) { var version int32 - p.Read(in, &version) + maybePanic(p.Read(in, &version)) if version != p.NSLCD_VERSION { panic(p.NslcdError(fmt.Sprintf("Version mismatch: server=%#08x client=%#08x", p.NSLCD_VERSION, version))) } var action int32 - p.Read(in, &action) + maybePanic(p.Read(in, &action)) ch := make(chan interface{}) switch action { @@ -67,7 +73,7 @@ while read -r request; do cat <<EOT case p.NSLCD_ACTION_${request^^}: var req p.Request_${request} - p.Read(in, &req) + maybePanic(p.Read(in, &req)) $( case "$request" in PAM_Authentication) @@ -107,13 +113,13 @@ done < "$requests" close(ch) panic(p.NslcdError(fmt.Sprintf("Unknown request action: %#08x", action))) } - p.Write(out, p.NSLCD_VERSION) - p.Write(out, action) + maybePanic(p.Write(out, p.NSLCD_VERSION)) + maybePanic(p.Write(out, action)) for result := range ch { - p.Write(out, p.NSLCD_RESULT_BEGIN) - p.Write(out, result) + maybePanic(p.Write(out, p.NSLCD_RESULT_BEGIN)) + maybePanic(p.Write(out, result)) } - p.Write(out, p.NSLCD_RESULT_END) + maybePanic(p.Write(out, p.NSLCD_RESULT_END)) } EOF |