Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/amd_detail/batch/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_pa
auto param_copy = std::make_unique<const hipFileIOParams_t>(params[i]);
// flags currently unused. Ambiguous if flags in hipFileBatchIOSubmit is for buffer or
// file flags.
auto [_file, _buffer] = Context<DriverState>::get()->getFileAndBuffer(
param_copy->fh, param_copy->u.batch.devPtr_base, param_copy->u.batch.size, 0);
auto [_file, _buffer] =
Context<DriverState>::get()->getFileAndBuffer(param_copy->fh, param_copy->u.batch.devPtr_base);
auto op = std::make_shared<BatchOperation>(std::move(param_copy), _buffer, _file);

pending_ops.push_back(op);
Expand Down
31 changes: 21 additions & 10 deletions src/amd_detail/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ Buffer::Buffer(const void *_buffer, size_t _length, int _flags, const PassKey<Bu
}
}

Buffer::Buffer(const void *_buffer, const PassKey<BufferMap> &) : buffer{const_cast<void *>(_buffer)}
{
if (!buffer) {
throw std::invalid_argument("Buffer pointer cannot be null.");
}

hipPointerAttribute_t _attrs = Context<Hip>::get()->hipPointerGetAttributes(buffer);
if (_attrs.type != hipMemoryTypeDevice) {
throw InvalidMemoryType();
}
type = _attrs.type;
gpu_id = _attrs.device;

HipMemAddressRange range{Context<Hip>::get()->hipMemGetAddressRange(buffer)};
length = range.size - (reinterpret_cast<uintptr_t>(buffer) - reinterpret_cast<uintptr_t>(range.base));
}

void *
Buffer::getBuffer() const
{
Expand Down Expand Up @@ -117,7 +134,7 @@ BufferMap::deregisterBuffer(const void *buf)
}

