From 0a5596290c8b5adc42c46c0a375d4b970230b550 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Mon, 21 Nov 2016 14:26:43 -0500 Subject: tls-getcerts: add xmpp support --- tls-getcerts.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/tls-getcerts.go b/tls-getcerts.go index 49e15a2..7032199 100644 --- a/tls-getcerts.go +++ b/tls-getcerts.go @@ -4,17 +4,93 @@ import ( "crypto/tls" "crypto/x509" "encoding/pem" + "encoding/xml" "fmt" + "io" "net" "os" + "strings" ) +type xmppStreamsFeatures struct { + XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` +} + +type xmppTlsProceed struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"` +} + +func tlsDial(snet, saddr string) (*tls.Conn, error) { + switch snet { + case "tcp": + conn, err := tls.Dial(snet, saddr, &tls.Config{InsecureSkipVerify: true}) + if err != nil { + return nil, err + } + return conn, nil + case "xmpp": + host, _, err := net.SplitHostPort(saddr) + connTCP, err := net.Dial("tcp", saddr) + if err != nil { + return nil, err + } + + decoder := xml.NewDecoder(connTCP) + + // send start + _, err = fmt.Fprintf(connTCP, "", host) + if err != nil { + return nil, err + } + // read start + for { + t, err := decoder.Token() + if err != nil || t == nil { + return nil, err + } + if se, ok := t.(xml.StartElement); ok { + if se.Name.Local != "stream" { + return nil, xml.UnmarshalError(fmt.Sprintf("expected element of type <%s> but have <%s>", "stream", se.Name.Local)) + } + break + } + } + // read + var features xmppStreamsFeatures + err = decoder.DecodeElement(&features, nil) + if err != nil { + return nil, err + } + // send + _, err = io.WriteString(connTCP, "") + if err != nil { + return nil, err + } + // read + var proceed xmppTlsProceed + err = decoder.DecodeElement(&proceed, nil) + if err != nil { + return nil, err + } + + connTLS := tls.Client(connTCP, &tls.Config{InsecureSkipVerify: true}) + err = connTLS.Handshake() + if err != nil { + return nil, err + } + return connTLS, nil + default: + return nil, fmt.Errorf("Unknown TLS network: %q", snet) + } +} + func getcert(socket string) (*x509.Certificate, error){ - host, _, err := net.SplitHostPort(socket) + snet, saddr := split(socket) + host, _, err := net.SplitHostPort(saddr) if err != nil { return nil, err } - conn, err := tls.Dial("tcp", socket, &tls.Config{InsecureSkipVerify: true}) + conn, err := tlsDial(snet, saddr) if err != nil { return nil, err } @@ -34,6 +110,14 @@ func getcert(socket string) (*x509.Certificate, error){ return cert, err } +func split(socket string) (net, addr string) { + ary := strings.SplitN(socket, ":", 2) + if len(ary) == 1 { + return "tcp", ary[0] + } + return ary[0], ary[1] +} + func main() { for _, socket := range os.Args[1:] { cert, err := getcert(socket) -- cgit v1.2.3-2-g168b