8 #include "iwebsocketserverendpoint.h" 
    9 #include "websocketmessage.h" 
   10 #include "websocketserver.h" 
   16 #include <unordered_set> 
   17 #include <mathutils.h> 
   18 #include <rapidjson/document.h> 
   19 #include <rapidjson/istreamwrapper.h> 
   20 #include <rapidjson/prettywriter.h> 
   21 #include <rapidjson/error/en.h> 
   22 #include <nap/logger.h> 
   71     template<
typename config>
 
   72     class WebSocketServerEndPointSetup : 
public IWebSocketServerEndPoint
 
   74         RTTI_ENABLE(IWebSocketServerEndPoint)
 
   84         bool init(utility::ErrorState& errorState) 
override;
 
   97         bool isOpen() 
const override;
 
  102         void stop() 
override;
 
  146         std::string 
getHostName(
const WebSocketConnection& connection) 
override;
 
  152         void getHostNames(std::vector<std::string>& outHosts) 
override;
 
  292     template<
typename config>
 
  299         std::error_code stdec;
 
  300         if (mIPAddress.empty())
 
  302             mEndPoint.listen(
static_cast<uint16>(mPort), stdec);
 
  306             mEndPoint.listen(mIPAddress, utility::stringFormat(
"%d", 
static_cast<uint16>(mPort)), stdec);
 
  312             error.fail(stdec.message());
 
  317         mEndPoint.start_accept(stdec);
 
  320             error.fail(stdec.message());
 
  332     template<
typename config>
 
  339     template<
typename config>
 
  342         std::error_code stdec;
 
  343         mEndPoint.send(connection.mConnection, message, 
static_cast<wspp::OpCode>(code), stdec);
 
  346             error.fail(stdec.message());
 
  353     template<
typename config>
 
  356         std::error_code stdec;
 
  357         mEndPoint.send(connection.mConnection, payload, length, 
static_cast<wspp::OpCode>(code), stdec);
 
  360             error.fail(stdec.message());
 
  367     template<
typename config>
 
  371         std::lock_guard<std::mutex> lock(mConnectionMutex);
 
  372         for (
auto& connection : mConnections)
 
  374             if (!send(connection, message, code, error))
 
  381     template<
typename config>
 
  385         std::lock_guard<std::mutex> lock(mConnectionMutex);
 
  386         for (
auto& connection : mConnections)
 
  388             if (!send(connection, payload, length, code, error))
 
  395     template<
typename config>
 
  398         std::error_code stdec;
 
  399         auto cptr = mEndPoint.get_con_from_hdl(connection.mConnection, stdec);
 
  400         return stdec ? 
"" : cptr->get_host();
 
  403     template<
typename config>
 
  407         std::error_code stdec;
 
  408         std::lock_guard<std::mutex> lock(mConnectionMutex);
 
  409         outHosts.reserve(mConnections.size());
 
  410         for (
auto& connection : mConnections)
 
  412             auto cptr = mEndPoint.get_con_from_hdl(connection, stdec);
 
  414                 outHosts.emplace_back(cptr->get_host());
 
  419     template<
typename config>
 
  422         std::lock_guard<std::mutex> lock(mConnectionMutex);
 
  423         return mConnections.size();
 
  426     template<
typename config>
 
  429         if (mConnectionLimit < 0)
 
  432         std::lock_guard<std::mutex> lock(mConnectionMutex);
 
  433         return mConnections.size() < mConnectionLimit;
 
  437     template<
typename config>
 
  445     template<
typename config>
 
  451             std::error_code stdec;
 
  452             mEndPoint.stop_listening(stdec);
 
  456                 nap::Logger::error(
"%s: %s", mID.c_str(), stdec.message().c_str());
 
  461             if (!disconnect(napec))
 
  464                 nap::Logger::error(
"%s: %s", mID.c_str(), napec.
toString().c_str());
 
  471             assert(mServerTask.valid());
 
  478     template<
typename config>
 
  485     template<
