8 #include "iwebsocketserverendpoint.h"
9 #include "websocketmessage.h"
10 #include "websocketserver.h"
16 #include <unordered_set>
17 #include <mathutils.h>
18 #include <rapidjson/document.h>
19 #include <rapidjson/istreamwrapper.h>
20 #include <rapidjson/prettywriter.h>
21 #include <rapidjson/error/en.h>
22 #include <nap/logger.h>
71 template<
typename config>
72 class WebSocketServerEndPointSetup :
public IWebSocketServerEndPoint
74 RTTI_ENABLE(IWebSocketServerEndPoint)
84 bool init(utility::ErrorState& errorState)
override;
97 bool isOpen()
const override;
102 void stop()
override;
146 std::string
getHostName(
const WebSocketConnection& connection)
override;
152 void getHostNames(std::vector<std::string>& outHosts)
override;
292 template<
typename config>
299 std::error_code stdec;
300 if (mIPAddress.empty())
302 mEndPoint.listen(
static_cast<uint16>(mPort), stdec);
306 mEndPoint.listen(mIPAddress, utility::stringFormat(
"%d",
static_cast<uint16>(mPort)), stdec);
312 error.fail(stdec.message());
317 mEndPoint.start_accept(stdec);
320 error.fail(stdec.message());
332 template<
typename config>
339 template<
typename config>
342 std::error_code stdec;
343 mEndPoint.send(connection.mConnection, message,
static_cast<wspp::OpCode>(code), stdec);
346 error.fail(stdec.message());
353 template<
typename config>
356 std::error_code stdec;
357 mEndPoint.send(connection.mConnection, payload, length,
static_cast<wspp::OpCode>(code), stdec);
360 error.fail(stdec.message());
367 template<
typename config>
371 std::lock_guard<std::mutex> lock(mConnectionMutex);
372 for (
auto& connection : mConnections)
374 if (!send(connection, message, code, error))
381 template<
typename config>
385 std::lock_guard<std::mutex> lock(mConnectionMutex);
386 for (
auto& connection : mConnections)
388 if (!send(connection, payload, length, code, error))
395 template<
typename config>
398 std::error_code stdec;
399 auto cptr = mEndPoint.get_con_from_hdl(connection.mConnection, stdec);
400 return stdec ?
"" : cptr->get_host();
403 template<
typename config>
407 std::error_code stdec;
408 std::lock_guard<std::mutex> lock(mConnectionMutex);
409 outHosts.reserve(mConnections.size());
410 for (
auto& connection : mConnections)
412 auto cptr = mEndPoint.get_con_from_hdl(connection, stdec);
414 outHosts.emplace_back(cptr->get_host());
419 template<
typename config>
422 std::lock_guard<std::mutex> lock(mConnectionMutex);
423 return mConnections.size();
426 template<
typename config>
429 if (mConnectionLimit < 0)
432 std::lock_guard<std::mutex> lock(mConnectionMutex);
433 return mConnections.size() < mConnectionLimit;
437 template<
typename config>
445 template<
typename config>
451 std::error_code stdec;
452 mEndPoint.stop_listening(stdec);
456 nap::Logger::error(
"%s: %s", mID.c_str(), stdec.message().c_str());
461 if (!disconnect(napec))
464 nap::Logger::error(
"%s: %s", mID.c_str(), napec.
toString().c_str());
471 assert(mServerTask.valid());
478 template<
typename config>
485 template<
typename config>
490 mAccessLogLevel = mLogConnectionUpdates ? websocketpp::log::alevel::all ^ websocketpp::log::alevel::frame_payload
491 : websocketpp::log::alevel::fail;
494 mEndPoint.clear_error_channels(websocketpp::log::elevel::all);
495 mEndPoint.set_error_channels(mLogLevel);
497 mEndPoint.clear_access_channels(websocketpp::log::alevel::all);
498 mEndPoint.set_access_channels(mAccessLogLevel);
501 mEndPoint.set_reuse_addr(mAllowPortReuse);
504 std::error_code stdec;
505 mEndPoint.init_asio(stdec);
508 errorState.
fail(stdec.message());
521 mEndPoint.set_message_handler(std::bind(
523 std::placeholders::_1, std::placeholders::_2
528 for (
const auto& ticket : mClients)
529 mClientHashes.emplace(ticket->toHash());
534 template<
typename config>
537 std::error_code stdec;
538 auto cptr = mEndPoint.get_con_from_hdl(connection, stdec);
541 nap::Logger::error(stdec.message());
547 std::lock_guard<std::mutex> lock(mConnectionMutex);
548 mConnections.emplace_back(cptr);
556 template<
typename config>
559 std::error_code stdec;
560 auto cptr = mEndPoint.get_con_from_hdl(connection, stdec);
563 nap::Logger::error(stdec.message());
569 listener->onConnectionClosed(
WebSocketConnection(connection), cptr->get_ec().value(), cptr->get_ec().message());
573 std::lock_guard<std::mutex> lock(mConnectionMutex);
574 auto found_it = std::find_if(mConnections.begin(), mConnections.end(), [&](
const auto& it)
576 auto client_ptr = mEndPoint.get_con_from_hdl(it, stdec);
579 nap::Logger::error(stdec.message());
582 return cptr == client_ptr;
586 if (found_it == mConnections.end())
591 mConnections.erase(found_it);
596 template<
typename config>
599 std::error_code stdec;
600 auto cptr = mEndPoint.get_con_from_hdl(connection, stdec);
603 nap::Logger::error(stdec.message());
608 listener->onConnectionFailed(
WebSocketConnection(connection), cptr->get_ec().value(), cptr->get_ec().message());
612 template<
typename config>
621 template<
typename config>
625 std::error_code stdec;
626 auto conp = mEndPoint.get_con_from_hdl(con, stdec);
629 nap::Logger::error(stdec.message());
630 conp->set_status(websocketpp::http::status_code::internal_server_error);
635 conp->append_header(
"Access-Control-Allow-Origin", mAccessAllowControlOrigin);
638 std::string method = conp->get_request().get_method();
641 if (method.compare(
"OPTIONS") == 0)
643 conp->set_status(websocketpp::http::status_code::no_content);
644 conp->append_header(
"Access-Control-Allow-Methods",
"OPTIONS, POST");
645 conp->append_header(
"Access-Control-Allow-Headers",
"Content-Type");
650 if (method.compare(
"POST") != 0)
652 conp->set_status(websocketpp::http::status_code::method_not_allowed,
653 "only OPTIONS and POST requests are allowed");
658 if (mMode == EAccessMode::EveryOne)
660 conp->set_status(websocketpp::http::status_code::conflict,
661 "unable to generate ticket, no access policy set");
666 std::string body = conp->get_request_body();
669 rapidjson::Document document;
670 rapidjson::ParseResult parse_result = document.Parse(body.c_str());
671 if (!parse_result || !document.IsObject())
673 conp->set_status(websocketpp::http::status_code::bad_request,
674 "unable to parse as JSON");
679 if (!document.HasMember(
"user"))
681 conp->append_header(
"WWW-Authenticate",
"NAPUserPass");
682 conp->set_status(websocketpp::http::status_code::unauthorized,
683 "missing required member: 'user");
688 if (!document.HasMember(
"pass"))
690 conp->append_header(
"WWW-Authenticate",
"NAPUserPass");
691 conp->set_status(websocketpp::http::status_code::unauthorized,
692 "missing required member: 'pass");
700 ticket.
mID = math::generateUUID();
701 ticket.
mUsername = document[
"user"].GetString();
702 ticket.
mPassword = document[
"pass"].GetString();
704 if (!ticket.
init(error))
706 conp->append_header(
"WWW-Authenticate",
"NAPUserPass");
707 conp->set_status(websocketpp::http::status_code::unauthorized,
708 utility::stringFormat(
"invalid username or password: %s", error.toString().c_str()));
714 if (mMode == EAccessMode::Reserved)
718 if (mClientHashes.find(ticket.
toHash()) == mClientHashes.end())
720 conp->append_header(
"WWW-Authenticate",
"NAPUserPass");
721 conp->set_status(websocketpp::http::status_code::unauthorized,
722 "invalid username or password");
729 std::string ticket_str;
732 nap::Logger::error(error.toString());
733 conp->set_status(websocketpp::http::status_code::internal_server_error);
738 conp->set_body(ticket_str);
739 conp->set_status(websocketpp::http::status_code::ok);
743 template<
typename config>
747 std::error_code stdec;
748 auto conp = mEndPoint.get_con_from_hdl(con, stdec);
751 nap::Logger::error(stdec.message());
752 conp->set_status(websocketpp::http::status_code::internal_server_error);
757 if (!acceptsNewConnections())
759 conp->set_status(websocketpp::http::status_code::forbidden,
760 "client connection count exceeded");
765 const std::vector<std::string>& sub_protocol = conp->get_requested_subprotocols();
768 if (mMode == EAccessMode::EveryOne)
770 if (sub_protocol.empty())
774 conp->set_status(websocketpp::http::status_code::not_found,
"invalid sub-protocol");
780 if (sub_protocol.empty())
782 conp->set_status(websocketpp::http::status_code::forbidden,
783 "unable to extract ticket");
788 conp->select_subprotocol(sub_protocol[0]);
793 WebSocketTicket* client_ticket = WebSocketTicket::fromBinaryString(sub_protocol[0], result, error);
794 if (client_ticket ==
nullptr)
796 conp->set_status(websocketpp::http::status_code::forbidden,
797 "first sub-protocol argument is not a valid ticket object");
802 if (mMode == EAccessMode::Ticket)
806 if (mClientHashes.find(client_ticket->
toHash()) == mClientHashes.end())
808 conp->set_status(websocketpp::http::status_code::forbidden,
809 "not a valid ticket");
818 template<
typename config>
825 template<
typename config>
828 std::lock_guard<std::mutex> lock(mConnectionMutex);
830 for (
auto& connection : mConnections)
832 std::error_code stdec;
833 mEndPoint.close(connection, websocketpp::close::status::going_away,
"disconnected", stdec);
836 error.fail(stdec.message());
840 mConnections.clear();
845 template<
typename config>
848 std::unique_lock<std::mutex> lock(mListenerMutex);
849 mListeners.push_back(&server);
853 template<
typename config>
856 std::unique_lock<std::mutex> lock(mListenerMutex);
857 mListeners.erase(std::remove(mListeners.begin(), mListeners.end(), &server), mListeners.end());