diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 3a7ea50d..47e14ea5 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -43,6 +43,8 @@ const ( HashArg = "sha1" StoreLocationArg = "store-location" // 'machine', 'user', etc StoreNameArg = "store" // 'MY', 'CA', 'ROOT', etc + FriendlyNameArg = "friendly-name" + DescriptionArg = "description" IntermediateStoreLocationArg = "intermediate-store-location" IntermediateStoreNameArg = "intermediate-store" KeyIDArg = "key-id" @@ -90,6 +92,8 @@ type uriAttributes struct { subjectCN string serialNumber string issuerName string + friendlyName string + description string keySpec string skipFindCertificateKey bool pin string @@ -126,6 +130,8 @@ func parseURI(rawuri string) (*uriAttributes, error) { subjectCN: u.Get(SubjectCNArg), serialNumber: u.Get(SerialNumberArg), issuerName: u.Get(IssuerNameArg), + friendlyName: u.Get(FriendlyNameArg), + description: u.Get(DescriptionArg), keySpec: u.Get(KeySpec), skipFindCertificateKey: u.GetBool(SkipFindCertificateKey), pin: u.Pin(), @@ -386,11 +392,17 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error) 0, 0, certStoreLocation, - uintptr(unsafe.Pointer(wide(u.storeName)))) + uintptr(unsafe.Pointer(wide(u.storeName))), + ) if err != nil { return nil, fmt.Errorf("CertOpenStore for the %q store %q returned: %w", u.storeLocation, u.storeName, err) } + // if issuer + any of the other fields in the list below is provided, then attempt a second certificate lookup when + // lookup by KeyID fails (not found). This fix an issue when looking up device certificates, as in that case the KeyID is + // derived from a randomly generate string each time agent runs, thus not being able to find certificates installed from + // a previous run. + canLookupByIssuer := u.issuerName != "" && (u.serialNumber != "" || u.subjectCN != "" || u.friendlyName != "" || u.description != "") var handle *windows.CertContext switch { @@ -429,67 +441,91 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error) if err != nil { return nil, fmt.Errorf("findCertificateInStore failed: %w", err) } + + if handle == nil && !canLookupByIssuer { + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", KeyIDArg, u.keyID)} + } + } + + if handle != nil { + return handle, err + } + + if !canLookupByIssuer { + return nil, fmt.Errorf("%q, %q, or %q and one of %q or %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg, SubjectCNArg) + } + + // lookup certificate by issuer + another field (serial, CN, friendlyName, description) + var prevCert *windows.CertContext + for { + handle, err = findCertificateInStore(st, + encodingX509ASN|encodingPKCS7, + 0, + findIssuerStr, + uintptr(unsafe.Pointer(wide(u.issuerName))), prevCert) + if err != nil { + return nil, fmt.Errorf("findCertificateInStore failed: %w", err) + } + if handle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%x not found", KeyIDArg, u.keyID)} + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q not found", IssuerNameArg, u.issuerName)} + } + + x509Cert, err := certContextToX509(handle) + if err != nil { + return nil, fmt.Errorf("could not unmarshal certificate to DER: %w", err) } - case u.issuerName != "" && (u.serialNumber != "" || u.subjectCN != ""): - var prevCert *windows.CertContext - for { - handle, err = findCertificateInStore(st, - encodingX509ASN|encodingPKCS7, - 0, - findIssuerStr, - uintptr(unsafe.Pointer(wide(u.issuerName))), prevCert) - if err != nil { - return nil, fmt.Errorf("findCertificateInStore failed: %w", err) - } - if handle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q not found", IssuerNameArg, u.issuerName)} + switch { + case len(u.serialNumber) > 0: + // TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead + // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id + // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number + var bi *big.Int + if strings.HasPrefix(u.serialNumber, "0x") { + serialBytes, err := hex.DecodeString(strings.TrimPrefix(u.serialNumber, "0x")) + if err != nil { + return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) + } + + bi = new(big.Int).SetBytes(serialBytes) + } else { + bi := new(big.Int) + bi, ok := bi.SetString(u.serialNumber, 10) + if !ok { + return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) + } } - x509Cert, err := certContextToX509(handle) + if x509Cert.SerialNumber.Cmp(bi) == 0 { + return handle, nil + } + case len(u.subjectCN) > 0: + if x509Cert.Subject.CommonName == u.subjectCN { + return handle, nil + } + case len(u.friendlyName) > 0: + val, err := cryptFindCertificateFriendlyName(handle) if err != nil { - return nil, fmt.Errorf("could not unmarshal certificate to DER: %w", err) + return nil, fmt.Errorf("cryptFindCertificateFriendlyName failed: %w", err) } - switch { - case len(u.serialNumber) > 0: - // TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead - // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id - // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number - var bi *big.Int - if strings.HasPrefix(u.serialNumber, "0x") { - serialBytes, err := hex.DecodeString(strings.TrimPrefix(u.serialNumber, "0x")) - if err != nil { - return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) - } - - bi = new(big.Int).SetBytes(serialBytes) - } else { - bi := new(big.Int) - bi, ok := bi.SetString(u.serialNumber, 10) - if !ok { - return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) - } - } - - if x509Cert.SerialNumber.Cmp(bi) == 0 { - return handle, nil - } - case len(u.subjectCN) > 0: - if x509Cert.Subject.CommonName == u.subjectCN { - return handle, nil - } + if val == u.friendlyName { + return handle, nil + } + case len(u.description) > 0: + val, err := cryptFindCertificateDescription(handle) + if err != nil { + return nil, fmt.Errorf("cryptFindCertificateDescription failed: %w", err) } - prevCert = handle + if val == u.description { + return handle, nil + } } - default: - return nil, fmt.Errorf("%q, %q, or %q and one of %q or %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg, SubjectCNArg) - } - return handle, err + prevCert = handle + } } // CreateSigner returns a crypto.Signer that will sign using the key passed in via the URI. @@ -827,6 +863,14 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { cryptFindCertificateKeyProvInfo(certContext) } + if u.friendlyName != "" { + cryptSetCertificateFriendlyName(certContext, u.friendlyName) + } + + if u.description != "" { + cryptSetCertificateDescription(certContext, u.description) + } + st, err := windows.CertOpenStore( certStoreProvSystem, 0, @@ -862,6 +906,8 @@ func (k *CAPIKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) HashArg: []string{fp}, StoreLocationArg: []string{u.storeLocation}, StoreNameArg: []string{u.storeName}, + FriendlyNameArg: []string{u.friendlyName}, + DescriptionArg: []string{u.description}, SkipFindCertificateKey: []string{strconv.FormatBool(u.skipFindCertificateKey)}, }).String(), Certificate: leaf, diff --git a/kms/capi/ncrypt_windows.go b/kms/capi/ncrypt_windows.go index 9366f08b..25dd37af 100644 --- a/kms/capi/ncrypt_windows.go +++ b/kms/capi/ncrypt_windows.go @@ -60,9 +60,11 @@ const ( compareShift = 16 // CERT_COMPARE_SHIFT compareSHA1Hash = 1 // CERT_COMPARE_SHA1_HASH compareCertID = 16 // CERT_COMPARE_CERT_ID + compareProp = 5 // CERT_COMPARE_CERT_ID findIssuerStr = compareNameStrW<