typename config>
 
  490         mAccessLogLevel = mLogConnectionUpdates ? websocketpp::log::alevel::all ^ websocketpp::log::alevel::frame_payload
 
  491             : websocketpp::log::alevel::fail;
 
  494         mEndPoint.clear_error_channels(websocketpp::log::elevel::all);
 
  495         mEndPoint.set_error_channels(mLogLevel);
 
  497         mEndPoint.clear_access_channels(websocketpp::log::alevel::all);
 
  498         mEndPoint.set_access_channels(mAccessLogLevel);
 
  501         mEndPoint.set_reuse_addr(mAllowPortReuse);
 
  504         std::error_code stdec;
 
  505         mEndPoint.init_asio(stdec);
 
  508             errorState.
fail(stdec.message());
 
  521         mEndPoint.set_message_handler(std::bind(
 
  523             std::placeholders::_1, std::placeholders::_2
 
  528         for (
const auto& ticket : mClients)
 
  529             mClientHashes.emplace(ticket->toHash());
 
  534     template<
typename config>
 
  537         std::error_code stdec;
 
  538         auto cptr = mEndPoint.get_con_from_hdl(connection, stdec);
 
  541             nap::Logger::error(stdec.message());
 
  547             std::lock_guard<std::mutex> lock(mConnectionMutex);
 
  548             mConnections.emplace_back(cptr);
 
  556     template<
typename config>
 
  559         std::error_code stdec;
 
  560         auto cptr = mEndPoint.get_con_from_hdl(connection, stdec);
 
  563             nap::Logger::error(stdec.message());
 
  569             listener->onConnectionClosed(
WebSocketConnection(connection), cptr->get_ec().value(), cptr->get_ec().message());
 
  573             std::lock_guard<std::mutex> lock(mConnectionMutex);
 
  574             auto found_it = std::find_if(mConnections.begin(), mConnections.end(), [&](
const auto& it)
 
  576                     auto client_ptr = mEndPoint.get_con_from_hdl(it, stdec);
 
  579                         nap::Logger::error(stdec.message());
 
  582                     return cptr == client_ptr;
 
  586             if (found_it == mConnections.end())
 
  591             mConnections.erase(found_it);
 
  596     template<
typename config>
 
  599         std::error_code stdec;
 
  600         auto cptr = mEndPoint.get_con_from_hdl(connection, stdec);
 
  603             nap::Logger::error(stdec.message());
 
  608             listener->onConnectionFailed(
WebSocketConnection(connection), cptr->get_ec().value(), cptr->get_ec().message());
 
  612     template<
typename config>
 
  621     template<
typename config>
 
  625         std::error_code stdec;
 
  626         auto conp = mEndPoint.get_con_from_hdl(con, stdec);
 
  629             nap::Logger::error(stdec.message());
 
  630             conp->set_status(websocketpp::http::status_code::internal_server_error);
 
  635         conp->append_header(
"Access-Control-Allow-Origin", mAccessAllowControlOrigin);
 
  638         std::string method = conp->get_request().get_method();
 
  641         if (method.compare(
"OPTIONS") == 0)
 
  643             conp->set_status(websocketpp::http::status_code::no_content);
 
  644             conp->append_header(
"Access-Control-Allow-Methods", 
"OPTIONS, POST");
 
  645             conp->append_header(
"Access-Control-Allow-Headers", 
"Content-Type");
 
  650         if (method.compare(
"POST") != 0)
 
  652             conp->set_status(websocketpp::http::status_code::method_not_allowed,
 
  653                 "only OPTIONS and POST requests are allowed");
 
  658         if (mMode == EAccessMode::EveryOne)
 
  660             conp->set_status(websocketpp::http::status_code::conflict,
 
  661                 "unable to generate ticket, no access policy set");
 
  666         std::string body = conp->get_request_body();
 
  669         rapidjson::Document document;
 
  670         rapidjson::ParseResult parse_result = document.Parse(body.c_str());
 
  671         if (!parse_result || !document.IsObject())
 
  673             conp->set_status(websocketpp::http::status_code::bad_request,
 
  674                 "unable to parse as JSON");
 
  679         if (!document.HasMember(
"user"))
 
  681             conp->append_header(
"WWW-Authenticate", 
"NAPUserPass");
 
  682             conp->set_status(websocketpp::http::status_code::unauthorized,
 
  683                 "missing required member: 'user");
 
  688         if (!document.HasMember(
"pass"))
 
  690             conp->append_header(
"WWW-Authenticate", 
"NAPUserPass");
 
  691             conp->set_status(websocketpp::http::status_code::unauthorized,
 
  692                 "missing required member: 'pass");
 
  700         ticket.
mID = math::generateUUID();
 
  701         ticket.
mUsername = document[
"user"].GetString();
 
  702         ticket.
mPassword = document[
"pass"].GetString();
 
  704         if (!ticket.
init(error))
 
  706             conp->append_header(
"WWW-Authenticate", 
"NAPUserPass");
 
  707             conp->set_status(websocketpp::http::status_code::unauthorized,
 
  708                 utility::stringFormat(
"invalid username or password: %s", error.toString().c_str()));
 
  714         if (mMode == EAccessMode::Reserved)
 
  718             if (mClientHashes.find(ticket.
toHash()) == mClientHashes.end())
 
  720                 conp->append_header(
"WWW-Authenticate", 
"NAPUserPass");
 
  721                 conp->set_status(websocketpp::http::status_code::unauthorized,
 
  722                     "invalid username or password");
 
  729         std::string ticket_str;
 
  732             nap::Logger::error(error.toString());
 
  733             conp->set_status(websocketpp::http::status_code::internal_server_error);
 
  738         conp->set_body(ticket_str);
 
  739         conp->set_status(websocketpp::http::status_code::ok);
 
  743     template<
typename config>
 
  747         std::error_code stdec;
 
  748         auto conp = mEndPoint.get_con_from_hdl(con, stdec);
 
  751             nap::Logger::error(stdec.message());
 
  752             conp->set_status(websocketpp::http::status_code::internal_server_error);
 
  757         if (!acceptsNewConnections())
 
  759             conp->set_status(websocketpp::http::status_code::forbidden,
 
  760                 "client connection count exceeded");
 
  765         const std::vector<std::string>& sub_protocol = conp->get_requested_subprotocols();
 
  768         if (mMode == EAccessMode::EveryOne)
 
  770             if (sub_protocol.empty())
 
  774             conp->set_status(websocketpp::http::status_code::not_found, 
"invalid sub-protocol");
 
  780         if (sub_protocol.empty())
 
  782             conp->set_status(websocketpp::http::status_code::forbidden,
 
  783                 "unable to extract ticket");
 
  788         conp->select_subprotocol(sub_protocol[0]);
 
  793         WebSocketTicket* client_ticket = WebSocketTicket::fromBinaryString(sub_protocol[0], result, error);
 
  794         if (client_ticket == 
nullptr)
 
  796             conp->set_status(websocketpp::http::status_code::forbidden,
 
  797                 "first sub-protocol argument is not a valid ticket object");
 
  802         if (mMode == EAccessMode::Ticket)
 
  806         if (mClientHashes.find(client_ticket->
toHash()) == mClientHashes.end())
 
  808             conp->set_status(websocketpp::http::status_code::forbidden,
 
  809                 "not a valid ticket");
 
  818     template<
typename config>
 
  825     template<
typename config>
 
  828         std::lock_guard<std::mutex> lock(mConnectionMutex);
 
  830         for (
auto& connection : mConnections)
 
  832             std::error_code stdec;
 
  833             mEndPoint.close(connection, websocketpp::close::status::going_away, 
"disconnected", stdec);
 
  836                 error.fail(stdec.message());
 
  840         mConnections.clear();
 
  845     template<
typename config>
 
  848         std::unique_lock<std::mutex> lock(mListenerMutex);
 
  849         mListeners.push_back(&server);
 
  853     template<
typename config>
 
  856         std::unique_lock<std::mutex> lock(mListenerMutex);
 
  857         mListeners.erase(std::remove(mListeners.begin(), mListeners.end(), &server), mListeners.end());