diff --git a/src/main/kotlin/Extension.kt b/src/main/kotlin/Extension.kt index 97a6042..0bc51ae 100644 --- a/src/main/kotlin/Extension.kt +++ b/src/main/kotlin/Extension.kt @@ -22,7 +22,7 @@ import co.nstant.`in`.cbor.CborEncoder import co.nstant.`in`.cbor.CborException import co.nstant.`in`.cbor.model.DataItem import co.nstant.`in`.cbor.model.MajorType -import co.nstant.`in`.cbor.model.Map +import co.nstant.`in`.cbor.model.Map as CborModelMap import co.nstant.`in`.cbor.model.NegativeInteger import co.nstant.`in`.cbor.model.SimpleValue import co.nstant.`in`.cbor.model.SimpleValueType @@ -68,7 +68,7 @@ data class ProvisioningInfoMap( val certificatesIssued: Int, ) { fun cborEncode(): ByteArray { - val map = Map() + val map = CborModelMap() map.put(UnsignedInteger(1L), certificatesIssued.asDataItem()) return cborEncode(map) } @@ -93,7 +93,7 @@ data class ProvisioningInfoMap( throw IllegalArgumentException(e) } - private fun from(seq: Map): ProvisioningInfoMap { + private fun from(seq: CborModelMap): ProvisioningInfoMap { require(seq.keys.size >= 1) return ProvisioningInfoMap( certificatesIssued = seq.get(UnsignedInteger(1L)).asInteger(), @@ -413,6 +413,33 @@ data class AuthorizationList( .let { DERSequence(it.toTypedArray()) } internal companion object { + + private class ASN1Converter( + private val objects: Map, + private val logFn: (String) -> Unit, + ) { + fun parse(tag: KeyMintTag, transform: (ASN1Encodable) -> T): T? { + return try { + objects[tag]?.let(transform) + } catch (e: ExtensionParsingException) { + logFn("Exception when parsing ${tag.name.lowercase()}: ${e.message}") + null + } + } + + fun parseInt(tag: KeyMintTag) = parse(tag) { it.toInt() } + + fun parseIntSet(tag: KeyMintTag) = + parse(tag) { it.toSet().map { innerIt -> innerIt.value }.toSet() } + + fun parseStr(tag: KeyMintTag) = parse(tag) { it.toStr() } + + fun parseByteString(tag: KeyMintTag) = parse(tag) { it.toByteString() } + + fun parsePatchLevel(tag: KeyMintTag, partition: String) = + parse(tag) { it.toPatchLevel(partition, logFn) } + } + fun from(seq: ASN1Sequence, logFn: (String) -> Unit = { _ -> }): AuthorizationList { val objects = seq.associate { @@ -437,48 +464,49 @@ data class AuthorizationList( logFn("AuthorizationList tags should appear in ascending order") } + val converter = ASN1Converter(objects, logFn) return AuthorizationList( - purposes = objects[KeyMintTag.PURPOSE]?.toSet()?.map { it.value }?.toSet(), - algorithms = objects[KeyMintTag.ALGORITHM]?.toInt(), - keySize = objects[KeyMintTag.KEY_SIZE]?.toInt(), - blockModes = - objects[KeyMintTag.BLOCK_MODE]?.toSet()?.map { it.value }?.toSet(), - digests = objects[KeyMintTag.DIGEST]?.toSet()?.map { it.value }?.toSet(), - paddings = objects[KeyMintTag.PADDING]?.toSet()?.map { it.value }?.toSet(), - ecCurve = objects[KeyMintTag.EC_CURVE]?.toInt(), - rsaPublicExponent = objects[KeyMintTag.RSA_PUBLIC_EXPONENT]?.toInt(), - rsaOaepMgfDigests = - objects[KeyMintTag.RSA_OAEP_MGF_DIGEST]?.toSet()?.map { it.value }?.toSet(), - activeDateTime = objects[KeyMintTag.ACTIVE_DATE_TIME]?.toInt(), - originationExpireDateTime = objects[KeyMintTag.ORIGINATION_EXPIRE_DATE_TIME]?.toInt(), - usageExpireDateTime = objects[KeyMintTag.USAGE_EXPIRE_DATE_TIME]?.toInt(), + purposes = converter.parseIntSet(KeyMintTag.PURPOSE), + algorithms = converter.parseInt(KeyMintTag.ALGORITHM), + keySize = converter.parseInt(KeyMintTag.KEY_SIZE), + blockModes = converter.parseIntSet(KeyMintTag.BLOCK_MODE), + digests = converter.parseIntSet(KeyMintTag.DIGEST), + paddings = converter.parseIntSet(KeyMintTag.PADDING), + ecCurve = converter.parseInt(KeyMintTag.EC_CURVE), + rsaPublicExponent = converter.parseInt(KeyMintTag.RSA_PUBLIC_EXPONENT), + rsaOaepMgfDigests = converter.parseIntSet(KeyMintTag.RSA_OAEP_MGF_DIGEST), + activeDateTime = converter.parseInt(KeyMintTag.ACTIVE_DATE_TIME), + originationExpireDateTime = converter.parseInt(KeyMintTag.ORIGINATION_EXPIRE_DATE_TIME), + usageExpireDateTime = converter.parseInt(KeyMintTag.USAGE_EXPIRE_DATE_TIME), noAuthRequired = if (objects.containsKey(KeyMintTag.NO_AUTH_REQUIRED)) true else null, - userAuthType = objects[KeyMintTag.USER_AUTH_TYPE]?.toInt(), - authTimeout = objects[KeyMintTag.AUTH_TIMEOUT]?.toInt(), + userAuthType = converter.parseInt(KeyMintTag.USER_AUTH_TYPE), + authTimeout = converter.parseInt(KeyMintTag.AUTH_TIMEOUT), trustedUserPresenceRequired = if (objects.containsKey(KeyMintTag.TRUSTED_USER_PRESENCE_REQUIRED)) true else null, unlockedDeviceRequired = if (objects.containsKey(KeyMintTag.UNLOCKED_DEVICE_REQUIRED)) true else null, - creationDateTime = objects[KeyMintTag.CREATION_DATE_TIME]?.toInt(), - origin = objects[KeyMintTag.ORIGIN]?.toOrigin(), + creationDateTime = converter.parseInt(KeyMintTag.CREATION_DATE_TIME), + origin = converter.parse(KeyMintTag.ORIGIN) { it.toOrigin() }, rollbackResistant = if (objects.containsKey(KeyMintTag.ROLLBACK_RESISTANT)) true else null, - rootOfTrust = objects[KeyMintTag.ROOT_OF_TRUST]?.toRootOfTrust(), - osVersion = objects[KeyMintTag.OS_VERSION]?.toInt(), - osPatchLevel = objects[KeyMintTag.OS_PATCH_LEVEL]?.toPatchLevel("OS", logFn), + rootOfTrust = converter.parse(KeyMintTag.ROOT_OF_TRUST) { it.toRootOfTrust() }, + osVersion = converter.parseInt(KeyMintTag.OS_VERSION), + osPatchLevel = converter.parsePatchLevel(KeyMintTag.OS_PATCH_LEVEL, "OS"), attestationApplicationId = - objects[KeyMintTag.ATTESTATION_APPLICATION_ID]?.toAttestationApplicationId(), - attestationIdBrand = objects[KeyMintTag.ATTESTATION_ID_BRAND]?.toStr(), - attestationIdDevice = objects[KeyMintTag.ATTESTATION_ID_DEVICE]?.toStr(), - attestationIdProduct = objects[KeyMintTag.ATTESTATION_ID_PRODUCT]?.toStr(), - attestationIdSerial = objects[KeyMintTag.ATTESTATION_ID_SERIAL]?.toStr(), - attestationIdImei = objects[KeyMintTag.ATTESTATION_ID_IMEI]?.toStr(), - attestationIdMeid = objects[KeyMintTag.ATTESTATION_ID_MEID]?.toStr(), - attestationIdManufacturer = objects[KeyMintTag.ATTESTATION_ID_MANUFACTURER]?.toStr(), - attestationIdModel = objects[KeyMintTag.ATTESTATION_ID_MODEL]?.toStr(), - vendorPatchLevel = objects[KeyMintTag.VENDOR_PATCH_LEVEL]?.toPatchLevel("vendor", logFn), - bootPatchLevel = objects[KeyMintTag.BOOT_PATCH_LEVEL]?.toPatchLevel("boot", logFn), - attestationIdSecondImei = objects[KeyMintTag.ATTESTATION_ID_SECOND_IMEI]?.toStr(), - moduleHash = objects[KeyMintTag.MODULE_HASH]?.toByteString(), + converter.parse(KeyMintTag.ATTESTATION_APPLICATION_ID) { + it.toAttestationApplicationId() + }, + attestationIdBrand = converter.parseStr(KeyMintTag.ATTESTATION_ID_BRAND), + attestationIdDevice = converter.parseStr(KeyMintTag.ATTESTATION_ID_DEVICE), + attestationIdProduct = converter.parseStr(KeyMintTag.ATTESTATION_ID_PRODUCT), + attestationIdSerial = converter.parseStr(KeyMintTag.ATTESTATION_ID_SERIAL), + attestationIdImei = converter.parseStr(KeyMintTag.ATTESTATION_ID_IMEI), + attestationIdMeid = converter.parseStr(KeyMintTag.ATTESTATION_ID_MEID), + attestationIdManufacturer = converter.parseStr(KeyMintTag.ATTESTATION_ID_MANUFACTURER), + attestationIdModel = converter.parseStr(KeyMintTag.ATTESTATION_ID_MODEL), + vendorPatchLevel = converter.parsePatchLevel(KeyMintTag.VENDOR_PATCH_LEVEL, "vendor"), + bootPatchLevel = converter.parsePatchLevel(KeyMintTag.BOOT_PATCH_LEVEL, "boot"), + attestationIdSecondImei = converter.parseStr(KeyMintTag.ATTESTATION_ID_SECOND_IMEI), + moduleHash = converter.parseByteString(KeyMintTag.MODULE_HASH), ) } } @@ -535,6 +563,7 @@ data class PatchLevel(val yearMonth: YearMonth, val version: Int? = null) { * https://source.android.com/docs/security/features/keystore/attestation#attestationapplicationid-schema */ @Immutable +@RequiresApi(24) data class AttestationApplicationId( @SuppressWarnings("Immutable") val packages: Set, @SuppressWarnings("Immutable") val signatures: Set, @@ -566,6 +595,7 @@ data class AttestationApplicationId( * @see * https://source.android.com/docs/security/features/keystore/attestation#attestationapplicationid-schema */ +@RequiresApi(24) data class AttestationPackageInfo(val name: String, val version: BigInteger) { fun toAsn1() = buildList { @@ -593,6 +623,7 @@ data class AttestationPackageInfo(val name: String, val version: BigInteger) { * @see https://source.android.com/docs/security/features/keystore/attestation#rootoftrust-fields */ @Immutable +@RequiresApi(24) data class RootOfTrust( val verifiedBootKey: ByteString, val deviceLocked: Boolean, @@ -644,40 +675,57 @@ enum class VerifiedBootState(val value: Int) { } } +@RequiresApi(24) private fun ASN1Encodable.toAttestationApplicationId(): AttestationApplicationId { - require(this is ASN1OctetString) { - "Object must be an ASN1OctetString, was ${this::class.simpleName}" + if (this !is ASN1OctetString) { + throw ExtensionParsingException( + "Object must be an ASN1OctetString, was ${this::class.simpleName}" + ) } return AttestationApplicationId.from(ASN1Sequence.getInstance(this.octets)) } @RequiresApi(24) private fun ASN1Encodable.toAuthorizationList(logFn: (String) -> Unit): AuthorizationList { - check(this is ASN1Sequence) { "Object must be an ASN1Sequence, was ${this::class.simpleName}" } + if (this !is ASN1Sequence) { + throw ExtensionParsingException("Object must be an ASN1Sequence, was ${this::class.simpleName}") + } return AuthorizationList.from(this, logFn) } +@RequiresApi(24) private fun ASN1Encodable.toBoolean(): Boolean { - check(this is ASN1Boolean) { "Must be an ASN1Boolean, was ${this::class.simpleName}" } + if (this !is ASN1Boolean) { + throw ExtensionParsingException("Must be an ASN1Boolean, was ${this::class.simpleName}") + } return this.isTrue } +@RequiresApi(24) private fun ASN1Encodable.toByteArray(): ByteArray { - check(this is ASN1OctetString) { "Must be an ASN1OctetString, was ${this::class.simpleName}" } + if (this !is ASN1OctetString) { + throw ExtensionParsingException("Must be an ASN1OctetString, was ${this::class.simpleName}") + } return this.octets } -private fun ASN1Encodable.toByteBuffer() = ByteBuffer.wrap(this.toByteArray()) +@RequiresApi(24) private fun ASN1Encodable.toByteBuffer() = ByteBuffer.wrap(this.toByteArray()) -private fun ASN1Encodable.toByteString() = ByteString.copyFrom(this.toByteArray()) +@RequiresApi(24) private fun ASN1Encodable.toByteString() = ByteString.copyFrom(this.toByteArray()) +@RequiresApi(24) private fun ASN1Encodable.toEnumerated(): ASN1Enumerated { - check(this is ASN1Enumerated) { "Must be an ASN1Enumerated, was ${this::class.simpleName}" } + if (this !is ASN1Enumerated) { + throw ExtensionParsingException("Must be an ASN1Enumerated, was ${this::class.simpleName}") + } return this } +@RequiresApi(24) private fun ASN1Encodable.toInt(): BigInteger { - check(this is ASN1Integer) { "Must be an ASN1Integer, was ${this::class.simpleName}" } + if (this !is ASN1Integer) { + throw ExtensionParsingException("Must be an ASN1Integer, was ${this::class.simpleName}") + } return this.value } @@ -686,28 +734,41 @@ private fun ASN1Encodable.toPatchLevel( logFn: (String) -> Unit = { _ -> }, ): PatchLevel? = PatchLevel.from(this, partitionName, logFn) +@RequiresApi(24) private fun ASN1Encodable.toRootOfTrust(): RootOfTrust { - check(this is ASN1Sequence) { "Object must be an ASN1Sequence, was ${this::class.simpleName}" } + if (this !is ASN1Sequence) { + throw ExtensionParsingException("Object must be an ASN1Sequence, was ${this::class.simpleName}") + } return RootOfTrust.from(this) } +@RequiresApi(24) private fun ASN1Encodable.toSecurityLevel(): SecurityLevel = SecurityLevel.values().firstOrNull { it.value.toBigInteger() == this.toEnumerated().value } ?: throw IllegalStateException("unknown value: ${this.toEnumerated().value}") +@RequiresApi(24) private fun ASN1Encodable.toOrigin(): Origin = Origin.values().firstOrNull { it.value.toBigInteger() == this.toInt() } ?: throw IllegalStateException("unknown value: ${this.toInt()}") +@RequiresApi(24) private inline fun ASN1Encodable.toSet(): Set { - check(this is ASN1Set) { "Object must be an ASN1Set, was ${this::class.simpleName}" } + if (this !is ASN1Set) { + throw ExtensionParsingException("Object must be an ASN1Set, was ${this::class.simpleName}") + } return this.map { - check(it is T) { "Object must be a ${T::class.simpleName}, was ${this::class.simpleName}" } + if (it !is T) { + throw ExtensionParsingException( + "Object must be a ${T::class.simpleName}, was ${this::class.simpleName}" + ) + } it } .toSet() } +@RequiresApi(24) private fun ASN1Encodable.toStr() = UTF_8.newDecoder() .onMalformedInput(CodingErrorAction.REPORT) @@ -763,12 +824,12 @@ fun Int.asDataItem() = fun String.asDataItem() = UnicodeString(this) -private fun DataItem.asMap(): Map { +private fun DataItem.asMap(): CborModelMap { if (this.majorType != MajorType.MAP) { throw CborException("Expected a map, got ${this.majorType.name}") } @Suppress("UNCHECKED_CAST") - return this as Map + return this as CborModelMap } fun DataItem.asUnicodeString(): UnicodeString { diff --git a/src/main/kotlin/testing/Certs.kt b/src/main/kotlin/testing/Certs.kt index 5bc1255..71f1de7 100644 --- a/src/main/kotlin/testing/Certs.kt +++ b/src/main/kotlin/testing/Certs.kt @@ -533,3 +533,38 @@ object Chains { ) } } + +object V3Extensions { + private fun ASN1Encodable.toTaggedObject(tag: KeyMintTag) = DERTaggedObject(tag.value, this) + + private val partialAuthorizationList: DERSequence = + DERSequence( + arrayOf( + ASN1Integer(1).toTaggedObject(KeyMintTag.ALGORITHM), + ASN1Integer(2).toTaggedObject(KeyMintTag.KEY_SIZE), + ) + ) + + private val partialMalformedAuthorizationList: DERSequence = + DERSequence( + arrayOf( + ASN1Integer(1).toTaggedObject(KeyMintTag.ALGORITHM), + DEROctetString(ByteArray(0)).toTaggedObject(KeyMintTag.KEY_SIZE), + ) + ) + + val keyDescriptionWithMalformedSoftwareAuthorizations: ByteArray = + DERSequence( + arrayOf( + ASN1Integer(1), // attestationVersion + ASN1Enumerated(SecurityLevel.SOFTWARE.value), // attestationSecurityLevel + ASN1Integer(1), // keyMintVersion + ASN1Enumerated(SecurityLevel.SOFTWARE.value), // keyMintSecurityLevel + DEROctetString(ByteArray(0)), // attestationChallenge + DEROctetString(ByteArray(0)), // uniqueId + partialMalformedAuthorizationList, // softwareEnforced + partialAuthorizationList, // hardwareEnforced + ) + ) + .encoded +} diff --git a/src/test/kotlin/ExtensionTest.kt b/src/test/kotlin/ExtensionTest.kt index dea5418..8983a21 100644 --- a/src/test/kotlin/ExtensionTest.kt +++ b/src/test/kotlin/ExtensionTest.kt @@ -17,8 +17,10 @@ package com.android.keyattestation.verifier import com.android.keyattestation.verifier.testing.Chains +import com.android.keyattestation.verifier.testing.FakeLogHook import com.android.keyattestation.verifier.testing.TestUtils.TESTDATA_PATH import com.android.keyattestation.verifier.testing.TestUtils.readCertPath +import com.android.keyattestation.verifier.testing.V3Extensions import com.android.keyattestation.verifier.testing.toKeyDescription import com.google.common.truth.Truth.assertThat @@ -82,7 +84,7 @@ class ExtensionTest { } @Test - @Ignore("TODO: b/356172932 - Reenable test once enabling tag order validator is configurable.") + @Ignore("TODO(google-internal bug): Reenable test once enabling tag order validator is configurable.") fun parseFrom_tagsNotInAscendingOrder_Throws() { assertFailsWith { KeyDescription.parseFrom(readCertPath("invalid/tags_not_in_ascending_order.pem").leafCert()) @@ -156,4 +158,38 @@ class ExtensionTest { ) assertThat(KeyDescription.parseFrom(keyDescription.encodeToAsn1())).isEqualTo(keyDescription) } + + @Test + fun keyDescriptionParseFrom_partialAuthorizationListExtension_success() { + val authorizationList = + AuthorizationList(purposes = setOf(1.toBigInteger()), algorithms = 1.toBigInteger()) + val keyDescription = + KeyDescription( + attestationVersion = 1.toBigInteger(), + attestationSecurityLevel = SecurityLevel.SOFTWARE, + keyMintVersion = 1.toBigInteger(), + keyMintSecurityLevel = SecurityLevel.SOFTWARE, + attestationChallenge = ByteString.empty(), + uniqueId = ByteString.empty(), + softwareEnforced = authorizationList, + hardwareEnforced = authorizationList, + ) + assertThat(KeyDescription.parseFrom(keyDescription.encodeToAsn1())).isEqualTo(keyDescription) + } + + @Test + fun keyDescriptionParseFrom_malformedAuthorizationListExtension_successAndLogs() { + val logHook = FakeLogHook() + assertThat( + KeyDescription.parseFrom( + V3Extensions.keyDescriptionWithMalformedSoftwareAuthorizations, + logFn = logHook.fakeVerifyRequestLog::logInfoMessage, + ) + .softwareEnforced + .keySize + ) + .isNull() + assertThat(logHook.fakeVerifyRequestLog.infoMessages) + .contains("Exception when parsing key_size: Must be an ASN1Integer, was DEROctetString") + } }