Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions include/neug/storages/graph/vertex_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,15 @@ class VertexTable {

template <typename PK_T>
void insert_vertices_impl(std::shared_ptr<IRecordBatchSupplier> supplier) {
auto row_nums = supplier->row_num();
size_t new_size = indexer_.size() + row_nums;
if (new_size > indexer_.capacity()) {
size_t cap = indexer_.capacity();
while (new_size >= cap) {
cap = cap < 4096 ? 4096 : cap + cap / 4;
}
EnsureCapacity(cap);
}
Comment on lines +279 to +287

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Action required

2. Row_num size_t wraparound risk 🐞 Bug ⛯ Reliability

VertexTable::insert_vertices_impl adds int64_t row_num() to a size_t without validating
non-negativity or overflow, so negative/invalid values can wrap to huge sizes and trigger extreme
EnsureCapacity() growth or non-terminating capacity loops due to overflow.
Agent Prompt
## Issue description
`insert_vertices_impl()` uses `supplier->row_num()` (int64) in `size_t` arithmetic without validating sign/overflow, risking wraparound, huge allocations, or overflow in the capacity growth loop.

## Issue Context
`row_num()` has no documented contract here (e.g., may be “unknown”). Even if current implementations return non-negative counts, this is an unsafe assumption for a public interface.

## Fix
- Treat non-positive row counts as “unknown”: skip upfront preallocation when `row_nums <= 0`.
- Use checked arithmetic before adding:
  - if `row_nums > 0`, cast to `uint64_t`/`size_t` only after bounds checks.
  - guard `indexer_.size() + row_nums` overflow (`row_nums > SIZE_MAX - indexer_.size()`).
- Make the capacity growth loop overflow-safe:
  - before `cap + cap/4`, check `cap > SIZE_MAX - cap/4` and throw/stop.

## Fix Focus Areas
- include/neug/storages/graph/vertex_table.h[278-287]
- include/neug/storages/loader/loader_utils.h[77-82]

ⓘ Copy this prompt and use it to remediate the issue with your preferred AI generation tools

while (true) {
auto batch = supplier->GetNextBatch();
if (batch == nullptr) {
Expand All @@ -290,13 +299,14 @@ class VertexTable {
auto ind = std::get<2>(vertex_schema_->primary_keys[0]);
auto pk_array = columns[ind];
columns.erase(columns.begin() + ind);
size_t new_size = indexer_.size() + pk_array->length();
if (new_size >= indexer_.capacity()) {
size_t new_cap = new_size;
while (new_size >= new_cap) {
new_cap = new_cap < 4096 ? 4096 : new_cap + new_cap / 4;
// Add capacity checking logic when performing the actual batch insert.
size_t new_size = indexer_.size() + batch->num_rows();
if (new_size > indexer_.capacity()) {
size_t cap = indexer_.capacity();
while (new_size >= cap) {
cap = cap < 4096 ? 4096 : cap + cap / 4;
}
EnsureCapacity(new_cap);
EnsureCapacity(cap);
}

auto vids = insert_primary_keys<PK_T>(pk_array);
Expand Down
31 changes: 29 additions & 2 deletions include/neug/storages/loader/loader_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class IRecordBatchSupplier {
public:
virtual ~IRecordBatchSupplier() = default;
virtual std::shared_ptr<arrow::RecordBatch> GetNextBatch() = 0;
virtual int64_t row_num() const = 0;
};
Comment on lines 78 to 82

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Action required

1. Odps suppliers won't compile 🐞 Bug ✓ Correctness

IRecordBatchSupplier now requires a pure-virtual row_num(), but ODPSStreamRecordBatchSupplier
and ODPSTableRecordBatchSupplier don’t override it, so instantiation is ill-formed and the build
will fail.
Agent Prompt
## Issue description
`IRecordBatchSupplier::row_num()` is now pure virtual, but ODPS suppliers don’t override it, leaving them abstract and causing compilation failures where they are instantiated.

## Issue Context
ODPS suppliers are constructed in `odps_fragment_loader.cc`, so they must be concrete.

## Fix
- Add `int64_t row_num() const override` to both ODPS suppliers.
  - For `ODPSTableRecordBatchSupplier`, return `table_ ? table_->num_rows() : 0`.
  - For `ODPSStreamRecordBatchSupplier`, if an exact total isn’t cheaply available, return `0` (or another agreed sentinel), and ensure callers treat it as “unknown”.

## Fix Focus Areas
- include/neug/storages/loader/odps_fragment_loader.h[109-151]
- src/storages/loader/odps_fragment_loader.cc[359-382]
- include/neug/storages/loader/loader_utils.h[77-82]

ⓘ Copy this prompt and use it to remediate the issue with your preferred AI generation tools


class SupplierWrapperWithFirstBatch : public IRecordBatchSupplier {
Expand Down Expand Up @@ -111,6 +112,14 @@ class SupplierWrapperWithFirstBatch : public IRecordBatchSupplier {
return batch; // Return the batch from the current supplier
}

int64_t row_num() const override {
int64_t total_rows = 0;
for (const auto& supplier : suppliers_) {
total_rows += supplier->row_num();
}
return total_rows;
}

private:
std::vector<std::shared_ptr<IRecordBatchSupplier>> suppliers_;
std::shared_ptr<arrow::RecordBatch> first_batch_;
Expand All @@ -127,7 +136,10 @@ class CSVStreamRecordBatchSupplier : public IRecordBatchSupplier {

std::shared_ptr<arrow::RecordBatch> GetNextBatch() override;

int64_t row_num() const override { return row_num_; }

private:
int64_t row_num_;
std::string file_path_;
std::shared_ptr<arrow::csv::StreamingReader> reader_;
};
Expand All @@ -141,6 +153,8 @@ class CSVTableRecordBatchSupplier : public IRecordBatchSupplier {

std::shared_ptr<arrow::RecordBatch> GetNextBatch() override;

int64_t row_num() const override { return table_->num_rows(); }

private:
std::string file_path_;
std::shared_ptr<arrow::Table> table_;
Expand All @@ -166,6 +180,16 @@ class ArrowRecordBatchArraySupplier : public IRecordBatchSupplier {

std::shared_ptr<arrow::RecordBatch> GetNextBatch() override;

int64_t row_num() const override {
int64_t total_rows = 0;
if (!arrays_.empty()) {
for (const auto& batch : arrays_[0]) {
total_rows += batch->length();
}
}
return total_rows;
}

private:
// NUM_COLUMNS * NUM_BATCHES
std::vector<std::vector<std::shared_ptr<arrow::Array>>> arrays_;
Expand All @@ -182,12 +206,15 @@ class ArrowRecordBatchArraySupplier : public IRecordBatchSupplier {
class ArrowRecordBatchStreamSupplier : public IRecordBatchSupplier {
public:
ArrowRecordBatchStreamSupplier(
const std::shared_ptr<arrow::RecordBatchReader>& reader)
: reader_(reader) {}
const std::shared_ptr<arrow::RecordBatchReader>& reader, int64_t row_num)
: row_num_(row_num), reader_(reader) {}

std::shared_ptr<arrow::RecordBatch> GetNextBatch() override;

int64_t row_num() const override { return row_num_; }

private:
int64_t row_num_;
std::shared_ptr<arrow::RecordBatchReader> reader_;
};

Expand Down
27 changes: 26 additions & 1 deletion src/storages/loader/loader_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,35 @@ CSVStreamRecordBatchSupplier::CSVStreamRecordBatchSupplier(
: file_path_(file_path) {
auto read_result = arrow::io::ReadableFile::Open(file_path);
if (!read_result.ok()) {
LOG(FATAL) << "Failed to open file: " << file_path
LOG(ERROR) << "Failed to open file: " << file_path
<< " error: " << read_result.status().message();
THROW_IO_EXCEPTION("Failed to open file: " + file_path +
" error: " + read_result.status().message());
}
auto file = read_result.ValueOrDie();
auto count_file_result = arrow::io::ReadableFile::Open(file_path);
if (count_file_result.ok()) {
auto count_file = count_file_result.ValueOrDie();
auto future = arrow::csv::CountRowsAsync(
arrow::io::default_io_context(), count_file,
arrow::internal::GetCpuThreadPool(), read_options, parse_options);
future.Wait();

auto count_result = future.result();
if (count_result.ok()) {
row_num_ = count_result.ValueUnsafe();
} else {
LOG(WARNING) << "Failed to count rows for " << file_path << ": "
<< count_result.status().message();
THROW_IO_EXCEPTION("Failed to count rows for " + file_path + ": " +
count_result.status().message());
}
} else {
LOG(WARNING) << "Failed to reopen file for counting: "
<< count_file_result.status().message();
THROW_IO_EXCEPTION("Failed to reopen file for counting: " +
count_file_result.status().message());
}
auto res = arrow::csv::StreamingReader::Make(arrow::io::default_io_context(),
file, read_options,
parse_options, convert_options);
Expand Down
15 changes: 13 additions & 2 deletions src/utils/reader/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ void ArrowReader::batch_read(std::shared_ptr<arrow::dataset::Scanner> scanner,
if (!scanner) {
THROW_INVALID_ARGUMENT_EXCEPTION("Scanner is null");
}
auto row_num_result = scanner->CountRows();
int64_t row_num = 0;
if (!row_num_result.ok()) {
LOG(WARNING) << "Failed to count rows via scanner: "
<< row_num_result.status().message();
THROW_IO_EXCEPTION("Failed to count rows via scanner: " +
row_num_result.status().message());
} else {
VLOG(10) << "Row count from scanner: " << row_num_result.ValueOrDie();
row_num = row_num_result.ValueOrDie();
}
Comment on lines +188 to +198

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Action required

3. Countrows becomes hard requirement 🐞 Bug ⛯ Reliability

ArrowReader::batch_read now throws if scanner->CountRows() fails, even though it separately
creates a RecordBatchReader afterward; this turns a preallocation hint into a hard failure path
and can break previously-working reads.
Agent Prompt
## Issue description
`ArrowReader::batch_read()` throws when `scanner->CountRows()` fails, making row counting mandatory even though batch reading can still proceed via `ToRecordBatchReader()`.

## Issue Context
The counted value is only used to construct `ArrowRecordBatchStreamSupplier(row_num)`. If row count is unknown, the system should still be able to stream batches.

## Fix
- Change CountRows handling to best-effort:
  - If `CountRows()` fails, log and set `row_num` to `0` (or a sentinel like `-1`) and continue.
- Ensure downstream preallocation code treats `row_num <= 0` as unknown (see VertexTable fix).
- Optionally, avoid calling `CountRows()` entirely to prevent double scanning.

## Fix Focus Areas
- src/utils/reader/reader.cc[180-211]
- include/neug/storages/loader/loader_utils.h[206-219]
- include/neug/storages/graph/vertex_table.h[278-287]

ⓘ Copy this prompt and use it to remediate the issue with your preferred AI generation tools


auto batch_reader_result = scanner->ToRecordBatchReader();
if (!batch_reader_result.ok()) {
Expand All @@ -195,8 +206,8 @@ void ArrowReader::batch_read(std::shared_ptr<arrow::dataset::Scanner> scanner,
}
auto batch_reader = batch_reader_result.ValueOrDie();

auto batch_supplier =
std::make_shared<neug::ArrowRecordBatchStreamSupplier>(batch_reader);
auto batch_supplier = std::make_shared<neug::ArrowRecordBatchStreamSupplier>(
batch_reader, row_num);

int num_cols = sharedState->columnNum();
output.clear();
Expand Down
22 changes: 1 addition & 21 deletions tests/storage/test_edge_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,6 @@
namespace neug {
namespace test {

class LocalGeneratedRecordBatchSupplier : public neug::IRecordBatchSupplier {
public:
explicit LocalGeneratedRecordBatchSupplier(
std::vector<std::shared_ptr<arrow::RecordBatch>>&& batches)
: batch_index_(0), batches_(std::move(batches)) {}

std::shared_ptr<arrow::RecordBatch> GetNextBatch() override {
if (batch_index_ >= batches_.size()) {
return nullptr;
}
auto batch = batches_[batch_index_];
batch_index_++;
return batch;
}

private:
size_t batch_index_ = 0;
std::vector<std::shared_ptr<arrow::RecordBatch>> batches_;
};

class EdgeTableTest : public ::testing::Test {
protected:
void SetUp() override {
Expand Down Expand Up @@ -148,7 +128,7 @@ class EdgeTableTest : public ::testing::Test {

void BatchInsert(std::vector<std::shared_ptr<arrow::RecordBatch>>&& batches) {
auto supplier =
std::make_shared<LocalGeneratedRecordBatchSupplier>(std::move(batches));
std::make_shared<GeneratedRecordBatchSupplier>(std::move(batches));
edge_table->BatchAddEdges(src_indexer, dst_indexer, supplier);
}

Expand Down
8 changes: 8 additions & 0 deletions tests/unittest/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class GeneratedRecordBatchSupplier : public neug::IRecordBatchSupplier {
}
}

int64_t row_num() const override {
int64_t total_rows = 0;
for (const auto& batch : batches_) {
total_rows += batch->num_rows();
}
return total_rows;
}

private:
std::vector<std::shared_ptr<arrow::RecordBatch>> batches_;
};
Expand Down
Loading