Skip to content

Commit 231788b

Browse files
suzannajiwanicopybara-github
authored andcommitted
Be more lenient with malformed AttestationList extensions.
Move casting logic into a runCatching block. If the value is malformed, this will prevent errors from being thrown up the stack. Instead, log and set the value to null PiperOrigin-RevId: 831599286
1 parent ab49edf commit 231788b

File tree

3 files changed

+184
-52
lines changed

3 files changed

+184
-52
lines changed

src/main/kotlin/Extension.kt

Lines changed: 112 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import co.nstant.`in`.cbor.CborEncoder
2222
import co.nstant.`in`.cbor.CborException
2323
import co.nstant.`in`.cbor.model.DataItem
2424
import co.nstant.`in`.cbor.model.MajorType
25-
import co.nstant.`in`.cbor.model.Map
25+
import co.nstant.`in`.cbor.model.Map as CborModelMap
2626
import co.nstant.`in`.cbor.model.NegativeInteger
2727
import co.nstant.`in`.cbor.model.SimpleValue
2828
import co.nstant.`in`.cbor.model.SimpleValueType
@@ -68,7 +68,7 @@ data class ProvisioningInfoMap(
6868
val certificatesIssued: Int,
6969
) {
7070
fun cborEncode(): ByteArray {
71-
val map = Map()
71+
val map = CborModelMap()
7272
map.put(UnsignedInteger(1L), certificatesIssued.asDataItem())
7373
return cborEncode(map)
7474
}
@@ -93,7 +93,7 @@ data class ProvisioningInfoMap(
9393
throw IllegalArgumentException(e)
9494
}
9595

96-
private fun from(seq: Map): ProvisioningInfoMap {
96+
private fun from(seq: CborModelMap): ProvisioningInfoMap {
9797
require(seq.keys.size >= 1)
9898
return ProvisioningInfoMap(
9999
certificatesIssued = seq.get(UnsignedInteger(1L)).asInteger(),
@@ -413,6 +413,33 @@ data class AuthorizationList(
413413
.let { DERSequence(it.toTypedArray()) }
414414

415415
internal companion object {
416+
417+
private class ASN1Converter(
418+
private val objects: Map<KeyMintTag, ASN1Encodable>,
419+
private val logFn: (String) -> Unit,
420+
) {
421+
fun <T> parse(tag: KeyMintTag, transform: (ASN1Encodable) -> T): T? {
422+
return try {
423+
objects[tag]?.let(transform)
424+
} catch (e: ExtensionParsingException) {
425+
logFn("Exception when parsing ${tag.name.lowercase()}: ${e.message}")
426+
null
427+
}
428+
}
429+
430+
fun parseInt(tag: KeyMintTag) = parse(tag) { it.toInt() }
431+
432+
fun parseIntSet(tag: KeyMintTag) =
433+
parse(tag) { it.toSet<ASN1Integer>().map { innerIt -> innerIt.value }.toSet() }
434+
435+
fun parseStr(tag: KeyMintTag) = parse(tag) { it.toStr() }
436+
437+
fun parseByteString(tag: KeyMintTag) = parse(tag) { it.toByteString() }
438+
439+
fun parsePatchLevel(tag: KeyMintTag, partition: String) =
440+
parse(tag) { it.toPatchLevel(partition, logFn) }
441+
}
442+
416443
fun from(seq: ASN1Sequence, logFn: (String) -> Unit = { _ -> }): AuthorizationList {
417444
val objects =
418445
seq.associate {
@@ -437,48 +464,49 @@ data class AuthorizationList(
437464
logFn("AuthorizationList tags should appear in ascending order")
438465
}
439466

467+
val converter = ASN1Converter(objects, logFn)
440468
return AuthorizationList(
441-
purposes = objects[KeyMintTag.PURPOSE]?.toSet<ASN1Integer>()?.map { it.value }?.toSet(),
442-
algorithms = objects[KeyMintTag.ALGORITHM]?.toInt(),
443-
keySize = objects[KeyMintTag.KEY_SIZE]?.toInt(),
444-
blockModes =
445-
objects[KeyMintTag.BLOCK_MODE]?.toSet<ASN1Integer>()?.map { it.value }?.toSet(),
446-
digests = objects[KeyMintTag.DIGEST]?.toSet<ASN1Integer>()?.map { it.value }?.toSet(),
447-
paddings = objects[KeyMintTag.PADDING]?.toSet<ASN1Integer>()?.map { it.value }?.toSet(),
448-
ecCurve = objects[KeyMintTag.EC_CURVE]?.toInt(),
449-
rsaPublicExponent = objects[KeyMintTag.RSA_PUBLIC_EXPONENT]?.toInt(),
450-
rsaOaepMgfDigests =
451-
objects[KeyMintTag.RSA_OAEP_MGF_DIGEST]?.toSet<ASN1Integer>()?.map { it.value }?.toSet(),
452-
activeDateTime = objects[KeyMintTag.ACTIVE_DATE_TIME]?.toInt(),
453-
originationExpireDateTime = objects[KeyMintTag.ORIGINATION_EXPIRE_DATE_TIME]?.toInt(),
454-
usageExpireDateTime = objects[KeyMintTag.USAGE_EXPIRE_DATE_TIME]?.toInt(),
469+
purposes = converter.parseIntSet(KeyMintTag.PURPOSE),
470+
algorithms = converter.parseInt(KeyMintTag.ALGORITHM),
471+
keySize = converter.parseInt(KeyMintTag.KEY_SIZE),
472+
blockModes = converter.parseIntSet(KeyMintTag.BLOCK_MODE),
473+
digests = converter.parseIntSet(KeyMintTag.DIGEST),
474+
paddings = converter.parseIntSet(KeyMintTag.PADDING),
475+
ecCurve = converter.parseInt(KeyMintTag.EC_CURVE),
476+
rsaPublicExponent = converter.parseInt(KeyMintTag.RSA_PUBLIC_EXPONENT),
477+
rsaOaepMgfDigests = converter.parseIntSet(KeyMintTag.RSA_OAEP_MGF_DIGEST),
478+
activeDateTime = converter.parseInt(KeyMintTag.ACTIVE_DATE_TIME),
479+
originationExpireDateTime = converter.parseInt(KeyMintTag.ORIGINATION_EXPIRE_DATE_TIME),
480+
usageExpireDateTime = converter.parseInt(KeyMintTag.USAGE_EXPIRE_DATE_TIME),
455481
noAuthRequired = if (objects.containsKey(KeyMintTag.NO_AUTH_REQUIRED)) true else null,
456-
userAuthType = objects[KeyMintTag.USER_AUTH_TYPE]?.toInt(),
457-
authTimeout = objects[KeyMintTag.AUTH_TIMEOUT]?.toInt(),
482+
userAuthType = converter.parseInt(KeyMintTag.USER_AUTH_TYPE),
483+
authTimeout = converter.parseInt(KeyMintTag.AUTH_TIMEOUT),
458484
trustedUserPresenceRequired =
459485
if (objects.containsKey(KeyMintTag.TRUSTED_USER_PRESENCE_REQUIRED)) true else null,
460486
unlockedDeviceRequired =
461487
if (objects.containsKey(KeyMintTag.UNLOCKED_DEVICE_REQUIRED)) true else null,
462-
creationDateTime = objects[KeyMintTag.CREATION_DATE_TIME]?.toInt(),
463-
origin = objects[KeyMintTag.ORIGIN]?.toOrigin(),
488+
creationDateTime = converter.parseInt(KeyMintTag.CREATION_DATE_TIME),
489+
origin = converter.parse(KeyMintTag.ORIGIN) { it.toOrigin() },
464490
rollbackResistant = if (objects.containsKey(KeyMintTag.ROLLBACK_RESISTANT)) true else null,
465-
rootOfTrust = objects[KeyMintTag.ROOT_OF_TRUST]?.toRootOfTrust(),
466-
osVersion = objects[KeyMintTag.OS_VERSION]?.toInt(),
467-
osPatchLevel = objects[KeyMintTag.OS_PATCH_LEVEL]?.toPatchLevel("OS", logFn),
491+
rootOfTrust = converter.parse(KeyMintTag.ROOT_OF_TRUST) { it.toRootOfTrust() },
492+
osVersion = converter.parseInt(KeyMintTag.OS_VERSION),
493+
osPatchLevel = converter.parsePatchLevel(KeyMintTag.OS_PATCH_LEVEL, "OS"),
468494
attestationApplicationId =
469-
objects[KeyMintTag.ATTESTATION_APPLICATION_ID]?.toAttestationApplicationId(),
470-
attestationIdBrand = objects[KeyMintTag.ATTESTATION_ID_BRAND]?.toStr(),
471-
attestationIdDevice = objects[KeyMintTag.ATTESTATION_ID_DEVICE]?.toStr(),
472-
attestationIdProduct = objects[KeyMintTag.ATTESTATION_ID_PRODUCT]?.toStr(),
473-
attestationIdSerial = objects[KeyMintTag.ATTESTATION_ID_SERIAL]?.toStr(),
474-
attestationIdImei = objects[KeyMintTag.ATTESTATION_ID_IMEI]?.toStr(),
475-
attestationIdMeid = objects[KeyMintTag.ATTESTATION_ID_MEID]?.toStr(),
476-
attestationIdManufacturer = objects[KeyMintTag.ATTESTATION_ID_MANUFACTURER]?.toStr(),
477-
attestationIdModel = objects[KeyMintTag.ATTESTATION_ID_MODEL]?.toStr(),
478-
vendorPatchLevel = objects[KeyMintTag.VENDOR_PATCH_LEVEL]?.toPatchLevel("vendor", logFn),
479-
bootPatchLevel = objects[KeyMintTag.BOOT_PATCH_LEVEL]?.toPatchLevel("boot", logFn),
480-
attestationIdSecondImei = objects[KeyMintTag.ATTESTATION_ID_SECOND_IMEI]?.toStr(),
481-
moduleHash = objects[KeyMintTag.MODULE_HASH]?.toByteString(),
495+
converter.parse(KeyMintTag.ATTESTATION_APPLICATION_ID) {
496+
it.toAttestationApplicationId()
497+
},
498+
attestationIdBrand = converter.parseStr(KeyMintTag.ATTESTATION_ID_BRAND),
499+
attestationIdDevice = converter.parseStr(KeyMintTag.ATTESTATION_ID_DEVICE),
500+
attestationIdProduct = converter.parseStr(KeyMintTag.ATTESTATION_ID_PRODUCT),
501+
attestationIdSerial = converter.parseStr(KeyMintTag.ATTESTATION_ID_SERIAL),
502+
attestationIdImei = converter.parseStr(KeyMintTag.ATTESTATION_ID_IMEI),
503+
attestationIdMeid = converter.parseStr(KeyMintTag.ATTESTATION_ID_MEID),
504+
attestationIdManufacturer = converter.parseStr(KeyMintTag.ATTESTATION_ID_MANUFACTURER),
505+
attestationIdModel = converter.parseStr(KeyMintTag.ATTESTATION_ID_MODEL),
506+
vendorPatchLevel = converter.parsePatchLevel(KeyMintTag.VENDOR_PATCH_LEVEL, "vendor"),
507+
bootPatchLevel = converter.parsePatchLevel(KeyMintTag.BOOT_PATCH_LEVEL, "boot"),
508+
attestationIdSecondImei = converter.parseStr(KeyMintTag.ATTESTATION_ID_SECOND_IMEI),
509+
moduleHash = converter.parseByteString(KeyMintTag.MODULE_HASH),
482510
)
483511
}
484512
}
@@ -535,6 +563,7 @@ data class PatchLevel(val yearMonth: YearMonth, val version: Int? = null) {
535563
* https://source.android.com/docs/security/features/keystore/attestation#attestationapplicationid-schema
536564
*/
537565
@Immutable
566+
@RequiresApi(24)
538567
data class AttestationApplicationId(
539568
@SuppressWarnings("Immutable") val packages: Set<AttestationPackageInfo>,
540569
@SuppressWarnings("Immutable") val signatures: Set<ByteString>,
@@ -566,6 +595,7 @@ data class AttestationApplicationId(
566595
* @see
567596
* https://source.android.com/docs/security/features/keystore/attestation#attestationapplicationid-schema
568597
*/
598+
@RequiresApi(24)
569599
data class AttestationPackageInfo(val name: String, val version: BigInteger) {
570600
fun toAsn1() =
571601
buildList {
@@ -593,6 +623,7 @@ data class AttestationPackageInfo(val name: String, val version: BigInteger) {
593623
* @see https://source.android.com/docs/security/features/keystore/attestation#rootoftrust-fields
594624
*/
595625
@Immutable
626+
@RequiresApi(24)
596627
data class RootOfTrust(
597628
val verifiedBootKey: ByteString,
598629
val deviceLocked: Boolean,
@@ -644,40 +675,57 @@ enum class VerifiedBootState(val value: Int) {
644675
}
645676
}
646677

678+
@RequiresApi(24)
647679
private fun ASN1Encodable.toAttestationApplicationId(): AttestationApplicationId {
648-
require(this is ASN1OctetString) {
649-
"Object must be an ASN1OctetString, was ${this::class.simpleName}"
680+
if (this !is ASN1OctetString) {
681+
throw ExtensionParsingException(
682+
"Object must be an ASN1OctetString, was ${this::class.simpleName}"
683+
)
650684
}
651685
return AttestationApplicationId.from(ASN1Sequence.getInstance(this.octets))
652686
}
653687

654688
@RequiresApi(24)
655689
private fun ASN1Encodable.toAuthorizationList(logFn: (String) -> Unit): AuthorizationList {
656-
check(this is ASN1Sequence) { "Object must be an ASN1Sequence, was ${this::class.simpleName}" }
690+
if (this !is ASN1Sequence) {
691+
throw ExtensionParsingException("Object must be an ASN1Sequence, was ${this::class.simpleName}")
692+
}
657693
return AuthorizationList.from(this, logFn)
658694
}
659695

696+
@RequiresApi(24)
660697
private fun ASN1Encodable.toBoolean(): Boolean {
661-
check(this is ASN1Boolean) { "Must be an ASN1Boolean, was ${this::class.simpleName}" }
698+
if (this !is ASN1Boolean) {
699+
throw ExtensionParsingException("Must be an ASN1Boolean, was ${this::class.simpleName}")
700+
}
662701
return this.isTrue
663702
}
664703

704+
@RequiresApi(24)
665705
private fun ASN1Encodable.toByteArray(): ByteArray {
666-
check(this is ASN1OctetString) { "Must be an ASN1OctetString, was ${this::class.simpleName}" }
706+
if (this !is ASN1OctetString) {
707+
throw ExtensionParsingException("Must be an ASN1OctetString, was ${this::class.simpleName}")
708+
}
667709
return this.octets
668710
}
669711

670-
private fun ASN1Encodable.toByteBuffer() = ByteBuffer.wrap(this.toByteArray())
712+
@RequiresApi(24) private fun ASN1Encodable.toByteBuffer() = ByteBuffer.wrap(this.toByteArray())
671713

672-
private fun ASN1Encodable.toByteString() = ByteString.copyFrom(this.toByteArray())
714+
@RequiresApi(24) private fun ASN1Encodable.toByteString() = ByteString.copyFrom(this.toByteArray())
673715

716+
@RequiresApi(24)
674717
private fun ASN1Encodable.toEnumerated(): ASN1Enumerated {
675-
check(this is ASN1Enumerated) { "Must be an ASN1Enumerated, was ${this::class.simpleName}" }
718+
if (this !is ASN1Enumerated) {
719+
throw ExtensionParsingException("Must be an ASN1Enumerated, was ${this::class.simpleName}")
720+
}
676721
return this
677722
}
678723

724+
@RequiresApi(24)
679725
private fun ASN1Encodable.toInt(): BigInteger {
680-
check(this is ASN1Integer) { "Must be an ASN1Integer, was ${this::class.simpleName}" }
726+
if (this !is ASN1Integer) {
727+
throw ExtensionParsingException("Must be an ASN1Integer, was ${this::class.simpleName}")
728+
}
681729
return this.value
682730
}
683731

@@ -686,28 +734,41 @@ private fun ASN1Encodable.toPatchLevel(
686734
logFn: (String) -> Unit = { _ -> },
687735
): PatchLevel? = PatchLevel.from(this, partitionName, logFn)
688736

737+
@RequiresApi(24)
689738
private fun ASN1Encodable.toRootOfTrust(): RootOfTrust {
690-
check(this is ASN1Sequence) { "Object must be an ASN1Sequence, was ${this::class.simpleName}" }
739+
if (this !is ASN1Sequence) {
740+
throw ExtensionParsingException("Object must be an ASN1Sequence, was ${this::class.simpleName}")
741+
}
691742
return RootOfTrust.from(this)
692743
}
693744

745+
@RequiresApi(24)
694746
private fun ASN1Encodable.toSecurityLevel(): SecurityLevel =
695747
SecurityLevel.values().firstOrNull { it.value.toBigInteger() == this.toEnumerated().value }
696748
?: throw IllegalStateException("unknown value: ${this.toEnumerated().value}")
697749

750+
@RequiresApi(24)
698751
private fun ASN1Encodable.toOrigin(): Origin =
699752
Origin.values().firstOrNull { it.value.toBigInteger() == this.toInt() }
700753
?: throw IllegalStateException("unknown value: ${this.toInt()}")
701754

755+
@RequiresApi(24)
702756
private inline fun <reified T> ASN1Encodable.toSet(): Set<T> {
703-
check(this is ASN1Set) { "Object must be an ASN1Set, was ${this::class.simpleName}" }
757+
if (this !is ASN1Set) {
758+
throw ExtensionParsingException("Object must be an ASN1Set, was ${this::class.simpleName}")
759+
}
704760
return this.map {
705-
check(it is T) { "Object must be a ${T::class.simpleName}, was ${this::class.simpleName}" }
761+
if (it !is T) {
762+
throw ExtensionParsingException(
763+
"Object must be a ${T::class.simpleName}, was ${this::class.simpleName}"
764+
)
765+
}
706766
it
707767
}
708768
.toSet()
709769
}
710770

771+
@RequiresApi(24)
711772
private fun ASN1Encodable.toStr() =
712773
UTF_8.newDecoder()
713774
.onMalformedInput(CodingErrorAction.REPORT)
@@ -763,12 +824,12 @@ fun Int.asDataItem() =
763824

764825
fun String.asDataItem() = UnicodeString(this)
765826

766-
private fun DataItem.asMap(): Map {
827+
private fun DataItem.asMap(): CborModelMap {
767828
if (this.majorType != MajorType.MAP) {
768829
throw CborException("Expected a map, got ${this.majorType.name}")
769830
}
770831
@Suppress("UNCHECKED_CAST")
771-
return this as Map
832+
return this as CborModelMap
772833
}
773834

774835
fun DataItem.asUnicodeString(): UnicodeString {

src/main/kotlin/testing/Certs.kt

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,38 @@ object Chains {
533533
)
534534
}
535535
}
536+
537+
object V3Extensions {
538+
private fun ASN1Encodable.toTaggedObject(tag: KeyMintTag) = DERTaggedObject(tag.value, this)
539+
540+
private val partialAuthorizationList: DERSequence =
541+
DERSequence(
542+
arrayOf(
543+
ASN1Integer(1).toTaggedObject(KeyMintTag.ALGORITHM),
544+
ASN1Integer(2).toTaggedObject(KeyMintTag.KEY_SIZE),
545+
)
546+
)
547+
548+
private val partialMalformedAuthorizationList: DERSequence =
549+
DERSequence(
550+
arrayOf(
551+
ASN1Integer(1).toTaggedObject(KeyMintTag.ALGORITHM),
552+
DEROctetString(ByteArray(0)).toTaggedObject(KeyMintTag.KEY_SIZE),
553+
)
554+
)
555+
556+
val keyDescriptionWithMalformedSoftwareAuthorizations: ByteArray =
557+
DERSequence(
558+
arrayOf(
559+
ASN1Integer(1), // attestationVersion
560+
ASN1Enumerated(SecurityLevel.SOFTWARE.value), // attestationSecurityLevel
561+
ASN1Integer(1), // keyMintVersion
562+
ASN1Enumerated(SecurityLevel.SOFTWARE.value), // keyMintSecurityLevel
563+
DEROctetString(ByteArray(0)), // attestationChallenge
564+
DEROctetString(ByteArray(0)), // uniqueId
565+
partialMalformedAuthorizationList, // softwareEnforced
566+
partialAuthorizationList, // hardwareEnforced
567+
)
568+
)
569+
.encoded
570+
}

src/test/kotlin/ExtensionTest.kt

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
package com.android.keyattestation.verifier
1818

1919
import com.android.keyattestation.verifier.testing.Chains
20+
import com.android.keyattestation.verifier.testing.FakeLogHook
2021
import com.android.keyattestation.verifier.testing.TestUtils.TESTDATA_PATH
2122
import com.android.keyattestation.verifier.testing.TestUtils.readCertPath
23+
import com.android.keyattestation.verifier.testing.V3Extensions
2224
import com.android.keyattestation.verifier.testing.toKeyDescription
2325
import com.google.common.truth.Truth.assertThat
2426

@@ -82,7 +84,7 @@ class ExtensionTest {
8284
}
8385

8486
@Test
85-
@Ignore("TODO: b/356172932 - Reenable test once enabling tag order validator is configurable.")
87+
@Ignore("TODO(google-internal bug): Reenable test once enabling tag order validator is configurable.")
8688
fun parseFrom_tagsNotInAscendingOrder_Throws() {
8789
assertFailsWith<IllegalArgumentException> {
8890
KeyDescription.parseFrom(readCertPath("invalid/tags_not_in_ascending_order.pem").leafCert())
@@ -156,4 +158,38 @@ class ExtensionTest {
156158
)
157159
assertThat(KeyDescription.parseFrom(keyDescription.encodeToAsn1())).isEqualTo(keyDescription)
158160
}
161+
162+
@Test
163+
fun keyDescriptionParseFrom_partialAuthorizationListExtension_success() {
164+
val authorizationList =
165+
AuthorizationList(purposes = setOf(1.toBigInteger()), algorithms = 1.toBigInteger())
166+
val keyDescription =
167+
KeyDescription(
168+
attestationVersion = 1.toBigInteger(),
169+
attestationSecurityLevel = SecurityLevel.SOFTWARE,
170+
keyMintVersion = 1.toBigInteger(),
171+
keyMintSecurityLevel = SecurityLevel.SOFTWARE,
172+
attestationChallenge = ByteString.empty(),
173+
uniqueId = ByteString.empty(),
174+
softwareEnforced = authorizationList,
175+
hardwareEnforced = authorizationList,
176+
)
177+
assertThat(KeyDescription.parseFrom(keyDescription.encodeToAsn1())).isEqualTo(keyDescription)
178+
}
179+
180+
@Test
181+
fun keyDescriptionParseFrom_malformedAuthorizationListExtension_successAndLogs() {
182+
val logHook = FakeLogHook()
183+
assertThat(
184+
KeyDescription.parseFrom(
185+
V3Extensions.keyDescriptionWithMalformedSoftwareAuthorizations,
186+
logFn = logHook.fakeVerifyRequestLog::logInfoMessage,
187+
)
188+
.softwareEnforced
189+
.keySize
190+
)
191+
.isNull()
192+
assertThat(logHook.fakeVerifyRequestLog.infoMessages)
193+
.contains("Exception when parsing key_size: Must be an ASN1Integer, was DEROctetString")
194+
}
159195
}

0 commit comments

Comments
 (0)