shared_ptr<IBuffer>
BufferMap::getBuffer(const void *buf)
BufferMap::getRegisteredBuffer(const void *buf)
{
auto itr = from_ptr.find(buf);
if (from_ptr.end() == itr) {
Expand All @@ -128,21 +145,15 @@ BufferMap::getBuffer(const void *buf)
}

shared_ptr<IBuffer>
BufferMap::getBuffer(const void *buf, size_t length, int flags)
BufferMap::getBuffer(const void *buf)
{
auto itr = from_ptr.find(buf);

if (from_ptr.end() == itr) {
// If the buffer hasn't been registered, use an unregistered
// temporary Buffer object
return std::shared_ptr<IBuffer>(new Buffer(buf, length, flags, PassKey<BufferMap>{}));
// Create a temporary buffer
return std::shared_ptr<IBuffer>(new Buffer(buf, PassKey<BufferMap>{}));
}
else {
// If we found a registered buffer, it's an error if the
// length parameter doesn't match what we found
if (itr->second->getLength() < length) {
throw std::invalid_argument("bad length parameter");
}
return itr->second;
}
}
Expand Down
15 changes: 11 additions & 4 deletions src/amd_detail/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ class Buffer : public IBuffer {
/// @param k Key class instance (see passkey.h)
Buffer(const void *buf, size_t length, int flags, const PassKey<BufferMap> &k);

/// @brief Creates a buffer.
///
/// The length of the buffer is determined by querying the hip runtime
///
/// @param buf Buffer pointer
/// @param k Key class instance (see passkey.h)
Buffer(const void *buf, const PassKey<BufferMap> &k);

private:
/// @brief Pointer to a hip allocated buffer
void *buffer;
Expand Down Expand Up @@ -126,15 +134,14 @@ class BufferMap {
/// @attention A shared_lock on HipFileMutex must be held
/// @param buf Buffer pointer
/// @return A registered buffer
virtual std::shared_ptr<IBuffer> getBuffer(const void *buf);
virtual std::shared_ptr<IBuffer> getRegisteredBuffer(const void *buf);

/// @brief Look up a registered buffer. Returns a temporary unregistered
/// buffer (of size length, using flags) if no matching buffer is found.
/// buffer if no registered buffer is found.
/// @attention A shared_lock on HipFileMutex must be held
/// @param buf Buffer pointer
/// @param length Buffer length
/// @return A registered or temporary unregistered buffer
virtual std::shared_ptr<IBuffer> getBuffer(const void *buf, size_t length, int flags);
virtual std::shared_ptr<IBuffer> getBuffer(const void *buf);

virtual void clear();

Expand Down
2 changes: 1 addition & 1 deletion src/amd_detail/hipfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ ssize_t
hipFileIo(IoType type, hipFileHandle_t fh, const void *buffer_base, size_t size, hoff_t file_offset,
hoff_t buffer_offset, const vector<shared_ptr<Backend>> &backends)
try {
auto [file, buffer] = Context<DriverState>::get()->getFileAndBuffer(fh, buffer_base, size, 0);
auto [file, buffer] = Context<DriverState>::get()->getFileAndBuffer(fh, buffer_base);
int score{-1};
std::shared_ptr<Backend> backend{};

Expand Down
18 changes: 8 additions & 10 deletions src/amd_detail/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,25 +92,25 @@ DriverState::deregisterBuffer(const void *buf)
}

shared_ptr<IBuffer>
DriverState::getBuffer(const void *buf)
DriverState::getRegisteredBuffer(const void *buf)
{
unique_lock<shared_mutex> ulock{state_mutex};

if (ref_count == 0) {
throw DriverNotInitialized();
}

return buffer_map->getBuffer(buf);
return buffer_map->getRegisteredBuffer(buf);
}

shared_ptr<IBuffer>
DriverState::getBuffer(const void *buf, size_t length, int flags)
DriverState::getBuffer(const void *buf)
{
// NOTE: This mutex only protects the map, so we'll
// also need to protect the data
shared_lock<shared_mutex> slock{state_mutex};

return buffer_map->getBuffer(buf, length, flags);
return buffer_map->getBuffer(buf);
}

//
Expand Down Expand Up @@ -190,7 +190,7 @@ DriverState::getStream(hipStream_t hip_stream)
//

file_buffer_pair
DriverState::getFileAndBuffer(hipFileHandle_t fh, const void *buf, size_t length, int flags)
DriverState::getFileAndBuffer(hipFileHandle_t fh, const void *buf)
{
// NOTE: This mutex only protects the map, so we'll
// also need to protect the data
Expand All @@ -200,25 +200,23 @@ DriverState::getFileAndBuffer(hipFileHandle_t fh, const void *buf, size_t length
throw DriverNotInitialized();
}

return {file_map->getFile(fh), buffer_map->getBuffer(buf, length, flags)};
return {file_map->getFile(fh), buffer_map->getBuffer(buf)};
}

//
// Buffer and file and stream calls
//

file_buffer_stream_tuple
DriverState::getFileBufferAndStream(hipFileHandle_t fh, const void *buf, size_t length, int flags,
hipStream_t hipStream)
DriverState::getFileBufferAndStream(hipFileHandle_t fh, const void *buf, hipStream_t hipStream)
{
shared_lock<shared_mutex> slock{state_mutex};

if (ref_count == 0) {
throw DriverNotInitialized();
}

return {file_map->getFile(fh), buffer_map->getBuffer(buf, length, flags),
stream_map->getStream(hipStream)};
return {file_map->getFile(fh), buffer_map->getBuffer(buf), stream_map->getStream(hipStream)};
}

//
Expand Down
16 changes: 6 additions & 10 deletions src/amd_detail/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,13 @@ class DriverState {
/// @brief Look up a registered buffer using the buffer pointer
/// @param [in] buf Buffer pointer
/// @return A registered buffer
virtual std::shared_ptr<IBuffer> getBuffer(const void *buf);
virtual std::shared_ptr<IBuffer> getRegisteredBuffer(const void *buf);

/// @brief Look up a registered buffer. Returns a temporary unregistered
/// buffer (of size length, using flags) if no matching buffer is found.
/// buffer if no matching buffer is found.
/// @param [in] buf Buffer pointer
/// @param [in] length Buffer length
/// @param [in] flags Buffer flags (unused)
/// @return A registered or temporary unregistered buffer
virtual std::shared_ptr<IBuffer> getBuffer(const void *buf, size_t length, int flags);
virtual std::shared_ptr<IBuffer> getBuffer(const void *buf);

//
// File interface
Expand Down Expand Up @@ -158,19 +156,17 @@ class DriverState {
/// This combined file + buffer getter reduces the number of lock calls.
///
/// Like the buffer getter, this function emits a temporary unregistered buffer
/// (of size length, using flags) if no matching buffer is found.
/// if no matching registered buffer is found.
///
/// @param [in] fh File handle
/// @param [in] buf Buffer pointer
/// @param [in] length Buffer length
/// @param [in] flags Buffer flags (unused)
virtual file_buffer_pair getFileAndBuffer(hipFileHandle_t fh, const void *buf, size_t length, int flags);
virtual file_buffer_pair getFileAndBuffer(hipFileHandle_t fh, const void *buf);

//
// Buffer, file, and stream calls
//
virtual file_buffer_stream_tuple getFileBufferAndStream(hipFileHandle_t fh, const void *buf,
size_t length, int flags, hipStream_t hipStream);
hipStream_t hipStream);

//
// Reference counts
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ set(SYSTEM_TEST_SOURCE_FILES
system/buffer.cpp
system/config.cpp
system/driver.cpp
system/io.cpp
system/main.cpp
system/version.cpp
)
Expand Down
32 changes: 11 additions & 21 deletions test/amd_detail/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ TEST_F(HipFileBuffer, deregister_internal_get_prevents_deregister)
expect_buffer_registration(mhip, hipMemoryTypeDevice);
Context<DriverState>::get()->registerBuffer(nonnull_ptr, 0, 0);
{
auto buffer = Context<DriverState>::get()->getBuffer(nonnull_ptr);
auto buffer = Context<DriverState>::get()->getRegisteredBuffer(nonnull_ptr);
ASSERT_THROW(Context<DriverState>::get()->deregisterBuffer(nonnull_ptr), BufferOperationsOutstanding);
}
Context<DriverState>::get()->deregisterBuffer(nonnull_ptr);
Expand All @@ -222,31 +222,31 @@ TEST_F(HipFileBuffer, deregister_get_prevents_deregister)
expect_buffer_registration(mhip, hipMemoryTypeDevice);
ASSERT_EQ(hipFileBufRegister(nonnull_ptr, 0, 0), HIPFILE_SUCCESS);
{
auto buffer = Context<DriverState>::get()->getBuffer(nonnull_ptr);
auto buffer = Context<DriverState>::get()->getRegisteredBuffer(nonnull_ptr);
ASSERT_EQ(hipFileBufDeregister(nonnull_ptr), HipFileOpError(hipFileInternalError));
}
ASSERT_EQ(hipFileBufDeregister(nonnull_ptr), HIPFILE_SUCCESS);
}

TEST_F(HipFileBuffer, get_not_registered)
{
ASSERT_THROW(Context<DriverState>::get()->getBuffer(nonnull_ptr), BufferNotRegistered);
ASSERT_THROW(Context<DriverState>::get()->getRegisteredBuffer(nonnull_ptr), BufferNotRegistered);
}

TEST_F(HipFileBuffer, get_internal_after_register)
{
StrictMock<MHip> mhip;
expect_buffer_registration(mhip, hipMemoryTypeDevice);
Context<DriverState>::get()->registerBuffer(nonnull_ptr, 0, 0);
auto buffer = Context<DriverState>::get()->getBuffer(nonnull_ptr);
auto buffer = Context<DriverState>::get()->getRegisteredBuffer(nonnull_ptr);
}

TEST_F(HipFileBuffer, get_after_register)
{
StrictMock<MHip> mhip;
expect_buffer_registration(mhip, hipMemoryTypeDevice);
ASSERT_EQ(hipFileBufRegister(nonnull_ptr, 0, 0), HIPFILE_SUCCESS);
auto buffer = Context<DriverState>::get()->getBuffer(nonnull_ptr);
auto buffer = Context<DriverState>::get()->getRegisteredBuffer(nonnull_ptr);
}

TEST_F(HipFileBuffer, get_internal_after_deregister)
Expand All @@ -255,7 +255,7 @@ TEST_F(HipFileBuffer, get_internal_after_deregister)
expect_buffer_registration(mhip, hipMemoryTypeDevice);
Context<DriverState>::get()->registerBuffer(nonnull_ptr, 0, 0);
Context<DriverState>::get()->deregisterBuffer(nonnull_ptr);
ASSERT_THROW(Context<DriverState>::get()->getBuffer(nonnull_ptr), BufferNotRegistered);
ASSERT_THROW(Context<DriverState>::get()->getRegisteredBuffer(nonnull_ptr), BufferNotRegistered);
}

TEST_F(HipFileBuffer, get_after_deregister)
Expand All @@ -264,14 +264,14 @@ TEST_F(HipFileBuffer, get_after_deregister)
expect_buffer_registration(mhip, hipMemoryTypeDevice);
ASSERT_EQ(hipFileBufRegister(nonnull_ptr, 0, 0), HIPFILE_SUCCESS);
ASSERT_EQ(hipFileBufDeregister(nonnull_ptr), HIPFILE_SUCCESS);
ASSERT_THROW(Context<DriverState>::get()->getBuffer(nonnull_ptr), BufferNotRegistered);
ASSERT_THROW(Context<DriverState>::get()->getRegisteredBuffer(nonnull_ptr), BufferNotRegistered);
}

TEST_F(HipFileBuffer, get_buffer_makes_temporary_buffer)
{
StrictMock<MHip> mhip;
expect_buffer_registration(mhip, hipMemoryTypeDevice);
auto buffer = Context<DriverState>::get()->getBuffer(nonnull_ptr, 0, 0);
auto buffer = Context<DriverState>::get()->getBuffer(nonnull_ptr);
ASSERT_EQ(buffer.use_count(), 1);
}

Expand All @@ -280,25 +280,15 @@ TEST_F(HipFileBuffer, get_buffer_returns_registered_buffer)
StrictMock<MHip> mhip;
expect_buffer_registration(mhip, hipMemoryTypeDevice);
Context<DriverState>::get()->registerBuffer(nonnull_ptr, 0, 0);
ASSERT_EQ(Context<DriverState>::get()->getBuffer(nonnull_ptr, 0, 0),
Context<DriverState>::get()->getBuffer(nonnull_ptr));
}

