diff --git a/cmd/step/main.go b/cmd/step/main.go index 82777390e..59e843375 100644 --- a/cmd/step/main.go +++ b/cmd/step/main.go @@ -43,6 +43,7 @@ import ( _ "github.com/smallstep/cli/command/oauth" _ "github.com/smallstep/cli/command/path" _ "github.com/smallstep/cli/command/ssh" + _ "github.com/smallstep/cli/command/tls" ) // Version is set by an LDFLAG at build time representing the git tag or commit diff --git a/command/certificate/fingerprint.go b/command/certificate/fingerprint.go index edac36938..fecdbebc9 100644 --- a/command/certificate/fingerprint.go +++ b/command/certificate/fingerprint.go @@ -132,7 +132,7 @@ func fingerprintAction(ctx *cli.Context) error { return err } - switch addr, isURL, err := trimURL(crtFile); { + switch addr, isURL, err := utils.TrimURL(crtFile); { case err != nil: return err case isURL: diff --git a/command/certificate/inspect.go b/command/certificate/inspect.go index 7d465b349..cba3a74d9 100644 --- a/command/certificate/inspect.go +++ b/command/certificate/inspect.go @@ -208,7 +208,7 @@ func inspectAction(ctx *cli.Context) error { return errs.IncompatibleFlagWithFlag(ctx, "short", "format "+format) } - switch addr, isURL, err := trimURL(crtFile); { + switch addr, isURL, err := utils.TrimURL(crtFile); { case err != nil: return err case isURL: diff --git a/command/certificate/lint.go b/command/certificate/lint.go index 8eecde2b3..c4c495790 100644 --- a/command/certificate/lint.go +++ b/command/certificate/lint.go @@ -13,6 +13,7 @@ import ( "github.com/smallstep/zlint" "github.com/smallstep/cli/flags" + "github.com/smallstep/cli/utils" ) func lintCommand() cli.Command { @@ -103,7 +104,7 @@ func lintAction(ctx *cli.Context) error { insecure = ctx.Bool("insecure") block *pem.Block ) - switch addr, isURL, err := trimURL(crtFile); { + switch addr, isURL, err := utils.TrimURL(crtFile); { case err != nil: return err case isURL: diff --git a/command/certificate/needsRenewal.go b/command/certificate/needsRenewal.go index 399ef0c32..66db6f0d1 100644 --- a/command/certificate/needsRenewal.go +++ b/command/certificate/needsRenewal.go @@ -15,6 +15,7 @@ import ( "go.step.sm/crypto/pemutil" "github.com/smallstep/cli/flags" + "github.com/smallstep/cli/utils" ) const defaultPercentUsedThreshold = 66 @@ -157,7 +158,7 @@ func needsRenewalAction(ctx *cli.Context) error { ) var certs []*x509.Certificate - switch addr, isURL, err := trimURL(certFile); { + switch addr, isURL, err := utils.TrimURL(certFile); { case err != nil: return errs.NewExitError(err, 255) case isURL: diff --git a/command/certificate/remote.go b/command/certificate/remote.go index 211de5047..200683f9a 100644 --- a/command/certificate/remote.go +++ b/command/certificate/remote.go @@ -4,9 +4,6 @@ import ( "crypto/tls" "crypto/x509" "net" - "net/url" - "strconv" - "strings" "github.com/pkg/errors" "go.step.sm/crypto/x509util" @@ -63,33 +60,3 @@ func getPeerCertificates(addr, serverName, roots string, insecure bool) ([]*x509 conn.Close() return conn.ConnectionState().PeerCertificates, nil } - -// trimURL returns the host[:port] if the input is a URL, otherwise returns an -// empty string (and 'isURL:false'). -// -// If the URL is valid and no port is specified, the default port determined -// by the URL prefix is used. -// -// Examples: -// trimURL("https://smallstep.com/onboarding") -> "smallstep.com:443", true, nil -// trimURL("https://ca.smallSTEP.com:8080") -> "ca.smallSTEP.com:8080", true, nil -// trimURL("./certs/root_ca.crt") -> "", false, nil -// trimURL("hTtPs://sMaLlStEp.cOm") -> "sMaLlStEp.cOm:443", true, nil -// trimURL("hTtPs://sMaLlStEp.cOm hello") -> "", false, err{"invalid url"} -func trimURL(ref string) (string, bool, error) { - tmp := strings.ToLower(ref) - for prefix := range urlPrefixes { - if strings.HasPrefix(tmp, prefix) { - u, err := url.Parse(ref) - if err != nil { - return "", false, errors.Wrapf(err, "error parsing URL '%s'", ref) - } - if _, _, err := net.SplitHostPort(u.Host); err != nil { - port := strconv.FormatUint(uint64(urlPrefixes[prefix]), 10) - u.Host = net.JoinHostPort(u.Host, port) - } - return u.Host, true, nil - } - } - return "", false, nil -} diff --git a/command/certificate/remote_test.go b/command/certificate/remote_test.go index 91037745a..001287ba3 100644 --- a/command/certificate/remote_test.go +++ b/command/certificate/remote_test.go @@ -8,37 +8,6 @@ import ( "github.com/smallstep/assert" ) -func TestTrimURL(t *testing.T) { - type newTest struct { - input, host string - isURL bool - err error - } - tests := map[string]newTest{ - "true-http": {"https://smallstep.com", "smallstep.com:443", true, nil}, - "true-tcp": {"tcp://smallstep.com:8080", "smallstep.com:8080", true, nil}, - "true-tls": {"tls://smallstep.com/onboarding", "smallstep.com:443", true, nil}, - "false": {"./certs/root_ca.crt", "", false, nil}, - "false-err": {"https://google.com hello", "", false, errors.New("error parsing URL 'https://google.com hello'")}, - "true-http-case": {"hTtPs://sMaLlStEp.cOm", "sMaLlStEp.cOm:443", true, nil}, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - host, isURL, err := trimURL(tc.input) - assert.Equals(t, tc.host, host) - assert.Equals(t, tc.isURL, isURL) - if err != nil { - if assert.NotNil(t, tc.err) { - assert.HasPrefix(t, err.Error(), tc.err.Error()) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - func TestGetPeerCertificateServerName(t *testing.T) { host := "smallstep.com" serverName := host diff --git a/command/certificate/verify.go b/command/certificate/verify.go index 81e34706a..bcfc2344e 100644 --- a/command/certificate/verify.go +++ b/command/certificate/verify.go @@ -19,6 +19,7 @@ import ( "github.com/smallstep/cli/flags" "github.com/smallstep/cli/internal/crlutil" + "github.com/smallstep/cli/utils" ) func verifyCommand() cli.Command { @@ -170,7 +171,7 @@ func verifyAction(ctx *cli.Context) error { httpClient *http.Client ) - switch addr, isURL, err := trimURL(crtFile); { + switch addr, isURL, err := utils.TrimURL(crtFile); { case err != nil: return err case isURL: diff --git a/command/tls/handshake.go b/command/tls/handshake.go new file mode 100644 index 000000000..892545511 --- /dev/null +++ b/command/tls/handshake.go @@ -0,0 +1,268 @@ +package tls + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "net" + "reflect" + + "github.com/smallstep/cli-utils/errs" + "github.com/smallstep/cli/flags" + "github.com/smallstep/cli/internal/cryptoutil" + "github.com/smallstep/cli/utils" + "github.com/urfave/cli" + "go.step.sm/crypto/pemutil" + "go.step.sm/crypto/x509util" +) + +func handshakeCommand() cli.Command { + return cli.Command{ + Name: "handshake", + Action: cli.ActionFunc(handshakeAction), + Usage: `print handshake details`, + UsageText: `**step tls handshake** `, + Description: `**step tls handshake** displays detailed handshake information for a TLS connection.`, + Flags: []cli.Flag{ + flags.ServerName, + cli.StringFlag{ + Name: "tls", + Usage: `Defines the TLS in the handshake. By default it will use TLS 1.3 or TLS 1.2. +: The supported versions are **1.3**, **1.2**, **1.1**, and **1.0**.`, + }, + cli.StringFlag{ + Name: "cert", + Usage: `The path to the containing the client certificate to use.`, + }, + cli.StringFlag{ + Name: "key", + Usage: `The path to the or KMS containing the certificate key to use.`, + }, + cli.StringFlag{ + Name: "roots", + Usage: `Root certificate(s) that will be used to verify the +authenticity of the remote server. + +: is a case-sensitive string and may be one of: + + **file** + : Relative or full path to a file. All certificates in the file will be used for path validation. + + **list of files** + : Comma-separated list of relative or full file paths. Every PEM encoded certificate from each file will be used for path validation. + + **directory** + : Relative or full path to a directory. Every PEM encoded certificate from each file in the directory will be used for path validation.`, + }, + + cli.StringFlag{ + Name: "password-file", + Usage: "The path to the containing the password to decrypt the private key.", + }, + cli.BoolFlag{ + Name: "chain", + Usage: "Print only the chain of verified certificates.", + }, + cli.BoolFlag{ + Name: "peer", + Usage: `Print only the peer certificates sent by the server.`, + }, + cli.BoolFlag{ + Name: "insecure", + Usage: `Use an insecure client to retrieve a remote peer certificate. Useful for +debugging invalid certificates remotely.`, + }, + }, + } +} + +func handshakeAction(c *cli.Context) error { + if err := errs.NumberOfArguments(c, 1); err != nil { + return err + } + + var ( + addr = c.Args().First() + tlsVersion = c.String("tls") + roots = c.String("roots") + serverName = c.String("servername") + certFile = c.String("cert") + keyFile = c.String("key") + passwordFile = c.String("password-file") + printChains = c.Bool("chain") + printPeer = c.Bool("peer") + insecure = c.Bool("insecure") + rootCAs *x509.CertPool + err error + ) + + switch { + case certFile != "" && keyFile == "": + return errs.RequiredWithFlag(c, "cert", "key") + case keyFile != "" && certFile == "": + return errs.RequiredWithFlag(c, "key", "cert") + } + + // Parse address + if u, ok, err := utils.TrimURL(addr); err != nil { + return err + } else if ok { + addr = u + } + if _, _, err := net.SplitHostPort(addr); err != nil { + addr = net.JoinHostPort(addr, "443") + } + + // Load client TLS certificate + var certificates []tls.Certificate + if certFile != "" && keyFile != "" { + opts := []pemutil.Options{} + if passwordFile != "" { + opts = append(opts, pemutil.WithPasswordFile(passwordFile)) + } + crt, err := cryptoutil.LoadTLSCertificate(certFile, keyFile, opts...) + if err != nil { + return err + } + certificates = []tls.Certificate{crt} + } + + // Get the list of roots used to validate the certificate. + if roots != "" { + rootCAs, err = x509util.ReadCertPool(roots) + if err != nil { + return fmt.Errorf("error loading root certificate pool from %q: %w", roots, err) + } + } else { + rootCAs, err = x509.SystemCertPool() + if err != nil { + return fmt.Errorf("error loading the system cert pool: %w", err) + } + } + + // Get the tls version to use. Defaults to TLS 1.2+ + minVersion, maxVersion, err := getTLSVersions(tlsVersion) + if err != nil { + return err + } + + tlsConfig := &tls.Config{ + MinVersion: minVersion, + MaxVersion: maxVersion, + RootCAs: rootCAs, + InsecureSkipVerify: insecure, + ServerName: serverName, + Certificates: certificates, + } + + cs, err := tlsDialWithFallback(context.Background(), addr, tlsConfig) + if err != nil { + return err + } + + // Print only the list of verified chains + if printChains { + if len(cs.VerifiedChains) == 0 { + return errors.New("failed to build a chain of verified certificates") + } + for _, chain := range cs.VerifiedChains { + for _, crt := range chain { + fmt.Print(string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: crt.Raw, + }))) + } + } + return nil + } + + // Print only the peer certificates + if printPeer { + if len(cs.PeerCertificates) == 0 { + return errors.New("peer did not sent a certificate") + } + for _, crt := range cs.PeerCertificates { + fmt.Print(string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", Bytes: crt.Raw, + }))) + } + return nil + } + + // Check if the certificate is valid + var intermediates *x509.CertPool + if len(cs.PeerCertificates) > 1 { + intermediates = x509.NewCertPool() + for _, crt := range cs.PeerCertificates[1:] { + intermediates.AddCert(crt) + } + } + _, verifyErr := cs.PeerCertificates[0].Verify(x509.VerifyOptions{ + Roots: rootCAs, + Intermediates: intermediates, + DNSName: serverName, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + }) + + connStateValue := reflect.ValueOf(cs) + curveIDField := connStateValue.FieldByName("testingOnlyCurveID") + + fmt.Printf("Server Name: %s\n", cs.ServerName) + fmt.Printf("Version: %s\n", tls.VersionName(cs.Version)) + fmt.Printf("Cipher Suite: %s\n", tls.CipherSuiteName(cs.CipherSuite)) + fmt.Printf("KEM: %s\n", curveIDName(curveIDField.Uint())) + fmt.Printf("Insecure: %v\n", tlsConfig.InsecureSkipVerify) + fmt.Printf("Verified: %v\n", verifyErr == nil) + + return nil +} + +func curveIDName(curveID uint64) string { + switch tls.CurveID(curveID) { + case tls.CurveP256: + return "P-256" + case tls.CurveP384: + return "P-384" + case tls.CurveP521: + return "P-521" + case tls.X25519: + return "X25519" + case tls.X25519MLKEM768: + return "X25519MLKEM768" + default: + return "Unknown" + } +} + +func getTLSVersions(s string) (uint16, uint16, error) { + switch s { + case "": + return tls.VersionTLS12, 0, nil + case "1.3": + return tls.VersionTLS13, tls.VersionTLS13, nil + case "1.2": + return tls.VersionTLS12, tls.VersionTLS12, nil + case "1.1": + return tls.VersionTLS11, tls.VersionTLS11, nil + case "1.0": + return tls.VersionTLS10, tls.VersionTLS10, nil + default: + return 0, 0, fmt.Errorf("unsupported TLS version %q", s) + } +} + +func tlsDialWithFallback(ctx context.Context, addr string, tlsConfig *tls.Config) (tls.ConnectionState, error) { + conn, err := tls.Dial("tcp", addr, tlsConfig) + if err != nil { + if tlsConfig.InsecureSkipVerify { + return tls.ConnectionState{}, fmt.Errorf("error connecting to %q: %w", addr, err) + } + tlsConfig.InsecureSkipVerify = true + return tlsDialWithFallback(ctx, addr, tlsConfig) + } + defer conn.Close() + return conn.ConnectionState(), conn.HandshakeContext(ctx) +} diff --git a/command/tls/tls.go b/command/tls/tls.go new file mode 100644 index 000000000..1ef78a30f --- /dev/null +++ b/command/tls/tls.go @@ -0,0 +1,32 @@ +package tls + +import ( + "github.com/urfave/cli" + + "github.com/smallstep/cli-utils/command" +) + +// Command returns the cli.Command for jwt and related subcommands. +func init() { + cmd := cli.Command{ + Name: "tls", + Usage: "tls inspection utilities", + UsageText: "step tls SUBCOMMAND [ARGUMENTS] [GLOBAL_FLAGS] [SUBCOMMAND_FLAGS]", + Description: `**step tls** command group provides facilities for +inspecting TLS services. + +## EXAMPLES + +Do a TLS handshake: +''' +$ step tls handshake https://smallstep.com +''' +`, + + Subcommands: cli.Commands{ + handshakeCommand(), + }, + } + + command.Register(cmd) +} diff --git a/internal/cryptoutil/cryptoutil.go b/internal/cryptoutil/cryptoutil.go index 005e74685..b4cd0287f 100644 --- a/internal/cryptoutil/cryptoutil.go +++ b/internal/cryptoutil/cryptoutil.go @@ -6,6 +6,7 @@ import ( "crypto/ed25519" "crypto/elliptic" "crypto/rsa" + "crypto/tls" "crypto/x509" "encoding/base64" "errors" @@ -115,6 +116,37 @@ func LoadCertificate(kmsURI, certPath string) ([]*x509.Certificate, error) { return cert, nil } +// LoadTLSCertificate returns a [tls.Certificate] from a certificate file and a +// key in a file or in a KMS. +func LoadTLSCertificate(certFile, keyName string, opts ...pemutil.Options) (tls.Certificate, error) { + bundle, err := pemutil.ReadCertificateBundle(certFile) + if err != nil { + return tls.Certificate{}, err + } + + var signer crypto.Signer + if IsKMS(keyName) { + if signer, err = CreateSigner(keyName, keyName, opts...); err != nil { + return tls.Certificate{}, err + } + } else { + if signer, err = CreateSigner("", keyName, opts...); err != nil { + return tls.Certificate{}, err + } + } + + cert := make([][]byte, len(bundle)) + for i, crt := range bundle { + cert[i] = crt.Raw + } + + return tls.Certificate{ + Certificate: cert, + PrivateKey: signer, + Leaf: bundle[0], + }, nil +} + // LoadJSONWebKey returns a jose.JSONWebKey from a KMS or a file. func LoadJSONWebKey(kmsURI, name string, opts ...jose.Option) (*jose.JSONWebKey, error) { if kmsURI == "" { diff --git a/utils/utils.go b/utils/utils.go index e888af4d7..9c601e93d 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -2,13 +2,23 @@ package utils import ( "fmt" + "net" "net/url" "os" + "strconv" "strings" "github.com/pkg/errors" ) +var urlPrefixes = map[string]uint16{ + "tcp://": 443, + "tls://": 443, + "https://": 443, + "smtps://": 465, + "ldaps://": 636, +} + // Fail prints out the error struct if STEPDEBUG is true otherwise it just // prints out the error message. Finally, it exits with an error code of 1. func Fail(err error) { @@ -59,3 +69,33 @@ func CompleteURL(rawurl string) (string, error) { // rawurl looks like ca.smallstep.com:443 or ca.smallstep.com:443/1.0/sign return CompleteURL("https://" + rawurl) } + +// TrimURL returns the host[:port] if the input is a URL, otherwise returns an +// empty string (and 'isURL:false'). +// +// If the URL is valid and no port is specified, the default port determined +// by the URL prefix is used. +// +// Examples: +// TrimURL("https://smallstep.com/onboarding") -> "smallstep.com:443", true, nil +// TrimURL("https://ca.smallSTEP.com:8080") -> "ca.smallSTEP.com:8080", true, nil +// TrimURL("./certs/root_ca.crt") -> "", false, nil +// TrimURL("hTtPs://sMaLlStEp.cOm") -> "sMaLlStEp.cOm:443", true, nil +// TrimURL("hTtPs://sMaLlStEp.cOm hello") -> "", false, err{"invalid url"} +func TrimURL(ref string) (string, bool, error) { + tmp := strings.ToLower(ref) + for prefix := range urlPrefixes { + if strings.HasPrefix(tmp, prefix) { + u, err := url.Parse(ref) + if err != nil { + return "", false, fmt.Errorf("error parsing %q: %w", ref, err) + } + if _, _, err := net.SplitHostPort(u.Host); err != nil { + port := strconv.FormatUint(uint64(urlPrefixes[prefix]), 10) + u.Host = net.JoinHostPort(u.Host, port) + } + return u.Host, true, nil + } + } + return "", false, nil +} diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 000000000..910579e7f --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,49 @@ +package utils + +import ( + "strings" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func assertHasPrefix(t *testing.T, s, p string) bool { + if strings.HasPrefix(s, p) { + return true + } + t.Helper() + t.Errorf("%q is not a prefix of %q", p, s) + return false +} + +func TestTrimURL(t *testing.T) { + type newTest struct { + input, host string + isURL bool + err error + } + tests := map[string]newTest{ + "true-http": {"https://smallstep.com", "smallstep.com:443", true, nil}, + "true-tcp": {"tcp://smallstep.com:8080", "smallstep.com:8080", true, nil}, + "true-tls": {"tls://smallstep.com/onboarding", "smallstep.com:443", true, nil}, + "false": {"./certs/root_ca.crt", "", false, nil}, + "false-err": {"https://google.com hello", "", false, errors.New(`error parsing "https://google.com hello"`)}, + "true-http-case": {"hTtPs://sMaLlStEp.cOm", "sMaLlStEp.cOm:443", true, nil}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + host, isURL, err := TrimURL(tc.input) + assert.Equal(t, tc.host, host) + assert.Equal(t, tc.isURL, isURL) + if err != nil { + if assert.NotNil(t, tc.err) { + assertHasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +}