Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ cc_test(
"testdata/codepoints_with_freq_invalid.txt",
"testdata/test_freq_data.riegeli",
"testdata/invalid_test_freq_data.riegeli",
],
] + glob([
"testdata/sharded/*",
]),
deps = [
":load_codepoints",
"@googletest//:gtest_main",
Expand Down
7 changes: 6 additions & 1 deletion util/freq_data_to_sorted_codepoints.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <iomanip>
#include <iostream>
#include <locale>
#include <ostream>
#include <string>
#include <vector>

Expand Down Expand Up @@ -40,7 +41,11 @@ int main(int argc, char** argv) {

if (args.size() != 2) {
std::cerr << "Usage:" << std::endl
<< "freq_data_to_sorted_codepoints <riegeli_file>" << std::endl;
<< "freq_data_to_sorted_codepoints <riegeli_file>" << std::endl
<< std::endl
<< "Append @* to the file name to load sharded data files. "
<< "For example \"<path>@*\" will load all files of the form <path>-?????-of-?????"
<< std::endl;
return -1;
}

Expand Down
56 changes: 40 additions & 16 deletions util/generate_riegeli_test_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "riegeli/bytes/fd_writer.h"
#include "riegeli/records/record_writer.h"
#include "util/unicode_count.pb.h"

using absl::StrCat;

ABSL_FLAG(std::string, output_path, "", "Path to write the output file.");

ABSL_FLAG(bool, include_invalid_record, false, "If set add an invalid record.");
ABSL_FLAG(bool, shard, false, "If set shard into multiple files.");

namespace {

Expand Down Expand Up @@ -41,24 +45,44 @@ absl::Status Main() {
message4.add_codepoints(0x44);
message4.set_count(75);

riegeli::RecordWriter writer{riegeli::FdWriter(output_path)};
writer.WriteRecord(message1);
writer.WriteRecord(message2);
writer.WriteRecord(message3);
writer.WriteRecord(message4);

if (absl::GetFlag(FLAGS_include_invalid_record)) {
CodepointCount message5;
message5.add_codepoints(0x46);
message5.add_codepoints(0x46);
message5.add_codepoints(0x46);
message5.set_count(75);
writer.WriteRecord(message5);
}
if (!absl::GetFlag(FLAGS_shard)) {
riegeli::RecordWriter writer{riegeli::FdWriter(output_path)};
writer.WriteRecord(message1);
writer.WriteRecord(message2);
writer.WriteRecord(message3);
writer.WriteRecord(message4);

if (!writer.Close()) {
return writer.status();
if (absl::GetFlag(FLAGS_include_invalid_record)) {
CodepointCount message5;
message5.add_codepoints(0x46);
message5.add_codepoints(0x46);
message5.add_codepoints(0x46);
message5.set_count(75);
writer.WriteRecord(message5);
}

if (!writer.Close()) {
return writer.status();
}
} else {
{
riegeli::RecordWriter writer{riegeli::FdWriter(StrCat(output_path, "-00000-of-00003"))};
writer.WriteRecord(message1);
writer.WriteRecord(message2);
writer.Close();
}
{
riegeli::RecordWriter writer{riegeli::FdWriter(StrCat(output_path, "-00001-of-00003"))};
writer.WriteRecord(message3);
writer.Close();
}
{
riegeli::RecordWriter writer{riegeli::FdWriter(StrCat(output_path, "-00002-of-00003"))};
writer.WriteRecord(message4);
writer.Close();
}
}

return absl::OkStatus();
}

Expand Down
62 changes: 60 additions & 2 deletions util/load_codepoints.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include "load_codepoints.h"

#include <algorithm>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <optional>
#include <regex>
#include <sstream>

#include "absl/strings/str_cat.h"
Expand All @@ -14,6 +17,7 @@
#include "riegeli/records/record_reader.h"
#include "util/unicode_count.pb.h"

using absl::Status;
using absl::StatusOr;
using absl::StrCat;
using absl::string_view;
Expand Down Expand Up @@ -122,8 +126,53 @@ StatusOr<std::vector<CodepointAndFrequency>> LoadCodepointsOrdered(
return out;
}

StatusOr<UnicodeFrequencies> LoadFrequenciesFromRiegeli(const char* path) {
UnicodeFrequencies frequencies;
StatusOr<std::vector<std::string>> ExpandShardedPath(const char* path) {
std::string full_path(path);

if (!full_path.ends_with("@*")) {
if (!std::filesystem::exists(full_path)) {
return absl::NotFoundError(StrCat("Path does not exist: ", full_path));
}
return std::vector<std::string>{full_path};
}

std::filesystem::path file_path = full_path.substr(0, full_path.size() - 2);
std::string base_name = file_path.filename();
std::filesystem::path directory = file_path.parent_path();

// Find the list of files matching the pattern:
// <base name>-?????-of-?????
std::regex file_pattern("^.*-[0-9]{5}-of-[0-9]{5}$");

if (!std::filesystem::exists(directory) ||
!std::filesystem::is_directory(directory)) {
return absl::NotFoundError(StrCat(
"Path does not exist or is not a directory: ", directory.string()));
}

// Collect into a set to ensure the output is sorted.
absl::btree_set<std::string> files;
for (const auto& entry : std::filesystem::directory_iterator(directory)) {
std::string name = entry.path().filename();
if (!name.starts_with(base_name)) {
continue;
}

if (std::regex_match(name, file_pattern)) {
files.insert(entry.path());
}
}

if (files.empty()) {
return absl::NotFoundError(StrCat("No files matched the shard pattern: ", full_path));
}

return std::vector<std::string>(files.begin(), files.end());
}

static Status LoadFrequenciesFromRiegeliIndividual(
const char* path, UnicodeFrequencies& frequencies
) {
riegeli::RecordReader reader{riegeli::FdReader(path)};
if (!reader.ok()) {
return absl::InvalidArgumentError(
Expand All @@ -144,6 +193,15 @@ StatusOr<UnicodeFrequencies> LoadFrequenciesFromRiegeli(const char* path) {
if (!reader.Close()) {
return absl::InternalError(reader.status().message());
}
return absl::OkStatus();
}

StatusOr<UnicodeFrequencies> LoadFrequenciesFromRiegeli(const char* path) {
auto paths = TRY(ExpandShardedPath(path));
UnicodeFrequencies frequencies;
for (const auto& path : paths) {
TRYV(LoadFrequenciesFromRiegeliIndividual(path.c_str(), frequencies));
}
return frequencies;
}

Expand Down
12 changes: 12 additions & 0 deletions util/load_codepoints.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,21 @@ absl::StatusOr<common::FontData> LoadFile(const char* path);

// Loads a Riegeli file of CodepointCount protos and returns a
// UnicodeFrequencies instance.
//
// Append "@*" to the path to load all sharded files for this path.
// For example "FrequencyData.riegeli@*" will load all files of the
// form FrequencyData.riegeli-*-of-* into the frequency data set.
absl::StatusOr<ift::freq::UnicodeFrequencies> LoadFrequenciesFromRiegeli(
const char* path);

// Given a filepath if it ends with @* this will expand the path into
// the list of paths matching the pattern: <path>-?????-of-?????
// Otherwise returns just the input path.
//
// Checks that the input path exists and will return a NotFoundError if
// it does not.
absl::StatusOr<std::vector<std::string>> ExpandShardedPath(const char* path);

struct CodepointAndFrequency {
uint32_t codepoint;
std::optional<uint64_t> frequency;
Expand Down
63 changes: 63 additions & 0 deletions util/load_codepoints_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,73 @@ TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli) {
EXPECT_EQ(result->ProbabilityFor(0x44, 0x45), 0.25);
}

TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli_Sharded) {
auto result =
util::LoadFrequenciesFromRiegeli("util/testdata/sharded/test_freq_data.riegeli@*");
ASSERT_TRUE(result.ok()) << result.status();

EXPECT_EQ(result->ProbabilityFor(0x43, 0x43), 1.0);
EXPECT_EQ(result->ProbabilityFor(0x44, 0x44), 75.0 / 200.0);

EXPECT_EQ(result->ProbabilityFor(0x41, 0x42), 0.5);
EXPECT_EQ(result->ProbabilityFor(0x44, 0x45), 0.25);
}

TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli_Sharded_DoesNotExist) {
auto result =
util::LoadFrequenciesFromRiegeli("util/testdata/sharded/notfound.riegeli@*");
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();
}

TEST_F(LoadCodepointsTest, LoadFrequenciesFromRiegeli_BadData) {
auto result = util::LoadFrequenciesFromRiegeli(
"util/testdata/invalid_test_freq_data.riegeli");
ASSERT_TRUE(absl::IsInvalidArgument(result.status())) << result.status();
}

TEST_F(LoadCodepointsTest, ExpandShardedPath) {
auto result = ExpandShardedPath("util/testdata/test_freq_data.riegeli");
ASSERT_TRUE(result.ok()) << result.status();
ASSERT_EQ(*result,
(std::vector<std::string>{"util/testdata/test_freq_data.riegeli"}));

result = ExpandShardedPath("util/testdata/test_freq_data.riegeli@*");
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();

result = ExpandShardedPath("util/testdata/sharded/BadSuffix@*");
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();

result = ExpandShardedPath("does/not/exist.file@*");
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();

result = ExpandShardedPath("util/testdata/sharded/notfound.file@*");
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();

result = ExpandShardedPath("does/not/exist.file");
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();

result = ExpandShardedPath("util/testdata/sharded/Language_ja.riegeli@*");
ASSERT_TRUE(result.ok()) << result.status();
ASSERT_EQ(*result,
(std::vector<std::string>{
"util/testdata/sharded/Language_ja.riegeli-00000-of-00003",
"util/testdata/sharded/Language_ja.riegeli-00001-of-00003",
"util/testdata/sharded/Language_ja.riegeli-00002-of-00003",
}));

result = ExpandShardedPath("util/testdata/sharded/Language_ko.riegeli@*");
ASSERT_TRUE(result.ok()) << result.status();
ASSERT_EQ(*result,
(std::vector<std::string>{
"util/testdata/sharded/Language_ko.riegeli-00000-of-00100",
"util/testdata/sharded/Language_ko.riegeli-00008-of-00100",
"util/testdata/sharded/Language_ko.riegeli-00011-of-00100",
"util/testdata/sharded/Language_ko.riegeli-00013-of-00100",
"util/testdata/sharded/Language_ko.riegeli-00020-of-00100",
}));

result = ExpandShardedPath("util/testdata/sharded/Language_ja.riegeli");
ASSERT_TRUE(absl::IsNotFound(result.status())) << result.status();
}

} // namespace util
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading