summaryrefslogtreecommitdiff
path: root/nslcd_server
diff options
context:
space:
mode:
Diffstat (limited to 'nslcd_server')
-rwxr-xr-xnslcd_server/func_handlerequest.go.gen64
1 files changed, 54 insertions, 10 deletions
diff --git a/nslcd_server/func_handlerequest.go.gen b/nslcd_server/func_handlerequest.go.gen
index d34db88..af36e84 100755
--- a/nslcd_server/func_handlerequest.go.gen
+++ b/nslcd_server/func_handlerequest.go.gen
@@ -26,8 +26,8 @@ package nslcd_server
import (
"fmt"
- "io"
"os"
+ "time"
"golang.org/x/sys/unix"
p "git.lukeshu.com/go/libnslcd/nslcd_proto"
@@ -41,8 +41,32 @@ func maybePanic(err error) {
}
}
+type Limits struct {
+ // What is the maximum total amount of time that we spend
+ // handling a single request. This includes both the time
+ // reading the request and the time creating and writing the
+ // response.
+ Timeout time.Duration
+
+ // How long can we spend reading a request?
+ ReadTimeout time.Duration
+
+ // How long can we spend writing a response?
+ WriteTimeout time.Duration
+}
+
+type Conn interface {
+ // This is a subset of net.Conn; semantics are the same.
+
+ Read(b []byte) (n int, err error)
+ Write(b []byte) (n int, err error)
+ SetDeadline(t time.Time) error
+ SetReadDeadline(t time.Time) error
+ SetWriteDeadline(t time.Time) error
+}
+
// Handle a request to nslcd
-func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred) (err error) {
+func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) (err error) {
defer func() {
if r := recover(); r != nil {
switch r := r.(type) {
@@ -54,13 +78,27 @@ func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred
}
}()
+ now := time.Now()
+ if limits.Timeout != 0 {
+ err = conn.SetDeadline(now.Add(limits.Timeout))
+ if err != nil {
+ return err
+ }
+ }
+ if limits.ReadTimeout != 0 {
+ err = conn.SetReadDeadline(now.Add(limits.ReadTimeout))
+ if err != nil {
+ return err
+ }
+ }
+
var version int32
- maybePanic(p.Read(in, &version))
+ maybePanic(p.Read(conn, &version))
if version != p.NSLCD_VERSION {
return p.NslcdError(fmt.Sprintf("Version mismatch: server=%#08x client=%#08x", p.NSLCD_VERSION, version))
}
var action int32
- maybePanic(p.Read(in, &action))
+ maybePanic(p.Read(conn, &action))
switch action {
$(
@@ -68,7 +106,7 @@ while read -r request; do
cat <<EOT
case p.NSLCD_ACTION_${request^^}:
var req p.Request_${request}
- maybePanic(p.Read(in, &req))
+ maybePanic(p.Read(conn, &req))
$(
case "$request" in
PAM_Authentication)
@@ -94,19 +132,25 @@ while read -r request; do
;;
esac
)
- maybePanic(p.Write(out, p.NSLCD_VERSION))
- maybePanic(p.Write(out, action))
+ if limits.WriteTimeout != 0 {
+ err = conn.SetWriteDeadline(time.Now().Add(limits.WriteTimeout))
+ if err != nil {
+ return err
+ }
+ }
+ maybePanic(p.Write(conn, p.NSLCD_VERSION))
+ maybePanic(p.Write(conn, action))
ch := backend.${request}(cred, req)
for result := range ch {
if err == nil {
- err = p.Write(out, p.NSLCD_RESULT_BEGIN)
+ err = p.Write(conn, p.NSLCD_RESULT_BEGIN)
}
if err == nil {
- err = p.Write(out, result)
+ err = p.Write(conn, result)
}
}
maybePanic(err)
- maybePanic(p.Write(out, p.NSLCD_RESULT_END))
+ maybePanic(p.Write(conn, p.NSLCD_RESULT_END))
return nil
EOT
done < "$requests"