Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions include/websocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ struct RequestContent {
Callback on_unsubscribe;
// params to be passed with the subscription string
json params;
bool subscribed = false;
// using promise to ensure this
std::promise<bool> subscribed_pr;
std::future<bool> subscription_future = subscribed_pr.get_future();
RequestIdType ws_id;

RequestContent() = default;
Expand Down Expand Up @@ -77,7 +79,7 @@ class session : public std::enable_shared_from_this<session> {

/// @brief push a function for subscription
/// @param req the request to call
void subscribe(const RequestContent &req);
void subscribe(RequestContent *req);

/// @brief push for unsubscription
/// @param id the id to unsubscribe on
Expand Down Expand Up @@ -147,7 +149,7 @@ class session : public std::enable_shared_from_this<session> {
std::atomic_bool is_connected;

// map of subscription id with callback
std::unordered_map<RequestIdType, RequestContent> callback_map;
std::unordered_map<RequestIdType, RequestContent *> callback_map;
std::unordered_map<RequestIdType, RequestIdType> maps_wsid_to_id;
std::shared_mutex mutex_for_maps;

Expand Down
8 changes: 4 additions & 4 deletions lib/solana.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1119,17 +1119,17 @@ int WebSocketSubscriber::onAccountChange(const solana::PublicKey &pub_key,
json param = {pub_key, {{"encoding", "base64"}, {"commitment", commitment}}};

// create a new request content
RequestContent req(curr_id, "accountSubscribe", "accountUnsubscribe",
account_change_callback, std::move(param), on_subscibe,
on_unsubscribe);
RequestContent *req = new RequestContent(
curr_id, "accountSubscribe", "accountUnsubscribe",
account_change_callback, std::move(param), on_subscibe, on_unsubscribe);

// subscribe the new request content
sess->subscribe(req);

// increase the curr_id so that it can be used for the next request content
curr_id += 2;

return req.id;
return req->id;
}

/// @brief remove the account change listener for the given id
Expand Down
81 changes: 60 additions & 21 deletions lib/websocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ RequestContent::RequestContent(RequestIdType id, std::string subscribe_method,
/// @return the json that can be used to make subscription
json RequestContent::get_subscription_request() const {
json req = {{"jsonrpc", "2.0"},
{"id", id},
{"method", subscribe_method},
{"params", params}};
{"id", this->id},
{"method", this->subscribe_method},
{"params", this->params}};
return req;
}

Expand All @@ -31,8 +31,8 @@ json RequestContent::get_unsubscription_request(
RequestIdType subscription_id) const {
json params = {subscription_id};
json req = {{"jsonrpc", "2.0"},
{"id", id + 1},
{"method", unsubscribe_method},
{"id", this->id + 1},
{"method", this->unsubscribe_method},
{"params", params}};
return req;
}
Expand Down Expand Up @@ -63,30 +63,48 @@ void session::run(std::string host, std::string port) {

/// @brief push a function for subscription
/// @param req the request to call
void session::subscribe(const RequestContent &req) {
void session::subscribe(RequestContent *req) {
// context for unique_lock
{
std::unique_lock lk(mutex_for_maps);
callback_map[req.id] = req;
callback_map[req->id] = req;
}
// get subscription request and then send it to the websocket
ws.write(net::buffer(req.get_subscription_request().dump()));
ws.write(net::buffer(req->get_subscription_request().dump()));
}

/// @brief push for unsubscription
/// @param id the id to unsubscribe on
void session::unsubscribe(RequestIdType id) {
// context for shared lock
std::string unsubsciption_request = "";
std::unordered_map<RequestIdType, RequestContent *>::iterator ite;
{
std::shared_lock lk(mutex_for_maps);

auto ite = callback_map.find(id);
if (ite == callback_map.end()) return;
unsubsciption_request =
ite->second.get_unsubscription_request(ite->second.ws_id).dump();
maps_wsid_to_id.erase(ite->second.ws_id);
ite = callback_map.find(id);
}
if (ite == callback_map.end()) return;

// wait for subscription to happen
bool subscription_successful = ite->second->subscription_future.get();

// if subscription wasn't successful then just remove the id from callback
if (!subscription_successful) {
{
std::unique_lock lk(mutex_for_maps);
const auto ite = callback_map.find(id);
if (ite != callback_map.end()) {
callback_map.erase(ite);
delete ite->second;
}
}
return;
}

unsubsciption_request =
ite->second->get_unsubscription_request(ite->second->ws_id).dump();
maps_wsid_to_id.erase(ite->second->ws_id);
if (!unsubsciption_request.empty()) {
// write it to the websocket
ws.write(net::buffer(unsubsciption_request));
Expand Down Expand Up @@ -197,10 +215,10 @@ void session::on_read(beast::error_code ec, std::size_t bytes_transferred) {
// get the data from the websocket and parse it to json
auto res = buffer.data();
json data = json::parse(net::buffers_begin(res), net::buffers_end(res));

// if data contains field result then it's either subscription or
// unsubscription response
static const char *result = "result";
static const char *error = "error";
try {
if (data.contains(std::string{result})) {
RequestIdType id = data["id"];
Expand All @@ -215,8 +233,9 @@ void session::on_read(beast::error_code ec, std::size_t bytes_transferred) {
std::unique_lock lk(mutex_for_maps);
const auto ite = callback_map.find(id);
if (ite != callback_map.end()) {
on_unsubscribe = ite->second.on_unsubscribe;
on_unsubscribe = ite->second->on_unsubscribe;
callback_map.erase(ite);
delete ite->second;
}
}
if (on_unsubscribe) {
Expand All @@ -231,17 +250,37 @@ void session::on_read(beast::error_code ec, std::size_t bytes_transferred) {

const auto ite = callback_map.find(id);
if (ite != callback_map.end()) {
on_subscribe = ite->second.on_subscribe;
ite->second.subscribed = true;
ite->second.ws_id = data[result];
maps_wsid_to_id[ite->second.ws_id] = id;
on_subscribe = ite->second->on_subscribe;
ite->second->subscribed_pr.set_value(true);
ite->second->ws_id = data[result];
maps_wsid_to_id[ite->second->ws_id] = id;
}
}
if (on_subscribe) {
on_subscribe(data);
}
}
}
// In case of erro
else if (data.contains(std::string{error})) {
RequestIdType id = data["id"];
json er_mess = data["error"];
// if id is even then error in subscribing else error in unsubscribing
if (id % 2 == 0) {
std::cout << "Some error happened while subscribing" << std::endl;
{
std::unique_lock lk(mutex_for_maps);

const auto ite = callback_map.find(id);
if (ite != callback_map.end()) {
ite->second->subscribed_pr.set_value(false);
}
}
} else {
std::cout << "Some error happened while unsubscribing" << std::endl;
}
std::cout << er_mess << std::endl;
}
// it's a notification process it sccordingly
else {
call_callback(data);
Expand Down Expand Up @@ -275,7 +314,7 @@ Callback session::get_callback(RequestIdType request_id) {
<< std::endl;
return nullptr;
}
return sub_ite->second.cb;
return sub_ite->second->cb;
}

/// @brief call the specified callback
Expand Down Expand Up @@ -307,4 +346,4 @@ void session::on_close(beast::error_code ec) {

// If we get here then the connection is closed gracefully
return;
}
}