diff --git a/.github/workflows/bugCatcher.yml b/.github/workflows/bugCatcher.yml index 729c267a9..aee8e2de1 100644 --- a/.github/workflows/bugCatcher.yml +++ b/.github/workflows/bugCatcher.yml @@ -30,7 +30,7 @@ jobs: restore-keys: ${{ runner.os }}-m2 - name: Run Integration Tests - run: mvn -B test -Dtest=*e2e/OAuthTests* + run: mvn -pl jdbc-core -B test -Dtest='*e2e/OAuthTests*' env: DATABRICKS_HOST: ${{ secrets.JDBC_PAT_TEST_HOST_NAME }} DATABRICKS_HTTP_PATH: ${{ secrets.JDBC_PAT_TEST_HTTP_PATH }} diff --git a/.github/workflows/concurrencyExecutionTests.yml b/.github/workflows/concurrencyExecutionTests.yml index ebaa0249a..b7a7904f1 100644 --- a/.github/workflows/concurrencyExecutionTests.yml +++ b/.github/workflows/concurrencyExecutionTests.yml @@ -48,7 +48,7 @@ jobs: restore-keys: ${{ runner.os }}-m2 - name: Run Concurrency Execution Tests - run: mvn -B test -Dtest=com.databricks.jdbc.integration.e2e.ConcurrentExecutionTests -DargLine="-ea" + run: mvn -pl jdbc-core -B test -Dtest=com.databricks.jdbc.integration.e2e.ConcurrentExecutionTests -DargLine="-ea" env: DATABRICKS_TOKEN: ${{ secrets.JDBC_PAT_TEST_TOKEN }} DATABRICKS_USER: ${{ secrets.DATABRICKS_USER }} diff --git a/.github/workflows/coverageReport.yml b/.github/workflows/coverageReport.yml index adef7ad51..1b3ec131d 100644 --- a/.github/workflows/coverageReport.yml +++ b/.github/workflows/coverageReport.yml @@ -34,7 +34,7 @@ jobs: ${{ runner.os }}-m2- - name: Run tests with coverage - run: mvn clean test jacoco:report + run: mvn -pl jdbc-core clean test -Dgroups='!Jvm17PlusAndArrowToNioReflectionDisabled' jacoco:report - name: Check for coverage override id: override @@ -53,7 +53,7 @@ jobs: - name: Check coverage percentage if: steps.override.outputs.override == 'false' run: | - COVERAGE_FILE="target/site/jacoco/jacoco.xml" + COVERAGE_FILE="jdbc-core/target/site/jacoco/jacoco.xml" if [ ! -f "$COVERAGE_FILE" ]; then echo "ERROR: Coverage file not found at $COVERAGE_FILE" exit 1 diff --git a/.github/workflows/loggingTesting.yml b/.github/workflows/loggingTesting.yml index de2704fe0..61eae07db 100644 --- a/.github/workflows/loggingTesting.yml +++ b/.github/workflows/loggingTesting.yml @@ -45,13 +45,13 @@ jobs: - name: Find JAR file shell: bash run: | - # Find the main JAR file dynamically (fat JAR, not thin, not tests) - MAIN_JAR=$(find target -maxdepth 1 -name "databricks-jdbc-*.jar" \ + # Find the main JAR file dynamically (uber JAR from assembly-uber module) + MAIN_JAR=$(find assembly-uber/target -maxdepth 1 -name "databricks-jdbc-*.jar" \ -not -name "*-thin.jar" \ -not -name "*-tests.jar" | head -1) if [ -z "$MAIN_JAR" ]; then - echo "ERROR: Could not find main JAR file in target directory" - ls -la target/ + echo "ERROR: Could not find main JAR file in assembly-uber/target directory" + ls -la assembly-uber/target/ exit 1 fi echo "Using JAR file: $MAIN_JAR" @@ -88,18 +88,18 @@ jobs: - name: Clean & Compile LoggingTest shell: bash run: | - rm -rf target/test-classes - mkdir -p target/test-classes + rm -rf jdbc-core/target/test-classes + mkdir -p jdbc-core/target/test-classes echo "Using JAR file: $MAIN_JAR" javac \ -cp "$MAIN_JAR" \ - -d target/test-classes \ - src/test/java/com/databricks/client/jdbc/LoggingTest.java + -d jdbc-core/target/test-classes \ + jdbc-core/src/test/java/com/databricks/client/jdbc/LoggingTest.java echo "==== Checking compiled classes ====" - find target/test-classes -type f + find jdbc-core/target/test-classes -type f - name: Run LoggingTest shell: bash @@ -110,7 +110,7 @@ jobs: echo "Using classpath separator: '$SEP'" echo "Using JAR file: $MAIN_JAR" - CP="target/test-classes${SEP}$MAIN_JAR" + CP="jdbc-core/target/test-classes${SEP}$MAIN_JAR" java \ --add-opens=java.base/java.nio=ALL-UNNAMED \ diff --git a/.github/workflows/prCheck.yml b/.github/workflows/prCheck.yml index d20bd1544..57b00e198 100644 --- a/.github/workflows/prCheck.yml +++ b/.github/workflows/prCheck.yml @@ -76,21 +76,51 @@ jobs: key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }} restore-keys: ${{ runner.os }}-m2 + - name: Set up Maven Toolchains + shell: bash + run: | + mkdir -p ~/.m2 + cat > ~/.m2/toolchains.xml < + + + jdk + + ${{ matrix.java-version }} + + + $JAVA_HOME + + + + EOF + + - name: Check Arrow Patch Tests + shell: bash + if: matrix.java-version >= 17 + run: mvn -Pjdk${{ matrix.java-version }}-NioNotOpen -pl jdbc-core test -Dgroups='Jvm17PlusAndArrowToNioReflectionDisabled' + + - name: Check Arrow Allocator Manager Tests + shell: bash + if: matrix.java-version >= 17 + run: mvn -Pjdk${{ matrix.java-version }}-NioNotOpen -pl jdbc-core test -Dgroups='Jvm17PlusAndArrowToNioReflectionDisabled' -Dtest="ArrowBufferAllocatorNettyManagerTest,ArrowBufferAllocatorUnsafeManagerTest,ArrowBufferAllocatorUnknownManagerTest" -DforkCount=1 -DreuseForks=false + + - name: Check Arrow Memory Tests + shell: bash + run: mvn -Plow-memory -pl jdbc-core test -Dtest='DatabricksArrowPatchMemoryUsageTest' + - name: Check Unit Tests shell: bash - run: mvn test -Dtest='!**/integration/**,!**/DatabricksDriverExamples.java,!**/ProxyTest.java,!**/LoggingTest.java,!**/SSLTest.java' + run: mvn -pl jdbc-core clean test -Dgroups='!Jvm17PlusAndArrowToNioReflectionDisabled' jacoco:report - name: Install xmllint if: runner.os == 'Linux' run: sudo apt-get update && sudo apt-get install -y libxml2-utils - - name: JaCoCo report - run: mvn --batch-mode --errors jacoco:report --file pom.xml - - name: Extract codeCov percentage shell: bash run: | - COVERAGE_FILE="target/site/jacoco/jacoco.xml" + COVERAGE_FILE="jdbc-core/target/site/jacoco/jacoco.xml" COVERED=$(xmllint --xpath "string(//report/counter[@type='INSTRUCTION']/@covered)" "$COVERAGE_FILE") MISSED=$(xmllint --xpath "string(//report/counter[@type='INSTRUCTION']/@missed)" "$COVERAGE_FILE") TOTAL=$((COVERED + MISSED)) @@ -114,4 +144,80 @@ jobs: exit 1 else echo "Coverage is equal to or greater than 85%" - fi \ No newline at end of file + fi + + packaging-tests: + strategy: + fail-fast: false + matrix: + java-version: [ 17 ] + github-runner: [ linux-ubuntu-latest, windows-server-latest ] + + runs-on: + group: databricks-protected-runner-group + labels: ${{ matrix.github-runner }} + + steps: + - name: Set up JDK ${{ matrix.java-version }} + uses: actions/setup-java@v4 + with: + java-version: ${{ matrix.java-version }} + distribution: 'adopt' + + - name: Enable long paths + if: runner.os == 'Windows' + run: git config --system core.longpaths true + + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.ref || github.ref_name }} + repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }} + + - name: Cache Maven packages + uses: actions/cache@v4 + with: + path: ~/.m2 + key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }} + restore-keys: ${{ runner.os }}-m2 + + - name: Set up Maven Toolchains + shell: bash + run: | + mkdir -p ~/.m2 + cat > ~/.m2/toolchains.xml < + + + jdk + + ${{ matrix.java-version }} + + + $JAVA_HOME + + + + EOF + + - name: Install JDBC artifacts into maven local + shell: bash + run: mvn -B -pl jdbc-core,assembly-uber,assembly-thin install -DskipTests -Dmaven.javadoc.skip=true -Dmaven.source.skip=true -Ddependency-check.skip=true + + - name: Check Uber Jar Packaging + shell: bash + run: mvn -pl test-assembly-uber test + env: + DATABRICKS_HOST: ${{ secrets.JDBC_PAT_TEST_HOST_NAME }} + DATABRICKS_HTTP_PATH: ${{ secrets.JDBC_PAT_TEST_HTTP_PATH }} + DATABRICKS_USER: ${{ secrets.DATABRICKS_USER }} + DATABRICKS_TOKEN: ${{ secrets.JDBC_PAT_TEST_TOKEN }} + + - name: Check Thin Jar Packaging + shell: bash + run: mvn -pl test-assembly-thin test + env: + DATABRICKS_HOST: ${{ secrets.JDBC_PAT_TEST_HOST_NAME }} + DATABRICKS_HTTP_PATH: ${{ secrets.JDBC_PAT_TEST_HTTP_PATH }} + DATABRICKS_USER: ${{ secrets.DATABRICKS_USER }} + DATABRICKS_TOKEN: ${{ secrets.JDBC_PAT_TEST_TOKEN }} diff --git a/.github/workflows/prCheckJDK8.yml b/.github/workflows/prCheckJDK8.yml index 5202b4e55..089986716 100644 --- a/.github/workflows/prCheckJDK8.yml +++ b/.github/workflows/prCheckJDK8.yml @@ -48,4 +48,8 @@ jobs: restore-keys: ${{ runner.os }}-jdk8-m2 - name: Run Unit Tests - run: mvn clean test + run: mvn -pl jdbc-core clean test -Dgroups='!Jvm17PlusAndArrowToNioReflectionDisabled' + + - name: Check Arrow Memory Tests + shell: bash + run: mvn -Plow-memory -pl jdbc-core test -Dtest='DatabricksArrowPatchMemoryUsageTest' diff --git a/.github/workflows/prIntegrationTests.yml b/.github/workflows/prIntegrationTests.yml index fada21c77..1715b86f1 100644 --- a/.github/workflows/prIntegrationTests.yml +++ b/.github/workflows/prIntegrationTests.yml @@ -16,10 +16,10 @@ jobs: include: # SQL_EXEC mode: Tests SEA client behavior # Note: CircuitBreakerIntegrationTests requires THRIFT_SERVER mode (tested in second matrix entry) - - test-command: mvn -B compile test -Dtest=*IntegrationTests,!M2MPrivateKeyCredentialsIntegrationTests,!M2MAuthIntegrationTests,!CircuitBreakerIntegrationTests,!ThriftCloudFetchFakeIntegrationTests + - test-command: mvn -pl jdbc-core -B compile test -Dtest=*IntegrationTests,!M2MPrivateKeyCredentialsIntegrationTests,!M2MAuthIntegrationTests,!CircuitBreakerIntegrationTests,!ThriftCloudFetchFakeIntegrationTests fake-service-type: 'SQL_EXEC' # THRIFT_SERVER mode: Tests Thrift client behavior and circuit breaker fallback - - test-command: mvn -B compile test -Dtest=*IntegrationTests,!M2MPrivateKeyCredentialsIntegrationTests,!SqlExecApiHybridResultsIntegrationTests,!DBFSVolumeIntegrationTests,!M2MAuthIntegrationTests,!UCVolumeIntegrationTests,!SqlExecApiIntegrationTests + - test-command: mvn -pl jdbc-core -B compile test -Dtest=*IntegrationTests,!M2MPrivateKeyCredentialsIntegrationTests,!SqlExecApiHybridResultsIntegrationTests,!DBFSVolumeIntegrationTests,!M2MAuthIntegrationTests,!UCVolumeIntegrationTests,!SqlExecApiIntegrationTests fake-service-type: 'THRIFT_SERVER' steps: - name: Checkout PR diff --git a/.github/workflows/proxyTesting.yml b/.github/workflows/proxyTesting.yml index d7c08735a..62c36cf2b 100644 --- a/.github/workflows/proxyTesting.yml +++ b/.github/workflows/proxyTesting.yml @@ -157,7 +157,7 @@ jobs: ################################################################ - name: Run ProxyTest run: | - mvn test -Dtest=**/ProxyTest.java + mvn -pl jdbc-core test -Dtest=**/ProxyTest.java ################################################################ # 14) Cleanup diff --git a/.github/workflows/release-thin.yml b/.github/workflows/release-thin.yml index 5406cbb3d..0f898f543 100644 --- a/.github/workflows/release-thin.yml +++ b/.github/workflows/release-thin.yml @@ -31,111 +31,37 @@ jobs: - name: Set up Java for publishing to Maven Central Repository uses: actions/setup-java@v4 + env: + GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} with: java-version: 11 - distribution: "adopt" server-id: central + distribution: "adopt" server-username: MAVEN_CENTRAL_USERNAME server-password: MAVEN_CENTRAL_PASSWORD gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }} gpg-passphrase: GPG_PASSPHRASE - - name: Configure GPG + # Step 1: Build and install dependencies to local Maven repository + # This builds jdbc-core (and parent) without publishing them. + # The -am flag builds all dependencies needed by assembly-thin. + # We use -Prelease here to generate sources/javadoc JARs for jdbc-core, + # which assembly-thin needs for its own sources/javadoc artifacts. + # GPG signing is skipped since we're only installing locally, not publishing. + - name: Build dependencies run: | - echo "allow-loopback-pinentry" >> ~/.gnupg/gpg-agent.conf - echo "pinentry-mode loopback" >> ~/.gnupg/gpg.conf - gpg-connect-agent reloadagent /bye + mvn -Prelease clean install --batch-mode -pl jdbc-core -am -Dgpg.skip=true - - name: Build thin JAR with sources and javadocs + # Step 2: Deploy only the thin JAR module to Maven Central + # We don't use -am here to avoid the central-publishing-maven-plugin + # from collecting parent/jdbc-core artifacts into the deployment bundle. + # The jdbc-core dependency is already available from Step 1. + - name: Publish thin JAR to Maven Central run: | - # Build main artifacts including sources and javadocs - mvn -B -DskipTests package source:jar javadoc:jar - - - name: Sign all thin JAR artifacts - run: | - VERSION=$(grep -m1 '' pom.xml | sed 's/.*\(.*\)<\/version>.*/\1/') - - # Sign thin JAR - echo "$GPG_PASSPHRASE" | gpg --batch --yes --passphrase-fd 0 --pinentry-mode loopback \ - --armor --detach-sign "target/databricks-jdbc-${VERSION}-thin.jar" - - # Sign sources JAR - echo "$GPG_PASSPHRASE" | gpg --batch --yes --passphrase-fd 0 --pinentry-mode loopback \ - --armor --detach-sign "target/databricks-jdbc-${VERSION}-sources.jar" - - # Sign javadoc JAR - echo "$GPG_PASSPHRASE" | gpg --batch --yes --passphrase-fd 0 --pinentry-mode loopback \ - --armor --detach-sign "target/databricks-jdbc-${VERSION}-javadoc.jar" - env: - GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} - - - name: Verify all required artifacts exist - run: | - VERSION=$(grep -m1 '' pom.xml | sed 's/.*\(.*\)<\/version>.*/\1/') - test -f "target/databricks-jdbc-${VERSION}-thin.jar" - test -f "target/databricks-jdbc-${VERSION}-thin.jar.asc" - test -f "target/databricks-jdbc-${VERSION}-sources.jar" - test -f "target/databricks-jdbc-${VERSION}-sources.jar.asc" - test -f "target/databricks-jdbc-${VERSION}-javadoc.jar" - test -f "target/databricks-jdbc-${VERSION}-javadoc.jar.asc" - - - name: Publish Thin JAR as Separate Artifact to Maven Central - run: | - VERSION=$(grep -m1 '' pom.xml | sed 's/.*\(.*\)<\/version>.*/\1/') - - echo "Creating deployment bundle for thin JAR..." - - # Create staging directory - mkdir -p target/thin-staging/com/databricks/databricks-jdbc-thin/${VERSION} - - # Copy thin JAR and its signature - cp "target/databricks-jdbc-${VERSION}-thin.jar" \ - target/thin-staging/com/databricks/databricks-jdbc-thin/${VERSION}/databricks-jdbc-thin-${VERSION}.jar - cp "target/databricks-jdbc-${VERSION}-thin.jar.asc" \ - target/thin-staging/com/databricks/databricks-jdbc-thin/${VERSION}/databricks-jdbc-thin-${VERSION}.jar.asc - - # Copy sources JAR and its signature - cp "target/databricks-jdbc-${VERSION}-sources.jar" \ - target/thin-staging/com/databricks/databricks-jdbc-thin/${VERSION}/databricks-jdbc-thin-${VERSION}-sources.jar - cp "target/databricks-jdbc-${VERSION}-sources.jar.asc" \ - target/thin-staging/com/databricks/databricks-jdbc-thin/${VERSION}/databricks-jdbc-thin-${VERSION}-sources.jar.asc - - # Copy javadoc JAR and its signature - cp "target/databricks-jdbc-${VERSION}-javadoc.jar" \ - target/thin-staging/com/databricks/databricks-jdbc-thin/${VERSION}/databricks-jdbc-thin-${VERSION}-javadoc.jar - cp "target/databricks-jdbc-${VERSION}-javadoc.jar.asc" \ - target/thin-staging/com/databricks/databricks-jdbc-thin/${VERSION}/databricks-jdbc-thin-${VERSION}-javadoc.jar.asc - - # Copy POM and sign it - cp thin_public_pom.xml target/thin-staging/com/databricks/databricks-jdbc-thin/${VERSION}/databricks-jdbc-thin-${VERSION}.pom - echo "$GPG_PASSPHRASE" | gpg --batch --yes --passphrase-fd 0 --pinentry-mode loopback \ - --armor --detach-sign \ - target/thin-staging/com/databricks/databricks-jdbc-thin/${VERSION}/databricks-jdbc-thin-${VERSION}.pom - - # Generate checksums for all files - cd target/thin-staging/com/databricks/databricks-jdbc-thin/${VERSION} - for file in databricks-jdbc-thin-*; do - md5sum "$file" | awk '{print $1}' > "${file}.md5" - sha1sum "$file" | awk '{print $1}' > "${file}.sha1" - done - cd $GITHUB_WORKSPACE - - # Create bundle ZIP - cd target/thin-staging - zip -r ../central-thin-bundle.zip com/ - cd $GITHUB_WORKSPACE - - echo "Uploading bundle to Maven Central Portal..." - - # Upload to new Maven Central Portal - curl -X POST \ - -u "$MAVEN_CENTRAL_USERNAME:$MAVEN_CENTRAL_PASSWORD" \ - -F "bundle=@target/central-thin-bundle.zip" \ - -F "publishingType=AUTOMATIC" \ - -w "\nHTTP_STATUS:%{http_code}\n" \ - https://central.sonatype.com/api/v1/publisher/upload - - echo "Thin JAR published successfully!" + mvn -Prelease deploy --batch-mode -pl assembly-thin \ + -Dnvd.api.key=${{ secrets.NVD_API_KEY }} \ + -Dossindex.username=${{ secrets.OSSINDEX_USERNAME }} \ + -Dossindex.password=${{ secrets.OSSINDEX_PASSWORD }} env: GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} MAVEN_CENTRAL_USERNAME: ${{ secrets.MAVEN_CENTRAL_USERNAME }} @@ -174,5 +100,4 @@ jobs: with: tag_name: ${{ steps.get_tag.outputs.tag }} files: | - target/*-thin.jar - + assembly-thin/target/databricks-jdbc-thin-*.jar \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b8dbb474b..aa56131e6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -27,10 +27,23 @@ jobs: gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }} gpg-passphrase: GPG_PASSPHRASE - - name: Publish to the Maven Central Repository + # Step 1: Build and install dependencies to local Maven repository + # This builds jdbc-core (and parent) without publishing them. + # The -am flag builds all dependencies needed by assembly-uber. + # We use -Prelease here to generate sources/javadoc JARs for jdbc-core, + # which assembly-uber copies for its own sources/javadoc artifacts. + # GPG signing is skipped since we're only installing locally, not publishing. + - name: Build dependencies run: | - # Deploy main artifacts (uber JAR, sources, javadoc) - mvn -Prelease --batch-mode deploy \ + mvn -Prelease clean install --batch-mode -pl jdbc-core -am -Dgpg.skip=true + + # Step 2: Deploy only the uber JAR module to Maven Central + # We don't use -am here to avoid the central-publishing-maven-plugin + # from collecting parent/jdbc-core artifacts into the deployment bundle. + # The jdbc-core dependency is already available from Step 1. + - name: Publish uber JAR to Maven Central + run: | + mvn -Prelease deploy --batch-mode -pl assembly-uber \ -Dnvd.api.key=${{ secrets.NVD_API_KEY }} \ -Dossindex.username=${{ secrets.OSSINDEX_USERNAME }} \ -Dossindex.password=${{ secrets.OSSINDEX_PASSWORD }} @@ -43,4 +56,5 @@ jobs: uses: softprops/action-gh-release@v1 with: files: | - files: target/*.jar + assembly-uber/target/databricks-jdbc-*.jar + diff --git a/.github/workflows/runIntegrationTests.yml b/.github/workflows/runIntegrationTests.yml index 8eb7557a2..e74947ef1 100644 --- a/.github/workflows/runIntegrationTests.yml +++ b/.github/workflows/runIntegrationTests.yml @@ -14,10 +14,10 @@ jobs: strategy: matrix: include: - - test-command: mvn -B compile test -Dtest=*IntegrationTests,!CircuitBreakerIntegrationTests,!ThriftCloudFetchFakeIntegrationTests + - test-command: mvn -pl jdbc-core -B compile test -Dtest=*IntegrationTests,!CircuitBreakerIntegrationTests,!ThriftCloudFetchFakeIntegrationTests token-secret: DATABRICKS_TOKEN fake-service-type: 'SQL_EXEC' - - test-command: mvn -B compile test -Dtest=*IntegrationTests,!M2MPrivateKeyCredentialsIntegrationTests,!SqlExecApiHybridResultsIntegrationTests,!DBFSVolumeIntegrationTests,!M2MAuthIntegrationTests,!UCVolumeIntegrationTests,!SqlExecApiIntegrationTests + - test-command: mvn -pl jdbc-core -B compile test -Dtest=*IntegrationTests,!M2MPrivateKeyCredentialsIntegrationTests,!SqlExecApiHybridResultsIntegrationTests,!DBFSVolumeIntegrationTests,!M2MAuthIntegrationTests,!UCVolumeIntegrationTests,!SqlExecApiIntegrationTests token-secret: THRIFT_DATABRICKS_TOKEN fake-service-type: 'THRIFT_SERVER' steps: diff --git a/.github/workflows/runJdbcComparator.yml b/.github/workflows/runJdbcComparator.yml index 78f278ccc..442b92951 100644 --- a/.github/workflows/runJdbcComparator.yml +++ b/.github/workflows/runJdbcComparator.yml @@ -47,7 +47,7 @@ jobs: echo "DATABRICKS_COMPARATOR_TOKEN=${DATABRICKS_COMPARATOR_TOKEN}" >> $GITHUB_ENV - name: Run Tests - run: mvn test -Dtest=JDBCDriverComparisonTest + run: mvn -pl jdbc-core test -Dtest=JDBCDriverComparisonTest - name: Format Email Content run: | diff --git a/.github/workflows/slt.yml b/.github/workflows/slt.yml index c631ac540..b3b3d48a8 100644 --- a/.github/workflows/slt.yml +++ b/.github/workflows/slt.yml @@ -31,6 +31,6 @@ jobs: distribution: 'temurin' cache: maven - name: Build with Maven - run: mvn -B package --file pom.xml + run: mvn -pl jdbc-core -B package --file pom.xml -Dgroups='!Jvm17PlusAndArrowToNioReflectionDisabled' - name: Run SQL Logic Tests - run: mvn exec:exec -Dslt.token=${{ inputs.token }} + run: mvn -pl jdbc-core exec:exec -Dslt.token=${{ inputs.token }} diff --git a/.github/workflows/sslTesting.yml b/.github/workflows/sslTesting.yml index 3e89f575b..ef15948ca 100644 --- a/.github/workflows/sslTesting.yml +++ b/.github/workflows/sslTesting.yml @@ -226,11 +226,11 @@ jobs: - name: Maven Build run: | - mvn clean package -DskipTests + mvn -pl jdbc-core clean package -DskipTests - name: Run SSL Tests run: | - mvn test -Dtest=**/SSLTest.java + mvn -pl jdbc-core test -Dtest=**/SSLTest.java - name: Cleanup if: always() diff --git a/.github/workflows/vulnerabilityCatcher.yml b/.github/workflows/vulnerabilityCatcher.yml index 7f4d5b370..7962de3bf 100644 --- a/.github/workflows/vulnerabilityCatcher.yml +++ b/.github/workflows/vulnerabilityCatcher.yml @@ -23,12 +23,12 @@ jobs: cache: maven - name: Run OWASP Dependency Check - run: mvn org.owasp:dependency-check-maven:check -Dnvd.api.key=${{ secrets.NVD_API_KEY }} + run: mvn -pl jdbc-core org.owasp:dependency-check-maven:check -Dnvd.api.key=${{ secrets.NVD_API_KEY }} - name: Check for vulnerabilities id: check_vulnerabilities run: | - if grep -q "CVSS score >= 7" target/dependency-check-report.html; then + if grep -q "CVSS score >= 7" jdbc-core/target/dependency-check-report.html; then echo "has_vulnerabilities=true" >> $GITHUB_OUTPUT echo "Critical or high vulnerabilities found (CVSS score >= 7)" # Generate a simple HTML report for email @@ -62,6 +62,6 @@ jobs: with: name: security-scan-reports path: | - target/dependency-check-report.html - target/dependency-check-report.json + jdbc-core/target/dependency-check-report.html + jdbc-core/target/dependency-check-report.json security-scan-report.html \ No newline at end of file diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 4f4263787..173ea5f8e 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -7,6 +7,7 @@ ### Updated - Fat jar now routes SDK and Apache HTTP client logs through Java Util Logging (JUL), removing the need for external logging libraries. +- PECOBLR-1121 Arrow patch to circumvent Arrow issues with JDK 16+. ### Fixed - Fixed `rollback()` to throw `SQLException` when called in auto-commit mode (no active transaction), aligning with JDBC spec. Previously it silently sent a ROLLBACK command to the server. diff --git a/NOTICE b/NOTICE index 6fd09f3c5..25eac9b38 100644 --- a/NOTICE +++ b/NOTICE @@ -11,6 +11,7 @@ Notice - https://github.com/databricks/databricks-sdk-java/blob/main/NOTICE apache/arrow - https://github.com/apache/arrow/tree/main Copyright 2016-2025 The Apache Software Foundation +*This software contains code modified by Databricks, Inc.* Notice - https://github.com/apache/arrow/blob/main/NOTICE.txt diffplug/spotless - https://github.com/diffplug/spotless/tree/main diff --git a/README.md b/README.md index 2febe0313..05ff72b1a 100644 --- a/README.md +++ b/README.md @@ -26,13 +26,23 @@ Add the following dependency to your `pom.xml`: ### Build from Source +This is a multi-module Maven project: + +| Module | Artifact | Description | +|--------|----------|-------------| +| `jdbc-core` | `databricks-jdbc-core` | Core driver code | +| `assembly-uber` | `databricks-jdbc` | Uber jar with all dependencies bundled | +| `assembly-thin` | `databricks-jdbc-thin` | Thin jar (dependencies not bundled) | +| `test-assembly-uber` | `test-databricks-jdbc-uber` | Packaging tests for the uber jar | +| `test-assembly-thin` | `test-databricks-jdbc-thin` | Packaging tests for the thin jar | + 1. Clone the repository 2. Run the following command: ```bash mvn clean package ``` -3. The jar file is generated as `target/databricks-jdbc-.jar` -4. The test coverage report is generated in `target/site/jacoco/index.html` +3. The uber jar is generated at `assembly-uber/target/databricks-jdbc-.jar` +4. The test coverage report is generated in `jdbc-core/target/site/jacoco/index.html` ## Usage diff --git a/assembly-thin/pom.xml b/assembly-thin/pom.xml new file mode 100644 index 000000000..83313186b --- /dev/null +++ b/assembly-thin/pom.xml @@ -0,0 +1,261 @@ + + + 4.0.0 + + + com.databricks + databricks-jdbc-parent + 3.2.2-SNAPSHOT + + + databricks-jdbc-thin + jar + Databricks JDBC thin jar + + Databricks JDBC thin jar. + + https://github.com/databricks/databricks-jdbc + + + + Apache License, Version 2.0 + + https://github.com/databricks/databricks-jdbc/blob/main/LICENSE + + + + + + Databricks JDBC Team + eng-oss-sql-driver@databricks.com + Databricks + https://www.databricks.com + + + + scm:git:https://github.com/databricks/databricks-jdbc.git + + + scm:git:https://github.com/databricks/databricks-jdbc.git + + https://github.com/databricks/databricks-jdbc + + + GitHub Issues + https://github.com/databricks/databricks-jdbc/issues + + + + + com.databricks + databricks-jdbc-core + 3.2.2-SNAPSHOT + + + + + false + + + + + + org.codehaus.mojo + flatten-maven-plugin + 1.6.0 + + true + oss + + expand + remove + + + + + flatten + process-resources + + flatten + + + + flatten.clean + clean + + clean + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.5.0 + + + shade and package jars + package + + shade + + + true + true + + true + + + true + + + false + + + true + + + + org.apache.arrow:* + + com.databricks:databricks-jdbc-core + + + + + + org.apache.arrow + + com.databricks.internal.apache.arrow + + + + + + *:* + + META-INF/*.DSA + META-INF/*.RSA + META-INF/*.SF + META-INF/DEPENDENCIES + META-INF/LICENSE.txt + META-INF/versions/** + + + + + + + com.databricks.client.jdbc.Driver + + + + ${project.artifactId} + + + ${project.version} + + + + + + + + + + + + + + release + + + + + org.sonatype.central + central-publishing-maven-plugin + + false + + + + + org.apache.maven.plugins + maven-source-plugin + + + attach-sources + none + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + + attach-javadocs + none + + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-javadoc + package + + copy + + + + + com.databricks + databricks-jdbc-core + ${project.version} + javadoc + jar + ${project.artifactId}-${project.version}-javadoc.jar + + + ${project.build.directory} + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + attach-javadoc + package + + attach-artifact + + + + + ${project.build.directory}/${project.artifactId}-${project.version}-javadoc.jar + jar + javadoc + + + + + + + + + + + + \ No newline at end of file diff --git a/assembly-thin/src/main/resources/README.md b/assembly-thin/src/main/resources/README.md new file mode 100644 index 000000000..506fc8b4c --- /dev/null +++ b/assembly-thin/src/main/resources/README.md @@ -0,0 +1 @@ +Shading Arrow in the driver. \ No newline at end of file diff --git a/assembly-uber/pom.xml b/assembly-uber/pom.xml new file mode 100644 index 000000000..604bb61f6 --- /dev/null +++ b/assembly-uber/pom.xml @@ -0,0 +1,346 @@ + + + 4.0.0 + + + com.databricks + databricks-jdbc-parent + 3.2.2-SNAPSHOT + + + databricks-jdbc + jar + Databricks JDBC uber jar + + Databricks JDBC uber jar. + + https://github.com/databricks/databricks-jdbc + + + + Apache License, Version 2.0 + + https://github.com/databricks/databricks-jdbc/blob/main/LICENSE + + + + + + Databricks JDBC Team + eng-oss-sql-driver@databricks.com + Databricks + https://www.databricks.com + + + + scm:git:https://github.com/databricks/databricks-jdbc.git + + + scm:git:https://github.com/databricks/databricks-jdbc.git + + https://github.com/databricks/databricks-jdbc + + + GitHub Issues + https://github.com/databricks/databricks-jdbc/issues + + + + + com.databricks + databricks-jdbc-core + 3.2.2-SNAPSHOT + + + + + false + + + + + + org.codehaus.mojo + flatten-maven-plugin + 1.6.0 + + true + oss + + expand + remove + + + + + flatten + process-resources + + flatten + + + + flatten.clean + clean + + clean + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.5.0 + + + shade and package jars + package + + shade + + + + false + + + + codegen + + com.databricks.internal.codegen + + + + com.databricks.sdk + com.databricks.internal.sdk + + + + com.fasterxml + + com.databricks.internal.fasterxml + + + + com.google + + com.databricks.internal.google + + + + com.nimbusds + + com.databricks.internal.nimbusds + + + + io + com.databricks.internal.io + + + + net.jpountz + + com.databricks.internal.jpountz + + + + org.apache + + com.databricks.internal.apache + + + + org.bouncycastle + + com.databricks.internal.bouncycastle + + + + org.checkerframework + + com.databricks.internal.checkerframework + + + + org.ini4j + + com.databricks.internal.ini4j + + + + org.json + + com.databricks.internal.json + + + + org.locationtech.jts + com.databricks.internal.jts + + + + org.osgi + + com.databricks.internal.osgi + + + + org.slf4j + + com.databricks.internal.slf4j + + + + + + *:* + + META-INF/*.DSA + META-INF/*.RSA + META-INF/*.SF + META-INF/DEPENDENCIES + META-INF/LICENSE.txt + META-INF/versions/** + + + + *:* + + edu/** + javax/** + jakarta/** + net/jcip/** + + + + + + + com.databricks.client.jdbc.Driver + + + + ${project.artifactId} + + + ${project.version} + + + + + + + + + + + + + + + release + + + + + org.sonatype.central + central-publishing-maven-plugin + + false + + + + + org.apache.maven.plugins + maven-source-plugin + + + attach-sources + none + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + + attach-javadocs + none + + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-sources-javadoc + package + + copy + + + + + com.databricks + databricks-jdbc-core + ${project.version} + sources + jar + ${project.artifactId}-${project.version}-sources.jar + + + com.databricks + databricks-jdbc-core + ${project.version} + javadoc + jar + ${project.artifactId}-${project.version}-javadoc.jar + + + ${project.build.directory} + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + attach-sources-javadoc + package + + attach-artifact + + + + + ${project.build.directory}/${project.artifactId}-${project.version}-sources.jar + jar + sources + + + ${project.build.directory}/${project.artifactId}-${project.version}-javadoc.jar + jar + javadoc + + + + + + + + + + + + diff --git a/assembly-uber/src/main/resources/README.md b/assembly-uber/src/main/resources/README.md new file mode 100644 index 000000000..f8e46b480 --- /dev/null +++ b/assembly-uber/src/main/resources/README.md @@ -0,0 +1 @@ +Shaded version of the driver. \ No newline at end of file diff --git a/jdbc-core/pom.xml b/jdbc-core/pom.xml new file mode 100644 index 000000000..7b6c273be --- /dev/null +++ b/jdbc-core/pom.xml @@ -0,0 +1,527 @@ + + + 4.0.0 + + + com.databricks + databricks-jdbc-parent + 3.2.2-SNAPSHOT + + + databricks-jdbc-core + jar + Databricks JDBC Driver + Databricks JDBC Driver. + https://github.com/databricks/databricks-jdbc + + + Apache License, Version 2.0 + https://github.com/databricks/databricks-jdbc/blob/main/LICENSE + + + + + Databricks JDBC Team + eng-oss-sql-driver@databricks.com + Databricks + https://www.databricks.com + + + + scm:git:https://github.com/databricks/databricks-jdbc.git + scm:git:https://github.com/databricks/databricks-jdbc.git + https://github.com/databricks/databricks-jdbc + + + GitHub Issues + https://github.com/databricks/databricks-jdbc/issues + + + + local-test-repo + file://${project.build.directory}/local-repo + + + + + + + org.apache.commons + commons-lang3 + ${commons-lang3.version} + + + + com.google.code.gson + gson + ${gson.version} + + + + + + com.databricks + databricks-sdk-java + ${databricks-sdk.version} + + + org.apache.commons + commons-configuration2 + ${commons-configuration.version} + + + org.apache.arrow + arrow-memory-core + ${arrow.version} + + + org.apache.arrow + arrow-memory-unsafe + ${arrow.version} + + + org.apache.arrow + arrow-vector + ${arrow.version} + + + org.apache.arrow + arrow-memory-netty + ${arrow.version} + + + org.apache.httpcomponents + httpclient + ${httpclient.version} + + + org.apache.thrift + libthrift + ${thrift.version} + + + org.slf4j + slf4j-api + ${slf4j.version} + + + + org.slf4j + slf4j-jdk14 + ${slf4j.version} + + + commons-io + commons-io + ${commons-io.version} + + + com.google.code.findbugs + annotations + ${google.findbugs.annotations.version} + + + com.google.guava + guava + ${google.guava.version} + + + org.junit.jupiter + junit-jupiter + ${junit.jupiter.version} + test + + + com.nimbusds + nimbus-jose-jwt + ${nimbusjose.version} + + + org.bouncycastle + bcprov-jdk18on + ${bouncycastle.version} + + + org.bouncycastle + bcpkix-jdk18on + ${bouncycastle.version} + + + org.mockito + mockito-inline + ${mockito.version} + test + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + test + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + com.fasterxml.jackson.core + jackson-annotations + ${jackson.version} + + + com.fasterxml.jackson.core + jackson-core + ${jackson.version} + + + org.immutables + value + ${immutables.value.version} + provided + + + net.hydromatic + sql-logic-test + ${sql-logic-test.version} + test + + + org.lz4 + lz4-java + ${lz4-compression.version} + + + + io.grpc + grpc-context + ${grpc.version} + + + + io.netty + netty-common + ${netty.version} + + + + io.netty + netty-buffer + ${netty.version} + + + jakarta.annotation + jakarta.annotation-api + ${annotation.version} + + + org.wiremock + wiremock + ${wiremock.version} + test + + + commons-fileupload + commons-fileupload + + + + + org.apache.httpcomponents.client5 + httpclient5 + ${async-httpclient.version} + + + org.apache.httpcomponents.core5 + httpcore5 + ${async-httpclient.version} + + + io.github.resilience4j + resilience4j-circuitbreaker + ${resilience4j.version} + + + io.github.resilience4j + resilience4j-core + ${resilience4j.version} + + + org.locationtech.jts + jts-core + 1.20.0 + + + + org.openjdk.jmh + jmh-core + ${jmh.version} + test + + + + org.openjdk.jmh + jmh-generator-annprocess + ${jmh.version} + test + + + + + + ${project.artifactId}-${project.version} + + + org.apache.maven.plugins + maven-jar-plugin + + + + com.databricks.client.jdbc.Driver + true + + + + + + attach-test-jar + + test-jar + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + **/DatabricksDriverExamples.java + **/integration/**/*.java + **/ErrorTypes.java + **/ErrorCodes.java + **/ProxyTest.java + **/LoggingTest.java + **/SSLTest.java + + **/ArrowBufferAllocatorNettyManagerTest.java + **/ArrowBufferAllocatorUnsafeManagerTest.java + **/ArrowBufferAllocatorUnknownManagerTest.java + + + @{argLine} + -Xmx5g + --add-opens=java.base/java.nio=ALL-UNNAMED + -Dnet.bytebuddy.experimental=true + + + + + org.codehaus.mojo + exec-maven-plugin + + java + + --add-opens=java.base/java.nio=ALL-UNNAMED + -classpath + + com.databricks.jdbc.sqllogictest.SLTMain + -e + ${slt.executor} + -p + ${slt.token} + + test + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + org.immutables + value + ${immutables.value.version} + + + org.openjdk.jmh + jmh-generator-annprocess + ${jmh.version} + + + + + + org.owasp + dependency-check-maven + + + HTML + JSON + + + 7 + ${nvd.api.key} + 10 + 4000 + true + ${ossindex.username} + ${ossindex.password} + + + + + check + + + + + + org.jacoco + jacoco-maven-plugin + + + + prepare-agent + + + + report + prepare-package + + report + + + + + + **/*Constants* + **/*Exception* + **/CommandName* + **/DatabricksJdbcConstants* + **/DatabricksJdbcUrlParams* + **/Driver* + **/EnvironmentVariables* + **/model/** + **/thrift/generated/** + + + org/apache/arrow/memory/util/MemoryUtil* + org/apache/arrow/memory/ArrowBuf* + org/apache/arrow/vector/util/DecimalUtility* + + + **/DatabricksArrowBuf* + + + + + + + + + low-memory + + + + org.apache.maven.plugins + maven-surefire-plugin + + + @{argLine} -Xmx100m + + + + + + + + + + jdk17-NioNotOpen + + + + org.apache.maven.plugins + maven-toolchains-plugin + + + + toolchain + + + + + + + 17 + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + + @{argLine} -Darrow.memory.debug.allocator=true + + + + + + + + + + jdk21-NioNotOpen + + + + org.apache.maven.plugins + maven-toolchains-plugin + + + + toolchain + + + + + + + 21 + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + + @{argLine} -Darrow.memory.debug.allocator=true + + + + + + + + + + diff --git a/jdbc-core/src b/jdbc-core/src new file mode 120000 index 000000000..5cd551cf2 --- /dev/null +++ b/jdbc-core/src @@ -0,0 +1 @@ +../src \ No newline at end of file diff --git a/pom.xml b/pom.xml index 65f53a3dc..7552124fc 100644 --- a/pom.xml +++ b/pom.xml @@ -3,13 +3,22 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 com.databricks - databricks-jdbc + databricks-jdbc-parent - 3.2.1 - jar - Databricks JDBC Driver - Databricks JDBC Driver. + 3.2.2-SNAPSHOT + pom + Databricks JDBC Parent + Parent POM for Databricks JDBC Driver. https://github.com/databricks/databricks-jdbc + + + jdbc-core + assembly-thin + assembly-uber + test-assembly-thin + test-assembly-uber + + Apache License, Version 2.0 @@ -33,383 +42,121 @@ GitHub Issues https://github.com/databricks/databricks-jdbc/issues - - - local-test-repo - file://${project.build.directory}/local-repo - - + + UTF-8 UTF-8 - 17.0.0 - 3.18.0 11 11 - 5.2.0 - 2.18.3 - 2.0.13 - 33.0.0-jre - 5.9.2 - 3.0.1 - 2.13.2 - 2.9.2 - 4.5.14 + true + + + 3.2.0 + 3.14.1 + 3.1.2 + 3.3.0 + 1.2.1 + 12.1.6 + 0.8.11 + 2.39.0 + 3.6.1 + + + 3.2.2-SNAPSHOT + 18.3.0 + 3.18.0 2.10.1 2.14.0 0.69.0 - 3.1.2 - 0.3 - 1.10.1 + 4.5.14 + 5.3.1 0.19.0 + 2.0.13 + 2.18.3 + 2.13.2 + 33.0.0-jre + 3.0.1 + 2.9.2 + 1.8.1 1.3.5 - dbsql - dummy-token - 3.5.4 - 10.0.2 - 1.79 - 5.3.1 4.2.6.Final 1.71.0 1.7.0 + 10.0.2 + 1.79 + 1.37 + + + 5.9.2 + + 5.2.0 + 3.5.4 + 0.3 + + + dbsql + dummy-token + + + true - - - - - org.apache.commons - commons-lang3 - ${commons-lang3.version} - - - - com.google.code.gson - gson - ${gson.version} - - - - - - com.databricks - databricks-sdk-java - ${databricks-sdk.version} - - - org.apache.commons - commons-configuration2 - ${commons-configuration.version} - - - org.apache.arrow - arrow-memory-core - ${arrow.version} - - - org.apache.arrow - arrow-memory-unsafe - ${arrow.version} - - - org.apache.arrow - arrow-vector - ${arrow.version} - - - org.apache.arrow - arrow-memory-netty - ${arrow.version} - - - org.apache.httpcomponents - httpclient - ${httpclient.version} - - - org.apache.thrift - libthrift - ${thrift.version} - - - org.slf4j - slf4j-api - ${slf4j.version} - - - - org.slf4j - slf4j-jdk14 - ${slf4j.version} - - - commons-io - commons-io - ${commons-io.version} - - - com.google.code.findbugs - annotations - ${google.findbugs.annotations.version} - - - com.google.guava - guava - ${google.guava.version} - - - org.junit.jupiter - junit-jupiter - ${junit.jupiter.version} - test - - - com.nimbusds - nimbus-jose-jwt - ${nimbusjose.version} - - - org.bouncycastle - bcprov-jdk18on - ${bouncycastle.version} - - - org.bouncycastle - bcpkix-jdk18on - ${bouncycastle.version} - - - org.mockito - mockito-inline - ${mockito.version} - test - - - org.mockito - mockito-junit-jupiter - ${mockito.version} - test - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-core - ${jackson.version} - - - org.immutables - value - ${immutables.value.version} - provided - - - net.hydromatic - sql-logic-test - ${sql-logic-test.version} - test - - - at.yawk.lz4 - lz4-java - ${lz4-compression.version} - - - - io.grpc - grpc-context - ${grpc.version} - - - - io.netty - netty-common - ${netty.version} - - - - io.netty - netty-buffer - ${netty.version} - - - jakarta.annotation - jakarta.annotation-api - ${annotation.version} - - - org.wiremock - wiremock - ${wiremock.version} - test - - - commons-fileupload - commons-fileupload - - - - - org.apache.httpcomponents.client5 - httpclient5 - ${async-httpclient.version} - - - org.apache.httpcomponents.core5 - httpcore5 - ${async-httpclient.version} - - - io.github.resilience4j - resilience4j-circuitbreaker - ${resilience4j.version} - - - io.github.resilience4j - resilience4j-core - ${resilience4j.version} - - - org.locationtech.jts - jts-core - 1.20.0 - - - - ${project.artifactId}-${project.version} + + + + org.apache.maven.plugins + maven-compiler-plugin + ${maven-compiler-plugin.version} + + + org.apache.maven.plugins + maven-surefire-plugin + ${maven-surefire-plugin.version} + + + org.apache.maven.plugins + maven-jar-plugin + ${maven-jar-plugin.version} + + + org.codehaus.mojo + exec-maven-plugin + ${exec-maven-plugin.version} + + + org.owasp + dependency-check-maven + ${dependency-check-maven.version} + + + org.jacoco + jacoco-maven-plugin + ${jacoco-maven-plugin.version} + + + com.diffplug.spotless + spotless-maven-plugin + ${spotless-maven-plugin.version} + + + org.apache.maven.plugins + maven-toolchains-plugin + ${maven-toolchains-plugin.version} + + + org.codehaus.mojo + build-helper-maven-plugin + ${build-helper-maven-plugin.version} + + + + - - org.apache.maven.plugins - maven-jar-plugin - 3.3.0 - - - - com.databricks.client.jdbc.Driver - true - - - - - - attach-thin-jar - package - - jar - - - thin - - - - attach-test-jar - - test-jar - - - - - - org.apache.maven.plugins - maven-surefire-plugin - ${maven-surefire-plugin.version} - - - **/DatabricksDriverExamples.java - **/integration/**/*.java - **/ErrorTypes.java - **/ErrorCodes.java - **/ProxyTest.java - **/LoggingTest.java - **/SSLTest.java - - - @{argLine} - --add-opens=java.base/java.nio=ALL-UNNAMED - -Dnet.bytebuddy.experimental=true - - - - - org.codehaus.mojo - exec-maven-plugin - 1.2.1 - - java - - --add-opens=java.base/java.nio=ALL-UNNAMED - -classpath - - com.databricks.jdbc.sqllogictest.SLTMain - -e - ${slt.executor} - -p - ${slt.token} - - test - - - - org.apache.maven.plugins - maven-compiler-plugin - - - - org.immutables - value - ${immutables.value.version} - - - - - - org.owasp - dependency-check-maven - 12.1.6 - - - HTML - JSON - - - 7 - ${nvd.api.key} - 10 - 4000 - true - ${ossindex.username} - ${ossindex.password} - - - - - check - - - - com.diffplug.spotless spotless-maven-plugin - 2.39.0 format @@ -424,155 +171,20 @@ 1.18.1 + + + + **/MemoryUtil.java + **/ArrowBuf.java + **/DecimalUtility.java + - - org.jacoco - jacoco-maven-plugin - 0.8.11 - - - - prepare-agent - - - - report - prepare-package - - report - - - - - - **/*Constants* - **/*Exception* - **/CommandName* - **/DatabricksJdbcConstants* - **/DatabricksJdbcUrlParams* - **/Driver* - **/EnvironmentVariables* - **/model/** - **/thrift/generated/** - - - - - org.apache.maven.plugins - maven-shade-plugin - 3.5.0 - - - shade and package jars - package - - shade - - - - false - - - codegen - com.databricks.internal.codegen - - - com.databricks.sdk - com.databricks.internal.sdk - - - com.fasterxml - com.databricks.internal.fasterxml - - - com.google - com.databricks.internal.google - - - com.nimbusds - com.databricks.internal.nimbusds - - - io - com.databricks.internal.io - - - net.jpountz - com.databricks.internal.jpountz - - - org.apache - com.databricks.internal.apache - - - org.bouncycastle - com.databricks.internal.bouncycastle - - - org.checkerframework - com.databricks.internal.checkerframework - - - org.ini4j - com.databricks.internal.ini4j - - - org.json - com.databricks.internal.json - - - org.locationtech.jts - com.databricks.internal.jts - - - org.osgi - com.databricks.internal.osgi - - - org.slf4j - com.databricks.internal.slf4j - - - - - *:* - - META-INF/*.DSA - META-INF/*.RSA - META-INF/*.SF - META-INF/DEPENDENCIES - META-INF/LICENSE.txt - META-INF/versions/** - - - - *:* - - edu/** - javax/** - jakarta/** - net/jcip/** - - - - - - com.databricks.client.jdbc.Driver - - ${project.artifactId} - ${project.version} - - - - - - - - + @@ -623,28 +235,6 @@ release - - org.codehaus.mojo - build-helper-maven-plugin - 3.6.1 - - - attach-uber-minimal-pom - - attach-artifact - - package - - - - ${project.basedir}/uber-minimal-pom.xml - pom - - - - - - org.apache.maven.plugins maven-source-plugin @@ -705,12 +295,12 @@ central true published + + true - - - + \ No newline at end of file diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java index 1a9664299..30dfdcac5 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java @@ -23,7 +23,6 @@ import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowStreamReader; @@ -99,7 +98,7 @@ protected AbstractArrowResultChunk( this.rowOffset = rowOffset; this.chunkIndex = chunkIndex; this.statementId = statementId; - this.rootAllocator = new RootAllocator(Integer.MAX_VALUE); + this.rootAllocator = ArrowBufferAllocator.getBufferAllocator(); this.chunkReadyFuture = new CompletableFuture<>(); this.chunkLink = chunkLink; this.expiryTime = expiryTime; diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocator.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocator.java new file mode 100644 index 000000000..26aff4c48 --- /dev/null +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocator.java @@ -0,0 +1,83 @@ +package com.databricks.jdbc.api.impl.arrow; + +import com.databricks.jdbc.log.JdbcLogger; +import com.databricks.jdbc.log.JdbcLoggerFactory; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.DatabricksBufferAllocator; +import org.apache.arrow.memory.RootAllocator; + +/** + * Creates {@link BufferAllocator} instances. + * + *

First tries to create a {@link RootAllocator} which uses off-heap memory and is faster. If + * that fails (usually due to JVM reflection restrictions), falls back to {@link + * DatabricksBufferAllocator} which uses heap memory. + */ +public class ArrowBufferAllocator { + /** Should the RootAllocator be used. */ + private static final boolean canUseRootAllocator; + + /** Logger instance. */ + private static final JdbcLogger LOGGER = JdbcLoggerFactory.getLogger(ArrowBufferAllocator.class); + + /* Check if the RootAllocator can be used. */ + static { + canUseRootAllocator = canUseRootAllocator(); + } + + /** + * @return true iff the RootAllocator can be used. + */ + static boolean canUseRootAllocator() { + RootAllocator rootAllocator = null; + ArrowBuf buffer = null; + boolean canWriteWithRootAllocator = false; + + try { + rootAllocator = new RootAllocator(); + buffer = rootAllocator.buffer(64); + buffer.writeByte(0); + canWriteWithRootAllocator = true; + } catch (Throwable t) { + String message = t.getMessage(); + if (message == null) { + message = t.getCause() != null ? t.getCause().getMessage() : ""; + } + LOGGER.info( + "Failed to create RootAllocator, will use DatabricksBufferAllocator as fallback: " + + message); + } + + if (rootAllocator != null) { + try { + if (buffer != null) { + buffer.close(); + } + rootAllocator.close(); + } catch (Throwable t) { + LOGGER.warn("RootAllocator could not be closed: " + t.getMessage()); + } + } + + return canWriteWithRootAllocator; + } + + /** + * @return an instance of the {@code BufferAllocator}. + */ + public static BufferAllocator getBufferAllocator() { + if (canUseRootAllocator) { + return new RootAllocator(); + } else { + return new DatabricksBufferAllocator(); + } + } + + /** + * @return true iff the patched Databricks allocator is being used. + */ + public static boolean isUsingPatchedAllocator() { + return !canUseRootAllocator; + } +} diff --git a/src/main/java/com/databricks/jdbc/model/telemetry/SqlExecutionEvent.java b/src/main/java/com/databricks/jdbc/model/telemetry/SqlExecutionEvent.java index 1b65b3bac..73e890b60 100644 --- a/src/main/java/com/databricks/jdbc/model/telemetry/SqlExecutionEvent.java +++ b/src/main/java/com/databricks/jdbc/model/telemetry/SqlExecutionEvent.java @@ -33,6 +33,9 @@ public class SqlExecutionEvent { @JsonProperty("operation_detail") OperationDetail operationDetail; + @JsonProperty("java_uses_patched_arrow") + Boolean javaUsesPatchedArrow; + public SqlExecutionEvent setDriverStatementType(StatementType driverStatementType) { this.driverStatementType = driverStatementType; return this; @@ -73,6 +76,11 @@ public SqlExecutionEvent setOperationDetail(OperationDetail operationDetail) { return this; } + public SqlExecutionEvent setJavaUsesPatchedArrow(Boolean javaUsesPatchedArrow) { + this.javaUsesPatchedArrow = javaUsesPatchedArrow; + return this; + } + @Override public String toString() { return new ToStringer(SqlExecutionEvent.class) @@ -84,6 +92,7 @@ public String toString() { .add("chunk_details", chunkDetails) .add("result_latency", resultLatency) .add("operation_details", operationDetail) + .add("java_uses_patched_arrow", javaUsesPatchedArrow) .toString(); } } diff --git a/src/main/java/com/databricks/jdbc/telemetry/TelemetryHelper.java b/src/main/java/com/databricks/jdbc/telemetry/TelemetryHelper.java index 7dee9ca7a..72002357a 100644 --- a/src/main/java/com/databricks/jdbc/telemetry/TelemetryHelper.java +++ b/src/main/java/com/databricks/jdbc/telemetry/TelemetryHelper.java @@ -3,6 +3,7 @@ import static com.databricks.jdbc.common.DatabricksJdbcConstants.QUERY_TAGS; import static com.databricks.jdbc.common.util.WildcardUtil.isNullOrEmpty; +import com.databricks.jdbc.api.impl.arrow.ArrowBufferAllocator; import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.DatabricksClientConfiguratorManager; @@ -116,6 +117,7 @@ private static void exportTelemetryEvent( .setResultLatency(telemetryDetails.getResultLatency()) .setOperationDetail(telemetryDetails.getOperationDetail()) .setExecutionResultFormat(telemetryDetails.getExecutionResultFormat()) + .setJavaUsesPatchedArrow(ArrowBufferAllocator.isUsingPatchedAllocator()) .setChunkId(chunkIndex); // This is only set for chunk download failure logs telemetryEvent.setSqlOperation(sqlExecutionEvent); diff --git a/src/main/java/org/apache/arrow/memory/ArrowBuf.java b/src/main/java/org/apache/arrow/memory/ArrowBuf.java new file mode 100644 index 000000000..61536f8a6 --- /dev/null +++ b/src/main/java/org/apache/arrow/memory/ArrowBuf.java @@ -0,0 +1,1271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ------------------------------------------------------------------------- +// MODIFICATION NOTICE: +// This file was modified by Databricks, Inc on 16-December-2025. +// Description of changes: +// - Patched ArrowBuf to be non-final and extensible. +// - Patched ArrowBuf to remove dependency on BaseAllocator during static initialization. +// - Patched ArrowBuf to modify method visibility from public to private for +// `print(StringBuilder sb, int indent, Verbosity verbosity)` +// ------------------------------------------------------------------------- + +package org.apache.arrow.memory; + +import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ReadOnlyBufferException; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.arrow.memory.BaseAllocator.Verbosity; +import org.apache.arrow.memory.util.CommonUtil; +import org.apache.arrow.memory.util.HistoricalLog; +import org.apache.arrow.memory.util.MemoryUtil; +import org.apache.arrow.util.Preconditions; + +/** + * ArrowBuf serves as a facade over underlying memory by providing several access APIs to read/write + * data into a chunk of direct memory. All the accounting, ownership and reference management is + * done by {@link ReferenceManager} and ArrowBuf can work with a custom user provided implementation + * of ReferenceManager + * + *

Two important instance variables of an ArrowBuf: (1) address - starting virtual address in the + * underlying memory chunk that this ArrowBuf has access to (2) length - length (in bytes) in the + * underlying memory chunk that this ArrowBuf has access to + * + *

The management (allocation, deallocation, reference counting etc) for the memory chunk is not + * done by ArrowBuf. Default implementation of ReferenceManager, allocation is in {@link + * BaseAllocator}, {@link BufferLedger} and {@link AllocationManager} + */ +public class ArrowBuf implements AutoCloseable { + + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ArrowBuf.class); + + // ---- Databricks patch start ---- + // ---- Copied verbatim from BaseAllocator. We avoid initializing static fields of BaseAllocator + // ---- to avoid unsafe allocation initialization errors. + public static final String DEBUG_ALLOCATOR = "arrow.memory.debug.allocator"; + public static final int DEBUG_LOG_LENGTH = 6; + public static final boolean DEBUG; + + // Initialize this before DEFAULT_CONFIG as DEFAULT_CONFIG will eventually initialize the + // allocation manager, + // which in turn allocates an ArrowBuf, which requires DEBUG to have been properly initialized + static { + // the system property takes precedence. + String propValue = System.getProperty(DEBUG_ALLOCATOR); + if (propValue != null) { + DEBUG = Boolean.parseBoolean(propValue); + } else { + DEBUG = false; + } + logger.info( + "Debug mode " + + (DEBUG + ? "enabled." + : "disabled. Enable with the VM option -Darrow.memory.debug.allocator=true.")); + } + + // ---- Databricks patch end ---- + + private static final int SHORT_SIZE = Short.BYTES; + private static final int INT_SIZE = Integer.BYTES; + private static final int FLOAT_SIZE = Float.BYTES; + private static final int DOUBLE_SIZE = Double.BYTES; + private static final int LONG_SIZE = Long.BYTES; + + private static final AtomicLong idGenerator = new AtomicLong(0); + private static final int LOG_BYTES_PER_ROW = 10; + private final long id = idGenerator.incrementAndGet(); + private final ReferenceManager referenceManager; + private final BufferManager bufferManager; + private final long addr; + private long readerIndex; + private long writerIndex; + + // ---- Databricks patch start ---- + private final HistoricalLog historicalLog = + DEBUG ? new HistoricalLog(DEBUG_LOG_LENGTH, "ArrowBuf[%d]", id) : null; + // ---- Databricks patch end ---- + + private volatile long capacity; + + /** + * Constructs a new ArrowBuf. + * + * @param referenceManager The memory manager to track memory usage and reference count of this + * buffer + * @param capacity The capacity in bytes of this buffer + */ + public ArrowBuf( + final ReferenceManager referenceManager, + final BufferManager bufferManager, + final long capacity, + final long memoryAddress) { + this.referenceManager = referenceManager; + this.bufferManager = bufferManager; + this.addr = memoryAddress; + this.capacity = capacity; + this.readerIndex = 0; + this.writerIndex = 0; + if (historicalLog != null) { + historicalLog.recordEvent("create()"); + } + } + + public int refCnt() { + return referenceManager.getRefCount(); + } + + /** + * Allows a function to determine whether not reading a particular string of bytes is valid. + * + *

Will throw an exception if the memory is not readable for some reason. Only doesn't + * something in the case that AssertionUtil.BOUNDS_CHECKING_ENABLED is true. + * + * @param start The starting position of the bytes to be read. + * @param end The exclusive endpoint of the bytes to be read. + */ + public void checkBytes(long start, long end) { + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + checkIndexD(start, end - start); + } + } + + /** For get/set operations, reference count should be >= 1. */ + private void ensureAccessible() { + if (this.refCnt() == 0) { + throw new IllegalStateException("Ref count should be >= 1 for accessing the ArrowBuf"); + } + } + + /** + * Get reference manager for this ArrowBuf. + * + * @return user provided implementation of {@link ReferenceManager} + */ + public ReferenceManager getReferenceManager() { + return referenceManager; + } + + public long capacity() { + return capacity; + } + + /** + * Adjusts the capacity of this buffer. Size increases are NOT supported. + * + * @param newCapacity Must be in in the range [0, length). + */ + public synchronized ArrowBuf capacity(long newCapacity) { + + if (newCapacity == capacity) { + return this; + } + + Preconditions.checkArgument(newCapacity >= 0); + + if (newCapacity < capacity) { + capacity = newCapacity; + return this; + } + + throw new UnsupportedOperationException( + "Buffers don't support resizing that increases the size."); + } + + /** Returns the byte order of elements in this buffer. */ + public ByteOrder order() { + return ByteOrder.nativeOrder(); + } + + /** Returns the number of bytes still available to read in this buffer. */ + public long readableBytes() { + Preconditions.checkState( + writerIndex >= readerIndex, "Writer index cannot be less than reader index"); + return writerIndex - readerIndex; + } + + /** + * Returns the number of bytes still available to write into this buffer before capacity is + * reached. + */ + public long writableBytes() { + return capacity() - writerIndex; + } + + /** Returns a slice of only the readable bytes in the buffer. */ + public ArrowBuf slice() { + return slice(readerIndex, readableBytes()); + } + + /** Returns a slice (view) starting at index with the given length. */ + public ArrowBuf slice(long index, long length) { + + Preconditions.checkPositionIndex(index, this.capacity); + Preconditions.checkPositionIndex(index + length, this.capacity); + + /* + * Re the behavior of reference counting, see http://netty.io/wiki/reference-counted-objects + * .html#wiki-h3-5, which + * explains that derived buffers share their reference count with their parent + */ + final ArrowBuf newBuf = referenceManager.deriveBuffer(this, index, length); + newBuf.writerIndex(length); + return newBuf; + } + + /** Make a nio byte buffer from this arrowbuf. */ + public ByteBuffer nioBuffer() { + return nioBuffer(readerIndex, checkedCastToInt(readableBytes())); + } + + /** Make a nio byte buffer from this ArrowBuf. */ + public ByteBuffer nioBuffer(long index, int length) { + chk(index, length); + return getDirectBuffer(index, length); + } + + private ByteBuffer getDirectBuffer(long index, int length) { + long address = addr(index); + return MemoryUtil.directBuffer(address, length); + } + + public long memoryAddress() { + return this.addr; + } + + @Override + public String toString() { + return String.format("ArrowBuf[%d], address:%d, capacity:%d", id, memoryAddress(), capacity); + } + + @Override + public int hashCode() { + return System.identityHashCode(this); + } + + @Override + public boolean equals(Object obj) { + // identity equals only. + return this == obj; + } + + /* + * IMPORTANT NOTE + * The data getters and setters work with a caller provided + * index. This index is 0 based and since ArrowBuf has access + * to a portion of underlying chunk of memory starting at + * some address, we convert the given relative index into + * absolute index as memory address + index. + * + * Example: + * + * Let's say we have an underlying chunk of memory of length 64 bytes + * Now let's say we have an ArrowBuf that has access to the chunk + * from offset 4 for length of 16 bytes. + * + * If the starting virtual address of chunk is MAR, then memory + * address of this ArrowBuf is MAR + offset -- this is what is stored + * in variable addr. See the BufferLedger and AllocationManager code + * for the implementation of ReferenceManager that manages a + * chunk of memory and creates ArrowBuf with access to a range of + * bytes within the chunk (or the entire chunk) + * + * So now to get/set data, we will do => addr + index + * This logic is put in method addr(index) and is frequently + * used in get/set data methods to compute the absolute + * byte address for get/set operation in the underlying chunk + * + * @param index the index at which we the user wants to read/write + * @return the absolute address within the memory + */ + private long addr(long index) { + return addr + index; + } + + /*-------------------------------------------------* + | Following are a set of fast path data set and | + | get APIs to write/read data from ArrowBuf | + | at a given index (0 based relative to this | + | ArrowBuf and not relative to the underlying | + | memory chunk). | + | | + *-------------------------------------------------*/ + + /** + * Helper function to do bounds checking at a particular index for particular length of data. + * + * @param index index (0 based relative to this ArrowBuf) + * @param length provided length of data for get/set + */ + private void chk(long index, long length) { + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + checkIndexD(index, length); + } + } + + private void checkIndexD(long index, long fieldLength) { + // check reference count + ensureAccessible(); + // check bounds + Preconditions.checkArgument(fieldLength >= 0, "expecting non-negative data length"); + if (index < 0 || index > capacity() - fieldLength) { + if (historicalLog != null) { + historicalLog.logHistory(logger); + } + throw new IndexOutOfBoundsException( + String.format( + "index: %d, length: %d (expected: range(0, %d))", index, fieldLength, capacity())); + } + } + + /** + * Get long value stored at a particular index in the underlying memory chunk this ArrowBuf has + * access to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be read from + * @return 8 byte long value + */ + public long getLong(long index) { + chk(index, LONG_SIZE); + return MemoryUtil.getLong(addr(index)); + } + + /** + * Set long value at a particular index in the underlying memory chunk this ArrowBuf has access + * to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be written + * @param value value to write + */ + public void setLong(long index, long value) { + chk(index, LONG_SIZE); + MemoryUtil.putLong(addr(index), value); + } + + /** + * Get float value stored at a particular index in the underlying memory chunk this ArrowBuf has + * access to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be read from + * @return 4 byte float value + */ + public float getFloat(long index) { + return Float.intBitsToFloat(getInt(index)); + } + + /** + * Set float value at a particular index in the underlying memory chunk this ArrowBuf has access + * to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be written + * @param value value to write + */ + public void setFloat(long index, float value) { + chk(index, FLOAT_SIZE); + MemoryUtil.putInt(addr(index), Float.floatToRawIntBits(value)); + } + + /** + * Get double value stored at a particular index in the underlying memory chunk this ArrowBuf has + * access to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be read from + * @return 8 byte double value + */ + public double getDouble(long index) { + return Double.longBitsToDouble(getLong(index)); + } + + /** + * Set double value at a particular index in the underlying memory chunk this ArrowBuf has access + * to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be written + * @param value value to write + */ + public void setDouble(long index, double value) { + chk(index, DOUBLE_SIZE); + MemoryUtil.putLong(addr(index), Double.doubleToRawLongBits(value)); + } + + /** + * Get char value stored at a particular index in the underlying memory chunk this ArrowBuf has + * access to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be read from + * @return 2 byte char value + */ + public char getChar(long index) { + return (char) getShort(index); + } + + /** + * Set char value at a particular index in the underlying memory chunk this ArrowBuf has access + * to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be written + * @param value value to write + */ + public void setChar(long index, int value) { + chk(index, SHORT_SIZE); + MemoryUtil.putShort(addr(index), (short) value); + } + + /** + * Get int value stored at a particular index in the underlying memory chunk this ArrowBuf has + * access to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be read from + * @return 4 byte int value + */ + public int getInt(long index) { + chk(index, INT_SIZE); + return MemoryUtil.getInt(addr(index)); + } + + /** + * Set int value at a particular index in the underlying memory chunk this ArrowBuf has access to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be written + * @param value value to write + */ + public void setInt(long index, int value) { + chk(index, INT_SIZE); + MemoryUtil.putInt(addr(index), value); + } + + /** + * Get short value stored at a particular index in the underlying memory chunk this ArrowBuf has + * access to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be read from + * @return 2 byte short value + */ + public short getShort(long index) { + chk(index, SHORT_SIZE); + return MemoryUtil.getShort(addr(index)); + } + + /** + * Set short value at a particular index in the underlying memory chunk this ArrowBuf has access + * to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be written + * @param value value to write + */ + public void setShort(long index, int value) { + setShort(index, (short) value); + } + + /** + * Set short value at a particular index in the underlying memory chunk this ArrowBuf has access + * to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be written + * @param value value to write + */ + public void setShort(long index, short value) { + chk(index, SHORT_SIZE); + MemoryUtil.putShort(addr(index), value); + } + + /** + * Set byte value at a particular index in the underlying memory chunk this ArrowBuf has access + * to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be written + * @param value value to write + */ + public void setByte(long index, int value) { + chk(index, 1); + MemoryUtil.putByte(addr(index), (byte) value); + } + + /** + * Set byte value at a particular index in the underlying memory chunk this ArrowBuf has access + * to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be written + * @param value value to write + */ + public void setByte(long index, byte value) { + chk(index, 1); + MemoryUtil.putByte(addr(index), value); + } + + /** + * Get byte value stored at a particular index in the underlying memory chunk this ArrowBuf has + * access to. + * + * @param index index (0 based relative to this ArrowBuf) where the value will be read from + * @return byte value + */ + public byte getByte(long index) { + chk(index, 1); + return MemoryUtil.getByte(addr(index)); + } + + /*--------------------------------------------------* + | Following are another set of data set APIs | + | that directly work with writerIndex | + | | + *--------------------------------------------------*/ + + /** + * Helper function to do bound checking w.r.t writerIndex by checking if we can set "length" bytes + * of data at the writerIndex in this ArrowBuf. + * + * @param length provided length of data for set + */ + private void ensureWritable(final int length) { + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + Preconditions.checkArgument(length >= 0, "expecting non-negative length"); + // check reference count + this.ensureAccessible(); + // check bounds + if (length > writableBytes()) { + throw new IndexOutOfBoundsException( + String.format( + "writerIndex(%d) + length(%d) exceeds capacity(%d)", + writerIndex, length, capacity())); + } + } + } + + /** + * Helper function to do bound checking w.r.t readerIndex by checking if we can read "length" + * bytes of data at the readerIndex in this ArrowBuf. + * + * @param length provided length of data for get + */ + private void ensureReadable(final int length) { + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + Preconditions.checkArgument(length >= 0, "expecting non-negative length"); + // check reference count + this.ensureAccessible(); + // check bounds + if (length > readableBytes()) { + throw new IndexOutOfBoundsException( + String.format( + "readerIndex(%d) + length(%d) exceeds writerIndex(%d)", + readerIndex, length, writerIndex)); + } + } + } + + /** + * Read the byte at readerIndex. + * + * @return byte value + */ + public byte readByte() { + ensureReadable(1); + final byte b = getByte(readerIndex); + ++readerIndex; + return b; + } + + /** + * Read dst.length bytes at readerIndex into dst byte array. + * + * @param dst byte array where the data will be written + */ + public void readBytes(byte[] dst) { + Preconditions.checkArgument(dst != null, "expecting valid dst bytearray"); + ensureReadable(dst.length); + getBytes(readerIndex, dst, 0, checkedCastToInt(dst.length)); + } + + /** + * Set the provided byte value at the writerIndex. + * + * @param value value to set + */ + public void writeByte(byte value) { + ensureWritable(1); + MemoryUtil.putByte(addr(writerIndex), value); + ++writerIndex; + } + + /** + * Set the lower order byte for the provided value at the writerIndex. + * + * @param value value to be set + */ + public void writeByte(int value) { + ensureWritable(1); + MemoryUtil.putByte(addr(writerIndex), (byte) value); + ++writerIndex; + } + + /** + * Write the bytes from given byte array into this ArrowBuf starting at writerIndex. + * + * @param src src byte array + */ + public void writeBytes(byte[] src) { + Preconditions.checkArgument(src != null, "expecting valid src array"); + writeBytes(src, 0, src.length); + } + + /** + * Write the bytes from given byte array starting at srcIndex into this ArrowBuf starting at + * writerIndex. + * + * @param src src byte array + * @param srcIndex index in the byte array where the copy will being from + * @param length length of data to copy + */ + public void writeBytes(byte[] src, int srcIndex, int length) { + ensureWritable(length); + setBytes(writerIndex, src, srcIndex, length); + writerIndex += length; + } + + /** + * Set the provided int value as short at the writerIndex. + * + * @param value value to set + */ + public void writeShort(int value) { + ensureWritable(SHORT_SIZE); + MemoryUtil.putShort(addr(writerIndex), (short) value); + writerIndex += SHORT_SIZE; + } + + /** + * Set the provided int value at the writerIndex. + * + * @param value value to set + */ + public void writeInt(int value) { + ensureWritable(INT_SIZE); + MemoryUtil.putInt(addr(writerIndex), value); + writerIndex += INT_SIZE; + } + + /** + * Set the provided long value at the writerIndex. + * + * @param value value to set + */ + public void writeLong(long value) { + ensureWritable(LONG_SIZE); + MemoryUtil.putLong(addr(writerIndex), value); + writerIndex += LONG_SIZE; + } + + /** + * Set the provided float value at the writerIndex. + * + * @param value value to set + */ + public void writeFloat(float value) { + ensureWritable(FLOAT_SIZE); + MemoryUtil.putInt(addr(writerIndex), Float.floatToRawIntBits(value)); + writerIndex += FLOAT_SIZE; + } + + /** + * Set the provided double value at the writerIndex. + * + * @param value value to set + */ + public void writeDouble(double value) { + ensureWritable(DOUBLE_SIZE); + MemoryUtil.putLong(addr(writerIndex), Double.doubleToRawLongBits(value)); + writerIndex += DOUBLE_SIZE; + } + + /*--------------------------------------------------* + | Following are another set of data set/get APIs | + | that read and write stream of bytes from/to byte | + | arrays, ByteBuffer, ArrowBuf etc | + | | + *--------------------------------------------------*/ + + /** + * Determine if the requested {@code index} and {@code length} will fit within {@code capacity}. + * + * @param index The starting index. + * @param length The length which will be utilized (starting from {@code index}). + * @param capacity The capacity that {@code index + length} is allowed to be within. + * @return {@code true} if the requested {@code index} and {@code length} will fit within {@code + * capacity}. {@code false} if this would result in an index out of bounds exception. + */ + private static boolean isOutOfBounds(long index, long length, long capacity) { + return (index | length | (index + length) | (capacity - (index + length))) < 0; + } + + private void checkIndex(long index, long fieldLength) { + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + // check reference count + this.ensureAccessible(); + // check bounds + if (isOutOfBounds(index, fieldLength, this.capacity())) { + throw new IndexOutOfBoundsException( + String.format( + "index: %d, length: %d (expected: range(0, %d))", + index, fieldLength, this.capacity())); + } + } + } + + /** + * Copy data from this ArrowBuf at a given index in into destination byte array. + * + * @param index starting index (0 based relative to the portion of memory) this ArrowBuf has + * access to + * @param dst byte array to copy the data into + */ + public void getBytes(long index, byte[] dst) { + getBytes(index, dst, 0, dst.length); + } + + /** + * Copy data from this ArrowBuf at a given index into destination byte array. + * + * @param index index (0 based relative to the portion of memory this ArrowBuf has access to) + * @param dst byte array to copy the data into + * @param dstIndex starting index in dst byte array to copy into + * @param length length of data to copy from this ArrowBuf + */ + public void getBytes(long index, byte[] dst, int dstIndex, int length) { + // bound check for this ArrowBuf where the data will be copied from + checkIndex(index, length); + // null check + Preconditions.checkArgument(dst != null, "expecting a valid dst byte array"); + // bound check for dst byte array where the data will be copied to + if (isOutOfBounds(dstIndex, length, dst.length)) { + // not enough space to copy "length" bytes into dst array from dstIndex onwards + throw new IndexOutOfBoundsException( + "Not enough space to copy data into destination" + dstIndex); + } + if (length != 0) { + // copy "length" bytes from this ArrowBuf starting at addr(index) address + // into dst byte array at dstIndex onwards + MemoryUtil.copyFromMemory(addr(index), dst, dstIndex, length); + } + } + + /** + * Copy data from a given byte array into this ArrowBuf starting at a given index. + * + * @param index starting index (0 based relative to the portion of memory) this ArrowBuf has + * access to + * @param src byte array to copy the data from + */ + public void setBytes(long index, byte[] src) { + setBytes(index, src, 0, src.length); + } + + /** + * Copy data from a given byte array starting at the given source index into this ArrowBuf at a + * given index. + * + * @param index index (0 based relative to the portion of memory this ArrowBuf has access to) + * @param src src byte array to copy the data from + * @param srcIndex index in the byte array where the copy will start from + * @param length length of data to copy from byte array + */ + public void setBytes(long index, byte[] src, int srcIndex, long length) { + // bound check for this ArrowBuf where the data will be copied into + checkIndex(index, length); + // null check + Preconditions.checkArgument(src != null, "expecting a valid src byte array"); + // bound check for src byte array where the data will be copied from + if (isOutOfBounds(srcIndex, length, src.length)) { + // not enough space to copy "length" bytes into dst array from dstIndex onwards + throw new IndexOutOfBoundsException( + "Not enough space to copy data from byte array" + srcIndex); + } + if (length > 0) { + // copy "length" bytes from src byte array at the starting index (srcIndex) + // into this ArrowBuf starting at address "addr(index)" + MemoryUtil.copyToMemory(src, srcIndex, addr(index), length); + } + } + + /** + * Copy data from this ArrowBuf at a given index into the destination ByteBuffer. + * + * @param index index (0 based relative to the portion of memory this ArrowBuf has access to) + * @param dst dst ByteBuffer where the data will be copied into + */ + public void getBytes(long index, ByteBuffer dst) { + // bound check for this ArrowBuf where the data will be copied from + checkIndex(index, dst.remaining()); + // dst.remaining() bytes of data will be copied into dst ByteBuffer + if (dst.remaining() != 0) { + // address in this ArrowBuf where the copy will begin from + final long srcAddress = addr(index); + if (dst.isDirect()) { + if (dst.isReadOnly()) { + throw new ReadOnlyBufferException(); + } + // copy dst.remaining() bytes of data from this ArrowBuf starting + // at address srcAddress into the dst ByteBuffer starting at + // address dstAddress + final long dstAddress = MemoryUtil.getByteBufferAddress(dst) + dst.position(); + MemoryUtil.copyMemory(srcAddress, dstAddress, dst.remaining()); + // after copy, bump the next write position for the dst ByteBuffer + dst.position(dst.position() + dst.remaining()); + } else if (dst.hasArray()) { + // copy dst.remaining() bytes of data from this ArrowBuf starting + // at address srcAddress into the dst ByteBuffer starting at + // index dstIndex + final int dstIndex = dst.arrayOffset() + dst.position(); + MemoryUtil.copyFromMemory(srcAddress, dst.array(), dstIndex, dst.remaining()); + // after copy, bump the next write position for the dst ByteBuffer + dst.position(dst.position() + dst.remaining()); + } else { + throw new UnsupportedOperationException( + "Copy from this ArrowBuf to ByteBuffer is not supported"); + } + } + } + + /** + * Copy data into this ArrowBuf at a given index onwards from a source ByteBuffer. + * + * @param index index index (0 based relative to the portion of memory this ArrowBuf has access + * to) + * @param src src ByteBuffer where the data will be copied from + */ + public void setBytes(long index, ByteBuffer src) { + // bound check for this ArrowBuf where the data will be copied into + checkIndex(index, src.remaining()); + // length of data to copy + int length = src.remaining(); + // address in this ArrowBuf where the data will be copied to + long dstAddress = addr(index); + if (length != 0) { + if (src.isDirect()) { + // copy src.remaining() bytes of data from src ByteBuffer starting at + // address srcAddress into this ArrowBuf starting at address dstAddress + final long srcAddress = MemoryUtil.getByteBufferAddress(src) + src.position(); + MemoryUtil.copyMemory(srcAddress, dstAddress, length); + // after copy, bump the next read position for the src ByteBuffer + src.position(src.position() + length); + } else if (src.hasArray()) { + // copy src.remaining() bytes of data from src ByteBuffer starting at + // index srcIndex into this ArrowBuf starting at address dstAddress + final int srcIndex = src.arrayOffset() + src.position(); + MemoryUtil.copyToMemory(src.array(), srcIndex, dstAddress, length); + // after copy, bump the next read position for the src ByteBuffer + src.position(src.position() + length); + } else { + final ByteOrder originalByteOrder = src.order(); + src.order(order()); + try { + // copy word at a time + while (length - 128 >= LONG_SIZE) { + for (int x = 0; x < 16; x++) { + MemoryUtil.putLong(dstAddress, src.getLong()); + length -= LONG_SIZE; + dstAddress += LONG_SIZE; + } + } + while (length >= LONG_SIZE) { + MemoryUtil.putLong(dstAddress, src.getLong()); + length -= LONG_SIZE; + dstAddress += LONG_SIZE; + } + // copy last byte + while (length > 0) { + MemoryUtil.putByte(dstAddress, src.get()); + --length; + ++dstAddress; + } + } finally { + src.order(originalByteOrder); + } + } + } + } + + /** + * Copy data into this ArrowBuf at a given index onwards from a source ByteBuffer starting at a + * given srcIndex for a certain length. + * + * @param index index (0 based relative to the portion of memory this ArrowBuf has access to) + * @param src src ByteBuffer where the data will be copied from + * @param srcIndex starting index in the src ByteBuffer where the data copy will start from + * @param length length of data to copy from src ByteBuffer + */ + public void setBytes(long index, ByteBuffer src, int srcIndex, int length) { + // bound check for this ArrowBuf where the data will be copied into + checkIndex(index, length); + if (src.isDirect()) { + // copy length bytes of data from src ByteBuffer starting at address + // srcAddress into this ArrowBuf at address dstAddress + final long srcAddress = MemoryUtil.getByteBufferAddress(src) + srcIndex; + final long dstAddress = addr(index); + MemoryUtil.copyMemory(srcAddress, dstAddress, length); + } else { + if (srcIndex == 0 && src.capacity() == length) { + // copy the entire ByteBuffer from start to end of length + setBytes(index, src); + } else { + ByteBuffer newBuf = src.duplicate(); + newBuf.position(srcIndex); + newBuf.limit(srcIndex + length); + setBytes(index, newBuf); + } + } + } + + /** + * Copy a given length of data from this ArrowBuf starting at a given index into a dst ArrowBuf at + * dstIndex. + * + * @param index index (0 based relative to the portion of memory this ArrowBuf has access to) + * @param dst dst ArrowBuf where the data will be copied into + * @param dstIndex index (0 based relative to the portion of memory dst ArrowBuf has access to) + * @param length length of data to copy + */ + public void getBytes(long index, ArrowBuf dst, long dstIndex, int length) { + // bound check for this ArrowBuf where the data will be copied from + checkIndex(index, length); + // bound check for this ArrowBuf where the data will be copied into + Preconditions.checkArgument(dst != null, "expecting a valid ArrowBuf"); + // bound check for dst ArrowBuf + if (isOutOfBounds(dstIndex, length, dst.capacity())) { + throw new IndexOutOfBoundsException( + String.format( + "index: %d, length: %d (expected: range(0, %d))", dstIndex, length, dst.capacity())); + } + if (length != 0) { + // copy length bytes of data from this ArrowBuf starting at + // address srcAddress into dst ArrowBuf starting at address + // dstAddress + final long srcAddress = addr(index); + final long dstAddress = dst.memoryAddress() + (long) dstIndex; + MemoryUtil.copyMemory(srcAddress, dstAddress, length); + } + } + + /** + * Copy data from src ArrowBuf starting at index srcIndex into this ArrowBuf at given index. + * + * @param index index index (0 based relative to the portion of memory this ArrowBuf has access + * to) + * @param src src ArrowBuf where the data will be copied from + * @param srcIndex starting index in the src ArrowBuf where the copy will begin from + * @param length length of data to copy from src ArrowBuf + */ + public void setBytes(long index, ArrowBuf src, long srcIndex, long length) { + // bound check for this ArrowBuf where the data will be copied into + checkIndex(index, length); + // null check + Preconditions.checkArgument(src != null, "expecting a valid ArrowBuf"); + // bound check for src ArrowBuf + if (isOutOfBounds(srcIndex, length, src.capacity())) { + throw new IndexOutOfBoundsException( + String.format( + "index: %d, length: %d (expected: range(0, %d))", srcIndex, length, src.capacity())); + } + if (length != 0) { + // copy length bytes of data from src ArrowBuf starting at + // address srcAddress into this ArrowBuf starting at address + // dstAddress + final long srcAddress = src.memoryAddress() + srcIndex; + final long dstAddress = addr(index); + MemoryUtil.copyMemory(srcAddress, dstAddress, length); + } + } + + /** + * Copy readableBytes() number of bytes from src ArrowBuf starting from its readerIndex into this + * ArrowBuf starting at the given index. + * + * @param index index index (0 based relative to the portion of memory this ArrowBuf has access + * to) + * @param src src ArrowBuf where the data will be copied from + */ + public void setBytes(long index, ArrowBuf src) { + // null check + Preconditions.checkArgument(src != null, "expecting valid ArrowBuf"); + final long length = src.readableBytes(); + // bound check for this ArrowBuf where the data will be copied into + checkIndex(index, length); + final long srcAddress = src.memoryAddress() + src.readerIndex; + final long dstAddress = addr(index); + MemoryUtil.copyMemory(srcAddress, dstAddress, length); + src.readerIndex(src.readerIndex + length); + } + + /** + * Copy a certain length of bytes from given InputStream into this ArrowBuf at the provided index. + * + * @param index index index (0 based relative to the portion of memory this ArrowBuf has access + * to) + * @param in src stream to copy from + * @param length length of data to copy + * @return number of bytes copied from stream into ArrowBuf + * @throws IOException on failing to read from stream + */ + public int setBytes(long index, InputStream in, int length) throws IOException { + Preconditions.checkArgument(in != null, "expecting valid input stream"); + checkIndex(index, length); + int readBytes = 0; + if (length > 0) { + byte[] tmp = new byte[length]; + // read the data from input stream into tmp byte array + readBytes = in.read(tmp); + if (readBytes > 0) { + // copy readBytes length of data from the tmp byte array starting + // at srcIndex 0 into this ArrowBuf starting at address addr(index) + MemoryUtil.copyToMemory(tmp, 0, addr(index), readBytes); + } + } + return readBytes; + } + + /** + * Copy a certain length of bytes from this ArrowBuf at a given index into the given OutputStream. + * + * @param index index index (0 based relative to the portion of memory this ArrowBuf has access + * to) + * @param out dst stream to copy data into + * @param length length of data to copy + * @throws IOException on failing to write to stream + */ + public void getBytes(long index, OutputStream out, int length) throws IOException { + Preconditions.checkArgument(out != null, "expecting valid output stream"); + checkIndex(index, length); + if (length > 0) { + // copy length bytes of data from this ArrowBuf starting at + // address addr(index) into the tmp byte array starting at index 0 + byte[] tmp = new byte[length]; + MemoryUtil.copyFromMemory(addr(index), tmp, 0, length); + // write the copied data to output stream + out.write(tmp); + } + } + + @Override + public void close() { + referenceManager.release(); + } + + /** + * Returns the possible memory consumed by this ArrowBuf in the worse case scenario. (not shared, + * connected to larger underlying buffer of allocated memory) + * + * @return Size in bytes. + */ + public long getPossibleMemoryConsumed() { + return referenceManager.getSize(); + } + + /** + * Return that is Accounted for by this buffer (and its potentially shared siblings within the + * context of the associated allocator). + * + * @return Size in bytes. + */ + public long getActualMemoryConsumed() { + return referenceManager.getAccountedSize(); + } + + /** + * Return the buffer's byte contents in the form of a hex dump. + * + * @param start the starting byte index + * @param length how many bytes to log + * @return A hex dump in a String. + */ + public String toHexString(final long start, final int length) { + final long roundedStart = (start / LOG_BYTES_PER_ROW) * LOG_BYTES_PER_ROW; + + final StringBuilder sb = new StringBuilder("buffer byte dump\n"); + long index = roundedStart; + for (long nLogged = 0; nLogged < length; nLogged += LOG_BYTES_PER_ROW) { + sb.append(String.format(" [%05d-%05d]", index, index + LOG_BYTES_PER_ROW - 1)); + for (int i = 0; i < LOG_BYTES_PER_ROW; ++i) { + try { + final byte b = getByte(index++); + sb.append(String.format(" 0x%02x", b)); + } catch (IndexOutOfBoundsException ioob) { + sb.append(" "); + } + } + sb.append('\n'); + } + return sb.toString(); + } + + /** + * Get the integer id assigned to this ArrowBuf for debugging purposes. + * + * @return integer id + */ + public long getId() { + return id; + } + + /** + * Print information of this buffer into sb at the given indentation and verbosity + * level. + * + *

It will include history if BaseAllocator.DEBUG is true and the + * verbosity.includeHistoricalLog are true. + */ + // ---- Databricks patch start ---- + // ---- Modify method visibility public -> private. It was annotated with @VisibleForTesting. + // ---- This ensures that DatabricksBufferAllocator need not have a public method with + // ---- Verbosity method parameter which might trigger unsafe path class loading and fail on + // ---- JVM 16+ which does not have "--add-opens=java.base/java.nio=ALL-UNNAMED" jvm arg present. + private void print(StringBuilder sb, int indent, Verbosity verbosity) { + // ---- Databricks patch end ---- + CommonUtil.indent(sb, indent).append(toString()); + + if (historicalLog != null && verbosity.includeHistoricalLog) { + sb.append("\n"); + historicalLog.buildHistory(sb, indent + 1, verbosity.includeStackTraces); + } + } + + /** + * Print detailed information of this buffer into sb. + * + *

Most information will only be present if BaseAllocator.DEBUG is true. + */ + public void print(StringBuilder sb, int indent) { + print(sb, indent, Verbosity.LOG_WITH_STACKTRACE); + } + + /** + * Get the index at which the next byte will be read from. + * + * @return reader index + */ + public long readerIndex() { + return readerIndex; + } + + /** + * Get the index at which next byte will be written to. + * + * @return writer index + */ + public long writerIndex() { + return writerIndex; + } + + /** + * Set the reader index for this ArrowBuf. + * + * @param readerIndex new reader index + * @return this ArrowBuf + */ + public ArrowBuf readerIndex(long readerIndex) { + this.readerIndex = readerIndex; + return this; + } + + /** + * Set the writer index for this ArrowBuf. + * + * @param writerIndex new writer index + * @return this ArrowBuf + */ + public ArrowBuf writerIndex(long writerIndex) { + this.writerIndex = writerIndex; + return this; + } + + /** + * Zero-out the bytes in this ArrowBuf starting at the given index for the given length. + * + * @param index index index (0 based relative to the portion of memory this ArrowBuf has access + * to) + * @param length length of bytes to zero-out + * @return this ArrowBuf + */ + public ArrowBuf setZero(long index, long length) { + if (length != 0) { + this.checkIndex(index, length); + MemoryUtil.setMemory(this.addr + index, length, (byte) 0); + } + return this; + } + + /** + * Sets all bits to one in the specified range. + * + * @param index index index (0 based relative to the portion of memory this ArrowBuf has access + * to) + * @param length length of bytes to set. + * @return this ArrowBuf + * @deprecated use {@link ArrowBuf#setOne(long, long)} instead. + */ + @Deprecated + public ArrowBuf setOne(int index, int length) { + if (length != 0) { + this.checkIndex(index, length); + MemoryUtil.setMemory(this.addr + index, length, (byte) 0xff); + } + return this; + } + + /** + * Sets all bits to one in the specified range. + * + * @param index index index (0 based relative to the portion of memory this ArrowBuf has access + * to) + * @param length length of bytes to set. + * @return this ArrowBuf + */ + public ArrowBuf setOne(long index, long length) { + if (length != 0) { + this.checkIndex(index, length); + MemoryUtil.setMemory(this.addr + index, length, (byte) 0xff); + } + return this; + } + + /** + * Returns this if size is less than {@link #capacity()}, otherwise delegates to + * {@link BufferManager#replace(ArrowBuf, long)} to get a new buffer. + */ + public ArrowBuf reallocIfNeeded(final long size) { + Preconditions.checkArgument(size >= 0, "reallocation size must be non-negative"); + if (this.capacity() >= size) { + return this; + } + if (bufferManager != null) { + return bufferManager.replace(this, size); + } else { + throw new UnsupportedOperationException( + "Realloc is only available in the context of operator's UDFs"); + } + } + + public ArrowBuf clear() { + this.readerIndex = this.writerIndex = 0; + return this; + } +} diff --git a/src/main/java/org/apache/arrow/memory/DatabricksAllocationReservation.java b/src/main/java/org/apache/arrow/memory/DatabricksAllocationReservation.java new file mode 100644 index 000000000..353d45e55 --- /dev/null +++ b/src/main/java/org/apache/arrow/memory/DatabricksAllocationReservation.java @@ -0,0 +1,104 @@ +package org.apache.arrow.memory; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +/** An AllocationReservation implementation for cumulative allocation requests. */ +class DatabricksAllocationReservation implements AllocationReservation { + + private final DatabricksBufferAllocator allocator; + private final AtomicLong reservedSize = new AtomicLong(0); + private final AtomicBoolean used = new AtomicBoolean(false); + private final AtomicBoolean closed = new AtomicBoolean(false); + + public DatabricksAllocationReservation(DatabricksBufferAllocator allocator) { + this.allocator = allocator; + } + + @SuppressWarnings("removal") + @Override + @Deprecated + public boolean add(int nBytes) { + return add((long) nBytes); + } + + @Override + public boolean add(long nBytes) { + assertNotUsed(); + if (nBytes < 0) { + return false; + } + reservedSize.addAndGet(nBytes); + return true; + } + + @SuppressWarnings("removal") + @Override + @Deprecated + public boolean reserve(int nBytes) { + return reserve((long) nBytes); + } + + @Override + public boolean reserve(long nBytes) { + assertNotUsed(); + if (nBytes < 0) { + return false; + } + // Check if reservation would exceed limits + long currentReservation = reservedSize.get(); + long newReservation = currentReservation + nBytes; + if (newReservation > allocator.getHeadroom() + currentReservation) { + return false; + } + reservedSize.addAndGet(nBytes); + return true; + } + + @Override + public ArrowBuf allocateBuffer() { + assertNotUsed(); + if (!used.compareAndSet(false, true)) { + throw new IllegalStateException("Reservation already used"); + } + long size = reservedSize.get(); + if (size == 0) { + return allocator.getEmpty(); + } + return allocator.buffer(size); + } + + @Override + public int getSize() { + return (int) Math.min(reservedSize.get(), Integer.MAX_VALUE); + } + + @Override + public long getSizeLong() { + return reservedSize.get(); + } + + @Override + public boolean isUsed() { + return used.get(); + } + + @Override + public boolean isClosed() { + return closed.get(); + } + + @Override + public void close() { + closed.set(true); + } + + private void assertNotUsed() { + if (used.get()) { + throw new IllegalStateException("Reservation already used"); + } + if (closed.get()) { + throw new IllegalStateException("Reservation is closed"); + } + } +} diff --git a/src/main/java/org/apache/arrow/memory/DatabricksArrowBuf.java b/src/main/java/org/apache/arrow/memory/DatabricksArrowBuf.java new file mode 100644 index 000000000..a03bc7a07 --- /dev/null +++ b/src/main/java/org/apache/arrow/memory/DatabricksArrowBuf.java @@ -0,0 +1,736 @@ +package org.apache.arrow.memory; + +import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.arrow.memory.util.CommonUtil; +import org.apache.arrow.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A ByteBuffer-backed implementation of ArrowBuf that does not use unsafe memory operations. This + * implementation uses standard java.nio.ByteBuffer for all memory operations instead of + * MemoryUtil/Unsafe-based direct memory access. + */ +public class DatabricksArrowBuf extends ArrowBuf { + + private static final Logger logger = LoggerFactory.getLogger(DatabricksArrowBuf.class); + + /** Generate unique id for each buffer. Helpful in tracing logs. */ + private static final AtomicLong bufferId = new AtomicLong(0); + + private static final int SHORT_SIZE = Short.BYTES; + private static final int INT_SIZE = Integer.BYTES; + private static final int FLOAT_SIZE = Float.BYTES; + private static final int DOUBLE_SIZE = Double.BYTES; + private static final int LONG_SIZE = Long.BYTES; + private static final int LOG_BYTES_PER_ROW = 10; + + private final ByteBuffer byteBuffer; + private final ReferenceManager referenceManager; + private final BufferManager bufferManager; + private final int offset; // offset within the underlying ByteBuffer for sliced buffers + private volatile long capacity; + private long readerIndex; + private long writerIndex; + + /** Memory address used to instantiate the super class {@code ArrowBuf}. Unused in this class. */ + private static final int MEMORY_ADDRESS = 0; + + /** ArrowBuf uses native order, copying the same logic here. */ + private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder(); + + /** + * Constructs a new DatabricksArrowBuf backed by a heap ByteBuffer. + * + * @param referenceManager The memory manager to track memory usage and reference count + * @param bufferManager The buffer manager for reallocation support + * @param capacity The capacity in bytes of this buffer + */ + public DatabricksArrowBuf( + ReferenceManager referenceManager, BufferManager bufferManager, long capacity) { + super(referenceManager, bufferManager, capacity, MEMORY_ADDRESS); + + this.referenceManager = referenceManager; + this.bufferManager = bufferManager; + this.capacity = capacity; + this.offset = 0; + this.readerIndex = 0; + this.writerIndex = 0; + + if (capacity > Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "DatabricksArrowBuf does not support capacity > Integer.MAX_VALUE"); + } + + this.byteBuffer = ByteBuffer.allocate((int) capacity); + this.byteBuffer.order(BYTE_ORDER); + } + + /** + * Constructor for creating sliced views or derived buffers that share an underlying ByteBuffer. + * + * @param referenceManager The memory manager + * @param bufferManager The buffer manager + * @param byteBuffer The underlying ByteBuffer (shared with parent) + * @param offset The offset within the ByteBuffer + * @param capacity The capacity in bytes of this buffer + */ + DatabricksArrowBuf( + ReferenceManager referenceManager, + BufferManager bufferManager, + ByteBuffer byteBuffer, + int offset, + long capacity) { + super(referenceManager, bufferManager, capacity, MEMORY_ADDRESS); + + this.referenceManager = referenceManager; + this.bufferManager = bufferManager; + this.byteBuffer = byteBuffer; + this.offset = offset; + this.capacity = capacity; + this.readerIndex = 0; + this.writerIndex = 0; + } + + @Override + public int refCnt() { + return referenceManager.getRefCount(); + } + + @Override + public void checkBytes(long start, long end) { + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + checkIndexD(start, end - start); + } + } + + @Override + public ReferenceManager getReferenceManager() { + return referenceManager; + } + + @Override + public long capacity() { + return capacity; + } + + @Override + public synchronized ArrowBuf capacity(long newCapacity) { + if (newCapacity == capacity) { + return this; + } + + Preconditions.checkArgument(newCapacity >= 0); + + if (newCapacity >= capacity) { + throw new UnsupportedOperationException( + "Buffers don't support resizing that increases the size."); + } + + this.capacity = newCapacity; + return this; + } + + @Override + public ByteOrder order() { + return BYTE_ORDER; + } + + @Override + public long readableBytes() { + Preconditions.checkState( + writerIndex >= readerIndex, "Writer index cannot be less than reader index"); + return writerIndex - readerIndex; + } + + @Override + public long writableBytes() { + return capacity() - writerIndex; + } + + @Override + public ArrowBuf slice() { + return slice(readerIndex, readableBytes()); + } + + @Override + public ArrowBuf slice(long index, long length) { + Preconditions.checkPositionIndex(index, this.capacity); + Preconditions.checkPositionIndex(index + length, this.capacity); + + // Delegate to reference manager's deriveBuffer to ensure consistent behavior + // with reference counting semantics (derived buffers share ref count with parent) + final ArrowBuf newBuf = referenceManager.deriveBuffer(this, index, length); + newBuf.writerIndex(length); + return newBuf; + } + + @Override + public ByteBuffer nioBuffer() { + return nioBuffer(readerIndex, checkedCastToInt(readableBytes())); + } + + @Override + public ByteBuffer nioBuffer(long index, int length) { + chk(index, length); + ByteBuffer duplicate = byteBuffer.duplicate(); + duplicate.order(BYTE_ORDER); + duplicate.position(offset + (int) index); + duplicate.limit(offset + (int) index + length); + return duplicate; + } + + @Override + public long memoryAddress() { + return MEMORY_ADDRESS; + } + + @Override + public String toString() { + return String.format( + "DatabricksArrowBuf id:%d capacity:%d, offset:%d", getId(), capacity, offset); + } + + @Override + public int hashCode() { + return System.identityHashCode(this); + } + + @Override + public boolean equals(Object obj) { + return this == obj; + } + + private int bufferIndex(long index) { + return offset + (int) index; + } + + private void chk(long index, long length) { + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + checkIndexD(index, length); + } + } + + private void checkIndexD(long index, long fieldLength) { + Preconditions.checkArgument(fieldLength >= 0, "expecting non-negative data length"); + if (index < 0 || index > capacity() - fieldLength) { + throw new IndexOutOfBoundsException( + String.format( + "index: %d, length: %d (expected: range(0, %d))", index, fieldLength, capacity())); + } + } + + // --- Primitive get/set operations using ByteBuffer --- + + @Override + public long getLong(long index) { + chk(index, LONG_SIZE); + return byteBuffer.getLong(bufferIndex(index)); + } + + @Override + public void setLong(long index, long value) { + chk(index, LONG_SIZE); + byteBuffer.putLong(bufferIndex(index), value); + } + + @Override + public float getFloat(long index) { + chk(index, FLOAT_SIZE); + return byteBuffer.getFloat(bufferIndex(index)); + } + + @Override + public void setFloat(long index, float value) { + chk(index, FLOAT_SIZE); + byteBuffer.putFloat(bufferIndex(index), value); + } + + @Override + public double getDouble(long index) { + chk(index, DOUBLE_SIZE); + return byteBuffer.getDouble(bufferIndex(index)); + } + + @Override + public void setDouble(long index, double value) { + chk(index, DOUBLE_SIZE); + byteBuffer.putDouble(bufferIndex(index), value); + } + + @Override + public char getChar(long index) { + chk(index, SHORT_SIZE); + return byteBuffer.getChar(bufferIndex(index)); + } + + @Override + public void setChar(long index, int value) { + chk(index, SHORT_SIZE); + byteBuffer.putChar(bufferIndex(index), (char) value); + } + + @Override + public int getInt(long index) { + chk(index, INT_SIZE); + return byteBuffer.getInt(bufferIndex(index)); + } + + @Override + public void setInt(long index, int value) { + chk(index, INT_SIZE); + byteBuffer.putInt(bufferIndex(index), value); + } + + @Override + public short getShort(long index) { + chk(index, SHORT_SIZE); + return byteBuffer.getShort(bufferIndex(index)); + } + + @Override + public void setShort(long index, int value) { + setShort(index, (short) value); + } + + @Override + public void setShort(long index, short value) { + chk(index, SHORT_SIZE); + byteBuffer.putShort(bufferIndex(index), value); + } + + @Override + public void setByte(long index, int value) { + chk(index, 1); + byteBuffer.put(bufferIndex(index), (byte) value); + } + + @Override + public void setByte(long index, byte value) { + chk(index, 1); + byteBuffer.put(bufferIndex(index), value); + } + + @Override + public byte getByte(long index) { + chk(index, 1); + return byteBuffer.get(bufferIndex(index)); + } + + // --- Writer index based operations --- + + private void ensureWritable(final int length) { + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + Preconditions.checkArgument(length >= 0, "expecting non-negative length"); + if (length > writableBytes()) { + throw new IndexOutOfBoundsException( + String.format( + "writerIndex(%d) + length(%d) exceeds capacity(%d)", + writerIndex, length, capacity())); + } + } + } + + private void ensureReadable(final int length) { + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + Preconditions.checkArgument(length >= 0, "expecting non-negative length"); + if (length > readableBytes()) { + throw new IndexOutOfBoundsException( + String.format( + "readerIndex(%d) + length(%d) exceeds writerIndex(%d)", + readerIndex, length, writerIndex)); + } + } + } + + @Override + public byte readByte() { + ensureReadable(1); + final byte b = getByte(readerIndex); + ++readerIndex; + return b; + } + + @Override + public void readBytes(byte[] dst) { + Preconditions.checkArgument(dst != null, "expecting valid dst bytearray"); + ensureReadable(dst.length); + getBytes(readerIndex, dst, 0, dst.length); + readerIndex += dst.length; + } + + @Override + public void writeByte(byte value) { + ensureWritable(1); + byteBuffer.put(bufferIndex(writerIndex), value); + ++writerIndex; + } + + @Override + public void writeByte(int value) { + ensureWritable(1); + byteBuffer.put(bufferIndex(writerIndex), (byte) value); + ++writerIndex; + } + + @Override + public void writeBytes(byte[] src) { + Preconditions.checkArgument(src != null, "expecting valid src array"); + writeBytes(src, 0, src.length); + } + + @Override + public void writeBytes(byte[] src, int srcIndex, int length) { + ensureWritable(length); + setBytes(writerIndex, src, srcIndex, length); + writerIndex += length; + } + + @Override + public void writeShort(int value) { + ensureWritable(SHORT_SIZE); + byteBuffer.putShort(bufferIndex(writerIndex), (short) value); + writerIndex += SHORT_SIZE; + } + + @Override + public void writeInt(int value) { + ensureWritable(INT_SIZE); + byteBuffer.putInt(bufferIndex(writerIndex), value); + writerIndex += INT_SIZE; + } + + @Override + public void writeLong(long value) { + ensureWritable(LONG_SIZE); + byteBuffer.putLong(bufferIndex(writerIndex), value); + writerIndex += LONG_SIZE; + } + + @Override + public void writeFloat(float value) { + ensureWritable(FLOAT_SIZE); + byteBuffer.putFloat(bufferIndex(writerIndex), value); + writerIndex += FLOAT_SIZE; + } + + @Override + public void writeDouble(double value) { + ensureWritable(DOUBLE_SIZE); + byteBuffer.putDouble(bufferIndex(writerIndex), value); + writerIndex += DOUBLE_SIZE; + } + + // --- Bulk byte array operations --- + + private static boolean isOutOfBounds(long index, long length, long capacity) { + return (index | length | (index + length) | (capacity - (index + length))) < 0; + } + + private void checkIndex(long index, long fieldLength) { + if (BoundsChecking.BOUNDS_CHECKING_ENABLED) { + if (isOutOfBounds(index, fieldLength, this.capacity())) { + throw new IndexOutOfBoundsException( + String.format( + "index: %d, length: %d (expected: range(0, %d))", + index, fieldLength, this.capacity())); + } + } + } + + @Override + public void getBytes(long index, byte[] dst) { + getBytes(index, dst, 0, dst.length); + } + + @Override + public void getBytes(long index, byte[] dst, int dstIndex, int length) { + checkIndex(index, length); + Preconditions.checkArgument(dst != null, "expecting a valid dst byte array"); + if (isOutOfBounds(dstIndex, length, dst.length)) { + throw new IndexOutOfBoundsException( + "Not enough space to copy data into destination" + dstIndex); + } + if (length != 0) { + // Use absolute positioning to avoid affecting buffer state + ByteBuffer duplicate = byteBuffer.duplicate(); + duplicate.position(bufferIndex(index)); + duplicate.get(dst, dstIndex, length); + } + } + + @Override + public void setBytes(long index, byte[] src) { + setBytes(index, src, 0, src.length); + } + + @Override + public void setBytes(long index, byte[] src, int srcIndex, long length) { + checkIndex(index, length); + Preconditions.checkArgument(src != null, "expecting a valid src byte array"); + if (isOutOfBounds(srcIndex, length, src.length)) { + throw new IndexOutOfBoundsException( + "Not enough space to copy data from byte array" + srcIndex); + } + if (length > 0) { + ByteBuffer duplicate = byteBuffer.duplicate(); + duplicate.position(bufferIndex(index)); + duplicate.put(src, srcIndex, (int) length); + } + } + + @Override + public void getBytes(long index, ByteBuffer dst) { + checkIndex(index, dst.remaining()); + if (dst.remaining() != 0) { + int length = dst.remaining(); + ByteBuffer duplicate = byteBuffer.duplicate(); + duplicate.position(bufferIndex(index)); + duplicate.limit(bufferIndex(index) + length); + dst.put(duplicate); + } + } + + @Override + public void setBytes(long index, ByteBuffer src) { + checkIndex(index, src.remaining()); + int length = src.remaining(); + if (length != 0) { + ByteBuffer duplicate = byteBuffer.duplicate(); + duplicate.position(bufferIndex(index)); + duplicate.put(src); + } + } + + @Override + public void setBytes(long index, ByteBuffer src, int srcIndex, int length) { + checkIndex(index, length); + if (length != 0) { + ByteBuffer srcDuplicate = src.duplicate(); + srcDuplicate.position(srcIndex); + srcDuplicate.limit(srcIndex + length); + + ByteBuffer duplicate = byteBuffer.duplicate(); + duplicate.position(bufferIndex(index)); + duplicate.put(srcDuplicate); + } + } + + @Override + public void getBytes(long index, ArrowBuf dst, long dstIndex, int length) { + checkIndex(index, length); + Preconditions.checkArgument(dst != null, "expecting a valid ArrowBuf"); + checkBufferType(dst); + if (isOutOfBounds(dstIndex, length, dst.capacity())) { + throw new IndexOutOfBoundsException( + String.format( + "index: %d, length: %d (expected: range(0, %d))", dstIndex, length, dst.capacity())); + } + if (length != 0) { + byte[] tmp = new byte[length]; + getBytes(index, tmp, 0, length); + dst.setBytes(dstIndex, tmp, 0, length); + } + } + + @Override + public void setBytes(long index, ArrowBuf src, long srcIndex, long length) { + checkIndex(index, length); + Preconditions.checkArgument(src != null, "expecting a valid ArrowBuf"); + checkBufferType(src); + if (isOutOfBounds(srcIndex, length, src.capacity())) { + throw new IndexOutOfBoundsException( + String.format( + "index: %d, length: %d (expected: range(0, %d))", srcIndex, length, src.capacity())); + } + if (length != 0) { + byte[] tmp = new byte[(int) length]; + src.getBytes(srcIndex, tmp, 0, (int) length); + setBytes(index, tmp, 0, length); + } + } + + @Override + public void setBytes(long index, ArrowBuf src) { + Preconditions.checkArgument(src != null, "expecting valid ArrowBuf"); + checkBufferType(src); + + final long length = src.readableBytes(); + checkIndex(index, length); + byte[] tmp = new byte[(int) length]; + src.getBytes(src.readerIndex(), tmp, 0, (int) length); + setBytes(index, tmp, 0, length); + src.readerIndex(src.readerIndex() + length); + } + + @Override + public int setBytes(long index, InputStream in, int length) throws IOException { + Preconditions.checkArgument(in != null, "expecting valid input stream"); + checkIndex(index, length); + int readBytes = 0; + if (length > 0) { + byte[] tmp = new byte[length]; + readBytes = in.read(tmp); + if (readBytes > 0) { + setBytes(index, tmp, 0, readBytes); + } + } + return readBytes; + } + + @Override + public void getBytes(long index, OutputStream out, int length) throws IOException { + Preconditions.checkArgument(out != null, "expecting valid output stream"); + checkIndex(index, length); + if (length > 0) { + byte[] tmp = new byte[length]; + getBytes(index, tmp, 0, length); + out.write(tmp); + } + } + + @Override + public void close() { + referenceManager.release(); + } + + @Override + public long getPossibleMemoryConsumed() { + return referenceManager.getSize(); + } + + @Override + public long getActualMemoryConsumed() { + return referenceManager.getAccountedSize(); + } + + @Override + public String toHexString(final long start, final int length) { + final long roundedStart = (start / LOG_BYTES_PER_ROW) * LOG_BYTES_PER_ROW; + + final StringBuilder sb = new StringBuilder("buffer byte dump\n"); + long index = roundedStart; + for (long nLogged = 0; nLogged < length; nLogged += LOG_BYTES_PER_ROW) { + sb.append(String.format(" [%05d-%05d]", index, index + LOG_BYTES_PER_ROW - 1)); + for (int i = 0; i < LOG_BYTES_PER_ROW; ++i) { + try { + final byte b = getByte(index++); + sb.append(String.format(" 0x%02x", b)); + } catch (IndexOutOfBoundsException ioob) { + sb.append(" "); + } + } + sb.append('\n'); + } + return sb.toString(); + } + + @Override + public void print(StringBuilder sb, int indent) { + CommonUtil.indent(sb, indent).append(this); + ; + } + + @Override + public long readerIndex() { + return readerIndex; + } + + @Override + public long writerIndex() { + return writerIndex; + } + + @Override + public ArrowBuf readerIndex(long readerIndex) { + this.readerIndex = readerIndex; + return this; + } + + @Override + public ArrowBuf writerIndex(long writerIndex) { + this.writerIndex = writerIndex; + return this; + } + + @Override + public ArrowBuf setZero(long index, long length) { + if (length != 0) { + this.checkIndex(index, length); + // Fill with zeros using Arrays.fill on the backing array + int startIdx = bufferIndex(index); + int endIdx = startIdx + (int) length; + Arrays.fill(byteBuffer.array(), startIdx, endIdx, (byte) 0); + } + return this; + } + + @Override + @Deprecated + public ArrowBuf setOne(int index, int length) { + return setOne((long) index, (long) length); + } + + @Override + public ArrowBuf setOne(long index, long length) { + if (length != 0) { + this.checkIndex(index, length); + int startIdx = bufferIndex(index); + int endIdx = startIdx + (int) length; + Arrays.fill(byteBuffer.array(), startIdx, endIdx, (byte) 0xff); + } + return this; + } + + @Override + public ArrowBuf reallocIfNeeded(final long size) { + Preconditions.checkArgument(size >= 0, "reallocation size must be non-negative"); + if (this.capacity() >= size) { + return this; + } + if (bufferManager != null) { + return bufferManager.replace(this, size); + } else { + throw new UnsupportedOperationException( + "Realloc is only available in the context of operator's UDFs"); + } + } + + @Override + public ArrowBuf clear() { + logger.debug("Clearing buffer from {}", this); + this.readerIndex = this.writerIndex = 0; + return this; + } + + /** + * Returns the offset within the underlying ByteBuffer where this buffer's data starts. This is + * used for sliced buffers that share the same underlying ByteBuffer. + * + * @return the offset in bytes + */ + public int getOffset() { + return offset; + } + + /** + * @return the underlying ByteBuffer. + */ + ByteBuffer getByteBuffer() { + return byteBuffer; + } + + private void checkBufferType(ArrowBuf buffer) { + if (!(buffer instanceof DatabricksArrowBuf)) { + throw new IllegalArgumentException("Buffer should be an instance of DatabricksArrowBuf"); + } + } +} diff --git a/src/main/java/org/apache/arrow/memory/DatabricksBufferAllocator.java b/src/main/java/org/apache/arrow/memory/DatabricksBufferAllocator.java new file mode 100644 index 000000000..43a10bce1 --- /dev/null +++ b/src/main/java/org/apache/arrow/memory/DatabricksBufferAllocator.java @@ -0,0 +1,236 @@ +package org.apache.arrow.memory; + +import java.util.Collection; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.arrow.memory.rounding.DefaultRoundingPolicy; +import org.apache.arrow.memory.rounding.RoundingPolicy; +import org.apache.arrow.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A BufferAllocator implementation that uses DatabricksArrowBuf for memory allocation. This + * allocator uses heap-based ByteBuffer storage instead of direct/off-heap memory, avoiding the need + * for sun.misc.Unsafe operations. + * + *

This implementation is suitable for environments where direct memory access is restricted or + * where heap-based memory management is preferred. + */ +public class DatabricksBufferAllocator implements BufferAllocator { + private static final Logger logger = LoggerFactory.getLogger(DatabricksBufferAllocator.class); + + private final String name; + private final AtomicBoolean closed = new AtomicBoolean(false); + private final DatabricksBufferAllocator parent; + private final Set children = ConcurrentHashMap.newKeySet(); + + // Empty buffer singleton + private final ArrowBuf emptyBuffer; + + /** Creates a root allocator with default settings. */ + public DatabricksBufferAllocator() { + this("DatabricksBufferAllocator"); + } + + /** + * Creates a root allocator with specified limit. + * + * @param name the allocator name + */ + public DatabricksBufferAllocator(String name) { + this(name, null); + } + + /** + * Creates an allocator with full configuration. + * + * @param name the allocator name + * @param parent the parent allocator (null for root) + */ + public DatabricksBufferAllocator(String name, DatabricksBufferAllocator parent) { + this.name = name; + this.parent = parent; + + // Create an empty buffer with a no-op reference manager + this.emptyBuffer = new DatabricksArrowBuf(DatabricksReferenceManagerNOOP.INSTANCE, null, 0); + } + + @Override + public ArrowBuf buffer(long size) { + return buffer(size, null); + } + + @Override + public ArrowBuf buffer(long size, BufferManager manager) { + assertOpen(); + Preconditions.checkArgument(size >= 0, "Buffer size must be non-negative"); + + if (size == 0) { + return getEmpty(); + } + + logger.debug("Allocating buffer of size {}", size); + + // Create the reference manager and buffer + DatabricksReferenceManager refManager = new DatabricksReferenceManager(this, size); + return new DatabricksArrowBuf(refManager, manager, size); + } + + @Override + public BufferAllocator getRoot() { + if (parent == null) { + return this; + } + return parent.getRoot(); + } + + @Override + public BufferAllocator newChildAllocator(String name, long initReservation, long maxAllocation) { + return newChildAllocator(name, AllocationListener.NOOP, initReservation, maxAllocation); + } + + @Override + public BufferAllocator newChildAllocator( + String name, AllocationListener listener, long initReservation, long maxAllocation) { + assertOpen(); + + DatabricksBufferAllocator child = new DatabricksBufferAllocator(name, this); + + children.add(child); + + return child; + } + + @Override + public void close() { + if (!closed.compareAndSet(false, true)) { + return; + } + + // Close all children first + for (DatabricksBufferAllocator child : children) { + child.close(); + } + children.clear(); + + // Remove from parent's children list + if (parent != null) { + parent.children.remove(this); + } + } + + @Override + public long getAllocatedMemory() { + return 0; + } + + @Override + public long getLimit() { + return Integer.MAX_VALUE; + } + + @Override + public long getInitReservation() { + return 0; + } + + @Override + public void setLimit(long newLimit) { + // Do nothing. + } + + @Override + public long getPeakMemoryAllocation() { + // Do nothing. + return 0; + } + + @Override + public long getHeadroom() { + return Integer.MAX_VALUE; + } + + @Override + public boolean forceAllocate(long size) { + if (parent != null) { + parent.forceAllocate(size); + } + return true; + } + + @Override + public void releaseBytes(long size) { + // Do nothing. + } + + @Override + public AllocationListener getListener() { + return AllocationListener.NOOP; + } + + @Override + public BufferAllocator getParentAllocator() { + return parent; + } + + @Override + public Collection getChildAllocators() { + return Collections.unmodifiableSet(children); + } + + @Override + public AllocationReservation newReservation() { + assertOpen(); + return new DatabricksAllocationReservation(this); + } + + @Override + public ArrowBuf getEmpty() { + return emptyBuffer; + } + + @Override + public String getName() { + return name; + } + + @Override + public boolean isOverLimit() { + // Never over limit. + return false; + } + + @Override + public String toVerboseString() { + StringBuilder sb = new StringBuilder(); + sb.append("Allocator(").append(name).append(") "); + if (!children.isEmpty()) { + sb.append("\n Children:\n"); + for (DatabricksBufferAllocator child : children) { + sb.append(" ").append(child.toVerboseString().replace("\n", "\n ")).append("\n"); + } + } + return sb.toString(); + } + + @Override + public void assertOpen() { + if (closed.get()) { + throw new IllegalStateException("Allocator " + name + " is closed"); + } + } + + @Override + public RoundingPolicy getRoundingPolicy() { + return DefaultRoundingPolicy.DEFAULT_ROUNDING_POLICY; + } + + @Override + public ArrowBuf wrapForeignAllocation(ForeignAllocation allocation) { + throw new UnsupportedOperationException( + "DatabricksBufferAllocator does not support foreign allocations"); + } +} diff --git a/src/main/java/org/apache/arrow/memory/DatabricksReferenceManager.java b/src/main/java/org/apache/arrow/memory/DatabricksReferenceManager.java new file mode 100644 index 000000000..3aac9d670 --- /dev/null +++ b/src/main/java/org/apache/arrow/memory/DatabricksReferenceManager.java @@ -0,0 +1,138 @@ +package org.apache.arrow.memory; + +import org.apache.arrow.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A Databricks reference manager which acts as a no-op and does not reference count. All data is + * allocated on the heap and taken care of by the JVM garbage collector. + */ +class DatabricksReferenceManager implements ReferenceManager { + private static final Logger logger = LoggerFactory.getLogger(DatabricksReferenceManager.class); + + /** Allocator of this reference manager. */ + private final DatabricksBufferAllocator allocator; + + /** Size of this reference. */ + private final long size; + + /** The memory is heap allocated and taken care of by the JVM. Assuming value of one is safe. */ + private static final int REF_COUNT = 1; + + public DatabricksReferenceManager(DatabricksBufferAllocator allocator, long size) { + this.allocator = allocator; + this.size = size; + } + + @Override + public int getRefCount() { + return REF_COUNT; + } + + @Override + public boolean release() { + return release(1); + } + + @Override + public boolean release(int decrement) { + return getRefCount() == 0; + } + + @Override + public void retain() { + retain(1); + } + + @Override + public void retain(int increment) { + // Do nothing. + } + + @Override + public ArrowBuf retain(ArrowBuf srcBuffer, BufferAllocator targetAllocator) { + DatabricksArrowBuf buf = checkBufferType(srcBuffer); + return deriveBuffer(buf); + } + + private ArrowBuf deriveBuffer(DatabricksArrowBuf srcBuffer) { + return deriveBuffer(srcBuffer, 0, srcBuffer.capacity()); + } + + @Override + public ArrowBuf deriveBuffer(ArrowBuf sourceBuffer, long index, long length) { + Preconditions.checkArgument( + length <= Integer.MAX_VALUE, + "Length %s should be less than or equal to %s", + length, + Integer.MAX_VALUE); + + Preconditions.checkArgument( + index + length <= sourceBuffer.capacity(), + "Index=" + + index + + " and length=" + + length + + " exceeds source buffer capacity=" + + sourceBuffer.capacity()); + + // Create a new DatabricksArrowBuf sharing the same byte buffer. + DatabricksArrowBuf buf = checkBufferType(sourceBuffer); + + logger.debug("Deriving buffer at index {} and length {} from buffer {}", index, length, buf); + + return new DatabricksArrowBuf( + this, null, buf.getByteBuffer(), buf.getOffset() + (int) index, length); + } + + @Override + public OwnershipTransferResult transferOwnership( + ArrowBuf sourceBuffer, BufferAllocator targetAllocator) { + DatabricksArrowBuf buf = checkBufferType(sourceBuffer); + checkAllocatorType(targetAllocator); + + final ArrowBuf newBuf = deriveBuffer(buf); + return new OwnershipTransferResult() { + @Override + public boolean getAllocationFit() { + return true; + } + + @Override + public ArrowBuf getTransferredBuffer() { + return newBuf; + } + }; + } + + @Override + public BufferAllocator getAllocator() { + return allocator; + } + + @Override + public long getSize() { + return size; + } + + @Override + public long getAccountedSize() { + return size; + } + + private DatabricksArrowBuf checkBufferType(ArrowBuf buffer) { + if (!(buffer instanceof DatabricksArrowBuf)) { + throw new IllegalArgumentException("Buffer should be an instance of DatabricksArrowBuf"); + } + return (DatabricksArrowBuf) buffer; + } + + private DatabricksBufferAllocator checkAllocatorType(BufferAllocator bufferAllocator) { + if (!(bufferAllocator instanceof DatabricksBufferAllocator)) { + throw new IllegalArgumentException( + "Allocator should be an instance of DatabricksBufferAllocator"); + } + return (DatabricksBufferAllocator) bufferAllocator; + } +} diff --git a/src/main/java/org/apache/arrow/memory/DatabricksReferenceManagerNOOP.java b/src/main/java/org/apache/arrow/memory/DatabricksReferenceManagerNOOP.java new file mode 100644 index 000000000..5a1aaf58b --- /dev/null +++ b/src/main/java/org/apache/arrow/memory/DatabricksReferenceManagerNOOP.java @@ -0,0 +1,63 @@ +package org.apache.arrow.memory; + +/** + * A Databricks specific no-op ReferenceManager that returns a DatabricksBufferAllocator + * . + */ +class DatabricksReferenceManagerNOOP implements ReferenceManager { + public static DatabricksReferenceManagerNOOP INSTANCE = new DatabricksReferenceManagerNOOP(); + + private DatabricksReferenceManagerNOOP() {} + + @Override + public int getRefCount() { + return 1; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } + + @Override + public void retain() {} + + @Override + public void retain(int increment) {} + + @Override + public ArrowBuf retain(ArrowBuf srcBuffer, BufferAllocator targetAllocator) { + return srcBuffer; + } + + @Override + public ArrowBuf deriveBuffer(ArrowBuf sourceBuffer, long index, long length) { + return sourceBuffer; + } + + @Override + public OwnershipTransferResult transferOwnership( + ArrowBuf sourceBuffer, BufferAllocator targetAllocator) { + return new OwnershipTransferNOOP(sourceBuffer); + } + + @Override + public BufferAllocator getAllocator() { + return new DatabricksBufferAllocator(); + } + + @Override + public long getSize() { + return 0L; + } + + @Override + public long getAccountedSize() { + return 0L; + } +} diff --git a/src/main/java/org/apache/arrow/memory/util/MemoryUtil.java b/src/main/java/org/apache/arrow/memory/util/MemoryUtil.java new file mode 100644 index 000000000..3855bc62f --- /dev/null +++ b/src/main/java/org/apache/arrow/memory/util/MemoryUtil.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ------------------------------------------------------------------------- +// MODIFICATION NOTICE: +// This file was modified by Databricks, Inc on 16-December-2025. +// Description of changes: Patched static initializer to not printStackTrace on failure. +// ------------------------------------------------------------------------- + +package org.apache.arrow.memory.util; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.security.AccessController; +import java.security.PrivilegedAction; +import org.checkerframework.checker.nullness.qual.Nullable; +import sun.misc.Unsafe; + +/** Utilities for memory related operations. */ +public class MemoryUtil { + private static final org.slf4j.Logger logger = + org.slf4j.LoggerFactory.getLogger(MemoryUtil.class); + + private static final @Nullable Constructor DIRECT_BUFFER_CONSTRUCTOR; + + /** The unsafe object from which to access the off-heap memory. */ + private static final Unsafe UNSAFE; + + /** The start offset of array data relative to the start address of the array object. */ + private static final long BYTE_ARRAY_BASE_OFFSET; + + /** The offset of the address field with the {@link ByteBuffer} object. */ + private static final long BYTE_BUFFER_ADDRESS_OFFSET; + + /** If the native byte order is little-endian. */ + public static final boolean LITTLE_ENDIAN = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; + + // Java 1.8, 9, 11, 17, 21 becomes 1, 9, 11, 17, and 21. + @SuppressWarnings("StringSplitter") + private static final int majorVersion = + Integer.parseInt(System.getProperty("java.specification.version").split("\\D+")[0]); + + static { + try { + // try to get the unsafe object + final Object maybeUnsafe = + AccessController.doPrivileged( + new PrivilegedAction() { + @Override + @SuppressWarnings({"nullness:argument", "nullness:return"}) + // incompatible argument for parameter obj of Field.get + // incompatible types in return + public Object run() { + try { + final Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + return unsafeField.get(null); + } catch (Throwable e) { + return e; + } + } + }); + + if (maybeUnsafe instanceof Throwable) { + throw (Throwable) maybeUnsafe; + } + + UNSAFE = (Unsafe) maybeUnsafe; + + // get the offset of the data inside a byte array object + BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + + // get the offset of the address field in a java.nio.Buffer object + Field addressField = java.nio.Buffer.class.getDeclaredField("address"); + addressField.setAccessible(true); + BYTE_BUFFER_ADDRESS_OFFSET = UNSAFE.objectFieldOffset(addressField); + + Constructor directBufferConstructor; + long address = -1; + final ByteBuffer direct = ByteBuffer.allocateDirect(1); + try { + + final Object maybeDirectBufferConstructor = + AccessController.doPrivileged( + new PrivilegedAction() { + @Override + public Object run() { + try { + final Constructor constructor = + (majorVersion >= 21) + ? direct.getClass().getDeclaredConstructor(long.class, long.class) + : direct.getClass().getDeclaredConstructor(long.class, int.class); + constructor.setAccessible(true); + logger.debug("Constructor for direct buffer found and made accessible"); + return constructor; + } catch (NoSuchMethodException e) { + logger.debug("Cannot get constructor for direct buffer allocation", e); + return e; + } catch (SecurityException e) { + logger.debug("Cannot get constructor for direct buffer allocation", e); + return e; + } + } + }); + + if (maybeDirectBufferConstructor instanceof Constructor) { + address = UNSAFE.allocateMemory(1); + // try to use the constructor now + try { + ((Constructor) maybeDirectBufferConstructor).newInstance(address, 1); + directBufferConstructor = (Constructor) maybeDirectBufferConstructor; + logger.debug("direct buffer constructor: available"); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + logger.warn("unable to instantiate a direct buffer via constructor", e); + directBufferConstructor = null; + } + } else { + logger.debug( + "direct buffer constructor: unavailable", (Throwable) maybeDirectBufferConstructor); + directBufferConstructor = null; + } + } finally { + if (address != -1) { + UNSAFE.freeMemory(address); + } + } + DIRECT_BUFFER_CONSTRUCTOR = directBufferConstructor; + } catch (Throwable e) { + // This exception will get swallowed, but it's necessary for the static analysis that ensures + // the static fields above get initialized + final RuntimeException failure = + new RuntimeException( + "Failed to initialize MemoryUtil. You must start Java with " + + "`--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED` " + + "(See https://arrow.apache.org/docs/java/install.html)", + e); + // ---- Databricks patch start ---- + // ---- Remove 'failure.printStackTrace();' + // ---- Databricks patch end ---- + throw failure; + } + } + + /** + * Given a {@link ByteBuffer}, gets the address the underlying memory space. + * + * @param buf the byte buffer. + * @return address of the underlying memory. + */ + public static long getByteBufferAddress(ByteBuffer buf) { + return UNSAFE.getLong(buf, BYTE_BUFFER_ADDRESS_OFFSET); + } + + private MemoryUtil() {} + + /** Create nio byte buffer. */ + public static ByteBuffer directBuffer(long address, int capacity) { + if (DIRECT_BUFFER_CONSTRUCTOR != null) { + if (capacity < 0) { + throw new IllegalArgumentException("Capacity is negative, has to be positive or 0"); + } + try { + return (ByteBuffer) DIRECT_BUFFER_CONSTRUCTOR.newInstance(address, capacity); + } catch (Throwable cause) { + throw new Error(cause); + } + } + throw new UnsupportedOperationException( + "sun.misc.Unsafe or java.nio.DirectByteBuffer.(long, int) not available"); + } + + @SuppressWarnings( + "nullness:argument") // to handle null assignment on third party dependency: Unsafe + private static void copyMemory( + @Nullable Object srcBase, + long srcOffset, + @Nullable Object destBase, + long destOffset, + long bytes) { + UNSAFE.copyMemory(srcBase, srcOffset, destBase, destOffset, bytes); + } + + public static void copyMemory(long srcAddress, long destAddress, long bytes) { + UNSAFE.copyMemory(srcAddress, destAddress, bytes); + } + + public static void copyToMemory(byte[] src, long srcIndex, long destAddress, long bytes) { + copyMemory(src, BYTE_ARRAY_BASE_OFFSET + srcIndex, null, destAddress, bytes); + } + + public static void copyFromMemory(long srcAddress, byte[] dest, long destIndex, long bytes) { + copyMemory(null, srcAddress, dest, BYTE_ARRAY_BASE_OFFSET + destIndex, bytes); + } + + public static byte getByte(long address) { + return UNSAFE.getByte(address); + } + + public static void putByte(long address, byte value) { + UNSAFE.putByte(address, value); + } + + public static short getShort(long address) { + return UNSAFE.getShort(address); + } + + public static void putShort(long address, short value) { + UNSAFE.putShort(address, value); + } + + public static int getInt(long address) { + return UNSAFE.getInt(address); + } + + public static void putInt(long address, int value) { + UNSAFE.putInt(address, value); + } + + public static long getLong(long address) { + return UNSAFE.getLong(address); + } + + public static void putLong(long address, long value) { + UNSAFE.putLong(address, value); + } + + public static void setMemory(long address, long bytes, byte value) { + UNSAFE.setMemory(address, bytes, value); + } + + public static int getInt(byte[] bytes, int index) { + return UNSAFE.getInt(bytes, BYTE_ARRAY_BASE_OFFSET + index); + } + + public static long getLong(byte[] bytes, int index) { + return UNSAFE.getLong(bytes, BYTE_ARRAY_BASE_OFFSET + index); + } + + public static long allocateMemory(long bytes) { + return UNSAFE.allocateMemory(bytes); + } + + public static void freeMemory(long address) { + UNSAFE.freeMemory(address); + } +} diff --git a/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java b/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java new file mode 100644 index 000000000..ae76865c6 --- /dev/null +++ b/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ------------------------------------------------------------------------- +// MODIFICATION NOTICE: +// This file was modified by Databricks, Inc on 16-December-2025. +// Description of changes: Patched method writeLongToArrowBuf to handle DatabricksArrowBuf. +// ------------------------------------------------------------------------- + +package org.apache.arrow.vector.util; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.DatabricksArrowBuf; +import org.apache.arrow.memory.util.MemoryUtil; + +/** Utility methods for configurable precision Decimal values (e.g. {@link BigDecimal}). */ +public class DecimalUtility { + private DecimalUtility() {} + + public static final byte[] zeroes = + new byte[] { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; + public static final byte[] minus_one = + new byte[] { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 + }; + private static final boolean LITTLE_ENDIAN = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; + + /** + * Read an ArrowType.Decimal at the given value index in the ArrowBuf and convert to a BigDecimal + * with the given scale. + */ + public static BigDecimal getBigDecimalFromArrowBuf( + ArrowBuf bytebuf, int index, int scale, int byteWidth) { + byte[] value = new byte[byteWidth]; + byte temp; + final long startIndex = (long) index * byteWidth; + + bytebuf.getBytes(startIndex, value, 0, byteWidth); + if (LITTLE_ENDIAN) { + // Decimal stored as native endian, need to swap bytes to make BigDecimal if native endian is + // LE + int stop = byteWidth / 2; + for (int i = 0, j; i < stop; i++) { + temp = value[i]; + j = (byteWidth - 1) - i; + value[i] = value[j]; + value[j] = temp; + } + } + BigInteger unscaledValue = new BigInteger(value); + return new BigDecimal(unscaledValue, scale); + } + + /** + * Read an ArrowType.Decimal from the ByteBuffer and convert to a BigDecimal with the given scale. + */ + public static BigDecimal getBigDecimalFromByteBuffer( + ByteBuffer bytebuf, int scale, int byteWidth) { + byte[] value = new byte[byteWidth]; + bytebuf.get(value); + BigInteger unscaledValue = new BigInteger(value); + return new BigDecimal(unscaledValue, scale); + } + + /** + * Read an ArrowType.Decimal from the ArrowBuf at the given value index and return it as a byte + * array. + */ + public static byte[] getByteArrayFromArrowBuf(ArrowBuf bytebuf, int index, int byteWidth) { + final byte[] value = new byte[byteWidth]; + final long startIndex = (long) index * byteWidth; + bytebuf.getBytes(startIndex, value, 0, byteWidth); + return value; + } + + /** + * Check that the BigDecimal scale equals the vectorScale and that the BigDecimal precision is + * less than or equal to the vectorPrecision. If not, then an UnsupportedOperationException is + * thrown, otherwise returns true. + */ + public static boolean checkPrecisionAndScale( + BigDecimal value, int vectorPrecision, int vectorScale) { + if (value.scale() != vectorScale) { + throw new UnsupportedOperationException( + "BigDecimal scale must equal that in the Arrow vector: " + + value.scale() + + " != " + + vectorScale); + } + if (value.precision() > vectorPrecision) { + throw new UnsupportedOperationException( + "BigDecimal precision cannot be greater than that in the Arrow " + + "vector: " + + value.precision() + + " > " + + vectorPrecision); + } + return true; + } + + /** + * Check that the BigDecimal scale equals the vectorScale and that the BigDecimal precision is + * less than or equal to the vectorPrecision. Return true if so, otherwise return false. + */ + public static boolean checkPrecisionAndScaleNoThrow( + BigDecimal value, int vectorPrecision, int vectorScale) { + return value.scale() == vectorScale && value.precision() < vectorPrecision; + } + + /** + * Check that the decimal scale equals the vectorScale and that the decimal precision is less than + * or equal to the vectorPrecision. If not, then an UnsupportedOperationException is thrown, + * otherwise returns true. + */ + public static boolean checkPrecisionAndScale( + int decimalPrecision, int decimalScale, int vectorPrecision, int vectorScale) { + if (decimalScale != vectorScale) { + throw new UnsupportedOperationException( + "BigDecimal scale must equal that in the Arrow vector: " + + decimalScale + + " != " + + vectorScale); + } + if (decimalPrecision > vectorPrecision) { + throw new UnsupportedOperationException( + "BigDecimal precision cannot be greater than that in the Arrow " + + "vector: " + + decimalPrecision + + " > " + + vectorPrecision); + } + return true; + } + + /** + * Write the given BigDecimal to the ArrowBuf at the given value index. Will throw an + * UnsupportedOperationException if the decimal size is greater than the Decimal vector byte + * width. + */ + public static void writeBigDecimalToArrowBuf( + BigDecimal value, ArrowBuf bytebuf, int index, int byteWidth) { + final byte[] bytes = value.unscaledValue().toByteArray(); + writeByteArrayToArrowBufHelper(bytes, bytebuf, index, byteWidth); + } + + /** + * Write the given long to the ArrowBuf at the given value index. This routine extends the + * original sign bit to a new upper area in 128-bit or 256-bit. + */ + public static void writeLongToArrowBuf(long value, ArrowBuf bytebuf, int index, int byteWidth) { + if (byteWidth != 16 && byteWidth != 32) { + throw new UnsupportedOperationException( + "DecimalUtility.writeLongToArrowBuf() currently supports " + + "128-bit or 256-bit width data"); + } + final long padValue = Long.signum(value) == -1 ? -1L : 0L; + + // ---- Databricks patch start ---- + if (bytebuf instanceof DatabricksArrowBuf) { + DatabricksArrowBuf buf = (DatabricksArrowBuf) bytebuf; + final int startIdx = index * byteWidth; + if (LITTLE_ENDIAN) { + buf.setLong(startIdx, value); + for (int i = 1; i <= (byteWidth - 8) / 8; i++) { + buf.setLong(startIdx + Long.BYTES * i, padValue); + } + } else { + for (int i = 0; i < (byteWidth - 8) / 8; i++) { + MemoryUtil.putLong(startIdx + Long.BYTES * i, padValue); + } + buf.setLong(startIdx + Long.BYTES * (byteWidth - 8) / 8, value); + } + } else { + final long addressOfValue = bytebuf.memoryAddress() + (long) index * byteWidth; + if (LITTLE_ENDIAN) { + MemoryUtil.putLong(addressOfValue, value); + for (int i = 1; i <= (byteWidth - 8) / 8; i++) { + MemoryUtil.putLong(addressOfValue + Long.BYTES * i, padValue); + } + } else { + for (int i = 0; i < (byteWidth - 8) / 8; i++) { + MemoryUtil.putLong(addressOfValue + Long.BYTES * i, padValue); + } + MemoryUtil.putLong(addressOfValue + Long.BYTES * (byteWidth - 8) / 8, value); + } + } + // ---- Databricks patch end ---- + } + + /** + * Write the given byte array to the ArrowBuf at the given value index. Will throw an + * UnsupportedOperationException if the decimal size is greater than the Decimal vector byte + * width. + */ + public static void writeByteArrayToArrowBuf( + byte[] bytes, ArrowBuf bytebuf, int index, int byteWidth) { + writeByteArrayToArrowBufHelper(bytes, bytebuf, index, byteWidth); + } + + private static void writeByteArrayToArrowBufHelper( + byte[] bytes, ArrowBuf bytebuf, int index, int byteWidth) { + final long startIndex = (long) index * byteWidth; + if (bytes.length > byteWidth) { + throw new UnsupportedOperationException( + "Decimal size greater than " + byteWidth + " bytes: " + bytes.length); + } + + byte[] padBytes = bytes[0] < 0 ? minus_one : zeroes; + if (LITTLE_ENDIAN) { + // Decimal stored as native-endian, need to swap data bytes before writing to ArrowBuf if LE + byte[] bytesLE = new byte[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + bytesLE[i] = bytes[bytes.length - 1 - i]; + } + + // Write LE data + bytebuf.setBytes(startIndex, bytesLE, 0, bytes.length); + bytebuf.setBytes(startIndex + bytes.length, padBytes, 0, byteWidth - bytes.length); + } else { + // Write BE data + bytebuf.setBytes(startIndex + byteWidth - bytes.length, bytes, 0, bytes.length); + bytebuf.setBytes(startIndex, padBytes, 0, byteWidth - bytes.length); + } + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorNettyManagerTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorNettyManagerTest.java new file mode 100644 index 000000000..f2f5c1328 --- /dev/null +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorNettyManagerTest.java @@ -0,0 +1,48 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static com.databricks.jdbc.api.impl.arrow.ArrowBufferAllocatorTest.readAndWriteArrowData; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import java.io.IOException; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.DatabricksBufferAllocator; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Test that the fallback {@code DatabricksBufferAllocator} is used with Netty allocation manager + * type when the creation of {@code RootAllocator} is not possible in the current JVM. + * + *

This test is in a separate class to ensure it runs in a fresh JVM with the system property set + * before Arrow's static initialization. + */ +public class ArrowBufferAllocatorNettyManagerTest { + private static final Logger logger = + LoggerFactory.getLogger(ArrowBufferAllocatorNettyManagerTest.class); + + private static final String ARROW_ALLOCATION_MANAGER_TYPE = "arrow.allocation.manager.type"; + + @BeforeAll + static void setUpAllocationManagerType() { + String originalValue = System.getProperty(ARROW_ALLOCATION_MANAGER_TYPE); + logger.info("Original value of {} is {}", ARROW_ALLOCATION_MANAGER_TYPE, originalValue); + System.setProperty(ARROW_ALLOCATION_MANAGER_TYPE, "Netty"); + logger.info("Setting system property {} to Netty", ARROW_ALLOCATION_MANAGER_TYPE); + } + + @Test + @Tag("Jvm17PlusAndArrowToNioReflectionDisabled") + @EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) + public void testCreateDatabricksBufferAllocatorWithNettyManagerType() throws IOException { + try (BufferAllocator allocator = ArrowBufferAllocator.getBufferAllocator()) { + assertInstanceOf( + DatabricksBufferAllocator.class, allocator, "Should create DatabricksBufferAllocator"); + readAndWriteArrowData(allocator); + } + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorTest.java new file mode 100644 index 000000000..34018c846 --- /dev/null +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorTest.java @@ -0,0 +1,104 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.DatabricksBufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Test the functionality of {@link ArrowBufferAllocator}. */ +public class ArrowBufferAllocatorTest { + /** Logger instance. */ + private static final Logger logger = LoggerFactory.getLogger(ArrowBufferAllocatorTest.class); + + /** Test that root allocator can be created. */ + @Test + public void testCreateRootAllocator() throws IOException { + try (BufferAllocator allocator = ArrowBufferAllocator.getBufferAllocator()) { + assertInstanceOf(RootAllocator.class, allocator, "Should create RootAllocator"); + readAndWriteArrowData(allocator); + } + + assertFalse(ArrowBufferAllocator.isUsingPatchedAllocator(), "Should use RootAllocator"); + } + + /** + * Test that the fallback {@code DatabricksBufferAllocator} is used when the creation of {@code + * RootAllocator} is not possible in the current JVM. + */ + @Test + @Tag("Jvm17PlusAndArrowToNioReflectionDisabled") + @EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) + public void testCreateDatabricksBufferAllocator() throws IOException { + try (BufferAllocator allocator = ArrowBufferAllocator.getBufferAllocator()) { + assertInstanceOf( + DatabricksBufferAllocator.class, allocator, "Should create DatabricksBufferAllocator"); + readAndWriteArrowData(allocator); + } + + assertTrue( + ArrowBufferAllocator.isUsingPatchedAllocator(), "Should use DatabricksBufferAllocator"); + } + + /** Write and read a sample arrow data to validate that the BufferAllocator works. */ + static void readAndWriteArrowData(BufferAllocator allocator) throws IOException { + // 1. Write sample data. + Field name = new Field("name", FieldType.nullable(new ArrowType.Utf8()), null); + Field age = new Field("age", FieldType.nullable(new ArrowType.Int(32, true)), null); + Schema schemaPerson = new Schema(asList(name, age)); + try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schemaPerson, allocator)) { + VarCharVector nameVector = (VarCharVector) vectorSchemaRoot.getVector("name"); + nameVector.allocateNew(3); + nameVector.set(0, "David".getBytes()); + nameVector.set(1, "Gladis".getBytes()); + nameVector.set(2, "Juan".getBytes()); + IntVector ageVector = (IntVector) vectorSchemaRoot.getVector("age"); + ageVector.allocateNew(3); + ageVector.set(0, 10); + ageVector.set(1, 20); + ageVector.set(2, 30); + vectorSchemaRoot.setRowCount(3); + ByteArrayOutputStream arrowData = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = + new ArrowStreamWriter(vectorSchemaRoot, null, Channels.newChannel(arrowData))) { + writer.start(); + writer.writeBatch(); + logger.info("Number of rows written: " + vectorSchemaRoot.getRowCount()); + } + + // 2. Read the sample data. + int totalRecords = 0; + try (ArrowStreamReader reader = + new ArrowStreamReader(new ByteArrayInputStream(arrowData.toByteArray()), allocator)) { + while (reader.loadNextBatch()) { + totalRecords += reader.getVectorSchemaRoot().getRowCount(); + } + } + + assertEquals(3, totalRecords, "Read 3 records"); + } + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorUnknownManagerTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorUnknownManagerTest.java new file mode 100644 index 000000000..4fe133eee --- /dev/null +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorUnknownManagerTest.java @@ -0,0 +1,48 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static com.databricks.jdbc.api.impl.arrow.ArrowBufferAllocatorTest.readAndWriteArrowData; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import java.io.IOException; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.DatabricksBufferAllocator; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Test that the fallback {@code DatabricksBufferAllocator} is used with an unknown allocation + * manager type when the creation of {@code RootAllocator} is not possible in the current JVM. + * + *

This test is in a separate class to ensure it runs in a fresh JVM with the system property set + * before Arrow's static initialization. + */ +public class ArrowBufferAllocatorUnknownManagerTest { + private static final Logger logger = + LoggerFactory.getLogger(ArrowBufferAllocatorUnknownManagerTest.class); + + private static final String ARROW_ALLOCATION_MANAGER_TYPE = "arrow.allocation.manager.type"; + + @BeforeAll + static void setUpAllocationManagerType() { + String originalValue = System.getProperty(ARROW_ALLOCATION_MANAGER_TYPE); + logger.info("Original value of {} is {}", ARROW_ALLOCATION_MANAGER_TYPE, originalValue); + System.setProperty(ARROW_ALLOCATION_MANAGER_TYPE, "Unknown"); + logger.info("Setting system property {} to Unknown", ARROW_ALLOCATION_MANAGER_TYPE); + } + + @Test + @Tag("Jvm17PlusAndArrowToNioReflectionDisabled") + @EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) + public void testCreateDatabricksBufferAllocatorWithUnknownManagerType() throws IOException { + try (BufferAllocator allocator = ArrowBufferAllocator.getBufferAllocator()) { + assertInstanceOf( + DatabricksBufferAllocator.class, allocator, "Should create DatabricksBufferAllocator"); + readAndWriteArrowData(allocator); + } + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorUnsafeManagerTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorUnsafeManagerTest.java new file mode 100644 index 000000000..7b0abeba4 --- /dev/null +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowBufferAllocatorUnsafeManagerTest.java @@ -0,0 +1,48 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static com.databricks.jdbc.api.impl.arrow.ArrowBufferAllocatorTest.readAndWriteArrowData; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import java.io.IOException; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.DatabricksBufferAllocator; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Test that the fallback {@code DatabricksBufferAllocator} is used with Unsafe allocation manager + * type when the creation of {@code RootAllocator} is not possible in the current JVM. + * + *

This test is in a separate class to ensure it runs in a fresh JVM with the system property set + * before Arrow's static initialization. + */ +public class ArrowBufferAllocatorUnsafeManagerTest { + private static final Logger logger = + LoggerFactory.getLogger(ArrowBufferAllocatorUnsafeManagerTest.class); + + private static final String ARROW_ALLOCATION_MANAGER_TYPE = "arrow.allocation.manager.type"; + + @BeforeAll + static void setUpAllocationManagerType() { + String originalValue = System.getProperty(ARROW_ALLOCATION_MANAGER_TYPE); + logger.info("Original value of {} is {}", ARROW_ALLOCATION_MANAGER_TYPE, originalValue); + System.setProperty(ARROW_ALLOCATION_MANAGER_TYPE, "Unsafe"); + logger.info("Setting system property {} to Unsafe", ARROW_ALLOCATION_MANAGER_TYPE); + } + + @Test + @Tag("Jvm17PlusAndArrowToNioReflectionDisabled") + @EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) + public void testCreateDatabricksBufferAllocatorWithUnsafeManagerType() throws IOException { + try (BufferAllocator allocator = ArrowBufferAllocator.getBufferAllocator()) { + assertInstanceOf( + DatabricksBufferAllocator.class, allocator, "Should create DatabricksBufferAllocator"); + readAndWriteArrowData(allocator); + } + } +} diff --git a/src/test/java/org/apache/arrow/memory/AbstractDatabricksArrowPatchTypesTest.java b/src/test/java/org/apache/arrow/memory/AbstractDatabricksArrowPatchTypesTest.java new file mode 100644 index 000000000..6030335ae --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/AbstractDatabricksArrowPatchTypesTest.java @@ -0,0 +1,407 @@ +package org.apache.arrow.memory; + +import com.databricks.jdbc.api.impl.arrow.ArrowBufferAllocator; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.Collections; +import java.util.stream.Stream; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.provider.Arguments; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base class for Arrow data type read/write tests with the patched Arrow classes. + * + *

Provides shared infrastructure including parameter sources, core test utilities, and data + * generation helpers. + */ +public abstract class AbstractDatabricksArrowPatchTypesTest { + protected static final Logger logger = + LoggerFactory.getLogger(AbstractDatabricksArrowPatchTypesTest.class); + + /** Provide different buffer allocators. */ + protected static Stream getBufferAllocators() { + // Large enough value which fits within the heap space for tests. + int totalRows = (int) Math.pow(2, 17); // A large enough value. + return Stream.of( + Arguments.of(new DatabricksBufferAllocator(), new DatabricksBufferAllocator(), totalRows)); + } + + /** Provide different buffer allocators with smaller row count. */ + protected static Stream getBufferAllocatorsSmallRows() { + int totalRows = 100_000; // A large enough value. + return Stream.of( + Arguments.of(new DatabricksBufferAllocator(), new DatabricksBufferAllocator(), totalRows)); + } + + @BeforeAll + public static void logDetails() { + logger.info("Using allocator: {}", ArrowBufferAllocator.getBufferAllocator().getName()); + } + + protected byte[] writeData(DataTester dataTester, int totalRowCount, BufferAllocator allocator) + throws IOException { + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(dataTester.getSchema(), allocator); + ArrowStreamWriter streamWriter = + new ArrowStreamWriter(vectorSchemaRoot, null, byteArrayOutputStream); + + streamWriter.start(); + for (int batchSize = 1; batchSize <= totalRowCount; batchSize *= 2) { + dataTester.writeData(vectorSchemaRoot, batchSize); + + // Write batch. + vectorSchemaRoot.setRowCount(batchSize); + streamWriter.writeBatch(); + vectorSchemaRoot.clear(); + } + + streamWriter.end(); + streamWriter.close(); + vectorSchemaRoot.close(); + allocator.close(); + + return byteArrayOutputStream.toByteArray(); + } + + protected void readAndValidate(DataTester dataTester, byte[] data, BufferAllocator allocator) + throws IOException { + try (ArrowStreamReader reader = + new ArrowStreamReader(new ByteArrayInputStream(data), allocator)) { + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + logger.info("Validating {} rows", root.getRowCount()); + dataTester.validateData(root); + } + } finally { + allocator.close(); + } + } + + /** Interface for test data writers and validators. */ + protected interface DataTester { + Schema getSchema(); + + void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize); + + void validateData(VectorSchemaRoot vectorSchemaRoot); + } + + // ============================================================================================ + // Field creation helper methods + // ============================================================================================ + + protected Field newSignedByteIntField() { + return new Field("signed-byte-int", FieldType.nullable(new ArrowType.Int(8, true)), null); + } + + protected Field newSignedShortIntField() { + return new Field("signed-short-int", FieldType.nullable(new ArrowType.Int(16, true)), null); + } + + protected Field newSignedIntField() { + return new Field("signed-int", FieldType.nullable(new ArrowType.Int(32, true)), null); + } + + protected Field newSignedLongField() { + return new Field("signed-long", FieldType.nullable(new ArrowType.Int(64, true)), null); + } + + protected Field newUnsignedByteIntField() { + return new Field("unsigned-byte-int", FieldType.nullable(new ArrowType.Int(8, false)), null); + } + + protected Field newUnsignedShortIntField() { + return new Field("unsigned-short-int", FieldType.nullable(new ArrowType.Int(16, false)), null); + } + + protected Field newUnsignedIntField() { + return new Field("unsigned-int", FieldType.nullable(new ArrowType.Int(32, false)), null); + } + + protected Field newUnsignedLongField() { + return new Field("unsigned-long", FieldType.nullable(new ArrowType.Int(64, false)), null); + } + + protected Field newFloatField() { + return new Field( + "float", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), + null); + } + + protected Field newDoubleField() { + return new Field( + "double", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + null); + } + + protected Field newDecimalField(int precision, int scale, int bitWidth) { + String name = "decimal-" + precision + "-" + scale + "-" + bitWidth; + return new Field( + name, FieldType.nullable(new ArrowType.Decimal(precision, scale, bitWidth)), null); + } + + protected Field newDateDayField() { + return new Field("date-day", FieldType.nullable(new ArrowType.Date(DateUnit.DAY)), null); + } + + protected Field newTimestampMilliField() { + return new Field( + "timestamp-milli", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")), + null); + } + + protected Field newTimestampMicroField() { + return new Field( + "timestamp-micro", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")), + null); + } + + protected Field newTimeSecField() { + return new Field("time-sec", FieldType.nullable(new ArrowType.Time(TimeUnit.SECOND, 32)), null); + } + + protected Field newTimeNanoField() { + return new Field( + "time-nano", FieldType.nullable(new ArrowType.Time(TimeUnit.NANOSECOND, 64)), null); + } + + protected Field newDurationMicrosecondField() { + return new Field( + "duration-micro-sec", + FieldType.nullable(new ArrowType.Duration(TimeUnit.MICROSECOND)), + null); + } + + protected Field newIntervalYearField() { + return new Field( + "interval-year", FieldType.nullable(new ArrowType.Interval(IntervalUnit.YEAR_MONTH)), null); + } + + protected Field newIntervalDayField() { + return new Field( + "interval-day", FieldType.nullable(new ArrowType.Interval(IntervalUnit.DAY_TIME)), null); + } + + protected Field newIntervalMonthDayNanoField() { + return new Field( + "interval-month-day-nano", + FieldType.nullable(new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO)), + null); + } + + protected Field newFixedSizeBinaryField() { + return new Field( + "fixed-size-binary", FieldType.nullable(new ArrowType.FixedSizeBinary(16)), null); + } + + protected Field newVarBinaryField() { + return new Field("variable-binary", FieldType.nullable(new ArrowType.Binary()), null); + } + + protected Field newLargeVarBinaryField() { + return new Field( + "large-variable-binary", FieldType.nullable(new ArrowType.LargeBinary()), null); + } + + protected Field newUtf8Field() { + return new Field("utf8-string", FieldType.nullable(new ArrowType.Utf8()), null); + } + + protected Field newLargeUtf8Field() { + return new Field("large-utf8-string", FieldType.nullable(new ArrowType.LargeUtf8()), null); + } + + protected Field newListIntField() { + return new Field( + "list-int", + FieldType.nullable(ArrowType.List.INSTANCE), + Collections.singletonList( + new Field("$data$", FieldType.nullable(new ArrowType.Int(32, true)), null))); + } + + protected Field newLargeListIntField() { + return new Field( + "large-list-int", + FieldType.nullable(ArrowType.LargeList.INSTANCE), + Collections.singletonList( + new Field("$data$", FieldType.nullable(new ArrowType.Int(32, true)), null))); + } + + protected Field newDictStringField() { + return new Field("dict-string", FieldType.nullable(new ArrowType.Utf8()), null); + } + + protected Field newDictIntField() { + return new Field("dict-int", FieldType.nullable(new ArrowType.Int(32, true)), null); + } + + // ============================================================================================ + // Data generation helper methods + // ============================================================================================ + + protected int getSignedByte(int index) { + return (index % 256) - 128; + } + + protected int getSignedShort(int index) { + return (index % 65536) - 32768; + } + + protected int getSignedInt(int index) { + return index * (index % 3 == 0 ? -1 : 1); + } + + protected long getSignedLong(int index) { + return (long) index + Integer.MAX_VALUE * (index % 3 == 0 ? -1 : 1); + } + + protected int getUnsignedByte(int index) { + return index % 256; + } + + protected int getUnsignedShort(int index) { + return index % 65536; + } + + protected int getUnsignedInt(int index) { + return index; + } + + protected long getUnsignedLong(int index) { + return (long) index * 2; + } + + protected float getFloat(int index) { + return (float) index * (float) Math.PI * (index % 3 == 0 ? -1 : 1); + } + + protected double getDouble(int index) { + return index * Math.PI * (index % 3 == 0 ? -1 : 1); + } + + protected BigDecimal getDecimal(int index, int scale) { + BigDecimal bigDecimal = new BigDecimal(index % 100 * (index % 3 == 0 ? -1 : 1)); + return bigDecimal.setScale(scale, RoundingMode.HALF_DOWN); + } + + protected int getDateDay(int index) { + return 18000 + (index % 10000); + } + + protected long getTimestampMilli(int index) { + return 1577836800000L + ((long) index * 1000L); + } + + protected long getTimestampMicro(int index) { + return getTimestampMilli(index) * 1000L; + } + + protected int getTimeSec(int index) { + return index % 86400; + } + + protected long getTimeNano(int index) { + return (long) (index % 86400) * 1_000_000_000L; + } + + protected long getDurationSec(int index) { + return (long) index * 3600L; + } + + protected long getDurationMicroseconds(int index) { + return (long) index % 1000; + } + + protected int getIntervalYearMonth(int index) { + return index % 600; + } + + protected int getIntervalDayDays(int index) { + return index % 365; + } + + protected int getIntervalDayMillis(int index) { + return (index * 1000) % 86400000; + } + + protected byte[] getFixedSizeBinary(int index, int size) { + byte[] data = new byte[size]; + for (int i = 0; i < size; i++) { + data[i] = (byte) ((index + i) % 256); + } + return data; + } + + protected byte[] getVarBinary(int index, int maxLength) { + int length = (index % maxLength) + 1; + byte[] data = new byte[length]; + for (int i = 0; i < length; i++) { + data[i] = (byte) ((index * 3 + i) % 256); + } + return data; + } + + protected byte[] getLargeVarBinary(int index, int maxLength) { + int length = (index % maxLength) + 1; + byte[] data = new byte[length]; + for (int i = 0; i < length; i++) { + data[i] = (byte) ((index * 7 + i) % 256); + } + return data; + } + + protected String getUtf8String(int index) { + return "UTF8-String-" + index + "-Data"; + } + + protected String getLargeUtf8String(int index) { + return "LargeUTF8-String-" + index + "-DataWithMoreContent-" + (index * 3); + } + + protected int getListSize(int index) { + return index % 32 + 1; + } + + protected int getListElement(int rowIndex, int elementIndex) { + return rowIndex * 100 + elementIndex; + } + + protected int getLargeListSize(int index) { + return (index % 128) + 1; + } + + protected int getLargeListElement(int rowIndex, int elementIndex) { + return rowIndex * 1000 + elementIndex; + } + + protected String getDictString(int index) { + // Return repeating values to simulate dictionary-encoded data + String[] dictValues = {"Red", "Green", "Blue", "Yellow", "Orange"}; + return dictValues[index % dictValues.length]; + } + + protected int getDictInt(int index) { + // Return repeating values to simulate dictionary-encoded data + return (index % 10) * 100; + } +} diff --git a/src/test/java/org/apache/arrow/memory/ArrowParsingBenchmark.java b/src/test/java/org/apache/arrow/memory/ArrowParsingBenchmark.java new file mode 100644 index 000000000..c77e99e84 --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/ArrowParsingBenchmark.java @@ -0,0 +1,161 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import net.jpountz.lz4.LZ4FrameInputStream; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.util.TransferPair; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@Fork(value = 1) +@Measurement(iterations = 20, time = 100, timeUnit = TimeUnit.MILLISECONDS) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Warmup(iterations = 20, time = 100, timeUnit = TimeUnit.MILLISECONDS) +public class ArrowParsingBenchmark { + /** Path to an arrow chunk. */ + private static final Path ARROW_CHUNK_PATH = Path.of("arrow", "chunk_all_types.arrow"); + + /** Path to a LZ4 compressed arrow chunk. */ + private static final Path ARROW_CHUNK_COMPRESSED_PATH = + Path.of("arrow", "chunk_all_types.arrow.lz4"); + + /** Compressed Arrow file suffix. */ + private static final String ARROW_CHUNK_COMPRESSED_FILE_SUFFIX = ".lz4"; + + public static void main(String[] args) throws RunnerException { + Options options = + new OptionsBuilder().include(ArrowParsingBenchmark.class.getSimpleName()).build(); + new Runner(options).run(); + } + + // Pre-loaded file contents + private byte[] arrowChunkBytes; + private byte[] arrowChunkCompressedBytes; + + @Setup(Level.Trial) + public void setup() throws IOException { + // Load files into memory once before all benchmark iterations + arrowChunkBytes = loadFileToMemory(ARROW_CHUNK_PATH); + arrowChunkCompressedBytes = loadFileToMemory(ARROW_CHUNK_COMPRESSED_PATH); + } + + private byte[] loadFileToMemory(Path filePath) throws IOException { + try (InputStream stream = + getClass().getClassLoader().getResourceAsStream(filePath.toString())) { + assertNotNull(stream, filePath + " not found"); + return stream.readAllBytes(); + } + } + + @Benchmark + public List> parseArrowChunk() throws IOException { + try (BufferAllocator allocator = new RootAllocator()) { + return parseArrowStream(arrowChunkBytes, false, allocator); + } + } + + @Benchmark + public List> parseArrowCompressedChunk() throws IOException { + try (BufferAllocator allocator = new RootAllocator()) { + return parseArrowStream(arrowChunkCompressedBytes, true, allocator); + } + } + + @Benchmark + public List> parsePatchedArrowChunk() throws IOException { + try (BufferAllocator allocator = new DatabricksBufferAllocator()) { + return parseArrowStream(arrowChunkBytes, false, allocator); + } + } + + @Benchmark + public List> parsePatchedArrowCompressedChunk() throws IOException { + try (BufferAllocator allocator = new DatabricksBufferAllocator()) { + return parseArrowStream(arrowChunkCompressedBytes, true, allocator); + } + } + + /** Parse the Arrow stream file stored at {@code filePath} and return the records in the file. */ + private List> parseArrowStream( + byte[] arrowChunkBytes, boolean isCompressed, BufferAllocator allocator) throws IOException { + ArrayList> records = new ArrayList<>(); + + InputStream arrowStream = new ByteArrayInputStream(arrowChunkBytes); + if (isCompressed) { + arrowStream = new LZ4FrameInputStream(arrowStream); + } + + try (ArrowStreamReader reader = new ArrowStreamReader(arrowStream, allocator)) { + // Iterate over batches. + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + + // Transfer all vectors. + List valueVectors = + root.getFieldVectors().stream() + .map( + fieldVector -> { + TransferPair transferPair = fieldVector.getTransferPair(allocator); + transferPair.transfer(); + return transferPair.getTo(); + }) + .collect(Collectors.toList()); + + // Parse and populate each record/row in this batch. + try { + for (int recordIndex = 0; recordIndex < root.getRowCount(); recordIndex++) { + HashMap record = new HashMap<>(); + for (ValueVector valueVector : valueVectors) { + record.put(valueVector.getField().getName(), valueVector.getObject(recordIndex)); + } + records.add(record); + } + } finally { + // Close all transferred vectors to prevent memory leak + valueVectors.forEach(ValueVector::close); + } + } + } + + return records; + } + + /** + * @return an input stream for the filePath. + */ + private InputStream getStream(Path filePath) throws IOException { + InputStream arrowStream = getClass().getClassLoader().getResourceAsStream(filePath.toString()); + assertNotNull(arrowStream, filePath + " not found"); + return filePath.toString().endsWith(ARROW_CHUNK_COMPRESSED_FILE_SUFFIX) + ? new LZ4FrameInputStream(arrowStream) + : arrowStream; + } +} diff --git a/src/test/java/org/apache/arrow/memory/DatabricksAllocationReservationTest.java b/src/test/java/org/apache/arrow/memory/DatabricksAllocationReservationTest.java new file mode 100644 index 000000000..98af1df0f --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/DatabricksAllocationReservationTest.java @@ -0,0 +1,84 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Random; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; + +/** Test allocation reservation */ +@Tag("Jvm17PlusAndArrowToNioReflectionDisabled") +@EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) +public class DatabricksAllocationReservationTest { + /** Test reserve and allocate */ + @Test + public void testReservation() { + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + DatabricksAllocationReservation reservation = new DatabricksAllocationReservation(allocator); + + Random random = new Random(); + long totalReservation = 0; + for (int i = 0; i < 1000; i++) { + long size = random.nextInt(1000); + assertTrue(reservation.reserve(size), "Reservation should return true"); + totalReservation += size; + assertEquals(totalReservation, reservation.getSizeLong(), "Reservation should match"); + } + + ArrowBuf buffer = reservation.allocateBuffer(); + assertInstanceOf(DatabricksArrowBuf.class, buffer, "Buffer type should match"); + assertTrue(buffer.capacity() >= totalReservation, "Reservation should be allocated"); + } + + /** Test fail on reuse */ + @Test + public void testFailureOnReuse() { + long bufferSize = 1024; + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + DatabricksAllocationReservation reservation = new DatabricksAllocationReservation(allocator); + reservation.reserve(bufferSize); + + ArrowBuf buffer = reservation.allocateBuffer(); + assertInstanceOf(DatabricksArrowBuf.class, buffer, "Buffer type should match"); + assertTrue(buffer.capacity() >= bufferSize, "Reservation should be allocated"); + + assertTrue(reservation.isUsed(), "Reservation should be used"); + + assertThrows( + IllegalStateException.class, + () -> reservation.reserve(10L), + "Reuse after allocate should fail"); + assertThrows( + IllegalStateException.class, + reservation::allocateBuffer, + "Reuse after allocate should fail"); + } + + /** Test fail on reuse after close */ + @Test + public void testFailureOnClose() { + long bufferSize = 1024; + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + DatabricksAllocationReservation reservation = new DatabricksAllocationReservation(allocator); + reservation.reserve(bufferSize); + + ArrowBuf buffer = reservation.allocateBuffer(); + assertInstanceOf(DatabricksArrowBuf.class, buffer, "Buffer type should match"); + assertTrue(buffer.capacity() >= bufferSize, "Reservation should be allocated"); + + reservation.close(); + assertTrue(reservation.isClosed(), "Reservation should have been closed"); + + assertThrows( + IllegalStateException.class, + () -> reservation.reserve(10L), + "Reuse after close should fail"); + assertThrows( + IllegalStateException.class, reservation::allocateBuffer, "Reuse after close should fail"); + } +} diff --git a/src/test/java/org/apache/arrow/memory/DatabricksArrowBufTest.java b/src/test/java/org/apache/arrow/memory/DatabricksArrowBufTest.java new file mode 100644 index 000000000..2e95ae69e --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/DatabricksArrowBufTest.java @@ -0,0 +1,1143 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Test all the public API of {@code DatabricksArrowBuf}. */ +@Tag("Jvm17PlusAndArrowToNioReflectionDisabled") +@EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) +public class DatabricksArrowBufTest { + + private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder(); + + private static final Logger logger = LoggerFactory.getLogger(DatabricksArrowBufTest.class); + + /** Test the constructor fails on invalid capacity arguments. */ + @Test + public void testConstructorFailsOnInvalidCapacity() { + final int bufferSize = 32; + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + DatabricksReferenceManager refManager = new DatabricksReferenceManager(allocator, bufferSize); + + // Should fail when capacity is exceeded. + assertThrows( + IllegalArgumentException.class, + () -> { + try (DatabricksArrowBuf arrowBuf = + new DatabricksArrowBuf(refManager, null, Integer.MAX_VALUE + 1L)) { + logger.info("Should not reach here. {}", arrowBuf); + } + }, + "Constructor should fail when capacity is greater than Integer.MAX_VALUE"); + } + + /** + * Test ref count is always positive as long as there is a reference to the DatabricksArrowBuffer. + */ + @Test + public void testRefCountIsAlwaysPositive() { + final int bufferSize = 32; + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + DatabricksArrowBuf buffer = (DatabricksArrowBuf) allocator.buffer(bufferSize); + assertTrue(buffer.refCnt() > 0, "Refcount should be positive"); + + // Even after allocator is closed, if there is a reference to the buffer it should be + // positive. + allocator.close(); + assertTrue(buffer.refCnt() > 0, "Refcount should be positive even after allocator is closed"); + } + + /** Test checkBytes behaviour. */ + @Test + public void testCheckBytes() { + final int bufferSize = 32; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int i = 0; i < bufferSize; i++) { + buffer.checkBytes(0, bufferSize); + } + + // Negative should throw an exception. + assertThrows( + IndexOutOfBoundsException.class, + () -> buffer.checkBytes(-1, bufferSize), + "Negative start index should fail."); + + // Past end should throw an exception. + assertThrows( + IndexOutOfBoundsException.class, + () -> buffer.checkBytes(0, bufferSize + 1), + "Out of bounds end index should fail."); + } + } + + /** Test setting buffer capacity. */ + @Test + public void testSetCapacity() { + final int bufferSize = 32; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + + assertThrows( + UnsupportedOperationException.class, + () -> buffer.capacity(bufferSize + 1), + "Increasing buffer capacity should fail."); + + // Reducing buffer size capacity should be supported. + for (int capacity = bufferSize - 1; capacity >= 0; capacity--) { + buffer.capacity(capacity); + } + } + } + + /** Test byte order is as expected. */ + @Test + public void testByteOrder() { + final int bufferSize = 32; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + assertEquals(BYTE_ORDER, buffer.order(), "ByteOrder should be " + BYTE_ORDER); + } + } + + /** Test readable bytes behaviour is correct. */ + @Test + public void testReadableBytes() { + final int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int i = 0; i < bufferSize; i++) { + buffer.writeByte((byte) i); + assertEquals(i + 1, buffer.readableBytes(), "Readable bytes should be correct."); + } + + for (int i = 0; i < bufferSize; i++) { + buffer.readByte(); + assertEquals( + bufferSize - 1 - i, buffer.readableBytes(), "Readable bytes should be correct."); + } + } + } + + /** Test writable bytes behaviour is correct. */ + @Test + public void testWritableBytes() { + final int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + assertEquals(bufferSize, buffer.writableBytes(), "Writable bytes should be correct."); + for (int i = 0; i < bufferSize; i++) { + buffer.writeByte((byte) i); + assertEquals( + bufferSize - 1 - i, buffer.writableBytes(), "Writable bytes should be correct."); + } + } + } + + /** Test slice works as expected. */ + @Test + public void testSlice() { + final int bufferSize = 1024; + DatabricksArrowBuf buffer = newBuffer(bufferSize); + + // Write zeroes into the original buffer. + for (int i = 0; i < bufferSize; i++) { + buffer.writeByte((byte) 0); + } + + // Test whole slice. + DatabricksArrowBuf wholeSlice = (DatabricksArrowBuf) buffer.slice(); + testWriteAffectsOriginalBufferAndSlice(buffer, wholeSlice, 0); + + // Write data to part of a slice and check that the original buffer is affected as well. + for (int sliceSize = 1; sliceSize < bufferSize; sliceSize++) { + DatabricksArrowBuf slice = (DatabricksArrowBuf) buffer.slice(0, sliceSize); + testWriteAffectsOriginalBufferAndSlice(buffer, slice, 0); + + int index = bufferSize - sliceSize; + slice = (DatabricksArrowBuf) buffer.slice(index, sliceSize); + testWriteAffectsOriginalBufferAndSlice(buffer, slice, index); + } + + // Test some random offsets and length. + Random random = new Random(); + for (int i = 0; i < 10; i++) { + int startOffset = random.nextInt(bufferSize); + int size = random.nextInt(bufferSize - startOffset); + DatabricksArrowBuf slice = (DatabricksArrowBuf) buffer.slice(startOffset, size); + testWriteAffectsOriginalBufferAndSlice(buffer, slice, startOffset); + } + } + + private void testWriteAffectsOriginalBufferAndSlice( + DatabricksArrowBuf buffer, DatabricksArrowBuf slice, int sliceStartIndex) { + final int bufferSize = (int) buffer.capacity(); + + // Write zeroes into the original buffer. + buffer.clear(); + for (int i = 0; i < bufferSize; i++) { + buffer.writeByte((byte) 0); + } + + // Write data to the slice and check that the original buffer is affected as well. + for (int i = 0; i < slice.capacity(); i++) { + slice.setByte(i, getByteValue(i)); + } + for (int i = 0; i < slice.capacity(); i++) { + assertEquals( + getByteValue(i), + buffer.getByte(sliceStartIndex + i), + "Readable bytes should be correct at index " + i); + } + + // Write data to the original buffer and check that the slice is affected. + for (int i = sliceStartIndex; i < sliceStartIndex + slice.capacity(); i++) { + buffer.setByte(i, getByteValue(i)); + } + for (int i = 0; i < slice.capacity(); i++) { + assertEquals( + getByteValue(sliceStartIndex + i), + slice.getByte(i), + "Readable bytes should be correct at index " + i); + } + } + + /** Should fail on incorrect indices. */ + @Test + public void testSliceFailsOnIncorrectIndices() { + final int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + // Write zeroes into the original buffer. + for (int i = 0; i < bufferSize; i++) { + buffer.writeByte((byte) 0); + } + + assertThrows( + IndexOutOfBoundsException.class, + () -> buffer.slice(-1, bufferSize), + "Should fail on negative index."); + assertThrows( + IndexOutOfBoundsException.class, + () -> buffer.slice(0, bufferSize + 1), + "Should fail on out of bounds length."); + assertThrows( + IndexOutOfBoundsException.class, + () -> buffer.slice(1, bufferSize), + "Should fail on out of bounds length"); + } + } + + /** Test nio buffer behaviour. */ + @Test + public void testNioBuffer() { + final int bufferSize = 1024; + ByteBuffer nioBuffer; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int i = 0; i < bufferSize; i++) { + buffer.writeByte(getByteValue(i)); + } + + nioBuffer = buffer.nioBuffer(); + } + assertEquals( + bufferSize, + nioBuffer.remaining(), + "NioBuffer should have " + bufferSize + " " + "bytes to read."); + + for (int i = nioBuffer.position(); i < nioBuffer.limit(); i++) { + assertEquals( + getByteValue(i), nioBuffer.get(i), "Readable bytes should be correct at index " + i); + } + } + + /** Test nio buffer slices. */ + @Test + public void testNioBufferSlices() { + final int bufferSize = 1024; + DatabricksArrowBuf buffer = newBuffer(bufferSize); + + // Write data to part of a slice and check that the original buffer is affected as well. + for (int i = 0; i < bufferSize; i++) { + int size = bufferSize - i; + testWriteAffectsOriginalBufferAndNioBuffer(buffer, i, size); + + int index = bufferSize - i; + testWriteAffectsOriginalBufferAndNioBuffer(buffer, index, i); + } + + // Test some random offsets and length. + Random random = new Random(); + for (int i = 0; i < 10; i++) { + int startIndex = random.nextInt(bufferSize); + int size = random.nextInt(bufferSize - startIndex); + testWriteAffectsOriginalBufferAndNioBuffer(buffer, startIndex, size); + } + } + + private void testWriteAffectsOriginalBufferAndNioBuffer( + DatabricksArrowBuf buffer, int startIndex, int length) { + final int bufferSize = (int) buffer.capacity(); + + // Write zeroes into the original buffer. + buffer.clear(); + for (int i = 0; i < bufferSize; i++) { + buffer.writeByte(getByteValue(i)); + } + + ByteBuffer nioBuffer = buffer.nioBuffer(startIndex, length); + assertEquals( + length, nioBuffer.remaining(), "NioBuffer should have " + length + " bytes to " + "read."); + for (int i = nioBuffer.position(); i < nioBuffer.limit(); i++) { + assertEquals( + getByteValue(i), nioBuffer.get(i), "Readable bytes should be correct at index " + i); + } + } + + /** Test memory address returned is correct. */ + @Test + public void testMemoryAddress() { + final int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + assertEquals(0, buffer.memoryAddress(), "Memory address should be correct."); + } + } + + /** Test toString works. */ + @Test + public void testToString() { + final int bufferSize = 1024; + DatabricksArrowBuf buffer; + try (DatabricksBufferAllocator allocator = new DatabricksBufferAllocator()) { + buffer = (DatabricksArrowBuf) allocator.buffer(bufferSize); + } + //noinspection Convert2MethodRef + assertDoesNotThrow(() -> buffer.toString()); + } + + /** Test reference equality. */ + @Test + public void testEquals() { + final int bufferSize = 1024; + DatabricksArrowBuf buffer = newBuffer(bufferSize); + //noinspection EqualsWithItself + assertEquals(buffer, buffer, "Same object should be equal"); + + DatabricksArrowBuf slice = (DatabricksArrowBuf) buffer.slice(0, bufferSize); + assertNotEquals(buffer, slice, "Different object should not be equal"); + } + + /** Test hash code. */ + @Test + public void testHashCode() { + final int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + assertEquals(buffer.hashCode(), buffer.hashCode(), "Same object should have same hashcode."); + + DatabricksArrowBuf slice = (DatabricksArrowBuf) buffer.slice(0, bufferSize); + assertNotEquals( + buffer.hashCode(), slice.hashCode(), "Different object should have different hashcode. "); + } + } + + /** Test get and set on long. */ + @Test + public void testGetAndSetLong() { + final int bufferSize = 1024; + ByteBuffer byteBuffer = newByteBuffer(bufferSize); + try (DatabricksArrowBuf buffer = newBuffer(byteBuffer)) { + for (int i = 0; i < bufferSize - Long.BYTES; i++) { + @SuppressWarnings("UnnecessaryLocalVariable") + long value = i; + byteBuffer.putLong(i, value); + assertEquals(value, buffer.getLong(i), "Long values should be same."); + } + + for (int i = 0; i < bufferSize - Long.BYTES; i++) { + long value = i + bufferSize; + buffer.setLong(i, value); + assertEquals(value, byteBuffer.getLong(i), "Long values should be same."); + } + } + } + + /** Test get and set on float. */ + @Test + public void testGetAndSetFloat() { + final int bufferSize = 1024; + ByteBuffer byteBuffer = newByteBuffer(bufferSize); + try (DatabricksArrowBuf buffer = newBuffer(byteBuffer)) { + for (int i = 0; i < bufferSize - Float.BYTES; i++) { + float value = (float) (i * Math.PI); + byteBuffer.putFloat(i, value); + assertEquals(value, buffer.getFloat(i), "Float values should be same."); + } + + for (int i = 0; i < bufferSize - Float.BYTES; i++) { + float value = (float) ((i + bufferSize) * Math.PI); + buffer.setFloat(i, value); + assertEquals(value, byteBuffer.getFloat(i), "Float values should be same."); + } + } + } + + /** Test get and set on double */ + @Test + public void testGetAndSetDouble() { + final int bufferSize = 1024; + ByteBuffer byteBuffer = newByteBuffer(bufferSize); + try (DatabricksArrowBuf buffer = newBuffer(byteBuffer)) { + for (int i = 0; i < bufferSize - Double.BYTES; i++) { + double value = i * Math.PI; + byteBuffer.putDouble(i, value); + assertEquals(value, buffer.getDouble(i), "Double values should be same."); + } + + for (int i = 0; i < bufferSize - Double.BYTES; i++) { + double value = (i + bufferSize) * Math.PI; + buffer.setDouble(i, value); + assertEquals(value, byteBuffer.getDouble(i), "Double values should be same."); + } + } + } + + /** Test get and set on char */ + @Test + public void testGetAndSetChar() { + final int bufferSize = 1024; + ByteBuffer byteBuffer = newByteBuffer(bufferSize); + try (DatabricksArrowBuf buffer = newBuffer(byteBuffer)) { + for (int i = 0; i < bufferSize - Character.BYTES; i++) { + char value = (char) getByteValue(i); + byteBuffer.putChar(i, value); + assertEquals(value, buffer.getChar(i), "Character values should be same."); + } + + for (int i = 0; i < bufferSize - Character.BYTES; i++) { + char value = (char) getByteValue(i + bufferSize); + buffer.setChar(i, value); + assertEquals(value, byteBuffer.getChar(i), "Character values should be same."); + } + } + } + + /** Test get and set on int */ + @Test + public void testGetAndSetInt() { + final int bufferSize = 1024; + ByteBuffer byteBuffer = newByteBuffer(bufferSize); + try (DatabricksArrowBuf buffer = newBuffer(byteBuffer)) { + for (int i = 0; i < bufferSize - Integer.BYTES; i++) { + @SuppressWarnings("UnnecessaryLocalVariable") + int value = i; + byteBuffer.putInt(i, value); + assertEquals(value, buffer.getInt(i), "Integer values should be same."); + } + + for (int i = 0; i < bufferSize - Integer.BYTES; i++) { + int value = i + bufferSize; + buffer.setInt(i, value); + assertEquals(value, byteBuffer.getInt(i), "Integer values should be same."); + } + } + } + + /** Test get and set on short */ + @Test + public void testGetAndSetShort() { + final int bufferSize = 1024; + ByteBuffer byteBuffer = newByteBuffer(bufferSize); + try (DatabricksArrowBuf buffer = newBuffer(byteBuffer)) { + for (int i = 0; i < bufferSize - Short.BYTES; i++) { + short value = (short) i; + byteBuffer.putShort(i, value); + assertEquals(value, buffer.getShort(i), "Short values should be same."); + } + + for (int i = 0; i < bufferSize - Short.BYTES; i++) { + short value = (short) (i + bufferSize); + buffer.setShort(i, value); + assertEquals(value, byteBuffer.getShort(i), "Short values should be same."); + } + } + } + + /** Test get and set on byte */ + @Test + public void testGetAndSetByte() { + final int bufferSize = 1024; + ByteBuffer byteBuffer = newByteBuffer(bufferSize); + try (DatabricksArrowBuf buffer = newBuffer(byteBuffer)) { + for (int i = 0; i < bufferSize - Byte.BYTES; i++) { + byte value = getByteValue(i); + byteBuffer.put(i, value); + assertEquals(value, buffer.getByte(i), "Byte values should be same."); + } + + for (int i = 0; i < bufferSize - Byte.BYTES; i++) { + byte value = getByteValue(i + bufferSize); + buffer.setByte(i, value); + assertEquals(value, byteBuffer.get(i), "Byte values should be same."); + } + } + } + + /** Test read byte and write byte. */ + @Test + public void testReadByteAndWriteByte() { + final int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + // Write bytes. + for (int i = 0; i < bufferSize - Byte.BYTES; i++) { + byte value = getByteValue(i); + if (i % 2 == 0) { + buffer.writeByte(value); + } else { + buffer.writeByte(i); + } + } + + // Read back the same bytes. + for (int i = 0; i < bufferSize - Byte.BYTES; i++) { + byte value = getByteValue(i); + assertEquals(value, buffer.readByte(), "Byte values should be same."); + } + } + } + + /** Test readBytes and writeBytes. */ + @Test + public void testWriteBytesAndReadBytes() { + final int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int size = 1; size < bufferSize; size++) { + logger.info("Testing bytes of length {}", size); + // Fill the write buffer. + byte[] writeBytes = new byte[size]; + for (int i = 0; i < size; i++) { + writeBytes[i] = getByteValue(i); + } + + // Write data. + buffer.clear(); + for (int i = 0; i + writeBytes.length < bufferSize; i += writeBytes.length) { + buffer.writeBytes(writeBytes); + } + + // Read the same data and validate. + for (int i = 0; i + writeBytes.length < bufferSize; i += writeBytes.length) { + byte[] readBytes = new byte[writeBytes.length]; + buffer.readBytes(readBytes); + assertArrayEquals(writeBytes, readBytes, "Byte values should be same."); + } + } + } + } + + /** Test write methods - writeShort, writeInt, writeLong, writeFloat, writeDouble. */ + @Test + public void testWriteOfNumbers() { + final int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + // Write random numbers. + Random random = new Random(); + List values = new ArrayList<>(); + for (int i = 0; i < bufferSize; /* incremented in loop */ ) { + int bytesAvailable = bufferSize - i; + if (bytesAvailable >= Double.BYTES) { + int rand = random.nextInt(5); + switch (rand) { + case 0: + short shortValue = (short) random.nextInt(Short.MAX_VALUE); + buffer.writeShort(shortValue); + values.add(shortValue); + i += Short.BYTES; + break; + case 1: + int intValue = random.nextInt(); + buffer.writeInt(intValue); + values.add(intValue); + i += Integer.BYTES; + break; + case 2: + long longValue = random.nextLong(); + buffer.writeLong(longValue); + values.add(longValue); + i += Long.BYTES; + break; + case 3: + float floatValue = random.nextFloat(); + buffer.writeFloat(floatValue); + values.add(floatValue); + i += Float.BYTES; + break; + case 4: + double doubleValue = random.nextDouble(); + buffer.writeDouble(doubleValue); + values.add(doubleValue); + i += Double.BYTES; + break; + default: + throw new IllegalArgumentException("Invalid random number " + rand); + } + } else { + for (int j = 0; j < bytesAvailable; j++) { + byte value = getByteValue(j); + buffer.writeByte(value); + values.add(value); + } + i += bytesAvailable; + } + } + + // Read and validate the numbers. + int index = 0; + for (Object value : values) { + if (value instanceof Byte) { + assertEquals( + (Byte) value, buffer.getByte(index), "Byte values should be same at index " + index); + index += Byte.BYTES; + } else if (value instanceof Short) { + assertEquals( + (short) value, + buffer.getShort(index), + "Short values should be same at index " + index); + index += Short.BYTES; + } else if (value instanceof Integer) { + assertEquals( + (int) value, buffer.getInt(index), "Integer values should be same at index " + index); + index += Integer.BYTES; + } else if (value instanceof Long) { + assertEquals( + (long) value, buffer.getLong(index), "Long values should be same at index " + index); + index += Long.BYTES; + } else if (value instanceof Float) { + assertEquals( + (float) value, + buffer.getFloat(index), + "Float values should be same at index " + index); + index += Float.BYTES; + } else if (value instanceof Double) { + assertEquals( + (double) value, + buffer.getDouble(index), + "Double values should be same at index " + index); + index += Double.BYTES; + } else { + throw new IllegalArgumentException("Invalid value " + value + " at index " + index); + } + } + } + } + + /** Test get and set on native byte arrays. */ + @Test + public void testGetAndSetBytesOnNativeByteArrays() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int size = 1; size < bufferSize; size++) { + logger.info("Testing bytes of length {}", size); + // Fill the write buffer. + byte[] writeBytes = new byte[size]; + for (int i = 0; i < size; i++) { + writeBytes[i] = getByteValue(i); + } + + // Write data. + buffer.clear(); + for (int i = 0; i + writeBytes.length < bufferSize; i += writeBytes.length) { + buffer.setBytes(i, writeBytes); + } + + // Read the same data and validate. + for (int i = 0; i + writeBytes.length < bufferSize; i += writeBytes.length) { + byte[] readBytes = new byte[writeBytes.length]; + buffer.getBytes(i, readBytes); + assertArrayEquals(writeBytes, readBytes, "Byte values should be same."); + } + } + } + } + + /** Test get and set on native byte arrays with index. */ + @Test + public void testGetAndSetBytesOnNativeByteArraysWithIndex() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int size = 1; size < bufferSize; size++) { + logger.info("Testing bytes of length {}", size); + // Fill the write buffer. + byte[] writeBytes = new byte[size]; + for (int i = 0; i < size; i++) { + writeBytes[i] = getByteValue(i); + } + + // Write data. + buffer.clear(); + for (int i = 0; i + writeBytes.length < bufferSize; i += writeBytes.length) { + // Write total=size data in stages of length 1, 2, 3, ... + int windex = 0; + int wlen = 1; + while (windex < writeBytes.length) { + int len = Math.min(writeBytes.length - windex, wlen - windex); + buffer.setBytes(i + windex, writeBytes, windex, len); + windex += len; + wlen++; + } + } + + // Read the same data and validate. + byte[] readBytes = new byte[writeBytes.length]; + for (int i = 0; i + readBytes.length < bufferSize; i += readBytes.length) { + // Read total=size data in stages of length 1, 2, 3, ... + int rindex = 0; + int rlen = 1; + while (rindex < readBytes.length) { + int len = Math.min(readBytes.length - rindex, rlen - rindex); + buffer.getBytes(i + rindex, readBytes, rindex, len); + rindex += len; + rlen++; + } + assertArrayEquals(writeBytes, readBytes, "Byte values should be same."); + } + } + } + } + + /** Test get and set on byte buffers. */ + @Test + public void testGetAndSetBytesOnByteBuffers() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int size = 1; size < bufferSize; size++) { + logger.info("Testing byte buffer of length {}", size); + + // Fill the write buffer. + ByteBuffer writeByteBuffer = newByteBuffer(size); + for (int i = 0; i < size; i++) { + writeByteBuffer.put(getByteValue(i)); + } + + // Write data. + buffer.clear(); + writeByteBuffer.flip(); + for (int i = 0; + i + writeByteBuffer.capacity() < bufferSize; + i += writeByteBuffer.capacity()) { + writeByteBuffer.rewind(); + buffer.setBytes(i, writeByteBuffer); + } + + // Read the same data and validate. + for (int i = 0; + i + writeByteBuffer.capacity() < bufferSize; + i += writeByteBuffer.capacity()) { + ByteBuffer readByteBuffer = newByteBuffer(writeByteBuffer.capacity()); + buffer.getBytes(i, readByteBuffer); + assertArrayEquals( + writeByteBuffer.array(), readByteBuffer.array(), "Byte values should be same."); + } + } + } + } + + /** Test get and set on byte buffers with index. */ + @Test + public void testGetAndSetBytesOnByteBuffersWithIndex() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int size = 1; size < bufferSize; size++) { + logger.info("Testing bytes of length {}", size); + + // Fill the write buffer. + ByteBuffer writeByteBuffer = newByteBuffer(size); + for (int i = 0; i < size; i++) { + writeByteBuffer.put(getByteValue(i)); + } + + // Write data. + buffer.clear(); + writeByteBuffer.flip(); + for (int i = 0; + i + writeByteBuffer.capacity() < bufferSize; + i += writeByteBuffer.capacity()) { + // Write total=size data in stages of length 1, 2, 3, ... + int windex = 0; + int wlen = 1; + while (windex < writeByteBuffer.capacity()) { + int len = Math.min(writeByteBuffer.capacity() - windex, wlen - windex); + buffer.setBytes(i + windex, writeByteBuffer, windex, len); + windex += len; + wlen++; + } + } + + // Read the same data and validate. + ByteBuffer readByteBuffer = newByteBuffer(writeByteBuffer.capacity()); + for (int i = 0; + i + readByteBuffer.capacity() < bufferSize; + i += readByteBuffer.capacity()) { + buffer.getBytes(i, readByteBuffer); + assertArrayEquals( + writeByteBuffer.array(), readByteBuffer.array(), "Byte values should be same."); + } + } + } + } + + /** Test get and set bytes on Arrow buffer. */ + @Test + public void testGetAndSetBytesOnArrowBuf() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int size = 1; size < bufferSize; size++) { + logger.info("Testing buffers of length {}", size); + + // Set data in write buffer. + DatabricksArrowBuf writeBuffer = newBuffer(size); + for (int i = 0; i < size; i++) { + writeBuffer.writeByte(getByteValue(i)); + } + + // Copy data to buffer. + buffer.clear(); + for (int i = 0; + i + writeBuffer.capacity() < bufferSize; + i += (int) writeBuffer.capacity()) { + writeBuffer.readerIndex(0); + buffer.setBytes(i, writeBuffer); + } + + // Read the same data and validate. + DatabricksArrowBuf readBuffer = newBuffer(size); + for (int i = 0; i + readBuffer.capacity() < bufferSize; i += (int) readBuffer.capacity()) { + readBuffer.clear(); + buffer.getBytes(i, readBuffer, 0, (int) readBuffer.capacity()); + + byte[] readBytes = new byte[(int) readBuffer.capacity()]; + readBuffer.getBytes(0, readBytes); + + byte[] writeBytes = new byte[readBytes.length]; + writeBuffer.getBytes(0, writeBytes); + + assertArrayEquals(writeBytes, readBytes, "Byte values should be same for size " + size); + } + } + } + } + + @Test + public void testGetAndSetBytesOnArrowBufWithIndex() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int size = 1; size < bufferSize; size++) { + logger.info("Testing buffers of length {}", size); + + // Set data in write buffer. + DatabricksArrowBuf writeBuffer = newBuffer(size); + for (int i = 0; i < size; i++) { + writeBuffer.writeByte(getByteValue(i)); + } + + // Copy data to buffer. + buffer.clear(); + for (int i = 0; + i + writeBuffer.capacity() < bufferSize; + i += (int) writeBuffer.capacity()) { + writeBuffer.readerIndex(0); + // Write total=size data in stages of length 1, 2, 3, ... + int windex = 0; + int wlen = 1; + while (windex < writeBuffer.capacity()) { + int len = Math.min((int) writeBuffer.capacity() - windex, wlen - windex); + buffer.setBytes(i + windex, writeBuffer, windex, len); + windex += len; + wlen++; + } + } + + // Read the same data and validate. + DatabricksArrowBuf readBuffer = newBuffer(size); + for (int i = 0; i + readBuffer.capacity() < bufferSize; i += (int) readBuffer.capacity()) { + readBuffer.clear(); + buffer.getBytes(i, readBuffer, 0, (int) readBuffer.capacity()); + + byte[] readBytes = new byte[(int) readBuffer.capacity()]; + readBuffer.getBytes(0, readBytes); + + byte[] writeBytes = new byte[readBytes.length]; + writeBuffer.getBytes(0, writeBytes); + + assertArrayEquals(writeBytes, readBytes, "Byte values should be same for size " + size); + } + } + } + } + + /** Test get and set bytes on streams. */ + @Test + public void testGetAndSetBytesOnInputAndOutputStream() throws IOException { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int size = 1; size < bufferSize; size++) { + logger.info("Testing streams of length {}", size); + // Fill the write buffer. + byte[] writeBytes = new byte[size]; + for (int i = 0; i < size; i++) { + writeBytes[i] = getByteValue(i); + } + + // Write data. + buffer.clear(); + for (int i = 0; i + writeBytes.length < bufferSize; i += writeBytes.length) { + buffer.setBytes(i, new ByteArrayInputStream(writeBytes), writeBytes.length); + } + + // Read the same data and validate. + for (int i = 0; i + writeBytes.length < bufferSize; i += writeBytes.length) { + ByteArrayOutputStream readBytes = new ByteArrayOutputStream(writeBytes.length); + buffer.getBytes(i, readBytes, writeBytes.length); + assertArrayEquals(writeBytes, readBytes.toByteArray(), "Byte values should be same."); + } + } + } + } + + /** Test possible memory consumed. */ + @Test + public void testPossibleMemoryConsumed() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + assertEquals( + buffer.capacity(), + buffer.getPossibleMemoryConsumed(), + "Memory consumed should be same for size " + buffer.capacity()); + } + } + + /** Test actual memory consumed. */ + @Test + public void testActualMemoryConsumed() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + assertEquals( + buffer.capacity(), + buffer.getActualMemoryConsumed(), + "Memory consumed should be same for size " + buffer.capacity()); + } + } + + /** Test hex string does not throw exception. */ + @Test + public void testToHexString() { + int bufferSize = 1024; + DatabricksArrowBuf buffer = newBuffer(bufferSize); + + assertDoesNotThrow( + () -> buffer.toHexString(0, bufferSize), "To hex string should not throw exception."); + for (int i = 0; i < buffer.capacity(); i++) { + buffer.writeByte(getByteValue(i)); + assertDoesNotThrow( + () -> buffer.toHexString(0, bufferSize), "To hex string should not throw exception."); + } + + buffer.clear(); + assertDoesNotThrow( + () -> buffer.toHexString(0, bufferSize), "To hex string should not throw exception."); + + buffer.close(); + assertDoesNotThrow( + () -> buffer.toHexString(0, bufferSize), "To hex string should not throw exception."); + } + + /** Test print. */ + @Test + public void testPrint() { + int bufferSize = 1024; + DatabricksArrowBuf buffer = newBuffer(bufferSize); + + testPrint(buffer); + for (int i = 0; i < buffer.capacity(); i++) { + buffer.writeByte(getByteValue(i)); + testPrint(buffer); + } + } + + private void testPrint(DatabricksArrowBuf buffer) { + for (int indent = 0; indent <= 8; indent++) { + StringBuilder sb = new StringBuilder(); + buffer.print(sb, indent); + assertTrue(sb.length() > 0, "Print failed"); + } + } + + /** Test reader and writer index. */ + @Test + public void testReaderAndWriterIndex() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int i = 0; i < buffer.capacity(); i++) { + buffer.writeByte(getByteValue(i)); + assertEquals(i + 1, buffer.writerIndex(), "writerIndex should be same"); + } + + for (int i = 0; i < buffer.capacity(); i++) { + buffer.readByte(); + assertEquals(i + 1, buffer.readerIndex(), "readerIndex should be same"); + } + } + } + + /** Test empty buffer (size 0) behavior. */ + @Test + public void testEmptyBuffer() { + try (DatabricksArrowBuf buffer = newBuffer(0)) { + // Basic properties + assertEquals(0, buffer.capacity(), "Empty buffer should have 0 capacity"); + assertEquals(0, buffer.readableBytes(), "Empty buffer should have 0 readable bytes"); + assertEquals(0, buffer.writableBytes(), "Empty buffer should have 0 writable bytes"); + + // checkBytes with 0 length should succeed + assertDoesNotThrow(() -> buffer.checkBytes(0, 0), "checkBytes(0, 0) should not throw"); + + // nioBuffer should return empty buffer + ByteBuffer nioBuffer = buffer.nioBuffer(); + assertEquals(0, nioBuffer.remaining(), "NIO buffer should be empty"); + + // slice() should work + ArrowBuf slice = buffer.slice(); + assertEquals(0, slice.capacity(), "Slice of empty buffer should be empty"); + + // slice(0, 0) should work + ArrowBuf slice2 = buffer.slice(0, 0); + assertEquals(0, slice2.capacity(), "slice(0, 0) should return empty buffer"); + } + } + + /** Test set zero. */ + @Test + public void testSetZero() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int i = 0; i < buffer.capacity(); i++) { + buffer.writeByte(getByteValue(i)); + } + + buffer.clear(); + int index = 0; + int size = 1; + while (index < buffer.capacity()) { + int len = Math.min((int) buffer.capacity() - index, size); + buffer.setZero(index, len); + index += len; + size += 1; + } + + for (int i = 0; i < buffer.capacity(); i++) { + assertEquals(0, buffer.getByte(i), "Byte values should be same at index " + index); + } + } + } + + /** Test set zero. */ + @Test + public void testSetOne() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int i = 0; i < buffer.capacity(); i++) { + buffer.writeByte(getByteValue(i)); + } + + buffer.clear(); + int index = 0; + int size = 1; + while (index < buffer.capacity()) { + int len = Math.min((int) buffer.capacity() - index, size); + //noinspection deprecation + buffer.setOne(index, len); + index += len; + size += 1; + } + + for (int i = 0; i < buffer.capacity(); i++) { + assertEquals( + (byte) 0xff, buffer.getByte(i), "Byte values should be same at index " + index); + } + } + } + + /** Test realloc. */ + @Test + public void testRealloc() { + int bufferSize = 1024; + DatabricksArrowBuf buffer = newBuffer(bufferSize); + + for (int size = 0; size < buffer.capacity(); size++) { + ArrowBuf realloced = buffer.reallocIfNeeded(size); + assertEquals(buffer, realloced, "Should be the same"); + } + + assertThrows( + UnsupportedOperationException.class, + () -> buffer.reallocIfNeeded(buffer.capacity() + 1), + "Realloc above capacity should fail."); + } + + /** Test clear. */ + @Test + public void testClear() { + int bufferSize = 1024; + try (DatabricksArrowBuf buffer = newBuffer(bufferSize)) { + for (int i = 0; i < buffer.capacity(); i++) { + buffer.writeByte(getByteValue(i)); + } + for (int i = 0; i < buffer.capacity(); i++) { + buffer.readByte(); + } + + assertEquals(buffer.capacity(), buffer.writerIndex(), "Write index should match"); + assertEquals(buffer.capacity(), buffer.readerIndex(), "Read index should match"); + + buffer.clear(); + assertEquals(0, buffer.writerIndex(), "Write index should be zero"); + assertEquals(0, buffer.readerIndex(), "Write index should be zero"); + } + } + + @SuppressWarnings("SameParameterValue") + private ByteBuffer newByteBuffer(int size) { + ByteBuffer byteBuffer = ByteBuffer.allocate(size); + byteBuffer.order(BYTE_ORDER); + return byteBuffer; + } + + private DatabricksArrowBuf newBuffer(ByteBuffer byteBuffer) { + final int bufferSize = byteBuffer.capacity(); + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + DatabricksReferenceManager refManager = new DatabricksReferenceManager(allocator, bufferSize); + return new DatabricksArrowBuf(refManager, null, byteBuffer, 0, bufferSize); + } + + @SuppressWarnings("resource") + private DatabricksArrowBuf newBuffer(int size) { + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + return (DatabricksArrowBuf) allocator.buffer(size); + } + + private byte getByteValue(int index) { + return (byte) (index % 256); + } +} diff --git a/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchBinaryStringTypesTest.java b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchBinaryStringTypesTest.java new file mode 100644 index 000000000..78cc3fdfe --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchBinaryStringTypesTest.java @@ -0,0 +1,413 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ViewVarBinaryVector; +import org.apache.arrow.vector.ViewVarCharVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +/** Test binary and string Arrow data types. */ +@Tag("Jvm17PlusAndArrowToNioReflectionDisabled") +@EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) +public class DatabricksArrowPatchBinaryStringTypesTest + extends AbstractDatabricksArrowPatchTypesTest { + + /** Test read and write of binary types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testBinaryTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testBinary = new TestBinaryTypes(); + byte[] data = writeData(testBinary, totalRows, writeAllocator); + readAndValidate(testBinary, data, readAllocator); + } + + /** Test read and write of UTF-8 string types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testUtf8Types( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testUtf8 = new TestUtf8Types(); + byte[] data = writeData(testUtf8, totalRows, writeAllocator); + readAndValidate(testUtf8, data, readAllocator); + } + + /** Test read and write of UTF8 view types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testUtf8ViewTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testUtf8View = new TestUtf8ViewTypes(); + byte[] data = writeData(testUtf8View, totalRows, writeAllocator); + readAndValidate(testUtf8View, data, readAllocator); + } + + /** Test read and write of binary view types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testBinaryViewTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testBinaryView = new TestBinaryViewTypes(); + byte[] data = writeData(testBinaryView, totalRows, writeAllocator); + readAndValidate(testBinaryView, data, readAllocator); + } + + /** Test binary types */ + private class TestBinaryTypes implements DataTester { + private final Field fixedSizeBinaryField; + private final Field varBinaryField; + private final Field largeVarBinaryField; + private final Schema schema; + private final int FIXED_SIZE_BINARY_LENGTH = 16; + private final int VAR_BINARY_LENGTH = 32; + private final int LARGE_VAR_BINARY_LENGTH = 50; + + TestBinaryTypes() { + fixedSizeBinaryField = newFixedSizeBinaryField(); + varBinaryField = newVarBinaryField(); + largeVarBinaryField = newLargeVarBinaryField(); + schema = new Schema(Arrays.asList(fixedSizeBinaryField, varBinaryField, largeVarBinaryField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + FixedSizeBinaryVector fixedSizeBinaryVector = + (FixedSizeBinaryVector) vectorSchemaRoot.getVector(fixedSizeBinaryField.getName()); + VarBinaryVector varBinaryVector = + (VarBinaryVector) vectorSchemaRoot.getVector(varBinaryField.getName()); + LargeVarBinaryVector largeVarBinaryVector = + (LargeVarBinaryVector) vectorSchemaRoot.getVector(largeVarBinaryField.getName()); + + // Set fixed-size binary (16 bytes). + fixedSizeBinaryVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + fixedSizeBinaryVector.setNull(i); + } else { + fixedSizeBinaryVector.set(i, getFixedSizeBinary(i, FIXED_SIZE_BINARY_LENGTH)); + } + } + + // Set variable binary. + varBinaryVector.allocateNew(batchSize * 20L, batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + varBinaryVector.setNull(i); + } else { + varBinaryVector.set(i, getVarBinary(i, VAR_BINARY_LENGTH)); + } + } + + // Set large variable binary. + largeVarBinaryVector.clear(); + largeVarBinaryVector.allocateNew(batchSize * 50L, batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + largeVarBinaryVector.setNull(i); + } else { + largeVarBinaryVector.set(i, getLargeVarBinary(i, LARGE_VAR_BINARY_LENGTH)); + } + } + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + FixedSizeBinaryVector fixedSizeBinaryVector = + (FixedSizeBinaryVector) vectorSchemaRoot.getVector(fixedSizeBinaryField.getName()); + VarBinaryVector varBinaryVector = + (VarBinaryVector) vectorSchemaRoot.getVector(varBinaryField.getName()); + LargeVarBinaryVector largeVarBinaryVector = + (LargeVarBinaryVector) vectorSchemaRoot.getVector(largeVarBinaryField.getName()); + + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + // Validate fixed-size binary (16 bytes) + if (i % 2 == 0) { + assertTrue( + fixedSizeBinaryVector.isNull(i), "Fixed-size binary should be null at index " + i); + } else { + byte[] expected = getFixedSizeBinary(i, FIXED_SIZE_BINARY_LENGTH); + byte[] actual = fixedSizeBinaryVector.get(i); + assertNotNull(actual, "Fixed-size binary should not be null at index " + i); + assertArrayEquals(expected, actual, "Fixed-size binary mismatch at index " + i); + } + + // Validate variable binary + if (i % 2 == 0) { + assertTrue(varBinaryVector.isNull(i), "Variable binary should be null at index " + i); + } else { + byte[] expected = getVarBinary(i, VAR_BINARY_LENGTH); + byte[] actual = varBinaryVector.get(i); + assertNotNull(actual, "Variable binary should not be null at index " + i); + assertArrayEquals(expected, actual, "Variable binary mismatch at index " + i); + } + + // Validate large variable binary + if (i % 2 == 0) { + assertTrue( + largeVarBinaryVector.isNull(i), "Large variable binary should be null at index " + i); + } else { + byte[] expected = getLargeVarBinary(i, LARGE_VAR_BINARY_LENGTH); + byte[] actual = largeVarBinaryVector.get(i); + assertNotNull(actual, "Large variable binary should not be null at index " + i); + assertArrayEquals(expected, actual, "Large variable binary mismatch at index " + i); + } + } + } + } + + /** Test UTF-8 string types */ + private class TestUtf8Types implements DataTester { + private final Field utf8Field; + private final Field largeUtf8Field; + private final Schema schema; + + TestUtf8Types() { + utf8Field = newUtf8Field(); + largeUtf8Field = newLargeUtf8Field(); + schema = new Schema(Arrays.asList(utf8Field, largeUtf8Field)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + VarCharVector utf8Vector = (VarCharVector) vectorSchemaRoot.getVector(utf8Field.getName()); + LargeVarCharVector largeUtf8Vector = + (LargeVarCharVector) vectorSchemaRoot.getVector(largeUtf8Field.getName()); + + // Set UTF-8 strings. + utf8Vector.allocateNew(batchSize * 50L, batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + utf8Vector.setNull(i); + } else { + utf8Vector.set(i, getUtf8String(i).getBytes(StandardCharsets.UTF_8)); + } + } + + // Set large UTF-8 strings. + largeUtf8Vector.clear(); + largeUtf8Vector.allocateNew(batchSize * 100L, batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + largeUtf8Vector.setNull(i); + } else { + largeUtf8Vector.set(i, getLargeUtf8String(i).getBytes(StandardCharsets.UTF_8)); + } + } + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + VarCharVector utf8Vector = (VarCharVector) vectorSchemaRoot.getVector(utf8Field.getName()); + LargeVarCharVector largeUtf8Vector = + (LargeVarCharVector) vectorSchemaRoot.getVector(largeUtf8Field.getName()); + + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + // Validate UTF-8 string + if (i % 2 == 0) { + assertTrue(utf8Vector.isNull(i), "UTF-8 string should be null at index " + i); + } else { + String expected = getUtf8String(i); + byte[] actualBytes = utf8Vector.get(i); + assertNotNull(actualBytes, "UTF-8 string should not be null at index " + i); + String actual = new String(actualBytes, StandardCharsets.UTF_8); + assertEquals(expected, actual, "UTF-8 string mismatch at index " + i); + } + + // Validate large UTF-8 string + if (i % 2 == 0) { + assertTrue(largeUtf8Vector.isNull(i), "Large UTF-8 string should be null at index " + i); + } else { + String expected = getLargeUtf8String(i); + byte[] actualBytes = largeUtf8Vector.get(i); + assertNotNull(actualBytes, "Large UTF-8 string should not be null at index " + i); + String actual = new String(actualBytes, StandardCharsets.UTF_8); + assertEquals(expected, actual, "Large UTF-8 string mismatch at index " + i); + } + } + } + } + + /** Test UTF8 view types */ + private class TestUtf8ViewTypes implements DataTester { + private final Field utf8ViewField; + private final Schema schema; + + TestUtf8ViewTypes() { + utf8ViewField = newUtf8ViewField(); + schema = new Schema(Collections.singletonList(utf8ViewField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + ViewVarCharVector utf8ViewVector = + (ViewVarCharVector) vectorSchemaRoot.getVector(utf8ViewField.getName()); + + // Calculate the total bytes needed for the data buffer + long totalDataBytes = 0; + for (int i = 0; i < batchSize; i++) { + byte[] bytes = getUtf8ViewString(i).getBytes(StandardCharsets.UTF_8); + // Round to nearest power of 64. + totalDataBytes += (bytes.length + 63) & ~63; + } + + utf8ViewVector.allocateNew(totalDataBytes, batchSize); + + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + utf8ViewVector.setNull(i); + } else { + utf8ViewVector.set(i, getUtf8ViewString(i).getBytes(StandardCharsets.UTF_8)); + } + } + utf8ViewVector.setValueCount(batchSize); + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + ViewVarCharVector utf8ViewVector = + (ViewVarCharVector) vectorSchemaRoot.getVector(utf8ViewField.getName()); + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + if (i % 2 == 0) { + assertTrue(utf8ViewVector.isNull(i), "UTF8 view string should be null at index " + i); + } else { + String expected = getUtf8ViewString(i); + byte[] actualBytes = utf8ViewVector.get(i); + assertNotNull(actualBytes, "UTF8 view string should not be null at index " + i); + String actual = new String(actualBytes, StandardCharsets.UTF_8); + assertEquals(expected, actual, "UTF8 view string mismatch at index " + i); + } + } + } + + private Field newUtf8ViewField() { + return new Field("utf8-view-string", FieldType.nullable(new ArrowType.Utf8View()), null); + } + + private String getUtf8ViewString(int index) { + // Strings of length <= 12 are inlined. + // See https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-view-layout + if (index % 3 == 0) { + return "short-" + index; + } else { + return "Utf8View-" + index + "-StringData"; + } + } + } + + /** Test binary view types */ + private class TestBinaryViewTypes implements DataTester { + private final Field binaryViewField; + private final Schema schema; + + TestBinaryViewTypes() { + binaryViewField = newBinaryViewField(); + schema = new Schema(Collections.singletonList(binaryViewField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + ViewVarBinaryVector binaryViewVector = + (ViewVarBinaryVector) vectorSchemaRoot.getVector(binaryViewField.getName()); + + // Calculate the total bytes needed for the data buffer + long totalDataBytes = 0; + for (int i = 0; i < batchSize; i++) { + byte[] bytes = getBinaryViewData(i); + // Round to nearest power of 64. + totalDataBytes += (bytes.length + 63) & ~63; + } + + binaryViewVector.allocateNew(totalDataBytes, batchSize); + + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + binaryViewVector.setNull(i); + } else { + binaryViewVector.set(i, getBinaryViewData(i)); + } + } + binaryViewVector.setValueCount(batchSize); + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + ViewVarBinaryVector binaryViewVector = + (ViewVarBinaryVector) vectorSchemaRoot.getVector(binaryViewField.getName()); + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + if (i % 2 == 0) { + assertTrue(binaryViewVector.isNull(i), "Binary view should be null at index " + i); + } else { + byte[] expected = getBinaryViewData(i); + byte[] actual = binaryViewVector.get(i); + assertNotNull(actual, "Binary view should not be null at index " + i); + assertArrayEquals(expected, actual, "Binary view mismatch at index " + i); + } + } + } + + private Field newBinaryViewField() { + return new Field("binary-view", FieldType.nullable(new ArrowType.BinaryView()), null); + } + + private byte[] getBinaryViewData(int index) { + int length = (index % 20) + 1; + byte[] data = new byte[length]; + for (int i = 0; i < length; i++) { + data[i] = (byte) ((index * 5 + i) % 256); + } + return data; + } + } +} diff --git a/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchComplexTypesTest.java b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchComplexTypesTest.java new file mode 100644 index 000000000..4b3c0dd94 --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchComplexTypesTest.java @@ -0,0 +1,1959 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Period; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.ListViewVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.RunEndEncodedVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.impl.UnionLargeListWriter; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.UnionMode; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +/** Test complex Arrow data types (lists, structs, maps, unions, dictionary, REE). */ +@Tag("Jvm17PlusAndArrowToNioReflectionDisabled") +@EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) +public class DatabricksArrowPatchComplexTypesTest extends AbstractDatabricksArrowPatchTypesTest { + + /** Test read and write of list types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testListTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testList = new TestListTypes(); + byte[] data = writeData(testList, totalRows, writeAllocator); + readAndValidate(testList, data, readAllocator); + } + + /** Test read and write of list view types. */ + @ParameterizedTest + @MethodSource("getBufferAllocatorsSmallRows") + public void testListViewTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testListView = new TestListViewTypes(); + byte[] data = writeData(testListView, totalRows, writeAllocator); + readAndValidate(testListView, data, readAllocator); + } + + /** Test read and write of large list view types. */ + @ParameterizedTest + @MethodSource("getBufferAllocatorsSmallRows") + public void testLargeListViewTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testLargeListView = new TestLargeListViewTypes(); + byte[] data = writeData(testLargeListView, totalRows, writeAllocator); + readAndValidate(testLargeListView, data, readAllocator); + } + + /** Test read and write of fixed-size list types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testFixedSizeListTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testFixedSizeList = new TestFixedSizeListTypes(); + byte[] data = writeData(testFixedSizeList, totalRows, writeAllocator); + readAndValidate(testFixedSizeList, data, readAllocator); + } + + /** Test read and write of struct types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testStructTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testStruct = new TestStructTypes(); + byte[] data = writeData(testStruct, totalRows, writeAllocator); + readAndValidate(testStruct, data, readAllocator); + } + + /** Test read and write of map types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testMapTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testMap = new TestMapTypes(); + byte[] data = writeData(testMap, totalRows, writeAllocator); + readAndValidate(testMap, data, readAllocator); + } + + /** Test read and write of union types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testUnionTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testUnion = new TestUnionTypes(); + byte[] data = writeData(testUnion, totalRows, writeAllocator); + readAndValidate(testUnion, data, readAllocator); + } + + /** Test read and write of dictionary types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testDictionaryTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testDictionary = new TestDictionaryTypes(); + byte[] data = writeData(testDictionary, totalRows, writeAllocator); + readAndValidate(testDictionary, data, readAllocator); + } + + /** Test read and write of run-end encoded types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testRunEndEncodedTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testRunEndEncoded = new TestRunEndEncodedTypes(); + byte[] data = writeData(testRunEndEncoded, totalRows, writeAllocator); + readAndValidate(testRunEndEncoded, data, readAllocator); + } + + /** Test list types */ + private class TestListTypes implements DataTester { + private final Field listIntField; + private final Field largeListIntField; + private final Schema schema; + + TestListTypes() { + listIntField = newListIntField(); + largeListIntField = newLargeListIntField(); + schema = new Schema(Arrays.asList(listIntField, largeListIntField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + ListVector listIntVector = (ListVector) vectorSchemaRoot.getVector(listIntField.getName()); + LargeListVector largeListIntVector = + (LargeListVector) vectorSchemaRoot.getVector(largeListIntField.getName()); + + // Set list of integers. + listIntVector.allocateNew(); + UnionListWriter listWriter = listIntVector.getWriter(); + for (int i = 0; i < batchSize; i++) { + listWriter.setPosition(i); + if (i % 2 == 0) { + // Null list + listWriter.writeNull(); + } else { + // Write list with varying number of elements + listWriter.startList(); + int listSize = getListSize(i); + for (int j = 0; j < listSize; j++) { + listWriter.integer().writeInt(getListElement(i, j)); + } + listWriter.endList(); + } + } + listWriter.setValueCount(batchSize); + + // Set large list of integers. + largeListIntVector.allocateNew(); + UnionLargeListWriter largeListWriter = largeListIntVector.getWriter(); + for (int i = 0; i < batchSize; i++) { + largeListWriter.setPosition(i); + if (i % 2 == 0) { + // Null list + largeListWriter.writeNull(); + } else { + // Write list with varying number of elements + largeListWriter.startList(); + int listSize = getLargeListSize(i); + for (int j = 0; j < listSize; j++) { + largeListWriter.integer().writeInt(getLargeListElement(i, j)); + } + largeListWriter.endList(); + } + } + largeListWriter.setValueCount(batchSize); + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + ListVector listIntVector = (ListVector) vectorSchemaRoot.getVector(listIntField.getName()); + LargeListVector largeListIntVector = + (LargeListVector) vectorSchemaRoot.getVector(largeListIntField.getName()); + + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + // Validate list of integers + if (i % 2 == 0) { + assertTrue(listIntVector.isNull(i), "List should be null at index " + i); + } else { + Object listObj = listIntVector.getObject(i); + assertNotNull(listObj, "List should not be null at index " + i); + int expectedSize = getListSize(i); + @SuppressWarnings("unchecked") + java.util.List list = (java.util.List) listObj; + assertEquals(expectedSize, list.size(), "List size mismatch at index " + i); + for (int j = 0; j < expectedSize; j++) { + Integer element = list.get(j); + assertEquals( + getListElement(i, j), + element, + "List element mismatch at index " + i + "[" + j + "]"); + } + } + + // Validate large list of integers + if (i % 2 == 0) { + assertTrue(largeListIntVector.isNull(i), "Large list should be null at index " + i); + } else { + Object listObj = largeListIntVector.getObject(i); + assertNotNull(listObj, "Large list should not be null at index " + i); + int expectedSize = getLargeListSize(i); + @SuppressWarnings("unchecked") + java.util.List list = (java.util.List) listObj; + assertEquals(expectedSize, list.size(), "Large list size mismatch at index " + i); + for (int j = 0; j < expectedSize; j++) { + Integer element = list.get(j); + assertEquals( + getLargeListElement(i, j), + element, + "Large list element mismatch at index " + i + "[" + j + "]"); + } + } + } + } + } + + /** Test list view types */ + private class TestListViewTypes implements DataTester { + private final Field listViewField; + private final Schema schema; + + TestListViewTypes() { + listViewField = newListViewField(); + schema = new Schema(Collections.singletonList(listViewField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + ListViewVector listViewVector = + (ListViewVector) vectorSchemaRoot.getVector(listViewField.getName()); + listViewVector.allocateNew(); + + IntVector childVector = (IntVector) listViewVector.getDataVector(); + + // Calculate total child elements needed + int totalElements = 0; + for (int i = 0; i < batchSize; i++) { + if (i % 2 != 0) { + totalElements += getListViewSize(i); + } + } + childVector.allocateNew(totalElements); + + // Populate child vector with all data first + int childIndex = 0; + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + listViewVector.setNull(i); + } else { + int startOffset = listViewVector.startNewValue(i); + int listSize = getListViewSize(i); + for (int j = 0; j < listSize; j++) { + childVector.set(startOffset + j, getListViewElement(i, j)); + } + listViewVector.endValue(i, listSize); + } + } + listViewVector.setValueCount(batchSize); + childVector.setValueCount(childIndex); + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + ListViewVector listViewVector = + (ListViewVector) vectorSchemaRoot.getVector(listViewField.getName()); + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + if (i % 2 == 0) { + assertTrue(listViewVector.isNull(i), "List view should be null at index " + i); + } else { + assertFalse(listViewVector.isNull(i), "List view should not be null at index " + i); + List list = listViewVector.getObject(i); + assertNotNull(list, "List view should not be null at index " + i); + int expectedSize = getListViewSize(i); + assertEquals(expectedSize, list.size(), "List view size mismatch at index " + i); + for (int j = 0; j < expectedSize; j++) { + assertEquals( + getListViewElement(i, j), + list.get(j), + "List view element mismatch at index " + i + "[" + j + "]"); + } + } + } + } + + private Field newListViewField() { + return new Field( + "list-view-int", + FieldType.nullable(new ArrowType.ListView()), + Collections.singletonList( + new Field("$data$", FieldType.nullable(new ArrowType.Int(32, true)), null))); + } + + private int getListViewSize(int index) { + return (index % 5) + 1; + } + + private int getListViewElement(int rowIndex, int elementIndex) { + return rowIndex * 50 + elementIndex; + } + } + + /** Test large list view types */ + private class TestLargeListViewTypes implements DataTester { + private final Field largeListViewField; + private final Schema schema; + + TestLargeListViewTypes() { + largeListViewField = newLargeListViewField(); + schema = new Schema(Collections.singletonList(largeListViewField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + LargeListViewVector largeListViewVector = + (LargeListViewVector) vectorSchemaRoot.getVector(largeListViewField.getName()); + largeListViewVector.allocateNew(); + + IntVector childVector = (IntVector) largeListViewVector.getDataVector(); + + // Calculate total child elements needed + int totalElements = 0; + for (int i = 0; i < batchSize; i++) { + if (i % 2 != 0) { + totalElements += getLargeListViewSize(i); + } + } + childVector.allocateNew(totalElements); + + // Populate child vector with all data first + int childIndex = 0; + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + largeListViewVector.setNull(i); + } else { + long startOffset = largeListViewVector.startNewValue(i); + int listSize = getLargeListViewSize(i); + for (int j = 0; j < listSize; j++) { + childVector.set((int) startOffset + j, getLargeListViewElement(i, j)); + } + largeListViewVector.endValue(i, listSize); + } + } + largeListViewVector.setValueCount(batchSize); + childVector.setValueCount(childIndex); + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + LargeListViewVector largeListViewVector = + (LargeListViewVector) vectorSchemaRoot.getVector(largeListViewField.getName()); + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + if (i % 2 == 0) { + assertTrue(largeListViewVector.isNull(i), "Large list view should be null at index " + i); + } else { + assertFalse( + largeListViewVector.isNull(i), "Large list view should not be null at index " + i); + List list = largeListViewVector.getObject(i); + assertNotNull(list, "Large list view should not be null at index " + i); + int expectedSize = getLargeListViewSize(i); + assertEquals(expectedSize, list.size(), "Large list view size mismatch at index " + i); + for (int j = 0; j < expectedSize; j++) { + assertEquals( + getLargeListViewElement(i, j), + list.get(j), + "Large list view element mismatch at index " + i + "[" + j + "]"); + } + } + } + } + + private Field newLargeListViewField() { + return new Field( + "large-list-view-int", + FieldType.nullable(new ArrowType.LargeListView()), + Collections.singletonList( + new Field("$data$", FieldType.nullable(new ArrowType.Int(32, true)), null))); + } + + private int getLargeListViewSize(int index) { + return (index % 7) + 1; + } + + private int getLargeListViewElement(int rowIndex, int elementIndex) { + return rowIndex * 100 + elementIndex; + } + } + + /** Test fixed-size list types */ + private class TestFixedSizeListTypes implements DataTester { + private final Field fixedSizeListField; + private final Schema schema; + private final int LIST_SIZE = 3; + + TestFixedSizeListTypes() { + fixedSizeListField = newFixedSizeListField(); + schema = new Schema(Collections.singletonList(fixedSizeListField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + FixedSizeListVector fixedSizeListVector = + (FixedSizeListVector) vectorSchemaRoot.getVector(fixedSizeListField.getName()); + fixedSizeListVector.allocateNew(); + + IntVector childVector = (IntVector) fixedSizeListVector.getDataVector(); + childVector.allocateNew(batchSize * LIST_SIZE); + + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + fixedSizeListVector.setNull(i); + } else { + for (int j = 0; j < LIST_SIZE; j++) { + childVector.set(i * LIST_SIZE + j, i * 10 + j); + } + fixedSizeListVector.setNotNull(i); + } + } + fixedSizeListVector.setValueCount(batchSize); + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + FixedSizeListVector fixedSizeListVector = + (FixedSizeListVector) vectorSchemaRoot.getVector(fixedSizeListField.getName()); + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + if (i % 2 == 0) { + assertTrue(fixedSizeListVector.isNull(i), "Fixed-size list should be null at index " + i); + } else { + assertFalse( + fixedSizeListVector.isNull(i), "Fixed-size list should not be null at index " + i); + List list = fixedSizeListVector.getObject(i); + assertNotNull(list, "Fixed-size list should not be null at index " + i); + assertEquals(LIST_SIZE, list.size(), "Fixed-size list size mismatch at index " + i); + for (int j = 0; j < LIST_SIZE; j++) { + assertEquals( + i * 10 + j, + list.get(j), + "Fixed-size list element mismatch at index " + i + "[" + j + "]"); + } + } + } + } + + private Field newFixedSizeListField() { + return new Field( + "fixed-size-list", + FieldType.nullable(new ArrowType.FixedSizeList(LIST_SIZE)), + Collections.singletonList( + new Field("$data$", FieldType.nullable(new ArrowType.Int(32, true)), null))); + } + } + + /** Test struct types */ + private class TestStructTypes implements DataTester { + private final Field structField; + private final Schema schema; + private final int VAR_BINARY_MAX_LENGTH = 30; + + TestStructTypes() { + structField = newStructField(); + //noinspection ArraysAsListWithZeroOrOneArgument + schema = new Schema(Arrays.asList(structField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + StructVector structVector = (StructVector) vectorSchemaRoot.getVector(structField.getName()); + + // Allocate struct vector + structVector.allocateNew(); + + // Get child vectors + IntVector structIntVector = + structVector.addOrGet( + "s_int", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + BigIntVector structLongVector = + structVector.addOrGet( + "s_long", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + Float4Vector structFloatVector = + structVector.addOrGet( + "s_float", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), + Float4Vector.class); + Float8Vector structDoubleVector = + structVector.addOrGet( + "s_double", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Float8Vector.class); + DecimalVector structDecimalVector = + structVector.addOrGet( + "s_decimal", + FieldType.nullable(new ArrowType.Decimal(16, 0, 128)), + DecimalVector.class); + DateDayVector structDateVector = + structVector.addOrGet( + "s_date", FieldType.nullable(new ArrowType.Date(DateUnit.DAY)), DateDayVector.class); + TimeStampMilliTZVector structTimestampVector = + structVector.addOrGet( + "s_timestamp", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")), + TimeStampMilliTZVector.class); + DurationVector structDurationVector = + structVector.addOrGet( + "s_duration", + FieldType.nullable(new ArrowType.Duration(TimeUnit.SECOND)), + DurationVector.class); + IntervalYearVector structIntervalVector = + structVector.addOrGet( + "s_interval", + FieldType.nullable(new ArrowType.Interval(IntervalUnit.YEAR_MONTH)), + IntervalYearVector.class); + VarBinaryVector structBinaryVector = + structVector.addOrGet( + "s_binary", FieldType.nullable(new ArrowType.Binary()), VarBinaryVector.class); + VarCharVector structUtf8Vector = + structVector.addOrGet( + "s_utf8", FieldType.nullable(new ArrowType.Utf8()), VarCharVector.class); + ListVector structListVector = + structVector.addOrGet( + "s_list", FieldType.nullable(ArrowType.List.INSTANCE), ListVector.class); + + // Allocate child vectors + structIntVector.allocateNew(batchSize); + structLongVector.allocateNew(batchSize); + structFloatVector.allocateNew(batchSize); + structDoubleVector.allocateNew(batchSize); + structDecimalVector.allocateNew(batchSize); + structDateVector.allocateNew(batchSize); + structTimestampVector.allocateNew(batchSize); + structDurationVector.allocateNew(batchSize); + structIntervalVector.allocateNew(batchSize); + structBinaryVector.allocateNew(batchSize * 20L, batchSize); + structUtf8Vector.allocateNew(batchSize * 50L, batchSize); + structListVector.allocateNew(); + + // Set struct values + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + structVector.setNull(i); + } else { + structVector.setIndexDefined(i); + structIntVector.set(i, getSignedInt(i)); + structLongVector.set(i, getSignedLong(i)); + structFloatVector.set(i, getFloat(i)); + structDoubleVector.set(i, getDouble(i)); + structDecimalVector.set(i, getDecimal(i, structDecimalVector.getScale())); + structDateVector.set(i, getDateDay(i)); + structTimestampVector.set(i, getTimestampMilli(i)); + structDurationVector.set(i, getDurationSec(i)); + structIntervalVector.set(i, getIntervalYearMonth(i)); + structBinaryVector.set(i, getVarBinary(i, VAR_BINARY_MAX_LENGTH)); + structUtf8Vector.set(i, getUtf8String(i).getBytes(StandardCharsets.UTF_8)); + + // Set list in struct + UnionListWriter listWriter = structListVector.getWriter(); + listWriter.setPosition(i); + listWriter.startList(); + int listSize = getListSize(i); + for (int j = 0; j < listSize; j++) { + listWriter.integer().writeInt(getListElement(i, j)); + } + listWriter.endList(); + } + } + structListVector.getWriter().setValueCount(batchSize); + structVector.setValueCount(batchSize); + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + StructVector structVector = (StructVector) vectorSchemaRoot.getVector(structField.getName()); + + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + // Validate struct + if (i % 2 == 0) { + assertTrue(structVector.isNull(i), "Struct should be null at index " + i); + } else { + java.util.Map structMap = structVector.getObject(i); + assertNotNull(structMap, "Struct should not be null at index " + i); + + // Validate int + assertEquals( + getSignedInt(i), + ((Integer) structMap.get("s_int")).intValue(), + "Struct int mismatch at index " + i); + + // Validate long + assertEquals( + getSignedLong(i), + ((Long) structMap.get("s_long")).longValue(), + "Struct long mismatch at index " + i); + + // Validate float + assertEquals( + getFloat(i), + (Float) structMap.get("s_float"), + 0.0001f, + "Struct float mismatch at index " + i); + + // Validate double + assertEquals( + getDouble(i), + (Double) structMap.get("s_double"), + 0.0001, + "Struct double mismatch at index " + i); + + // Validate decimal + BigDecimal decimalValue = (BigDecimal) structMap.get("s_decimal"); + assertEquals( + getDecimal(i, decimalValue.scale()), + decimalValue, + "Struct decimal mismatch at index " + i); + + // Validate date + assertEquals( + getDateDay(i), + ((Integer) structMap.get("s_date")).intValue(), + "Struct date mismatch at index " + i); + + // Validate timestamp + assertEquals( + getTimestampMilli(i), + ((Long) structMap.get("s_timestamp")).longValue(), + "Struct timestamp mismatch at index " + i); + + // Validate duration + Duration durationValue = (Duration) structMap.get("s_duration"); + assertEquals( + getDurationSec(i), + durationValue.getSeconds(), + "Struct duration mismatch at index " + i); + + // Validate interval + assertEquals( + getIntervalYearMonth(i), + ((Period) structMap.get("s_interval")).getMonths(), + "Struct interval mismatch at index " + i); + + // Validate binary + byte[] binaryValue = (byte[]) structMap.get("s_binary"); + assertArrayEquals( + getVarBinary(i, VAR_BINARY_MAX_LENGTH), + binaryValue, + "Struct binary mismatch at index " + i); + + // Validate utf8 + String utf8Value = + new String(((Text) structMap.get("s_utf8")).getBytes(), StandardCharsets.UTF_8); + assertEquals(getUtf8String(i), utf8Value, "Struct utf8 mismatch at index " + i); + + // Validate list + @SuppressWarnings("unchecked") + List listValue = (List) structMap.get("s_list"); + int expectedSize = getListSize(i); + assertEquals(expectedSize, listValue.size(), "Struct list size mismatch at index " + i); + for (int j = 0; j < expectedSize; j++) { + assertEquals( + getListElement(i, j), + listValue.get(j), + "Struct list element mismatch at index " + i + "[" + j + "]"); + } + } + } + } + + private Field newStructField() { + return new Field( + "struct-all-types", + FieldType.nullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("s_int", FieldType.nullable(new ArrowType.Int(32, true)), null), + new Field("s_long", FieldType.nullable(new ArrowType.Int(64, true)), null), + new Field( + "s_float", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), + null), + new Field( + "s_double", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + null), + new Field("s_decimal", FieldType.nullable(new ArrowType.Decimal(16, 0, 128)), null), + new Field( + "s_date", + FieldType.nullable( + new ArrowType.Date(org.apache.arrow.vector.types.DateUnit.DAY)), + null), + new Field( + "s_timestamp", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")), + null), + new Field( + "s_duration", FieldType.nullable(new ArrowType.Duration(TimeUnit.SECOND)), null), + new Field( + "s_interval", + FieldType.nullable(new ArrowType.Interval(IntervalUnit.YEAR_MONTH)), + null), + new Field("s_binary", FieldType.nullable(new ArrowType.Binary()), null), + new Field("s_utf8", FieldType.nullable(new ArrowType.Utf8()), null), + new Field( + "s_list", + FieldType.nullable(ArrowType.List.INSTANCE), + java.util.Collections.singletonList( + new Field( + "$data$", FieldType.nullable(new ArrowType.Int(32, true)), null))))); + } + } + + /** Test map types with all value types from struct test */ + private class TestMapTypes implements DataTester { + private final Field mapStringIntField; + private final Field mapStringLongField; + private final Field mapStringFloatField; + private final Field mapStringDoubleField; + private final Field mapStringDecimalField; + private final Field mapStringDateField; + private final Field mapStringTimestampField; + private final Field mapStringDurationField; + private final Field mapStringIntervalField; + private final Field mapStringBinaryField; + private final Field mapStringUtf8Field; + private final Field mapStringListField; + private final Schema schema; + private final int VAR_BINARY_MAX_LENGTH = 32; + private final int DECIMAL_PRECISION = 16; + private final int DECIMAL_SCALE = 0; + + TestMapTypes() { + mapStringIntField = newMapStringIntField(); + mapStringLongField = newMapStringLongField(); + mapStringFloatField = newMapStringFloatField(); + mapStringDoubleField = newMapStringDoubleField(); + mapStringDecimalField = newMapStringDecimalField(); + mapStringDateField = newMapStringDateField(); + mapStringTimestampField = newMapStringTimestampField(); + mapStringDurationField = newMapStringDurationField(); + mapStringIntervalField = newMapStringIntervalField(); + mapStringBinaryField = newMapStringBinaryField(); + mapStringUtf8Field = newMapStringUtf8Field(); + mapStringListField = newMapStringListField(); + schema = + new Schema( + Arrays.asList( + mapStringIntField, + mapStringLongField, + mapStringFloatField, + mapStringDoubleField, + mapStringDecimalField, + mapStringDateField, + mapStringTimestampField, + mapStringDurationField, + mapStringIntervalField, + mapStringBinaryField, + mapStringUtf8Field, + mapStringListField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringIntField.getName()), batchSize, "int"); + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringLongField.getName()), batchSize, "long"); + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringFloatField.getName()), + batchSize, + "float"); + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringDoubleField.getName()), + batchSize, + "double"); + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringDecimalField.getName()), + batchSize, + "decimal"); + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringDateField.getName()), batchSize, "date"); + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringTimestampField.getName()), + batchSize, + "timestamp"); + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringDurationField.getName()), + batchSize, + "duration"); + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringIntervalField.getName()), + batchSize, + "interval"); + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringBinaryField.getName()), + batchSize, + "binary"); + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringUtf8Field.getName()), batchSize, "utf8"); + writeMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringListField.getName()), batchSize, "list"); + } + + private void writeMapVector(MapVector mapVector, int batchSize, String valueType) { + mapVector.allocateNew(); + StructVector structVector = (StructVector) mapVector.getDataVector(); + + VarCharVector keyVector = + structVector.addOrGet( + "key", FieldType.notNullable(new ArrowType.Utf8()), VarCharVector.class); + + int maxMapEntries = (batchSize / 2) * 3 + 10; + keyVector.allocateNew((long) batchSize * 30, maxMapEntries); + + switch (valueType) { + case "int": + { + IntVector valueVector = + structVector.addOrGet( + "value", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + valueVector.allocateNew(maxMapEntries); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + valueVector.set(start + j, getSignedInt(i * 10 + j)); + } + mapVector.endValue(i, mapSize); + } + } + break; + } + case "long": + { + BigIntVector valueVector = + structVector.addOrGet( + "value", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + valueVector.allocateNew(maxMapEntries); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + valueVector.set(start + j, getSignedLong(i * 10 + j)); + } + mapVector.endValue(i, mapSize); + } + } + break; + } + case "float": + { + Float4Vector valueVector = + structVector.addOrGet( + "value", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), + Float4Vector.class); + valueVector.allocateNew(maxMapEntries); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + valueVector.set(start + j, getFloat(i * 10 + j)); + } + mapVector.endValue(i, mapSize); + } + } + break; + } + case "double": + { + Float8Vector valueVector = + structVector.addOrGet( + "value", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Float8Vector.class); + valueVector.allocateNew(maxMapEntries); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + valueVector.set(start + j, getDouble(i * 10 + j)); + } + mapVector.endValue(i, mapSize); + } + } + break; + } + case "decimal": + { + DecimalVector valueVector = + structVector.addOrGet( + "value", + FieldType.nullable(new ArrowType.Decimal(16, 0, 128)), + DecimalVector.class); + valueVector.allocateNew(maxMapEntries); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + valueVector.set(start + j, getDecimal(i * 10 + j, valueVector.getScale())); + } + mapVector.endValue(i, mapSize); + } + } + break; + } + case "date": + { + DateDayVector valueVector = + structVector.addOrGet( + "value", + FieldType.nullable( + new ArrowType.Date(org.apache.arrow.vector.types.DateUnit.DAY)), + DateDayVector.class); + valueVector.allocateNew(maxMapEntries); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + valueVector.set(start + j, getDateDay(i * 10 + j)); + } + mapVector.endValue(i, mapSize); + } + } + break; + } + case "timestamp": + { + TimeStampMilliTZVector valueVector = + structVector.addOrGet( + "value", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")), + TimeStampMilliTZVector.class); + valueVector.allocateNew(maxMapEntries); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + valueVector.set(start + j, getTimestampMilli(i * 10 + j)); + } + mapVector.endValue(i, mapSize); + } + } + break; + } + case "duration": + { + DurationVector valueVector = + structVector.addOrGet( + "value", + FieldType.nullable(new ArrowType.Duration(TimeUnit.SECOND)), + DurationVector.class); + valueVector.allocateNew(maxMapEntries); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + valueVector.set(start + j, getDurationSec(i * 10 + j)); + } + mapVector.endValue(i, mapSize); + } + } + break; + } + case "interval": + { + IntervalYearVector valueVector = + structVector.addOrGet( + "value", + FieldType.nullable(new ArrowType.Interval(IntervalUnit.YEAR_MONTH)), + IntervalYearVector.class); + valueVector.allocateNew(maxMapEntries); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + valueVector.set(start + j, getIntervalYearMonth(i * 10 + j)); + } + mapVector.endValue(i, mapSize); + } + } + break; + } + case "binary": + { + VarBinaryVector valueVector = + structVector.addOrGet( + "value", FieldType.nullable(new ArrowType.Binary()), VarBinaryVector.class); + valueVector.allocateNew((long) batchSize * 20, maxMapEntries); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + valueVector.set(start + j, getVarBinary(i * 10 + j, VAR_BINARY_MAX_LENGTH)); + } + mapVector.endValue(i, mapSize); + } + } + break; + } + case "utf8": + { + VarCharVector valueVector = + structVector.addOrGet( + "value", FieldType.nullable(new ArrowType.Utf8()), VarCharVector.class); + valueVector.allocateNew((long) batchSize * 50, maxMapEntries); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + valueVector.set( + start + j, getUtf8String(i * 10 + j).getBytes(StandardCharsets.UTF_8)); + } + mapVector.endValue(i, mapSize); + } + } + break; + } + case "list": + { + ListVector valueVector = + structVector.addOrGet( + "value", FieldType.nullable(ArrowType.List.INSTANCE), ListVector.class); + valueVector.allocateNew(); + UnionListWriter listWriter = valueVector.getWriter(); + + int entryIndex = 0; + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + mapVector.setNull(i); + } else { + int start = mapVector.startNewValue(i); + int mapSize = getMapSize(i); + for (int j = 0; j < mapSize; j++) { + structVector.setIndexDefined(start + j); + keyVector.set(start + j, getMapStringKey(i, j).getBytes(StandardCharsets.UTF_8)); + + listWriter.setPosition(start + j); + listWriter.startList(); + int listSize = getListSize(i * 10 + j); + for (int k = 0; k < listSize; k++) { + listWriter.integer().writeInt(getListElement(i * 10 + j, k)); + } + listWriter.endList(); + entryIndex++; + } + mapVector.endValue(i, mapSize); + } + } + listWriter.setValueCount(entryIndex); + break; + } + } + mapVector.setValueCount(batchSize); + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringIntField.getName()), + vectorSchemaRoot.getRowCount(), + "int"); + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringLongField.getName()), + vectorSchemaRoot.getRowCount(), + "long"); + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringFloatField.getName()), + vectorSchemaRoot.getRowCount(), + "float"); + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringDoubleField.getName()), + vectorSchemaRoot.getRowCount(), + "double"); + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringDecimalField.getName()), + vectorSchemaRoot.getRowCount(), + "decimal"); + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringDateField.getName()), + vectorSchemaRoot.getRowCount(), + "date"); + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringTimestampField.getName()), + vectorSchemaRoot.getRowCount(), + "timestamp"); + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringDurationField.getName()), + vectorSchemaRoot.getRowCount(), + "duration"); + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringIntervalField.getName()), + vectorSchemaRoot.getRowCount(), + "interval"); + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringBinaryField.getName()), + vectorSchemaRoot.getRowCount(), + "binary"); + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringUtf8Field.getName()), + vectorSchemaRoot.getRowCount(), + "utf8"); + validateMapVector( + (MapVector) vectorSchemaRoot.getVector(mapStringListField.getName()), + vectorSchemaRoot.getRowCount(), + "list"); + } + + @SuppressWarnings("unchecked") + private void validateMapVector(MapVector mapVector, int rowCount, String valueType) { + for (int i = 0; i < rowCount; i++) { + if (i % 2 == 0) { + assertTrue( + mapVector.isNull(i), "Map should be null at index " + i); + } else { + List mapList = mapVector.getObject(i); + assertNotNull(mapList, "Map should not be null at index " + i); + int expectedSize = getMapSize(i); + assertEquals( + expectedSize, + mapList.size(), + "Map size mismatch at index " + i); + + for (int j = 0; j < expectedSize; j++) { + java.util.Map entry = (java.util.Map) mapList.get(j); + String expectedKey = getMapStringKey(i, j); + assertEquals( + expectedKey, + new String(((Text) entry.get("key")).getBytes(), StandardCharsets.UTF_8), + "Map key mismatch at index " + i + "[" + j + "]"); + + Object value = entry.get("value"); + switch (valueType) { + case "int": + assertEquals( + getSignedInt(i * 10 + j), + ((Integer) value).intValue(), + "Map int value mismatch at index " + i + "[" + j + "]"); + break; + case "long": + assertEquals( + getSignedLong(i * 10 + j), + ((Long) value).longValue(), + "Map long value mismatch at index " + i + "[" + j + "]"); + break; + case "float": + assertEquals( + getFloat(i * 10 + j), + (Float) value, + 0.0001f, + "Map float value mismatch at index " + i + "[" + j + "]"); + break; + case "double": + assertEquals( + getDouble(i * 10 + j), + (Double) value, + 0.0001, + "Map double value mismatch at index " + i + "[" + j + "]"); + break; + case "decimal": + assertEquals( + getDecimal(i * 10 + j, DECIMAL_SCALE), + value, + "Map decimal value mismatch at index " + i + "[" + j + "]"); + break; + case "date": + assertEquals( + getDateDay(i * 10 + j), + ((Integer) value).intValue(), + "Map date value mismatch at index " + i + "[" + j + "]"); + break; + case "timestamp": + assertEquals( + getTimestampMilli(i * 10 + j), + ((Long) value).longValue(), + "Map timestamp value mismatch at index " + i + "[" + j + "]"); + break; + case "duration": + assertEquals( + getDurationSec(i * 10 + j), + ((Duration) value).getSeconds(), + "Map duration value mismatch at index " + i + "[" + j + "]"); + break; + case "interval": + assertEquals( + getIntervalYearMonth(i * 10 + j), + ((Period) value).getMonths(), + "Map interval value mismatch at index " + i + "[" + j + "]"); + break; + case "binary": + assertArrayEquals( + getVarBinary(i * 10 + j, VAR_BINARY_MAX_LENGTH), + (byte[]) value, + "Map binary value mismatch at index " + i + "[" + j + "]"); + break; + case "utf8": + assertEquals( + getUtf8String(i * 10 + j), + new String(((Text) value).getBytes(), StandardCharsets.UTF_8), + "Map utf8 value mismatch at index " + i + "[" + j + "]"); + break; + case "list": + java.util.List listValue = (java.util.List) value; + int expectedListSize = getListSize(i * 10 + j); + assertEquals( + expectedListSize, + listValue.size(), + "Map list value size mismatch at index " + i + "[" + j + "]"); + for (int k = 0; k < expectedListSize; k++) { + assertEquals( + getListElement(i * 10 + j, k), + listValue.get(k), + "Map list element mismatch at index " + i + "[" + j + "][" + k + "]"); + } + break; + } + } + } + } + } + + private Field newMapStringIntField() { + return new Field( + "map-string-int", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field("value", FieldType.nullable(new ArrowType.Int(32, true)), null))))); + } + + private Field newMapStringLongField() { + return new Field( + "map-string-long", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field("value", FieldType.nullable(new ArrowType.Int(64, true)), null))))); + } + + private Field newMapStringFloatField() { + return new Field( + "map-string-float", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field( + "value", + FieldType.nullable( + new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), + null))))); + } + + private Field newMapStringDoubleField() { + return new Field( + "map-string-double", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field( + "value", + FieldType.nullable( + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + null))))); + } + + private Field newMapStringDecimalField() { + return new Field( + "map-string-decimal", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field( + "value", + FieldType.nullable( + new ArrowType.Decimal(DECIMAL_PRECISION, DECIMAL_SCALE, 128)), + null))))); + } + + private Field newMapStringDateField() { + return new Field( + "map-string-date", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field( + "value", + FieldType.nullable( + new ArrowType.Date(org.apache.arrow.vector.types.DateUnit.DAY)), + null))))); + } + + private Field newMapStringTimestampField() { + return new Field( + "map-string-timestamp", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field( + "value", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")), + null))))); + } + + private Field newMapStringDurationField() { + return new Field( + "map-string-duration", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field( + "value", + FieldType.nullable(new ArrowType.Duration(TimeUnit.SECOND)), + null))))); + } + + private Field newMapStringIntervalField() { + return new Field( + "map-string-interval", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field( + "value", + FieldType.nullable(new ArrowType.Interval(IntervalUnit.YEAR_MONTH)), + null))))); + } + + private Field newMapStringBinaryField() { + return new Field( + "map-string-binary", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field("value", FieldType.nullable(new ArrowType.Binary()), null))))); + } + + private Field newMapStringUtf8Field() { + return new Field( + "map-string-utf8", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field("value", FieldType.nullable(new ArrowType.Utf8()), null))))); + } + + private Field newMapStringListField() { + return new Field( + "map-string-list", + FieldType.nullable(new ArrowType.Map(false)), + java.util.Collections.singletonList( + new Field( + "entries", + FieldType.notNullable(ArrowType.Struct.INSTANCE), + Arrays.asList( + new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null), + new Field( + "value", + FieldType.nullable(ArrowType.List.INSTANCE), + java.util.Collections.singletonList( + new Field( + "$data$", + FieldType.nullable(new ArrowType.Int(32, true)), + null))))))); + } + + private int getMapSize(int index) { + return (index % 3) + 1; + } + + private String getMapStringKey(int rowIndex, int entryIndex) { + return "key_" + rowIndex + "_" + entryIndex; + } + } + + private class TestUnionTypes implements DataTester { + private final Field unionField; + private final Schema schema; + private final int VAR_BINARY_LENGTH = 64; + private final int DECIMAL_PRECISION = 16; + private final int DECIMAL_SCALE = 0; + + TestUnionTypes() { + unionField = newUnionField(); + schema = new Schema(java.util.Collections.singletonList(unionField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot root, int totalRows) { + DenseUnionVector unionVector = (DenseUnionVector) root.getVector(unionField.getName()); + + // Set Union - cycles through all types + unionVector.allocateNew(); + + // Get child vectors by type ID (0-10) + IntVector unionIntVector = (IntVector) unionVector.getVectorByType((byte) 0); + BigIntVector unionLongVector = (BigIntVector) unionVector.getVectorByType((byte) 1); + Float4Vector unionFloatVector = (Float4Vector) unionVector.getVectorByType((byte) 2); + Float8Vector unionDoubleVector = (Float8Vector) unionVector.getVectorByType((byte) 3); + DecimalVector unionDecimalVector = (DecimalVector) unionVector.getVectorByType((byte) 4); + DateDayVector unionDateVector = (DateDayVector) unionVector.getVectorByType((byte) 5); + TimeStampMilliTZVector unionTimestampVector = + (TimeStampMilliTZVector) unionVector.getVectorByType((byte) 6); + DurationVector unionDurationVector = (DurationVector) unionVector.getVectorByType((byte) 7); + IntervalYearVector unionIntervalVector = + (IntervalYearVector) unionVector.getVectorByType((byte) 8); + VarBinaryVector unionBinaryVector = (VarBinaryVector) unionVector.getVectorByType((byte) 9); + VarCharVector unionUtf8Vector = (VarCharVector) unionVector.getVectorByType((byte) 10); + + // Track offsets for each type (11 types: 0-10) + int[] typeOffsets = new int[11]; + + for (int i = 0; i < totalRows; i++) { + // Cycle through different types based on index mod 11 + int typeIndex = i % 11; + byte typeId = (byte) typeIndex; + + // Get the current offset for this type + int offset = typeOffsets[typeIndex]; + + // Set type ID and offset for this position + unionVector.setTypeId(i, typeId); + unionVector.setOffset(i, offset); + + switch (typeIndex) { + case 0: // Int + unionIntVector.setSafe(offset, getSignedInt(i)); + break; + case 1: // Long + unionLongVector.setSafe(offset, getSignedLong(i)); + break; + case 2: // Float + unionFloatVector.setSafe(offset, getFloat(i)); + break; + case 3: // Double + unionDoubleVector.setSafe(offset, getDouble(i)); + break; + case 4: // Decimal + unionDecimalVector.setSafe(offset, getDecimal(i, unionDecimalVector.getScale())); + break; + case 5: // Date + unionDateVector.setSafe(offset, getDateDay(i)); + break; + case 6: // Timestamp + unionTimestampVector.setSafe(offset, getTimestampMilli(i)); + break; + case 7: // Duration + unionDurationVector.setSafe(offset, getDurationSec(i)); + break; + case 8: // Interval Year-Month + unionIntervalVector.setSafe(offset, getIntervalYearMonth(i)); + break; + case 9: // Binary + unionBinaryVector.setSafe(offset, getVarBinary(i, VAR_BINARY_LENGTH)); + break; + case 10: // UTF8 + unionUtf8Vector.setSafe(offset, getUtf8String(i).getBytes(StandardCharsets.UTF_8)); + break; + } + + // Increment the offset for this type + typeOffsets[typeIndex]++; + } + unionVector.setValueCount(totalRows); + } + + @Override + public void validateData(VectorSchemaRoot root) { + DenseUnionVector unionVector = (DenseUnionVector) root.getVector(unionField.getName()); + int rowCount = root.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + // Validate Union + assertFalse(unionVector.isNull(i), "Union should not be null at index " + i); + int typeIndex = i % 11; + Object unionValue = unionVector.getObject(i); + assertNotNull(unionValue, "Union value should not be null at index " + i); + + switch (typeIndex) { + case 0: // Int + assertEquals( + getSignedInt(i), + ((Integer) unionValue).intValue(), + "Union int mismatch at index " + i); + break; + case 1: // Long + assertEquals( + getSignedLong(i), + ((Long) unionValue).longValue(), + "Union long mismatch at index " + i); + break; + case 2: // Float + assertEquals( + getFloat(i), (Float) unionValue, 0.0001f, "Union float mismatch at index " + i); + break; + case 3: // Double + assertEquals( + getDouble(i), (Double) unionValue, 0.0001, "Union double mismatch at index " + i); + break; + case 4: // Decimal + assertEquals( + getDecimal(i, DECIMAL_SCALE), unionValue, "Union decimal mismatch at index " + i); + break; + case 5: // Date + assertEquals( + getDateDay(i), + ((Integer) unionValue).intValue(), + "Union date mismatch at index " + i); + break; + case 6: // Timestamp + assertEquals( + getTimestampMilli(i), + ((Long) unionValue).longValue(), + "Union timestamp mismatch at index " + i); + break; + case 7: // Duration + assertEquals( + getDurationSec(i), + ((java.time.Duration) unionValue).getSeconds(), + "Union duration mismatch at index " + i); + break; + case 8: // Interval + assertEquals( + getIntervalYearMonth(i), + ((Period) unionValue).getMonths(), + "Union interval mismatch at index " + i); + break; + case 9: // Binary + assertArrayEquals( + getVarBinary(i, VAR_BINARY_LENGTH), + (byte[]) unionValue, + "Union binary mismatch at index " + i); + break; + case 10: // UTF8 + assertEquals( + getUtf8String(i), + new String(((Text) unionValue).getBytes(), StandardCharsets.UTF_8), + "Union utf8 mismatch at index " + i); + break; + } + } + } + + private Field newUnionField() { + return new Field( + "union-all-types", + FieldType.nullable( + new ArrowType.Union(UnionMode.Dense, new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10})), + Arrays.asList( + new Field("u_int", FieldType.nullable(new ArrowType.Int(32, true)), null), + new Field("u_long", FieldType.nullable(new ArrowType.Int(64, true)), null), + new Field( + "u_float", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), + null), + new Field( + "u_double", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + null), + new Field( + "u_decimal", + FieldType.nullable(new ArrowType.Decimal(DECIMAL_PRECISION, DECIMAL_SCALE, 128)), + null), + new Field( + "u_date", + FieldType.nullable( + new ArrowType.Date(org.apache.arrow.vector.types.DateUnit.DAY)), + null), + new Field( + "u_timestamp", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")), + null), + new Field( + "u_duration", FieldType.nullable(new ArrowType.Duration(TimeUnit.SECOND)), null), + new Field( + "u_interval", + FieldType.nullable(new ArrowType.Interval(IntervalUnit.YEAR_MONTH)), + null), + new Field("u_binary", FieldType.nullable(new ArrowType.Binary()), null), + new Field("u_utf8", FieldType.nullable(new ArrowType.Utf8()), null))); + } + } + + /** Test dictionary types */ + private class TestDictionaryTypes implements DataTester { + private final Field dictStringField; + private final Field dictIntField; + private final Schema schema; + + TestDictionaryTypes() { + dictStringField = newDictStringField(); + dictIntField = newDictIntField(); + schema = new Schema(Arrays.asList(dictStringField, dictIntField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + VarCharVector dictStringVector = + (VarCharVector) vectorSchemaRoot.getVector(dictStringField.getName()); + IntVector dictIntVector = (IntVector) vectorSchemaRoot.getVector(dictIntField.getName()); + + // Set dictionary-encoded strings (repeating pattern of values) + dictStringVector.allocateNew(batchSize * 20L, batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + dictStringVector.setNull(i); + } else { + dictStringVector.set(i, getDictString(i).getBytes(StandardCharsets.UTF_8)); + } + } + + // Set dictionary-encoded integers (repeating pattern of values) + dictIntVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + dictIntVector.setNull(i); + } else { + dictIntVector.set(i, getDictInt(i)); + } + } + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + VarCharVector dictStringVector = + (VarCharVector) vectorSchemaRoot.getVector(dictStringField.getName()); + IntVector dictIntVector = (IntVector) vectorSchemaRoot.getVector(dictIntField.getName()); + + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + // Validate dictionary string + if (i % 2 == 0) { + assertTrue(dictStringVector.isNull(i), "Dictionary string should be null at index " + i); + } else { + String expected = getDictString(i); + byte[] actualBytes = dictStringVector.get(i); + assertNotNull(actualBytes, "Dictionary string should not be null at index " + i); + String actual = new String(actualBytes, StandardCharsets.UTF_8); + assertEquals(expected, actual, "Dictionary string mismatch at index " + i); + } + + // Validate dictionary int + if (i % 2 == 0) { + assertTrue(dictIntVector.isNull(i), "Dictionary int should be null at index " + i); + } else { + assertEquals( + getDictInt(i), dictIntVector.get(i), "Dictionary int mismatch at index " + i); + } + } + } + } + + /** Test run-end encoded types */ + private class TestRunEndEncodedTypes implements DataTester { + private final Field reeIntField; + private final Field reeLongField; + private final Schema schema; + + TestRunEndEncodedTypes() { + reeIntField = newRunEndEncodedIntField(); + reeLongField = newRunEndEncodedLongField(); + schema = new Schema(Arrays.asList(reeIntField, reeLongField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + RunEndEncodedVector reeIntVector = + (RunEndEncodedVector) vectorSchemaRoot.getVector(reeIntField.getName()); + RunEndEncodedVector reeLongVector = + (RunEndEncodedVector) vectorSchemaRoot.getVector(reeLongField.getName()); + + // Set run-end encoded integers + // Calculate required capacity: runs every 3 rows means we need (batchSize + 2) / 3 runs + int reeIntCapacity = (batchSize + 2) / 3; + reeIntVector.allocateNew(); + IntVector reeIntRunEnds = (IntVector) reeIntVector.getRunEndsVector(); + IntVector reeIntValues = (IntVector) reeIntVector.getValuesVector(); + reeIntRunEnds.allocateNew(reeIntCapacity); + reeIntValues.allocateNew(reeIntCapacity); + + // Create runs with repeating values (every 3 rows have the same value) + int runCount = 0; + for (int i = 0; i < batchSize; i += 3) { + int runEnd = Math.min(i + 3, batchSize); + reeIntRunEnds.set(runCount, runEnd); + reeIntValues.set(runCount, getSignedInt(i)); + runCount++; + } + reeIntRunEnds.setValueCount(runCount); + reeIntValues.setValueCount(runCount); + reeIntVector.setValueCount(batchSize); + + // Set run-end encoded longs + // Calculate required capacity: runs every 5 rows means we need (batchSize + 4) / 5 runs + int reeLongCapacity = (batchSize + 4) / 5; + reeLongVector.allocateNew(); + IntVector reeLongRunEnds = (IntVector) reeLongVector.getRunEndsVector(); + BigIntVector reeLongValues = (BigIntVector) reeLongVector.getValuesVector(); + reeLongRunEnds.allocateNew(reeLongCapacity); + reeLongValues.allocateNew(reeLongCapacity); + + // Create runs with repeating values (every 5 rows have the same value) + runCount = 0; + for (int i = 0; i < batchSize; i += 5) { + int runEnd = Math.min(i + 5, batchSize); + reeLongRunEnds.set(runCount, runEnd); + reeLongValues.set(runCount, getSignedLong(i)); + runCount++; + } + reeLongRunEnds.setValueCount(runCount); + reeLongValues.setValueCount(runCount); + reeLongVector.setValueCount(batchSize); + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + RunEndEncodedVector reeIntVector = + (RunEndEncodedVector) vectorSchemaRoot.getVector(reeIntField.getName()); + RunEndEncodedVector reeLongVector = + (RunEndEncodedVector) vectorSchemaRoot.getVector(reeLongField.getName()); + + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + // Validate run-end encoded int + Object reeIntValue = reeIntVector.getObject(i); + assertNotNull(reeIntValue, "REE int should not be null at index " + i); + int expectedInt = getSignedInt((i / 3) * 3); + assertEquals( + expectedInt, ((Integer) reeIntValue).intValue(), "REE int mismatch at index " + i); + + // Validate run-end encoded long + Object reeLongValue = reeLongVector.getObject(i); + assertNotNull(reeLongValue, "REE long should not be null at index " + i); + long expectedLong = getSignedLong((i / 5) * 5); + assertEquals( + expectedLong, ((Long) reeLongValue).longValue(), "REE long mismatch at index " + i); + } + } + + private Field newRunEndEncodedIntField() { + Field runEndField = + new Field("run_ends", FieldType.notNullable(new ArrowType.Int(32, true)), null); + Field valueField = new Field("values", FieldType.nullable(new ArrowType.Int(32, true)), null); + return new Field( + "ree-int", + FieldType.nullable(new ArrowType.RunEndEncoded()), + Arrays.asList(runEndField, valueField)); + } + + private Field newRunEndEncodedLongField() { + Field runEndField = + new Field("run_ends", FieldType.notNullable(new ArrowType.Int(32, true)), null); + Field valueField = new Field("values", FieldType.nullable(new ArrowType.Int(64, true)), null); + return new Field( + "ree-long", + FieldType.nullable(new ArrowType.RunEndEncoded()), + Arrays.asList(runEndField, valueField)); + } + } +} diff --git a/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchMemoryUsageTest.java b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchMemoryUsageTest.java new file mode 100644 index 000000000..9655f6413 --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchMemoryUsageTest.java @@ -0,0 +1,113 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.util.TransferPair; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Test the patched allocator does not put the JVM into GC pressure and cause it to OOM + * (OutOfMemoryError). + */ +public class DatabricksArrowPatchMemoryUsageTest { + /** Path to an arrow chunk. */ + private static final Path ARROW_CHUNK_PATH = Path.of("arrow", "chunk_all_types.arrow"); + + private static final Logger logger = + LoggerFactory.getLogger(DatabricksArrowPatchMemoryUsageTest.class); + + private interface BufferAllocatorFactory { + BufferAllocator create(); + } + + /** + * Repeatedly parse an Arrow stream file with low JVM memory -Xmx100m and verify no OOM occurs. + */ + @Test + public void testMemoryUsageOfDatabricksBufferAllocator() throws Exception { + logger.info("Testing memory usage of DatabricksBufferAllocator"); + testMemoryUsageOfBufferAllocator(DatabricksBufferAllocator::new); + } + + public void testMemoryUsageOfBufferAllocator(BufferAllocatorFactory factory) throws Exception { + for (int i = 0; i < 1000; i++) { + try (BufferAllocator allocator = factory.create()) { + long recordCount = parseArrowStream(ARROW_CHUNK_PATH, allocator); + if (i % 100 == 0) { + logger.info("Iteration {}: Parsed {} records.", i, recordCount); + } + } + } + } + + /** + * Parse the Arrow stream file stored at {@code filePath}, access every value, and return the + * record count. + */ + private long parseArrowStream(Path filePath, BufferAllocator allocator) throws IOException { + long recordCount = 0; + + Random random = new Random(); + try (InputStream arrowStream = getStream(filePath); + ArrowStreamReader reader = new ArrowStreamReader(arrowStream, allocator)) { + // Iterate over batches. + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + + // Transfer all vectors. + List valueVectors = + root.getFieldVectors().stream() + .map( + fieldVector -> { + TransferPair transferPair = fieldVector.getTransferPair(allocator); + transferPair.transfer(); + return transferPair.getTo(); + }) + .collect(Collectors.toList()); + + // Access each value without retaining references to avoid heap pressure. + try { + for (int recordIndex = 0; recordIndex < root.getRowCount(); recordIndex++) { + // Add logging side effects to prevent JVM from optimizing out this code path. + HashMap record = new HashMap<>(); + for (ValueVector valueVector : valueVectors) { + record.put(valueVector.getField().getName(), valueVector.getObject(recordIndex)); + } + if (random.nextInt(10_000) < 2) { + logger.trace("Read record with {} keys", record.size()); + } + + recordCount++; + } + } finally { + // Close all transferred vectors to prevent memory leak + valueVectors.forEach(ValueVector::close); + } + } + } + + return recordCount; + } + + /** + * @return an input stream for the filePath. + */ + private InputStream getStream(Path filePath) throws IOException { + InputStream arrowStream = + this.getClass().getClassLoader().getResourceAsStream(filePath.toString()); + assertNotNull(arrowStream, filePath + " not found"); + return arrowStream; + } +} diff --git a/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchNumericTypesTest.java b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchNumericTypesTest.java new file mode 100644 index 000000000..d05896d05 --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchNumericTypesTest.java @@ -0,0 +1,644 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.Collections; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.Decimal256Vector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +/** Test numeric Arrow data types (integers, floats, decimals, booleans, nulls). */ +@Tag("Jvm17PlusAndArrowToNioReflectionDisabled") +@EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) +public class DatabricksArrowPatchNumericTypesTest extends AbstractDatabricksArrowPatchTypesTest { + + /** Test read and write of integer types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testIntegerTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testFloats = new TestIntegers(); + byte[] data = writeData(testFloats, totalRows, writeAllocator); + readAndValidate(testFloats, data, readAllocator); + } + + /** Test read and write of float types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testFloatTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testFloats = new TestFloats(); + byte[] data = writeData(testFloats, totalRows, writeAllocator); + readAndValidate(testFloats, data, readAllocator); + } + + /** Test read and write of decimal types . */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testDecimalTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testFloats = new TestDecimal(); + byte[] data = writeData(testFloats, totalRows, writeAllocator); + readAndValidate(testFloats, data, readAllocator); + } + + /** Test read and write of decimal256 types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testDecimal256Types( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testDecimal256 = new TestDecimal256(); + byte[] data = writeData(testDecimal256, totalRows, writeAllocator); + readAndValidate(testDecimal256, data, readAllocator); + } + + /** Test read and write of boolean types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testBoolTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testBool = new TestBoolTypes(); + byte[] data = writeData(testBool, totalRows, writeAllocator); + readAndValidate(testBool, data, readAllocator); + } + + /** Test read and write of null types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testNullTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testNull = new TestNullTypes(); + byte[] data = writeData(testNull, totalRows, writeAllocator); + readAndValidate(testNull, data, readAllocator); + } + + /** Test integers */ + private class TestIntegers implements DataTester { + private final Field signedByteField = newSignedByteIntField(); + private final Field signedShortField = newSignedShortIntField(); + private final Field signedIntField = newSignedIntField(); + private final Field signedLongField = newSignedLongField(); + private final Field unsignedByteField = newUnsignedByteIntField(); + private final Field unsignedShortField = newUnsignedShortIntField(); + private final Field unsignedIntField = newUnsignedIntField(); + private final Field unsignedLongField = newUnsignedLongField(); + private final Schema schema = + new Schema( + Arrays.asList( + signedByteField, + signedShortField, + signedIntField, + signedLongField, + unsignedByteField, + unsignedShortField, + unsignedIntField, + unsignedLongField)); + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + TinyIntVector signedByteInt = + (TinyIntVector) vectorSchemaRoot.getVector(signedByteField.getName()); + SmallIntVector signedShortInt = + (SmallIntVector) vectorSchemaRoot.getVector(signedShortField.getName()); + IntVector signedInt = (IntVector) vectorSchemaRoot.getVector(signedIntField.getName()); + BigIntVector signedLong = + (BigIntVector) vectorSchemaRoot.getVector(signedLongField.getName()); + UInt1Vector unsignedByteInt = + (UInt1Vector) vectorSchemaRoot.getVector(unsignedByteField.getName()); + UInt2Vector unsignedShortInt = + (UInt2Vector) vectorSchemaRoot.getVector(unsignedShortField.getName()); + UInt4Vector unsignedInt = + (UInt4Vector) vectorSchemaRoot.getVector(unsignedIntField.getName()); + UInt8Vector unsignedLong = + (UInt8Vector) vectorSchemaRoot.getVector(unsignedLongField.getName()); + // Set signed bytes. + signedByteInt.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + signedByteInt.setNull(i); + } else { + signedByteInt.set(i, getSignedByte(i)); + } + } + + // Set signed shorts. + signedShortInt.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + signedShortInt.setNull(i); + } else { + signedShortInt.set(i, getSignedShort(i)); + } + } + + // Set signed ints. + signedInt.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + signedInt.setNull(i); + } else { + signedInt.set(i, getSignedInt(i)); + } + } + + // Set signed longs. + signedLong.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + signedLong.setNull(i); + } else { + signedLong.set(i, getSignedLong(i)); + } + } + + // Set unsigned bytes. + unsignedByteInt.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + unsignedByteInt.setNull(i); + } else { + unsignedByteInt.set(i, getUnsignedByte(i)); + } + } + + // Set unsigned shorts. + unsignedShortInt.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + unsignedShortInt.setNull(i); + } else { + unsignedShortInt.set(i, getUnsignedShort(i)); + } + } + + // Set unsigned ints. + unsignedInt.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + unsignedInt.setNull(i); + } else { + unsignedInt.set(i, getUnsignedInt(i)); + } + } + + // Set unsigned longs. + unsignedLong.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + unsignedLong.setNull(i); + } else { + unsignedLong.set(i, getUnsignedLong(i)); + } + } + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + TinyIntVector signedByteInt = + (TinyIntVector) vectorSchemaRoot.getVector(signedByteField.getName()); + SmallIntVector signedShortInt = + (SmallIntVector) vectorSchemaRoot.getVector(signedShortField.getName()); + IntVector signedInt = (IntVector) vectorSchemaRoot.getVector(signedIntField.getName()); + BigIntVector signedLong = + (BigIntVector) vectorSchemaRoot.getVector(signedLongField.getName()); + UInt1Vector unsignedByteInt = + (UInt1Vector) vectorSchemaRoot.getVector(unsignedByteField.getName()); + UInt2Vector unsignedShortInt = + (UInt2Vector) vectorSchemaRoot.getVector(unsignedShortField.getName()); + UInt4Vector unsignedInt = + (UInt4Vector) vectorSchemaRoot.getVector(unsignedIntField.getName()); + UInt8Vector unsignedLong = + (UInt8Vector) vectorSchemaRoot.getVector(unsignedLongField.getName()); + + int rowCount = vectorSchemaRoot.getRowCount(); + // Validate all rows + for (int i = 0; i < rowCount; i++) { + // Validate signed byte + if (i % 2 == 0) { + assertTrue(signedByteInt.isNull(i), "Signed byte should be null at index " + i); + } else { + assertEquals( + getSignedByte(i), signedByteInt.get(i), "Signed byte mismatch at index " + i); + } + + // Validate signed short + if (i % 2 == 0) { + assertTrue(signedShortInt.isNull(i), "Signed short should be null at index " + i); + } else { + assertEquals( + getSignedShort(i), signedShortInt.get(i), "Signed short mismatch at index " + i); + } + + // Validate signed int + if (i % 2 == 0) { + assertTrue(signedInt.isNull(i), "Signed int should be null at index " + i); + } else { + assertEquals(getSignedInt(i), signedInt.get(i), "Signed int mismatch at index " + i); + } + + // Validate signed long + if (i % 2 == 0) { + assertTrue(signedLong.isNull(i), "Signed long should be null at index " + i); + } else { + assertEquals(getSignedLong(i), signedLong.get(i), "Signed long mismatch at index " + i); + } + + // Validate unsigned byte (convert to unsigned using Byte.toUnsignedInt) + if (i % 2 == 0) { + assertTrue(unsignedByteInt.isNull(i), "Unsigned byte should be null at index " + i); + } else { + assertEquals( + getUnsignedByte(i), + Byte.toUnsignedInt(unsignedByteInt.get(i)), + "Unsigned byte mismatch at index " + i); + } + + // Validate unsigned short (char is already unsigned in Java) + if (i % 2 == 0) { + assertTrue(unsignedShortInt.isNull(i), "Unsigned short should be null at index " + i); + } else { + assertEquals( + getUnsignedShort(i), + unsignedShortInt.get(i), + "Unsigned short mismatch at index " + i); + } + + // Validate unsigned int (convert to unsigned long for comparison) + if (i % 2 == 0) { + assertTrue(unsignedInt.isNull(i), "Unsigned int should be null at index " + i); + } else { + assertEquals( + getUnsignedInt(i), unsignedInt.get(i), "Unsigned int mismatch at index " + i); + } + + // Validate unsigned long + if (i % 2 == 0) { + assertTrue(unsignedLong.isNull(i), "Unsigned long should be null at index " + i); + } else { + assertEquals( + getUnsignedLong(i), unsignedLong.get(i), "Unsigned long mismatch at index " + i); + } + } + } + } + + /** Test floats */ + private class TestFloats implements DataTester { + Field floatField = newFloatField(); + Field doubleField = newDoubleField(); + Schema schema = new Schema(Arrays.asList(floatField, doubleField)); + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + Float4Vector floatVector = (Float4Vector) vectorSchemaRoot.getVector(floatField.getName()); + Float8Vector doubleVector = (Float8Vector) vectorSchemaRoot.getVector(doubleField.getName()); + + // Set floats. + floatVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + floatVector.setNull(i); + } else { + floatVector.set(i, getFloat(i)); + } + } + + // Set doubles. + doubleVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + doubleVector.setNull(i); + } else { + doubleVector.set(i, getDouble(i)); + } + } + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + try (Float4Vector floatVector = + (Float4Vector) vectorSchemaRoot.getVector(floatField.getName()); + Float8Vector doubleVector = + (Float8Vector) vectorSchemaRoot.getVector(doubleField.getName())) { + for (int i = 0; i < vectorSchemaRoot.getRowCount(); i++) { + // Validate float + if (i % 2 == 0) { + assertTrue(floatVector.isNull(i), "Float should be null at index " + i); + } else { + assertEquals(getFloat(i), floatVector.get(i), 0.0001f, "Float mismatch at index " + i); + } + + // Validate double + if (i % 2 == 0) { + assertTrue(doubleVector.isNull(i), "Double should be null at index " + i); + } else { + assertEquals( + getDouble(i), doubleVector.get(i), 0.0001, "Double mismatch at index " + i); + } + } + } + } + } + + /** Test decimals */ + private class TestDecimal implements DataTester { + private final Field decimalFullPrecisionField = newDecimalField(38, 0, 128); + private final Field decimalTenPrecisionField = newDecimalField(10, 5, 128); + private final Field decimalTwentyPrecisionField = newDecimalField(20, 10, 128); + private final Field decimalZeroScaleField = newDecimalField(16, 0, 128); + private final Schema schema = + new Schema( + Arrays.asList( + decimalFullPrecisionField, + decimalTenPrecisionField, + decimalTwentyPrecisionField, + decimalZeroScaleField)); + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + DecimalVector decimalFullPrecisionVector = + (DecimalVector) vectorSchemaRoot.getVector(decimalFullPrecisionField.getName()); + DecimalVector decimalTenPrecisionVector = + (DecimalVector) vectorSchemaRoot.getVector(decimalTenPrecisionField.getName()); + DecimalVector decimalTwentyPrecisionVector = + (DecimalVector) vectorSchemaRoot.getVector(decimalTwentyPrecisionField.getName()); + DecimalVector decimalZeroScaleVector = + (DecimalVector) vectorSchemaRoot.getVector(decimalZeroScaleField.getName()); + + writeDecimals(decimalFullPrecisionVector, batchSize); + writeDecimals(decimalTenPrecisionVector, batchSize); + writeDecimals(decimalTwentyPrecisionVector, batchSize); + writeDecimals(decimalZeroScaleVector, batchSize); + } + + private void writeDecimals(DecimalVector decimalVector, int batchSize) { + // Set decimals. + decimalVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + decimalVector.setNull(i); + } else { + decimalVector.set(i, getDecimal(i, decimalVector.getScale())); + } + } + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + DecimalVector decimalFullPrecisionVector = + (DecimalVector) vectorSchemaRoot.getVector(decimalFullPrecisionField.getName()); + DecimalVector decimalTenPrecisionVector = + (DecimalVector) vectorSchemaRoot.getVector(decimalTenPrecisionField.getName()); + DecimalVector decimalTwentyPrecisionVector = + (DecimalVector) vectorSchemaRoot.getVector(decimalTwentyPrecisionField.getName()); + DecimalVector decimalZeroScaleVector = + (DecimalVector) vectorSchemaRoot.getVector(decimalZeroScaleField.getName()); + + int rowCount = vectorSchemaRoot.getRowCount(); + validateDecimals(decimalFullPrecisionVector, rowCount); + validateDecimals(decimalTenPrecisionVector, rowCount); + validateDecimals(decimalTwentyPrecisionVector, rowCount); + validateDecimals(decimalZeroScaleVector, rowCount); + } + + private void validateDecimals(DecimalVector decimalVector, int rowCount) { + for (int i = 0; i < rowCount; i++) { + // Validate decimal + BigDecimal bigDecimal = decimalVector.getObject(i); + if (i % 2 == 0) { + assertNull(bigDecimal, "Decimal should be null at index " + i); + } else { + assertNotNull(bigDecimal, "Decimal should not be null at index " + i); + assertEquals( + getDecimal(i, decimalVector.getScale()), + bigDecimal, + "Decimal mismatch at index " + i); + } + } + } + } + + /** Test decimal256 */ + private class TestDecimal256 implements DataTester { + private final Field decimalFullPrecisionField = newDecimalField(76, 10, 256); + private final Field decimalTenPrecisionField = newDecimalField(10, 5, 256); + private final Field decimalTwentyPrecisionField = newDecimalField(20, 10, 256); + private final Field decimalZeroScaleField = newDecimalField(32, 0, 256); + private final Schema schema = + new Schema( + Arrays.asList( + decimalFullPrecisionField, + decimalTenPrecisionField, + decimalTwentyPrecisionField, + decimalZeroScaleField)); + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + Decimal256Vector decimalFullPrecisionVector = + (Decimal256Vector) vectorSchemaRoot.getVector(decimalFullPrecisionField.getName()); + Decimal256Vector decimalTenPrecisionVector = + (Decimal256Vector) vectorSchemaRoot.getVector(decimalTenPrecisionField.getName()); + Decimal256Vector decimalTwentyPrecisionVector = + (Decimal256Vector) vectorSchemaRoot.getVector(decimalTwentyPrecisionField.getName()); + Decimal256Vector decimalZeroScaleVector = + (Decimal256Vector) vectorSchemaRoot.getVector(decimalZeroScaleField.getName()); + + writeDecimals(decimalFullPrecisionVector, batchSize); + writeDecimals(decimalTenPrecisionVector, batchSize); + writeDecimals(decimalTwentyPrecisionVector, batchSize); + writeDecimals(decimalZeroScaleVector, batchSize); + } + + private void writeDecimals(Decimal256Vector decimalVector, int batchSize) { + // Set decimals. + decimalVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + decimalVector.setNull(i); + } else { + BigDecimal value = getDecimal(i, decimalVector.getScale()); + decimalVector.set(i, value); + } + } + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + Decimal256Vector decimalFullPrecisionVector = + (Decimal256Vector) vectorSchemaRoot.getVector(decimalFullPrecisionField.getName()); + Decimal256Vector decimalTenPrecisionVector = + (Decimal256Vector) vectorSchemaRoot.getVector(decimalTenPrecisionField.getName()); + Decimal256Vector decimalTwentyPrecisionVector = + (Decimal256Vector) vectorSchemaRoot.getVector(decimalTwentyPrecisionField.getName()); + Decimal256Vector decimalZeroScaleVector = + (Decimal256Vector) vectorSchemaRoot.getVector(decimalZeroScaleField.getName()); + + int rowCount = vectorSchemaRoot.getRowCount(); + validateDecimals(decimalFullPrecisionVector, rowCount); + validateDecimals(decimalTenPrecisionVector, rowCount); + validateDecimals(decimalTwentyPrecisionVector, rowCount); + validateDecimals(decimalZeroScaleVector, rowCount); + } + + private void validateDecimals(Decimal256Vector decimalVector, int rowCount) { + for (int i = 0; i < rowCount; i++) { + // Validate decimal + BigDecimal bigDecimal = decimalVector.getObject(i); + if (i % 2 == 0) { + assertNull(bigDecimal, "Decimal256 should be null at index " + i); + } else { + assertNotNull(bigDecimal, "Decimal256 should not be null at index " + i); + assertEquals( + getDecimal(i, decimalVector.getScale()), + bigDecimal, + "Decimal256 mismatch at index " + i); + } + } + } + } + + /** Test boolean types */ + private class TestBoolTypes implements DataTester { + private final Field boolField; + private final Schema schema; + + TestBoolTypes() { + boolField = newBoolField(); + schema = new Schema(Collections.singletonList(boolField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + BitVector boolVector = (BitVector) vectorSchemaRoot.getVector(boolField.getName()); + boolVector.allocateNew(batchSize); + + for (int i = 0; i < batchSize; i++) { + if (i % 3 == 0) { + boolVector.setNull(i); + } else if (i % 2 == 0) { + boolVector.set(i, 0); // false + } else { + boolVector.set(i, 1); // true + } + } + boolVector.setValueCount(batchSize); + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + BitVector boolVector = (BitVector) vectorSchemaRoot.getVector(boolField.getName()); + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + if (i % 3 == 0) { + assertTrue(boolVector.isNull(i), "Bool should be null at index " + i); + } else if (i % 2 == 0) { + assertFalse(boolVector.isNull(i), "Bool should not be null at index " + i); + assertEquals(0, boolVector.get(i), "Bool should be false (0) at index " + i); + } else { + assertFalse(boolVector.isNull(i), "Bool should not be null at index " + i); + assertEquals(1, boolVector.get(i), "Bool should be true (1) at index " + i); + } + } + } + + private Field newBoolField() { + return new Field("bool-field", FieldType.nullable(new ArrowType.Bool()), null); + } + } + + /** Test null types */ + private class TestNullTypes implements DataTester { + private final Field nullField; + private final Schema schema; + + TestNullTypes() { + nullField = newNullField(); + schema = new Schema(Collections.singletonList(nullField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + NullVector nullVector = (NullVector) vectorSchemaRoot.getVector(nullField.getName()); + nullVector.allocateNew(); + nullVector.setValueCount(batchSize); + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + NullVector nullVector = (NullVector) vectorSchemaRoot.getVector(nullField.getName()); + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + assertTrue(nullVector.isNull(i), "Null vector should be null at index " + i); + } + } + + private Field newNullField() { + return new Field("null-field", FieldType.nullable(new ArrowType.Null()), null); + } + } +} diff --git a/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchTemporalTypesTest.java b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchTemporalTypesTest.java new file mode 100644 index 000000000..6ceb48f7e --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchTemporalTypesTest.java @@ -0,0 +1,341 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.time.Duration; +import java.util.Arrays; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalMonthDayNanoVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.PeriodDuration; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.holders.NullableIntervalDayHolder; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +/** Test temporal Arrow data types (dates, timestamps, time, duration, intervals). */ +@Tag("Jvm17PlusAndArrowToNioReflectionDisabled") +@EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) +public class DatabricksArrowPatchTemporalTypesTest extends AbstractDatabricksArrowPatchTypesTest { + + /** Test read and write of temporal types. */ + @ParameterizedTest + @MethodSource("getBufferAllocators") + public void testTemporalTypes( + BufferAllocator readAllocator, BufferAllocator writeAllocator, int totalRows) + throws Exception { + DataTester testTemporal = new TestTemporalTypes(); + byte[] data = writeData(testTemporal, totalRows, writeAllocator); + readAndValidate(testTemporal, data, readAllocator); + } + + /** Test temporal types */ + private class TestTemporalTypes implements DataTester { + private final Field dateDayField; + private final Field timestampField; + private final Field timestampMicroField; + private final Field timeSecField; + private final Field timeNanoField; + private final Field durationMicrosecondField; + private final Field intervalYearField; + private final Field intervalDayField; + private final Field intervalMonthDayNanoField; + private final Schema schema; + + TestTemporalTypes() { + dateDayField = newDateDayField(); + timestampField = newTimestampMilliField(); + timestampMicroField = newTimestampMicroField(); + timeSecField = newTimeSecField(); + timeNanoField = newTimeNanoField(); + durationMicrosecondField = newDurationMicrosecondField(); + intervalYearField = newIntervalYearField(); + intervalDayField = newIntervalDayField(); + intervalMonthDayNanoField = newIntervalMonthDayNanoField(); + schema = + new Schema( + Arrays.asList( + dateDayField, + timestampField, + timestampMicroField, + timeSecField, + timeNanoField, + durationMicrosecondField, + intervalYearField, + intervalDayField, + intervalMonthDayNanoField)); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public void writeData(VectorSchemaRoot vectorSchemaRoot, int batchSize) { + DateDayVector dateDayVector = + (DateDayVector) vectorSchemaRoot.getVector(dateDayField.getName()); + TimeStampMilliTZVector timestampVector = + (TimeStampMilliTZVector) vectorSchemaRoot.getVector(timestampField.getName()); + TimeStampMicroTZVector timestampMicroVector = + (TimeStampMicroTZVector) vectorSchemaRoot.getVector(timestampMicroField.getName()); + TimeSecVector timeSecVector = + (TimeSecVector) vectorSchemaRoot.getVector(timeSecField.getName()); + TimeNanoVector timeNanoVector = + (TimeNanoVector) vectorSchemaRoot.getVector(timeNanoField.getName()); + DurationVector durationVector = + (DurationVector) vectorSchemaRoot.getVector(durationMicrosecondField.getName()); + IntervalYearVector intervalYearVector = + (IntervalYearVector) vectorSchemaRoot.getVector(intervalYearField.getName()); + IntervalDayVector intervalDayVector = + (IntervalDayVector) vectorSchemaRoot.getVector(intervalDayField.getName()); + IntervalMonthDayNanoVector intervalMonthDayNanoVector = + (IntervalMonthDayNanoVector) + vectorSchemaRoot.getVector(intervalMonthDayNanoField.getName()); + + // Set dates (days since epoch). + dateDayVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + dateDayVector.setNull(i); + } else { + dateDayVector.set(i, getDateDay(i)); + } + } + + // Set timestamps (milliseconds since epoch). + timestampVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + timestampVector.setNull(i); + } else { + timestampVector.set(i, getTimestampMilli(i)); + } + } + + // Set timestamps (microseconds since epoch). + timestampMicroVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + timestampMicroVector.setNull(i); + } else { + timestampMicroVector.set(i, getTimestampMicro(i)); + } + } + + // Set times (seconds since midnight). + timeSecVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + timeSecVector.setNull(i); + } else { + timeSecVector.set(i, getTimeSec(i)); + } + } + + // Set times (nanoseconds since midnight). + timeNanoVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + timeNanoVector.setNull(i); + } else { + timeNanoVector.set(i, getTimeNano(i)); + } + } + + // Set durations (microseconds). + durationVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + durationVector.setNull(i); + } else { + durationVector.set(i, getDurationMicroseconds(i)); + } + } + + // Set interval year-months. + intervalYearVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + intervalYearVector.setNull(i); + } else { + intervalYearVector.set(i, getIntervalYearMonth(i)); + } + } + + // Set interval day-times. + intervalDayVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + intervalDayVector.setNull(i); + } else { + intervalDayVector.set(i, getIntervalDayDays(i), getIntervalDayMillis(i)); + } + } + + // Set interval month-day-nanos. + intervalMonthDayNanoVector.allocateNew(batchSize); + for (int i = 0; i < batchSize; i++) { + if (i % 2 == 0) { + intervalMonthDayNanoVector.setNull(i); + } else { + intervalMonthDayNanoVector.set( + i, + getIntervalMonthDayNanoMonths(i), + getIntervalMonthDayNanoDays(i), + getIntervalMonthDayNanoNanos(i)); + } + } + } + + @Override + public void validateData(VectorSchemaRoot vectorSchemaRoot) { + DateDayVector dateDayVector = + (DateDayVector) vectorSchemaRoot.getVector(dateDayField.getName()); + TimeStampMilliTZVector timestampVector = + (TimeStampMilliTZVector) vectorSchemaRoot.getVector(timestampField.getName()); + TimeStampMicroTZVector timestampMicroVector = + (TimeStampMicroTZVector) vectorSchemaRoot.getVector(timestampMicroField.getName()); + TimeSecVector timeSecVector = + (TimeSecVector) vectorSchemaRoot.getVector(timeSecField.getName()); + TimeNanoVector timeNanoVector = + (TimeNanoVector) vectorSchemaRoot.getVector(timeNanoField.getName()); + DurationVector durationVector = + (DurationVector) vectorSchemaRoot.getVector(durationMicrosecondField.getName()); + IntervalYearVector intervalYearVector = + (IntervalYearVector) vectorSchemaRoot.getVector(intervalYearField.getName()); + IntervalDayVector intervalDayVector = + (IntervalDayVector) vectorSchemaRoot.getVector(intervalDayField.getName()); + IntervalMonthDayNanoVector intervalMonthDayNanoVector = + (IntervalMonthDayNanoVector) + vectorSchemaRoot.getVector(intervalMonthDayNanoField.getName()); + + int rowCount = vectorSchemaRoot.getRowCount(); + + for (int i = 0; i < rowCount; i++) { + // Validate date (days since epoch) + if (i % 2 == 0) { + assertTrue(dateDayVector.isNull(i), "Date should be null at index " + i); + } else { + assertEquals(getDateDay(i), dateDayVector.get(i), "Date mismatch at index " + i); + } + + // Validate timestamp (milliseconds since epoch) + if (i % 2 == 0) { + assertTrue(timestampVector.isNull(i), "Timestamp should be null at index " + i); + } else { + assertEquals( + getTimestampMilli(i), timestampVector.get(i), "Timestamp mismatch at index " + i); + } + + // Validate timestamp (microseconds since epoch) + if (i % 2 == 0) { + assertTrue( + timestampMicroVector.isNull(i), "Timestamp micro should be null at index " + i); + } else { + assertEquals( + getTimestampMicro(i), + timestampMicroVector.get(i), + "Timestamp micro mismatch at index " + i); + } + + // Validate time (seconds since midnight) + if (i % 2 == 0) { + assertTrue(timeSecVector.isNull(i), "Time should be null at index " + i); + } else { + assertEquals(getTimeSec(i), timeSecVector.get(i), "Time mismatch at index " + i); + } + + // Validate time (nanoseconds since midnight) + if (i % 2 == 0) { + assertTrue(timeNanoVector.isNull(i), "Time nano should be null at index " + i); + } else { + assertEquals(getTimeNano(i), timeNanoVector.get(i), "Time nano mismatch at index " + i); + } + + // Validate duration (seconds) + if (i % 2 == 0) { + assertTrue(durationVector.isNull(i), "Duration should be null at index " + i); + } else { + Duration durationValue = durationVector.getObject(i); + assertNotNull(durationValue, "Duration should not be null at index " + i); + assertEquals( + getDurationMicroseconds(i), + durationValue.getNano() / (1000), + "Duration mismatch at index " + i); + } + + // Validate interval year-month + if (i % 2 == 0) { + assertTrue( + intervalYearVector.isNull(i), "Interval year-month should be null at index " + i); + } else { + assertEquals( + getIntervalYearMonth(i), + intervalYearVector.get(i), + "Interval year-month mismatch at index " + i); + } + + // Validate interval day-time + if (i % 2 == 0) { + assertTrue(intervalDayVector.isNull(i), "Interval day-time should be null at index " + i); + } else { + NullableIntervalDayHolder holder = new NullableIntervalDayHolder(); + intervalDayVector.get(i, holder); + assertEquals(getIntervalDayDays(i), holder.days, "Interval days mismatch at index " + i); + assertEquals( + getIntervalDayMillis(i), + holder.milliseconds, + "Interval milliseconds mismatch at index " + i); + } + + // Validate interval month-day-nano + if (i % 2 == 0) { + assertTrue( + intervalMonthDayNanoVector.isNull(i), + "Interval month-day-nano should be null at index " + i); + } else { + PeriodDuration value = intervalMonthDayNanoVector.getObject(i); + assertNotNull(value, "Interval month-day-nano should not be null at index " + i); + assertEquals( + getIntervalMonthDayNanoMonths(i), + value.getPeriod().toTotalMonths(), + "Interval months mismatch at index " + i); + assertEquals( + getIntervalMonthDayNanoDays(i), + value.getPeriod().getDays(), + "Interval days mismatch at index " + i); + assertEquals( + getIntervalMonthDayNanoNanos(i), + value.getDuration().toNanos(), + "Interval nanoseconds mismatch at index " + i); + } + } + } + + private int getIntervalMonthDayNanoMonths(int index) { + return index % 12; + } + + private int getIntervalMonthDayNanoDays(int index) { + return index % 30; + } + + private long getIntervalMonthDayNanoNanos(int index) { + return ((long) index * 1_000_000_000L) % 86_400_000_000_000L; + } + } +} diff --git a/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchTest.java b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchTest.java new file mode 100644 index 000000000..f26fc1ffc --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/DatabricksArrowPatchTest.java @@ -0,0 +1,293 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import net.jpountz.lz4.LZ4FrameInputStream; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.util.TransferPair; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Test the patched allocator works. */ +public class DatabricksArrowPatchTest { + private static final Logger logger = LoggerFactory.getLogger(DatabricksArrowPatchTest.class); + + /** Path to an arrow chunk. */ + private static final Path ARROW_CHUNK_PATH = Path.of("arrow", "chunk_all_types.arrow"); + + /** Path to a LZ4 compressed arrow chunk. */ + private static final Path ARROW_CHUNK_COMPRESSED_PATH = + Path.of("arrow", "chunk_all_types.arrow.lz4"); + + /** Compressed Arrow file suffix. */ + private static final String ARROW_CHUNK_COMPRESSED_FILE_SUFFIX = ".lz4"; + + /** Default number of concurrent threads. */ + private static final int DEFAULT_NUM_THREADS = Runtime.getRuntime().availableProcessors(); + + /** System property name to set the number of threads */ + private static final String NUM_THREADS_PROPERTY_NAME = "test.arrow.num.threads"; + + /** Default iterations per thread. */ + private static final int DEFAULT_CONCURRENT_ITERATIONS_PER_THREAD = 100; + + /** System property name to set the number of iterations per thread. */ + private static final String CONCURRENT_ITERATIONS_PER_THREAD_PROPERTY_NAME = + "test.arrow.iterations.per.thread"; + + /** + * Test exception is thrown when jvm arg "--add-opens=java.base/java.nio=ALL-UNNAMED" is missing + * on JVM >= 17. + */ + @Test + @Tag("Jvm17PlusAndArrowToNioReflectionDisabled") + @EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) + public void testArrowThrowsExceptionOnMissingAddOpensJvmArgs() { + Throwable throwable = null; + try { + RootAllocator allocator = new RootAllocator(); + ArrowBuf buffer = allocator.buffer(64); + buffer.writeByte(0); + allocator.close(); // Unreachable code. + } catch (Throwable t) { + throwable = t; + } + + assertNotNull(throwable); + for (var cause = throwable; cause != null; cause = cause.getCause()) { + logger.info("Throwable in chain: {} - {}", cause.getClass().getName(), cause.getMessage()); + } + } + + /** + * Test patched Arrow buffer allocator works when jvm arg + * "--add-opens=java.base/java.nio=ALL-UNNAMED" is missing. + */ + @Test + @Tag("Jvm17PlusAndArrowToNioReflectionDisabled") + @EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) + public void testPatchedArrowWorksWithMissingAddOpensJvmArgs() throws IOException { + for (Path path : Arrays.asList(ARROW_CHUNK_PATH, ARROW_CHUNK_COMPRESSED_PATH)) { + try (DatabricksBufferAllocator allocator = new DatabricksBufferAllocator()) { + List> records = parseArrowStream(path, allocator); + assertFalse(records.isEmpty(), "Some records should be parsed"); + + logger.info("Parsed {} records from path {}", records.size(), path); + } + } + } + + /** Parse files concurrently. Test for any memory leaks */ + @Test + public void testConcurrentExecution() { + int numThreads = + System.getProperty(NUM_THREADS_PROPERTY_NAME) == null + ? DEFAULT_NUM_THREADS + : Integer.parseInt(System.getProperty(NUM_THREADS_PROPERTY_NAME)); + int iterationsPerThread = + System.getProperty(CONCURRENT_ITERATIONS_PER_THREAD_PROPERTY_NAME) == null + ? DEFAULT_CONCURRENT_ITERATIONS_PER_THREAD + : Integer.parseInt(System.getProperty(CONCURRENT_ITERATIONS_PER_THREAD_PROPERTY_NAME)); + + int totalIterations = numThreads * iterationsPerThread; + logger.info("Num threads {}, Total iterations: {}", numThreads, totalIterations); + + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + try { + IntStream.range(0, totalIterations) + .mapToObj( + i -> + executor.submit( + () -> { + try { + parseArrowStream( + i % 2 == 0 ? ARROW_CHUNK_PATH : ARROW_CHUNK_COMPRESSED_PATH, + new DatabricksBufferAllocator()); + } catch (IOException e) { + throw new RuntimeException(e); + } + })) + .forEach( + future -> { + try { + future.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }); + } finally { + executor.shutdownNow(); + } + } + + /** Test that the patched DatabricksArrowBuf parses records correctly. */ + @Test + public void testPatchedArrowParsing() throws IOException { + testParsing(ARROW_CHUNK_PATH); + testParsing(ARROW_CHUNK_COMPRESSED_PATH); + } + + /** + * Parse the Arrow stream file at {@code filePath} and compare the records returned by native + * Arrow and patched Arrow. + */ + private void testParsing(Path filePath) throws IOException { + try (RootAllocator rootAllocator = new RootAllocator(); + DatabricksBufferAllocator patchedAllocator = new DatabricksBufferAllocator()) { + + // Parse with Arrow. + logger.info("Parsing {} with Arrow RootAllocator", filePath); + List> rootAllocatorRecords = parseArrowStream(filePath, rootAllocator); + logger.info("RootAllocator records: {}", rootAllocatorRecords.size()); + + // Parse with Patched Arrow. + logger.info("Parsing {} with Arrow patched DatabricksBufferAllocator", filePath); + List> patchedAllocatorRecords = + parseArrowStream(filePath, patchedAllocator); + logger.info("DatabricksBufferAllocator records: {}", patchedAllocatorRecords.size()); + + // Assert that records exist and same number of records are parsed. + assertFalse(rootAllocatorRecords.isEmpty(), "Some records should be parsed"); + assertEquals( + rootAllocatorRecords.size(), + patchedAllocatorRecords.size(), + "Both should parse same number of records"); + + // Log a sample record. + logger.info("Sample record {}", rootAllocatorRecords.get(0)); + + // Compare all records parsed using both allocators. + for (int i = 0; i < patchedAllocatorRecords.size(); i++) { + Map patchedRecord = patchedAllocatorRecords.get(i); + Map rootRecord = rootAllocatorRecords.get(i); + assertEquals( + rootRecord.keySet(), patchedRecord.keySet(), "Same number of columns should be parsed"); + + for (String key : patchedRecord.keySet()) { + Object patchedColumn = patchedRecord.get(key); + Object rootColumn = rootRecord.get(key); + assertTrue( + deepEquals(rootColumn, patchedColumn), + "Column " + key + " should be same at row " + i); + } + } + } + } + + /** + * Deep equality check that handles byte[] (which uses reference equality by default) and + * recursively checks Lists and Maps that may contain byte[] values (e.g., struct or map columns + * with binary fields). + */ + @SuppressWarnings("unchecked") + private boolean deepEquals(Object a, Object b) { + if (a == b) return true; + if (a == null || b == null) return false; + if (a instanceof byte[] && b instanceof byte[]) { + return Arrays.equals((byte[]) a, (byte[]) b); + } + if (a instanceof List && b instanceof List) { + List listA = (List) a; + List listB = (List) b; + if (listA.size() != listB.size()) return false; + for (int i = 0; i < listA.size(); i++) { + if (!deepEquals(listA.get(i), listB.get(i))) return false; + } + return true; + } + if (a instanceof Map && b instanceof Map) { + Map mapA = (Map) a; + Map mapB = (Map) b; + if (mapA.size() != mapB.size()) return false; + for (Map.Entry entry : mapA.entrySet()) { + // For binary map keys, we need to find a matching key in mapB + Object matchedValue = null; + boolean found = false; + for (Map.Entry entryB : mapB.entrySet()) { + if (deepEquals(entry.getKey(), entryB.getKey())) { + matchedValue = entryB.getValue(); + found = true; + break; + } + } + if (!found || !deepEquals(entry.getValue(), matchedValue)) return false; + } + return true; + } + return a.equals(b); + } + + /** Parse the Arrow stream file stored at {@code filePath} and return the records in the file. */ + private List> parseArrowStream(Path filePath, BufferAllocator allocator) + throws IOException { + ArrayList> records = new ArrayList<>(); + + try (InputStream arrowStream = getStream(filePath); + ArrowStreamReader reader = new ArrowStreamReader(arrowStream, allocator)) { + // Iterate over batches. + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + + // Transfer all vectors. + List valueVectors = + root.getFieldVectors().stream() + .map( + fieldVector -> { + TransferPair transferPair = fieldVector.getTransferPair(allocator); + transferPair.transfer(); + return transferPair.getTo(); + }) + .collect(Collectors.toList()); + + // Parse and populate each record/row in this batch. + try { + for (int recordIndex = 0; recordIndex < root.getRowCount(); recordIndex++) { + HashMap record = new HashMap<>(); + for (ValueVector valueVector : valueVectors) { + record.put(valueVector.getField().getName(), valueVector.getObject(recordIndex)); + } + records.add(record); + } + } finally { + // Close all transferred vectors to prevent memory leak + valueVectors.forEach(ValueVector::close); + } + } + } + + return records; + } + + /** + * @return an input stream for the filePath. + */ + private InputStream getStream(Path filePath) throws IOException { + InputStream arrowStream = getClass().getClassLoader().getResourceAsStream(filePath.toString()); + assertNotNull(arrowStream, filePath + " not found"); + return filePath.toString().endsWith(ARROW_CHUNK_COMPRESSED_FILE_SUFFIX) + ? new LZ4FrameInputStream(arrowStream) + : arrowStream; + } +} diff --git a/src/test/java/org/apache/arrow/memory/DatabricksBufferAllocatorTest.java b/src/test/java/org/apache/arrow/memory/DatabricksBufferAllocatorTest.java new file mode 100644 index 000000000..cc2f98482 --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/DatabricksBufferAllocatorTest.java @@ -0,0 +1,156 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.arrow.memory.rounding.DefaultRoundingPolicy; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; + +/** Test buffer allocator */ +@Tag("Jvm17PlusAndArrowToNioReflectionDisabled") +@EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) +public class DatabricksBufferAllocatorTest { + private static int LIMIT = Integer.MAX_VALUE; + private static long ALLOCATED_MEMORY = 0; + private static long INIT_RESERVATION = 0; + private static long PEAK_MEMORY_ALLOCATION = 0; + private static long HEADROOM = Integer.MAX_VALUE; + private static boolean IS_OVERLIMIT = false; + + /** Test allocation. */ + @Test + public void testAllocation() { + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + + for (int size = 0; size < 4096; size++) { + ArrowBuf buffer = allocator.buffer(size); + assertInstanceOf(DatabricksArrowBuf.class, buffer, "Should be of type DatabricksArrowBuf"); + assertTrue(buffer.capacity() >= size, "Should have expected capacity"); + } + } + + /** Test parent child allocations. */ + @Test + public void testParentChildAllocation() { + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + + for (int i = 0; i < 5_000; i++) { + allocator.newChildAllocator("child" + i, 0, 0); + + allocator.getChildAllocators().stream() + .forEach( + c -> { + assertInstanceOf( + DatabricksBufferAllocator.class, + c, + "Should be of type DatabricksBufferAllocator"); + assertEquals( + allocator, c.getParentAllocator(), "Allocator parent should be the same"); + assertEquals(allocator, c.getRoot(), "Allocator parent should be the same"); + }); + + assertEquals( + i + 1, + allocator.getChildAllocators().size(), + "Allocator children should " + "be" + " the same"); + } + } + + /** Test get root in deeply nested allocators */ + @Test + public void testGetRoot() { + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + + DatabricksBufferAllocator currentNode = allocator; + for (int i = 0; i < 1000; i++) { + currentNode = (DatabricksBufferAllocator) currentNode.newChildAllocator("child" + 1, 0, 0); + assertEquals(allocator, currentNode.getRoot(), "Allocator root should be the same"); + } + } + + /** Test constants returned by DatabricksBufferAllocator. */ + @Test + public void testConstants() { + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + + for (int i = 0; i < 1000; i++) { + allocator.newChildAllocator("child" + i, 0, 0); + allocator.buffer(i); + + assertEquals(LIMIT, allocator.getLimit(), "Limit should be constant"); + assertEquals( + ALLOCATED_MEMORY, allocator.getAllocatedMemory(), "Allocated memory should be constant"); + assertEquals( + INIT_RESERVATION, allocator.getInitReservation(), "Init reservations should be constant"); + assertEquals(HEADROOM, allocator.getHeadroom(), "Headroom should be constant"); + assertEquals(LIMIT, allocator.getLimit(), "Limit should be constant"); + assertEquals( + PEAK_MEMORY_ALLOCATION, + allocator.getPeakMemoryAllocation(), + "Peak memory should be constant"); + assertEquals(IS_OVERLIMIT, allocator.isOverLimit(), "Over limit should be constant"); + assertEquals(AllocationListener.NOOP, allocator.getListener(), "Listener should be NO-OP"); + assertEquals( + DefaultRoundingPolicy.DEFAULT_ROUNDING_POLICY, + allocator.getRoundingPolicy(), + "Rounding policy should be default"); + } + } + + /** Test verbose string */ + @Test + public void testToVerboseString() { + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + + for (int i = 0; i < 1000; i++) { + allocator.newChildAllocator("child" + i, 0, 0); + allocator.buffer(i); + assertDoesNotThrow(allocator::toVerboseString, "Verbose string should faile"); + } + } + + /** Test force allocate */ + @Test + public void testForceAllocate() { + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + + DatabricksBufferAllocator currentNode = allocator; + for (int i = 0; i < 1000; i++) { + currentNode = (DatabricksBufferAllocator) allocator.newChildAllocator("child" + i, 0, 0); + currentNode.buffer(i); + + assertTrue(currentNode.forceAllocate(i), "Force allocate should succeed"); + } + } + + /** Test assert open */ + @Test + public void testAssertOpen() { + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + assertDoesNotThrow(allocator::assertOpen, "Assert should not throw exception"); + allocator.close(); + assertThrows( + IllegalStateException.class, allocator::assertOpen, "Assert should throw exception"); + } + + /** Test wrap foreign allocation fails */ + @Test + public void testWrapForeignAllocationFails() { + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + ForeignAllocation foreignAllocation = + new ForeignAllocation(0, 0) { + @Override + protected void release0() {} + }; + assertThrows( + UnsupportedOperationException.class, + () -> allocator.wrapForeignAllocation(foreignAllocation), + "Wrap should fail"); + } +} diff --git a/src/test/java/org/apache/arrow/memory/DatabricksReferenceManagerNOOPTest.java b/src/test/java/org/apache/arrow/memory/DatabricksReferenceManagerNOOPTest.java new file mode 100644 index 000000000..2a826937c --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/DatabricksReferenceManagerNOOPTest.java @@ -0,0 +1,143 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; + +/** Test for DatabricksReferenceManagerNOOP. */ +@Tag("Jvm17PlusAndArrowToNioReflectionDisabled") +@EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) +public class DatabricksReferenceManagerNOOPTest { + /** Test that getRefCount always returns 1. */ + @Test + public void testGetRefCount() { + assertEquals(1, DatabricksReferenceManagerNOOP.INSTANCE.getRefCount()); + } + + /** Test that release operations always return false. */ + @Test + public void testRelease() { + DatabricksReferenceManagerNOOP noop = DatabricksReferenceManagerNOOP.INSTANCE; + + assertFalse(noop.release(), "release() should always return false"); + assertFalse(noop.release(1), "release(int) should always return false"); + assertFalse(noop.release(100), "release(int) should always return false"); + + // Ref count should remain 1 after release operations. + assertEquals(1, noop.getRefCount(), "Ref count should remain 1"); + } + + /** Test that retain operations don't throw and don't change ref count. */ + @Test + public void testRetain() { + DatabricksReferenceManagerNOOP noop = DatabricksReferenceManagerNOOP.INSTANCE; + + // These should not throw. + noop.retain(); + noop.retain(1); + noop.retain(100); + + // Ref count should remain 1 after retain operations. + assertEquals(1, noop.getRefCount(), "Ref count should remain 1"); + } + + /** Test that getSize and getAccountedSize return 0. */ + @Test + public void testSizeReturnsZero() { + DatabricksReferenceManagerNOOP noop = DatabricksReferenceManagerNOOP.INSTANCE; + + assertEquals(0L, noop.getSize(), "getSize() should return 0"); + assertEquals(0L, noop.getAccountedSize(), "getAccountedSize() should return 0"); + } + + /** Test that getAllocator returns a DatabricksBufferAllocator. */ + @Test + public void testGetAllocator() { + DatabricksReferenceManagerNOOP noop = DatabricksReferenceManagerNOOP.INSTANCE; + + BufferAllocator allocator = noop.getAllocator(); + assertNotNull(allocator, "getAllocator() should not return null"); + assertTrue( + allocator instanceof DatabricksBufferAllocator, + "getAllocator() should return a DatabricksBufferAllocator"); + } + + /** Test that retain(ArrowBuf, BufferAllocator) returns the same buffer. */ + @Test + public void testRetainBuffer() { + DatabricksReferenceManagerNOOP noop = DatabricksReferenceManagerNOOP.INSTANCE; + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + ArrowBuf buffer = allocator.buffer(1024); + + ArrowBuf retained = noop.retain(buffer, allocator); + assertSame(buffer, retained, "retain() should return the same buffer"); + } + + /** Test that deriveBuffer returns the same buffer. */ + @Test + public void testDeriveBuffer() { + DatabricksReferenceManagerNOOP noop = DatabricksReferenceManagerNOOP.INSTANCE; + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + ArrowBuf buffer = allocator.buffer(1024); + + ArrowBuf derived = noop.deriveBuffer(buffer, 0, 512); + assertSame(buffer, derived, "deriveBuffer() should return the same buffer"); + + // Test with different index and length - should still return the same buffer. + derived = noop.deriveBuffer(buffer, 256, 256); + assertSame(buffer, derived, "deriveBuffer() should return the same buffer regardless of index"); + } + + /** Test that transferOwnership returns a valid result. */ + @Test + public void testTransferOwnership() { + DatabricksReferenceManagerNOOP noop = DatabricksReferenceManagerNOOP.INSTANCE; + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + ArrowBuf buffer = allocator.buffer(1024); + + OwnershipTransferResult result = + noop.transferOwnership(buffer, new DatabricksBufferAllocator()); + + assertNotNull(result, "transferOwnership() should not return null"); + assertTrue(result.getAllocationFit(), "getAllocationFit() should return true"); + assertSame( + buffer, + result.getTransferredBuffer(), + "getTransferredBuffer() should return the original buffer"); + } + + /** Test all operations remain consistent after multiple calls. */ + @Test + public void testConsistencyAcrossMultipleCalls() { + DatabricksReferenceManagerNOOP noop = DatabricksReferenceManagerNOOP.INSTANCE; + + for (int i = 0; i < 100; i++) { + switch (i % 4) { + case 0: + noop.retain(); + break; + case 1: + noop.release(); + break; + case 2: + noop.retain(i); + break; + case 3: + default: + noop.release(i); + break; + } + + assertEquals(1, noop.getRefCount(), "Ref count should always be 1"); + assertEquals(0L, noop.getSize(), "Size should always be 0"); + assertEquals(0L, noop.getAccountedSize(), "Accounted size should always be 0"); + } + } +} diff --git a/src/test/java/org/apache/arrow/memory/DatabricksReferenceManagerTest.java b/src/test/java/org/apache/arrow/memory/DatabricksReferenceManagerTest.java new file mode 100644 index 000000000..2f34eb739 --- /dev/null +++ b/src/test/java/org/apache/arrow/memory/DatabricksReferenceManagerTest.java @@ -0,0 +1,152 @@ +package org.apache.arrow.memory; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Random; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnJre; +import org.junit.jupiter.api.condition.JRE; + +/** Test reference manager */ +@Tag("Jvm17PlusAndArrowToNioReflectionDisabled") +@EnabledOnJre({JRE.JAVA_17, JRE.JAVA_21}) +public class DatabricksReferenceManagerTest { + private static int REF_COUNT = 1; + + /** Test accounting of Reference manager */ + @Test + public void testAccounting() { + final long size = 2048; + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + DatabricksReferenceManager referenceManager = new DatabricksReferenceManager(allocator, size); + + for (int i = 0; i < 100; i++) { + switch (i % 4) { + case 0: + referenceManager.retain(); + break; + case 1: + referenceManager.release(); + break; + case 2: + referenceManager.retain(i); + break; + case 3: + default: + referenceManager.release(i); + break; + } + assertEquals(REF_COUNT, referenceManager.getRefCount(), "Ref count should be constant"); + assertEquals(size, referenceManager.getSize(), "Size should be constant"); + assertEquals(size, referenceManager.getAccountedSize(), "Size should be constant"); + assertEquals(allocator, referenceManager.getAllocator(), "Allocator should be the same"); + } + } + + /** Test derive buffer fails on invalid arguments */ + @Test + public void testDeriveBufferFailsOnPreconditions() { + final int bufferSize = 1024; + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + DatabricksReferenceManager referenceManager = + new DatabricksReferenceManager(allocator, bufferSize); + DatabricksArrowBuf buffer = (DatabricksArrowBuf) allocator.buffer(bufferSize); + + assertThrows( + IllegalArgumentException.class, + () -> referenceManager.deriveBuffer(buffer, 0, Integer.MAX_VALUE + 1L), + "Should fail on invalid length"); + assertThrows( + IllegalArgumentException.class, + () -> referenceManager.deriveBuffer(buffer, bufferSize - 1, 2), + "Should fail on invalid length"); + } + + /** Test derive buffer. */ + @Test + public void testDeriveBuffer() { + final int bufferSize = 4096; + DatabricksBufferAllocator allocator = new DatabricksBufferAllocator(); + DatabricksReferenceManager referenceManager = + new DatabricksReferenceManager(allocator, bufferSize); + DatabricksArrowBuf buffer = (DatabricksArrowBuf) allocator.buffer(bufferSize); + + // Write zeroes into the original buffer. + for (int i = 0; i < bufferSize; i++) { + buffer.writeByte((byte) 0); + } + + // Test whole slice. + DatabricksArrowBuf wholeSlice = (DatabricksArrowBuf) referenceManager.retain(buffer, allocator); + testWriteAffectsOriginalAndDerivedBuffer(buffer, wholeSlice, 0); + + // Test transfer ownership. + OwnershipTransferResult ownershipTransferResult = + referenceManager.transferOwnership(buffer, new DatabricksBufferAllocator()); + assertTrue(ownershipTransferResult.getAllocationFit(), "Should fit"); + DatabricksArrowBuf transferredBuffer = + (DatabricksArrowBuf) ownershipTransferResult.getTransferredBuffer(); + testWriteAffectsOriginalAndDerivedBuffer(transferredBuffer, buffer, 0); + + // Write data to part of a slice and check that the original buffer is affected as well. + for (int sliceSize = 1; sliceSize < bufferSize; sliceSize++) { + DatabricksArrowBuf slice = + (DatabricksArrowBuf) referenceManager.deriveBuffer(buffer, 0, sliceSize); + testWriteAffectsOriginalAndDerivedBuffer(buffer, slice, 0); + + int index = bufferSize - sliceSize; + slice = (DatabricksArrowBuf) buffer.slice(index, sliceSize); + testWriteAffectsOriginalAndDerivedBuffer(buffer, slice, index); + } + + // Test some random offsets and length. + Random random = new Random(); + for (int i = 0; i < 10; i++) { + int startOffset = random.nextInt(bufferSize); + int size = random.nextInt(bufferSize - startOffset); + DatabricksArrowBuf slice = + (DatabricksArrowBuf) referenceManager.deriveBuffer(buffer, startOffset, size); + testWriteAffectsOriginalAndDerivedBuffer(buffer, slice, startOffset); + } + } + + private void testWriteAffectsOriginalAndDerivedBuffer( + DatabricksArrowBuf buffer, DatabricksArrowBuf slice, int sliceStartIndex) { + final int bufferSize = (int) buffer.capacity(); + + // Write zeroes into the original buffer. + buffer.clear(); + for (int i = 0; i < bufferSize; i++) { + buffer.writeByte((byte) 0); + } + + // Write data to the slice and check that the original buffer is affected as well. + for (int i = 0; i < slice.capacity(); i++) { + slice.setByte(i, getByteValue(i)); + } + for (int i = 0; i < slice.capacity(); i++) { + assertEquals( + getByteValue(i), + buffer.getByte(sliceStartIndex + i), + "Readable bytes should be correct at index " + i); + } + + // Write data to the original buffer and check that the slice is affected. + for (int i = sliceStartIndex; i < sliceStartIndex + slice.capacity(); i++) { + buffer.setByte(i, getByteValue(i)); + } + for (int i = 0; i < slice.capacity(); i++) { + assertEquals( + getByteValue(sliceStartIndex + i), + slice.getByte(i), + "Readable bytes should be correct at index " + i); + } + } + + private byte getByteValue(int i) { + return (byte) (i % 256); + } +} diff --git a/src/test/resources/arrow/chunk_all_types.arrow b/src/test/resources/arrow/chunk_all_types.arrow new file mode 100644 index 000000000..794d65040 Binary files /dev/null and b/src/test/resources/arrow/chunk_all_types.arrow differ diff --git a/src/test/resources/arrow/chunk_all_types.arrow.lz4 b/src/test/resources/arrow/chunk_all_types.arrow.lz4 new file mode 100644 index 000000000..953431209 Binary files /dev/null and b/src/test/resources/arrow/chunk_all_types.arrow.lz4 differ diff --git a/test-assembly-thin/pom.xml b/test-assembly-thin/pom.xml new file mode 100644 index 000000000..cab4b45ff --- /dev/null +++ b/test-assembly-thin/pom.xml @@ -0,0 +1,36 @@ + + + 4.0.0 + + + com.databricks + databricks-jdbc-parent + 3.2.2-SNAPSHOT + + + test-databricks-jdbc-thin + jar + Test thin + Test JDBC databricks driver thin jar. + + + + true + + + + + com.databricks + databricks-jdbc-thin + 3.2.2-SNAPSHOT + + + org.junit.jupiter + junit-jupiter + ${junit.jupiter.version} + test + + + \ No newline at end of file diff --git a/test-assembly-thin/src/test/java/com/databricks/jdbc/TestThinPackaging.java b/test-assembly-thin/src/test/java/com/databricks/jdbc/TestThinPackaging.java new file mode 100644 index 000000000..b74a5bd79 --- /dev/null +++ b/test-assembly-thin/src/test/java/com/databricks/jdbc/TestThinPackaging.java @@ -0,0 +1,198 @@ +package com.databricks.jdbc; + +import com.databricks.jdbc.api.impl.arrow.ArrowBufferAllocator; +import com.databricks.jdbc.common.DatabricksJdbcUrlParams; +import com.databricks.sdk.core.DatabricksConfig; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.flatbuffers.FlatBufferBuilder; +import com.google.gson.Gson; +import com.google.protobuf.ByteString; +import com.nimbusds.jose.JWSAlgorithm; +import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig; +import io.grpc.Context; +import io.netty.buffer.ByteBufAllocator; +import io.vavr.collection.List; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.logging.Logger; +import org.apache.commons.codec.binary.Base64; +import org.apache.commons.io.IOUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; +import org.apache.hc.client5.http.impl.classic.HttpClients; +import org.apache.hc.core5.http.HttpStatus; +import org.apache.thrift.TException; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; +import org.locationtech.jts.geom.GeometryFactory; + +/** Test artifacts are packaged properly. */ +public class TestThinPackaging { + /** Logger instance. */ + private static final Logger logger = Logger.getLogger(TestThinPackaging.class.getName()); + + /** Test packages are shaded as expected. */ + @Test + public void testThinPackaging() { + // Test that Arrow packages is relocated. + com.databricks.internal.apache.arrow.memory.BufferAllocator bufferAllocator = + ArrowBufferAllocator.getBufferAllocator(); + logger.info("Shaded buffer allocator " + bufferAllocator); + + // Test that jackson packages are not relocated. + ObjectMapper jacksonMapper = new ObjectMapper(); + logger.info("Jackson ObjectMapper: " + jacksonMapper); + + // Test that guava is not relocated. + ImmutableList guavaList = ImmutableList.of("test"); + logger.info("Guava ImmutableList: " + guavaList); + + // Test that protobuf is not relocated. + ByteString protoByteString = ByteString.copyFromUtf8("test"); + logger.info("Protobuf ByteString: " + protoByteString); + + // Test that commons-lang3 is not relocated. + String commonsResult = StringUtils.upperCase("test"); + logger.info("Commons-Lang3 result: " + commonsResult); + + // Test that commons-codec is not relocated. + byte[] commonsCodec = Base64.encodeBase64("test".getBytes()); + logger.info("Commons-Codec Base64: " + new String(commonsCodec)); + + // Test that commons-io is not relocated. + try { + String commonsIo = IOUtils.toString(new ByteArrayInputStream("test".getBytes()), "UTF-8"); + logger.info("Commons-IO result: " + commonsIo); + } catch (IOException e) { + throw new RuntimeException("Failed to test Commons-IO", e); + } + + // Test that httpclient5 is not relocated. + HttpClientBuilder httpClientBuilder = HttpClients.custom(); + logger.info("HttpClient5 builder: " + httpClientBuilder); + + // Test that httpcore5 is not relocated. + int httpStatus = HttpStatus.SC_OK; + logger.info("HttpCore5 status: " + httpStatus); + + // Test that thrift is not relocated. + TException thriftException = new TException("test"); + logger.info("Thrift TException: " + thriftException.getMessage()); + + // Test that gson is not relocated. + Gson gson = new Gson(); + logger.info("Gson: " + gson); + + // Test that flatbuffers is not relocated. + FlatBufferBuilder flatBuilder = new FlatBufferBuilder(); + logger.info("FlatBuffers: " + flatBuilder); + + // Test that netty is not relocated. + ByteBufAllocator nettyAllocator = ByteBufAllocator.DEFAULT; + logger.info("Netty ByteBufAllocator: " + nettyAllocator); + + // Test that grpc is not relocated. + Context grpcContext = Context.current(); + logger.info("gRPC Context: " + grpcContext); + + // Test that bouncycastle is not relocated. + BouncyCastleProvider bcProvider = new BouncyCastleProvider(); + logger.info("BouncyCastle Provider: " + bcProvider.getName()); + + // Test that resilience4j is not relocated. + CircuitBreakerConfig cbConfig = CircuitBreakerConfig.ofDefaults(); + logger.info("Resilience4j CircuitBreakerConfig: " + cbConfig); + + // Test that vavr is not relocated. + List vavrList = List.of("test"); + logger.info("Vavr List: " + vavrList); + + // Test that JTS is not relocated. + GeometryFactory jtsFactory = new GeometryFactory(); + logger.info("JTS GeometryFactory: " + jtsFactory); + + // Test that Databricks SDK is not relocated. + Class sdkClass = DatabricksConfig.class; + logger.info("Databricks SDK class: " + sdkClass.getName()); + + // Test that JSON is not relocated. + JSONObject jsonObject = new JSONObject().put("key", "value"); + logger.info("JSON object: " + jsonObject); + + // Test that Nimbus JOSE JWT is not relocated. + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256; + logger.info("Nimbus JWSAlgorithm: " + jwsAlgorithm); + } + + /** Test large query execution with Arrow result format works. */ + @Test + public void executeLargeQuery() throws SQLException { + Map params = new HashMap<>(); + params.put(DatabricksJdbcUrlParams.ENABLE_ARROW.getParamName(), "1"); + params.put(DatabricksJdbcUrlParams.USE_THRIFT_CLIENT.getParamName(), "0"); + + try (Connection connection = connect(params)) { + try (Statement statement = connection.createStatement()) { + final String sql = "SELECT * FROM samples.tpch.lineitem where 1 = 0"; + ResultSet result = statement.executeQuery(sql); + int totalRows = 0; + while (result.next()) { + if (totalRows % 100_000 == 0) { + logger.info("Processed " + totalRows + " rows"); + } + totalRows++; + } + + logger.info("Total " + totalRows + " rows processed"); + } + } + } + + private Connection connect(Map urlParams) throws SQLException { + Properties props = new Properties(); + props.setProperty("user", getDatabricksUser()); + props.setProperty("password", getDatabricksToken()); + for (Map.Entry entry : urlParams.entrySet()) { + props.setProperty(entry.getKey(), entry.getValue().toString()); + } + + String url = getDogfoodJDBCUrl(); + + return new com.databricks.client.jdbc.Driver().connect(url, props); + } + + private String getDogfoodJDBCUrl() { + String template = + "jdbc:databricks://%s/default;transportMode=http;ssl=1;AuthMech=3;httpPath=%s"; + String host = getDatabricksHost(); + String httpPath = getDatabricksHttpPath(); + + return String.format(template, host, httpPath); + } + + private String getDatabricksHttpPath() { + return System.getenv("DATABRICKS_HTTP_PATH"); + } + + private String getDatabricksHost() { + return System.getenv("DATABRICKS_HOST"); + } + + private String getDatabricksUser() { + return Optional.ofNullable(System.getenv("DATABRICKS_USER")).orElse("token"); + } + + private String getDatabricksToken() { + return System.getenv("DATABRICKS_TOKEN"); + } +} diff --git a/test-assembly-uber/pom.xml b/test-assembly-uber/pom.xml new file mode 100644 index 000000000..9155e3f8c --- /dev/null +++ b/test-assembly-uber/pom.xml @@ -0,0 +1,36 @@ + + + 4.0.0 + + + com.databricks + databricks-jdbc-parent + 3.2.2-SNAPSHOT + + + test-databricks-jdbc-uber + jar + Test Uber + Test JDBC databricks driver uber jar. + + + + true + + + + + com.databricks + databricks-jdbc + 3.2.2-SNAPSHOT + + + org.junit.jupiter + junit-jupiter + ${junit.jupiter.version} + test + + + \ No newline at end of file diff --git a/test-assembly-uber/src/test/java/com/databricks/jdbc/TestUberPackaging.java b/test-assembly-uber/src/test/java/com/databricks/jdbc/TestUberPackaging.java new file mode 100644 index 000000000..b8254cec3 --- /dev/null +++ b/test-assembly-uber/src/test/java/com/databricks/jdbc/TestUberPackaging.java @@ -0,0 +1,204 @@ +package com.databricks.jdbc; + +import com.databricks.internal.apache.commons.codec.binary.Base64; +import com.databricks.internal.apache.commons.io.IOUtils; +import com.databricks.internal.apache.commons.lang3.StringUtils; +import com.databricks.internal.apache.hc.client5.http.impl.classic.HttpClientBuilder; +import com.databricks.internal.apache.hc.client5.http.impl.classic.HttpClients; +import com.databricks.internal.apache.hc.core5.http.HttpStatus; +import com.databricks.internal.apache.thrift.TException; +import com.databricks.internal.bouncycastle.jce.provider.BouncyCastleProvider; +import com.databricks.internal.fasterxml.jackson.databind.ObjectMapper; +import com.databricks.internal.google.common.collect.ImmutableList; +import com.databricks.internal.google.flatbuffers.FlatBufferBuilder; +import com.databricks.internal.google.gson.Gson; +import com.databricks.internal.google.protobuf.ByteString; +import com.databricks.internal.io.github.resilience4j.circuitbreaker.CircuitBreakerConfig; +import com.databricks.internal.io.grpc.Context; +import com.databricks.internal.io.netty.buffer.ByteBufAllocator; +import com.databricks.internal.io.vavr.collection.List; +import com.databricks.internal.json.JSONObject; +import com.databricks.internal.jts.geom.GeometryFactory; +import com.databricks.internal.nimbusds.jose.JWSAlgorithm; +import com.databricks.internal.sdk.core.DatabricksConfig; +import com.databricks.jdbc.api.impl.arrow.ArrowBufferAllocator; +import com.databricks.jdbc.common.DatabricksJdbcUrlParams; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.logging.Logger; +import org.junit.jupiter.api.Test; + +/** Test artifacts are packaged properly. */ +public class TestUberPackaging { + /** Logger instance. */ + private static final java.util.logging.Logger logger = + Logger.getLogger(TestUberPackaging.class.getName()); + + /** Test packages are shaded as expected. */ + @Test + public void testThinPackaging() { + // Test that the "arrow" package is relocated. + com.databricks.internal.apache.arrow.memory.BufferAllocator bufferAllocator = + ArrowBufferAllocator.getBufferAllocator(); + logger.info("Shaded buffer allocator " + bufferAllocator); + + // Test that jackson packages are relocated. + ObjectMapper jacksonMapper = new ObjectMapper(); + logger.info("Shaded Jackson ObjectMapper: " + jacksonMapper); + + // Test that guava is relocated. + ImmutableList guavaList = ImmutableList.of("test"); + logger.info("Shaded Guava ImmutableList: " + guavaList); + + // Test that protobuf is relocated. + ByteString protoByteString = ByteString.copyFromUtf8("test"); + logger.info("Shaded Protobuf ByteString: " + protoByteString); + + // Test that commons-lang3 is relocated. + String commonsResult = StringUtils.upperCase("test"); + logger.info("Shaded Commons-Lang3 result: " + commonsResult); + + // Test that commons-codec is relocated (org.apache.commons.codec -> + // com.databricks.internal.apache.commons.codec). + byte[] commonsCodec = Base64.encodeBase64("test".getBytes()); + logger.info("Shaded Commons-Codec Base64: " + new String(commonsCodec)); + + // Test that commons-io is relocated. + try { + String commonsIo = IOUtils.toString(new ByteArrayInputStream("test".getBytes()), "UTF-8"); + logger.info("Shaded Commons-IO result: " + commonsIo); + } catch (IOException e) { + throw new RuntimeException("Failed to test Commons-IO shading", e); + } + + // Test that httpclient5 is relocated (org.apache.hc.client5 -> + // com.databricks.internal.apache.hc.client5). + HttpClientBuilder httpClientBuilder = HttpClients.custom(); + logger.info("Shaded HttpClient5 builder: " + httpClientBuilder); + + // Test that httpcore5 is relocated (org.apache.hc.core5 -> + // com.databricks.internal.apache.hc.core5). + int httpStatus = HttpStatus.SC_OK; + logger.info("Shaded HttpCore5 status: " + httpStatus); + + // Test that thrift is relocated. + TException thriftException = new TException("test"); + logger.info("Shaded Thrift TException: " + thriftException.getMessage()); + + // Test that gson is relocated. + Gson gson = new Gson(); + logger.info("Shaded Gson: " + gson); + + // Test that flatbuffers is relocated. + FlatBufferBuilder flatBuilder = new FlatBufferBuilder(); + logger.info("Shaded FlatBuffers: " + flatBuilder); + + // Test that netty is relocated (io.netty -> com.databricks.internal.io.netty). + ByteBufAllocator nettyAllocator = ByteBufAllocator.DEFAULT; + logger.info("Shaded Netty ByteBufAllocator: " + nettyAllocator); + + // Test that grpc is relocated (io.grpc -> com.databricks.internal.io.grpc). + Context grpcContext = Context.current(); + logger.info("Shaded gRPC Context: " + grpcContext); + + // Test that bouncycastle is relocated (org.bouncycastle -> + // com.databricks.internal.bouncycastle). + BouncyCastleProvider bcProvider = new BouncyCastleProvider(); + logger.info("Shaded BouncyCastle Provider: " + bcProvider.getName()); + + // Test that resilience4j is relocated (io.github.resilience4j -> + // com.databricks.internal.io.github.resilience4j). + CircuitBreakerConfig cbConfig = CircuitBreakerConfig.ofDefaults(); + logger.info("Shaded Resilience4j CircuitBreakerConfig: " + cbConfig); + + // Test that vavr is relocated (io.vavr -> com.databricks.internal.io.vavr). + List vavrList = List.of("test"); + logger.info("Shaded Vavr List: " + vavrList); + + // Test that JTS is relocated (org.locationtech.jts -> com.databricks.internal.jts). + GeometryFactory jtsFactory = new GeometryFactory(); + logger.info("Shaded JTS GeometryFactory: " + jtsFactory); + + // Test that Databricks SDK is relocated (com.databricks.sdk -> com.databricks.internal.sdk). + Class sdkClass = DatabricksConfig.class; + logger.info("Shaded Databricks SDK class: " + sdkClass.getName()); + + // Test that JSON is relocated (org.json -> com.databricks.internal.json). + JSONObject jsonObject = new JSONObject().put("key", "value"); + logger.info("Shaded JSON object: " + jsonObject); + + // Test that Nimbus JOSE JWT is relocated (com.nimbusds -> com.databricks.internal.nimbusds). + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256; + logger.info("Shaded Nimbus JWSAlgorithm: " + jwsAlgorithm); + } + + /** Test large query execution with Arrow result format works. */ + @Test + public void executeLargeQuery() throws SQLException { + Map params = new HashMap<>(); + params.put(DatabricksJdbcUrlParams.ENABLE_ARROW.getParamName(), "1"); + params.put(DatabricksJdbcUrlParams.USE_THRIFT_CLIENT.getParamName(), "0"); + + try (Connection connection = connect(params)) { + try (Statement statement = connection.createStatement()) { + final String sql = "SELECT * FROM samples.tpch.lineitem where 1 = 0"; + ResultSet result = statement.executeQuery(sql); + int totalRows = 0; + while (result.next()) { + if (totalRows % 100_000 == 0) { + logger.info("Processed " + totalRows + " rows"); + } + totalRows++; + } + + logger.info("Total " + totalRows + " rows processed"); + } + } + } + + private Connection connect(Map urlParams) throws SQLException { + Properties props = new Properties(); + props.setProperty("user", getDatabricksUser()); + props.setProperty("password", getDatabricksToken()); + for (Map.Entry entry : urlParams.entrySet()) { + props.setProperty(entry.getKey(), entry.getValue().toString()); + } + + String url = getDogfoodJDBCUrl(); + + return new com.databricks.client.jdbc.Driver().connect(url, props); + } + + private String getDogfoodJDBCUrl() { + String template = + "jdbc:databricks://%s/default;transportMode=http;ssl=1;AuthMech=3;httpPath=%s"; + String host = getDatabricksHost(); + String httpPath = getDatabricksHttpPath(); + + return String.format(template, host, httpPath); + } + + private String getDatabricksHttpPath() { + return System.getenv("DATABRICKS_HTTP_PATH"); + } + + private String getDatabricksHost() { + return System.getenv("DATABRICKS_HOST"); + } + + private String getDatabricksUser() { + return Optional.ofNullable(System.getenv("DATABRICKS_USER")).orElse("token"); + } + + private String getDatabricksToken() { + return System.getenv("DATABRICKS_TOKEN"); + } +} diff --git a/thin_public_pom.xml b/thin_public_pom.xml deleted file mode 100644 index 5b2afec7f..000000000 --- a/thin_public_pom.xml +++ /dev/null @@ -1,214 +0,0 @@ - - - 4.0.0 - com.databricks - databricks-jdbc-thin - 3.1.2 - jar - Databricks JDBC Driver Thin - Databricks JDBC Driver Thin JAR - requires external dependencies. - https://github.com/databricks/databricks-jdbc - - - - Apache License, Version 2.0 - https://github.com/databricks/databricks-jdbc/blob/main/LICENSE - - - - - - Databricks JDBC Team - eng-oss-sql-driver@databricks.com - Databricks - https://www.databricks.com - - - - - scm:git:https://github.com/databricks/databricks-jdbc.git - scm:git:https://github.com/databricks/databricks-jdbc.git - https://github.com/databricks/databricks-jdbc - - - - GitHub Issues - https://github.com/databricks/databricks-jdbc/issues - - - - - - com.databricks - databricks-sdk-java - 0.69.0 - - - - - org.apache.commons - commons-lang3 - 3.18.0 - - - org.apache.commons - commons-configuration2 - 2.10.1 - - - commons-io - commons-io - 2.14.0 - - - - - org.apache.arrow - arrow-memory-core - 17.0.0 - - - org.apache.arrow - arrow-memory-unsafe - 17.0.0 - - - org.apache.arrow - arrow-vector - 17.0.0 - - - org.apache.arrow - arrow-memory-netty - 17.0.0 - - - - - org.apache.httpcomponents - httpclient - 4.5.14 - - - org.apache.httpcomponents.client5 - httpclient5 - 5.3.1 - - - org.apache.httpcomponents.core5 - httpcore5 - 5.3.1 - - - - - org.apache.thrift - libthrift - 0.19.0 - - - - - org.slf4j - slf4j-api - 2.0.13 - - - - - com.google.code.findbugs - annotations - 3.0.1 - - - com.google.guava - guava - 33.0.0-jre - - - - - com.nimbusds - nimbus-jose-jwt - 10.0.2 - - - org.bouncycastle - bcprov-jdk18on - 1.79 - - - org.bouncycastle - bcpkix-jdk18on - 1.79 - - - - - com.fasterxml.jackson.core - jackson-databind - 2.18.3 - - - com.fasterxml.jackson.core - jackson-annotations - 2.18.3 - - - com.fasterxml.jackson.core - jackson-core - 2.18.3 - - - com.google.code.gson - gson - 2.13.2 - - - - - at.yawk.lz4 - lz4-java - 1.10.1 - - - - - io.grpc - grpc-context - 1.71.0 - - - - - io.netty - netty-common - 4.2.6.Final - - - io.netty - netty-buffer - 4.2.6.Final - - - - - jakarta.annotation - jakarta.annotation-api - 1.3.5 - - - - - io.github.resilience4j - resilience4j-circuitbreaker - 1.7.0 - - - io.github.resilience4j - resilience4j-core - 1.7.0 - - - \ No newline at end of file diff --git a/uber-minimal-pom.xml b/uber-minimal-pom.xml deleted file mode 100644 index b2bd5cbe9..000000000 --- a/uber-minimal-pom.xml +++ /dev/null @@ -1,36 +0,0 @@ - - - 4.0.0 - com.databricks - databricks-jdbc - - 3.2.1 - jar - Databricks JDBC Driver - Databricks JDBC Driver. - https://github.com/databricks/databricks-jdbc - - - Apache License, Version 2.0 - https://github.com/databricks/databricks-jdbc/blob/main/LICENSE - - - - - Databricks JDBC Team - eng-oss-sql-driver@databricks.com - Databricks - https://www.databricks.com - - - - scm:git:https://github.com/databricks/databricks-jdbc.git - scm:git:https://github.com/databricks/databricks-jdbc.git - https://github.com/databricks/databricks-jdbc - - - GitHub Issues - https://github.com/databricks/databricks-jdbc/issues - -