diff --git a/util/BUILD b/util/BUILD index 16f9316..1c245de 100644 --- a/util/BUILD +++ b/util/BUILD @@ -217,7 +217,9 @@ cc_test( "testdata/codepoints_with_freq_invalid.txt", "testdata/test_freq_data.riegeli", "testdata/invalid_test_freq_data.riegeli", - ], + ] + glob([ + "testdata/sharded/*", + ]), deps = [ ":load_codepoints", "@googletest//:gtest_main", diff --git a/util/freq_data_to_sorted_codepoints.cc b/util/freq_data_to_sorted_codepoints.cc index aae9ce1..dded655 100644 --- a/util/freq_data_to_sorted_codepoints.cc +++ b/util/freq_data_to_sorted_codepoints.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -40,7 +41,11 @@ int main(int argc, char** argv) { if (args.size() != 2) { std::cerr << "Usage:" << std::endl - << "freq_data_to_sorted_codepoints " << std::endl; + << "freq_data_to_sorted_codepoints " << std::endl + << std::endl + << "Append @* to the file name to load sharded data files. " + << "For example \"@*\" will load all files of the form -?????-of-?????" + << std::endl; return -1; } diff --git a/util/generate_riegeli_test_data.cc b/util/generate_riegeli_test_data.cc index ee4666e..b30560c 100644 --- a/util/generate_riegeli_test_data.cc +++ b/util/generate_riegeli_test_data.cc @@ -6,13 +6,17 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "riegeli/bytes/fd_writer.h" #include "riegeli/records/record_writer.h" #include "util/unicode_count.pb.h" +using absl::StrCat; + ABSL_FLAG(std::string, output_path, "", "Path to write the output file."); ABSL_FLAG(bool, include_invalid_record, false, "If set add an invalid record."); +ABSL_FLAG(bool, shard, false, "If set shard into multiple files."); namespace { @@ -41,24 +45,44 @@ absl::Status Main() { message4.add_codepoints(0x44); message4.set_count(75); - riegeli::RecordWriter writer{riegeli::FdWriter(output_path)}; - writer.WriteRecord(message1); - writer.WriteRecord(message2); - writer.WriteRecord(message3); - writer.WriteRecord(message4); - - if (absl::GetFlag(FLAGS_include_invalid_record)) { - CodepointCount message5; - message5.add_codepoints(0x46); - message5.add_codepoints(0x46); - message5.add_codepoints(0x46); - message5.set_count(75); - writer.WriteRecord(message5); - } + if (!absl::GetFlag(FLAGS_shard)) { + riegeli::RecordWriter writer{riegeli::FdWriter(output_path)}; + writer.WriteRecord(message1); + writer.WriteRecord(message2); + writer.WriteRecord(message3); + writer.WriteRecord(message4); - if (!writer.Close()) { - return writer.status(); + if (absl::GetFlag(FLAGS_include_invalid_record)) { + CodepointCount message5; + message5.add_codepoints(0x46); + message5.add_codepoints(0x46); + message5.add_codepoints(0x46); + message5.set_count(75); + writer.WriteRecord(message5); + } + + if (!writer.Close()) { + return writer.status(); + } + } else { + { + riegeli::RecordWriter writer{riegeli::FdWriter(StrCat(output_path, "-00000-of-00003"))}; + writer.WriteRecord(message1); + writer.WriteRecord(message2); + writer.Close(); + } + { + riegeli::RecordWriter writer{riegeli::FdWriter(StrCat(output_path, "-00001-of-00003"))}; + writer.WriteRecord(message3); + writer.Close(); + } + { + riegeli::RecordWriter writer{riegeli::FdWriter(StrCat(output_path, "-00002-of-00003"))}; + writer.WriteRecord(message4); + writer.Close(); + } } + return absl::OkStatus(); } diff --git a/util/load_codepoints.cc b/util/load_codepoints.cc index f79d725..8e1f826 100644 --- a/util/load_codepoints.cc +++ b/util/load_codepoints.cc @@ -1,8 +1,11 @@ #include "load_codepoints.h" +#include +#include #include #include #include +#include #include #include "absl/strings/str_cat.h" @@ -14,6 +17,7 @@ #include "riegeli/records/record_reader.h" #include "util/unicode_count.pb.h" +using absl::Status; using absl::StatusOr; using absl::StrCat; using absl::string_view; @@ -122,8 +126,53 @@ StatusOr> LoadCodepointsOrdered( return out; } -StatusOr LoadFrequenciesFromRiegeli(const char* path) { - UnicodeFrequencies frequencies; +StatusOr> ExpandShardedPath(const char* path) { + std::string full_path(path); + + if (!full_path.ends_with("@*")) { + if (!std::filesystem::exists(full_path)) { + return absl::NotFoundError(StrCat("Path does not exist: ", full_path)); + } + return std::vector{full_path}; + } + + std::filesystem::path file_path = full_path.substr(0, full_path.size() - 2); + std::string base_name = file_path.filename(); + std::filesystem::path directory = file_path.parent_path(); + + // Find the list of files matching the pattern: + // -?????-of-????? + std::regex file_pattern("^.*-[0-9]{5}-of-[0-9]{5}$"); + + if (!std::filesystem::exists(directory) || + !std::filesystem::is_directory(directory)) { + return absl::NotFoundError(StrCat( + "Path does not exist or is not a directory: ", directory.string())); + } + + // Collect into a set to ensure the output is sorted. + absl::btree_set files; + for (const auto& entry : std::filesystem::directory_iterator(directory)) { + std::string name = entry.path().filename(); + if (!name.starts_with(base_name)) { + continue; + } + + if (std::regex_match(name, file_pattern)) { + files.insert(entry.path()); + } + } + + if (files.empty()) { + return absl::NotFoundError(StrCat("No files matched the shard pattern: ", full_path)); + } + + return std::vector(files.begin(), files.end()); +} + +static Status LoadFrequenciesFromRiegeliIndividual( + const char* path, UnicodeFrequencies& frequencies +) { riegeli::RecordReader reader{riegeli::FdReader(path)}; if (!reader.ok()) { return absl::InvalidArgumentError( @@ -144,6 +193,15 @@ StatusOr LoadFrequenciesFromRiegeli(const char* path) { if (!reader.Close()) { return absl::InternalError(reader.status().message()); } + return absl::OkStatus(); +} + +StatusOr LoadFrequenciesFromRiegeli(const char* path) { + auto paths = TRY(ExpandShardedPath(path)); + UnicodeFrequencies frequencies; + for (const auto& path : paths) { + TRYV(LoadFrequenciesFromRiegeliIndividual(path.c_str(), frequencies)); + } return frequencies; } diff --git a/util/load_codepoints.h b/util/load_codepoints.h index 85ada48..21521d7 100644 --- a/util/load_codepoints.h +++ b/util/load_codepoints.h @@ -36,9 +36,21 @@ absl::StatusOr LoadFile(const char* path); // Loads a Riegeli file of CodepointCount protos and returns a // UnicodeFrequencies instance. +// +// Append "@*" to the path to load all sharded files for this path. +// For example "FrequencyData.riegeli@*" will load all files of the +// form FrequencyData.riegeli-*-of-* into the frequency data set. absl::StatusOr LoadFrequenciesFromRiegeli( const char* path); +// Given a filepath if it ends with @* this will expand the path into +// the list of paths matching the pattern: -?????-of-????? +// Otherwise returns just the input path. +// +// Checks that the input path exists and will return a NotFoundError if +// it does not. +absl::StatusOr> ExpandShardedPath(const char* path); + struct CodepointAndFrequency { uint32_t codepoint; std::optional frequency; diff --git a/util/load_codepoints_test.cc b/util/load_codepoints_test.cc index 4c94529..dab594e 100644 --- a/util/load_codepoints_test.cc +++ b/util/load_codepoints_test.cc @@ -86,10 +86,73 @@ TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli) { EXPECT_EQ(result->ProbabilityFor(0x44, 0x45), 0.25); } +TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli_Sharded) { + auto result = + util::LoadFrequenciesFromRiegeli("util/testdata/sharded/test_freq_data.riegeli@*"); + ASSERT_TRUE(result.ok()) << result.status(); + + EXPECT_EQ(result->ProbabilityFor(0x43, 0x43), 1.0); + EXPECT_EQ(result->ProbabilityFor(0x44, 0x44), 75.0 / 200.0); + + EXPECT_EQ(result->ProbabilityFor(0x41, 0x42), 0.5); + EXPECT_EQ(result->ProbabilityFor(0x44, 0x45), 0.25); +} + +TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli_Sharded_DoesNotExist) { + auto result = + util::LoadFrequenciesFromRiegeli("util/testdata/sharded/notfound.riegeli@*"); + ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status(); +} + TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli_BadData) { auto result = util::LoadFrequenciesFromRiegeli( "util/testdata/invalid_test_freq_data.riegeli"); ASSERT_TRUE(absl::IsInvalidArgument(result.status())) << result.status(); } +TEST_F(LoadCodepointsTest, ExpandShardedPath) { + auto result = ExpandShardedPath("util/testdata/test_freq_data.riegeli"); + ASSERT_TRUE(result.ok()) << result.status(); + ASSERT_EQ(*result, + (std::vector{"util/testdata/test_freq_data.riegeli"})); + + result = ExpandShardedPath("util/testdata/test_freq_data.riegeli@*"); + ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status(); + + result = ExpandShardedPath("util/testdata/sharded/BadSuffix@*"); + ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status(); + + result = ExpandShardedPath("does/not/exist.file@*"); + ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status(); + + result = ExpandShardedPath("util/testdata/sharded/notfound.file@*"); + ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status(); + + result = ExpandShardedPath("does/not/exist.file"); + ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status(); + + result = ExpandShardedPath("util/testdata/sharded/Language_ja.riegeli@*"); + ASSERT_TRUE(result.ok()) << result.status(); + ASSERT_EQ(*result, + (std::vector{ + "util/testdata/sharded/Language_ja.riegeli-00000-of-00003", + "util/testdata/sharded/Language_ja.riegeli-00001-of-00003", + "util/testdata/sharded/Language_ja.riegeli-00002-of-00003", + })); + + result = ExpandShardedPath("util/testdata/sharded/Language_ko.riegeli@*"); + ASSERT_TRUE(result.ok()) << result.status(); + ASSERT_EQ(*result, + (std::vector{ + "util/testdata/sharded/Language_ko.riegeli-00000-of-00100", + "util/testdata/sharded/Language_ko.riegeli-00008-of-00100", + "util/testdata/sharded/Language_ko.riegeli-00011-of-00100", + "util/testdata/sharded/Language_ko.riegeli-00013-of-00100", + "util/testdata/sharded/Language_ko.riegeli-00020-of-00100", + })); + + result = ExpandShardedPath("util/testdata/sharded/Language_ja.riegeli"); + ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status(); +} + } // namespace util diff --git a/util/testdata/sharded/BadSuffix-00-of-02 b/util/testdata/sharded/BadSuffix-00-of-02 new file mode 100644 index 0000000..e69de29 diff --git a/util/testdata/sharded/BadSuffix-01-of-02 b/util/testdata/sharded/BadSuffix-01-of-02 new file mode 100644 index 0000000..e69de29 diff --git a/util/testdata/sharded/Language_ja.riegeli-00000-of-00003 b/util/testdata/sharded/Language_ja.riegeli-00000-of-00003 new file mode 100644 index 0000000..e69de29 diff --git a/util/testdata/sharded/Language_ja.riegeli-00001-of-00003 b/util/testdata/sharded/Language_ja.riegeli-00001-of-00003 new file mode 100644 index 0000000..e69de29 diff --git a/util/testdata/sharded/Language_ja.riegeli-00002-of-00003 b/util/testdata/sharded/Language_ja.riegeli-00002-of-00003 new file mode 100644 index 0000000..e69de29 diff --git a/util/testdata/sharded/Language_ko.riegeli-00000-of-00100 b/util/testdata/sharded/Language_ko.riegeli-00000-of-00100 new file mode 100644 index 0000000..e69de29 diff --git a/util/testdata/sharded/Language_ko.riegeli-00008-of-00100 b/util/testdata/sharded/Language_ko.riegeli-00008-of-00100 new file mode 100644 index 0000000..e69de29 diff --git a/util/testdata/sharded/Language_ko.riegeli-00011-of-00100 b/util/testdata/sharded/Language_ko.riegeli-00011-of-00100 new file mode 100644 index 0000000..e69de29 diff --git a/util/testdata/sharded/Language_ko.riegeli-00013-of-00100 b/util/testdata/sharded/Language_ko.riegeli-00013-of-00100 new file mode 100644 index 0000000..e69de29 diff --git a/util/testdata/sharded/Language_ko.riegeli-00020-of-00100 b/util/testdata/sharded/Language_ko.riegeli-00020-of-00100 new file mode 100644 index 0000000..e69de29 diff --git a/util/testdata/sharded/test_freq_data.riegeli-00000-of-00003 b/util/testdata/sharded/test_freq_data.riegeli-00000-of-00003 new file mode 100644 index 0000000..9b91222 Binary files /dev/null and b/util/testdata/sharded/test_freq_data.riegeli-00000-of-00003 differ diff --git a/util/testdata/sharded/test_freq_data.riegeli-00001-of-00003 b/util/testdata/sharded/test_freq_data.riegeli-00001-of-00003 new file mode 100644 index 0000000..f3d7994 Binary files /dev/null and b/util/testdata/sharded/test_freq_data.riegeli-00001-of-00003 differ diff --git a/util/testdata/sharded/test_freq_data.riegeli-00002-of-00003 b/util/testdata/sharded/test_freq_data.riegeli-00002-of-00003 new file mode 100644 index 0000000..92e4f48 Binary files /dev/null and b/util/testdata/sharded/test_freq_data.riegeli-00002-of-00003 differ