package main

import (
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"encoding/xml"
	"fmt"
	"io"
	"net"
	"net/textproto"
	"net/url"
	"os"
	"strings"
	"time"
)

type xmppStreamsFeatures struct {
	XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
}

type xmppTlsProceed struct {
	XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"`
}

func xmppStartTLS(connRaw net.Conn, host string) error {
	decoder := xml.NewDecoder(connRaw)

	// send <stream> start
	_, err := fmt.Fprintf(connRaw, "<stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' to='%s' version='1.0'>", host)
	if err != nil {
		return err
	}
	// read <stream> start
	for {
		t, err := decoder.Token()
		if err != nil || t == nil {
			return err
		}
		if se, ok := t.(xml.StartElement); ok {
			if se.Name.Local != "stream" {
				return xml.UnmarshalError(fmt.Sprintf("expected element of type <%s> but have <%s>", "stream", se.Name.Local))
			}
			break
		}
	}
	// read <features>
	var features xmppStreamsFeatures
	err = decoder.DecodeElement(&features, nil)
	if err != nil {
		return err
	}
	// send <starttls>
	_, err = io.WriteString(connRaw, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
	if err != nil {
		return err
	}
	// read <proceed>
	var proceed xmppTlsProceed
	err = decoder.DecodeElement(&proceed, nil)
	if err != nil {
		return err
	}
	return nil
}

// smtpCmd is a convenience function that sends a command, and reads
// (but discards) the response
func smtpCmd(tp *textproto.Conn, expectCode int, format string, args ...interface{}) error {
	id, err := tp.Cmd(format, args...)
	if err != nil {
		return err
	}
	tp.StartResponse(id)
	defer tp.EndResponse(id)
	_, _, err = tp.ReadResponse(expectCode)
	return err
}

func smtpStartTLS(connRaw net.Conn, host string) error {
	tp := textproto.NewConn(connRaw)

	// let the server introduce itself
	_, _, err := tp.ReadResponse(220)
	if err != nil {
		return err
	}
	// introduce ourself
	localhost, err := os.Hostname()
	if err != nil {
		localhost = "localhost"
	}
	err = smtpCmd(tp, 250, "EHLO %s", localhost)
	if err != nil {
		err := smtpCmd(tp, 250, "HELO %s", localhost)
		if err != nil {
			return err
		}
	}
	// starttls
	err = smtpCmd(tp, 220, "STARTTLS")
	if err != nil {
		return err
	}
	return nil
}

func getcert(socket string) (*x509.Certificate, error) {
	u, err := url.Parse(socket)
	if err != nil {
		return nil, err
	}
	host, _, err := net.SplitHostPort(u.Host)
	if err != nil {
		return nil, err
	}

	connRaw, err := net.Dial(u.Scheme, u.Host)
	if err != nil {
		return nil, err
	}
	err = connRaw.SetDeadline(time.Now().Add(5 * time.Second))
	if err != nil {
		return nil, err
	}

	switch u.Path {
	case "", "/":
		// do nothing
	case "/xmpp":
		err = xmppStartTLS(connRaw, host)
		if err != nil {
			return nil, err
		}
	case "/smtp":
		err = smtpStartTLS(connRaw, host)
		if err != nil {
			return nil, err
		}
	default:
		return nil, fmt.Errorf("Unknown negotiation path: %q", u.Path)
	}

	connTLS := tls.Client(connRaw, &tls.Config{InsecureSkipVerify: true})
	defer connTLS.Close()
	err = connTLS.Handshake()
	if err != nil {
		return nil, err
	}

	cstate := connTLS.ConnectionState()

	opts := x509.VerifyOptions{
		DNSName:       host,
		Intermediates: x509.NewCertPool(),
	}
	for _, cert := range cstate.PeerCertificates[1:] {
		opts.Intermediates.AddCert(cert)
	}

	cert := cstate.PeerCertificates[0]
	_, err = cert.Verify(opts)
	return cert, err
}

func split(socket string) (net, addr string) {
	ary := strings.SplitN(socket, ":", 2)
	if len(ary) == 1 {
		return "tcp", ary[0]
	}
	return ary[0], ary[1]
}

func main() {
	for _, socket := range os.Args[1:] {
		fmt.Fprintf(os.Stderr, "Getting %q... ", socket)
		block := pem.Block{
			Type:    "CERTIFICATE",
			Headers: map[string]string{"X-Socket": socket},
			Bytes:   nil,
		}
		cert, err := getcert(socket)
		if cert != nil {
			block.Bytes = cert.Raw
		}
		if err != nil {
			block.Headers["X-Error"] = err.Error()
		}
		pem.Encode(os.Stdout, &block)
		fmt.Fprintln(os.Stderr, "[done]")
	}
}