diff --git a/build.gradle b/build.gradle index d8ab701..fe82960 100644 --- a/build.gradle +++ b/build.gradle @@ -11,6 +11,10 @@ dependencies { //compile 'net.portswigger.burp.extender:burp-extender-api:1.7.13' implementation 'org.apache.commons:commons-text:1.9' implementation files('bulkScan-all.jar') // this contains albinowaxUtils + + testImplementation( + 'org.junit.jupiter:junit-jupiter:5.10.5', + ) } sourceSets { @@ -22,6 +26,15 @@ sourceSets { srcDir 'resources' } } + test { + java { + srcDir 'test' + } + } +} + +tasks.test { + useJUnitPlatform() } archivesBaseName = ('active-scan-plus-plus-all') diff --git a/src/burp/XMLScan.java b/src/burp/XMLScan.java index 31377a7..8eab923 100644 --- a/src/burp/XMLScan.java +++ b/src/burp/XMLScan.java @@ -4,35 +4,10 @@ import burp.api.montoya.http.message.requests.HttpRequest; import burp.api.montoya.http.message.responses.analysis.Attribute; import burp.api.montoya.http.message.responses.analysis.AttributeType; -import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; -import org.w3c.dom.Attr; -import org.w3c.dom.Document; -import org.w3c.dom.Node; -import org.w3c.dom.NodeList; -import javax.xml.parsers.DocumentBuilder; -import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.ParserConfigurationException; -import javax.xml.transform.OutputKeys; -import javax.xml.transform.Transformer; -import javax.xml.transform.TransformerFactory; -import javax.xml.transform.dom.DOMSource; -import javax.xml.transform.stream.StreamResult; -import javax.xml.xpath.XPath; -import javax.xml.xpath.XPathConstants; -import javax.xml.xpath.XPathExpression; -import javax.xml.xpath.XPathFactory; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.StringWriter; -import java.net.URLDecoder; -import java.nio.charset.StandardCharsets; import java.util.*; -import java.util.zip.DataFormatException; -import java.util.zip.Deflater; -import java.util.zip.Inflater; import static burp.Utilities.helpers; @@ -40,20 +15,16 @@ public class XMLScan extends ParamScan { private final Map checks; private final Set ATTRIBUTES; private final int confirmCount; - private boolean isCompressed; - private boolean isBase64Encoded; public XMLScan(String name) { super(name); this.checks = new HashMap<>(); - this.checks.put("DOCTYPE", new CheckDetails(this::detectUnsafeDOCTYPE, + this.checks.put("DOCTYPE", new CheckDetails(XMLUtilities.SAMLDocument::detectUnsafeDOCTYPE, List.of("https://portswigger.net/research/saml-roulette-the-hacker-always-wins"))); - this.checks.put("ENTITY", new CheckDetails(this::detectUnsafeENTITIES, + this.checks.put("ENTITY", new CheckDetails(XMLUtilities.SAMLDocument::detectUnsafeENTITIES, List.of("https://portswigger.net/research/saml-roulette-the-hacker-always-wins"))); this.confirmCount = 2; - this.isCompressed = false; - this.isBase64Encoded = false; this.ATTRIBUTES = new HashSet<>(); this.ATTRIBUTES.addAll(Set.of(AttributeType.values())); this.ATTRIBUTES.removeAll(Set.of( @@ -78,41 +49,6 @@ public static List getUniqueAttributeTypes(List firstA return mismatchedTypes; } - private Pair detectUnsafeDOCTYPE(Document document) { - if (document == null || document.getDoctype() != null) { - throw new IllegalArgumentException(); - } - String str = "" + transformDocument(document); - return new ImmutablePair<>(compressIfNeeded(str), ""); - } - - private Pair detectUnsafeENTITIES(Document document) { - if (document == null || document.getDoctype() != null) { - throw new IllegalArgumentException(); - } - try { - XPathFactory xPathFactory = XPathFactory.newInstance(); - XPath xpath = xPathFactory.newXPath(); - XPathExpression expr = xpath.compile("//*[@ID]"); - - Node node = (Node) expr.evaluate(document, XPathConstants.NODE); - if (node != null && node.getAttributes() != null) { - Attr idAttr = (Attr) node.getAttributes().getNamedItem("ID"); - if (idAttr != null) { - String uuid = idAttr.getValue(); - idAttr.setValue("PLACEHOLDER_UUID"); - String str = String.format(" ]>", uuid); - str += transformDocument(document); - str = str.replace("PLACEHOLDER_UUID", "&uuid;"); - return new ImmutablePair<>(compressIfNeeded(str), ""); - } - } - throw new IllegalArgumentException(); - } catch (Exception e) { - throw new IllegalArgumentException(); - } - } - private boolean areAttributesIdentical(List firstAttributes, List secondAttributes) { if (firstAttributes.size() != secondAttributes.size()) { return false; @@ -134,7 +70,7 @@ public List doActiveScan(IHttpRequestResponse basePair, IScannerInse String insertionPointName = insertionPoint.getInsertionPointName(); if (!(insertionPointName.equalsIgnoreCase("SAMLRequest") || insertionPointName.equalsIgnoreCase("SAMLResponse"))) return null; - Optional document = extractOptionalXMLDocument(base); + Optional document = XMLUtilities.SAMLDocument.parse(base); if (document.isEmpty()) return null; List issues = new ArrayList<>(); @@ -152,7 +88,7 @@ public List doActiveScan(IHttpRequestResponse basePair, IScannerInse if (unique.isEmpty()) { // Skip target as unpredictable return null; - }; + } originalAttributes = Utilities.buildMontoyaResp(new Resp(basePair)).response().attributes(unique.toArray(new AttributeType[]{})); @@ -160,20 +96,15 @@ public List doActiveScan(IHttpRequestResponse basePair, IScannerInse Map.Entry entry = checksCopy.entrySet().iterator().next(); checksCopy.remove(entry.getKey()); String name = entry.getKey(); - Check check = entry.getValue().getTransformation(); - List links = entry.getValue().getLinks(); + Check check = entry.getValue().transformation(); + List links = entry.getValue().links(); String probe; try { - Document copy = DocumentBuilderFactory.newInstance() - .newDocumentBuilder() - .newDocument(); - copy.appendChild(copy.importNode(document.get().getDocumentElement(), true)); - Pair result = check.apply(copy); + Pair result = check.apply(document.get().copy()); probe = result.getKey(); } catch (IllegalArgumentException | ParserConfigurationException e) { continue; } - Utilities.log("Trying " + probe); for (int attempt = 0; attempt < this.confirmCount; attempt++) { IHttpRequestResponse attack = OldUtilities.request2(basePair, insertionPoint, probe); @@ -196,173 +127,19 @@ public List doActiveScan(IHttpRequestResponse basePair, IScannerInse return issues; } - private String transformDocument(Document document) { - try { - TransformerFactory transformerFactory = TransformerFactory.newInstance(); - Transformer transformer = transformerFactory.newTransformer(); - transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes"); - StringWriter writer = new StringWriter(); - transformer.transform(new DOMSource(document), new StreamResult(writer)); - return writer.toString(); - } catch (Exception e) { - throw new IllegalArgumentException(e); - } - } - - public String compressIfNeeded(String data) { - byte[] resultingData = data.getBytes(StandardCharsets.UTF_8); - - if (isCompressed) { - byte[] compressedData = compress(resultingData); - if (compressedData != null) resultingData = compressedData; - } - - return isBase64Encoded - ? Base64.getEncoder().encodeToString(resultingData) - : new String(resultingData, StandardCharsets.ISO_8859_1); - - } - - private byte[] compress(byte[] input) { - Deflater deflater = new Deflater(5, true); - try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(input.length)) { - deflater.setInput(input); - deflater.finish(); - byte[] buffer = new byte[1024]; - int maxLoops = 10000; - int loops = 0; - - while (!deflater.finished()) { - int count = deflater.deflate(buffer); - if (count == 0) { - if (loops++ >= maxLoops) { - throw new RuntimeException("Deflater made no progress — possible logic error or invalid input."); - } - } else { - loops = 0; // reset loop count on progress - outputStream.write(buffer, 0, count); - } - } - - return outputStream.toByteArray(); - } catch (IOException e) { - throw new IllegalArgumentException("Compression failed", e); - } finally { - deflater.end(); - } - } - - public byte[] decompress(byte[] data) throws DataFormatException { - Inflater inflater = new Inflater(true); - try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(data.length)) { - inflater.setInput(data); - byte[] buffer = new byte[1024]; - int maxLoops = 10000; // prevent infinite loop - int loops = 0; - - while (!inflater.finished() && loops < maxLoops) { - int count = inflater.inflate(buffer); - if (count == 0 && inflater.needsInput()) { - break; - } - outputStream.write(buffer, 0, count); - loops++; - } - - if (loops >= maxLoops) { - throw new DataFormatException("Decompression exceeded safe iteration limit."); - } - - return outputStream.toByteArray(); - } catch (IOException e) { - throw new IllegalArgumentException(e); - } finally { - inflater.end(); - } - } - - private Optional tryURLDecode(String input) { - try { - String urlDecoded = URLDecoder.decode(input, StandardCharsets.UTF_8); - return Optional.of(urlDecoded); - } catch (Exception e) { - return Optional.empty(); - } - } - - private Optional tryBase64Decode(String input) { - try { - byte[] base64Decoded = Base64.getDecoder().decode(input); - this.isBase64Encoded = true; - return Optional.of(base64Decoded); - } catch (Exception e) { - return Optional.empty(); - } - } - - private Optional tryDecompress(byte[] input) { - try { - byte[] decompressed = decompress(input); - this.isCompressed = true; - return Optional.of(new String(decompressed, StandardCharsets.UTF_8)); - } catch (Exception e) { - return Optional.empty(); - } - } - - private Optional parseXML(String xmlString) { - try { - if (!xmlString.startsWith("<")) throw new IllegalArgumentException(); - DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); - factory.setFeature("http://xml.org/sax/features/external-general-entities", false); - factory.setFeature("http://xml.org/sax/features/external-parameter-entities", false); - factory.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false); - factory.setFeature(javax.xml.XMLConstants.FEATURE_SECURE_PROCESSING, true); - factory.setNamespaceAware(true); - - DocumentBuilder builder = factory.newDocumentBuilder(); - try (ByteArrayInputStream inputStream = new ByteArrayInputStream(xmlString.getBytes(StandardCharsets.UTF_8))) { - Document document = builder.parse(inputStream); - return Optional.of(document); - } - } catch (Exception e) { - return Optional.empty(); - } - } - - public Optional extractOptionalXMLDocument(String input) { - String processedData = tryURLDecode(input).orElse(input); - Optional optionalBytes = tryBase64Decode(processedData); - if (optionalBytes.isPresent()) { - processedData = tryDecompress(optionalBytes.get()) - .orElse(new String(optionalBytes.get(), StandardCharsets.UTF_8)); - } - return parseXML(processedData); - } - @FunctionalInterface private interface Check { - Pair apply(Document base); + Pair apply(XMLUtilities.SAMLDocument base); } - private static class CheckDetails { - private final Check transformation; - private final List links; - - public CheckDetails(Check transformation, List usefulLinks) { - this.transformation = transformation; - this.links = usefulLinks; - } + private record CheckDetails(Check transformation, List links) { - public Check getTransformation() { - return transformation; - } - - public List getLinks() { - return links.stream() - .map(link -> String.format("%s", link, link)).toList(); + @Override + public List links() { + return links.stream() + .map(link -> String.format("%s", link, link)).toList(); + } } - } } diff --git a/src/burp/XMLUtilities.java b/src/burp/XMLUtilities.java new file mode 100644 index 0000000..1e14f7a --- /dev/null +++ b/src/burp/XMLUtilities.java @@ -0,0 +1,262 @@ +package burp; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.w3c.dom.Attr; +import org.w3c.dom.Document; +import org.w3c.dom.Node; + +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; +import javax.xml.transform.OutputKeys; +import javax.xml.transform.Transformer; +import javax.xml.transform.TransformerFactory; +import javax.xml.transform.dom.DOMSource; +import javax.xml.transform.stream.StreamResult; +import javax.xml.xpath.XPath; +import javax.xml.xpath.XPathConstants; +import javax.xml.xpath.XPathExpression; +import javax.xml.xpath.XPathFactory; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.StringWriter; +import java.net.URLDecoder; +import java.nio.charset.Charset; +import java.util.Base64; +import java.util.Optional; +import java.util.zip.DataFormatException; +import java.util.zip.Deflater; +import java.util.zip.Inflater; + + +public class XMLUtilities { + + public static Optional tryBase64Decode(String input) { + try { + String urlDecoded = URLDecoder.decode(input, Charset.defaultCharset()); + byte[] base64Decoded = Base64.getDecoder().decode(urlDecoded); + return Optional.of(base64Decoded); + } catch (Exception e) { + try { + byte[] base64Decoded = Base64.getDecoder().decode(input); + return Optional.of(base64Decoded); + } catch (Exception ex) { + return Optional.empty(); + } + } + } + + public static Optional tryDecompress(byte[] input) { + try { + byte[] decompressed = decompressDeflate(input); + return Optional.of(decompressed); + } catch (Exception e) { + return Optional.empty(); + } + } + + public static Optional tryCompress(byte[] input) { + try { + byte[] compressed = compressDeflate(input); + return Optional.of(compressed); + } catch (Exception e) { + return Optional.empty(); + } + } + + private static byte[] compressDeflate(byte[] input) { + Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true); + try { + deflater.setInput(input); + deflater.finish(); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(Math.max(32, input.length / 2)); + byte[] buffer = new byte[8192]; + + while (!deflater.finished()) { + int written = deflater.deflate(buffer, 0, buffer.length, Deflater.NO_FLUSH); + + if (written > 0) { + baos.write(buffer, 0, written); + continue; + } + + if (deflater.needsInput()) { + throw new IllegalStateException("Deflater needs more input after finish(); input likely incomplete."); + } + + if (!deflater.finished()) { + throw new IllegalStateException("Deflater made no progress (possible invalid state)."); + } + } + + return baos.toByteArray(); + } finally { + deflater.end(); + } + } + + private static byte[] decompressDeflate(byte[] input) throws DataFormatException { + Inflater inflater = new Inflater(true); + try { + inflater.setInput(input); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(input.length * 2); + byte[] buffer = new byte[8192]; + + while (!inflater.finished()) { + int read = inflater.inflate(buffer); + + if (read > 0) { + baos.write(buffer, 0, read); + continue; + } + + if (inflater.needsDictionary()) { + throw new DataFormatException("Preset dictionary required for this stream."); + } + if (inflater.needsInput()) { + throw new DataFormatException("Truncated deflate stream (needs more input)."); + } + + throw new DataFormatException("Inflater made no progress (corrupt or invalid stream)."); + } + + return baos.toByteArray(); + } finally { + inflater.end(); + } + } + + public static class SAMLDocument { + private boolean isBase64Encoded; + private boolean isCompressed; + private Document document; + + public SAMLDocument(boolean isBase64Encoded, boolean isCompressed, Document document) { + this.isBase64Encoded = isBase64Encoded; + this.isCompressed = isCompressed; + this.document = document; + } + + public void setBase64Encoded(boolean base64Encoded) { + isBase64Encoded = base64Encoded; + } + + public void setCompressed(boolean compressed) { + isCompressed = compressed; + } + + public Document getDocument() { + return document; + } + + public SAMLDocument copy() throws ParserConfigurationException { + Document copy = DocumentBuilderFactory.newInstance() + .newDocumentBuilder() + .newDocument(); + copy.appendChild(copy.importNode(document.getDocumentElement(), true)); + return new SAMLDocument(isBase64Encoded, isCompressed, copy); + } + + public String transformDocument() { + try { + TransformerFactory transformerFactory = TransformerFactory.newInstance(); + Transformer transformer = transformerFactory.newTransformer(); + transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes"); + StringWriter writer = new StringWriter(); + transformer.transform(new DOMSource(document), new StreamResult(writer)); + return writer.toString(); + } catch (Exception e) { + throw new IllegalArgumentException(e); + } + } + public static Pair detectUnsafeDOCTYPE(XMLUtilities.SAMLDocument document) { + if (document == null || document.getDocument().getDoctype() != null) { + throw new IllegalArgumentException(); + } + String str = "" + document.transformDocument(); + return new ImmutablePair<>(document.encode(str), ""); + } + + public static Pair detectUnsafeENTITIES(XMLUtilities.SAMLDocument document) { + if (document == null || document.getDocument().getDoctype() != null) { + throw new IllegalArgumentException(); + } + try { + XPathFactory xPathFactory = XPathFactory.newInstance(); + XPath xpath = xPathFactory.newXPath(); + XPathExpression expr = xpath.compile("//*[@ID]"); + + Node node = (Node) expr.evaluate(document.getDocument(), XPathConstants.NODE); + if (node != null && node.getAttributes() != null) { + Attr idAttr = (Attr) node.getAttributes().getNamedItem("ID"); + if (idAttr != null) { + String uuid = idAttr.getValue(); + idAttr.setValue("PLACEHOLDER_UUID"); + String str = String.format(" ]>", uuid); + str += document.transformDocument(); + str = str.replace("PLACEHOLDER_UUID", "&uuid;"); + return new ImmutablePair<>(document.encode(str), ""); + } + } + throw new IllegalArgumentException(); + } catch (Exception e) { + throw new IllegalArgumentException(); + } + } + + public static Optional parse(String xmlString) { + try { + boolean isB64 = false; + boolean isDeflated = false; + + String processedData = xmlString; + Optional optionalBytes = XMLUtilities.tryBase64Decode(processedData); + + if (optionalBytes.isPresent() && optionalBytes.get().length != 0) { + isB64 = true; + byte[] data = optionalBytes.get(); + if (data[0] == '<') { + processedData = new String(data, Charset.defaultCharset()); + } else { + Optional decompressed = XMLUtilities.tryDecompress(data); + if (decompressed.isPresent() && decompressed.get().length != 0) { + isDeflated = true; + processedData = new String(decompressed.get(), Charset.defaultCharset()); + } + } + } + + if (!processedData.startsWith("<")) throw new IllegalArgumentException(); + DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); + factory.setFeature("http://xml.org/sax/features/external-general-entities", false); + factory.setFeature("http://xml.org/sax/features/external-parameter-entities", false); + factory.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false); + factory.setFeature(javax.xml.XMLConstants.FEATURE_SECURE_PROCESSING, true); + factory.setNamespaceAware(true); + + DocumentBuilder builder = factory.newDocumentBuilder(); + try (ByteArrayInputStream inputStream = new ByteArrayInputStream(processedData.getBytes(Charset.defaultCharset()))) { + Document document = builder.parse(inputStream); + return Optional.of(new SAMLDocument(isB64, isDeflated, document)); + } + } catch (Exception e) { + return Optional.empty(); + } + } + + public String encode(String input) { + byte[] result = input.getBytes(Charset.defaultCharset()); + if (isCompressed) { + result = XMLUtilities.compressDeflate(result); + } + if (isBase64Encoded) { + result = Base64.getEncoder().encode(result); + } + return new String(result, Charset.defaultCharset()); + } + + } +} diff --git a/test/burp/XMLTest.java b/test/burp/XMLTest.java new file mode 100644 index 0000000..300805d --- /dev/null +++ b/test/burp/XMLTest.java @@ -0,0 +1,87 @@ +package burp; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.*; + +public class XMLTest { + @Test + void testTransformer() { + String originalSamlRequest = "\n" + + "" + + "https://example.com" + + "\n" + + ""; + Optional parsedDocument = XMLUtilities.SAMLDocument.parse(originalSamlRequest); + assertTrue(parsedDocument.isPresent()); + XMLUtilities.SAMLDocument samlDocument = parsedDocument.get(); + samlDocument.setCompressed(true); + samlDocument.setBase64Encoded(true); + String transformedXml = samlDocument.transformDocument(); + String encodedResult = samlDocument.encode(transformedXml); + assertNotNull(encodedResult); + Optional roundTripDocument = XMLUtilities.SAMLDocument.parse(encodedResult); + assertTrue(roundTripDocument.isPresent()); + } + + @Test + void testEmptyString() { + Optional decompressed = XMLUtilities.tryDecompress(new byte[]{}); + assertFalse(decompressed.isPresent()); + } + + @Test + void testEmptyInput() { + Optional doc = XMLUtilities.SAMLDocument.parse(""); + assertFalse(doc.isPresent()); + } + + @Test + void testInvalidXML() { + Optional doc = XMLUtilities.SAMLDocument.parse("NOTXML"); + assertFalse(doc.isPresent()); + } + + @Test + @Timeout(value = 5, unit = TimeUnit.SECONDS) + void testDecompressDeflateWithInvalidInput() { + + byte[] randomBytes = new byte[]{0x00, 0x01, 0x02, 0x03, 0x04, 0x05}; + Optional result = XMLUtilities.tryDecompress(randomBytes); + assertFalse(result.isPresent()); + + byte[] truncatedDeflate = new byte[]{0x78, (byte) 0x9c}; + Optional truncatedResult = XMLUtilities.tryDecompress(truncatedDeflate); + assertFalse(truncatedResult.isPresent()); + + byte[] malformedData = new byte[]{0x78, (byte) 0x9c, (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff}; + Optional malformedResult = XMLUtilities.tryDecompress(malformedData); + assertFalse(malformedResult.isPresent()); + + } + + @Test + @Timeout(value = 5, unit = TimeUnit.SECONDS) + void testCompressDeflateWithInvalidInput() { + + byte[] emptyInput = new byte[0]; + Optional result = XMLUtilities.tryCompress(emptyInput); + assertTrue(result.isPresent()); + + byte[] randomBytes = new byte[]{0x00, 0x01, 0x02, 0x03, 0x04, 0x05}; + Optional randomResult = XMLUtilities.tryCompress(randomBytes); + assertTrue(randomResult.isPresent()); + + } + +}