diff --git a/orbit/pkg/certificate/certificate.go b/orbit/pkg/certificate/certificate.go index 2f9d29e88..f704d61ec 100644 --- a/orbit/pkg/certificate/certificate.go +++ b/orbit/pkg/certificate/certificate.go @@ -2,42 +2,14 @@ package certificate import ( - "bytes" "crypto/tls" "crypto/x509" - "encoding/pem" "io/ioutil" "net/url" "github.com/pkg/errors" ) -// FetchPEM retrieves the certificate chain presented by the server listening at -// hostname in PEM format. -// -// Adapted from https://stackoverflow.com/a/46735876/491710 -func FetchPEM(hostname string) ([]byte, error) { - conn, err := tls.Dial("tcp", hostname, &tls.Config{ - InsecureSkipVerify: true, - }) - if err != nil { - return nil, errors.Wrap(err, "dial server to fetch PEM") - } - defer conn.Close() - - var b bytes.Buffer - for _, cert := range conn.ConnectionState().PeerCertificates { - err := pem.Encode(&b, &pem.Block{ - Type: "CERTIFICATE", - Bytes: cert.Raw, - }) - if err != nil { - return nil, errors.Wrap(err, "encode PEM") - } - } - return b.Bytes(), nil -} - // LoadPEM loads certificates from a PEM file and returns a cert pool containing // the certificates. func LoadPEM(path string) (*x509.CertPool, error) { diff --git a/orbit/pkg/certificate/certificate_test.go b/orbit/pkg/certificate/certificate_test.go index 80c1b8418..9db7f92b4 100644 --- a/orbit/pkg/certificate/certificate_test.go +++ b/orbit/pkg/certificate/certificate_test.go @@ -1,52 +1,13 @@ package certificate import ( - "io/ioutil" - "net" - "net/http" "path/filepath" - "strconv" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestFetchPEMInvalidHostname(t *testing.T) { - t.Parallel() - - _, err := FetchPEM("foobar") - require.Error(t, err) -} - -func TestFetchPEM(t *testing.T) { - t.Parallel() - - certPath := filepath.Join("testdata", "test.crt") - keyPath := filepath.Join("testdata", "test.key") - expectedCert, err := ioutil.ReadFile(certPath) - require.NoError(t, err) - - var port int - go func() { - // Assign any available port - listener, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err) - port = listener.Addr().(*net.TCPAddr).Port - defer listener.Close() - - err = http.ServeTLS(listener, nil, certPath, keyPath) - require.NoError(t, err) - }() - // Sleep to allow the goroutine to run and start the server. - time.Sleep(10 * time.Millisecond) - - pem, err := FetchPEM("localhost:" + strconv.Itoa(port)) - require.NoError(t, err) - assert.Equal(t, expectedCert, pem) -} - func TestLoadPEM(t *testing.T) { t.Parallel()