Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_library(cpr
auth.cpp
callback.cpp
cert_info.cpp
connection_pool.cpp
cookies.cpp
cprtypes.cpp
curl_container.cpp
Expand Down
39 changes: 39 additions & 0 deletions cpr/connection_pool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "cpr/connection_pool.h"
#include <curl/curl.h>
#include <memory>
#include <mutex>

namespace cpr {
ConnectionPool::ConnectionPool() {
auto* curl_share = curl_share_init();
Comment thread
cleaton marked this conversation as resolved.
Outdated
this->connection_mutex_ = std::make_shared<std::mutex>();

auto lock_f = +[](__attribute__((unused)) CURL* handle, __attribute__((unused)) curl_lock_data data, __attribute__((unused)) curl_lock_access access, void* userptr) {
std::mutex* lock = static_cast<std::mutex*>(userptr);
lock->lock();
};
Comment thread
cleaton marked this conversation as resolved.
Outdated

auto unlock_f = +[](__attribute__((unused)) CURL* handle, __attribute__((unused)) curl_lock_data data, void* userptr) {
Comment thread
cleaton marked this conversation as resolved.
Outdated
std::mutex* lock = static_cast<std::mutex*>(userptr);
lock->unlock();
};

curl_share_setopt(curl_share, CURLSHOPT_SHARE, CURL_LOCK_DATA_CONNECT);
curl_share_setopt(curl_share, CURLSHOPT_USERDATA, this->connection_mutex_.get());
curl_share_setopt(curl_share, CURLSHOPT_LOCKFUNC, lock_f);
curl_share_setopt(curl_share, CURLSHOPT_UNLOCKFUNC, unlock_f);

this->curl_sh_ = std::shared_ptr<CURLSH>(curl_share,
[](CURLSH* ptr) {
// Make sure to reset callbacks before cleanup to avoid deadlocks
curl_share_setopt(ptr, CURLSHOPT_LOCKFUNC, nullptr);
curl_share_setopt(ptr, CURLSHOPT_UNLOCKFUNC, nullptr);
curl_share_cleanup(ptr);
});
}

void ConnectionPool::SetupHandler(CURL* easy_handler) const {
curl_easy_setopt(easy_handler, CURLOPT_SHARE, this->curl_sh_.get());
}

} // namespace cpr
9 changes: 9 additions & 0 deletions cpr/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "cpr/body.h"
#include "cpr/callback.h"
#include "cpr/connect_timeout.h"
#include "cpr/connection_pool.h"
#include "cpr/cookies.h"
#include "cpr/cprtypes.h"
#include "cpr/curlholder.h"
Expand Down Expand Up @@ -390,6 +391,13 @@ void Session::SetConnectTimeout(const ConnectTimeout& timeout) {
curl_easy_setopt(curl_->handle, CURLOPT_CONNECTTIMEOUT_MS, timeout.Milliseconds());
}

void Session::SetConnectionPool(const ConnectionPool& pool) {
auto * curl = curl_->handle;
Comment thread
cleaton marked this conversation as resolved.
Outdated
if (curl) {
pool.SetupHandler(curl);
}
Comment thread
cleaton marked this conversation as resolved.
Outdated
}

void Session::SetAuth(const Authentication& auth) {
// Ignore here since this has been defined by libcurl.
switch (auth.GetAuthMode()) {
Expand Down Expand Up @@ -1071,6 +1079,7 @@ void Session::SetOption(const MultiRange& multi_range) { SetMultiRange(multi_ran
void Session::SetOption(const ReserveSize& reserve_size) { SetReserveSize(reserve_size.size); }
void Session::SetOption(const AcceptEncoding& accept_encoding) { SetAcceptEncoding(accept_encoding); }
void Session::SetOption(AcceptEncoding&& accept_encoding) { SetAcceptEncoding(std::move(accept_encoding)); }
void Session::SetOption(const ConnectionPool& pool) { SetConnectionPool(pool); }
// clang-format on

void Session::SetCancellationParam(std::shared_ptr<std::atomic_bool> param) {
Expand Down
21 changes: 21 additions & 0 deletions include/cpr/connection_pool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef CPR_CONNECTION_POOL_H
#define CPR_CONNECTION_POOL_H

#include <curl/curl.h>
#include <memory>
#include <mutex>

namespace cpr {
class ConnectionPool {
public:
ConnectionPool();
ConnectionPool(const ConnectionPool&) = default;
ConnectionPool& operator=(const ConnectionPool&) = delete;
void SetupHandler(CURL* easy_handler) const;

private:
std::shared_ptr<std::mutex> connection_mutex_;
std::shared_ptr<CURLSH> curl_sh_;
};
} // namespace cpr
#endif
Comment thread
cleaton marked this conversation as resolved.
1 change: 1 addition & 0 deletions include/cpr/cpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "cpr/callback.h"
#include "cpr/cert_info.h"
#include "cpr/connect_timeout.h"
#include "cpr/connection_pool.h"
#include "cpr/cookies.h"
#include "cpr/cprtypes.h"
#include "cpr/cprver.h"
Expand Down
3 changes: 3 additions & 0 deletions include/cpr/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "cpr/body.h"
#include "cpr/callback.h"
#include "cpr/connect_timeout.h"
#include "cpr/connection_pool.h"
#include "cpr/cookies.h"
#include "cpr/cprtypes.h"
#include "cpr/curlholder.h"
Expand Down Expand Up @@ -71,6 +72,7 @@ class Session : public std::enable_shared_from_this<Session> {
[[nodiscard]] const Header& GetHeader() const;
void SetTimeout(const Timeout& timeout);
void SetConnectTimeout(const ConnectTimeout& timeout);
void SetConnectionPool(const ConnectionPool& pool);
void SetAuth(const Authentication& auth);
// Only supported with libcurl >= 7.61.0.
// As an alternative use SetHeader and add the token manually.
Expand Down Expand Up @@ -135,6 +137,7 @@ class Session : public std::enable_shared_from_this<Session> {
void SetOption(const Timeout& timeout);
void SetOption(const ConnectTimeout& timeout);
void SetOption(const Authentication& auth);
void SetOption(const ConnectionPool& pool);
// Only supported with libcurl >= 7.61.0.
// As an alternative use SetHeader and add the token manually.
#if LIBCURL_VERSION_NUM >= 0x073D00
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ add_cpr_test(multiasync)
add_cpr_test(file_upload)
add_cpr_test(singleton)
add_cpr_test(threadpool)
add_cpr_test(connection_pool)

if (ENABLE_SSL_TESTS)
add_cpr_test(ssl)
Expand Down
15 changes: 15 additions & 0 deletions test/abstractServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ static void EventHandler(mg_connection* conn, int event, void* event_data, void*

case MG_EV_HTTP_MSG: {
AbstractServer* server = static_cast<AbstractServer*>(context);
// Use the connection address as unique identifier instead
int port = AbstractServer::GetRemotePort(conn);
server->AddConnection(port);
server->OnRequest(conn, static_cast<mg_http_message*>(event_data));
} break;

Expand Down Expand Up @@ -79,6 +82,18 @@ void AbstractServer::Run() {
server_stop_cv.notify_all();
}

void AbstractServer::AddConnection(int remote_port) {
unique_connections.insert(remote_port);
Comment thread
COM8 marked this conversation as resolved.
}

size_t AbstractServer::GetConnectionCount() {
return unique_connections.size();
}

void AbstractServer::ResetConnectionCount() {
unique_connections.clear();
}

static const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
Expand Down
12 changes: 9 additions & 3 deletions test/abstractServer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <memory>
#include <mutex>
#include <string>
#include <set>

#include "cpr/cpr.h"
#include "mongoose.h"
Expand Down Expand Up @@ -38,18 +39,26 @@ class AbstractServer : public testing::Environment {
void Start();
void Stop();

size_t GetConnectionCount();
void ResetConnectionCount();
void AddConnection(int remote_port);

virtual std::string GetBaseUrl() = 0;
virtual uint16_t GetPort() = 0;

virtual void acceptConnection(mg_connection* conn) = 0;
virtual void OnRequest(mg_connection* conn, mg_http_message* msg) = 0;

static uint16_t GetRemotePort(const mg_connection* conn);
static uint16_t GetLocalPort(const mg_connection* conn);

private:
std::shared_ptr<std::thread> serverThread{nullptr};
std::mutex server_mutex;
std::condition_variable server_start_cv;
std::condition_variable server_stop_cv;
std::atomic<bool> should_run{false};
std::set<int> unique_connections;

void Run();

Expand All @@ -61,9 +70,6 @@ class AbstractServer : public testing::Environment {
static std::string Base64Decode(const std::string& in);
static void SendError(mg_connection* conn, int code, std::string& reason);
static bool IsConnectionActive(mg_mgr* mgr, mg_connection* conn);

static uint16_t GetRemotePort(const mg_connection* conn);
static uint16_t GetLocalPort(const mg_connection* conn);
};
} // namespace cpr

Expand Down
83 changes: 83 additions & 0 deletions test/connection_pool_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include <gtest/gtest.h>

#include <string>
#include <vector>

#include <cpr/cpr.h>

#include "httpServer.hpp"

using namespace cpr;

static HttpServer* server = new HttpServer();
const size_t NUM_REQUESTS = 100;

TEST(MultipleGetTests, PoolBasicMultipleGetTest) {
Url url{server->GetBaseUrl() + "/hello.html"};
ConnectionPool pool;
server->ResetConnectionCount();

// Without shared connection pool
for (size_t i = 0; i < NUM_REQUESTS; ++i) {
Response response = cpr::Get(url);
EXPECT_EQ(std::string{"Hello world!"}, response.text);
EXPECT_EQ(url, response.url);
EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]);
EXPECT_EQ(200, response.status_code);
}
EXPECT_EQ(server->GetConnectionCount(), NUM_REQUESTS);

// With shared connection pool
server->ResetConnectionCount();
for (size_t i = 0; i < NUM_REQUESTS; ++i) {
Response response = cpr::Get(url, pool);
EXPECT_EQ(std::string{"Hello world!"}, response.text);
EXPECT_EQ(url, response.url);
EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]);
EXPECT_EQ(200, response.status_code);
}
EXPECT_LT(server->GetConnectionCount(), NUM_REQUESTS);
}

TEST(MultipleGetTests, PoolAsyncGetMultipleTest) {
Url url{server->GetBaseUrl() + "/hello.html"};
ConnectionPool pool;
std::vector<AsyncResponse> responses;
server->ResetConnectionCount();

// Without shared connection pool
responses.reserve(NUM_REQUESTS);
for (size_t i = 0; i < NUM_REQUESTS; ++i) {
responses.emplace_back(cpr::GetAsync(url));
}
for (AsyncResponse& future : responses) {
Response response = future.get();
EXPECT_EQ(std::string{"Hello world!"}, response.text);
EXPECT_EQ(url, response.url);
EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]);
EXPECT_EQ(200, response.status_code);
}
EXPECT_EQ(server->GetConnectionCount(), NUM_REQUESTS);

// With shared connection pool
server->ResetConnectionCount();
responses.clear();
for (size_t i = 0; i < NUM_REQUESTS; ++i) {
responses.emplace_back(cpr::GetAsync(url, pool));
}
for (AsyncResponse& future : responses) {
Response response = future.get();

EXPECT_EQ(std::string{"Hello world!"}, response.text);
EXPECT_EQ(url, response.url);
EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]);
EXPECT_EQ(200, response.status_code);
}
EXPECT_LT(server->GetConnectionCount(), NUM_REQUESTS);
}

int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
::testing::AddGlobalTestEnvironment(server);
return RUN_ALL_TESTS();
}