diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 12c22deb7f..28fbc6b1d0 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -406,6 +406,8 @@ void WriteContext::reset() { has_first_type_info_ = false; type_info_index_map_active_ = false; current_dyn_depth_ = 0; + buffer_.clear_output_stream(); + output_stream_ = nullptr; // reset buffer indices for reuse - no memory operations needed buffer_.writer_index(0); buffer_.reader_index(0); diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 18e5e68bd5..e080604f67 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -118,6 +118,10 @@ class WriteContext { /// get const reference to internal output buffer. inline const Buffer &buffer() const { return buffer_; } + inline void set_output_stream(OutputStream *output_stream) { + output_stream_ = output_stream; + } + /// get reference writer for tracking shared references. inline RefWriter &ref_writer() { return ref_writer_; } @@ -167,70 +171,107 @@ class WriteContext { } } + inline uint32_t flush_barrier_depth() const { + return output_stream_ == nullptr ? 0 + : output_stream_->flush_barrier_depth(); + } + + inline void enter_flush_barrier() { + if (output_stream_ != nullptr) { + output_stream_->enter_flush_barrier(); + } + } + + inline void exit_flush_barrier() { + if (output_stream_ != nullptr) { + output_stream_->exit_flush_barrier(); + } + } + + inline void try_flush() { + if (output_stream_ == nullptr || buffer_.writer_index() <= 4096) { + return; + } + output_stream_->try_flush(); + if (FORY_PREDICT_FALSE(output_stream_->has_error())) { + set_error(output_stream_->error()); + } + } + + inline void force_flush() { + if (output_stream_ == nullptr) { + return; + } + output_stream_->force_flush(); + if (FORY_PREDICT_FALSE(output_stream_->has_error())) { + set_error(output_stream_->error()); + } + } + /// write uint8_t value to buffer. FORY_ALWAYS_INLINE void write_uint8(uint8_t value) { - buffer().write_uint8(value); + buffer_.write_uint8(value); } /// write int8_t value to buffer. FORY_ALWAYS_INLINE void write_int8(int8_t value) { - buffer().write_int8(value); + buffer_.write_int8(value); } /// write uint16_t value to buffer. FORY_ALWAYS_INLINE void write_uint16(uint16_t value) { - buffer().write_uint16(value); + buffer_.write_uint16(value); } /// write uint32_t value to buffer. FORY_ALWAYS_INLINE void write_uint32(uint32_t value) { - buffer().write_uint32(value); + buffer_.write_uint32(value); } /// write int64_t value to buffer. FORY_ALWAYS_INLINE void write_int64(int64_t value) { - buffer().write_int64(value); + buffer_.write_int64(value); } /// write uint32_t value as varint to buffer. FORY_ALWAYS_INLINE void write_var_uint32(uint32_t value) { - buffer().write_var_uint32(value); + buffer_.write_var_uint32(value); } /// write int32_t value as zigzag varint to buffer. FORY_ALWAYS_INLINE void write_varint32(int32_t value) { - buffer().write_var_int32(value); + buffer_.write_var_int32(value); } /// write uint64_t value as varint to buffer. FORY_ALWAYS_INLINE void write_var_uint64(uint64_t value) { - buffer().write_var_uint64(value); + buffer_.write_var_uint64(value); } /// write int64_t value as zigzag varint to buffer. FORY_ALWAYS_INLINE void write_varint64(int64_t value) { - buffer().write_var_int64(value); + buffer_.write_var_int64(value); } /// write uint64_t value using tagged encoding to buffer. FORY_ALWAYS_INLINE void write_tagged_uint64(uint64_t value) { - buffer().write_tagged_uint64(value); + buffer_.write_tagged_uint64(value); } /// write int64_t value using tagged encoding to buffer. FORY_ALWAYS_INLINE void write_tagged_int64(int64_t value) { - buffer().write_tagged_int64(value); + buffer_.write_tagged_int64(value); } /// write uint64_t value as varuint36small to buffer. /// This is the special variable-length encoding used for string headers. FORY_ALWAYS_INLINE void write_var_uint36_small(uint64_t value) { - buffer().write_var_uint36_small(value); + buffer_.write_var_uint36_small(value); } /// write raw bytes to buffer. FORY_ALWAYS_INLINE void write_bytes(const void *data, uint32_t length) { - buffer().write_bytes(data, length); + buffer_.write_bytes(data, length); } /// write TypeMeta inline using streaming protocol. @@ -329,6 +370,7 @@ class WriteContext { std::unique_ptr type_resolver_; RefWriter ref_writer_; uint32_t current_dyn_depth_; + OutputStream *output_stream_ = nullptr; // Meta sharing state (for streaming inline TypeMeta) // Maps TypeInfo* to index for reference tracking - uses map size as counter diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index 09a0bfa487..1c5f19522c 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -41,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -501,6 +502,52 @@ class Fory : public BaseFory { return result; } + /// Serialize an object to an output stream. + /// + /// @tparam T The type of object to serialize. + /// @param output_stream The output stream. + /// @param obj The object to serialize. + /// @return Number of bytes written, or error. + template + Result serialize(OutputStream &output_stream, const T &obj) { + if (FORY_PREDICT_FALSE(!finalized_)) { + ensure_finalized(); + } + WriteContextGuard guard(*write_ctx_); + output_stream.reset(); + write_ctx_->set_output_stream(&output_stream); + Buffer &buffer = write_ctx_->buffer(); + buffer.bind_output_stream(&output_stream); + auto serialize_result = serialize_impl(obj, buffer); + if (FORY_PREDICT_FALSE(!serialize_result.ok())) { + buffer.clear_output_stream(); + write_ctx_->set_output_stream(nullptr); + return Unexpected(std::move(serialize_result).error()); + } + output_stream.force_flush(); + buffer.clear_output_stream(); + write_ctx_->set_output_stream(nullptr); + if (FORY_PREDICT_FALSE(output_stream.has_error())) { + return Unexpected(output_stream.error()); + } + if (FORY_PREDICT_FALSE(write_ctx_->has_error())) { + return Unexpected(write_ctx_->take_error()); + } + return output_stream.flushed_bytes(); + } + + /// Serialize an object to a std::ostream. + /// + /// @tparam T The type of object to serialize. + /// @param ostream The output stream. + /// @param obj The object to serialize. + /// @return Number of bytes written, or error. + template + Result serialize(std::ostream &ostream, const T &obj) { + StdOutputStream output_stream(ostream); + return serialize(output_stream, obj); + } + /// Serialize an object to an existing Buffer (fastest path). /// /// @tparam T The type of object to serialize. @@ -627,36 +674,36 @@ class Fory : public BaseFory { return deserialize_impl(buffer); } - /// Deserialize an object from a stream reader. + /// Deserialize an object from an input stream. /// /// This overload obtains the reader-owned Buffer via get_buffer() and /// continues deserialization on that buffer. /// /// @tparam T The type of object to deserialize. - /// @param stream_reader Stream reader to read from. + /// @param input_stream Input stream to read from. /// @return Deserialized object, or error. template - Result deserialize(StreamReader &stream_reader) { + Result deserialize(InputStream &input_stream) { struct StreamShrinkGuard { - StreamReader *stream_reader = nullptr; + InputStream *input_stream = nullptr; ~StreamShrinkGuard() { - if (stream_reader != nullptr) { - stream_reader->shrink_buffer(); + if (input_stream != nullptr) { + input_stream->shrink_buffer(); } } }; - StreamShrinkGuard shrink_guard{&stream_reader}; - Buffer &buffer = stream_reader.get_buffer(); + StreamShrinkGuard shrink_guard{&input_stream}; + Buffer &buffer = input_stream.get_buffer(); return deserialize(buffer); } - /// Deserialize an object from ForyInputStream. + /// Deserialize an object from StdInputStream. /// /// @tparam T The type of object to deserialize. /// @param stream Input stream wrapper to read from. /// @return Deserialized object, or error. - template Result deserialize(ForyInputStream &stream) { - return deserialize(static_cast(stream)); + template Result deserialize(StdInputStream &stream) { + return deserialize(static_cast(stream)); } // ========================================================================== @@ -805,6 +852,18 @@ class ThreadSafeFory : public BaseFory { return fory_handle->serialize(obj); } + template + Result serialize(OutputStream &output_stream, const T &obj) { + auto fory_handle = fory_pool_.acquire(); + return fory_handle->serialize(output_stream, obj); + } + + template + Result serialize(std::ostream &ostream, const T &obj) { + auto fory_handle = fory_pool_.acquire(); + return fory_handle->serialize(ostream, obj); + } + template Result serialize_to(Buffer &buffer, const T &obj) { auto fory_handle = fory_pool_.acquire(); @@ -830,12 +889,12 @@ class ThreadSafeFory : public BaseFory { } template - Result deserialize(StreamReader &stream_reader) { + Result deserialize(InputStream &input_stream) { auto fory_handle = fory_pool_.acquire(); - return fory_handle->template deserialize(stream_reader); + return fory_handle->template deserialize(input_stream); } - template Result deserialize(ForyInputStream &stream) { + template Result deserialize(StdInputStream &stream) { auto fory_handle = fory_pool_.acquire(); return fory_handle->template deserialize(stream); } diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index 5bd9bea51b..dd2952da99 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -138,6 +138,7 @@ inline void write_map_data_fast(const MapType &map, WriteContext &ctx, // If nullability is needed, use the slow path if (need_write_header) { + ctx.enter_flush_barrier(); // reserve space for header (1 byte) + chunk size (1 byte) header_offset = ctx.buffer().writer_index(); ctx.write_uint16(0); // Placeholder for header and chunk size @@ -174,6 +175,8 @@ inline void write_map_data_fast(const MapType &map, WriteContext &ctx, pair_counter++; if (pair_counter == MAX_CHUNK_SIZE) { write_chunk_size(ctx, header_offset, pair_counter); + ctx.exit_flush_barrier(); + ctx.try_flush(); pair_counter = 0; need_write_header = true; } @@ -182,6 +185,8 @@ inline void write_map_data_fast(const MapType &map, WriteContext &ctx, // write final chunk size if (pair_counter > 0) { write_chunk_size(ctx, header_offset, pair_counter); + ctx.exit_flush_barrier(); + ctx.try_flush(); } } @@ -238,6 +243,7 @@ inline void write_map_data_slow(const MapType &map, WriteContext &ctx, // Finish current chunk if any if (pair_counter > 0) { write_chunk_size(ctx, header_offset, pair_counter); + ctx.exit_flush_barrier(); pair_counter = 0; need_write_header = true; } @@ -394,9 +400,12 @@ inline void write_map_data_slow(const MapType &map, WriteContext &ctx, // Finish previous chunk if types changed if (types_changed && pair_counter > 0) { write_chunk_size(ctx, header_offset, pair_counter); + ctx.exit_flush_barrier(); + ctx.try_flush(); pair_counter = 0; } + ctx.enter_flush_barrier(); // write new chunk header header_offset = ctx.buffer().writer_index(); ctx.write_uint16(0); // Placeholder for header and chunk size @@ -513,6 +522,8 @@ inline void write_map_data_slow(const MapType &map, WriteContext &ctx, pair_counter++; if (pair_counter == MAX_CHUNK_SIZE) { write_chunk_size(ctx, header_offset, pair_counter); + ctx.exit_flush_barrier(); + ctx.try_flush(); pair_counter = 0; need_write_header = true; current_key_type_info = nullptr; @@ -523,6 +534,8 @@ inline void write_map_data_slow(const MapType &map, WriteContext &ctx, // write final chunk size if (pair_counter > 0) { write_chunk_size(ctx, header_offset, pair_counter); + ctx.exit_flush_barrier(); + ctx.try_flush(); } } diff --git a/cpp/fory/serialization/stream_test.cc b/cpp/fory/serialization/stream_test.cc index edbefd63de..98783a37e7 100644 --- a/cpp/fory/serialization/stream_test.cc +++ b/cpp/fory/serialization/stream_test.cc @@ -110,6 +110,44 @@ class OneByteIStream final : public std::istream { OneByteStreamBuf buf_; }; +class OneByteOutputStreamBuf final : public std::streambuf { +public: + OneByteOutputStreamBuf() = default; + + const std::vector &data() const { return data_; } + +protected: + std::streamsize xsputn(const char *s, std::streamsize count) override { + if (count <= 0) { + return 0; + } + data_.insert(data_.end(), reinterpret_cast(s), + reinterpret_cast(s + count)); + return count; + } + + int_type overflow(int_type ch) override { + if (traits_type::eq_int_type(ch, traits_type::eof())) { + return traits_type::not_eof(ch); + } + data_.push_back(static_cast(traits_type::to_char_type(ch))); + return ch; + } + +private: + std::vector data_; +}; + +class OneByteOStream final : public std::ostream { +public: + OneByteOStream() : std::ostream(nullptr) { rdbuf(&buf_); } + + std::vector data() const { return buf_.data(); } + +private: + OneByteOutputStreamBuf buf_; +}; + static inline void register_stream_types(Fory &fory) { uint32_t type_id = 1; fory.register_struct(type_id++); @@ -124,7 +162,7 @@ TEST(StreamSerializationTest, PrimitiveAndStringRoundTrip) { ASSERT_TRUE(number_bytes_result.ok()) << number_bytes_result.error().to_string(); OneByteIStream number_source(std::move(number_bytes_result).value()); - ForyInputStream number_stream(number_source, 8); + StdInputStream number_stream(number_source, 8); auto number_result = fory.deserialize(number_stream); ASSERT_TRUE(number_result.ok()) << number_result.error().to_string(); EXPECT_EQ(number_result.value(), -9876543212345LL); @@ -133,7 +171,7 @@ TEST(StreamSerializationTest, PrimitiveAndStringRoundTrip) { ASSERT_TRUE(string_bytes_result.ok()) << string_bytes_result.error().to_string(); OneByteIStream string_source(std::move(string_bytes_result).value()); - ForyInputStream string_stream(string_source, 8); + StdInputStream string_stream(string_source, 8); auto string_result = fory.deserialize(string_stream); ASSERT_TRUE(string_result.ok()) << string_result.error().to_string(); EXPECT_EQ(string_result.value(), "stream-hello-世界"); @@ -155,7 +193,7 @@ TEST(StreamSerializationTest, StructRoundTrip) { ASSERT_TRUE(bytes_result.ok()) << bytes_result.error().to_string(); OneByteIStream source(std::move(bytes_result).value()); - ForyInputStream stream(source, 4); + StdInputStream stream(source, 4); auto result = fory.deserialize(stream); ASSERT_TRUE(result.ok()) << result.error().to_string(); EXPECT_EQ(result.value(), original); @@ -175,7 +213,7 @@ TEST(StreamSerializationTest, SequentialDeserializeFromSingleStream) { ASSERT_TRUE(fory.serialize_to(bytes, envelope).ok()); OneByteIStream source(bytes); - ForyInputStream stream(source, 3); + StdInputStream stream(source, 3); auto first = fory.deserialize(stream); ASSERT_TRUE(first.ok()) << first.error().to_string(); @@ -206,7 +244,7 @@ TEST(StreamSerializationTest, SharedPointerIdentityRoundTrip) { ASSERT_TRUE(bytes_result.ok()) << bytes_result.error().to_string(); OneByteIStream source(std::move(bytes_result).value()); - ForyInputStream stream(source, 2); + StdInputStream stream(source, 2); auto result = fory.deserialize(stream); ASSERT_TRUE(result.ok()) << result.error().to_string(); ASSERT_NE(result.value().first, nullptr); @@ -230,11 +268,48 @@ TEST(StreamSerializationTest, TruncatedStreamReturnsError) { truncated.pop_back(); OneByteIStream source(truncated); - ForyInputStream stream(source, 4); + StdInputStream stream(source, 4); auto result = fory.deserialize(stream); EXPECT_FALSE(result.ok()); } +TEST(StreamSerializationTest, SerializeToOutputStreamRoundTrip) { + auto fory = Fory::builder().xlang(true).track_ref(true).build(); + register_stream_types(fory); + + StreamEnvelope original{ + "writer-roundtrip", {2, 4, 6, 8}, {{"x", 1}, {"y", 2}}, {5, -9}, true, + }; + + OneByteOStream out; + StdOutputStream writer(out); + auto write_result = fory.serialize(writer, original); + ASSERT_TRUE(write_result.ok()) << write_result.error().to_string(); + ASSERT_GT(write_result.value(), 0U); + + auto bytes = out.data(); + auto roundtrip = fory.deserialize(bytes); + ASSERT_TRUE(roundtrip.ok()) << roundtrip.error().to_string(); + EXPECT_EQ(roundtrip.value(), original); +} + +TEST(StreamSerializationTest, SerializeToOStreamOverloadParity) { + auto fory = Fory::builder().xlang(true).track_ref(true).build(); + register_stream_types(fory); + + StreamEnvelope original{ + "ostream-overload", {11, 22, 33}, {{"k", 99}}, {1, 2}, false, + }; + + auto expected = fory.serialize(original); + ASSERT_TRUE(expected.ok()) << expected.error().to_string(); + + OneByteOStream out; + auto write_result = fory.serialize(out, original); + ASSERT_TRUE(write_result.ok()) << write_result.error().to_string(); + EXPECT_EQ(out.data(), expected.value()); +} + } // namespace test } // namespace serialization } // namespace fory diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index 8125861149..ccde0e480a 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -2997,6 +2997,10 @@ struct Serializer>> { constexpr size_t field_count = FieldDescriptor::Size; detail::write_struct_fields_impl( obj, ctx, std::make_index_sequence{}, false); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return; + } + ctx.try_flush(); } static void write_data_generic(const T &obj, WriteContext &ctx, @@ -3025,6 +3029,10 @@ struct Serializer>> { constexpr size_t field_count = FieldDescriptor::Size; detail::write_struct_fields_impl( obj, ctx, std::make_index_sequence{}, has_generics); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return; + } + ctx.try_flush(); } static T read(ReadContext &ctx, RefMode ref_mode, bool read_type) { diff --git a/cpp/fory/util/buffer.cc b/cpp/fory/util/buffer.cc index e7006ae38b..eea3833548 100644 --- a/cpp/fory/util/buffer.cc +++ b/cpp/fory/util/buffer.cc @@ -31,20 +31,25 @@ Buffer::Buffer() { writer_index_ = 0; reader_index_ = 0; wrapped_vector_ = nullptr; - stream_reader_ = nullptr; + input_stream_ = nullptr; + output_stream_ = nullptr; } Buffer::Buffer(Buffer &&buffer) noexcept { + FORY_CHECK(buffer.output_stream_ == nullptr) + << "Cannot move stream-writer-owned Buffer"; data_ = buffer.data_; size_ = buffer.size_; own_data_ = buffer.own_data_; writer_index_ = buffer.writer_index_; reader_index_ = buffer.reader_index_; wrapped_vector_ = buffer.wrapped_vector_; - stream_reader_ = buffer.stream_reader_; - stream_reader_owner_ = std::move(buffer.stream_reader_owner_); - rebind_stream_reader_to_this(); - buffer.stream_reader_ = nullptr; + input_stream_ = buffer.input_stream_; + input_stream_owner_ = std::move(buffer.input_stream_owner_); + output_stream_ = buffer.output_stream_; + rebind_input_stream_to_this(); + buffer.input_stream_ = nullptr; + buffer.output_stream_ = nullptr; buffer.data_ = nullptr; buffer.size_ = 0; buffer.own_data_ = false; @@ -52,7 +57,11 @@ Buffer::Buffer(Buffer &&buffer) noexcept { } Buffer &Buffer::operator=(Buffer &&buffer) noexcept { - detach_stream_reader_from_this(); + FORY_CHECK(buffer.output_stream_ == nullptr) + << "Cannot move stream-writer-owned Buffer"; + FORY_CHECK(output_stream_ == nullptr) + << "Cannot assign to stream-writer-owned Buffer"; + detach_input_stream_from_this(); if (own_data_) { free(data_); data_ = nullptr; @@ -63,10 +72,12 @@ Buffer &Buffer::operator=(Buffer &&buffer) noexcept { writer_index_ = buffer.writer_index_; reader_index_ = buffer.reader_index_; wrapped_vector_ = buffer.wrapped_vector_; - stream_reader_ = buffer.stream_reader_; - stream_reader_owner_ = std::move(buffer.stream_reader_owner_); - rebind_stream_reader_to_this(); - buffer.stream_reader_ = nullptr; + input_stream_ = buffer.input_stream_; + input_stream_owner_ = std::move(buffer.input_stream_owner_); + output_stream_ = buffer.output_stream_; + rebind_input_stream_to_this(); + buffer.input_stream_ = nullptr; + buffer.output_stream_ = nullptr; buffer.data_ = nullptr; buffer.size_ = 0; buffer.own_data_ = false; @@ -75,7 +86,8 @@ Buffer &Buffer::operator=(Buffer &&buffer) noexcept { } Buffer::~Buffer() { - detach_stream_reader_from_this(); + clear_output_stream(); + detach_input_stream_from_this(); if (own_data_) { free(data_); data_ = nullptr; diff --git a/cpp/fory/util/buffer.h b/cpp/fory/util/buffer.h index 857902fc4e..a44de2d759 100644 --- a/cpp/fory/util/buffer.h +++ b/cpp/fory/util/buffer.h @@ -35,8 +35,8 @@ namespace fory { -class ForyInputStream; -class PythonStreamReader; +class StdInputStream; +class PyInputStream; // A buffer class for storing raw bytes with various methods for reading and // writing the bytes. @@ -46,7 +46,7 @@ class Buffer { Buffer(uint8_t *data, uint32_t size, bool own_data = true) : data_(data), size_(size), own_data_(own_data), wrapped_vector_(nullptr), - stream_reader_(nullptr) { + input_stream_(nullptr), output_stream_(nullptr) { writer_index_ = 0; reader_index_ = 0; } @@ -59,16 +59,17 @@ class Buffer { explicit Buffer(std::vector &vec) : data_(vec.data()), size_(static_cast(vec.size())), own_data_(false), writer_index_(static_cast(vec.size())), - reader_index_(0), wrapped_vector_(&vec), stream_reader_(nullptr) {} + reader_index_(0), wrapped_vector_(&vec), input_stream_(nullptr), + output_stream_(nullptr) {} - explicit Buffer(StreamReader &stream_reader) + explicit Buffer(InputStream &input_stream) : data_(nullptr), size_(0), own_data_(false), writer_index_(0), reader_index_(0), wrapped_vector_(nullptr), - stream_reader_(&stream_reader) { - stream_reader_->bind_buffer(this); - stream_reader_owner_ = stream_reader_->weak_from_this().lock(); - FORY_CHECK(&stream_reader_->get_buffer() == this) - << "StreamReader must hold and return the same Buffer instance"; + input_stream_(&input_stream), output_stream_(nullptr) { + input_stream_->bind_buffer(this); + input_stream_owner_ = input_stream_->weak_from_this().lock(); + FORY_CHECK(&input_stream_->get_buffer() == this) + << "InputStream must hold and return the same Buffer instance"; } Buffer(Buffer &&buffer) noexcept; @@ -81,6 +82,8 @@ class Buffer { if (this == &other) { return; } + FORY_CHECK(output_stream_ == nullptr && other.output_stream_ == nullptr) + << "Cannot swap stream-writer-owned Buffer"; using std::swap; swap(data_, other.data_); swap(size_, other.size_); @@ -88,10 +91,11 @@ class Buffer { swap(writer_index_, other.writer_index_); swap(reader_index_, other.reader_index_); swap(wrapped_vector_, other.wrapped_vector_); - swap(stream_reader_, other.stream_reader_); - swap(stream_reader_owner_, other.stream_reader_owner_); - rebind_stream_reader_to_this(); - other.rebind_stream_reader_to_this(); + swap(input_stream_, other.input_stream_); + swap(input_stream_owner_, other.input_stream_owner_); + swap(output_stream_, other.output_stream_); + rebind_input_stream_to_this(); + other.rebind_input_stream_to_this(); } /// \brief Return a pointer to the buffer's data @@ -103,12 +107,37 @@ class Buffer { FORY_ALWAYS_INLINE bool own_data() const { return own_data_; } FORY_ALWAYS_INLINE bool is_stream_backed() const { - return stream_reader_ != nullptr; + return input_stream_ != nullptr; + } + + FORY_ALWAYS_INLINE bool is_output_stream_backed() const { + return output_stream_ != nullptr; + } + + FORY_ALWAYS_INLINE void bind_output_stream(OutputStream *output_stream) { + if (output_stream_ == output_stream) { + return; + } + if (output_stream_ != nullptr) { + output_stream_->unbind_buffer(this); + } + output_stream_ = output_stream; + if (output_stream_ != nullptr) { + output_stream_->bind_buffer(this); + } + } + + FORY_ALWAYS_INLINE void clear_output_stream() { + if (output_stream_ == nullptr) { + return; + } + output_stream_->unbind_buffer(this); + output_stream_ = nullptr; } FORY_ALWAYS_INLINE void shrink_stream_buffer() { - if (stream_reader_ != nullptr) { - stream_reader_->shrink_buffer(); + if (input_stream_ != nullptr) { + input_stream_->shrink_buffer(); } } @@ -152,7 +181,7 @@ class Buffer { } FORY_ALWAYS_INLINE bool reader_index(uint32_t reader_index, Error &error) { - if (FORY_PREDICT_FALSE(reader_index > size_ && stream_reader_ != nullptr)) { + if (FORY_PREDICT_FALSE(reader_index > size_ && input_stream_ != nullptr)) { if (FORY_PREDICT_FALSE( !fill_buffer(reader_index - reader_index_, error))) { return false; @@ -806,6 +835,9 @@ class Buffer { grow(length); unsafe_put(writer_index_, data, length); increase_writer_index(length); + if (FORY_PREDICT_FALSE(output_stream_ != nullptr && writer_index_ > 4096)) { + output_stream_->try_flush(); + } } // =========================================================================== @@ -1223,24 +1255,25 @@ class Buffer { std::string hex() const; private: - friend class ForyInputStream; - friend class PythonStreamReader; + friend class StdInputStream; + friend class PyInputStream; + friend class OutputStream; - FORY_ALWAYS_INLINE void rebind_stream_reader_to_this() { - if (stream_reader_ == nullptr) { + FORY_ALWAYS_INLINE void rebind_input_stream_to_this() { + if (input_stream_ == nullptr) { return; } - stream_reader_->bind_buffer(this); - FORY_CHECK(&stream_reader_->get_buffer() == this) - << "StreamReader must hold and return the same Buffer instance"; + input_stream_->bind_buffer(this); + FORY_CHECK(&input_stream_->get_buffer() == this) + << "InputStream must hold and return the same Buffer instance"; } - FORY_ALWAYS_INLINE void detach_stream_reader_from_this() { - if (stream_reader_ == nullptr) { + FORY_ALWAYS_INLINE void detach_input_stream_from_this() { + if (input_stream_ == nullptr) { return; } - if (&stream_reader_->get_buffer() == this) { - stream_reader_->bind_buffer(nullptr); + if (&input_stream_->get_buffer() == this) { + input_stream_->bind_buffer(nullptr); } } @@ -1248,11 +1281,11 @@ class Buffer { if (FORY_PREDICT_TRUE(min_fill_size <= size_ - reader_index_)) { return true; } - if (FORY_PREDICT_TRUE(stream_reader_ == nullptr)) { + if (FORY_PREDICT_TRUE(input_stream_ == nullptr)) { error.set_buffer_out_of_bound(reader_index_, min_fill_size, size_); return false; } - auto fill_result = stream_reader_->fill_buffer(min_fill_size); + auto fill_result = input_stream_->fill_buffer(min_fill_size); if (FORY_PREDICT_FALSE(!fill_result.ok())) { error = std::move(fill_result).error(); return false; @@ -1362,8 +1395,9 @@ class Buffer { uint32_t writer_index_; uint32_t reader_index_; std::vector *wrapped_vector_ = nullptr; - StreamReader *stream_reader_ = nullptr; - std::shared_ptr stream_reader_owner_; + InputStream *input_stream_ = nullptr; + std::shared_ptr input_stream_owner_; + OutputStream *output_stream_ = nullptr; }; /// \brief Allocate a fixed-size mutable buffer from the default memory pool diff --git a/cpp/fory/util/buffer_test.cc b/cpp/fory/util/buffer_test.cc index d06a65ba93..7f32c06850 100644 --- a/cpp/fory/util/buffer_test.cc +++ b/cpp/fory/util/buffer_test.cc @@ -70,6 +70,33 @@ class OneByteIStream : public std::istream { OneByteStreamBuf buf_; }; +class CountingOutputStream final : public OutputStream { +public: + Result write_to_stream(const uint8_t *src, + uint32_t length) override { + write_calls_++; + if (length == 0) { + return Result(); + } + data_.insert(data_.end(), src, src + length); + return Result(); + } + + Result flush_stream() override { + flush_calls_++; + return Result(); + } + + const std::vector &data() const { return data_; } + uint32_t write_calls() const { return write_calls_; } + uint32_t flush_calls() const { return flush_calls_; } + +private: + std::vector data_; + uint32_t write_calls_ = 0; + uint32_t flush_calls_ = 0; +}; + TEST(Buffer, to_string) { std::shared_ptr buffer; allocate_buffer(16, &buffer); @@ -226,7 +253,7 @@ TEST(Buffer, StreamReadFromOneByteSource) { raw.resize(writer.writer_index()); OneByteIStream one_byte_stream(raw); - ForyInputStream stream(one_byte_stream, 8); + StdInputStream stream(one_byte_stream, 8); Buffer reader(stream); Error error; @@ -247,7 +274,7 @@ TEST(Buffer, StreamReadFromOneByteSource) { TEST(Buffer, StreamGetAndReaderIndexFromOneByteSource) { std::vector raw{0x11, 0x22, 0x33, 0x44, 0x55}; OneByteIStream one_byte_stream(raw); - ForyInputStream stream(one_byte_stream, 2); + StdInputStream stream(one_byte_stream, 2); Buffer reader(stream); Error error; ASSERT_TRUE(reader.ensure_readable(4, error)) << error.to_string(); @@ -261,7 +288,7 @@ TEST(Buffer, StreamGetAndReaderIndexFromOneByteSource) { TEST(Buffer, StreamReadBytesAndSkipAdvanceReaderIndex) { std::vector raw{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; OneByteIStream one_byte_stream(raw); - ForyInputStream stream(one_byte_stream, 2); + StdInputStream stream(one_byte_stream, 2); Buffer reader(stream); Error error; uint8_t out[5] = {0}; @@ -282,7 +309,7 @@ TEST(Buffer, StreamReadBytesAndSkipAdvanceReaderIndex) { TEST(Buffer, StreamSkipAndUnread) { std::vector raw{0x01, 0x02, 0x03, 0x04, 0x05}; OneByteIStream one_byte_stream(raw); - ForyInputStream stream(one_byte_stream, 2); + StdInputStream stream(one_byte_stream, 2); auto fill_result = stream.fill_buffer(4); ASSERT_TRUE(fill_result.ok()) << fill_result.error().to_string(); @@ -306,13 +333,90 @@ TEST(Buffer, StreamSkipAndUnread) { TEST(Buffer, StreamReadErrorWhenInsufficientData) { std::vector raw{0x01, 0x02, 0x03}; OneByteIStream one_byte_stream(raw); - ForyInputStream stream(one_byte_stream, 2); + StdInputStream stream(one_byte_stream, 2); Buffer reader(stream); Error error; EXPECT_EQ(reader.read_uint32(error), 0U); EXPECT_FALSE(error.ok()); EXPECT_EQ(error.code(), ErrorCode::BufferOutOfBound); } + +TEST(Buffer, OutputStreamThresholdFlushOnWriteBytes) { + CountingOutputStream writer; + Buffer *buffer = writer.get_buffer(); + ASSERT_NE(buffer, nullptr); + + std::vector payload(5000, 7); + buffer->write_bytes(payload.data(), static_cast(payload.size())); + + EXPECT_EQ(buffer->writer_index(), 0U); + EXPECT_EQ(writer.data().size(), payload.size()); + EXPECT_GE(writer.write_calls(), 1U); +} + +TEST(Buffer, OutputStreamThresholdFlushCanBeTemporarilyDisabled) { + CountingOutputStream writer; + Buffer *buffer = writer.get_buffer(); + ASSERT_NE(buffer, nullptr); + writer.enter_flush_barrier(); + + std::vector payload(5000, 7); + buffer->write_bytes(payload.data(), static_cast(payload.size())); + + EXPECT_EQ(buffer->writer_index(), payload.size()); + EXPECT_EQ(writer.data().size(), 0U); + + writer.exit_flush_barrier(); + writer.try_flush(); + ASSERT_FALSE(writer.has_error()) << writer.error().to_string(); + EXPECT_EQ(buffer->writer_index(), 0U); + EXPECT_EQ(writer.data().size(), payload.size()); +} + +TEST(Buffer, OutputStreamForceFlush) { + CountingOutputStream writer; + Buffer *buffer = writer.get_buffer(); + ASSERT_NE(buffer, nullptr); + + std::vector payload{1, 2, 3, 4, 5}; + buffer->write_bytes(payload.data(), static_cast(payload.size())); + EXPECT_EQ(buffer->writer_index(), payload.size()); + + writer.force_flush(); + ASSERT_FALSE(writer.has_error()) << writer.error().to_string(); + EXPECT_EQ(buffer->writer_index(), 0U); + EXPECT_EQ(writer.data(), payload); + EXPECT_EQ(writer.flush_calls(), 1U); +} + +TEST(Buffer, OutputStreamRebindDetachesPreviousBufferBacklink) { + CountingOutputStream writer; + std::shared_ptr first; + std::shared_ptr second; + allocate_buffer(16, &first); + allocate_buffer(16, &second); + + first->bind_output_stream(&writer); + second->bind_output_stream(&writer); + EXPECT_FALSE(first->is_output_stream_backed()); + EXPECT_TRUE(second->is_output_stream_backed()); + + writer.enter_flush_barrier(); + std::vector second_payload(5000, 7); + second->write_bytes(second_payload.data(), + static_cast(second_payload.size())); + EXPECT_EQ(second->writer_index(), second_payload.size()); + writer.exit_flush_barrier(); + + std::vector first_payload(5000, 3); + first->write_bytes(first_payload.data(), + static_cast(first_payload.size())); + + // A stale backlink on `first` would call try_flush() and flush `second` + // because `second` is still the stream's active buffer. + EXPECT_EQ(second->writer_index(), second_payload.size()); + EXPECT_EQ(writer.data().size(), 0U); +} } // namespace fory int main(int argc, char **argv) { diff --git a/cpp/fory/util/stream.cc b/cpp/fory/util/stream.cc index a8466ab833..0d06ec7dee 100644 --- a/cpp/fory/util/stream.cc +++ b/cpp/fory/util/stream.cc @@ -28,7 +28,78 @@ namespace fory { -ForyInputStream::ForyInputStream(std::istream &stream, uint32_t buffer_size) +OutputStream::OutputStream(uint32_t buffer_size) + : buffer_(std::make_unique()) { + const uint32_t actual_size = std::max(buffer_size, 1U); + buffer_->reserve(actual_size); + buffer_->writer_index(0); + buffer_->reader_index(0); + buffer_->bind_output_stream(this); + active_buffer_ = buffer_.get(); +} + +OutputStream::~OutputStream() { + if (active_buffer_ != nullptr && active_buffer_ != buffer_.get()) { + active_buffer_->clear_output_stream(); + } + if (buffer_ != nullptr) { + buffer_->clear_output_stream(); + } + active_buffer_ = nullptr; +} + +void OutputStream::reset() { + flushed_bytes_ = 0; + flush_barrier_depth_ = 0; + error_.reset(); + Buffer *buffer = active_buffer(); + if (buffer != nullptr) { + buffer->writer_index(0); + buffer->reader_index(0); + } +} + +void OutputStream::bind_buffer(Buffer *buffer) { + Buffer *next = buffer == nullptr ? buffer_.get() : buffer; + if (active_buffer_ == next) { + return; + } + if (active_buffer_ != nullptr && active_buffer_ != buffer_.get()) { + // Rebinding must detach the previous external buffer to avoid stale + // backlinks that can trigger misdirected flushes and dangling pointers. + active_buffer_->output_stream_ = nullptr; + } + active_buffer_ = next; +} + +void OutputStream::unbind_buffer(Buffer *buffer) { + if (active_buffer_ == buffer) { + active_buffer_ = buffer_.get(); + } +} + +uint32_t OutputStream::active_buffer_writer_index() { + Buffer *buffer = active_buffer(); + return buffer == nullptr ? 0U : buffer->writer_index(); +} + +void OutputStream::flush_buffer_data() { + Buffer *buffer = active_buffer(); + if (buffer == nullptr || buffer->writer_index() == 0) { + return; + } + const uint32_t bytes_to_flush = buffer->writer_index(); + auto write_result = write_to_stream(buffer->data(), bytes_to_flush); + if (FORY_PREDICT_FALSE(!write_result.ok())) { + set_error(std::move(write_result).error()); + return; + } + flushed_bytes_ += bytes_to_flush; + buffer->writer_index(0); + buffer->reader_index(0); +} + +StdInputStream::StdInputStream(std::istream &stream, uint32_t buffer_size) : stream_(&stream), data_(std::max(buffer_size, static_cast(1))), initial_buffer_size_( @@ -37,8 +108,8 @@ ForyInputStream::ForyInputStream(std::istream &stream, uint32_t buffer_size) bind_buffer(owned_buffer_.get()); } -ForyInputStream::ForyInputStream(std::shared_ptr stream, - uint32_t buffer_size) +StdInputStream::StdInputStream(std::shared_ptr stream, + uint32_t buffer_size) : stream_owner_(std::move(stream)), stream_(stream_owner_.get()), data_(std::max(buffer_size, static_cast(1))), initial_buffer_size_( @@ -48,9 +119,9 @@ ForyInputStream::ForyInputStream(std::shared_ptr stream, bind_buffer(owned_buffer_.get()); } -ForyInputStream::~ForyInputStream() = default; +StdInputStream::~StdInputStream() = default; -Result ForyInputStream::fill_buffer(uint32_t min_fill_size) { +Result StdInputStream::fill_buffer(uint32_t min_fill_size) { if (min_fill_size == 0 || remaining_size() >= min_fill_size) { return Result(); } @@ -92,7 +163,7 @@ Result ForyInputStream::fill_buffer(uint32_t min_fill_size) { return Result(); } -Result ForyInputStream::read_to(uint8_t *dst, uint32_t length) { +Result StdInputStream::read_to(uint8_t *dst, uint32_t length) { if (length == 0) { return Result(); } @@ -106,7 +177,7 @@ Result ForyInputStream::read_to(uint8_t *dst, uint32_t length) { return Result(); } -Result ForyInputStream::skip(uint32_t size) { +Result StdInputStream::skip(uint32_t size) { if (size == 0) { return Result(); } @@ -118,7 +189,7 @@ Result ForyInputStream::skip(uint32_t size) { return Result(); } -Result ForyInputStream::unread(uint32_t size) { +Result StdInputStream::unread(uint32_t size) { if (FORY_PREDICT_FALSE(size > buffer_->reader_index_)) { return Unexpected(Error::buffer_out_of_bound(buffer_->reader_index_, size, buffer_->size_)); @@ -127,7 +198,7 @@ Result ForyInputStream::unread(uint32_t size) { return Result(); } -void ForyInputStream::shrink_buffer() { +void StdInputStream::shrink_buffer() { if (buffer_ == nullptr) { return; } @@ -167,22 +238,22 @@ void ForyInputStream::shrink_buffer() { } } -Buffer &ForyInputStream::get_buffer() { return *buffer_; } +Buffer &StdInputStream::get_buffer() { return *buffer_; } -uint32_t ForyInputStream::remaining_size() const { +uint32_t StdInputStream::remaining_size() const { return buffer_->size_ - buffer_->reader_index_; } -void ForyInputStream::reserve(uint32_t new_size) { +void StdInputStream::reserve(uint32_t new_size) { data_.resize(new_size); buffer_->data_ = data_.data(); } -void ForyInputStream::bind_buffer(Buffer *buffer) { +void StdInputStream::bind_buffer(Buffer *buffer) { Buffer *target = buffer == nullptr ? owned_buffer_.get() : buffer; if (target == nullptr) { if (buffer_ != nullptr) { - buffer_->stream_reader_ = nullptr; + buffer_->input_stream_ = nullptr; } buffer_ = nullptr; return; @@ -192,7 +263,7 @@ void ForyInputStream::bind_buffer(Buffer *buffer) { buffer_->data_ = data_.data(); buffer_->own_data_ = false; buffer_->wrapped_vector_ = nullptr; - buffer_->stream_reader_ = this; + buffer_->input_stream_ = this; return; } @@ -201,7 +272,7 @@ void ForyInputStream::bind_buffer(Buffer *buffer) { target->size_ = source->size_; target->writer_index_ = source->writer_index_; target->reader_index_ = source->reader_index_; - source->stream_reader_ = nullptr; + source->input_stream_ = nullptr; } else { target->size_ = 0; target->writer_index_ = 0; @@ -212,7 +283,46 @@ void ForyInputStream::bind_buffer(Buffer *buffer) { buffer_->data_ = data_.data(); buffer_->own_data_ = false; buffer_->wrapped_vector_ = nullptr; - buffer_->stream_reader_ = this; + buffer_->input_stream_ = this; +} + +StdOutputStream::StdOutputStream(std::ostream &stream) : stream_(&stream) {} + +StdOutputStream::StdOutputStream(std::shared_ptr stream) + : stream_owner_(std::move(stream)), stream_(stream_owner_.get()) { + FORY_CHECK(stream_owner_ != nullptr) << "stream must not be null"; +} + +StdOutputStream::~StdOutputStream() = default; + +Result StdOutputStream::write_to_stream(const uint8_t *src, + uint32_t length) { + if (length == 0) { + return Result(); + } + if (src == nullptr) { + return Unexpected(Error::invalid("output source pointer is null")); + } + if (stream_ == nullptr) { + return Unexpected(Error::io_error("output stream is null")); + } + stream_->write(reinterpret_cast(src), + static_cast(length)); + if (!(*stream_)) { + return Unexpected(Error::io_error("failed to write to output stream")); + } + return Result(); +} + +Result StdOutputStream::flush_stream() { + if (stream_ == nullptr) { + return Unexpected(Error::io_error("output stream is null")); + } + stream_->flush(); + if (!(*stream_)) { + return Unexpected(Error::io_error("failed to flush output stream")); + } + return Result(); } } // namespace fory diff --git a/cpp/fory/util/stream.h b/cpp/fory/util/stream.h index be37e66b87..c97ceafe04 100644 --- a/cpp/fory/util/stream.h +++ b/cpp/fory/util/stream.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "fory/util/error.h" @@ -31,9 +32,98 @@ namespace fory { class Buffer; -class StreamReader : public std::enable_shared_from_this { +class OutputStream { public: - virtual ~StreamReader() = default; + explicit OutputStream(uint32_t buffer_size = 4096); + + virtual ~OutputStream(); + + FORY_ALWAYS_INLINE Buffer *get_buffer() { return buffer_.get(); } + + FORY_ALWAYS_INLINE const Buffer *get_buffer() const { return buffer_.get(); } + + FORY_ALWAYS_INLINE void enter_flush_barrier() { flush_barrier_depth_++; } + + FORY_ALWAYS_INLINE void exit_flush_barrier() { flush_barrier_depth_--; } + + FORY_ALWAYS_INLINE bool try_flush() { + if (FORY_PREDICT_FALSE(flush_barrier_depth_ != 0)) { + return false; + } + const uint32_t bytes_before_flush = active_buffer_writer_index(); + if (FORY_PREDICT_FALSE(bytes_before_flush <= 4096)) { + return false; + } + flush_buffer_data(); + if (FORY_PREDICT_FALSE(!error_.ok())) { + return false; + } + return bytes_before_flush != 0; + } + + FORY_ALWAYS_INLINE void force_flush() { + if (FORY_PREDICT_FALSE(!error_.ok())) { + return; + } + flush_buffer_data(); + if (FORY_PREDICT_FALSE(!error_.ok())) { + return; + } + auto flush_result = flush_stream(); + if (FORY_PREDICT_FALSE(!flush_result.ok())) { + set_error(std::move(flush_result).error()); + } + } + + FORY_ALWAYS_INLINE uint32_t flush_barrier_depth() const { + return flush_barrier_depth_; + } + + FORY_ALWAYS_INLINE size_t flushed_bytes() const { return flushed_bytes_; } + + void reset(); + + FORY_ALWAYS_INLINE bool has_error() const { return !error_.ok(); } + + FORY_ALWAYS_INLINE const Error &error() const { return error_; } + +protected: + virtual Result write_to_stream(const uint8_t *src, + uint32_t length) = 0; + + virtual Result flush_stream() = 0; + +private: + void bind_buffer(Buffer *buffer); + + void unbind_buffer(Buffer *buffer); + + FORY_ALWAYS_INLINE Buffer *active_buffer() { + return active_buffer_ == nullptr ? buffer_.get() : active_buffer_; + } + + void flush_buffer_data(); + + uint32_t active_buffer_writer_index(); + + FORY_ALWAYS_INLINE void set_error(Error error) { + if (error_.ok()) { + error_ = std::move(error); + } + } + + std::unique_ptr buffer_; + Buffer *active_buffer_ = nullptr; + size_t flushed_bytes_ = 0; + uint32_t flush_barrier_depth_ = 0; + Error error_; + + friend class Buffer; +}; + +class InputStream : public std::enable_shared_from_this { +public: + virtual ~InputStream() = default; virtual Result fill_buffer(uint32_t min_fill_size) = 0; @@ -52,14 +142,14 @@ class StreamReader : public std::enable_shared_from_this { virtual void bind_buffer(Buffer *buffer) = 0; }; -class ForyInputStream final : public StreamReader { +class StdInputStream final : public InputStream { public: - explicit ForyInputStream(std::istream &stream, uint32_t buffer_size = 4096); + explicit StdInputStream(std::istream &stream, uint32_t buffer_size = 4096); - explicit ForyInputStream(std::shared_ptr stream, - uint32_t buffer_size = 4096); + explicit StdInputStream(std::shared_ptr stream, + uint32_t buffer_size = 4096); - ~ForyInputStream() override; + ~StdInputStream() override; Result fill_buffer(uint32_t min_fill_size) override; @@ -88,4 +178,23 @@ class ForyInputStream final : public StreamReader { std::unique_ptr owned_buffer_; }; +class StdOutputStream final : public OutputStream { +public: + explicit StdOutputStream(std::ostream &stream); + + explicit StdOutputStream(std::shared_ptr stream); + + ~StdOutputStream() override; + +protected: + Result write_to_stream(const uint8_t *src, + uint32_t length) override; + + Result flush_stream() override; + +private: + std::shared_ptr stream_owner_; + std::ostream *stream_ = nullptr; +}; + } // namespace fory diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 6a75abb031..2ccb8115b2 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -157,6 +157,7 @@ class Fory: "is_peer_out_of_band_enabled", "max_depth", "depth", + "_output_stream", "field_nullable", "policy", ) @@ -242,6 +243,7 @@ def __init__( self.is_peer_out_of_band_enabled = False self.max_depth = max_depth self.depth = 0 + self._output_stream = None def register( self, @@ -394,6 +396,37 @@ def dumps( """ return self.serialize(obj, buffer, buffer_callback, unsupported_callback) + def dump(self, obj, stream): + """ + Serialize an object directly to a writable stream. + + Args: + obj: The object to serialize + stream: Writable stream implementing write(...) + + Notes: + The stream must be a non-retaining sink: ``write(data)`` must + synchronously consume ``data`` before returning. Fory may reuse or + modify the underlying buffer after ``write`` returns, so retaining + the passed object (or a view of it) is unsupported. If your sink + needs retention, copy bytes inside ``write``. + """ + try: + self.buffer.set_writer_index(0) + self._output_stream = Buffer.wrap_output_stream(stream) + self.buffer.bind_output_stream(self._output_stream) + self._serialize( + obj, + self.buffer, + buffer_callback=None, + unsupported_callback=None, + ) + self.force_flush() + finally: + self.buffer.bind_output_stream(None) + self._output_stream = None + self.reset_write() + def loads( self, buffer: Union[Buffer, bytes], @@ -435,12 +468,17 @@ def serialize( """ try: - return self._serialize( + write_buffer = self._serialize( obj, buffer, buffer_callback=buffer_callback, unsupported_callback=unsupported_callback, ) + if write_buffer is not self.buffer: + return write_buffer + if write_buffer.get_output_stream() is not None: + return write_buffer + return write_buffer.to_bytes(0, write_buffer.get_writer_index()) finally: self.reset_write() @@ -450,7 +488,7 @@ def _serialize( buffer: Buffer = None, buffer_callback=None, unsupported_callback=None, - ) -> Union[Buffer, bytes]: + ) -> Buffer: assert self.depth == 0, "Nested serialization should use write_ref/write_no_ref." self.depth += 1 self.buffer_callback = buffer_callback @@ -462,6 +500,7 @@ def _serialize( # 1byte used for bit mask buffer.grow(1) buffer.set_writer_index(mask_index + 1) + buffer.put_int8(mask_index, 0) if obj is None: set_bit(buffer, mask_index, 0) else: @@ -476,10 +515,29 @@ def _serialize( # Type definitions are now written inline (streaming) instead of deferred to end self.write_ref(buffer, obj) - if buffer is not self.buffer: - return buffer - else: - return buffer.to_bytes(0, buffer.get_writer_index()) + return buffer + + def enter_flush_barrier(self): + output_stream = self._output_stream + if output_stream is not None: + output_stream.enter_flush_barrier() + + def exit_flush_barrier(self): + output_stream = self._output_stream + if output_stream is not None: + output_stream.exit_flush_barrier() + + def try_flush(self): + if self._output_stream is None or self.buffer.get_writer_index() <= 4096: + return + output_stream = self._output_stream + output_stream.try_flush() + + def force_flush(self): + output_stream = self._output_stream + if output_stream is None: + return + output_stream.force_flush() def write_ref(self, buffer, obj, typeinfo=None, serializer=None): if serializer is None and typeinfo is not None: @@ -687,6 +745,7 @@ def reset_write(self): self.metastring_resolver.reset_write() self.buffer_callback = None self._unsupported_callback = None + self._output_stream = None def reset_read(self): """ @@ -907,3 +966,20 @@ def loads( unsupported_objects: Iterable = None, ): return self.deserialize(buffer, buffers, unsupported_objects) + + def dump(self, obj, stream): + """ + Serialize an object directly to a writable stream. + + Notes: + The stream must be a non-retaining sink: ``write(data)`` must + synchronously consume ``data`` before returning. Fory may reuse or + modify the underlying buffer after ``write`` returns, so retaining + the passed object (or a view of it) is unsupported. If your sink + needs retention, copy bytes inside ``write``. + """ + fory = self._get_fory() + try: + return fory.dump(obj, stream) + finally: + self._return_fory(fory) diff --git a/python/pyfory/buffer.pxi b/python/pyfory/buffer.pxi index 74795d8239..d9c66e2b5b 100644 --- a/python/pyfory/buffer.pxi +++ b/python/pyfory/buffer.pxi @@ -14,17 +14,22 @@ from cpython.unicode cimport ( PyUnicode_DecodeUTF8, ) from cpython.bytes cimport PyBytes_AsString, PyBytes_FromStringAndSize, PyBytes_AS_STRING -from libcpp.memory cimport shared_ptr +from libcpp.memory cimport shared_ptr, unique_ptr from libcpp.utility cimport move from cython.operator cimport dereference as deref from libcpp.string cimport string as c_string from libc.stdint cimport * from libcpp cimport bool as c_bool from pyfory.includes.libutil cimport( - CBuffer, allocate_buffer, get_bit as c_get_bit, set_bit as c_set_bit, clear_bit as c_clear_bit, + CBuffer, COutputStream, allocate_buffer, get_bit as c_get_bit, set_bit as c_set_bit, clear_bit as c_clear_bit, set_bit_to as c_set_bit_to, CError, CErrorCode, CResultVoidError, utf16_has_surrogate_pairs ) -from pyfory.includes.libpyfory cimport Fory_PyCreateBufferFromStream +from pyfory.includes.libpyfory cimport ( + Fory_PyCreateBufferFromStream, + Fory_PyCreateOutputStream, + Fory_PyBindBufferToOutputStream, + Fory_PyClearBufferOutputStream +) import os from pyfory.error import raise_fory_error @@ -39,6 +44,76 @@ cdef class _SharedBufferOwner: cdef shared_ptr[CBuffer] buffer +cdef class Buffer + + +@cython.final +cdef class PyOutputStream: + cdef object stream + cdef unique_ptr[COutputStream] c_output_stream + + @staticmethod + cdef inline PyOutputStream from_stream(object stream): + cdef c_string stream_error + cdef COutputStream* raw_writer = NULL + if stream is None: + raise ValueError("stream must not be None") + if Fory_PyCreateOutputStream( + stream, &raw_writer, &stream_error + ) != 0: + raise ValueError(stream_error.decode("UTF-8")) + cdef PyOutputStream writer = PyOutputStream.__new__(PyOutputStream) + writer.stream = stream + writer.c_output_stream.reset(raw_writer) + if raw_writer != NULL: + raw_writer.reset() + return writer + + cdef inline COutputStream* get_c_output_stream(self): + return self.c_output_stream.get() + + cpdef inline object get_output_stream(self): + return self.stream + + cpdef inline void reset(self): + cdef COutputStream* output_stream = self.c_output_stream.get() + if output_stream == NULL: + raise ValueError("OutputStream is null") + output_stream.reset() + + cpdef inline void enter_flush_barrier(self): + cdef COutputStream* output_stream = self.c_output_stream.get() + if output_stream == NULL: + raise ValueError("OutputStream is null") + output_stream.enter_flush_barrier() + + cpdef inline void exit_flush_barrier(self): + cdef COutputStream* output_stream = self.c_output_stream.get() + if output_stream == NULL: + raise ValueError("OutputStream is null") + output_stream.exit_flush_barrier() + + cpdef inline void try_flush(self): + cdef COutputStream* output_stream = self.c_output_stream.get() + if output_stream == NULL: + raise ValueError("OutputStream is null") + output_stream.try_flush() + if output_stream.has_error(): + raise ValueError(output_stream.error().to_string().decode("UTF-8")) + + cpdef inline void force_flush(self): + cdef COutputStream* output_stream = self.c_output_stream.get() + if output_stream == NULL: + raise ValueError("OutputStream is null") + output_stream.force_flush() + if output_stream.has_error(): + raise ValueError(output_stream.error().to_string().decode("UTF-8")) + + +cpdef inline PyOutputStream _wrap_output_stream(object stream): + return PyOutputStream.from_stream(stream) + + @cython.final cdef class Buffer: cdef: @@ -46,6 +121,7 @@ cdef class Buffer: CError _error # hold python buffer reference count object data + object output_stream Py_ssize_t shape[1] Py_ssize_t stride[1] @@ -67,6 +143,7 @@ cdef class Buffer: self.c_buffer = CBuffer(address, length_, False) self.c_buffer.reader_index(0) self.c_buffer.writer_index(0) + self.output_stream = None @classmethod def from_stream(cls, stream not None, uint32_t buffer_size=4096): @@ -82,6 +159,7 @@ cdef class Buffer: buffer.c_buffer = move(deref(stream_buffer)) del stream_buffer buffer.data = stream + buffer.output_stream = None buffer.c_buffer.reader_index(0) buffer.c_buffer.writer_index(0) return buffer @@ -94,6 +172,7 @@ cdef class Buffer: cdef _SharedBufferOwner owner = _SharedBufferOwner.__new__(_SharedBufferOwner) owner.buffer = c_buffer buffer.data = owner + buffer.output_stream = None buffer.c_buffer.reader_index(0) buffer.c_buffer.writer_index(0) return buffer @@ -107,10 +186,37 @@ cdef class Buffer: buffer.c_buffer = move(deref(buf)) del buf buffer.data = None + buffer.output_stream = None buffer.c_buffer.reader_index(0) buffer.c_buffer.writer_index(0) return buffer + @staticmethod + def wrap_output_stream(stream): + return _wrap_output_stream(stream) + + cpdef inline void bind_output_stream(self, object output): + cdef c_string stream_error + cdef PyOutputStream output_stream + if Fory_PyClearBufferOutputStream(&self.c_buffer, &stream_error) != 0: + raise ValueError(stream_error.decode("UTF-8")) + if output is None: + self.output_stream = None + return + if isinstance(output, PyOutputStream): + output_stream = output + else: + output_stream = _wrap_output_stream(output) + output_stream.reset() + if Fory_PyBindBufferToOutputStream( + &self.c_buffer, output_stream.get_c_output_stream(), &stream_error + ) != 0: + raise ValueError(stream_error.decode("UTF-8")) + self.output_stream = output_stream + + cpdef inline object get_output_stream(self): + return self.output_stream + cdef inline void _raise_if_error(self): cdef CErrorCode code cdef c_string message diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi index 48bd3efca6..150f354e68 100644 --- a/python/pyfory/collection.pxi +++ b/python/pyfory/collection.pxi @@ -800,6 +800,7 @@ cdef class MapSerializer(Serializer): break key_cls = type(key) value_cls = type(value) + fory.enter_flush_barrier() buffer.write_int16(-1) chunk_size_offset = buffer.get_writer_index() - 1 chunk_header = 0 @@ -888,6 +889,8 @@ cdef class MapSerializer(Serializer): key_serializer = self.key_serializer value_serializer = self.value_serializer buffer.put_int8(chunk_size_offset, chunk_size) + fory.exit_flush_barrier() + fory.try_flush() cpdef inline read(self, Buffer buffer): cdef Fory fory = self.fory diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py index c7d6b9c376..5bb88cbaa2 100644 --- a/python/pyfory/collection.py +++ b/python/pyfory/collection.py @@ -420,6 +420,7 @@ def write(self, buffer, o): key_cls = type(key) value_cls = type(value) + fory.enter_flush_barrier() buffer.write_int16(-1) chunk_size_offset = buffer.get_writer_index() - 1 chunk_header = 0 @@ -472,6 +473,8 @@ def write(self, buffer, o): key_serializer = self.key_serializer value_serializer = self.value_serializer buffer.put_uint8(chunk_size_offset, chunk_size) + fory.exit_flush_barrier() + fory.try_flush() def read(self, buffer): fory = self.fory diff --git a/python/pyfory/cpp/pyfory.cc b/python/pyfory/cpp/pyfory.cc index 62251b7f66..13c28103b4 100644 --- a/python/pyfory/cpp/pyfory.cc +++ b/python/pyfory/cpp/pyfory.cc @@ -135,10 +135,145 @@ static bool resolve_python_stream_read_method(PyObject *stream, return false; } -class PythonStreamReader final : public StreamReader { +static bool resolve_python_stream_write_method(PyObject *stream, + std::string *error_message) { + const int has_write = PyObject_HasAttrString(stream, "write"); + if (has_write < 0) { + *error_message = fetch_python_error_message(); + return false; + } + if (has_write == 0) { + *error_message = "stream object must provide write(data) method"; + return false; + } + PyObject *method_obj = PyObject_GetAttrString(stream, "write"); + if (method_obj == nullptr) { + *error_message = fetch_python_error_message(); + return false; + } + const bool is_callable = PyCallable_Check(method_obj) != 0; + Py_DECREF(method_obj); + if (!is_callable) { + *error_message = "stream.write must be callable"; + return false; + } + return true; +} + +class PyOutputStream final : public OutputStream { public: - explicit PythonStreamReader(PyObject *stream, uint32_t buffer_size, - PythonStreamReadMethod read_method) + explicit PyOutputStream(PyObject *stream, uint32_t buffer_size = 4096) + : OutputStream(buffer_size), stream_(stream) { + FORY_CHECK(stream_ != nullptr) << "stream must not be null"; + Py_INCREF(stream_); + } + + ~PyOutputStream() override { + if (stream_ != nullptr) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + Py_DECREF(stream_); + PyGILState_Release(gil_state); + stream_ = nullptr; + } + } + +protected: + Result write_to_stream(const uint8_t *src, + uint32_t length) override { + if (length == 0) { + return Result(); + } + if (src == nullptr) { + return Unexpected(Error::invalid("output source pointer is null")); + } + if (stream_ == nullptr) { + return Unexpected(Error::io_error("output stream is null")); + } + PyGILState_STATE gil_state = PyGILState_Ensure(); + uint32_t total_written = 0; + while (total_written < length) { + const uint32_t remaining = length - total_written; + // Contract: stream.write must consume bytes synchronously before return. + // The memoryview below is a transient view over serializer-managed + // storage and is not safe to retain after write(...) returns. + PyObject *chunk = PyMemoryView_FromMemory( + reinterpret_cast( + const_cast(src + static_cast(total_written))), + static_cast(remaining), PyBUF_READ); + if (chunk == nullptr) { + const std::string message = fetch_python_error_message(); + PyGILState_Release(gil_state); + return Unexpected(Error::io_error(message)); + } + PyObject *written_obj = PyObject_CallMethod(stream_, "write", "O", chunk); + Py_DECREF(chunk); + if (written_obj == nullptr) { + const std::string message = fetch_python_error_message(); + PyGILState_Release(gil_state); + return Unexpected(Error::io_error(message)); + } + if (written_obj == Py_None) { + Py_DECREF(written_obj); + total_written = length; + break; + } + const long long wrote_value = PyLong_AsLongLong(written_obj); + Py_DECREF(written_obj); + if (wrote_value == -1 && PyErr_Occurred() != nullptr) { + const std::string message = fetch_python_error_message(); + PyGILState_Release(gil_state); + return Unexpected(Error::io_error(message)); + } + if (wrote_value <= 0) { + PyGILState_Release(gil_state); + return Unexpected( + Error::io_error("stream write returned non-positive bytes")); + } + const uint64_t wrote_u64 = static_cast(wrote_value); + if (wrote_u64 >= remaining) { + total_written = length; + } else { + total_written += static_cast(wrote_u64); + } + } + PyGILState_Release(gil_state); + return Result(); + } + + Result flush_stream() override { + if (stream_ == nullptr) { + return Unexpected(Error::io_error("output stream is null")); + } + PyGILState_STATE gil_state = PyGILState_Ensure(); + const int has_flush = PyObject_HasAttrString(stream_, "flush"); + if (has_flush < 0) { + const std::string message = fetch_python_error_message(); + PyGILState_Release(gil_state); + return Unexpected(Error::io_error(message)); + } + if (has_flush == 0) { + PyGILState_Release(gil_state); + return Result(); + } + PyObject *result = PyObject_CallMethod(stream_, "flush", nullptr); + if (result == nullptr) { + const std::string message = fetch_python_error_message(); + PyGILState_Release(gil_state); + return Unexpected(Error::io_error(message)); + } + Py_DECREF(result); + PyGILState_Release(gil_state); + return Result(); + } + +private: + PyObject *stream_ = nullptr; +}; + +class PyInputStream final : public InputStream { +public: + explicit PyInputStream(PyObject *stream, uint32_t buffer_size, + PythonStreamReadMethod read_method) : stream_(stream), read_method_(read_method), read_method_name_(python_stream_read_method_name(read_method)), data_(std::max(buffer_size, static_cast(1))), @@ -150,7 +285,7 @@ class PythonStreamReader final : public StreamReader { bind_buffer(owned_buffer_.get()); } - ~PythonStreamReader() override { + ~PyInputStream() override { if (stream_ != nullptr) { PyGILState_STATE gil_state = PyGILState_Ensure(); Py_DECREF(stream_); @@ -279,7 +414,7 @@ class PythonStreamReader final : public StreamReader { Buffer *target = buffer == nullptr ? owned_buffer_.get() : buffer; if (target == nullptr) { if (buffer_ != nullptr) { - buffer_->stream_reader_ = nullptr; + buffer_->input_stream_ = nullptr; } buffer_ = nullptr; return; @@ -289,7 +424,7 @@ class PythonStreamReader final : public StreamReader { buffer_->data_ = data_.data(); buffer_->own_data_ = false; buffer_->wrapped_vector_ = nullptr; - buffer_->stream_reader_ = this; + buffer_->input_stream_ = this; return; } @@ -298,7 +433,7 @@ class PythonStreamReader final : public StreamReader { target->size_ = source->size_; target->writer_index_ = source->writer_index_; target->reader_index_ = source->reader_index_; - source->stream_reader_ = nullptr; + source->input_stream_ = nullptr; } else { target->size_ = 0; target->writer_index_ = 0; @@ -309,7 +444,7 @@ class PythonStreamReader final : public StreamReader { buffer_->data_ = data_.data(); buffer_->own_data_ = false; buffer_->wrapped_vector_ = nullptr; - buffer_->stream_reader_ = this; + buffer_->input_stream_ = this; } private: @@ -1409,13 +1544,56 @@ int Fory_PyCreateBufferFromStream(PyObject *stream, uint32_t buffer_size, return -1; } try { - auto stream_reader = - std::make_shared(stream, buffer_size, read_method); - *out = new Buffer(*stream_reader); + auto input_stream = + std::make_shared(stream, buffer_size, read_method); + *out = new Buffer(*input_stream); return 0; } catch (const std::exception &e) { *error_message = e.what(); return -1; } } + +int Fory_PyCreateOutputStream(PyObject *stream, OutputStream **out, + std::string *error_message) { + if (stream == nullptr) { + *error_message = "stream must not be null"; + return -1; + } + // See PyOutputStream::write_to_stream contract: the provided sink must not + // retain passed write buffers after write(...) returns. + if (!resolve_python_stream_write_method(stream, error_message)) { + return -1; + } + try { + *out = new PyOutputStream(stream, 4096); + return 0; + } catch (const std::exception &e) { + *error_message = e.what(); + return -1; + } +} + +int Fory_PyBindBufferToOutputStream(Buffer *buffer, OutputStream *output_stream, + std::string *error_message) { + if (buffer == nullptr) { + *error_message = "buffer must not be null"; + return -1; + } + if (output_stream == nullptr) { + *error_message = "output stream must not be null"; + return -1; + } + buffer->bind_output_stream(output_stream); + return 0; +} + +int Fory_PyClearBufferOutputStream(Buffer *buffer, std::string *error_message) { + if (buffer == nullptr) { + *error_message = "buffer must not be null"; + return -1; + } + buffer->clear_output_stream(); + return 0; +} } // namespace fory diff --git a/python/pyfory/cpp/pyfory.h b/python/pyfory/cpp/pyfory.h index 1003733ccf..3801929c98 100644 --- a/python/pyfory/cpp/pyfory.h +++ b/python/pyfory/cpp/pyfory.h @@ -24,6 +24,7 @@ #include "Python.h" #include "fory/type/type.h" #include "fory/util/buffer.h" +#include "fory/util/stream.h" namespace fory { inline constexpr bool Fory_IsInternalTypeId(uint8_t type_id) { @@ -70,4 +71,9 @@ int Fory_PyWriteBasicFieldToBuffer(PyObject *value, Buffer *buffer, PyObject *Fory_PyReadBasicFieldFromBuffer(Buffer *buffer, uint8_t type_id); int Fory_PyCreateBufferFromStream(PyObject *stream, uint32_t buffer_size, Buffer **out, std::string *error_message); +int Fory_PyCreateOutputStream(PyObject *stream, OutputStream **out, + std::string *error_message); +int Fory_PyBindBufferToOutputStream(Buffer *buffer, OutputStream *output_stream, + std::string *error_message); +int Fory_PyClearBufferOutputStream(Buffer *buffer, std::string *error_message); } // namespace fory diff --git a/python/pyfory/includes/libpyfory.pxd b/python/pyfory/includes/libpyfory.pxd index c0e888c18d..6e1d0542f7 100644 --- a/python/pyfory/includes/libpyfory.pxd +++ b/python/pyfory/includes/libpyfory.pxd @@ -19,8 +19,13 @@ from cpython.object cimport PyObject from libc.stdint cimport uint32_t from libcpp.string cimport string as c_string -from pyfory.includes.libutil cimport CBuffer +from pyfory.includes.libutil cimport CBuffer, COutputStream cdef extern from "fory/python/pyfory.h" namespace "fory": int Fory_PyCreateBufferFromStream(PyObject* stream, uint32_t buffer_size, CBuffer** out, c_string* error_message) + int Fory_PyCreateOutputStream(PyObject* stream, COutputStream** out, + c_string* error_message) + int Fory_PyBindBufferToOutputStream(CBuffer* buffer, COutputStream* output_stream, + c_string* error_message) + int Fory_PyClearBufferOutputStream(CBuffer* buffer, c_string* error_message) diff --git a/python/pyfory/includes/libutil.pxd b/python/pyfory/includes/libutil.pxd index aeaa60a7db..b0d8914e33 100644 --- a/python/pyfory/includes/libutil.pxd +++ b/python/pyfory/includes/libutil.pxd @@ -16,6 +16,7 @@ # under the License. from libc.stdint cimport * +from libc.stddef cimport size_t from libcpp cimport bool as c_bool from libcpp.memory cimport shared_ptr from libcpp.string cimport string as c_string @@ -212,6 +213,18 @@ cdef extern from "fory/util/buffer.h" namespace "fory" nogil: CBuffer* allocate_buffer(uint32_t size) c_bool allocate_buffer(uint32_t size, shared_ptr[CBuffer]* out) +cdef extern from "fory/util/stream.h" namespace "fory" nogil: + cdef cppclass COutputStream "fory::OutputStream": + CBuffer* get_buffer() + void enter_flush_barrier() + void exit_flush_barrier() + c_bool try_flush() + void force_flush() + size_t flushed_bytes() const + void reset() + c_bool has_error() const + const CError& error() const + cdef extern from "fory/util/bit_util.h" namespace "fory::util" nogil: c_bool get_bit(const uint8_t *bits, uint32_t i) diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 9b8c730988..acaca722ca 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -1078,6 +1078,7 @@ cdef class Fory: cdef public bint is_peer_out_of_band_enabled cdef int32_t max_depth cdef int32_t depth + cdef object _output_stream def __init__( self, @@ -1156,6 +1157,7 @@ cdef class Fory: self.is_peer_out_of_band_enabled = False self.depth = 0 self.max_depth = max_depth + self._output_stream = None def register_serializer(self, cls: Union[type, TypeVar], Serializer serializer): """ @@ -1289,6 +1291,37 @@ cdef class Fory: """ return self.serialize(obj, buffer, buffer_callback, unsupported_callback) + def dump(self, obj, stream): + """ + Serialize an object directly to a writable stream. + + Args: + obj: The object to serialize + stream: Writable stream implementing write(...) + + Notes: + The stream must be a non-retaining sink: ``write(data)`` must + synchronously consume ``data`` before returning. Fory may reuse or + modify the underlying buffer after ``write`` returns, so retaining + the passed object (or a view of it) is unsupported. If your sink + needs retention, copy bytes inside ``write``. + """ + try: + self.buffer.set_writer_index(0) + self._output_stream = Buffer.wrap_output_stream(stream) + self.buffer.bind_output_stream(self._output_stream) + self._serialize( + obj, + self.buffer, + buffer_callback=None, + unsupported_callback=None, + ) + self.force_flush() + finally: + self.buffer.bind_output_stream(None) + self._output_stream = None + self.reset_write() + def loads( self, buffer: Union[Buffer, bytes], @@ -1328,12 +1361,18 @@ cdef class Fory: >>> print(type(data)) """ + cdef Buffer write_buffer try: - return self._serialize( + write_buffer = self._serialize( obj, buffer, buffer_callback=buffer_callback, unsupported_callback=unsupported_callback) + if write_buffer is not self.buffer: + return write_buffer + if write_buffer.get_output_stream() is not None: + return write_buffer + return write_buffer.to_bytes(0, write_buffer.get_writer_index()) finally: self.reset_write() @@ -1350,6 +1389,7 @@ cdef class Fory: # 1byte used for bit mask buffer.grow(1) buffer.set_writer_index(mask_index + 1) + buffer.put_int8(mask_index, 0) if obj is None: set_bit(buffer, mask_index, 0) else: @@ -1362,11 +1402,35 @@ cdef class Fory: else: clear_bit(buffer, mask_index, 2) self.write_ref(buffer, obj) + return buffer - if buffer is not self.buffer: - return buffer - else: - return buffer.to_bytes(0, buffer.get_writer_index()) + cpdef inline enter_flush_barrier(self): + cdef PyOutputStream output_stream + if self._output_stream is None: + return + output_stream = self._output_stream + output_stream.enter_flush_barrier() + + cpdef inline exit_flush_barrier(self): + cdef PyOutputStream output_stream + if self._output_stream is None: + return + output_stream = self._output_stream + output_stream.exit_flush_barrier() + + cpdef inline try_flush(self): + cdef PyOutputStream output_stream + if self._output_stream is None or self.buffer.get_writer_index() <= 4096: + return + output_stream = self._output_stream + output_stream.try_flush() + + cpdef inline force_flush(self): + cdef PyOutputStream output_stream + if self._output_stream is None: + return + output_stream = self._output_stream + output_stream.force_flush() cpdef inline write_ref( self, Buffer buffer, obj, TypeInfo typeinfo=None, Serializer serializer=None): @@ -1623,6 +1687,7 @@ cdef class Fory: self.metastring_resolver.reset_write() self.serialization_context.reset_write() self._unsupported_callback = None + self._output_stream = None cpdef inline reset_read(self): """ diff --git a/python/pyfory/struct.pxi b/python/pyfory/struct.pxi index 63a3871cc7..ce43dc723c 100644 --- a/python/pyfory/struct.pxi +++ b/python/pyfory/struct.pxi @@ -287,6 +287,7 @@ cdef class DataClassSerializer(Serializer): self._write_slots(buffer, value) else: self._write_dict(buffer, value) + self.fory.try_flush() cdef inline void _write_dict(self, Buffer buffer, object value): cdef dict value_dict = value.__dict__ diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py index 7923dd6739..cb61db3429 100644 --- a/python/pyfory/struct.py +++ b/python/pyfory/struct.py @@ -536,6 +536,7 @@ def write(self, buffer: Buffer, value): is_tracking_ref = self._ref_fields.get(field_name, False) is_basic = self._basic_field_flags[index] self._write_field_value(buffer, serializer, field_value, is_nullable, is_dynamic, is_basic, is_tracking_ref) + self.fory.try_flush() def read(self, buffer): if not self.fory.strict: diff --git a/python/pyfory/tests/test_buffer.py b/python/pyfory/tests/test_buffer.py index 6533ed94de..021412e570 100644 --- a/python/pyfory/tests/test_buffer.py +++ b/python/pyfory/tests/test_buffer.py @@ -65,6 +65,22 @@ def recvinto(self, buffer, size=-1): return read_size +class PartialWriteStream: + def __init__(self): + self._data = bytearray() + + def write(self, payload): + if not payload: + return 0 + view = memoryview(payload).cast("B") + wrote = min(2, len(view)) + self._data.extend(view[:wrote]) + return wrote + + def to_bytes(self): + return bytes(self._data) + + def test_buffer(): buffer = Buffer.allocate(8) buffer.write_bool(True) @@ -253,15 +269,68 @@ def test_write_var_uint64(): def check_varuint64(buf: Buffer, value: int, bytes_written: int): - reader_index = buf.get_reader_index() assert buf.get_writer_index() == buf.get_reader_index() actual_bytes_written = buf.write_var_uint64(value) assert actual_bytes_written == bytes_written varint = buf.read_var_uint64() assert buf.get_writer_index() == buf.get_reader_index() assert value == varint - # test slow read branch in `read_varint64` - assert buf.slice(reader_index, buf.get_reader_index() - reader_index).read_var_uint64() == value + + +def test_buffer_flush_stream(): + stream = PartialWriteStream() + buffer = Buffer.allocate(16) + output_stream = Buffer.wrap_output_stream(stream) + buffer.bind_output_stream(output_stream) + payload = b"stream-flush-buffer" + buffer.write_bytes(payload) + output_stream.force_flush() + assert stream.to_bytes() == payload + assert buffer.get_writer_index() == 0 + + +def test_wrap_output_stream_invalid_target_raises(): + with pytest.raises(ValueError): + Buffer.wrap_output_stream(object()) + + +def test_output_stream_try_flush_preserves_bound_buffer_when_barrier_active(): + stream = PartialWriteStream() + output_stream = Buffer.wrap_output_stream(stream) + buffer = Buffer.allocate(32) + buffer.bind_output_stream(output_stream) + payload = b"x" * 5000 + + output_stream.enter_flush_barrier() + buffer.write_bytes(payload) + output_stream.try_flush() + output_stream.try_flush() + assert buffer.get_writer_index() == len(payload) + assert stream.to_bytes() == b"" + + output_stream.exit_flush_barrier() + output_stream.try_flush() + assert buffer.get_writer_index() == 0 + + output_stream.force_flush() + assert stream.to_bytes() == payload + + +def test_output_stream_try_flush_small_payload_needs_force_flush(): + stream = PartialWriteStream() + output_stream = Buffer.wrap_output_stream(stream) + buffer = Buffer.allocate(32) + buffer.bind_output_stream(output_stream) + payload = b"small-payload" + buffer.write_bytes(payload) + + output_stream.try_flush() + assert buffer.get_writer_index() == len(payload) + assert stream.to_bytes() == b"" + + output_stream.force_flush() + assert buffer.get_writer_index() == 0 + assert stream.to_bytes() == payload def test_write_buffer(): diff --git a/python/pyfory/tests/test_stream.py b/python/pyfory/tests/test_stream.py index c567420e44..3276c40542 100644 --- a/python/pyfory/tests/test_stream.py +++ b/python/pyfory/tests/test_stream.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +from dataclasses import dataclass + import pytest import pyfory @@ -68,6 +70,47 @@ def recvinto(self, buffer, size=-1): return self.recv_into(buffer, size) +class OneByteWriteStream: + def __init__(self): + self._data = bytearray() + + def write(self, payload): + if not payload: + return 0 + view = memoryview(payload).cast("B") + self._data.append(view[0]) + return 1 + + def to_bytes(self): + return bytes(self._data) + + +class CountingWriteStream: + def __init__(self): + self._data = bytearray() + self.write_calls = 0 + self.flush_calls = 0 + + def write(self, payload): + view = memoryview(payload).cast("B") + self.write_calls += 1 + self._data.extend(view) + return len(view) + + def flush(self): + self.flush_calls += 1 + + def to_bytes(self): + return bytes(self._data) + + +@dataclass +class StreamStructValue: + idx: int + name: str + values: list + + @pytest.mark.parametrize("xlang", [False, True]) def test_stream_roundtrip_primitives_and_strings(xlang): fory = pyfory.Fory(xlang=xlang, ref=True) @@ -148,3 +191,62 @@ def test_stream_deserialize_truncated_error(xlang): with pytest.raises(Exception): fory.deserialize(Buffer.from_stream(OneByteStream(truncated))) + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_dump_matches_dumps_bytes(xlang): + fory = pyfory.Fory(xlang=xlang, ref=True) + value = { + "k": [1, 2, 3, 4], + "nested": {"x": True, "y": "hello"}, + "f": 3.14, + } + + sink = OneByteWriteStream() + fory.dump(value, sink) + expected = fory.dumps(value) + assert sink.to_bytes() == expected + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_dump_map_chunk_path_matches_dumps(xlang): + fory = pyfory.Fory(xlang=xlang, ref=True) + value = {f"k{i}": i for i in range(300)} + + sink = OneByteWriteStream() + fory.dump(value, sink) + expected = fory.dumps(value) + assert sink.to_bytes() == expected + + restored = fory.deserialize(Buffer.from_stream(OneByteStream(sink.to_bytes()))) + assert restored == value + + +def test_dump_large_list_of_structs_multiple_flushes_matches_dumps(): + fory = pyfory.Fory(xlang=False, ref=True, strict=False) + fory.register(StreamStructValue) + value = [StreamStructValue(i, f"item-{i}-{'x' * 56}", [i, i + 1, i + 2, i + 3, i + 4]) for i in range(1800)] + + sink = CountingWriteStream() + fory.dump(value, sink) + expected = fory.dumps(value) + assert sink.to_bytes() == expected + assert len(expected) > 4096 * 4 + assert sink.write_calls >= 4 + + restored = fory.deserialize(Buffer.from_stream(OneByteStream(sink.to_bytes()))) + assert restored == value + + +def test_dump_large_map_with_struct_values_matches_dumps(): + fory = pyfory.Fory(xlang=False, ref=True, strict=False) + fory.register(StreamStructValue) + value = {f"k{i}": StreamStructValue(i, "y" * 96, [i, i + 1, i + 2, i + 3]) for i in range(900)} + + sink = OneByteWriteStream() + fory.dump(value, sink) + expected = fory.dumps(value) + assert sink.to_bytes() == expected + + restored = fory.deserialize(Buffer.from_stream(OneByteStream(sink.to_bytes()))) + assert restored == value