summaryrefslogtreecommitdiff
path: root/bin-src/tls-getcerts.go
diff options
context:
space:
mode:
Diffstat (limited to 'bin-src/tls-getcerts.go')
-rw-r--r--bin-src/tls-getcerts.go192
1 files changed, 192 insertions, 0 deletions
diff --git a/bin-src/tls-getcerts.go b/bin-src/tls-getcerts.go
new file mode 100644
index 0000000..34e25e5
--- /dev/null
+++ b/bin-src/tls-getcerts.go
@@ -0,0 +1,192 @@
+package main
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/pem"
+ "encoding/xml"
+ "fmt"
+ "io"
+ "net"
+ "net/textproto"
+ "net/url"
+ "os"
+ "strings"
+ "time"
+)
+
+type xmppStreamsFeatures struct {
+ XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
+}
+
+type xmppTlsProceed struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"`
+}
+
+func xmppStartTLS(connRaw net.Conn, host string) error {
+ decoder := xml.NewDecoder(connRaw)
+
+ // send <stream> start
+ _, err := fmt.Fprintf(connRaw, "<stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' to='%s' version='1.0'>", host)
+ if err != nil {
+ return err
+ }
+ // read <stream> start
+ for {
+ t, err := decoder.Token()
+ if err != nil || t == nil {
+ return err
+ }
+ if se, ok := t.(xml.StartElement); ok {
+ if se.Name.Local != "stream" {
+ return xml.UnmarshalError(fmt.Sprintf("expected element of type <%s> but have <%s>", "stream", se.Name.Local))
+ }
+ break
+ }
+ }
+ // read <features>
+ var features xmppStreamsFeatures
+ err = decoder.DecodeElement(&features, nil)
+ if err != nil {
+ return err
+ }
+ // send <starttls>
+ _, err = io.WriteString(connRaw, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
+ if err != nil {
+ return err
+ }
+ // read <proceed>
+ var proceed xmppTlsProceed
+ err = decoder.DecodeElement(&proceed, nil)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// smtpCmd is a convenience function that sends a command, and reads
+// (but discards) the response
+func smtpCmd(tp *textproto.Conn, expectCode int, format string, args ...interface{}) error {
+ id, err := tp.Cmd(format, args...)
+ if err != nil {
+ return err
+ }
+ tp.StartResponse(id)
+ defer tp.EndResponse(id)
+ _, _, err = tp.ReadResponse(expectCode)
+ return err
+}
+
+func smtpStartTLS(connRaw net.Conn, host string) error {
+ tp := textproto.NewConn(connRaw)
+
+ // let the server introduce itself
+ _, _, err := tp.ReadResponse(220)
+ if err != nil {
+ return err
+ }
+ // introduce ourself
+ localhost, err := os.Hostname()
+ if err != nil {
+ localhost = "localhost"
+ }
+ err = smtpCmd(tp, 250, "EHLO %s", localhost)
+ if err != nil {
+ err := smtpCmd(tp, 250, "HELO %s", localhost)
+ if err != nil {
+ return err
+ }
+ }
+ // starttls
+ err = smtpCmd(tp, 220, "STARTTLS")
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func getcert(socket string) (*x509.Certificate, error) {
+ u, err := url.Parse(socket)
+ if err != nil {
+ return nil, err
+ }
+ host, _, err := net.SplitHostPort(u.Host)
+ if err != nil {
+ return nil, err
+ }
+
+ connRaw, err := net.Dial(u.Scheme, u.Host)
+ if err != nil {
+ return nil, err
+ }
+ err = connRaw.SetDeadline(time.Now().Add(5 * time.Second))
+ if err != nil {
+ return nil, err
+ }
+
+ switch u.Path {
+ case "", "/":
+ // do nothing
+ case "/xmpp":
+ err = xmppStartTLS(connRaw, host)
+ if err != nil {
+ return nil, err
+ }
+ case "/smtp":
+ err = smtpStartTLS(connRaw, host)
+ if err != nil {
+ return nil, err
+ }
+ default:
+ return nil, fmt.Errorf("Unknown negotiation path: %q", u.Path)
+ }
+
+ connTLS := tls.Client(connRaw, &tls.Config{InsecureSkipVerify: true})
+ defer connTLS.Close()
+ err = connTLS.Handshake()
+ if err != nil {
+ return nil, err
+ }
+
+ cstate := connTLS.ConnectionState()
+
+ opts := x509.VerifyOptions{
+ DNSName: host,
+ Intermediates: x509.NewCertPool(),
+ }
+ for _, cert := range cstate.PeerCertificates[1:] {
+ opts.Intermediates.AddCert(cert)
+ }
+
+ cert := cstate.PeerCertificates[0]
+ _, err = cert.Verify(opts)
+ return cert, err
+}
+
+func split(socket string) (net, addr string) {
+ ary := strings.SplitN(socket, ":", 2)
+ if len(ary) == 1 {
+ return "tcp", ary[0]
+ }
+ return ary[0], ary[1]
+}
+
+func main() {
+ for _, socket := range os.Args[1:] {
+ fmt.Fprintf(os.Stderr, "Getting %q... ", socket)
+ block := pem.Block{
+ Type: "CERTIFICATE",
+ Headers: map[string]string{"X-Socket": socket},
+ Bytes: nil,
+ }
+ cert, err := getcert(socket)
+ if cert != nil {
+ block.Bytes = cert.Raw
+ }
+ if err != nil {
+ block.Headers["X-Error"] = err.Error()
+ }
+ pem.Encode(os.Stdout, &block)
+ fmt.Fprintln(os.Stderr, "[done]")
+ }
+}