diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java index f77f48d1..4f6ac436 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java @@ -1,13 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 package com.hedera.pbj.compiler.impl; -import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser; +import com.hedera.pbj.compiler.impl.Field.FieldType; +import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser.DocCommentContext; import edu.umd.cs.findbugs.annotations.NonNull; import java.io.File; import java.util.Arrays; import java.util.List; import java.util.Objects; -import java.util.regex.Pattern; import java.util.stream.Collectors; /** @@ -30,8 +30,6 @@ public final class Common { /** Number of bits used to represent the tag type */ static final int TAG_TYPE_BITS = 3; - private static final Pattern COMPARABLE_PATTERN = Pattern.compile("implements Comparable<\\w+>\\s*\\{"); - /** * Makes a tag value given a field number and wire type. * @@ -120,7 +118,7 @@ public static String camelToUpperSnake(String name) { * * @return clean comment */ - public static String buildCleanFieldJavaDoc(int fieldNumber, Protobuf3Parser.DocCommentContext docContext) { + public static String buildCleanFieldJavaDoc(int fieldNumber, DocCommentContext docContext) { final String cleanedComment = docContext == null ? "" : cleanJavaDocComment(docContext.getText()); final String fieldNumComment = "(" + fieldNumber + ") "; return fieldNumComment + cleanedComment; @@ -134,8 +132,7 @@ public static String buildCleanFieldJavaDoc(int fieldNumber, Protobuf3Parser.Doc * * @return clean comment */ - public static String buildCleanFieldJavaDoc( - List fieldNumbers, Protobuf3Parser.DocCommentContext docContext) { + public static String buildCleanFieldJavaDoc(List fieldNumbers, DocCommentContext docContext) { final String cleanedComment = docContext == null ? "" : cleanJavaDocComment(docContext.getText()); final String fieldNumComment = "(" + fieldNumbers.stream().map(Objects::toString).collect(Collectors.joining(", ")) + ") "; @@ -208,229 +205,6 @@ public static String javaPrimitiveToObjectType(String primitiveFieldType) { }; } - /** - * Recursively calculates the hashcode for a message fields. - * - * @param fields The fields of this object. - * @param generatedCodeSoFar The accumulated hash code so far. - * - * @return The generated code for getting the hashCode value. - */ - public static String getFieldsHashCode(final List fields, String generatedCodeSoFar) { - for (Field f : fields) { - if (f.parent() != null) { - final OneOfField oneOfField = f.parent(); - generatedCodeSoFar += getFieldsHashCode(oneOfField.fields(), generatedCodeSoFar); - } else if (f.optionalValueType()) { - generatedCodeSoFar = getPrimitiveWrapperHashCodeGeneration(generatedCodeSoFar, f); - } else if (f.repeated()) { - generatedCodeSoFar = getRepeatedHashCodeGeneration(generatedCodeSoFar, f); - } else { - if (f.type() == Field.FieldType.FIXED32 - || f.type() == Field.FieldType.INT32 - || f.type() == Field.FieldType.SFIXED32 - || f.type() == Field.FieldType.SINT32 - || f.type() == Field.FieldType.UINT32) { - generatedCodeSoFar += - (""" - if ($fieldName != DEFAULT.$fieldName) { - result = 31 * result + Integer.hashCode($fieldName); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.FIXED64 - || f.type() == Field.FieldType.INT64 - || f.type() == Field.FieldType.SFIXED64 - || f.type() == Field.FieldType.SINT64 - || f.type() == Field.FieldType.UINT64) { - generatedCodeSoFar += - (""" - if ($fieldName != DEFAULT.$fieldName) { - result = 31 * result + Long.hashCode($fieldName); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.BOOL) { - generatedCodeSoFar += - (""" - if ($fieldName != DEFAULT.$fieldName) { - result = 31 * result + Boolean.hashCode($fieldName); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.FLOAT) { - generatedCodeSoFar += - (""" - if ($fieldName != DEFAULT.$fieldName) { - result = 31 * result + Float.hashCode($fieldName); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.DOUBLE) { - generatedCodeSoFar += - (""" - if ($fieldName != DEFAULT.$fieldName) { - result = 31 * result + Double.hashCode($fieldName); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.BYTES) { - generatedCodeSoFar += - (""" - if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { - result = 31 * result + $fieldName.hashCode(); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.ENUM) { - generatedCodeSoFar += - (""" - if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { - result = 31 * result + Integer.hashCode($fieldName.protoOrdinal()); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.MAP) { - generatedCodeSoFar += getMapHashCodeGeneration(generatedCodeSoFar, f); - } else if (f.type() == Field.FieldType.STRING || f.parent() == null) { // process sub message - generatedCodeSoFar += - (""" - if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { - result = 31 * result + $fieldName.hashCode(); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - } else { - throw new RuntimeException("Unexpected field type for getting HashCode - " - + f.type().toString()); - } - } - } - return generatedCodeSoFar.indent(DEFAULT_INDENT * 3); - } - - /** - * Get the hashcode codegen for a optional field. - * - * @param generatedCodeSoFar The string that the codegen is generated into. - * @param f The field for which to generate the hash code. - * - * @return Updated codegen string. - */ - @NonNull - private static String getPrimitiveWrapperHashCodeGeneration(String generatedCodeSoFar, Field f) { - switch (f.messageType()) { - case "StringValue" -> generatedCodeSoFar += - (""" - if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { - result = 31 * result + $fieldName.hashCode(); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "BoolValue" -> generatedCodeSoFar += - (""" - if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { - result = 31 * result + Boolean.hashCode($fieldName); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "Int32Value", "UInt32Value" -> generatedCodeSoFar += - (""" - if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { - result = 31 * result + Integer.hashCode($fieldName); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "Int64Value", "UInt64Value" -> generatedCodeSoFar += - (""" - if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { - result = 31 * result + Long.hashCode($fieldName); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "FloatValue" -> generatedCodeSoFar += - (""" - if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { - result = 31 * result + Float.hashCode($fieldName); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "DoubleValue" -> generatedCodeSoFar += - (""" - if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { - result = 31 * result + Double.hashCode($fieldName); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - case "BytesValue" -> generatedCodeSoFar += - (""" - if ($fieldName != null && !$fieldName.equals(DEFAULT.$fieldName)) { - result = 31 * result + ($fieldName == null ? 0 : $fieldName.hashCode()); - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - default -> throw new UnsupportedOperationException("Unhandled optional message type:" + f.messageType()); - } - return generatedCodeSoFar; - } - - /** - * Get the hashcode codegen for a repeated field. - * - * @param generatedCodeSoFar The string that the codegen is generated into. - * @param f The field for which to generate the hash code. - * - * @return Updated codegen string. - */ - @NonNull - private static String getRepeatedHashCodeGeneration(String generatedCodeSoFar, Field f) { - generatedCodeSoFar += - (""" - java.util.List list$$fieldName = $fieldName; - if (list$$fieldName != null) { - for (Object o : list$$fieldName) { - if (o != null) { - result = 31 * result + o.hashCode(); - } else { - result = 31 * result; - } - } - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - return generatedCodeSoFar; - } - - /** - * Get the hashcode codegen for a map field. - * - * @param generatedCodeSoFar The string that the codegen is generated into. - * @param f The field for which to generate the hash code. - * - * @return Updated codegen string. - */ - @NonNull - private static String getMapHashCodeGeneration(String generatedCodeSoFar, final Field f) { - generatedCodeSoFar += - (""" - for (Object k : ((PbjMap) $fieldName).getSortedKeys()) { - if (k != null) { - result = 31 * result + k.hashCode(); - } else { - result = 31 * result; - } - Object v = $fieldName.get(k); - if (v != null) { - result = 31 * result + v.hashCode(); - } else { - result = 31 * result; - } - } - """) - .replace("$fieldName", f.nameCamelFirstLower()); - return generatedCodeSoFar; - } - /** * Recursively calculates `equals` statement for a message fields. * @@ -452,11 +226,11 @@ public static String getFieldsEqualsStatements(final List fields, String generatedCodeSoFar = getRepeatedEqualsGeneration(generatedCodeSoFar, f); } else { f.nameCamelFirstLower(); - if (f.type() == Field.FieldType.FIXED32 - || f.type() == Field.FieldType.INT32 - || f.type() == Field.FieldType.SFIXED32 - || f.type() == Field.FieldType.SINT32 - || f.type() == Field.FieldType.UINT32) { + if (f.type() == FieldType.FIXED32 + || f.type() == FieldType.INT32 + || f.type() == FieldType.SFIXED32 + || f.type() == FieldType.SINT32 + || f.type() == FieldType.UINT32) { generatedCodeSoFar += """ if ($fieldName != thatObj.$fieldName) { @@ -464,11 +238,11 @@ public static String getFieldsEqualsStatements(final List fields, String } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.FIXED64 - || f.type() == Field.FieldType.INT64 - || f.type() == Field.FieldType.SFIXED64 - || f.type() == Field.FieldType.SINT64 - || f.type() == Field.FieldType.UINT64) { + } else if (f.type() == FieldType.FIXED64 + || f.type() == FieldType.INT64 + || f.type() == FieldType.SFIXED64 + || f.type() == FieldType.SINT64 + || f.type() == FieldType.UINT64) { generatedCodeSoFar += """ if ($fieldName != thatObj.$fieldName) { @@ -476,7 +250,7 @@ public static String getFieldsEqualsStatements(final List fields, String } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.BOOL) { + } else if (f.type() == FieldType.BOOL) { generatedCodeSoFar += """ if ($fieldName != thatObj.$fieldName) { @@ -484,7 +258,7 @@ public static String getFieldsEqualsStatements(final List fields, String } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.FLOAT) { + } else if (f.type() == FieldType.FLOAT) { generatedCodeSoFar += """ if ($fieldName != thatObj.$fieldName) { @@ -492,7 +266,7 @@ public static String getFieldsEqualsStatements(final List fields, String } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.DOUBLE) { + } else if (f.type() == FieldType.DOUBLE) { generatedCodeSoFar += """ if ($fieldName != thatObj.$fieldName) { @@ -500,10 +274,10 @@ public static String getFieldsEqualsStatements(final List fields, String } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.STRING - || f.type() == Field.FieldType.BYTES - || f.type() == Field.FieldType.ENUM - || f.type() == Field.FieldType.MAP + } else if (f.type() == FieldType.STRING + || f.type() == FieldType.BYTES + || f.type() == FieldType.ENUM + || f.type() == FieldType.MAP || f.parent() == null /* Process a sub-message */) { generatedCodeSoFar += (""" @@ -525,7 +299,7 @@ public static String getFieldsEqualsStatements(final List fields, String } /** - * Get the equals codegen for a optional field. + * Get the equals codegen for an optional field. * * @param generatedCodeSoFar The string that the codegen is generated into. * @param f The field for which to generate the equals code. @@ -543,8 +317,9 @@ private static String getPrimitiveWrapperEqualsGeneration(String generatedCodeSo "UInt64Value", "FloatValue", "DoubleValue", - "BytesValue" -> generatedCodeSoFar += - (""" + "BytesValue" -> + generatedCodeSoFar += + (""" if (this.$fieldName == null && thatObj.$fieldName != null) { return false; } @@ -552,7 +327,7 @@ private static String getPrimitiveWrapperEqualsGeneration(String generatedCodeSo return false; } """) - .replace("$fieldName", f.nameCamelFirstLower()); + .replace("$fieldName", f.nameCamelFirstLower()); default -> throw new UnsupportedOperationException("Unhandled optional message type:" + f.messageType()); } return generatedCodeSoFar; @@ -597,10 +372,10 @@ public static String getFieldsCompareToStatements(final List fields, Stri } else if (f.repeated()) { throw new UnsupportedOperationException("Repeated fields are not supported in compareTo method"); } else { - if (f.type() == Field.FieldType.FIXED32 - || f.type() == Field.FieldType.INT32 - || f.type() == Field.FieldType.SFIXED32 - || f.type() == Field.FieldType.SINT32) { + if (f.type() == FieldType.FIXED32 + || f.type() == FieldType.INT32 + || f.type() == FieldType.SFIXED32 + || f.type() == FieldType.SINT32) { generatedCodeSoFar += """ result = Integer.compare($fieldName, thatObj.$fieldName); @@ -609,7 +384,7 @@ public static String getFieldsCompareToStatements(final List fields, Stri } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.UINT32) { + } else if (f.type() == FieldType.UINT32) { generatedCodeSoFar += """ result = Integer.compareUnsigned($fieldName, thatObj.$fieldName); @@ -619,10 +394,10 @@ public static String getFieldsCompareToStatements(final List fields, Stri """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.FIXED64 - || f.type() == Field.FieldType.INT64 - || f.type() == Field.FieldType.SFIXED64 - || f.type() == Field.FieldType.SINT64) { + } else if (f.type() == FieldType.FIXED64 + || f.type() == FieldType.INT64 + || f.type() == FieldType.SFIXED64 + || f.type() == FieldType.SINT64) { generatedCodeSoFar += """ result = Long.compare($fieldName, thatObj.$fieldName); @@ -631,7 +406,7 @@ public static String getFieldsCompareToStatements(final List fields, Stri } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.UINT64) { + } else if (f.type() == FieldType.UINT64) { generatedCodeSoFar += """ result = Long.compareUnsigned($fieldName, thatObj.$fieldName); @@ -640,7 +415,7 @@ public static String getFieldsCompareToStatements(final List fields, Stri } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.BOOL) { + } else if (f.type() == FieldType.BOOL) { generatedCodeSoFar += """ result = Boolean.compare($fieldName, thatObj.$fieldName); @@ -649,7 +424,7 @@ public static String getFieldsCompareToStatements(final List fields, Stri } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.FLOAT) { + } else if (f.type() == FieldType.FLOAT) { generatedCodeSoFar += """ result = Float.compare($fieldName, thatObj.$fieldName); @@ -658,7 +433,7 @@ public static String getFieldsCompareToStatements(final List fields, Stri } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.DOUBLE) { + } else if (f.type() == FieldType.DOUBLE) { generatedCodeSoFar += """ result = Double.compare($fieldName, thatObj.$fieldName); @@ -667,11 +442,9 @@ public static String getFieldsCompareToStatements(final List fields, Stri } """ .replace("$fieldName", f.nameCamelFirstLower()); - } else if (f.type() == Field.FieldType.STRING - || f.type() == Field.FieldType.BYTES - || f.type() == Field.FieldType.ENUM) { + } else if (f.type() == FieldType.STRING || f.type() == FieldType.BYTES || f.type() == FieldType.ENUM) { generatedCodeSoFar += generateCompareToForObject(f); - } else if (f.type() == Field.FieldType.MESSAGE || f.type() == Field.FieldType.ONE_OF) { + } else if (f.type() == FieldType.MESSAGE || f.type() == FieldType.ONE_OF) { verifyComparable(f); generatedCodeSoFar += generateCompareToForObject(f); } else { @@ -709,7 +482,7 @@ private static String generateCompareToForObject(Field f) { */ private static void verifyComparable(final Field field) { if (field instanceof final SingleField singleField) { - if (singleField.type() != Field.FieldType.MESSAGE) { + if (singleField.type() != FieldType.MESSAGE) { // everything else except message and bytes is comparable for sure return; } @@ -760,8 +533,8 @@ private static String getPrimitiveWrapperCompareToGeneration(Field f) { case "UInt64Value" -> "java.lang.Long.compareUnsigned($fieldName, thatObj.$fieldName)"; case "FloatValue" -> "java.lang.Float.compare($fieldName, thatObj.$fieldName)"; case "DoubleValue" -> "java.lang.Double.compare($fieldName, thatObj.$fieldName)"; - default -> throw new UnsupportedOperationException( - "Unhandled optional message type:" + f.messageType()); + default -> + throw new UnsupportedOperationException("Unhandled optional message type:" + f.messageType()); }; return template.replace("$compareStatement", compareStatement).replace("$fieldName", f.nameCamelFirstLower()); diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/MapField.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/MapField.java index cafbd838..ce3dd52b 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/MapField.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/MapField.java @@ -21,9 +21,9 @@ */ public record MapField( /* A synthetic "key" field in a map entry. */ - Field keyField, + SingleField keyField, /* A synthetic "value" field in a map entry. */ - Field valueField, + SingleField valueField, // The rest of the fields below simply implement the Field interface: boolean repeated, int fieldNumber, @@ -94,7 +94,7 @@ public MapField(Protobuf3Parser.MapFieldContext mapContext, final ContextualLook */ public String javaGenericType() { final String fieldTypeName = valueField().type() == FieldType.MESSAGE - ? ((SingleField) valueField()).messageType() + ? valueField().messageType() : valueField().type().boxedType; return "<%s, %s>".formatted(keyField.type().boxedType, fieldTypeName); } diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java index 8fc30966..6ec567c8 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java @@ -137,6 +137,17 @@ public String javaFieldTypeBase() { return javaFieldType(false); } + public String javaFieldTypeBoxed() { + return switch (type) { + case BOOL -> "Boolean"; + case INT32, UINT32, SINT32, FIXED32, SFIXED32 -> "Integer"; + case INT64, SINT64, UINT64, FIXED64, SFIXED64 -> "Long"; + case FLOAT -> "Float"; + case DOUBLE -> "Double"; + default -> javaFieldType(); + }; + } + @NonNull private String javaFieldType(boolean considerRepeated) { String fieldType = diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java index 8fff2c71..6b566e68 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java @@ -6,7 +6,6 @@ import static com.hedera.pbj.compiler.impl.Common.camelToUpperSnake; import static com.hedera.pbj.compiler.impl.Common.cleanDocStr; import static com.hedera.pbj.compiler.impl.Common.cleanJavaDocComment; -import static com.hedera.pbj.compiler.impl.Common.getFieldsHashCode; import static com.hedera.pbj.compiler.impl.Common.javaPrimitiveToObjectType; import static com.hedera.pbj.compiler.impl.generators.EnumGenerator.EnumValue; import static com.hedera.pbj.compiler.impl.generators.EnumGenerator.createEnum; @@ -42,21 +41,6 @@ public final class ModelGenerator implements Generator { private static final String NON_NULL_ANNOTATION = "@NonNull"; - private static final String HASH_CODE_MANIPULATION = - """ - // Shifts: 30, 27, 16, 20, 5, 18, 10, 24, 30 - hashCode += hashCode << 30; - hashCode ^= hashCode >>> 27; - hashCode += hashCode << 16; - hashCode ^= hashCode >>> 20; - hashCode += hashCode << 5; - hashCode ^= hashCode >>> 18; - hashCode += hashCode << 10; - hashCode ^= hashCode >>> 24; - hashCode += hashCode << 30; - """ - .indent(DEFAULT_INDENT * 2); - /** * {@inheritDoc} * @@ -94,6 +78,10 @@ public void generate( writer.addImport("static " + lookupHelper.getFullyQualifiedMessageClassname(FileType.SCHEMA, msgDef) + ".*"); writer.addImport("java.util.Collections"); writer.addImport("java.util.List"); + writer.addImport("com.hedera.pbj.runtime.hashing.XXH3_64"); + writer.addImport("com.hedera.pbj.runtime.hashing.XXH3FieldHash"); + writer.addImport("com.hedera.pbj.runtime.hashing.SixtyFourBitHashable"); + writer.addImport("static com.hedera.pbj.runtime.hashing.XXH3FieldHash.*"); // Iterate over all the items in the protobuf schema for (final var item : msgDef.messageBody().messageElement()) { @@ -125,7 +113,7 @@ public void generate( // add precomputed fields to fields fields.add(new SingleField( false, - FieldType.FIXED32, + FieldType.FIXED64, -1, "$hashCode", null, @@ -242,7 +230,7 @@ public void generate( bodyContent += "\n"; // hashCode method - bodyContent += generateHashCode(fieldsNoPrecomputed); + bodyContent += ModelHashCodeGenerator.generateHashCode(fieldsNoPrecomputed); bodyContent += "\n"; // equals method @@ -310,11 +298,11 @@ private void generateClass( final boolean isComparable, final ContextualLookupHelper lookupHelper) throws IOException { - final String implementsComparable; + final String implementsCode; if (isComparable) { - implementsComparable = "implements Comparable<$javaRecordName> "; + implementsCode = "implements SixtyFourBitHashable, Comparable<$javaRecordName> "; } else { - implementsComparable = ""; + implementsCode = "implements SixtyFourBitHashable "; } final String staticModifier = Generator.isInner(msgDef) ? " static" : ""; @@ -330,6 +318,7 @@ private void generateClass( // spotless:off writer.append(""" $javaDocComment$deprecated + @java.lang.SuppressWarnings("ForLoopReplaceableByForEach") public final$staticModifier class $javaRecordName $implementsComparable{ $bodyContent @@ -337,7 +326,7 @@ private void generateClass( .replace("$javaDocComment", javaDocComment) .replace("$deprecated", deprecated) .replace("$staticModifier", staticModifier) - .replace("$implementsComparable", implementsComparable) + .replace("$implementsComparable", implementsCode) .replace("$javaRecordName", javaRecordName) .replace("$bodyContent", bodyContent)); // spotless:on @@ -531,63 +520,6 @@ public boolean equals(Object that) { return bodyContent; } - /** - * Generates the hashCode method - * - * @param fields the fields to use for the code generation - * - * @return the generated code - */ - @NonNull - private static String generateHashCode(final List fields) { - // Generate a call to private method that iterates through fields and calculates the hashcode - final String statements = getFieldsHashCode(fields, ""); - // spotless:off - String bodyContent = - """ - /** - * Override the default hashCode method for to make hashCode better distributed and follows protobuf rules - * for default values. This is important for backward compatibility. This also lazy computes and caches the - * hashCode for future calls. It is designed to be thread safe. - */ - @Override - public int hashCode() { - // The $hashCode field is subject to a benign data race, making it crucial to ensure that any - // observable result of the calculation in this method stays correct under any possible read of this - // field. Necessary restrictions to allow this to be correct without explicit memory fences or similar - // concurrency primitives is that we can ever only write to this field for a given Model object - // instance, and that the computation is idempotent and derived from immutable state. - // This is the same trick used in java.lang.String.hashCode() to avoid synchronization. - - if($hashCode == -1) { - int result = 1; - """.indent(DEFAULT_INDENT); - - bodyContent += statements; - - bodyContent += - """ - if ($unknownFields != null) { - for (int i = 0; i < $unknownFields.size(); i++) { - result = 31 * result + $unknownFields.get(i).hashCode(); - } - } - """.indent(DEFAULT_INDENT); - - bodyContent += - """ - long hashCode = result; - $hashCodeManipulation - $hashCode = (int)hashCode; - } - return $hashCode; - } - """.replace("$hashCodeManipulation", HASH_CODE_MANIPULATION) - .indent(DEFAULT_INDENT); - // spotless:on - return bodyContent; - } - /** * Generates the toString method, based on how Java records generate toStrings * diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelHashCodeGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelHashCodeGenerator.java new file mode 100644 index 00000000..7fb0ea2e --- /dev/null +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelHashCodeGenerator.java @@ -0,0 +1,272 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.compiler.impl.generators; + +import static com.hedera.pbj.compiler.impl.Common.DEFAULT_INDENT; + +import com.hedera.pbj.compiler.impl.Common; +import com.hedera.pbj.compiler.impl.Field; +import com.hedera.pbj.compiler.impl.Field.FieldType; +import com.hedera.pbj.compiler.impl.MapField; +import com.hedera.pbj.compiler.impl.OneOfField; +import com.hedera.pbj.compiler.impl.SingleField; +import edu.umd.cs.findbugs.annotations.NonNull; +import java.util.List; + +/** + * Generates the hashCode() and hashCode64() methods for a model class. + */ +@SuppressWarnings("StringConcatenationInLoop") +public class ModelHashCodeGenerator { + /** The mixing code for the hashCode generation. After each field */ + private static final String MIX_CODE = """ + if(($xx_fieldCount & 1) == 0) { + $xx_acc += mixPair($xx_value, $xx_carry, $xx_fieldCount++ >> 1); + } else { + $xx_carry = $xx_value; + $xx_fieldCount ++; + }"""; + + /** + * Generates the hashCode method + * + * @param fields the fields to use for the code generation + * @return the generated code + */ + @NonNull + static String generateHashCode(final List fields) { + // spotless:off + String bodyContent = + """ + /** + * Override the default hashCode method for to make hashCode better distributed and follows protobuf rules + * for default values. This is important for backward compatibility. This also lazy computes and caches the + * hashCode for future calls. It is designed to be thread safe. + */ + @Override + public int hashCode() { + return(int)hashCode64(); + } + + /** + * Extended 64bit hashCode method for to make hashCode better distributed and follows protobuf rules + * for default values. This is important for backward compatibility. This also lazy computes and caches the + * hashCode for future calls. It is designed to be thread safe. + */ + public long hashCode64() { + // The $hashCode field is subject to a benign data race, making it crucial to ensure that any + // observable result of the calculation in this method stays correct under any possible read of this + // field. Necessary restrictions to allow this to be correct without explicit memory fences or similar + // concurrency primitives is that we can ever only write to this field for a given Model object + // instance, and that the computation is idempotent and derived from immutable state. + // This is the same trick used in java.lang.String.hashCode() to avoid synchronization. + + if($hashCode == -1) { + long $xx_acc = 0L; + int $xx_fieldCount = 0; // pairCount + long $xx_carry = 0L; + long $xx_value = 0L; + """.indent(DEFAULT_INDENT); + + // Generate a call to private method that iterates through fields and calculates the hashcode + bodyContent += fieldsContainMapsOrRepeatedListFields(fields) ? + getFieldsHashCode(fields) : getFieldsHashCodeFixedLength(fields); + bodyContent += + """ + if ($unknownFields != null) { + for (int i = 0; i < $unknownFields.size(); i++) { + // For unknown fields, if they are default value they will not be on the wire + // there for they will not be in the $unknownFields list. So we can safely + // assume that if we are here, the field is not default value. + if (($xx_fieldCount & 1) == 0) { + $xx_acc += mixPair($xx_carry, $unknownFields.get(i).hashCode64(), $xx_fieldCount++ >> 1); + } else { + $xx_carry = $unknownFields.get(i).hashCode64(); + $xx_fieldCount ++; + } + } + } + if (($xx_fieldCount & 1) == 0) { + // If we have an odd number of pairs, we need to mix the last carry value + $xx_acc += mixTail8($xx_carry, $xx_fieldCount >> 1); + } + $hashCode = finish($xx_acc, (long)$xx_fieldCount << 3); // $xx_fieldCount * 8 + } + return $hashCode; + } + """.indent(DEFAULT_INDENT); + // spotless:on + return bodyContent; + } + + /** + * Checks if the fields contain maps or repeated list fields. + * + * @param fields The fields to check. + * @return true if the fields contain maps or repeated list fields, false otherwise. + */ + private static boolean fieldsContainMapsOrRepeatedListFields(final List fields) { + for (final Field f : fields) { + if (f.type() == FieldType.MAP || f.repeated()) { + return true; + } + } + return false; + } + + /** + * Adds code for each field, fast path for case where there are no repeating field types like map or list + * + * @param fields The fields of this object. + * @return The generated code for getting the hashCode value. + */ + private static String getFieldsHashCodeFixedLength(final List fields) { + String generatedCode = ""; + for (int i = 0; i < fields.size(); i++) { + final Field f = fields.get(i); + final boolean isOddField = (i % 2) != 0; + final String dstVarName = isOddField ? "$xx_value":"$xx_carry"; + if (f instanceof OneOfField oneOfField) { + final String fieldName = f.nameCamelFirstLower() + ".value()"; + String caseStatements = ""; + for (final Field childField : oneOfField.fields()) { + if (!caseStatements.isEmpty()) { + caseStatements += "\n "; + } + caseStatements += "case " + Common.camelToUpperSnake(childField.name()) + " -> " + + getFieldGetLongCode(childField, fieldName, true) +";"; + } + generatedCode += + (""" + $dstVarName = switch($fieldName.kind()) { + $caseStatements + default -> throw new IllegalStateException("Unknown one-of kind: " + $fieldName.kind()); + };$mixCode + """) + .replace("$caseStatements", caseStatements) + .replace("$mixCode", isOddField ? "\n$xx_acc += mixPair($xx_value, $xx_carry, "+(i/2)+");" : "") + .replace("$dstVarName", dstVarName) + .replace("$fieldName", f.nameCamelFirstLower()); + } else { + generatedCode += + (""" + $dstVarName = $varSetCode;$mixCode + """) + .replace("$varSetCode", getFieldGetLongCode(f, f.nameCamelFirstLower(), false)) + .replace("$mixCode", isOddField ? "\n$xx_acc += mixPair($xx_value, $xx_carry, "+(i/2)+");" : "") + .replace("$dstVarName", dstVarName); + } + } + generatedCode += "$xx_fieldCount = " + fields.size() + ";\n"; + return generatedCode.indent(DEFAULT_INDENT * 3); + } + + /** + * Recursively calculates the hashcode for a message fields. + * + * @param fields The fields of this object. + * @return The generated code for getting the hashCode value. + */ + private static String getFieldsHashCode(final List fields) { + String generatedCode = ""; + for (int i = 0; i < fields.size(); i++) { + final Field f = fields.get(i); + if (f instanceof OneOfField oneOfField) { + final String fieldName = f.nameCamelFirstLower() + ".value()"; + String caseStatements = ""; + for (final Field childField : oneOfField.fields()) { + if (!caseStatements.isEmpty()) { + caseStatements += "\n "; + } + caseStatements += "case " + Common.camelToUpperSnake(childField.name()) + " -> " + + getFieldGetLongCode(childField, fieldName, true) + ";"; + } + generatedCode += + (""" + $xx_value = switch($fieldName.kind()) { + $caseStatements + default -> throw new IllegalStateException("Unknown one-of kind: " + $fieldName.kind()); + }; + $mixCode + """) + .replace("$caseStatements", caseStatements) + .replace("$mixCode", MIX_CODE) + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.repeated()) { + generatedCode += (""" + for (var o : $fieldName) { + $xx_value = $addValueCode; + $mixCode + } + """) + .replace("$addValueCode", getFieldGetLongCode(f, "o", false)) + .replace("$mixCode",MIX_CODE.indent(DEFAULT_INDENT*2)) + .replace("$fieldType", f.javaFieldType()) + .replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == FieldType.MAP) { + final SingleField keyField = ((MapField) f).keyField(); + final SingleField valueField = ((MapField) f).valueField(); + String keyCode = getFieldGetLongCode(keyField, "$m_key", false); + String valueCode = getFieldGetLongCode(valueField, "$m_value", false); + generatedCode += + (""" + for ($keyType $m_key : ((PbjMap<$keyType,$valueType>) $fieldName).getSortedKeys()) { + if ($m_key != null) { + $xx_value = $keyCode; + $mixCode + } + final $valueType $m_value = $fieldName.get($m_key); + if ($m_value != null) { + $xx_value = $valueCode; + $mixCode + } + } + """) + .replace("$keyType", keyField.javaFieldTypeBoxed()) + .replace("$mixCode", MIX_CODE.indent(DEFAULT_INDENT * 2)) + .replace("$valueType", valueField.javaFieldTypeBoxed()) + .replace("$keyCode", keyCode) + .replace("$valueCode", valueCode) + .replace("$fieldName", ((MapField) f).nameCamelFirstLower()); + } else { + generatedCode += + (""" + $xx_value = $varSetCode; + $mixCode + """) + .replace("$varSetCode", getFieldGetLongCode(f, f.nameCamelFirstLower(), false)) + .replace("$mixCode", MIX_CODE); + } + } + return generatedCode.indent(DEFAULT_INDENT * 3); + } + + /** + * Get the hashcode codegen for an optional field. + * + * @param f The field for which to generate the hash code. + * + * @return Updated codegen string. + */ + @NonNull + private static String getFieldGetLongCode(final Field f, String fieldValueCode, final boolean needsCasting) { + final String fieldValueCodeCasted = needsCasting ? "((" + f.javaFieldType() + ")" + fieldValueCode + ")" : fieldValueCode; + return switch (f.type()) { + case FIXED32, INT32, SFIXED32, SINT32, UINT32, BOOL, FLOAT, DOUBLE, STRING -> "toLong("+fieldValueCodeCasted+")"; + case FIXED64, INT64, SFIXED64, SINT64, UINT64 -> fieldValueCodeCasted; + case BYTES -> fieldValueCodeCasted+".hashCode64()"; + case ENUM -> "toLong("+fieldValueCodeCasted+".protoOrdinal())"; + case MESSAGE -> switch (f.messageType()) { + case "StringValue" -> "toLong(" + fieldValueCodeCasted + ")"; + case "Int32Value", "UInt32Value" -> "toLong(" + fieldValueCodeCasted + ")"; + case "Int64Value", "UInt64Value" -> fieldValueCodeCasted ; + case "FloatValue" -> "toLong(" + fieldValueCodeCasted + ")"; + case "DoubleValue" -> "toLong(" + fieldValueCodeCasted + ")"; + case "BytesValue" -> fieldValueCodeCasted + ".hashCode64()"; + case "BoolValue" -> "toLong(" + fieldValueCodeCasted + ")"; + default -> fieldValueCode +"==null?0:"+fieldValueCodeCasted+ ".hashCode64()"; + }; + default -> throw new UnsupportedOperationException("Unhandled optional message type:" + f.messageType()); + }; + } + +} diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java index 3839b6a7..f0d20bcb 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java @@ -178,7 +178,7 @@ private static String generateFieldWriteLines( .replace("$map", getValueCode) .replace("$javaFieldType", mapField.javaFieldType()) .replace("$K", mapField.keyField().type().boxedType) - .replace("$V", mapField.valueField().type() == Field.FieldType.MESSAGE ? ((SingleField)mapField.valueField()).messageType() : mapField.valueField().type().boxedType) + .replace("$V", mapField.valueField().type() == Field.FieldType.MESSAGE ? mapField.valueField().messageType() : mapField.valueField().type().boxedType) .replace("$fieldWriteLines", fieldWriteLines.indent(DEFAULT_INDENT)) .replace("$fieldSizeOfLines", fieldSizeOfLines.indent(DEFAULT_INDENT)); } else { diff --git a/pbj-core/pbj-runtime/src/jmh/java/com/hedera/pbj/runtime/hashing/XxhBenchmark.java b/pbj-core/pbj-runtime/src/jmh/java/com/hedera/pbj/runtime/hashing/XxhBenchmark.java new file mode 100644 index 00000000..f1d4a123 --- /dev/null +++ b/pbj-core/pbj-runtime/src/jmh/java/com/hedera/pbj/runtime/hashing/XxhBenchmark.java @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.runtime.hashing; + +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * Benchmark for XXH3_64 hashing functions. + * This benchmark tests the performance of hashing byte arrays and strings using the XXH3_64 + * hashing algorithm. + */ +@State(Scope.Benchmark) +@Fork(1) +@Warmup(iterations = 4, time = 2) +@Measurement(iterations = 5, time = 2) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@BenchmarkMode(Mode.AverageTime) +public class XxhBenchmark { + private static final String CHAR_POOL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + public static final int SAMPLES = 10_000; + + + @Param({"4","8","16","32","48","64","120","1024"}) + public int length = 10000; + + private final byte[][] byteInputData = new byte[SAMPLES][]; + private final String[] stringInputData = new String[SAMPLES]; + + @Setup(Level.Trial) + public void init() { + final Random random = new Random(45155315113511L); + for (int i = 0; i < SAMPLES; i++) { + // byte[] + byteInputData[i] = new byte[length]; + random.nextBytes(byteInputData[i]); + // string + StringBuilder builder = new StringBuilder(length); + for (int j = 0; j < length; j++) { + builder.append(CHAR_POOL.charAt(random.nextInt(CHAR_POOL.length()))); + } + stringInputData[i] = builder.toString(); + } + } + + @Benchmark + @OperationsPerInvocation(SAMPLES) + public void testBytesHashing(final Blackhole blackhole) { + for (int i = 0; i < SAMPLES; i++) { + blackhole.consume(XXH3_64.DEFAULT_INSTANCE.hashBytesToLong(byteInputData[i], 0, byteInputData[i].length)); + } + } + + @Benchmark + @OperationsPerInvocation(SAMPLES) + public void testStringHashing(final Blackhole blackhole) { + for (int i = 0; i < SAMPLES; i++) { + blackhole.consume(XXH3_64.DEFAULT_INSTANCE.hashCharsToLong(stringInputData[i])); + } + } +} diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/UnknownField.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/UnknownField.java index 13781e22..b79ed9c9 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/UnknownField.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/UnknownField.java @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 package com.hedera.pbj.runtime; +import com.hedera.pbj.runtime.hashing.SixtyFourBitHashable; +import com.hedera.pbj.runtime.hashing.XXH3_64; import com.hedera.pbj.runtime.io.buffer.Bytes; import edu.umd.cs.findbugs.annotations.NonNull; import java.util.Objects; @@ -21,7 +23,7 @@ * @param bytes a list of the raw bytes of each occurrence of the field (e.g. for repeated fields) */ public record UnknownField(int field, @NonNull ProtoConstants wireType, @NonNull Bytes bytes) - implements Comparable { + implements Comparable, SixtyFourBitHashable { /** * A {@code Comparable} implementation that sorts UnknownField objects by their `field` numbers * in the increasing order. This comparator is used for maintaining a stable and deterministic order for any @@ -52,27 +54,28 @@ public boolean equals(final Object o) { * An `Object.hashCode()` implementation that computes a hash code using all the members of the UnknownField record: * the `field`, the `wireType`, and the `bytes`. * The implementation should remain stable over time because this is a public API. + *

+ * This hash code has to match how the field would be hashed if it was a normal field in the schema + *

*/ @Override public int hashCode() { - int hashCode = 1; - - hashCode = 31 * hashCode + Integer.hashCode(field); - hashCode = 31 * hashCode + Integer.hashCode(wireType.ordinal()); - hashCode = 31 * hashCode + bytes.hashCode(); - - // Shifts: 30, 27, 16, 20, 5, 18, 10, 24, 30 - hashCode += hashCode << 30; - hashCode ^= hashCode >>> 27; - hashCode += hashCode << 16; - hashCode ^= hashCode >>> 20; - hashCode += hashCode << 5; - hashCode ^= hashCode >>> 18; - hashCode += hashCode << 10; - hashCode ^= hashCode >>> 24; - hashCode += hashCode << 30; + return (int)hashCode64(); + } - return hashCode; + /** + * A `SixtyFourBitHashable.hashCode64()` implementation that computes a 64-bit hash code using all the members of + * the UnknownField record: the `field`, the `wireType`, and the `bytes`. + * The implementation should remain stable over time because this is a public API. + *

+ * This hash code has to match how the field would be hashed if it was a normal field in the schema + *

+ * + * @return a 64-bit hash code for this UnknownField object + */ + @Override + public long hashCode64() { + return bytes.hashCode64(); } /** diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/hashing/SixtyFourBitHashable.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/hashing/SixtyFourBitHashable.java new file mode 100644 index 00000000..8f9ef54a --- /dev/null +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/hashing/SixtyFourBitHashable.java @@ -0,0 +1,13 @@ +package com.hedera.pbj.runtime.hashing; + +/** + * Interface for objects that can be hashed to a 64-bit long value. + */ +public interface SixtyFourBitHashable { + /** + * Hash this object to a 64-bit long value. + * + * @return the 64-bit hash value + */ + long hashCode64(); +} diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/hashing/XXH3FieldHash.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/hashing/XXH3FieldHash.java new file mode 100644 index 00000000..f032dc45 --- /dev/null +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/hashing/XXH3FieldHash.java @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.runtime.hashing; + +/** + * Ultra-fast XXH3-inspired combiner for 64-bit field values. + *

+ * Goal: + *

    + *
  • Combine a sequence of 0..N 64-bit values (produced from fields) into a single 64-bit hash.
  • + *
  • Skip default-valued fields by not adding them to the sequence.
  • + *
  • No per-call allocations: only local vars, static methods, and compile-time constants.
  • + *
  • "Close to" XXH3-64 distribution; not bit-for-bit identical for arbitrary lengths.
  • + *
  • Very codegen-friendly: just keep local {acc, totalBytes, carry, haveCarry, pairIndex} and call static methods.
  • + *
+ *

+ * Pattern: + *

    + *
  • Accumulator starts at 0, totalBytes at 0.
  • + *
  • For each non-default field, convert to a 64-bit value v. If you have no carry, set carry=v and haveCarry=true; + * otherwise combine the pair (carry, v) with mixPair() and clear carry.
  • + *
  • After all fields, if a single leftover carry exists, fold it with mixTail8().
  • + *
  • Finalize with finish().
  • + *
+ *

+ * If the sequence is empty (no non-defaults), finish() returns the canonical empty-hash for XXH3 with seed=0. + *

+ * Notes: + *

    + *
  • This uses the same core mix and avalanche functions as XXH3, with a simple rotating secret schedule.
  • + *
  • For best throughput, unroll in your generator (avoid loops if you can), but the call overhead is tiny.
  • + *
+ *
+ * {@code
+ * // Pseudo-example showing how your generator should produce code for a model:
+ * final class ExampleModel {
+ *     long a; // default 0
+ *     long b; // default 0
+ *     int  c; // default 0
+ *     double d; // default 0.0
+ *     // ...
+ *
+ *     // Convert each field to a 64-bit value deterministically (little-endian representation in spirit).
+ *     // Only include the field if it is non-default for that field.
+ *     public long hashCode64() {
+ *         long $xx_acc = 0L;
+ *         long $xx_total = 0L;
+ *         long $xx_carry = 0L;
+ *         boolean $xx_haveCarry = false;
+ *         int $xx_pairIndex = 0;
+ *
+ *         // Field a (long)
+ *         if (a != 0L) {
+ *             final long v = a; // already a 64-bit value
+ *             if (!$xx_haveCarry) { $xx_carry = v; $xx_haveCarry = true; $xx_total += 8; }
+ *             else { $xx_acc += XXH3FieldHash.mixPair($xx_carry, v, $xx_pairIndex++); $xx_haveCarry = false; $xx_total += 8; }
+ *         }
+ *
+ *         // Field b (long)
+ *         if (b != 0L) {
+ *             final long v = b;
+ *             if (!$xx_haveCarry) { $xx_carry = v; $xx_haveCarry = true; $xx_total += 8; }
+ *             else { $xx_acc += XXH3FieldHash.mixPair($xx_carry, v, $xx_pairIndex++); $xx_haveCarry = false; $xx_total += 8; }
+ *         }
+ *
+ *         // Field c (int) -> widen to 64-bits in a stable way
+ *         if (c != 0) {
+ *             final long v = ((long) c) & 0xFFFFFFFFL; // LE semantics if desired
+ *             if (!$xx_haveCarry) { $xx_carry = v; $xx_haveCarry = true; $xx_total += 8; }
+ *             else { $xx_acc += XXH3FieldHash.mixPair($xx_carry, v, $xx_pairIndex++); $xx_haveCarry = false; $xx_total += 8; }
+ *         }
+ *
+ *         // Field d (double)
+ *         if (Double.doubleToRawLongBits(d) != 0L) {
+ *             final long v = Double.doubleToRawLongBits(d);
+ *             if (!$xx_haveCarry) { $xx_carry = v; $xx_haveCarry = true; $xx_total += 8; }
+ *             else { $xx_acc += XXH3FieldHash.mixPair($xx_carry, v, $xx_pairIndex++); $xx_haveCarry = false; $xx_total += 8; }
+ *         }
+ *
+ *         // ... repeat for all fields ...
+ *
+ *         if ($xx_haveCarry) {
+ *             $xx_acc += XXH3FieldHash.mixTail8($xx_carry, $xx_pairIndex);
+ *         }
+ *         return XXH3FieldHash.finish($xx_acc, $xx_total);
+ *     }
+ * }
+ * }
+ * 
+ */ +public final class XXH3FieldHash { + + private XXH3FieldHash() {} + + // XXH3 constants (seed=0), inlined to avoid touching an instance. + // We use only the first 16 for the rotating pair schedule. + private static final long SECRET_00 = 0xbe4ba423396cfeb8L; + private static final long SECRET_01 = 0x1cad21f72c81017cL; + private static final long SECRET_02 = 0xdb979083e96dd4deL; + private static final long SECRET_03 = 0x1f67b3b7a4a44072L; + private static final long SECRET_04 = 0x78e5c0cc4ee679cbL; + private static final long SECRET_05 = 0x2172ffcc7dd05a82L; + private static final long SECRET_06 = 0x8e2443f7744608b8L; + private static final long SECRET_07 = 0x4c263a81e69035e0L; + private static final long SECRET_08 = 0xcb00c391bb52283cL; + private static final long SECRET_09 = 0xa32e531b8b65d088L; + private static final long SECRET_10 = 0x4ef90da297486471L; + private static final long SECRET_11 = 0xd8acdea946ef1938L; + private static final long SECRET_12 = 0x3f349ce33f76faa8L; + private static final long SECRET_13 = 0x1d4f0bc7c7bbdcf9L; + private static final long SECRET_14 = 0x3159b4cd4be0518aL; + private static final long SECRET_15 = 0x647378d9c97e9fc8L; + + // INIT_ACC_1 from xxh3 (publicly visible value in the provided code) + private static final long INIT_ACC_1 = 0x9E3779B185EBCA87L; + + // Canonical hash for empty input with seed=0 (matches XXH3_64.DEFAULT_INSTANCE.hash0) + // Computed as avalanche64(0 ^ (SECRET_07 ^ SECRET_08)). + private static final long EMPTY_HASH = + XXH3_64.avalanche64(SECRET_07 ^ SECRET_08); + + // Small fixed table for rotating secret pairs. index = (pairIndex & 7) << 1 + private static final long[] SECRET_PAIRS = { + SECRET_00, SECRET_01, + SECRET_02, SECRET_03, + SECRET_04, SECRET_05, + SECRET_06, SECRET_07, + SECRET_08, SECRET_09, + SECRET_10, SECRET_11, + SECRET_12, SECRET_13, + SECRET_14, SECRET_15 + }; + + private static long mix2(long lo, long hi, long s0, long s1) { + // Equivalent to the XXH3 "mix2Accs" core: mix(lo ^ s0, hi ^ s1) + return XXH3_64.mix(lo ^ s0, hi ^ s1); + } + + /** + * Mix a pair of 64-bit values into the accumulator with a rotating secret schedule. + * + * @param lo first value in the pair + * @param hi second value in the pair + * @param pairIndex 0-based index of the pair (increment by 1 each time you call mixPair) + * @return updated accumulator + */ + public static long mixPair(long lo, long hi, int pairIndex) { + final int base = (pairIndex & 7) << 1; // 0..14 + final long s0 = SECRET_PAIRS[base]; + final long s1 = SECRET_PAIRS[base + 1]; + return XXH3_64.mix(lo ^ s0, hi ^ s1); + } + + /** + * Mix a single 64-bit leftover ("tail") value into the accumulator. + * Uses rrmxmx with a rotating secret; very fast and high quality. + * + * @param v tail value (one leftover long) + * @param pairIndex the next pair index (i.e., the count of full pairs already processed) + * @return updated accumulator + */ + public static long mixTail8(long v, int pairIndex) { + final int base = (pairIndex & 7) << 1; + final long s0 = SECRET_PAIRS[base]; + // rrmxmx gives excellent avalanching for a single 64-bit word with a "length" tweak of 8 + return XXH3_64.rrmxmx(v ^ s0, 8); + } + + /** + * Finish the hash. If totalBytes==0 returns the canonical XXH3 empty hash (seed=0). + * Otherwise applies a simple length bias like XXH3 and finishes with avalanche3. + * + * @param acc running accumulator from mixPair/mixTail8 + * @param totalBytes total output bytes you've conceptually written (8 per non-default field) + * @return final 64-bit hash + */ + public static long finish(long acc, long totalBytes) { + if (totalBytes == 0) { + return EMPTY_HASH; + } + // XXH3-like finish: add a length bias and avalanche3 + return XXH3_64.avalanche3(acc + totalBytes * INIT_ACC_1); + } + + /** + * Convert int primitive to a 64-bit value in a stable way using little-endian semantics, + * suitable for combining with mixPair or mixTail8. + * + * @param value the value to convert + * @return a 64-bit representation of the value + */ + public static long toLong(final int value) { + // Convert an int to a long in a stable way (little-endian semantics) + return ((long) value) & 0xFFFFFFFFL; // zero-extend to 64 bits + } + + /** + * Convert double primitive to a 64-bit value in a stable way using little-endian semantics, + * suitable for combining with mixPair or mixTail8. + * + * @param value the value to convert + * @return a 64-bit representation of the value + */ + public static long toLong(final double value) { + // Convert a double to a long in a stable way (little-endian semantics) + return Double.doubleToRawLongBits(value); + } + + /** + * Convert float primitive to a 64-bit value in a stable way using little-endian semantics, + * suitable for combining with mixPair or mixTail8. + * + * @param value the value to convert + * @return a 64-bit representation of the value + */ + public static long toLong(final float value) { + // Convert a float to a long in a stable way (little-endian semantics) + return Float.floatToRawIntBits(value) & 0xFFFFFFFFL; // zero-extend to 64 bits + } + + /** + * Convert boolean primitive to a 64-bit value in a stable way, suitable for combining with mixPair or mixTail8. + * + * @param value the value to convert + * @return a 64-bit representation of the value + */ + public static long toLong(final boolean value) { + // Convert a boolean to a long in a stable way (little-endian semantics) + return value ? 1L : 0L; // 1 for true, 0 for false + } + + /** + * Convert a byte array to a 64-bit value in a stable way using XXH3 64 hashing, + * suitable for combining with mixPair or mixTail8. + * + * @param value the byte array to convert + * @return a 64-bit representation of the byte array + */ + public static long toLong(final byte[] value) { + // Convert a byte to a long in a stable way (little-endian semantics) + return XXH3_64.DEFAULT_INSTANCE.hashBytesToLong(value,0,value.length); + } + + /** + * Convert a String to a 64-bit value in a stable way using XXH3 64 hashing of UTF16 bytes, + * suitable for combining with mixPair or mixTail8. + * + * @param value the String to convert + * @return a 64-bit representation of the String + */ + public static long toLong(final String value) { + return XXH3_64.DEFAULT_INSTANCE.hashCharsToLong(value); + } +} \ No newline at end of file diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/hashing/XXH3_64.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/hashing/XXH3_64.java new file mode 100644 index 00000000..cb96e9f3 --- /dev/null +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/hashing/XXH3_64.java @@ -0,0 +1,1221 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.runtime.hashing; + +import com.hedera.pbj.runtime.io.WritableSequentialData; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import com.hedera.pbj.runtime.io.buffer.Bytes; +import com.hedera.pbj.runtime.io.buffer.RandomAccessData; +import edu.umd.cs.findbugs.annotations.NonNull; +import java.io.UncheckedIOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; + +/** + * XXH3_64 is a 64-bit variant of the XXH3 hash function. + * It is designed to be fast and efficient, providing a good balance between speed and + * collision resistance. + * + *

It is recommended to use {@link #DEFAULT_INSTANCE} for most use cases. + * @see xxhash + */ +@SuppressWarnings({"DuplicatedCode", "NumericOverflow"}) +public final class XXH3_64 { + /** + * The default seed value used for hashing. ZERO is chosen as this is the default for xxhash and used in tools lile + * xxhsum command line tool. + */ + public static final long DEFAULT_SEED = 0; + /** Default instance of the XXH3_64 hasher with a seed of 0. */ + public static final XXH3_64 DEFAULT_INSTANCE = new XXH3_64(DEFAULT_SEED); + /** VarHandle for reading and writing longs in little-endian byte order. */ + private static final VarHandle LONG_HANDLE = + MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN); + /** VarHandle for reading and writing integers in little-endian byte order. */ + private static final VarHandle INT_HANDLE = + MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.LITTLE_ENDIAN); + /** The block length exponent, used for processing input in blocks of 1024 bytes. */ + private static final int BLOCK_LEN_EXP = 10; + // Private constants for the secret values used in the hash function. + private static final long SECRET_00 = 0xbe4ba423396cfeb8L; + private static final long SECRET_01 = 0x1cad21f72c81017cL; + private static final long SECRET_02 = 0xdb979083e96dd4deL; + private static final long SECRET_03 = 0x1f67b3b7a4a44072L; + private static final long SECRET_04 = 0x78e5c0cc4ee679cbL; + private static final long SECRET_05 = 0x2172ffcc7dd05a82L; + private static final long SECRET_06 = 0x8e2443f7744608b8L; + private static final long SECRET_07 = 0x4c263a81e69035e0L; + private static final long SECRET_08 = 0xcb00c391bb52283cL; + private static final long SECRET_09 = 0xa32e531b8b65d088L; + private static final long SECRET_10 = 0x4ef90da297486471L; + private static final long SECRET_11 = 0xd8acdea946ef1938L; + private static final long SECRET_12 = 0x3f349ce33f76faa8L; + private static final long SECRET_13 = 0x1d4f0bc7c7bbdcf9L; + private static final long SECRET_14 = 0x3159b4cd4be0518aL; + private static final long SECRET_15 = 0x647378d9c97e9fc8L; + private static final long SECRET_16 = 0xc3ebd33483acc5eaL; + private static final long SECRET_17 = 0xeb6313faffa081c5L; + private static final long SECRET_18 = 0x49daf0b751dd0d17L; + private static final long SECRET_19 = 0x9e68d429265516d3L; + private static final long SECRET_20 = 0xfca1477d58be162bL; + private static final long SECRET_21 = 0xce31d07ad1b8f88fL; + private static final long SECRET_22 = 0x280416958f3acb45L; + private static final long SECRET_23 = 0x7e404bbbcafbd7afL; + private static final long INIT_ACC_0 = 0x00000000C2B2AE3DL; + private static final long INIT_ACC_1 = 0x9E3779B185EBCA87L; + private static final long INIT_ACC_2 = 0xC2B2AE3D27D4EB4FL; + private static final long INIT_ACC_3 = 0x165667B19E3779F9L; + private static final long INIT_ACC_4 = 0x85EBCA77C2B2AE63L; + private static final long INIT_ACC_5 = 0x0000000085EBCA77L; + private static final long INIT_ACC_6 = 0x27D4EB2F165667C5L; + private static final long INIT_ACC_7 = 0x000000009E3779B1L; + // Private constants for the mix and avalanche functions, derived from seed and secret values. + public final long secret00; + public final long secret01; + public final long secret02; + public final long secret03; + private final long secret04; + private final long secret05; + private final long secret06; + private final long secret07; + private final long secret08; + private final long secret09; + private final long secret10; + private final long secret11; + private final long secret12; + private final long secret13; + private final long secret14; + private final long secret15; + private final long secret16; + private final long secret17; + private final long secret18; + private final long secret19; + private final long secret20; + private final long secret21; + private final long secret22; + private final long secret23; + private final long[] secret; + private final long secShift00; + private final long secShift01; + private final long secShift02; + private final long secShift03; + private final long secShift04; + private final long secShift05; + private final long secShift06; + private final long secShift07; + private final long secShift08; + private final long secShift09; + private final long secShift10; + private final long secShift11; + private final long secShift16; + private final long secShift17; + private final long secShift18; + private final long secShift19; + private final long secShift20; + private final long secShift21; + private final long secShift22; + private final long secShift23; + private final long secShiftFinal0; + private final long secShiftFinal1; + private final long secShiftFinal2; + private final long secShiftFinal3; + private final long secShiftFinal4; + private final long secShiftFinal5; + private final long secShiftFinal6; + private final long secShiftFinal7; + private final long secShift12; + private final long secShift13; + private final long secShift14; + private final long secShift15; + public final long bitflip00; + public final long bitflip12; + public final long bitflip34; + public final long bitflip56; + /** Precomputed hash value used when the input length is 0. */ + public final long hash0; + + /** + * Creates a new instance of {@link XXH3_64} with the specified seed. + * + *

It is recommended to use {@link #DEFAULT_INSTANCE} for most use cases.

+ * + * @param seed the seed to use for hashing + */ + @SuppressWarnings("NumericOverflow") + public XXH3_64(long seed) { + this.secret00 = SECRET_00 + seed; + this.secret01 = SECRET_01 - seed; + this.secret02 = SECRET_02 + seed; + this.secret03 = SECRET_03 - seed; + this.secret04 = SECRET_04 + seed; + this.secret05 = SECRET_05 - seed; + this.secret06 = SECRET_06 + seed; + this.secret07 = SECRET_07 - seed; + this.secret08 = SECRET_08 + seed; + this.secret09 = SECRET_09 - seed; + this.secret10 = SECRET_10 + seed; + this.secret11 = SECRET_11 - seed; + this.secret12 = SECRET_12 + seed; + this.secret13 = SECRET_13 - seed; + this.secret14 = SECRET_14 + seed; + this.secret15 = SECRET_15 - seed; + this.secret16 = SECRET_16 + seed; + this.secret17 = SECRET_17 - seed; + this.secret18 = SECRET_18 + seed; + this.secret19 = SECRET_19 - seed; + this.secret20 = SECRET_20 + seed; + this.secret21 = SECRET_21 - seed; + this.secret22 = SECRET_22 + seed; + this.secret23 = SECRET_23 - seed; + + this.secShift00 = (SECRET_00 >>> 24) + (SECRET_01 << 40) + seed; + this.secShift01 = (SECRET_01 >>> 24) + (SECRET_02 << 40) - seed; + this.secShift02 = (SECRET_02 >>> 24) + (SECRET_03 << 40) + seed; + this.secShift03 = (SECRET_03 >>> 24) + (SECRET_04 << 40) - seed; + this.secShift04 = (SECRET_04 >>> 24) + (SECRET_05 << 40) + seed; + this.secShift05 = (SECRET_05 >>> 24) + (SECRET_06 << 40) - seed; + this.secShift06 = (SECRET_06 >>> 24) + (SECRET_07 << 40) + seed; + this.secShift07 = (SECRET_07 >>> 24) + (SECRET_08 << 40) - seed; + this.secShift08 = (SECRET_08 >>> 24) + (SECRET_09 << 40) + seed; + this.secShift09 = (SECRET_09 >>> 24) + (SECRET_10 << 40) - seed; + this.secShift10 = (SECRET_10 >>> 24) + (SECRET_11 << 40) + seed; + this.secShift11 = (SECRET_11 >>> 24) + (SECRET_12 << 40) - seed; + + this.secShift16 = secret15 >>> 8 | secret16 << 56; + this.secShift17 = secret16 >>> 8 | secret17 << 56; + this.secShift18 = secret17 >>> 8 | secret18 << 56; + this.secShift19 = secret18 >>> 8 | secret19 << 56; + this.secShift20 = secret19 >>> 8 | secret20 << 56; + this.secShift21 = secret20 >>> 8 | secret21 << 56; + this.secShift22 = secret21 >>> 8 | secret22 << 56; + this.secShift23 = secret22 >>> 8 | secret23 << 56; + + this.secShiftFinal0 = secret01 >>> 24 | secret02 << 40; + this.secShiftFinal1 = secret02 >>> 24 | secret03 << 40; + this.secShiftFinal2 = secret03 >>> 24 | secret04 << 40; + this.secShiftFinal3 = secret04 >>> 24 | secret05 << 40; + this.secShiftFinal4 = secret05 >>> 24 | secret06 << 40; + this.secShiftFinal5 = secret06 >>> 24 | secret07 << 40; + this.secShiftFinal6 = secret07 >>> 24 | secret08 << 40; + this.secShiftFinal7 = secret08 >>> 24 | secret09 << 40; + + this.secret = new long[] { + secret00, secret01, secret02, secret03, secret04, secret05, secret06, secret07, + secret08, secret09, secret10, secret11, secret12, secret13, secret14, secret15, + secret16, secret17, secret18, secret19, secret20, secret21, secret22, secret23 + }; + + this.secShift12 = (SECRET_12 >>> 24) + (SECRET_13 << 40) + seed; + this.secShift13 = (SECRET_13 >>> 24) + (SECRET_14 << 40) - seed; + this.secShift14 = (SECRET_14 >>> 56) + (SECRET_15 << 8) + seed; + this.secShift15 = (SECRET_15 >>> 56) + (SECRET_16 << 8) - seed; + + this.bitflip00 = ((SECRET_00 >>> 32) ^ (SECRET_00 & 0xFFFFFFFFL)) + seed; + this.bitflip12 = (SECRET_01 ^ SECRET_02) - (seed ^ Long.reverseBytes(seed & 0xFFFFFFFFL)); + this.bitflip34 = (SECRET_03 ^ SECRET_04) + seed; + this.bitflip56 = (SECRET_05 ^ SECRET_06) - seed; + + this.hash0 = avalanche64(seed ^ (SECRET_07 ^ SECRET_08)); + } + + /** + * Creates a new instance of {@link HashingWritableSequentialData} with the seed derived from the instance of + * {@link XXH3_64}. + * + * @return a new instance of {@link XXH3_64} + */ + public HashingWritableSequentialData hashingWritableSequentialData() { + return new HashingWritableSequentialData(); + } + + /** + * Hashes a pair of longs to a 64-bit {@code long} value. + * + *

+ * Equivalent to {@code hashBytesToLong(input1 || input2)}, where {@code ||} denotes concatenation + * and the longs are interpreted as little-endian byte sequences. + *

+ * + * @param input1 the first long data to hash + * @param input2 the second long data to hash + * @return the hash value + */ + public long hash(long input1, long input2) { + long lo = input1 ^ bitflip34; + long hi = input2 ^ bitflip56; + long acc = 16 + Long.reverseBytes(lo) + hi + mix(lo, hi); + return avalanche3(acc); + } + + /** + * Hashes a pair of long and double to a 64-bit {@code long} value. + * + * @param input1 the long data to hash + * @param input2 the double data to hash + * @return the hash value + */ + public long hash(long input1, double input2) { + long lo = input1 ^ bitflip34; + long hi = Double.doubleToRawLongBits(input2) ^ bitflip56; + long acc = 16 + Long.reverseBytes(lo) + hi + mix(lo, hi); + return avalanche3(acc); + } + + /** + * Hashes a pair of long and float to a 64-bit {@code long} value. + * + * @param input1 the long data to hash + * @param input2 the float data to hash + * @return the hash value + */ + public long hash(long input1, float input2) { + long lo = input1 ^ bitflip34; + long hi = Float.floatToRawIntBits(input2) ^ bitflip56; + long acc = 16 + Long.reverseBytes(lo) + hi + mix(lo, hi); + return avalanche3(acc); + } + + /** + * Hashes a pair of long and String to a 64-bit {@code long} value. + * + * @param input1 the long data to hash + * @param input2 the String data to hash + * @return the hash value + */ + public long hash(long input1, String input2) { + long lo = input1 ^ bitflip34; + byte[] stringBytes = input2.getBytes(StandardCharsets.UTF_8); + long hi = hashBytesToLong(stringBytes,0,stringBytes.length) ^ bitflip56; + long acc = 16 + Long.reverseBytes(lo) + hi + mix(lo, hi); + return avalanche3(acc); + } + + /** + * Hashes a pair of long and Bytes to a 64-bit {@code long} value. + * + * @param input1 the long data to hash + * @param input2 the Bytes data to hash + * @return the hash value + */ + public long hash(long input1, Bytes input2) { + long lo = input1 ^ bitflip34; + long hi = input2.hashCode64() ^ bitflip56; + long acc = 16 + Long.reverseBytes(lo) + hi + mix(lo, hi); + return avalanche3(acc); + } + + /** + * Hashes a byte array to a 64-bit {@code long} value. + * + *

Equivalent to {@code hashToLong(input, (b, f) -> f.putBytes(b, off, len))}. + * + * @param input the byte array + * @param off the offset + * @param length the length + * @return the hash value + */ + public long hashBytesToLong(final byte[] input, final int off, final int length) { + if (length <= 16) { + if (length > 8) { + long lo = (long) LONG_HANDLE.get(input, off) ^ bitflip34; + long hi = (long) LONG_HANDLE.get(input, off + length - 8) ^ bitflip56; + long acc = length + Long.reverseBytes(lo) + hi + mix(lo, hi); + return avalanche3(acc); + } + if (length >= 4) { + long input1 = (int) INT_HANDLE.get(input, off); + long input2 = (int) INT_HANDLE.get(input, off + length - 4); + long keyed = (input2 & 0xFFFFFFFFL) ^ (input1 << 32) ^ bitflip12; + return XXH3_64.rrmxmx(keyed, length); + } + if (length != 0) { + int c1 = input[off] & 0xFF; + int c2 = input[off + (length >> 1)]; + int c3 = input[off + length - 1] & 0xFF; + long combined = ((c1 << 16) | (c2 << 24) | c3 | ((long) length << 8)) & 0xFFFFFFFFL; + return avalanche64(combined ^ bitflip00); + } + return hash0; + } + if (length <= 128) { + long acc = length * INIT_ACC_1; + + if (length > 32) { + if (length > 64) { + if (length > 96) { + acc += XXH3_64.mix16B(input, off + 48, secret12, secret13); + acc += XXH3_64.mix16B(input, off + length - 64, secret14, secret15); + } + acc += XXH3_64.mix16B(input, off + 32, secret08, secret09); + acc += XXH3_64.mix16B(input, off + length - 48, secret10, secret11); + } + acc += XXH3_64.mix16B(input, off + 16, secret04, secret05); + acc += XXH3_64.mix16B(input, off + length - 32, secret06, secret07); + } + acc += XXH3_64.mix16B(input, off, secret00, secret01); + acc += XXH3_64.mix16B(input, off + length - 16, secret02, secret03); + + return avalanche3(acc); + } + if (length <= 240) { + long acc = length * INIT_ACC_1; + acc += XXH3_64.mix16B(input, off, secret00, secret01); + acc += XXH3_64.mix16B(input, off + 16, secret02, secret03); + acc += XXH3_64.mix16B(input, off + 16 * 2, secret04, secret05); + acc += XXH3_64.mix16B(input, off + 16 * 3, secret06, secret07); + acc += XXH3_64.mix16B(input, off + 16 * 4, secret08, secret09); + acc += XXH3_64.mix16B(input, off + 16 * 5, secret10, secret11); + acc += XXH3_64.mix16B(input, off + 16 * 6, secret12, secret13); + acc += XXH3_64.mix16B(input, off + 16 * 7, secret14, secret15); + + acc = avalanche3(acc); + + if (length >= 144) { + acc += XXH3_64.mix16B(input, off + 128, secShift00, secShift01); + if (length >= 160) { + acc += XXH3_64.mix16B(input, off + 144, secShift02, secShift03); + if (length >= 176) { + acc += XXH3_64.mix16B(input, off + 160, secShift04, secShift05); + if (length >= 192) { + acc += XXH3_64.mix16B(input, off + 176, secShift06, secShift07); + if (length >= 208) { + acc += XXH3_64.mix16B(input, off + 192, secShift08, secShift09); + if (length >= 224) { + acc += XXH3_64.mix16B(input, off + 208, secShift10, secShift11); + if (length >= 240) acc += XXH3_64.mix16B(input, off + 224, secShift12, secShift13); + } + } + } + } + } + } + acc += XXH3_64.mix16B(input, off + length - 16, secShift14, secShift15); + return avalanche3(acc); + } + + long acc0 = INIT_ACC_0; + long acc1 = INIT_ACC_1; + long acc2 = INIT_ACC_2; + long acc3 = INIT_ACC_3; + long acc4 = INIT_ACC_4; + long acc5 = INIT_ACC_5; + long acc6 = INIT_ACC_6; + long acc7 = INIT_ACC_7; + + final int nbBlocks = (length - 1) >>> BLOCK_LEN_EXP; + for (int n = 0; n < nbBlocks; n++) { + final int offBlock = off + (n << BLOCK_LEN_EXP); + for (int s = 0; s < 16; s += 1) { + int offStripe = offBlock + (s << 6); + + long b0 = (long) LONG_HANDLE.get(input, offStripe); + long b1 = (long) LONG_HANDLE.get(input, offStripe + 8); + long b2 = (long) LONG_HANDLE.get(input, offStripe + 8 * 2); + long b3 = (long) LONG_HANDLE.get(input, offStripe + 8 * 3); + long b4 = (long) LONG_HANDLE.get(input, offStripe + 8 * 4); + long b5 = (long) LONG_HANDLE.get(input, offStripe + 8 * 5); + long b6 = (long) LONG_HANDLE.get(input, offStripe + 8 * 6); + long b7 = (long) LONG_HANDLE.get(input, offStripe + 8 * 7); + + acc0 += b1 + contrib(b0, secret[s]); + acc1 += b0 + contrib(b1, secret[s + 1]); + acc2 += b3 + contrib(b2, secret[s + 2]); + acc3 += b2 + contrib(b3, secret[s + 3]); + acc4 += b5 + contrib(b4, secret[s + 4]); + acc5 += b4 + contrib(b5, secret[s + 5]); + acc6 += b7 + contrib(b6, secret[s + 6]); + acc7 += b6 + contrib(b7, secret[s + 7]); + } + + acc0 = mixAcc(acc0, secret16); + acc1 = mixAcc(acc1, secret17); + acc2 = mixAcc(acc2, secret18); + acc3 = mixAcc(acc3, secret19); + acc4 = mixAcc(acc4, secret20); + acc5 = mixAcc(acc5, secret21); + acc6 = mixAcc(acc6, secret22); + acc7 = mixAcc(acc7, secret23); + } + + final int nbStripes = ((length - 1) - (nbBlocks << BLOCK_LEN_EXP)) >>> 6; + final int offBlock = off + (nbBlocks << BLOCK_LEN_EXP); + for (int s = 0; s < nbStripes; s++) { + int offStripe = offBlock + (s << 6); + + long b0 = (long) LONG_HANDLE.get(input, offStripe); + long b1 = (long) LONG_HANDLE.get(input, offStripe + 8); + long b2 = (long) LONG_HANDLE.get(input, offStripe + 8 * 2); + long b3 = (long) LONG_HANDLE.get(input, offStripe + 8 * 3); + long b4 = (long) LONG_HANDLE.get(input, offStripe + 8 * 4); + long b5 = (long) LONG_HANDLE.get(input, offStripe + 8 * 5); + long b6 = (long) LONG_HANDLE.get(input, offStripe + 8 * 6); + long b7 = (long) LONG_HANDLE.get(input, offStripe + 8 * 7); + + acc0 += b1 + contrib(b0, secret[s]); + acc1 += b0 + contrib(b1, secret[s + 1]); + acc2 += b3 + contrib(b2, secret[s + 2]); + acc3 += b2 + contrib(b3, secret[s + 3]); + acc4 += b5 + contrib(b4, secret[s + 4]); + acc5 += b4 + contrib(b5, secret[s + 5]); + acc6 += b7 + contrib(b6, secret[s + 6]); + acc7 += b6 + contrib(b7, secret[s + 7]); + } + + { + int offStripe = off + length - 64; + + long b0 = (long) LONG_HANDLE.get(input, offStripe); + long b1 = (long) LONG_HANDLE.get(input, offStripe + 8); + long b2 = (long) LONG_HANDLE.get(input, offStripe + 8 * 2); + long b3 = (long) LONG_HANDLE.get(input, offStripe + 8 * 3); + long b4 = (long) LONG_HANDLE.get(input, offStripe + 8 * 4); + long b5 = (long) LONG_HANDLE.get(input, offStripe + 8 * 5); + long b6 = (long) LONG_HANDLE.get(input, offStripe + 8 * 6); + long b7 = (long) LONG_HANDLE.get(input, offStripe + 8 * 7); + + acc0 += b1 + contrib(b0, secShift16); + acc1 += b0 + contrib(b1, secShift17); + acc2 += b3 + contrib(b2, secShift18); + acc3 += b2 + contrib(b3, secShift19); + acc4 += b5 + contrib(b4, secShift20); + acc5 += b4 + contrib(b5, secShift21); + acc6 += b7 + contrib(b6, secShift22); + acc7 += b6 + contrib(b7, secShift23); + } + + return finalizeHash(length, acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7); + } + + /** + * Hashes a CharSequence as raw UTF16 to a 64-bit {@code long} value. + * + *

Equivalent to {@code hashToLong(input, (b, f) -> f.putBytes(b, off, len))}. + * + * @param charSequence the character sequence to hash + * @return the hash value + */ + public long hashCharsToLong(final CharSequence charSequence) { + int len = charSequence.length(); + if (len <= 8) { + if (len > 4) { + long lo = getLong(charSequence, 0) ^ bitflip34; + long hi = getLong(charSequence, len - 4) ^ bitflip56; + long acc = (len << 1) + Long.reverseBytes(lo) + hi + mix(lo, hi); + return avalanche3(acc); + } + if (len >= 2) { + long input1 = getInt(charSequence, 0); + long input2 = getInt(charSequence, len - 2); + long keyed = (input2 & 0xFFFFFFFFL) ^ (input1 << 32) ^ bitflip12; + return rrmxmx(keyed, len << 1); + } + if (len != 0) { + long c = charSequence.charAt(0); + long combined = (c << 16) | (c >>> 8) | 512L; + return avalanche64(combined ^ bitflip00); + } + return hash0; + } + if (len <= 64) { + long acc = len * (INIT_ACC_1 << 1); + + if (len > 16) { + if (len > 32) { + if (len > 48) { + acc += mix16B(charSequence, 24, secret12, secret13); + acc += mix16B(charSequence, len - 32, secret14, secret15); + } + acc += mix16B(charSequence, 16, secret08, secret09); + acc += mix16B(charSequence, len - 24, secret10, secret11); + } + acc += mix16B(charSequence, 8, secret04, secret05); + acc += mix16B(charSequence, len - 16, secret06, secret07); + } + acc += mix16B(charSequence, 0, secret00, secret01); + acc += mix16B(charSequence, len - 8, secret02, secret03); + + return avalanche3(acc); + } + if (len <= 120) { + long acc = len * (INIT_ACC_1 << 1); + acc += mix16B(charSequence, 0, secret00, secret01); + acc += mix16B(charSequence, 8, secret02, secret03); + acc += mix16B(charSequence, 16, secret04, secret05); + acc += mix16B(charSequence, 24, secret06, secret07); + acc += mix16B(charSequence, 32, secret08, secret09); + acc += mix16B(charSequence, 40, secret10, secret11); + acc += mix16B(charSequence, 48, secret12, secret13); + acc += mix16B(charSequence, 56, secret14, secret15); + + acc = avalanche3(acc); + + if (len >= 72) { + acc += mix16B(charSequence, 64, secShift00, secShift01); + if (len >= 80) { + acc += mix16B(charSequence, 72, secShift02, secShift03); + if (len >= 88) { + acc += mix16B(charSequence, 80, secShift04, secShift05); + if (len >= 96) { + acc += mix16B(charSequence, 88, secShift06, secShift07); + if (len >= 104) { + acc += mix16B(charSequence, 96, secShift08, secShift09); + if (len >= 112) { + acc += mix16B(charSequence, 104, secShift10, secShift11); + if (len >= 120) acc += mix16B(charSequence, 112, secShift12, secShift13); + } + } + } + } + } + } + acc += mix16B(charSequence, len - 8, secShift14, secShift15); + return avalanche3(acc); + } + + long acc0 = INIT_ACC_0; + long acc1 = INIT_ACC_1; + long acc2 = INIT_ACC_2; + long acc3 = INIT_ACC_3; + long acc4 = INIT_ACC_4; + long acc5 = INIT_ACC_5; + long acc6 = INIT_ACC_6; + long acc7 = INIT_ACC_7; + + final int nbBlocks = (len - 1) >>> (BLOCK_LEN_EXP - 1); + for (int n = 0; n < nbBlocks; n++) { + final int offBlock = n << (BLOCK_LEN_EXP - 1); + for (int s = 0; s < 16; s += 1) { + int offStripe = offBlock + (s << 5); + + long b0 = getLong(charSequence, offStripe); + long b1 = getLong(charSequence, offStripe + 4); + long b2 = getLong(charSequence, offStripe + 4 * 2); + long b3 = getLong(charSequence, offStripe + 4 * 3); + long b4 = getLong(charSequence, offStripe + 4 * 4); + long b5 = getLong(charSequence, offStripe + 4 * 5); + long b6 = getLong(charSequence, offStripe + 4 * 6); + long b7 = getLong(charSequence, offStripe + 4 * 7); + + acc0 += b1 + contrib(b0, secret[s]); + acc1 += b0 + contrib(b1, secret[s + 1]); + acc2 += b3 + contrib(b2, secret[s + 2]); + acc3 += b2 + contrib(b3, secret[s + 3]); + acc4 += b5 + contrib(b4, secret[s + 4]); + acc5 += b4 + contrib(b5, secret[s + 5]); + acc6 += b7 + contrib(b6, secret[s + 6]); + acc7 += b6 + contrib(b7, secret[s + 7]); + } + + acc0 = mixAcc(acc0, secret16); + acc1 = mixAcc(acc1, secret17); + acc2 = mixAcc(acc2, secret18); + acc3 = mixAcc(acc3, secret19); + acc4 = mixAcc(acc4, secret20); + acc5 = mixAcc(acc5, secret21); + acc6 = mixAcc(acc6, secret22); + acc7 = mixAcc(acc7, secret23); + } + + final int nbStripes = ((len - 1) - (nbBlocks << (BLOCK_LEN_EXP - 1))) >>> 5; + final int offBlock = nbBlocks << (BLOCK_LEN_EXP - 1); + for (int s = 0; s < nbStripes; s++) { + int offStripe = offBlock + (s << 5); + + long b0 = getLong(charSequence, offStripe); + long b1 = getLong(charSequence, offStripe + 4); + long b2 = getLong(charSequence, offStripe + 4 * 2); + long b3 = getLong(charSequence, offStripe + 4 * 3); + long b4 = getLong(charSequence, offStripe + 4 * 4); + long b5 = getLong(charSequence, offStripe + 4 * 5); + long b6 = getLong(charSequence, offStripe + 4 * 6); + long b7 = getLong(charSequence, offStripe + 4 * 7); + + acc0 += b1 + contrib(b0, secret[s]); + acc1 += b0 + contrib(b1, secret[s + 1]); + acc2 += b3 + contrib(b2, secret[s + 2]); + acc3 += b2 + contrib(b3, secret[s + 3]); + acc4 += b5 + contrib(b4, secret[s + 4]); + acc5 += b4 + contrib(b5, secret[s + 5]); + acc6 += b7 + contrib(b6, secret[s + 6]); + acc7 += b6 + contrib(b7, secret[s + 7]); + } + + { + int offStripe = len - 32; + + long b0 = getLong(charSequence, offStripe); + long b1 = getLong(charSequence, offStripe + 4); + long b2 = getLong(charSequence, offStripe + 4 * 2); + long b3 = getLong(charSequence, offStripe + 4 * 3); + long b4 = getLong(charSequence, offStripe + 4 * 4); + long b5 = getLong(charSequence, offStripe + 4 * 5); + long b6 = getLong(charSequence, offStripe + 4 * 6); + long b7 = getLong(charSequence, offStripe + 4 * 7); + + acc0 += b1 + contrib(b0, secShift16); + acc1 += b0 + contrib(b1, secShift17); + acc2 += b3 + contrib(b2, secShift18); + acc3 += b2 + contrib(b3, secShift19); + acc4 += b5 + contrib(b4, secShift20); + acc5 += b4 + contrib(b5, secShift21); + acc6 += b7 + contrib(b6, secShift22); + acc7 += b6 + contrib(b7, secShift23); + } + + return finalizeHash((long) len << 1, acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7); + } + + // ============================================================================================================= + // Private methods + // ============================================================================================================= + + + /** + * Reads a {@code long} value from four UTF16 characters from a {@link CharSequence} with given offset. + * + * @param charSequence a char sequence + * @param off an offset + * @return the value + */ + private static long getLong(final CharSequence charSequence, int off) { + return (long) charSequence.charAt(off) + | ((long) charSequence.charAt(off + 1) << 16) + | ((long) charSequence.charAt(off + 2) << 32) + | ((long) charSequence.charAt(off + 3) << 48); + } + + /** + * Reads an {@code int} value from two UTF16 characters from a {@link CharSequence} with given offset. + * + * @param charSequence a char sequence + * @param off an offset + * @return the value + */ + private static int getInt(CharSequence charSequence, int off) { + return (int) charSequence.charAt(off) | ((int) charSequence.charAt(off + 1) << 16); + } + + static long rrmxmx(long h64, final long length) { + h64 ^= Long.rotateLeft(h64, 49) ^ Long.rotateLeft(h64, 24); + h64 *= 0x9FB21C651E98DF25L; + h64 ^= (h64 >>> 35) + length; + h64 *= 0x9FB21C651E98DF25L; + return h64 ^ (h64 >>> 28); + } + + private static long mix16B(final byte[] input, final int offIn, final long sec0, final long sec1) { + long lo = (long) LONG_HANDLE.get(input, offIn); + long hi = (long) LONG_HANDLE.get(input, offIn + 8); + return mix2Accs(lo, hi, sec0, sec1); + } + + + private static long mix16B( final CharSequence input, final int offIn, final long sec0, final long sec1) { + long lo = getLong(input, offIn); + long hi = getLong(input, offIn + 4); + return mix2Accs(lo, hi, sec0, sec1); + } + + static long avalanche64(long h64) { + h64 ^= h64 >>> 33; + h64 *= INIT_ACC_2; + h64 ^= h64 >>> 29; + h64 *= INIT_ACC_3; + return h64 ^ (h64 >>> 32); + } + + static long avalanche3(long h64) { + h64 ^= h64 >>> 37; + h64 *= 0x165667919E3779F9L; + return h64 ^ (h64 >>> 32); + } + + private static long mix2Accs(final long lh, final long rh, long sec0, long sec8) { + return mix(lh ^ sec0, rh ^ sec8); + } + + private static long contrib(long a, long b) { + long k = a ^ b; + return (0xFFFFFFFFL & k) * (k >>> 32); + } + + private static long mixAcc(long acc, long sec) { + return (acc ^ (acc >>> 47) ^ sec) * INIT_ACC_7; + } + + static long mix(long a, long b) { + long x = a * b; + long y = Math.unsignedMultiplyHigh(a, b); + return x ^ y; + } + + private long finalizeHash( + long length, long acc0, long acc1, long acc2, long acc3, long acc4, long acc5, long acc6, long acc7) { + + long result64 = length * INIT_ACC_1 + + mix2Accs(acc0, acc1, secShiftFinal0, secShiftFinal1) + + mix2Accs(acc2, acc3, secShiftFinal2, secShiftFinal3) + + mix2Accs(acc4, acc5, secShiftFinal4, secShiftFinal5) + + mix2Accs(acc6, acc7, secShiftFinal6, secShiftFinal7); + + return avalanche3(result64); + } + + /** + * A writable sequential data implementation that hashes data using the XXH3_64 algorithm. + * It buffers writes in bulk and processes them to compute the hash incrementally. + */ + public class HashingWritableSequentialData implements WritableSequentialData { + /** The size of the buffer used for writing data in bulk. */ + private static final int BULK_SIZE = 256; + /** The mask for the bulk size, used for wrapping around the buffer. */ + private static final int BULK_SIZE_MASK = BULK_SIZE - 1; + // Initial accumulator values for the hashing process. + private long acc0 = INIT_ACC_0; + private long acc1 = INIT_ACC_1; + private long acc2 = INIT_ACC_2; + private long acc3 = INIT_ACC_3; + private long acc4 = INIT_ACC_4; + private long acc5 = INIT_ACC_5; + private long acc6 = INIT_ACC_6; + private long acc7 = INIT_ACC_7; + /** The buffer used for writing data in bulk. */ + private final byte[] buffer = new byte[BULK_SIZE + 8]; + /** The offset in the buffer where the next write will occur. */ + private int offset = 0; + /** The total number of bytes written so far. */ + private long byteCount = 0; + + /** + * Constructs a new instance of the XXH3_64 hashing writable sequential data. + */ + private HashingWritableSequentialData() {} + + // ============================================================================================================= + // WritableSequentialData implementation + // ============================================================================================================= + + @Override + public void skip(long count) throws UncheckedIOException { + // Skip is not supported in this implementation. + throw new UnsupportedOperationException("Skip operation is not supported in this implementation."); + } + + @Override + public void limit(long limit) { + // Skip is not supported in this implementation. + throw new UnsupportedOperationException("Skip operation is not supported in this implementation."); + } + + @Override + public long limit() { + return Long.MAX_VALUE; + } + + @Override + public long position() { + return byteCount; + } + + @Override + public long capacity() { + return Long.MAX_VALUE; + } + + /** {@inheritDoc} */ + @Override + public void writeVarLong(long value, boolean zigZag) throws BufferOverflowException, UncheckedIOException { + // We do not need to protobuf encode for hashing, so we just write the long directly. + writeLong(value); + } + + /** {@inheritDoc} */ + @Override + public void writeVarInt(int value, boolean zigZag) throws BufferOverflowException, UncheckedIOException { + // We do not need to protobuf encode for hashing, so we just write the int directly. + writeInt(value); + } + + /** {@inheritDoc} */ + @Override + public void writeDouble(double value, @NonNull ByteOrder byteOrder) + throws BufferOverflowException, UncheckedIOException { + // we are ignoring the byte order here as we always write in little-endian in hashes + writeLong(Double.doubleToRawLongBits(value)); + } + + /** {@inheritDoc} */ + @Override + public void writeDouble(double value) throws BufferOverflowException, UncheckedIOException { + writeLong(Double.doubleToRawLongBits(value)); + } + + /** {@inheritDoc} */ + @Override + public void writeFloat(float value, @NonNull ByteOrder byteOrder) + throws BufferOverflowException, UncheckedIOException { + // we are ignoring the byte order here as we always write in little-endian in hashes + writeInt(Float.floatToRawIntBits(value)); + } + + /** {@inheritDoc} */ + @Override + public void writeFloat(float value) throws BufferOverflowException, UncheckedIOException { + writeInt(Float.floatToRawIntBits(value)); + } + + /** {@inheritDoc} */ + @Override + public void writeLong(long value, @NonNull ByteOrder byteOrder) + throws BufferOverflowException, UncheckedIOException { + writeLong(value); + } + + /** {@inheritDoc} */ + @Override + public void writeLong(long value) throws BufferOverflowException, UncheckedIOException { + LONG_HANDLE.set(buffer, offset, value); + if (offset >= BULK_SIZE - 7) { + processBuffer(); + offset -= BULK_SIZE; + LONG_HANDLE.set(buffer, 0, value >>> (-offset << 3)); + } + offset += 8; + byteCount += 8; + } + + /** {@inheritDoc} */ + @Override + public void writeUnsignedInt(long value, @NonNull ByteOrder byteOrder) + throws BufferOverflowException, UncheckedIOException { + writeInt((int) value); + } + + /** {@inheritDoc} */ + @Override + public void writeUnsignedInt(long value) throws BufferOverflowException, UncheckedIOException { + writeInt((int) value); + } + + /** {@inheritDoc} */ + @Override + public void writeInt(int value, @NonNull ByteOrder byteOrder) + throws BufferOverflowException, UncheckedIOException { + writeInt(value); + } + + /** {@inheritDoc} */ + @Override + public void writeInt(int value) throws BufferOverflowException, UncheckedIOException { + INT_HANDLE.set(buffer, offset, value); + if (offset >= BULK_SIZE - 3) { + processBuffer(); + offset -= BULK_SIZE; + INT_HANDLE.set(buffer, 0, value >>> (-offset << 3)); + } + offset += 4; + byteCount += 4; + } + + /** {@inheritDoc} */ + @Override + public void writeBytes(@NonNull RandomAccessData src) throws BufferOverflowException, UncheckedIOException { + long offset = 0; + int length = Math.toIntExact(src.length()); + int remaining = length; + final int x = BULK_SIZE - this.offset; + byte[] temp = new byte[BULK_SIZE]; + if (length > x) { + int s = (int) ((byteCount - 1) >>> 6) & 12; + if (this.offset > 0) { + src.getBytes(offset, buffer, this.offset, x); + processBuffer(0, buffer, s); + this.offset = 0; + offset += x; + remaining -= x; + } + if (remaining > BULK_SIZE) { + do { + s += 4; + s &= 12; + src.getBytes(offset, temp, 0, BULK_SIZE); + processBuffer(0, temp, s); + offset += BULK_SIZE; + remaining -= BULK_SIZE; + } while (remaining > BULK_SIZE); + if (remaining < 64) { + int l = 64 - remaining; + src.getBytes(offset - l, buffer, BULK_SIZE - l, l); + } + } + } + src.getBytes(offset, buffer, this.offset, remaining); + this.offset += remaining; + byteCount += length; + } + + /** {@inheritDoc} */ + @Override + public void writeBytes(@NonNull BufferedData src) throws BufferOverflowException, UncheckedIOException { + long offset = src.position(); + int length = Math.toIntExact(src.remaining()); + int remaining = length; + final int x = BULK_SIZE - this.offset; + byte[] temp = new byte[BULK_SIZE]; + if (length > x) { + int s = (int) ((byteCount - 1) >>> 6) & 12; + if (this.offset > 0) { + src.getBytes(offset, buffer, this.offset, x); + processBuffer(0, buffer, s); + this.offset = 0; + offset += x; + remaining -= x; + } + if (remaining > BULK_SIZE) { + do { + s += 4; + s &= 12; + src.getBytes(offset, temp, 0, BULK_SIZE); + processBuffer(0, temp, s); + offset += BULK_SIZE; + remaining -= BULK_SIZE; + } while (remaining > BULK_SIZE); + if (remaining < 64) { + int l = 64 - remaining; + src.getBytes(offset - l, buffer, BULK_SIZE - l, l); + } + } + } + src.getBytes(offset, buffer, this.offset, remaining); + this.offset += remaining; + byteCount += length; + } + + /** {@inheritDoc} */ + @Override + public void writeBytes(@NonNull ByteBuffer srcBuffer) throws BufferOverflowException, UncheckedIOException { + int offset = srcBuffer.position(); + int length = srcBuffer.remaining(); + int remaining = length; + final int x = BULK_SIZE - this.offset; + if (length > x) { + int s = (int) ((byteCount - 1) >>> 6) & 12; + if (this.offset > 0) { + // Copy x bytes from srcBuffer to buffer at this.offset + srcBuffer.get(buffer, this.offset, x); + processBuffer(0, buffer, s); + this.offset = 0; + offset += x; + remaining -= x; + } + if (remaining > BULK_SIZE) { + do { + s += 4; + s &= 12; + // Copy BULK_SIZE bytes from srcBuffer to buffer + srcBuffer.get(buffer, 0, BULK_SIZE); + processBuffer(0, buffer, s); + offset += BULK_SIZE; + remaining -= BULK_SIZE; + } while (remaining > BULK_SIZE); + if (remaining < 64) { + int l = 64 - remaining; + // Copy l bytes from srcBuffer's previous position to buffer at BULK_SIZE - l + int prevPos = srcBuffer.position() - l; + int oldLimit = srcBuffer.limit(); + srcBuffer.limit(srcBuffer.position()); + srcBuffer.position(prevPos); + srcBuffer.get(buffer, BULK_SIZE - l, l); + srcBuffer.limit(oldLimit); + srcBuffer.position(offset + remaining); + } + } + } + // Copy remaining bytes from srcBuffer to buffer at this.offset + srcBuffer.get(buffer, this.offset, remaining); + this.offset += remaining; + byteCount += length; + } + + /** {@inheritDoc} */ + @Override + public void writeBytes(@NonNull byte[] src, int offset, int length) + throws BufferOverflowException, UncheckedIOException { + int remaining = length; + final int x = BULK_SIZE - this.offset; + if (length > x) { + int s = (int) ((byteCount - 1) >>> 6) & 12; + if (this.offset > 0) { + System.arraycopy(src, offset, buffer, this.offset, x); + processBuffer(0, buffer, s); + this.offset = 0; + offset += x; + remaining -= x; + } + if (remaining > BULK_SIZE) { + do { + s += 4; + s &= 12; + processBuffer(offset, src, s); + offset += BULK_SIZE; + remaining -= BULK_SIZE; + } while (remaining > BULK_SIZE); + if (remaining < 64) { + int l = 64 - remaining; + System.arraycopy(src, offset - l, buffer, BULK_SIZE - l, l); + } + } + } + System.arraycopy(src, offset, buffer, this.offset, remaining); + this.offset += remaining; + byteCount += length; + } + + /** {@inheritDoc} */ + @Override + public void writeByte(byte b) throws BufferOverflowException, UncheckedIOException { + if (offset >= BULK_SIZE) { + processBuffer(); + offset -= BULK_SIZE; + } + buffer[offset] = b; + offset += 1; + byteCount += 1; + } + + /** + * Resets the internal state of the hashing writable sequential data. + * This method clears the accumulated values and resets the buffer. + */ + public void reset() { + acc0 = INIT_ACC_0; + acc1 = INIT_ACC_1; + acc2 = INIT_ACC_2; + acc3 = INIT_ACC_3; + acc4 = INIT_ACC_4; + acc5 = INIT_ACC_5; + acc6 = INIT_ACC_6; + acc7 = INIT_ACC_7; + offset = 0; + byteCount = 0; + } + + /** + * Computes the hash of the data written so far. + * + * @return the computed hash as a 64-bit long value + */ + public long computeHash() { + if (byteCount >= 0 && byteCount <= BULK_SIZE) { + return hashBytesToLong(buffer, 0, (int) byteCount); + } + LONG_HANDLE.set(buffer, BULK_SIZE, (long) LONG_HANDLE.get(buffer, 0)); + + long acc0Loc = acc0; + long acc1Loc = acc1; + long acc2Loc = acc2; + long acc3Loc = acc3; + long acc4Loc = acc4; + long acc5Loc = acc5; + long acc6Loc = acc6; + long acc7Loc = acc7; + + for (int off = 0, s = (((int) byteCount - 1) >>> 6) & 12; + off + 64 <= (((int) byteCount - 1) & BULK_SIZE_MASK); + off += 64, s += 1) { + + long b0 = (long) LONG_HANDLE.get(buffer, off); + long b1 = (long) LONG_HANDLE.get(buffer, off + 8); + long b2 = (long) LONG_HANDLE.get(buffer, off + 8 * 2); + long b3 = (long) LONG_HANDLE.get(buffer, off + 8 * 3); + long b4 = (long) LONG_HANDLE.get(buffer, off + 8 * 4); + long b5 = (long) LONG_HANDLE.get(buffer, off + 8 * 5); + long b6 = (long) LONG_HANDLE.get(buffer, off + 8 * 6); + long b7 = (long) LONG_HANDLE.get(buffer, off + 8 * 7); + + acc0Loc += b1 + contrib(b0, secret[s]); + acc1Loc += b0 + contrib(b1, secret[s + 1]); + acc2Loc += b3 + contrib(b2, secret[s + 2]); + acc3Loc += b2 + contrib(b3, secret[s + 3]); + acc4Loc += b5 + contrib(b4, secret[s + 4]); + acc5Loc += b4 + contrib(b5, secret[s + 5]); + acc6Loc += b7 + contrib(b6, secret[s + 6]); + acc7Loc += b6 + contrib(b7, secret[s + 7]); + } + + { + long b0 = (long) LONG_HANDLE.get(buffer, (offset - (64)) & BULK_SIZE_MASK); + long b1 = (long) LONG_HANDLE.get(buffer, (offset - (64 - 8)) & BULK_SIZE_MASK); + long b2 = (long) LONG_HANDLE.get(buffer, (offset - (64 - 8 * 2)) & BULK_SIZE_MASK); + long b3 = (long) LONG_HANDLE.get(buffer, (offset - (64 - 8 * 3)) & BULK_SIZE_MASK); + long b4 = (long) LONG_HANDLE.get(buffer, (offset - (64 - 8 * 4)) & BULK_SIZE_MASK); + long b5 = (long) LONG_HANDLE.get(buffer, (offset - (64 - 8 * 5)) & BULK_SIZE_MASK); + long b6 = (long) LONG_HANDLE.get(buffer, (offset - (64 - 8 * 6)) & BULK_SIZE_MASK); + long b7 = (long) LONG_HANDLE.get(buffer, (offset - (64 - 8 * 7)) & BULK_SIZE_MASK); + + acc0Loc += b1 + contrib(b0, secShift16); + acc1Loc += b0 + contrib(b1, secShift17); + acc2Loc += b3 + contrib(b2, secShift18); + acc3Loc += b2 + contrib(b3, secShift19); + acc4Loc += b5 + contrib(b4, secShift20); + acc5Loc += b4 + contrib(b5, secShift21); + acc6Loc += b7 + contrib(b6, secShift22); + acc7Loc += b6 + contrib(b7, secShift23); + } + + return finalizeHash(byteCount, acc0Loc, acc1Loc, acc2Loc, acc3Loc, acc4Loc, acc5Loc, acc6Loc, acc7Loc); + } + + // ============================================================================================================= + // Internal methods for processing the buffer and computing the hash + // ============================================================================================================= + + private void processBuffer() { + int s = (int) ((byteCount - 1) >>> 6) & 12; + processBuffer(0, buffer, s); + } + + private void mixAcc() { + acc0 = XXH3_64.mixAcc(acc0, secret16); + acc1 = XXH3_64.mixAcc(acc1, secret17); + acc2 = XXH3_64.mixAcc(acc2, secret18); + acc3 = XXH3_64.mixAcc(acc3, secret19); + acc4 = XXH3_64.mixAcc(acc4, secret20); + acc5 = XXH3_64.mixAcc(acc5, secret21); + acc6 = XXH3_64.mixAcc(acc6, secret22); + acc7 = XXH3_64.mixAcc(acc7, secret23); + } + + private void processBuffer(int off, byte[] buffer, int s) { + for (int i = 0; i < 4; ++i) { + int o = off + (i << 6); + long b0 = (long) LONG_HANDLE.get(buffer, o); + long b1 = (long) LONG_HANDLE.get(buffer, o + 8); + long b2 = (long) LONG_HANDLE.get(buffer, o + 8 * 2); + long b3 = (long) LONG_HANDLE.get(buffer, o + 8 * 3); + long b4 = (long) LONG_HANDLE.get(buffer, o + 8 * 4); + long b5 = (long) LONG_HANDLE.get(buffer, o + 8 * 5); + long b6 = (long) LONG_HANDLE.get(buffer, o + 8 * 6); + long b7 = (long) LONG_HANDLE.get(buffer, o + 8 * 7); + processBuffer(b0, b1, b2, b3, b4, b5, b6, b7, s + i); + } + if (s == 12) { + mixAcc(); + } + } + + private void processBuffer(long b0, long b1, long b2, long b3, long b4, long b5, long b6, long b7, int s) { + acc0 += b1 + contrib(b0, secret[s]); + acc1 += b0 + contrib(b1, secret[s + 1]); + acc2 += b3 + contrib(b2, secret[s + 2]); + acc3 += b2 + contrib(b3, secret[s + 3]); + acc4 += b5 + contrib(b4, secret[s + 4]); + acc5 += b4 + contrib(b5, secret[s + 5]); + acc6 += b7 + contrib(b6, secret[s + 6]); + acc7 += b6 + contrib(b7, secret[s + 7]); + } + } +} diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/Bytes.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/Bytes.java index 59aee2c6..2a70d4c9 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/Bytes.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/Bytes.java @@ -3,6 +3,8 @@ import static java.util.Objects.requireNonNull; +import com.hedera.pbj.runtime.hashing.SixtyFourBitHashable; +import com.hedera.pbj.runtime.hashing.XXH3_64; import com.hedera.pbj.runtime.io.DataEncodingException; import com.hedera.pbj.runtime.io.ReadableSequentialData; import com.hedera.pbj.runtime.io.UnsafeUtils; @@ -29,7 +31,7 @@ * An immutable representation of a byte array. This class is designed to be efficient and usable across threads. */ @SuppressWarnings("unused") -public final class Bytes implements RandomAccessData, Comparable { +public final class Bytes implements RandomAccessData, Comparable, SixtyFourBitHashable { /** An instance of an empty {@link Bytes} */ public static final Bytes EMPTY = new Bytes(new byte[0]); @@ -66,7 +68,7 @@ public final class Bytes implements RandomAccessData, Comparable { /** * The hash code of this {@link Bytes}. This is cached to avoid recomputing it multiple times. */ - private int hashCode = 0; + private long hashCode = 0; /** * Create a new ByteOverByteBuffer over given byte array. This does not copy data it just wraps so @@ -192,9 +194,10 @@ public static Bytes merge(@NonNull final Bytes bytes1, @NonNull final Bytes byte /** * Returns the first byte offset of {@code needle} inside {@code haystack}, * or –1 if it is not present. - * + *

* Offsets are *relative to the start of the Bytes slice*, so 0 means * “starts exactly at haystack.start”. + *

*/ public static int indexOf(@NonNull final Bytes haystack, @NonNull final Bytes needle) { requireNonNull(haystack); @@ -536,12 +539,17 @@ public boolean equals(@Nullable final Object o) { */ @Override public int hashCode() { + return (int)hashCode64(); + } + + /** + * Compute 64-bit hash code for Bytes based on all bytes of content + * + * @return unique for any given content + */ + public long hashCode64() { if (hashCode == 0) { - int h = 1; - for (int i = start + length - 1; i >= start; i--) { - h = 31 * h + UnsafeUtils.getArrayByteNoChecks(buffer, i); - } - hashCode = h; + hashCode = XXH3_64.DEFAULT_INSTANCE.hashBytesToLong(buffer, start, length); } return hashCode; } diff --git a/pbj-core/pbj-runtime/src/main/java/module-info.java b/pbj-core/pbj-runtime/src/main/java/module-info.java index 63d5f648..c0d1907e 100644 --- a/pbj-core/pbj-runtime/src/main/java/module-info.java +++ b/pbj-core/pbj-runtime/src/main/java/module-info.java @@ -12,4 +12,5 @@ exports com.hedera.pbj.runtime.io.buffer; exports com.hedera.pbj.runtime.jsonparser; exports com.hedera.pbj.runtime.grpc; + exports com.hedera.pbj.runtime.hashing; } diff --git a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/hashing/XXH3StreamingTest.java b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/hashing/XXH3StreamingTest.java new file mode 100644 index 00000000..ce209300 --- /dev/null +++ b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/hashing/XXH3StreamingTest.java @@ -0,0 +1,187 @@ +package com.hedera.pbj.runtime.hashing; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.hedera.pbj.runtime.hashing.XXH3_64.HashingWritableSequentialData; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import com.hedera.pbj.runtime.io.buffer.Bytes; +import com.hedera.pbj.runtime.io.buffer.RandomAccessData; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.ByteOrder; +import java.util.Random; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +public class XXH3StreamingTest { + /** VarHandle for reading and writing longs in little-endian byte order. */ + private static final VarHandle LONG_HANDLE = + MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN); + /** VarHandle for reading and writing integers in little-endian byte order. */ + private static final VarHandle INT_HANDLE = + MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.LITTLE_ENDIAN); + + @Test + @DisplayName("Test for 32-bit integer an float handling in XXH3 streaming") + public void testLimitAndCapacity() { + final var hashingStream = XXH3_64.DEFAULT_INSTANCE.hashingWritableSequentialData(); + assertEquals(Long.MAX_VALUE, hashingStream.limit()); + assertEquals(Long.MAX_VALUE, hashingStream.capacity()); + hashingStream.writeInt(123456789); + assertEquals(Integer.BYTES, hashingStream.position()); + assertEquals(Long.MAX_VALUE, hashingStream.limit()); + assertEquals(Long.MAX_VALUE, hashingStream.capacity()); + hashingStream.reset(); + assertEquals(0, hashingStream.position()); + } + + /** + * Test for 32-bit integer and float handling in XXH3 streaming. + */ + @Test + @DisplayName("Test for 32-bit integer an float handling in XXH3 streaming") + public void test32Bit() { + final int value = 123456789; + final byte[] bytes = new byte[4]; + INT_HANDLE.set(bytes, 0, value); + final long simpleHash = XXH3_64.DEFAULT_INSTANCE.hashBytesToLong(bytes, 0, 4); + final var hashingStream = XXH3_64.DEFAULT_INSTANCE.hashingWritableSequentialData(); + hashingStream.writeInt(value); + assertEquals(simpleHash, hashingStream.computeHash()); + assertEquals(Integer.BYTES, hashingStream.position()); + hashingStream.reset(); + hashingStream.writeInt(value, ByteOrder.LITTLE_ENDIAN); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeInt(value, ByteOrder.BIG_ENDIAN); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeUnsignedInt(value); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeUnsignedInt(value, ByteOrder.LITTLE_ENDIAN); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeUnsignedInt(value, ByteOrder.BIG_ENDIAN); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeVarInt(value, true); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeVarInt(value, false); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + // Convert int to float and write it + final float floatValue = Float.intBitsToFloat(value); + hashingStream.writeFloat(floatValue); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeFloat(floatValue, ByteOrder.LITTLE_ENDIAN); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeFloat(floatValue, ByteOrder.BIG_ENDIAN); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + } + + @Test + @DisplayName("Test for 64-bit long handling in XXH3 streaming") + public void test64Bit() { + final long value = 1234567890123456789L; + final byte[] bytes = new byte[8]; + LONG_HANDLE.set(bytes, 0, value); + final long simpleHash = XXH3_64.DEFAULT_INSTANCE.hashBytesToLong(bytes, 0, 8); + final var hashingStream = XXH3_64.DEFAULT_INSTANCE.hashingWritableSequentialData(); + hashingStream.writeLong(value); + assertEquals(simpleHash, hashingStream.computeHash()); + assertEquals(Long.BYTES, hashingStream.position()); + hashingStream.reset(); + hashingStream.writeLong(value, ByteOrder.LITTLE_ENDIAN); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeLong(value, ByteOrder.BIG_ENDIAN); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeVarLong(value, true); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeVarLong(value, false); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + // Convert long to double and write it + final double doubleValue = Double.longBitsToDouble(value); + hashingStream.writeDouble(doubleValue); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeDouble(doubleValue, ByteOrder.LITTLE_ENDIAN); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeDouble(doubleValue, ByteOrder.BIG_ENDIAN); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + } + + @Test + @DisplayName("Test for byte methods in XXH3 streaming") + public void testByteMethods() { + final byte[] bytes = new byte[128]; + new Random(91824819480L).nextBytes(bytes); + final long simpleHash = XXH3_64.DEFAULT_INSTANCE.hashBytesToLong(bytes, 0, bytes.length); + final HashingWritableSequentialData hashingStream = XXH3_64.DEFAULT_INSTANCE.hashingWritableSequentialData(); + // byte arrays + hashingStream.writeBytes(bytes, 0, bytes.length); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + hashingStream.writeBytes(bytes); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + for (byte aByte : bytes) { + hashingStream.writeByte(aByte); + } + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + // BufferedData + BufferedData bufferedData = BufferedData.wrap(bytes); + hashingStream.writeBytes(bufferedData); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + // ByteBuffer + java.nio.ByteBuffer byteBuffer = java.nio.ByteBuffer.wrap(bytes); + hashingStream.writeBytes(byteBuffer); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + // RandomAccessData + RandomAccessData randomAccessData = Bytes.wrap(bytes); + hashingStream.writeBytes(randomAccessData); + assertEquals(simpleHash, hashingStream.computeHash()); + hashingStream.reset(); + // === subsets === + int offset = 10; + int length = 50; + final long simpleSubsetHash = XXH3_64.DEFAULT_INSTANCE.hashBytesToLong(bytes, offset, length); + // byte arrays + hashingStream.writeBytes(bytes, offset, length); + assertEquals(simpleSubsetHash, hashingStream.computeHash()); + hashingStream.reset(); + for (int i = offset; i < offset + length; i++) { + hashingStream.writeByte(bytes[i]); + } + assertEquals(simpleSubsetHash, hashingStream.computeHash()); + hashingStream.reset(); + // BufferedData + BufferedData bufferedSubsetData = BufferedData.wrap(bytes, offset, length); + hashingStream.writeBytes(bufferedSubsetData); + assertEquals(simpleSubsetHash, hashingStream.computeHash()); + hashingStream.reset(); + // ByteBuffer + java.nio.ByteBuffer byteSubsetBuffer = java.nio.ByteBuffer.wrap(bytes, offset, length); + hashingStream.writeBytes(byteSubsetBuffer); + assertEquals(simpleSubsetHash, hashingStream.computeHash()); + hashingStream.reset(); + // RandomAccessData + RandomAccessData randomSubsetAccessData = Bytes.wrap(bytes, offset, length); + hashingStream.writeBytes(randomSubsetAccessData); + assertEquals(simpleSubsetHash, hashingStream.computeHash()); + hashingStream.reset(); + } +} diff --git a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/hashing/XXH3Test.java b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/hashing/XXH3Test.java new file mode 100644 index 00000000..f489b968 --- /dev/null +++ b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/hashing/XXH3Test.java @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.runtime.hashing; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.util.HexFormat; +import java.util.Random; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.IntStream; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +public class XXH3Test { + + /** + * This test checks the hash of the string "hello world". The expected hash is computed using the command line + * tool `xxhsum` with the `-H3` option. + */ + @Test + @DisplayName("Test for the string 'hello world'") + public void helloWorldTest() { + byte[] inputBytes = "hello world".getBytes(); + // Assuming XXH3.hash() is a method that computes the hash + long hash = XXH3_64.DEFAULT_INSTANCE.hashBytesToLong(inputBytes, 0, inputBytes.length); + // hello world expected hash in hex produced with command line -> echo -n "hello world" | xxhsum -H3 + String expectedHash = "d447b1ea40e6988b"; + assertEquals(expectedHash, Long.toHexString(hash)); + } + + /** + * This test checks the hash of the byte sequence CAFEBABE, which is often used as a magic number in Java class + * files. + */ + @Test + @DisplayName("Test for the CAFEBABE byte sequence") + public void cafeBabyTest() { + byte[] inputBytes = HexFormat.of().parseHex("CAFEBABE"); + // Assuming XXH3.hash() is a method that computes the hash + long hash = XXH3_64.DEFAULT_INSTANCE.hashBytesToLong(inputBytes, 0, inputBytes.length); + // hello world expected hash in hex produced with command line -> echo CAFEBABE | xxd -r -p | xxhsum -H3 + String expectedHash = "36afb8d0770d97ea"; + assertEquals(expectedHash, Long.toHexString(hash)); + } + + /** + * This test checks the hash of a large random data set against the `xxhsum` command line tool if it is available. + * It uses a large number of random byte arrays to ensure that the hash function behaves correctly across a wide + * range of inputs. + */ + @Test + @DisplayName("Test random data against xxhsum if available") + void testRandomDataAgainstXxhsumIfAvailable() { + Assumptions.assumeTrue(isXXHSumAvailable(), "xxhsum not available, skipping test"); + // test with a large random data set + Random random = new Random(18971947891479L); + final AtomicBoolean allMatch = new AtomicBoolean(true); + IntStream.range(0, 1_000).parallel().forEach(i -> { + byte[] randomData = new byte[1 + random.nextInt(128)]; + random.nextBytes(randomData); + long testCodeHashResult = XXH3_64.DEFAULT_INSTANCE.hashBytesToLong(randomData, 0, randomData.length); + long referenceExpectedHash = xxh364HashWithCommandLine(randomData, 0, randomData.length); + assertEquals( + referenceExpectedHash, + testCodeHashResult, + "Mismatch for random data " + i + ": Input: " + + HexFormat.of().formatHex(randomData) + + ", Expected xxhsum: " + Long.toHexString(referenceExpectedHash) + + ", XXH3_64: " + Long.toHexString(testCodeHashResult)); + if (testCodeHashResult != referenceExpectedHash) { + allMatch.set(false); + } + }); + assertTrue(allMatch.get()); + } + + /** + * This class checks if the `xxhsum` command line tool is available on the system. + * It does this by trying to execute `xxhsum --version` and checking the exit code. + */ + public static boolean isXXHSumAvailable() { + try { + Process process = new ProcessBuilder("xxhsum", "--version") + .redirectErrorStream(true) + .start(); + int exitCode = process.waitFor(); + return exitCode == 0; + } catch (IOException | InterruptedException e) { + return false; + } + } + + /** + * This method computes the XXH3-64 hash of the given byte array using the `xxhsum` command line tool. + * It writes the bytes to the standard input of `xxhsum` and reads the output. + * + * @param bytes The byte array to hash. + * @param start The starting index in the byte array. + * @param length The number of bytes to hash. + * @return The computed hash as a long value. + */ + public static long xxh364HashWithCommandLine(final byte[] bytes, int start, int length) { + String result; + ProcessBuilder pb = new ProcessBuilder("xxhsum", "-H" + 3, "-"); + Process process; + try { + process = pb.start(); + // Write input and close output to signal EOF to xxhsum + try (var out = process.getOutputStream()) { + out.write(bytes, start, length); + out.flush(); + } + // Read result from input stream + String resultString1; + try (var in = process.getInputStream()) { + var resultBytes = in.readAllBytes(); + resultString1 = new String(resultBytes).trim(); + } + // Drain error stream to avoid blocking + try (var err = process.getErrorStream()) { + var errorBytes = err.readAllBytes(); + if (errorBytes.length > 0) { + String errorString = new String(errorBytes).trim(); + if (!errorString.isEmpty()) { + throw new RuntimeException("Error from xxhsum: " + errorString); + } + } + } + process.waitFor(); + result = resultString1; + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + final String resultString = result; + final String resultHexString = resultString.substring(resultString.indexOf('_') + 1, resultString.indexOf(' ')); + return Long.parseUnsignedLong(resultHexString, 16); + } +} diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/ModelObjHashCodeBench.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/ModelObjHashCodeBench.java new file mode 100644 index 00000000..e9485c53 --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/ModelObjHashCodeBench.java @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.jmh; + +import com.hedera.pbj.integration.EverythingTestData; +import com.hedera.pbj.runtime.io.buffer.Bytes; +import com.hedera.pbj.test.proto.pbj.Everything; +import com.hedera.pbj.test.proto.pbj.Hasheval; +import com.hedera.pbj.test.proto.pbj.Hasheval2; +import com.hedera.pbj.test.proto.pbj.Suit; +import com.hedera.pbj.test.proto.pbj.TimestampTest; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@SuppressWarnings("unused") +@State(Scope.Benchmark) +@Fork(1) +@Warmup(iterations = 4, time = 2) +@Measurement(iterations = 5, time = 2) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@BenchmarkMode(Mode.AverageTime) +public class ModelObjHashCodeBench { + + @Benchmark + public void simpleObject(Blackhole blackhole) { + TimestampTest tst = new TimestampTest(987L, 123); + blackhole.consume(tst.hashCode()); + } + + @Benchmark + public void everythingObject(Blackhole blackhole) { + Everything e = EverythingTestData.EVERYTHING.copyBuilder().build(); + blackhole.consume(e.hashCode()); + } + + @Benchmark + public void bigObject(Blackhole blackhole) { + Hasheval2 complexObj = new Hasheval2( + 13, + -1262, + 2236, + 326, + -27, + 123f, + 7L, + -7L, + 123L, + 234L, + -345L, + 456.789D, + true, + Suit.ACES, + new Hasheval( + 1109840, + -1414, + 25151, + 31515, + -236, + 123f, + 7347L, + -7L, + 1233474347347L, + 234L, + -345L, + 456.789D, + true, + Suit.ACES, + new TimestampTest(987L, 123), + "FooBarKKKKHHHHOIOIOI", + Bytes.wrap(new byte[] {127, 2, 3, 123, 48, 6, 7, (byte) 255})), + "FooBarKKKKHHHHOIOIOI", + Bytes.wrap(new byte[] {81, 52, 13, 94, 85, 66, 7, (byte) 255})); + blackhole.consume(complexObj.hashCode()); + } +} diff --git a/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/CountingArray.java b/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/CountingArray.java new file mode 100644 index 00000000..ef103a89 --- /dev/null +++ b/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/CountingArray.java @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.hashing; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.IntStream; + +/** + * An array that counts occurrences of indices in the range [0, 4,294,967,295]. It uses 4 byte arrays to store counts + * up to 250 and an overflow map for counts above 250. + */ +final class CountingArray { + /** Maximum value for the index, 2^32 */ + private static final long MAX_VALUE = 4_294_967_296L; // 2^32 + /** 4x 1 GB arrays to split the integer space into 4 parts */ + private final byte[][] counts = new byte[4][1_073_741_824]; + /** Overflow map for counts above 250 */ + private final Map overflowMap = new HashMap<>(); + + /** + * Clears all the counts + */ + public void clear() { + for (byte[] subArray : counts) { + Arrays.fill(subArray, (byte) 0); + } + overflowMap.clear(); + } + + /** + * Returns the number of counts greater than zero across all indices. + * This includes counts in the overflow map. + * + * @return the number of counts greater than zero + */ + public long numberOfGreaterThanZeroCounts() { + long count = Arrays.stream(counts) + .parallel() + .mapToLong(subArray -> + // Count values > 0 and <= 250 in each subArray + IntStream.range(0, subArray.length) + .map(i -> Byte.toUnsignedInt(subArray[i])) + .filter(unsignedValue -> unsignedValue > 0 && unsignedValue <= 250) + .count()) + .sum(); + return count + + overflowMap.values().stream().mapToLong(Integer::longValue).sum(); + } + + /** + * Returns the number of counts greater than zero across all indices. + * This includes counts in the overflow map. + * + * @return the number of counts greater than one + */ + public long numberOfGreaterThanOneCounts() { + long count = Arrays.stream(counts) + .parallel() + .mapToLong(subArray -> + // Count values > 1 and <= 250 in each subArray + IntStream.range(0, subArray.length) + .map(i -> Byte.toUnsignedInt(subArray[i])) + .filter(unsignedValue -> unsignedValue > 1 && unsignedValue <= 250) + .count()) + .sum(); + return count + + overflowMap.values().stream().mapToLong(Integer::longValue).sum(); + } + + /** + * Returns the number of 0 counts across all indices. + * + * @return the number of zero counts + */ + public long numberOfZeroCounts() { + long count = 0; + for (byte[] subArray : counts) { + for (byte b : subArray) { + if (b == 0) { + count++; + } + } + } + return count; + } + + /** + * Increments the count for the given index. + * + * @param index the index to increment, must be in the range [0, 4,294,967,295] + */ + public void increment(long index) { + if (index < 0 || index >= MAX_VALUE) { + throw new IndexOutOfBoundsException("index: " + index); + } + int subArrayIndex = (int) (index >>> 30); // 2^30 = 1 GB + int indexInSubArray = (int) (index & 0x3FFFFFFF); // 2^30 - 1 + byte[] subArray = counts[subArrayIndex]; + int currentValueUnsigned = Byte.toUnsignedInt(subArray[indexInSubArray]); + if (currentValueUnsigned <= 250) { + // Increment the count in the sub-array using value as unsigned byte + final int newValueUnsigned = (currentValueUnsigned + 1) & 0xFF; // wrap at 255 + subArray[indexInSubArray] = (byte) newValueUnsigned; + } else { + // Handle overflow + subArray[indexInSubArray] = Byte.MIN_VALUE; // marker for overflow + overflowMap.compute(index, (key, value) -> value == null ? 250 : value + 1); + } + } + + /** + * Prints the statistics of the counts, including the number of occurrences for each value from 0 to 250, + * and the overflow counts. + */ + public void printStats(final StringBuilder resultStr) { + // count up number of bytes with each value 0 to 250 + long[] valueCounts = new long[251]; // 0 to 250 + for (byte[] subArray : counts) { + for (byte b : subArray) { + int unsignedValue = Byte.toUnsignedInt(b); + if (unsignedValue <= 250) { + valueCounts[unsignedValue]++; + } + } + } + // print the counts + resultStr.append(" Counts:"); + for (int i = 0; i <= 250; i++) { + long count = valueCounts[i]; + if (count > 0) { + resultStr.append(String.format(" %d=%,d", i, count)); + } + } + // print overflow map sorted by index + resultStr.append("\n Overflow counts: " + overflowMap.size()); + // overflowMap.entrySet().stream() + // .sorted(Map.Entry.comparingByKey()) + // .forEach(entry -> resultStr.append(String.format(" %d=%,d", entry.getKey(), + // entry.getValue()))); + resultStr.append("\n"); + } +} diff --git a/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/NonCryptographicHashQualityStateKeyTest.java b/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/NonCryptographicHashQualityStateKeyTest.java new file mode 100644 index 00000000..348289a1 --- /dev/null +++ b/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/NonCryptographicHashQualityStateKeyTest.java @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.hashing; + +import com.hedera.hapi.node.base.AccountID; +import com.hedera.hapi.node.base.NftID; +import com.hedera.hapi.node.base.TokenID; +import com.hedera.hapi.node.state.common.EntityIDPair; +import com.hedera.pbj.runtime.hashing.XXH3_64; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import com.hedera.pbj.test.proto.java.teststate.pbj.integration.tests.StateKey; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.util.Arrays; +import java.util.Map; +import java.util.Random; +import java.util.stream.Collectors; + +/** + * A test to evaluate the quality of non-cryptographic hash functions by checking how many unique hashes can be + * generated from 4.5 billion StateKey inputs. + */ +public final class NonCryptographicHashQualityStateKeyTest { + private static final int NUM_BUCKETS = 33_554_432; // 2^25 33 million buckets + // Where to place result files + private static final Path OUTPUT_ROOT = Path.of("hash_quality_results"); + + public static void main(String[] args) throws Exception { + final Path outputDir = createOutputDirectory(); + System.out.println("Testing non-cryptographic hash quality - Random StateKeys, 4.5 billion inputs"); + final CountingArray counts = new CountingArray(); // 4 billion counts + final long START_TIME = System.currentTimeMillis(); + final long NUM_INPUTS = 4_500_000_000L; // 4.5 billion inputs + // final long NUM_INPUTS = 50_000_000L; // 4.5 billion inputs + final byte[] bufferArray = new byte[1024]; + final BufferedData bufferedData = BufferedData.wrap(bufferArray); + final int[] bucketCounts = new int[NUM_BUCKETS]; // 2^25 33 million buckets + final Random random = new Random(2518643515415654L); // Seed for reproducibility + long lengthSum = 0; + long minLength = Integer.MAX_VALUE; + long maxLength = Integer.MIN_VALUE; + + for (long i = 0; i < NUM_INPUTS; i++) { + if (i % 10_000_000 == 0) { + long averageLength = lengthSum / (i + 1); + System.out.printf( + "\r Progress: %.2f%% Length: avg=%,d, min=%,d, max=%,d", + (i * 100.0) / NUM_INPUTS, averageLength, minLength, maxLength); + System.out.flush(); + } + // create a sample StateKey that will be hashed + StateKey stateKey = + switch (random.nextInt(4)) { + case 0 -> + StateKey.newBuilder() + .accountId(AccountID.newBuilder().accountNum(i)) + .build(); + case 1 -> + StateKey.newBuilder() + .tokenId(TokenID.newBuilder().tokenNum(i)) + .build(); + case 2 -> + StateKey.newBuilder() + .entityIdPair(EntityIDPair.newBuilder() + .accountId(AccountID.newBuilder().accountNum(i)) + .tokenId(TokenID.newBuilder().tokenNum(i))) + .build(); + case 3 -> + StateKey.newBuilder() + .nftId(NftID.newBuilder() + .tokenId(TokenID.newBuilder().tokenNum(i)) + .serialNumber(random.nextLong(1_000_000))) + .build(); + default -> throw new IllegalStateException("Unexpected value: "); + }; + bufferedData.position(0); + StateKey.PROTOBUF.write(stateKey, bufferedData); + int lengthWritten = (int) bufferedData.position(); + lengthSum += lengthWritten; + if (lengthWritten < minLength) { + minLength = lengthWritten; + } + if (lengthWritten > maxLength) { + maxLength = lengthWritten; + } + + final int hash32 = (int) XXH3_64.DEFAULT_INSTANCE.hashBytesToLong(bufferArray, 0, lengthWritten); + counts.increment(Integer.toUnsignedLong(hash32)); + long bucket = computeBucketIndex(hash32); + bucketCounts[(int) bucket]++; + } + + long numUniqueHashes = counts.numberOfGreaterThanZeroCounts(); + long hashCollisions = counts.numberOfGreaterThanOneCounts(); + double collisionRate = (double) hashCollisions / NUM_INPUTS * 100; + final long END_TIME = System.currentTimeMillis(); + StringBuilder resultStr = new StringBuilder(String.format( + "%n%s => Number of unique hashes: %,d, hash collisions: %,d, collision rate: %.2f%% time taken: %.3f seconds%n", + "XXH3_64", + numUniqueHashes, + hashCollisions, + collisionRate, + (END_TIME - START_TIME) / 1000.0)); + counts.printStats(resultStr); + // print the distribution of hash buckets sorted by bucket index + // convert the bucketCounts into the number of buckets with each count + Map bucketDistribution = Arrays.stream(bucketCounts) + .mapToObj(count -> { + if (count == 0) { + return "0"; + } else if (count <= 10) { + return "1->10"; + } else if (count <= 100) { + return "11->100"; + } else if (count <= 1000) { + return "101->1,000"; + } else if (count <= 10000) { + return "1,001->10,000"; + } else if (count <= 100_000) { + return "10,001->100,000"; + } else if (count <= 250_000) { + return "100,001->250,000"; + } else if (count <= 500_000) { + return "250,001->500,000"; + } else { + return "500,000+"; + } + }) + .collect(Collectors.toMap(count -> count, count -> 1, Integer::sum)); + resultStr.append(" Bucket distribution: "); + bucketDistribution.forEach((category, count) -> { + resultStr.append(String.format(" %s=%,d", category, count)); + }); + resultStr.append("\n"); + // print the total number of buckets + System.out.print(resultStr); + System.out.flush(); + + // Export detailed per-bucket counts for plotting + exportBucketCounts(outputDir, "XXH3_64", bucketCounts, NUM_INPUTS, NUM_BUCKETS); + } + + /** + *

Code direct from HalfDiskHashMap, only change is NUM_BUCKETS

+ * + * Computes which bucket a key with the given hash falls. Depends on the fact the numOfBuckets + * is a power of two. Based on same calculation that is used in java HashMap. + * + * @param keyHash the int hash for key + * @return the index of the bucket that key falls in + */ + private static int computeBucketIndex(final int keyHash) { + return (NUM_BUCKETS - 1) & keyHash; + } + + /** + * Creates a timestamped output directory like: + * hash_quality_results/run_YYYYMMDD_HHMMSSZ + */ + private static Path createOutputDirectory() throws IOException { + final String ts = DateTimeFormatter.ofPattern("yyyyMMdd_HHmmssX").format(ZonedDateTime.now(ZoneOffset.UTC)); + final Path dir = OUTPUT_ROOT.resolve("run_" + ts); + Files.createDirectories(dir); + return dir; + } + + /** + * Exports the per-bucket counts in a compact binary format and writes a sidecar JSON metadata file. + * + * Format: + * - Data file: _counts_i32_le.bin (little-endian 32-bit signed ints), length == numBuckets. + * - Metadata: .meta.json + */ + private static void exportBucketCounts( + final Path outputDir, + final String algorithmName, + final int[] bucketCounts, + final long numInputs, + final int numBuckets) + throws IOException { + final String safeAlg = algorithmName.replaceAll("[^A-Za-z0-9_.-]", "_"); + final Path dataFile = outputDir.resolve(safeAlg + "_counts_i32_le.bin"); + final Path metaFile = outputDir.resolve(safeAlg + ".meta.json"); + + // Write binary counts in little-endian in chunks to avoid large buffers + final int chunkSize = 1_048_576; // 1M ints (~4 MiB) + try (FileChannel ch = FileChannel.open( + dataFile, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE)) { + final ByteBuffer buf = + ByteBuffer.allocateDirect(chunkSize * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); + int written = 0; + while (written < numBuckets) { + buf.clear(); + final int end = Math.min(written + chunkSize, numBuckets); + for (int i = written; i < end; i++) { + buf.putInt(bucketCounts[i]); + } + buf.flip(); + while (buf.hasRemaining()) { + ch.write(buf); + } + written = end; + } + ch.force(true); + } + + // Metadata JSON + final double lambda = (double) numInputs / (double) numBuckets; + final String metaJson = "{\n" + " \"algorithm\": \"" + + escapeJson(algorithmName) + "\",\n" + " \"numBuckets\": " + + numBuckets + ",\n" + " \"numInputs\": " + + numInputs + ",\n" + " \"hashBits\": 32,\n" + + " \"bucketIndexFormula\": \"(NUM_BUCKETS - 1) & hash\",\n" + + " \"countsFile\": \"" + + escapeJson(dataFile.getFileName().toString()) + "\",\n" + " \"countsDtype\": \"int32\",\n" + + " \"endianness\": \"little\",\n" + + " \"expectedMeanPerBucket\": " + + String.format("%.6f", lambda) + "\n" + "}\n"; + Files.writeString( + metaFile, + metaJson, + StandardCharsets.UTF_8, + StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING, + StandardOpenOption.WRITE); + } + + private static String escapeJson(String s) { + return s.replace("\\", "\\\\").replace("\"", "\\\""); + } +} diff --git a/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/PbjObjHashQualityStateKeyTest.java b/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/PbjObjHashQualityStateKeyTest.java new file mode 100644 index 00000000..8ed672fd --- /dev/null +++ b/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/PbjObjHashQualityStateKeyTest.java @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.hashing; + +import com.hedera.hapi.node.base.AccountID; +import com.hedera.hapi.node.base.NftID; +import com.hedera.hapi.node.base.TokenID; +import com.hedera.hapi.node.state.common.EntityIDPair; +import com.hedera.pbj.runtime.hashing.XXH3_64; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import com.hedera.pbj.test.proto.java.teststate.pbj.integration.tests.StateKey; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.util.Arrays; +import java.util.Map; +import java.util.Random; +import java.util.stream.Collectors; + +/** + * A test to evaluate the quality of non-cryptographic hash functions by checking how many unique hashes can be + * generated from 4.5 billion StateKey inputs. + */ +public final class PbjObjHashQualityStateKeyTest { + private static final int NUM_BUCKETS = 33_554_432; // 2^25 33 million buckets + // Where to place result files + private static final Path OUTPUT_ROOT = Path.of("hash_quality_results"); + + public static void main(String[] args) throws Exception { + final Path outputDir = createOutputDirectory(); + System.out.println("Testing non-cryptographic hash quality - Random StateKeys, 4.5 billion inputs"); + final CountingArray counts = new CountingArray(); // 4 billion counts + final long START_TIME = System.currentTimeMillis(); + final long NUM_INPUTS = 4_500_000_000L; // 4.5 billion inputs + // final long NUM_INPUTS = 50_000_000L; // 4.5 billion inputs + final byte[] bufferArray = new byte[1024]; + final BufferedData bufferedData = BufferedData.wrap(bufferArray); + final int[] bucketCounts = new int[NUM_BUCKETS]; // 2^25 33 million buckets + final Random random = new Random(2518643515415654L); // Seed for reproducibility + long lengthSum = 0; + long minLength = Integer.MAX_VALUE; + long maxLength = Integer.MIN_VALUE; + + for (long i = 0; i < NUM_INPUTS; i++) { + if (i % 10_000_000 == 0) { + long averageLength = lengthSum / (i + 1); + System.out.printf( + "\r Progress: %.2f%% Length: avg=%,d, min=%,d, max=%,d", + (i * 100.0) / NUM_INPUTS, averageLength, minLength, maxLength); + System.out.flush(); + } + // create a sample StateKey that will be hashed + StateKey stateKey = + switch (random.nextInt(4)) { + case 0 -> + StateKey.newBuilder() + .accountId(AccountID.newBuilder().accountNum(i)) + .build(); + case 1 -> + StateKey.newBuilder() + .tokenId(TokenID.newBuilder().tokenNum(i)) + .build(); + case 2 -> + StateKey.newBuilder() + .entityIdPair(EntityIDPair.newBuilder() + .accountId(AccountID.newBuilder().accountNum(i)) + .tokenId(TokenID.newBuilder().tokenNum(i))) + .build(); + case 3 -> + StateKey.newBuilder() + .nftId(NftID.newBuilder() + .tokenId(TokenID.newBuilder().tokenNum(i)) + .serialNumber(random.nextLong(1_000_000))) + .build(); + default -> throw new IllegalStateException("Unexpected value: "); + }; + bufferedData.position(0); + final int hash32 = stateKey.hashCode(); + counts.increment(Integer.toUnsignedLong(hash32)); + long bucket = computeBucketIndex(hash32); + bucketCounts[(int) bucket]++; + } + + long numUniqueHashes = counts.numberOfGreaterThanZeroCounts(); + long hashCollisions = counts.numberOfGreaterThanOneCounts(); + double collisionRate = (double) hashCollisions / NUM_INPUTS * 100; + final long END_TIME = System.currentTimeMillis(); + StringBuilder resultStr = new StringBuilder(String.format( + "%n%s => Number of unique hashes: %,d, hash collisions: %,d, collision rate: %.2f%% time taken: %.3f seconds%n", + "PBJ_XXH3", + numUniqueHashes, + hashCollisions, + collisionRate, + (END_TIME - START_TIME) / 1000.0)); + counts.printStats(resultStr); + // print the distribution of hash buckets sorted by bucket index + // convert the bucketCounts into the number of buckets with each count + Map bucketDistribution = Arrays.stream(bucketCounts) + .mapToObj(count -> { + if (count == 0) { + return "0"; + } else if (count <= 10) { + return "1->10"; + } else if (count <= 100) { + return "11->100"; + } else if (count <= 1000) { + return "101->1,000"; + } else if (count <= 10000) { + return "1,001->10,000"; + } else if (count <= 100_000) { + return "10,001->100,000"; + } else if (count <= 250_000) { + return "100,001->250,000"; + } else if (count <= 500_000) { + return "250,001->500,000"; + } else { + return "500,000+"; + } + }) + .collect(Collectors.toMap(count -> count, count -> 1, Integer::sum)); + resultStr.append(" Bucket distribution: "); + bucketDistribution.forEach((category, count) -> { + resultStr.append(String.format(" %s=%,d", category, count)); + }); + resultStr.append("\n"); + // print the total number of buckets + System.out.print(resultStr); + System.out.flush(); + + // Export detailed per-bucket counts for plotting + exportBucketCounts(outputDir, "PBJ_XXH3", bucketCounts, NUM_INPUTS, NUM_BUCKETS); + } + + /** + *

Code direct from HalfDiskHashMap, only change is NUM_BUCKETS

+ * + * Computes which bucket a key with the given hash falls. Depends on the fact the numOfBuckets + * is a power of two. Based on same calculation that is used in java HashMap. + * + * @param keyHash the int hash for key + * @return the index of the bucket that key falls in + */ + private static int computeBucketIndex(final int keyHash) { + return (NUM_BUCKETS - 1) & keyHash; + } + + /** + * Creates a timestamped output directory like: + * hash_quality_results/run_YYYYMMDD_HHMMSSZ + */ + private static Path createOutputDirectory() throws IOException { + final String ts = DateTimeFormatter.ofPattern("yyyyMMdd_HHmmssX").format(ZonedDateTime.now(ZoneOffset.UTC)); + final Path dir = OUTPUT_ROOT.resolve("run_" + ts); + Files.createDirectories(dir); + return dir; + } + + /** + * Exports the per-bucket counts in a compact binary format and writes a sidecar JSON metadata file. + * + * Format: + * - Data file: _counts_i32_le.bin (little-endian 32-bit signed ints), length == numBuckets. + * - Metadata: .meta.json + */ + private static void exportBucketCounts( + final Path outputDir, + final String algorithmName, + final int[] bucketCounts, + final long numInputs, + final int numBuckets) + throws IOException { + final String safeAlg = algorithmName.replaceAll("[^A-Za-z0-9_.-]", "_"); + final Path dataFile = outputDir.resolve(safeAlg + "_counts_i32_le.bin"); + final Path metaFile = outputDir.resolve(safeAlg + ".meta.json"); + + // Write binary counts in little-endian in chunks to avoid large buffers + final int chunkSize = 1_048_576; // 1M ints (~4 MiB) + try (FileChannel ch = FileChannel.open( + dataFile, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE)) { + final ByteBuffer buf = + ByteBuffer.allocateDirect(chunkSize * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); + int written = 0; + while (written < numBuckets) { + buf.clear(); + final int end = Math.min(written + chunkSize, numBuckets); + for (int i = written; i < end; i++) { + buf.putInt(bucketCounts[i]); + } + buf.flip(); + while (buf.hasRemaining()) { + ch.write(buf); + } + written = end; + } + ch.force(true); + } + + // Metadata JSON + final double lambda = (double) numInputs / (double) numBuckets; + final String metaJson = "{\n" + " \"algorithm\": \"" + + escapeJson(algorithmName) + "\",\n" + " \"numBuckets\": " + + numBuckets + ",\n" + " \"numInputs\": " + + numInputs + ",\n" + " \"hashBits\": 32,\n" + + " \"bucketIndexFormula\": \"(NUM_BUCKETS - 1) & hash\",\n" + + " \"countsFile\": \"" + + escapeJson(dataFile.getFileName().toString()) + "\",\n" + " \"countsDtype\": \"int32\",\n" + + " \"endianness\": \"little\",\n" + + " \"expectedMeanPerBucket\": " + + String.format("%.6f", lambda) + "\n" + "}\n"; + Files.writeString( + metaFile, + metaJson, + StandardCharsets.UTF_8, + StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING, + StandardOpenOption.WRITE); + } + + private static String escapeJson(String s) { + return s.replace("\\", "\\\\").replace("\"", "\\\""); + } +} diff --git a/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/scripts_plot_hash_bucket_histograms_Version3.py b/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/scripts_plot_hash_bucket_histograms_Version3.py new file mode 100644 index 00000000..3f0146d0 --- /dev/null +++ b/pbj-integration-tests/src/main/java/com/hedera/pbj/integration/hashing/scripts_plot_hash_bucket_histograms_Version3.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +Reads per-algorithm per-bucket counts exported by NonCryptographicHashQualityStateKeyTest +and plots bucket-occupancy histograms suitable for comparing hash quality. + +Input format (per algorithm): + - .meta.json : metadata with fields {algorithm, numBuckets, numInputs, countsFile, countsDtype, endianness} + - _counts_i32_le.bin : little-endian int32 array of length numBuckets with counts per bucket + +Usage: + python scripts/plot_hash_bucket_histograms.py /path/to/hash_quality_results/run_YYYYMMDD_HHMMSSZ + [--max-k 400] [--overlay] [--logy] + [--dpi 300] [--figsize 12x7] [--format svg] [--transparent] [--tight] + +Outputs: + - One image per algorithm: hist_. + - If --overlay: a combined overlay image: hist_overlay. +""" +import argparse +import glob +import json +import math +from pathlib import Path +from typing import Dict, Any, List, Tuple + +import matplotlib.pyplot as plt +import numpy as np + + +def load_algorithm(meta_path: Path) -> Tuple[Dict[str, Any], np.ndarray]: + with open(meta_path, "r", encoding="utf-8") as f: + meta = json.load(f) + counts_file = meta_path.parent / meta["countsFile"] + dtype = np.int32 + if str(meta.get("endianness", "little")).lower().startswith("little"): + dtype = np.dtype("i4") + counts = np.fromfile(counts_file, dtype=dtype) + if counts.size != int(meta["numBuckets"]): + raise ValueError(f"Counts size {counts.size} != numBuckets {meta['numBuckets']} for {meta_path}") + return meta, counts + + +def poisson_expected_counts(max_k: int, lam: float, num_buckets: int) -> np.ndarray: + """ + Compute expected number of buckets with exactly k items for k=0..max_k under Poisson(lam). + Uses stable recurrence: P(k+1) = P(k) * lam / (k+1) + """ + exp_counts = np.zeros(max_k + 1, dtype=np.float64) + p = math.exp(-lam) # P(0) + exp_counts[0] = p * num_buckets + for k in range(0, max_k): + p = p * lam / (k + 1) + exp_counts[k + 1] = p * num_buckets + return exp_counts + + +def compute_hist(counts: np.ndarray, max_k: int = None) -> Tuple[np.ndarray, np.ndarray]: + """ + Returns (k_values, num_buckets_with_k) for k in [0..max_k] + """ + hist = np.bincount(counts.astype(np.int64)) + if max_k is None: + max_k = len(hist) - 1 + else: + max_k = min(max_k, len(hist) - 1) + k = np.arange(0, max_k + 1, dtype=np.int64) + y = hist[: (max_k + 1)] + return k, y + + +def plot_per_algorithm( + meta: Dict[str, Any], + k: np.ndarray, + y: np.ndarray, + out_dir: Path, + show_poisson: bool = True, + logy: bool = False, + figsize: Tuple[float, float] = (10.0, 6.0), + dpi: int = 300, + fmt: str = "png", + transparent: bool = False, + tight: bool = False, +): + alg = meta["algorithm"] + num_buckets = int(meta["numBuckets"]) + num_inputs = int(meta["numInputs"]) + lam = num_inputs / num_buckets + + fig, ax = plt.subplots(figsize=figsize) + ax.bar(k, y, width=1.0, color="#4e79a7", alpha=0.7, label=f"Observed ({alg})", edgecolor="none") + + if show_poisson: + y_exp = poisson_expected_counts(k.max(), lam, num_buckets) + ax.plot(k, y_exp, color="#e15759", linewidth=2.0, label=f"Poisson λ={lam:.2f}") + + ax.set_title(f"Bucket occupancy histogram — {alg}\n(numInputs={num_inputs:,}, numBuckets={num_buckets:,}, λ≈{lam:.2f})") + ax.set_xlabel("Items per bucket (k)") + ax.set_ylabel("Number of buckets with exactly k items") + if logy: + ax.set_yscale("log") + ax.set_ylabel("Number of buckets (log scale)") + ax.grid(True, which="both", axis="y", linestyle=":", alpha=0.5) + ax.legend() + if tight: + fig.tight_layout() + + out_path = out_dir / f"hist_{sanitize_filename(alg)}.{fmt}" + fig.savefig(out_path, dpi=dpi, format=fmt, transparent=transparent, bbox_inches="tight" if tight else None) + plt.close(fig) + + +def plot_overlay( + alg_results: List[Tuple[Dict[str, Any], np.ndarray, np.ndarray]], + out_dir: Path, + normalize: bool = True, + logy: bool = False, + figsize: Tuple[float, float] = (11.0, 7.0), + dpi: int = 300, + fmt: str = "png", + transparent: bool = False, + tight: bool = False, +): + """ + Overlays histograms as lines for quick comparison. + If normalize=True, y is fraction of buckets instead of absolute count. + """ + fig, ax = plt.subplots(figsize=figsize) + for meta, k, y in alg_results: + label = meta["algorithm"] + if normalize: + y_plot = y / y.sum() if y.sum() > 0 else y + ax.set_ylabel("Fraction of buckets with exactly k items") + else: + y_plot = y + ax.set_ylabel("Number of buckets with exactly k items") + ax.plot(k, y_plot, linewidth=1.8, label=label) + ax.set_xlabel("Items per bucket (k)") + if logy: + ax.set_yscale("log") + ax.set_title("Bucket occupancy histograms — overlay") + ax.grid(True, which="both", axis="y", linestyle=":", alpha=0.5) + ax.legend() + if tight: + fig.tight_layout() + out_path = out_dir / f"hist_overlay.{fmt}" + fig.savefig(out_path, dpi=dpi, format=fmt, transparent=transparent, bbox_inches="tight" if tight else None) + plt.close(fig) + + +def sanitize_filename(s: str) -> str: + return "".join(c if c.isalnum() or c in "._-" else "_" for c in s) + + +def parse_figsize(s: str) -> Tuple[float, float]: + try: + w, h = s.lower().replace(" ", "").split("x", 1) + return float(w), float(h) + except Exception: + raise argparse.ArgumentTypeError("figsize must be in the form WIDTHxHEIGHT, e.g., 12x7") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("results_dir", type=str, help="Path to run directory (hash_quality_results/run_YYYYMMDD_HHMMSSZ)") + parser.add_argument("--max-k", type=int, default=None, help="Maximum k to plot (default: auto up to max observed)") + parser.add_argument("--overlay", action="store_true", help="Also produce a combined overlay plot") + parser.add_argument("--logy", action="store_true", help="Use logarithmic y-axis") + + # Export/quality options + parser.add_argument("--dpi", type=int, default=300, help="Output DPI for raster formats (PNG/JPG). Default: 300") + parser.add_argument("--figsize", type=parse_figsize, default=(10.0, 6.0), + help="Figure size in inches as WxH, e.g., 12x7. Default: 10x6") + parser.add_argument("--overlay-figsize", type=parse_figsize, default=(11.0, 7.0), + help="Overlay figure size in inches as WxH. Default: 11x7") + parser.add_argument("--format", type=str, default="png", + choices=["png", "svg", "pdf", "jpg", "jpeg"], + help="Output image format. For infinite scalability, use svg or pdf. Default: png") + parser.add_argument("--transparent", action="store_true", help="Save with transparent background") + parser.add_argument("--tight", action="store_true", help="Use tight layout and bbox_inches='tight'") + args = parser.parse_args() + + run_dir = Path(args.results_dir) + if not run_dir.exists(): + raise SystemExit(f"Directory not found: {run_dir}") + + meta_files = sorted(glob.glob(str(run_dir / "*.meta.json"))) + if not meta_files: + raise SystemExit(f"No *.meta.json files found in {run_dir}") + + # Create an output subdir for plots + out_dir = run_dir / "plots" + out_dir.mkdir(parents=True, exist_ok=True) + + overlay_data: List[Tuple[Dict[str, Any], np.ndarray, np.ndarray]] = [] + + for meta_path_str in meta_files: + meta_path = Path(meta_path_str) + meta, counts = load_algorithm(meta_path) + k, y = compute_hist(counts, max_k=args.max_k) + plot_per_algorithm( + meta, k, y, out_dir, + show_poisson=True, logy=args.logy, + figsize=args.figsize, dpi=args.dpi, fmt=args.format, + transparent=args.transparent, tight=args.tight + ) + overlay_data.append((meta, k, y)) + + if args.overlay: + # Align k-range across algorithms to the minimum common max_k + min_max_k = min(int(k[-1]) for _, k, _ in overlay_data) + aligned = [] + for meta, k, y in overlay_data: + if int(k[-1]) > min_max_k: + aligned.append((meta, k[: min_max_k + 1], y[: min_max_k + 1])) + else: + aligned.append((meta, k, y)) + plot_overlay( + aligned, out_dir, + normalize=True, logy=args.logy, + figsize=args.overlay_figsize, dpi=args.dpi, fmt=args.format, + transparent=args.transparent, tight=args.tight + ) + + print(f"Done. Plots written to: {out_dir}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pbj-integration-tests/src/main/proto/hasheval.proto b/pbj-integration-tests/src/main/proto/hasheval.proto index b819a142..008c08b5 100644 --- a/pbj-integration-tests/src/main/proto/hasheval.proto +++ b/pbj-integration-tests/src/main/proto/hasheval.proto @@ -33,3 +33,24 @@ message Hasheval { string text = 16; bytes bytesField = 17; } + + +message Hasheval2 { + int32 int32Number = 1; + sint32 sint32Number = 2; + uint32 uint32Number = 3; + fixed32 fixed32Number = 4; + sfixed32 sfixed32Number = 5; + float floatNumber = 6; + int64 int64Number = 7; + sint64 sint64Number = 8; + uint64 uint64Number = 9; + fixed64 fixed64Number = 10; + sfixed64 sfixed64Number = 11; + double doubleNumber = 12; + bool booleanField = 13; + Suit enumSuit = 14; + Hasheval subObject = 15; + string text = 16; + bytes bytesField = 17; +} \ No newline at end of file diff --git a/pbj-integration-tests/src/main/proto/teststate.proto b/pbj-integration-tests/src/main/proto/teststate.proto new file mode 100644 index 00000000..4ab79e29 --- /dev/null +++ b/pbj-integration-tests/src/main/proto/teststate.proto @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +syntax = "proto3"; + +package proto; + +import "basic_types.proto"; +import "state/common.proto"; + +option java_package = "com.hedera.pbj.test.proto.java.teststate"; +option java_multiple_files = true; +// <<>> This comment is special code for setting PBJ Compiler java package + +message StateKey { + oneof key { + AccountID account_id = 1; + TokenID token_id = 2; + EntityIDPair entity_id_pair = 3; + NftID nft_id = 4; + } +}