summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--backend.go26
1 files changed, 20 insertions, 6 deletions
diff --git a/backend.go b/backend.go
index ff40f5b..fc67d82 100644
--- a/backend.go
+++ b/backend.go
@@ -10,23 +10,33 @@ import (
// A Backend is something that consumes a fast-import stream; the
// Backend object provides methods for writing to it.
type Backend struct {
- w *bufio.Writer
+ w io.WriteCloser
+ bw *bufio.Writer
fiw *textproto.FIWriter
cbr *textproto.CatBlobReader
- err error
+ err error
onErr func(error) error
}
-func NewBackend(fastImport io.Writer, catBlob io.Reader, onErr func(error) error) *Backend {
+func NewBackend(fastImport io.WriteCloser, catBlob io.Reader, onErr func(error) error) *Backend {
ret := &Backend{}
- ret.w = bufio.NewWriter(fastImport)
- ret.fiw = textproto.NewFIWriter(ret.w)
+ ret.w = fastImport
+ ret.bw = bufio.NewWriter(ret.w)
+ ret.fiw = textproto.NewFIWriter(ret.bw)
if catBlob != nil {
ret.cbr = textproto.NewCatBlobReader(catBlob)
}
ret.onErr = func(err error) error {
ret.err = err
+
+ // Close the underlying writer, but don't let the
+ // error mask the previous error.
+ err = ret.w.Close()
+ if ret.err == nil {
+ ret.err = err
+ }
+
if onErr != nil {
ret.err = onErr(ret.err)
}
@@ -44,11 +54,15 @@ func (b *Backend) Do(cmd Cmd) error {
if err != nil {
return b.onErr(err)
}
- err = b.w.Flush()
+ err = b.bw.Flush()
if err != nil {
return b.onErr(err)
}
+ if _, ok := cmd.(CmdDone); ok {
+ return b.onErr(nil)
+ }
+
return nil
}