Skip to content
Open
146 changes: 96 additions & 50 deletions kms/capi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ const (
HashArg = "sha1"
StoreLocationArg = "store-location" // 'machine', 'user', etc
StoreNameArg = "store" // 'MY', 'CA', 'ROOT', etc
FriendlyNameArg = "friendly-name"
DescriptionArg = "description"
IntermediateStoreLocationArg = "intermediate-store-location"
IntermediateStoreNameArg = "intermediate-store"
KeyIDArg = "key-id"
Expand Down Expand Up @@ -90,6 +92,8 @@ type uriAttributes struct {
subjectCN string
serialNumber string
issuerName string
friendlyName string
description string
keySpec string
skipFindCertificateKey bool
pin string
Expand Down Expand Up @@ -126,6 +130,8 @@ func parseURI(rawuri string) (*uriAttributes, error) {
subjectCN: u.Get(SubjectCNArg),
serialNumber: u.Get(SerialNumberArg),
issuerName: u.Get(IssuerNameArg),
friendlyName: u.Get(FriendlyNameArg),
description: u.Get(DescriptionArg),
keySpec: u.Get(KeySpec),
skipFindCertificateKey: u.GetBool(SkipFindCertificateKey),
pin: u.Pin(),
Expand Down Expand Up @@ -386,11 +392,17 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error)
0,
0,
certStoreLocation,
uintptr(unsafe.Pointer(wide(u.storeName))))
uintptr(unsafe.Pointer(wide(u.storeName))),
)
if err != nil {
return nil, fmt.Errorf("CertOpenStore for the %q store %q returned: %w", u.storeLocation, u.storeName, err)
}

// if issuer + any of the other fields in the list below is provided, then attempt a second certificate lookup when
// lookup by KeyID fails (not found). This fix an issue when looking up device certificates, as in that case the KeyID is
// derived from a randomly generate string each time agent runs, thus not being able to find certificates installed from
// a previous run.
canLookupByIssuer := u.issuerName != "" && (u.serialNumber != "" || u.subjectCN != "" || u.friendlyName != "" || u.description != "")
var handle *windows.CertContext

switch {
Expand Down Expand Up @@ -429,67 +441,91 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error)
if err != nil {
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}

if handle == nil && !canLookupByIssuer {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", KeyIDArg, u.keyID)}
}
}

if handle != nil {
return handle, err
}

if !canLookupByIssuer {
return nil, fmt.Errorf("%q, %q, or %q and one of %q or %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg, SubjectCNArg)
}

// lookup certificate by issuer + another field (serial, CN, friendlyName, description)
var prevCert *windows.CertContext
for {
handle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
uintptr(unsafe.Pointer(wide(u.issuerName))), prevCert)
if err != nil {
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}

if handle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%x not found", KeyIDArg, u.keyID)}
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q not found", IssuerNameArg, u.issuerName)}
}

x509Cert, err := certContextToX509(handle)
if err != nil {
return nil, fmt.Errorf("could not unmarshal certificate to DER: %w", err)
}
case u.issuerName != "" && (u.serialNumber != "" || u.subjectCN != ""):
var prevCert *windows.CertContext
for {
handle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
uintptr(unsafe.Pointer(wide(u.issuerName))), prevCert)
if err != nil {
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}

if handle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q not found", IssuerNameArg, u.issuerName)}
switch {
case len(u.serialNumber) > 0:
// TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
var bi *big.Int
if strings.HasPrefix(u.serialNumber, "0x") {
serialBytes, err := hex.DecodeString(strings.TrimPrefix(u.serialNumber, "0x"))
if err != nil {
return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err)
}

bi = new(big.Int).SetBytes(serialBytes)
} else {
bi := new(big.Int)
bi, ok := bi.SetString(u.serialNumber, 10)
if !ok {
return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg)
}
}

x509Cert, err := certContextToX509(handle)
if x509Cert.SerialNumber.Cmp(bi) == 0 {
return handle, nil
}
case len(u.subjectCN) > 0:
if x509Cert.Subject.CommonName == u.subjectCN {
return handle, nil
}
case len(u.friendlyName) > 0:
val, err := cryptFindCertificateFriendlyName(handle)
if err != nil {
return nil, fmt.Errorf("could not unmarshal certificate to DER: %w", err)
return nil, fmt.Errorf("cryptFindCertificateFriendlyName failed: %w", err)
}

switch {
case len(u.serialNumber) > 0:
// TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
var bi *big.Int
if strings.HasPrefix(u.serialNumber, "0x") {
serialBytes, err := hex.DecodeString(strings.TrimPrefix(u.serialNumber, "0x"))
if err != nil {
return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err)
}

bi = new(big.Int).SetBytes(serialBytes)
} else {
bi := new(big.Int)
bi, ok := bi.SetString(u.serialNumber, 10)
if !ok {
return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg)
}
}

