diff options
Diffstat (limited to 'bin-src/tls-getcerts.go')
-rw-r--r-- | bin-src/tls-getcerts.go | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/bin-src/tls-getcerts.go b/bin-src/tls-getcerts.go new file mode 100644 index 0000000..34e25e5 --- /dev/null +++ b/bin-src/tls-getcerts.go @@ -0,0 +1,192 @@ +package main + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "encoding/xml" + "fmt" + "io" + "net" + "net/textproto" + "net/url" + "os" + "strings" + "time" +) + +type xmppStreamsFeatures struct { + XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` +} + +type xmppTlsProceed struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"` +} + +func xmppStartTLS(connRaw net.Conn, host string) error { + decoder := xml.NewDecoder(connRaw) + + // send <stream> start + _, err := fmt.Fprintf(connRaw, "<stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' to='%s' version='1.0'>", host) + if err != nil { + return err + } + // read <stream> start + for { + t, err := decoder.Token() + if err != nil || t == nil { + return err + } + if se, ok := t.(xml.StartElement); ok { + if se.Name.Local != "stream" { + return xml.UnmarshalError(fmt.Sprintf("expected element of type <%s> but have <%s>", "stream", se.Name.Local)) + } + break + } + } + // read <features> + var features xmppStreamsFeatures + err = decoder.DecodeElement(&features, nil) + if err != nil { + return err + } + // send <starttls> + _, err = io.WriteString(connRaw, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>") + if err != nil { + return err + } + // read <proceed> + var proceed xmppTlsProceed + err = decoder.DecodeElement(&proceed, nil) + if err != nil { + return err + } + return nil +} + +// smtpCmd is a convenience function that sends a command, and reads +// (but discards) the response +func smtpCmd(tp *textproto.Conn, expectCode int, format string, args ...interface{}) error { + id, err := tp.Cmd(format, args...) + if err != nil { + return err + } + tp.StartResponse(id) + defer tp.EndResponse(id) + _, _, err = tp.ReadResponse(expectCode) + return err +} + +func smtpStartTLS(connRaw net.Conn, host string) error { + tp := textproto.NewConn(connRaw) + + // let the server introduce itself + _, _, err := tp.ReadResponse(220) + if err != nil { + return err + } + // introduce ourself + localhost, err := os.Hostname() + if err != nil { + localhost = "localhost" + } + err = smtpCmd(tp, 250, "EHLO %s", localhost) + if err != nil { + err := smtpCmd(tp, 250, "HELO %s", localhost) + if err != nil { + return err + } + } + // starttls + err = smtpCmd(tp, 220, "STARTTLS") + if err != nil { + return err + } + return nil +} + +func getcert(socket string) (*x509.Certificate, error) { + u, err := url.Parse(socket) + if err != nil { + return nil, err + } + host, _, err := net.SplitHostPort(u.Host) + if err != nil { + return nil, err + } + + connRaw, err := net.Dial(u.Scheme, u.Host) + if err != nil { + return nil, err + } + err = connRaw.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + return nil, err + } + + switch u.Path { + case "", "/": + // do nothing + case "/xmpp": + err = xmppStartTLS(connRaw, host) + if err != nil { + return nil, err + } + case "/smtp": + err = smtpStartTLS(connRaw, host) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("Unknown negotiation path: %q", u.Path) + } + + connTLS := tls.Client(connRaw, &tls.Config{InsecureSkipVerify: true}) + defer connTLS.Close() + err = connTLS.Handshake() + if err != nil { + return nil, err + } + + cstate := connTLS.ConnectionState() + + opts := x509.VerifyOptions{ + DNSName: host, + Intermediates: x509.NewCertPool(), + } + for _, cert := range cstate.PeerCertificates[1:] { + opts.Intermediates.AddCert(cert) + } + + cert := cstate.PeerCertificates[0] + _, err = cert.Verify(opts) + return cert, err +} + +func split(socket string) (net, addr string) { + ary := strings.SplitN(socket, ":", 2) + if len(ary) == 1 { + return "tcp", ary[0] + } + return ary[0], ary[1] +} + +func main() { + for _, socket := range os.Args[1:] { + fmt.Fprintf(os.Stderr, "Getting %q... ", socket) + block := pem.Block{ + Type: "CERTIFICATE", + Headers: map[string]string{"X-Socket": socket}, + Bytes: nil, + } + cert, err := getcert(socket) + if cert != nil { + block.Bytes = cert.Raw + } + if err != nil { + block.Headers["X-Error"] = err.Error() + } + pem.Encode(os.Stdout, &block) + fmt.Fprintln(os.Stderr, "[done]") + } +} |