TEST_F(HipFileBuffer, get_buffer_throws_if_length_larger_than_registered_length)
{
StrictMock<MHip> mhip;
size_t buffer_length = 0;
expect_buffer_registration(mhip, hipMemoryTypeDevice);
Context<DriverState>::get()->registerBuffer(nonnull_ptr, buffer_length, 0);
ASSERT_THROW(Context<DriverState>::get()->getBuffer(nonnull_ptr, buffer_length + 1, 0),
std::invalid_argument);
ASSERT_EQ(Context<DriverState>::get()->getBuffer(nonnull_ptr),
Context<DriverState>::get()->getRegisteredBuffer(nonnull_ptr));
}

TEST_F(HipFileBuffer, get_buffer_throws_on_getPointerAttributes_error)
{
StrictMock<MHip> mhip;
EXPECT_CALL(mhip, hipPointerGetAttributes).WillOnce(testing::Throw(Hip::RuntimeError(hipErrorUnknown)));
ASSERT_THROW(Context<DriverState>::get()->getBuffer(nonnull_ptr, 1, 0), Hip::RuntimeError);
ASSERT_THROW(Context<DriverState>::get()->getBuffer(nonnull_ptr), Hip::RuntimeError);
}

HIPFILE_WARN_NO_GLOBAL_CTOR_ON
6 changes: 3 additions & 3 deletions test/amd_detail/fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ struct FallbackIo : public HipFileOpened {

expect_buffer_registration(mhip, hipMemoryTypeDevice);
Context<DriverState>::get()->registerBuffer(buffer_data.data(), buffer_data.size(), 0);
buffer = Context<DriverState>::get()->getBuffer(buffer_data.data());
buffer = Context<DriverState>::get()->getRegisteredBuffer(buffer_data.data());

expect_file_registration(msys, mlibmounthelper);
file = Context<DriverState>::get()->getFile(Context<DriverState>::get()->registerFile(0xBADF00D));
Expand Down Expand Up @@ -182,7 +182,7 @@ struct FallbackParam : ::testing::TestWithParam<IoType> {
expect_buffer_registration(mhip, hipMemoryTypeDevice);
void *buf = reinterpret_cast<void *>(0xFEFEFEFE);
Context<DriverState>::get()->registerBuffer(buf, 4096, 0);
buffer = Context<DriverState>::get()->getBuffer(buf);
buffer = Context<DriverState>::get()->getRegisteredBuffer(buf);

expect_file_registration(msys, mlibmounthelper);
file = Context<DriverState>::get()->getFile(Context<DriverState>::get()->registerFile(0xBADF00D));
Expand Down Expand Up @@ -248,7 +248,7 @@ TEST_P(FallbackParam, fallback_io_truncates_size_to_MAX_RW_COUNT)
expect_buffer_registration(mhip, hipMemoryTypeDevice);
auto buf = reinterpret_cast<void *>(0xABABABAB);
Context<DriverState>::get()->registerBuffer(buf, MAX_RW_COUNT + 1, 0);
auto big_buffer = Context<DriverState>::get()->getBuffer(buf);
auto big_buffer = Context<DriverState>::get()->getRegisteredBuffer(buf);

EXPECT_CALL(msys, mmap).WillOnce(testing::Return(reinterpret_cast<void *>(0xFEFEFEFE)));
switch (io_type) {
Expand Down
Loading
Loading