if x509Cert.SerialNumber.Cmp(bi) == 0 {
return handle, nil
}
case len(u.subjectCN) > 0:
if x509Cert.Subject.CommonName == u.subjectCN {
return handle, nil
}
if val == u.friendlyName {
return handle, nil
}
case len(u.description) > 0:
val, err := cryptFindCertificateDescription(handle)
if err != nil {
return nil, fmt.Errorf("cryptFindCertificateDescription failed: %w", err)
}

prevCert = handle
if val == u.description {
return handle, nil
}
}
default:
return nil, fmt.Errorf("%q, %q, or %q and one of %q or %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg, SubjectCNArg)
}

return handle, err
prevCert = handle
}
}

// CreateSigner returns a crypto.Signer that will sign using the key passed in via the URI.
Expand Down Expand Up @@ -827,6 +863,14 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
cryptFindCertificateKeyProvInfo(certContext)
}

if u.friendlyName != "" {
cryptSetCertificateFriendlyName(certContext, u.friendlyName)
}

if u.description != "" {
cryptSetCertificateDescription(certContext, u.description)
}

st, err := windows.CertOpenStore(
certStoreProvSystem,
0,
Expand Down Expand Up @@ -862,6 +906,8 @@ func (k *CAPIKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest)
HashArg: []string{fp},
StoreLocationArg: []string{u.storeLocation},
StoreNameArg: []string{u.storeName},
FriendlyNameArg: []string{u.friendlyName},
DescriptionArg: []string{u.description},
SkipFindCertificateKey: []string{strconv.FormatBool(u.skipFindCertificateKey)},
}).String(),
Certificate: leaf,
Expand Down
101 changes: 101 additions & 0 deletions kms/capi/ncrypt_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ const (
compareShift = 16 // CERT_COMPARE_SHIFT
compareSHA1Hash = 1 // CERT_COMPARE_SHA1_HASH
compareCertID = 16 // CERT_COMPARE_CERT_ID
compareProp = 5 // CERT_COMPARE_CERT_ID
findIssuerStr = compareNameStrW<<compareShift | infoIssuerFlag // CERT_FIND_ISSUER_STR_W
findIssuerName = compareName<<compareShift | infoIssuerFlag // CERT_FIND_ISSUER_NAME
findHash = compareSHA1Hash << compareShift // CERT_FIND_HASH
findProperty = compareProp << compareShift // CERT_FIND_PROPERTY
findCertID = compareCertID << compareShift // CERT_FIND_CERT_ID

signatureKeyUsage = 0x80 // CERT_DIGITAL_SIGNATURE_KEY_USAGE
Expand All @@ -82,6 +84,8 @@ const (
CERT_ID_SHA1_HASH = uint32(3)

CERT_KEY_PROV_INFO_PROP_ID = uint32(2)
CERT_FRIENDLY_NAME_PROP_ID = uint32(11)
CERT_DESCRIPTION_PROP_ID = uint32(13)

CERT_NAME_STR_COMMA_FLAG = uint32(0x04000000)
CERT_SIMPLE_NAME_STR = uint32(1)
Expand Down Expand Up @@ -151,6 +155,7 @@ var (
procCertFindCertificateInStore = crypt32.MustFindProc("CertFindCertificateInStore")
procCryptFindCertificateKeyProvInfo = crypt32.MustFindProc("CryptFindCertificateKeyProvInfo")
procCertGetCertificateContextProperty = crypt32.MustFindProc("CertGetCertificateContextProperty")
procCertSetCertificateContextProperty = crypt32.MustFindProc("CertSetCertificateContextProperty")
procCertStrToName = crypt32.MustFindProc("CertStrToNameW")
)

Expand Down Expand Up @@ -606,6 +611,102 @@ func cryptFindCertificateKeyContainerName(certContext *windows.CertContext) (str
return "", nil
}

func certSetCertificateContextProperty(certContext *windows.CertContext, propID uint32, pvData uintptr) error {
r0, _, err := procCertSetCertificateContextProperty.Call(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is r0? I think this could use a more descriptive name.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually r1, this invokes a system call on windows (whose procedure is found in a DLL), the system call returns r1, r2, error, in general, where r1 represents the return value status from the procedure stored in a register (e.g. on Linux the %rax value), the semantic value of this depends on the procedure invoked, r2 is usually not used but kept for compatibility with platforms that return status on more registers, and the 3rd value return is actually an error, on windows the error is always non nil so we must check r0(r1).
I think we can leave this low level stuff as is, or perhaps considering refactor this in another pr.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that all makes sense. Can you add a comment to these funcs explaining it? With the information you provided above, it's clear, but without it, it's not clear what's happening.

uintptr(unsafe.Pointer(certContext)),
uintptr(propID),
0,
pvData,
)

if r0 == 0 {
return err
}
return nil
Comment on lines +622 to +625

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is err assumed to be non-nil here?

If err != nil, and r0 is != 0, is it ok to drop the error?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This invokes CertSetCertificateContextProperty which returns a bool, on success it returns true.
But since this is system call invocation the bool is cast to int, false is zero, so in that case we actually return the error.
We handle this as syscall on windows always return a non nil error even if the function succeeds

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the above, I think comments helps a lot here.

}

func cryptSetCertificateFriendlyName(certContext *windows.CertContext, val string) error {
data := CRYPTOAPI_BLOB{
len: uint32(len(val)+1) * 2,
data: uintptr(unsafe.Pointer(wide(val))),
}

return certSetCertificateContextProperty(certContext, CERT_FRIENDLY_NAME_PROP_ID, uintptr(unsafe.Pointer(&data)))
}

func cryptSetCertificateDescription(certContext *windows.CertContext, val string) error {
data := CRYPTOAPI_BLOB{
len: uint32(len(val)+1) * 2,
data: uintptr(unsafe.Pointer(wide(val))),
}

return certSetCertificateContextProperty(certContext, CERT_DESCRIPTION_PROP_ID, uintptr(unsafe.Pointer(&data)))
}

func certGetCertificateContextProperty(certContext *windows.CertContext, propID uint32, pvData *byte, pcbData *uint32) error {
r0, _, err := procCertGetCertificateContextProperty.Call(
uintptr(unsafe.Pointer(certContext)),
uintptr(propID),
uintptr(unsafe.Pointer(pvData)),
uintptr(unsafe.Pointer(pcbData)),
)
if r0 == 0 {
return err
}
return nil
}

func cryptFindCertificateFriendlyName(certContext *windows.CertContext) (string, error) {
var size uint32

err := certGetCertificateContextProperty(certContext, CERT_FRIENDLY_NAME_PROP_ID, nil, &size)
if err != nil {
if errno, ok := err.(windows.Errno); ok && uint32(errno) == CRYPT_E_NOT_FOUND {
return "", nil
}

return "", err
}

if size == 0 {
return "", nil
}

buf := make([]byte, size)
err = certGetCertificateContextProperty(certContext, CERT_FRIENDLY_NAME_PROP_ID, &buf[0], &size)
if err != nil {
return "", err
}

uc := bytes.ReplaceAll(buf, []byte{0x00}, []byte(""))
return string(uc), nil
}

func cryptFindCertificateDescription(certContext *windows.CertContext) (string, error) {
var size uint32

err := certGetCertificateContextProperty(certContext, CERT_DESCRIPTION_PROP_ID, nil, &size)
if err != nil {
if errno, ok := err.(windows.Errno); ok && uint32(errno) == CRYPT_E_NOT_FOUND {
return "", nil
}

return "", err
}
if size == 0 {
return "", nil
}

buf := make([]byte, size)
err = certGetCertificateContextProperty(certContext, CERT_DESCRIPTION_PROP_ID, &buf[0], &size)
if err != nil {
return "", err
}

uc := bytes.ReplaceAll(buf, []byte{0x00}, []byte(""))
return string(uc), nil
}

func certStrToName(x500Str string) ([]byte, error) {
var size uint32

Expand Down
19 changes: 13 additions & 6 deletions kms/tpmkms/tpmkms.go
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,9 @@ func (k *TPMKMS) loadCertificateChainFromWindowsCertificateStore(req *apiv1.Load
"store": []string{store},
"intermediate-store-location": []string{intermediateCAStoreLocation},
"intermediate-store": []string{intermediateCAStore},
"issuer": []string{o.issuer},
"friendly-name": []string{o.friendlyName},
"description": []string{o.description},
}).String(),
})
}
Expand Down Expand Up @@ -966,6 +969,8 @@ func (k *TPMKMS) storeCertificateChainToWindowsCertificateStore(req *apiv1.Store
Name: uri.New("capi", url.Values{
"store-location": []string{location},
"store": []string{store},
"friendly-name": []string{o.friendlyName},
"description": []string{o.description},
"skip-find-certificate-key": []string{skipFindCertificateKey},
"intermediate-store-location": []string{intermediateCAStoreLocation},
"intermediate-store": []string{intermediateCAStore},
Expand Down Expand Up @@ -1435,9 +1440,11 @@ type deletingCertificateChainManager interface {
DeleteCertificate(req *apiv1.DeleteCertificateRequest) error
}

var _ apiv1.KeyManager = (*TPMKMS)(nil)
var _ apiv1.Attester = (*TPMKMS)(nil)
var _ apiv1.CertificateManager = (*TPMKMS)(nil)
var _ apiv1.CertificateChainManager = (*TPMKMS)(nil)
var _ deletingCertificateChainManager = (*TPMKMS)(nil)
var _ apiv1.AttestationClient = (*attestationClient)(nil)
var (
_ apiv1.KeyManager = (*TPMKMS)(nil)
_ apiv1.Attester = (*TPMKMS)(nil)
_ apiv1.CertificateManager = (*TPMKMS)(nil)
_ apiv1.CertificateChainManager = (*TPMKMS)(nil)
_ deletingCertificateChainManager = (*TPMKMS)(nil)
_ apiv1.AttestationClient = (*attestationClient)(nil)
)
Loading
Loading