From 1cf4ebb231f2f7770b717a5e176d7bb5cbc66284 Mon Sep 17 00:00:00 2001 From: Adam Date: Thu, 8 Jul 2010 22:19:13 -0400 Subject: [PATCH] Added an epoll socket engine --- data/example.conf | 21 +- docs/Changes.conf | 1 + include/config.h | 3 + include/extern.h | 2 +- include/modules.h | 2 +- include/services.h | 7 +- include/socketengine.h | 54 +++ include/sockets.h | 216 +++++---- src/Makefile | 2 +- src/config.cpp | 2 + src/core/m_socketengine_epoll.cpp | 155 +++++++ src/core/m_socketengine_select.cpp | 134 ++++++ src/core/os_modlist.cpp | 22 + src/init.cpp | 3 + src/main.cpp | 12 +- src/modules/ssl/m_ssl.cpp | 10 +- src/sockets.cpp | 681 ++++++++++++++--------------- 17 files changed, 864 insertions(+), 463 deletions(-) create mode 100644 include/socketengine.h create mode 100644 src/core/m_socketengine_epoll.cpp create mode 100644 src/core/m_socketengine_select.cpp diff --git a/data/example.conf b/data/example.conf index ce38541bf..d66e98611 100644 --- a/data/example.conf +++ b/data/example.conf @@ -287,11 +287,6 @@ options */ encryption = "enc_none enc_sha1 enc_sha256 enc_md5 enc_old" - /* - * The maximum length of passwords - */ - passlen = 32 - /* * The database modules are used for saving and loading databases for Anope. * @@ -303,6 +298,22 @@ options */ database = "db_plain" + /* + * The socket engine modules are used for managing connections to and from Anope + * + * Supported: + * - m_socketengine_select + * - m_socketengine_epoll + * + * We recommend using epoll if your operating system supports it. + */ + socketengine = "m_socketengine_epoll" + + /* + * The maximum length of passwords + */ + passlen = 32 + /* * These keys are used to initiate the random number generator. These numbers * MUST be random as you want your passcodes to be random. Don't give these diff --git a/docs/Changes.conf b/docs/Changes.conf index cb613c434..bad0cf57b 100644 --- a/docs/Changes.conf +++ b/docs/Changes.conf @@ -8,6 +8,7 @@ Various nickserv/saset/* and chanserv/saset/* opertype command privileges added nickserv:modules added many new ns_set_command modules chanserv:modules added many new cs_set_command modules opertype:commands added nickserv/saset/* and chanserv/saset/* +options:socketengine added to choose what socket engine to use ** MODIFIED CONFIGURATION DIRECTIVES ** opertype:commands changed operserv/sgline to opserv/snline diff --git a/include/config.h b/include/config.h index f6f794a33..8b1d42527 100644 --- a/include/config.h +++ b/include/config.h @@ -802,6 +802,9 @@ class ServerConfig /* Reason to akill clients for defcon */ char *DefConAkillReason; + /* The socket engine in use */ + ci::string SocketEngine; + /* User keys to use for generating random hashes for pass codes etc */ long unsigned int UserKey1; long unsigned int UserKey2; diff --git a/include/extern.h b/include/extern.h index e8019a0b4..a59ebc2b4 100644 --- a/include/extern.h +++ b/include/extern.h @@ -400,7 +400,7 @@ E int exception_add(User *u, const char *mask, const int limit, const char *reas /**** sockets.cpp ****/ -E SocketEngine socketEngine; +E SocketEngineBase *SocketEngine; E int32 TotalRead; E int32 TotalWritten; diff --git a/include/modules.h b/include/modules.h index cb5e6dd28..d2f20dff4 100644 --- a/include/modules.h +++ b/include/modules.h @@ -155,7 +155,7 @@ enum ModuleReturn /** Priority types which can be returned from Module::Prioritize() */ enum Priority { PRIORITY_FIRST, PRIORITY_DONTCARE, PRIORITY_LAST, PRIORITY_BEFORE, PRIORITY_AFTER }; -enum MODType { CORE, PROTOCOL, THIRD, SUPPORTED, QATESTED, ENCRYPTION, DATABASE }; +enum MODType { CORE, PROTOCOL, THIRD, SUPPORTED, QATESTED, ENCRYPTION, DATABASE, SOCKETENGINE }; struct Message; extern CoreExport std::multimap MessageMap; diff --git a/include/services.h b/include/services.h index 34041df2f..36a0937c3 100644 --- a/include/services.h +++ b/include/services.h @@ -300,10 +300,6 @@ class DatabaseException : public CoreException /*************************************************************************/ -#include "sockets.h" - -/*************************************************************************/ - /** Class with the ability to keep flags on items, they should extend from this * where T is an enum. */ @@ -354,6 +350,9 @@ template class Flags } }; +#include "sockets.h" +#include "socketengine.h" + /*************************************************************************/ template diff --git a/include/socketengine.h b/include/socketengine.h new file mode 100644 index 000000000..042d4b410 --- /dev/null +++ b/include/socketengine.h @@ -0,0 +1,54 @@ +/* + * + * (C) 2003-2010 Anope Team + * Contact us at team@anope.org + * + * Please read COPYING and README for furhter details. + * + * Based on the original code of Epona by Lara. + * Based on the original code of Services by Andy Church. + */ + +#ifndef SOCKETENGINE_H +#define SOCKETENGINE_H + +class CoreExport SocketEngineBase +{ + public: + /* Map of sockets */ + std::map Sockets; + + /** Default constructor + */ + SocketEngineBase() { } + + /** Default destructor + */ + virtual ~SocketEngineBase() { } + + /** Add a socket to the internal list + * @param s The socket + */ + virtual void AddSocket(Socket *s) { } + + /** Delete a socket from the internal list + * @param s The socket + */ + virtual void DelSocket(Socket *s) { } + + /** Mark a socket as writeable + * @param s The socket + */ + virtual void MarkWriteable(Socket *s) { } + + /** Unmark a socket as writeable + * @param s The socket + */ + virtual void ClearWriteable(Socket *s) { } + + /** Read from sockets and do things + */ + virtual void Process() { } +}; + +#endif // SOCKETENGINE_H diff --git a/include/sockets.h b/include/sockets.h index a705baab6..92279f81f 100644 --- a/include/sockets.h +++ b/include/sockets.h @@ -1,5 +1,6 @@ -/* - * (C) 2004-2010 Anope Team +/* + * + * (C) 2003-2010 Anope Team * Contact us at team@anope.org * * Please read COPYING and README for furhter details. @@ -11,14 +12,14 @@ #ifndef SOCKETS_H #define SOCKETS_H +#define NET_BUFSIZE 65535 + #ifdef _WIN32 # define CloseSocket closesocket #else # define CloseSocket close #endif -#define NET_BUFSIZE 65536 - class SocketException : public CoreException { public: @@ -33,90 +34,69 @@ class SocketException : public CoreException virtual ~SocketException() throw() { } }; -class CoreExport Socket +enum SocketType +{ + SOCKTYPE_CLIENT, + SOCKTYPE_LISTEN +}; + +enum SocketFlag +{ + SF_DEAD +}; + +class CoreExport Socket : public Flags { private: - /** Read from the socket - * @param buf Buffer to read to + /** Really recieve something from the buffer + * @param buf The buf to read to * @param sz How much to read * @return Number of bytes recieved */ - virtual int RecvInternal(char *buf, size_t sz) const; + virtual const int RecvInternal(char *buf, size_t sz) const; - /** Write to the socket + /** Really write something to the socket * @param buf What to write - * @return Number of bytes sent, -1 on error + * @return Number of bytes written */ - virtual int SendInternal(const std::string &buf) const; + virtual const int SendInternal(const std::string &buf) const; protected: - /* Socket FD */ - int Sock; - /* Host this socket is connected to */ - std::string TargetHost; - /* Port we're connected to */ - int Port; - /* IP this socket is bound to */ - std::string BindHost; - /* Is this an IPv6 socket? */ + /* Socket FD */ + int sock; + /* IPv6? */ bool IPv6; - - /* Messages to be written to the socket */ + /* Things to be written to the socket */ std::string WriteBuffer; - /* Part of a message not totally yet recieved */ + /* Part of a message sent from the server, but not totally recieved */ std::string extrabuf; - /* How much data was recieved from the socket */ + /* How much data was received from this socket */ size_t RecvLen; public: - /** Default constructor - * @param nTargetHost Hostname to connect to - * @param nPort Port to connect to - * @param nBindHos Host to bind to when connecting - * @param nIPv6 true to use IPv6 - */ - Socket(const std::string &nTargetHost, int nPort, const std::string &nBindHost = "", bool nIPv6 = false); + /* Type this socket is */ + SocketType Type; + /** Default constructor + * @param nsock The socket to use, 0 if we need to create our own + * @param nIPv6 true if using ipv6 + */ + Socket(int nsock, bool nIPv6); + /** Default destructor */ virtual ~Socket(); /** Get the socket FD for this socket - * @return The fd + * @return the fd */ - virtual int GetSock() const; + int GetSock() const; /** Check if this socket is IPv6 * @return true or false */ bool IsIPv6() const; - /** Called when there is something to be read from thie socket - * @return true on success, false to kill this socket - */ - virtual bool ProcessRead(); - - /** Called when this socket becomes writeable - * @return true on success, false to drop this socket - */ - virtual bool ProcessWrite(); - - /** Called when there is an error on this socket - */ - virtual void ProcessError(); - - /** Called with a message recieved from the socket - * @param buf The message - * @return true on success, false to kill this socket - */ - virtual bool Read(const std::string &buf); - - /** Write to the socket - * @param message The message to write - */ - void Write(const char *message, ...); - void Write(std::string &message); - /** Get the length of the read buffer * @return The length of the read buffer */ @@ -126,59 +106,105 @@ class CoreExport Socket * @return The length of the write buffer */ size_t WriteBufferLen() const; + + /** Called when there is something to be recieved for this socket + * @return true on success, false to drop this socket + */ + virtual bool ProcessRead(); + + /** Called when there is something to be written to this socket + * @return true on success, false to drop this socket + */ + virtual bool ProcessWrite(); + + /** Called when there is an error for this socket + * @return true on success, false to drop this socket + */ + virtual void ProcessError(); + + /** Called with a line recieved from the socket + * @param buf The line + * @return true to continue reading, false to drop the socket + */ + virtual bool Read(const std::string &buf); + + /** Write to the socket + * @param message The message + */ + void Write(const char *message, ...); + void Write(const std::string &message); }; -class CoreExport SocketEngine +class CoreExport ClientSocket : public Socket { - private: - /* List of sockets that need to be deleted */ - std::set OldSockets; - /* FDs to read */ - fd_set ReadFDs; - /* FDs that want writing */ - fd_set WriteFDs; - /* Max FD */ - int MaxFD; + protected: + /* Target host we're connected to */ + std::string TargetHost; + /* Target port we're connected to */ + int Port; + /* The host to bind to */ + std::string BindHost; - /** Unmark a socket as writeable - * @param s The socket - */ - void ClearWriteable(Socket *s); public: - /* Set of sockets */ - std::set Sockets; + /** Constructor + * @param nTargetHost The target host to connect to + * @param nPort The target port to connect to + * @param nBindHost The host to bind to for connecting + * @param nIPv6 true to use IPv6 + */ + ClientSocket(const std::string &nTargetHost, int nPort, const std::string &nBindHost, bool nIPv6); + + /** Default destructor + */ + virtual ~ClientSocket(); + + /** Called with a line recieved from the socket + * @param buf The line + * @return true to continue reading, false to drop the socket + */ + virtual bool Read(const std::string &buf); +}; + +class CoreExport ListenSocket : public Socket +{ + protected: + /* Bind IP */ + std::string BindIP; + /* Port to bind to */ + int Port; + + public: /** Constructor + * @param bind The IP to bind to + * @param port The port to listen on */ - SocketEngine(); + ListenSocket(const std::string &bind, int port); /** Destructor */ - virtual ~SocketEngine(); + virtual ~ListenSocket(); - /** Add a socket to the socket engine - * @param s The socket + /** Process what has come in from the connection + * @return false to destory this socket */ - void AddSocket(Socket *s); + bool ProcessRead(); - /** Delete a socket from the socket engine - * @param s The socket + /** Called when a connection is accepted + * @param s The socket for the new connection + * @return true if the listen socket should remain alive */ - void DelSocket(Socket *s); + virtual bool OnAccept(Socket *s); - /** Mark a socket as wanting to be written to - * @param s The socket + /** Get the bind IP for this socket + * @return the bind ip */ - void MarkWriteable(Socket *s); + const std::string &GetBindIP() const; - /** Called to iterate through each socket and check for activity - */ - void Process(); - - /** Get the last socket error - * @return The error - */ - const std::string GetError() const; + /** Get the port this socket is bound to + * @return The port + */ + const int GetPort() const; }; -#endif // SOCKETS_H +#endif // SOCKET_H diff --git a/src/Makefile b/src/Makefile index d4ce6625e..e50218ab3 100644 --- a/src/Makefile +++ b/src/Makefile @@ -70,7 +70,7 @@ send.o: send.cpp $(INCLUDES) servers.o: servers.cpp $(INCLUDES) sessions.o: sessions.cpp $(INCLUDES) slist.o: slist.cpp $(INCLUDES) -sockets.o: sockets.cpp $(INCLUDES) +sockets.o: sockets.cpp $(INCLUDES) threadengine.o: threadengine.cpp $(INCLUDES) threadengine_pthread.o: threadengine_pthread.cpp $(INCLUDES) timers.o: timers.cpp $(INCLUDES) diff --git a/src/config.cpp b/src/config.cpp index 93a65f9ba..5785c4452 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -21,6 +21,7 @@ ServerConfig Config; static ci::string Modules; static ci::string EncModules; static ci::string DBModules; +static ci::string SocketEngineModule; static ci::string HostCoreModules; static ci::string MemoCoreModules; static ci::string BotCoreModules; @@ -629,6 +630,7 @@ int ServerConfig::Read(bool bail) {"options", "encryption", "", new ValueContainerCIString(&EncModules), DT_CISTRING | DT_NORELOAD, ValidateNotEmpty}, {"options", "passlen", "32", new ValueContainerUInt(&Config.PassLen), DT_UINTEGER | DT_NORELOAD, NoValidation}, {"options", "database", "", new ValueContainerCIString(&DBModules), DT_CISTRING | DT_NORELOAD, ValidateNotEmpty}, + {"options", "socketengine", "", new ValueContainerCIString(&Config.SocketEngine), DT_CISTRING | DT_NORELOAD, ValidateNotEmpty}, {"options", "userkey1", "0", new ValueContainerLUInt(&Config.UserKey1), DT_LUINTEGER, NoValidation}, {"options", "userkey2", "0", new ValueContainerLUInt(&Config.UserKey2), DT_LUINTEGER, NoValidation}, {"options", "userkey3", "0", new ValueContainerLUInt(&Config.UserKey3), DT_LUINTEGER, NoValidation}, diff --git a/src/core/m_socketengine_epoll.cpp b/src/core/m_socketengine_epoll.cpp new file mode 100644 index 000000000..95bc60926 --- /dev/null +++ b/src/core/m_socketengine_epoll.cpp @@ -0,0 +1,155 @@ +#include "module.h" +#include +#include + +class SocketEngineEPoll : public SocketEngineBase +{ + private: + long max; + int EngineHandle; + epoll_event *events; + unsigned SocketCount; + + public: + SocketEngineEPoll() + { + SocketCount = 0; + max = ulimit(4, 0); + + if (max <= 0) + { + Alog() << "Can't determine maximum number of open sockets"; + throw ModuleException("Can't determine maximum number of open sockets"); + } + + EngineHandle = epoll_create(max / 4); + + if (EngineHandle == -1) + { + Alog() << "Could not initialize epoll socket engine: " << strerror(errno); + throw ModuleException("Could not initialize epoll socket engine: " + std::string(strerror(errno))); + } + + events = new epoll_event[max]; + memset(events, 0, sizeof(epoll_event) * max); + } + + ~SocketEngineEPoll() + { + delete [] events; + } + + void AddSocket(Socket *s) + { + epoll_event ev; + + memset(&ev, 0, sizeof(ev)); + + ev.events = EPOLLIN | EPOLLOUT; + ev.data.fd = s->GetSock(); + + if (epoll_ctl(EngineHandle, EPOLL_CTL_ADD, ev.data.fd, &ev) == -1) + { + Alog() << "Unable to add fd " << ev.data.fd << " to socketengine epoll: " << strerror(errno); + return; + } + + Sockets.insert(std::make_pair(ev.data.fd, s)); + + ++SocketCount; + } + + void DelSocket(Socket *s) + { + epoll_event ev; + + memset(&ev, 0, sizeof(ev)); + + ev.data.fd = s->GetSock(); + + if (epoll_ctl(EngineHandle, EPOLL_CTL_DEL, ev.data.fd, &ev) == -1) + { + Alog() << "Unable to delete fd " << ev.data.fd << " from socketengine epoll: " << strerror(errno); + return; + } + + Sockets.erase(ev.data.fd); + + --SocketCount; + } + + void Process() + { + int total = epoll_wait(EngineHandle, events, max - 1, (Config.ReadTimeout * 1000)); + + if (total == -1) + { + Alog() << "SockEngine::Process(): error " << strerror(errno); + return; + } + + for (int i = 0; i < total; ++i) + { + epoll_event *ev = &events[i]; + Socket *s = Sockets[ev->data.fd]; + + if (ev->events & (EPOLLHUP | EPOLLERR)) + { + s->ProcessError(); + s->SetFlag(SF_DEAD); + continue; + } + + if (ev->events & EPOLLIN) + { + if (!s->ProcessRead()) + { + s->SetFlag(SF_DEAD); + } + } + + if (ev->events & EPOLLOUT) + { + if (!s->ProcessWrite()) + { + s->SetFlag(SF_DEAD); + } + } + } + + for (std::map::iterator it = Sockets.begin(), it_end = Sockets.end(); it != it_end;) + { + Socket *s = it->second; + ++it; + + if (s->HasFlag(SF_DEAD)) + { + delete s; + } + } + } +}; + +class ModuleSocketEngineEPoll : public Module +{ + SocketEngineEPoll *engine; + + public: + ModuleSocketEngineEPoll(const std::string &modname, const std::string &creator) : Module(modname, creator) + { + this->SetPermanent(true); + this->SetType(SOCKETENGINE); + + engine = new SocketEngineEPoll(); + SocketEngine = engine; + } + + ~ModuleSocketEngineEPoll() + { + delete engine; + SocketEngine = NULL; + } +}; + +MODULE_INIT(ModuleSocketEngineEPoll) + diff --git a/src/core/m_socketengine_select.cpp b/src/core/m_socketengine_select.cpp new file mode 100644 index 000000000..c7346f87c --- /dev/null +++ b/src/core/m_socketengine_select.cpp @@ -0,0 +1,134 @@ +#include "module.h" + +class SocketEngineSelect : public SocketEngineBase +{ + private: + /* Max Read FD */ + int MaxFD; + /* Read FDs */ + fd_set ReadFDs; + /* Write FDs */ + fd_set WriteFDs; + + public: + SocketEngineSelect() + { + MaxFD = 0; + FD_ZERO(&ReadFDs); + FD_ZERO(&WriteFDs); + } + + ~SocketEngineSelect() + { + FD_ZERO(&ReadFDs); + FD_ZERO(&WriteFDs); + } + + void AddSocket(Socket *s) + { + if (s->GetSock() > MaxFD) + MaxFD = s->GetSock(); + FD_SET(s->GetSock(), &ReadFDs); + Sockets.insert(std::make_pair(s->GetSock(), s)); + } + + void DelSocket(Socket *s) + { + if (s->GetSock() == MaxFD) + --MaxFD; + FD_CLR(s->GetSock(), &ReadFDs); + FD_CLR(s->GetSock(), &WriteFDs); + Sockets.erase(s->GetSock()); + } + + void MarkWriteable(Socket *s) + { + FD_SET(s->GetSock(), &WriteFDs); + } + + void ClearWriteable(Socket *s) + { + FD_CLR(s->GetSock(), &WriteFDs); + } + + void Process() + { + fd_set rfdset = ReadFDs, wfdset = WriteFDs, efdset = ReadFDs; + timeval tval; + tval.tv_sec = Config.ReadTimeout; + tval.tv_usec = 0; + + int sresult = select(MaxFD + 1, &rfdset, &wfdset, &efdset, &tval); + + if (sresult == -1) + { +#ifdef WIN32 + errno = WSAGetLastError(); +#endif + Alog() << "SockEngine::Process(): error" << strerror(errno); + } + else if (sresult) + { + for (std::map::const_iterator it = Sockets.begin(), it_end = Sockets.end(); it != it_end; ++it) + { + Socket *s = it->second; + + if (FD_ISSET(s->GetSock(), &efdset)) + { + s->ProcessError(); + s->SetFlag(SF_DEAD); + continue; + } + if (FD_ISSET(s->GetSock(), &rfdset)) + { + if (!s->ProcessRead()) + { + s->SetFlag(SF_DEAD); + } + } + if (FD_ISSET(s->GetSock(), &wfdset)) + { + if (!s->ProcessWrite()) + { + s->SetFlag(SF_DEAD); + } + } + } + + for (std::map::iterator it = Sockets.begin(), it_end = Sockets.end(); it != it_end;) + { + Socket *s = it->second; + ++it; + + if (s->HasFlag(SF_DEAD)) + { + delete s; + } + } + } + } +}; + +class ModuleSocketEngineSelect : public Module +{ + SocketEngineSelect *engine; + + public: + ModuleSocketEngineSelect(const std::string &modname, const std::string &creator) : Module(modname, creator) + { + this->SetPermanent(true); + this->SetType(SOCKETENGINE); + + engine = new SocketEngineSelect(); + SocketEngine = engine; + } + + ~ModuleSocketEngineSelect() + { + delete engine; + SocketEngine = NULL; + } +}; + +MODULE_INIT(ModuleSocketEngineSelect) + diff --git a/src/core/os_modlist.cpp b/src/core/os_modlist.cpp index c719b734d..8764188da 100644 --- a/src/core/os_modlist.cpp +++ b/src/core/os_modlist.cpp @@ -30,6 +30,7 @@ class CommandOSModList : public Command int showSupported = 1; int showQA = 1; int showDB = 1; + int showSocketEngine = 1; ci::string param = params.size() ? params[0] : ""; @@ -40,6 +41,7 @@ class CommandOSModList : public Command char supported[] = "Supported"; char qa[] = "QATested"; char db[] = "Database"; + char socketengine[] = "SocketEngine"; if (!param.empty()) { @@ -52,6 +54,7 @@ class CommandOSModList : public Command showSupported = 0; showQA = 0; showDB = 0; + showSocketEngine = 0; } else if (param == third) { @@ -62,6 +65,7 @@ class CommandOSModList : public Command showProto = 0; showEnc = 0; showDB = 0; + showSocketEngine = 0; } else if (param == proto) { @@ -72,6 +76,7 @@ class CommandOSModList : public Command showSupported = 0; showQA = 0; showDB = 0; + showSocketEngine = 0; } else if (param == supported) { @@ -82,6 +87,7 @@ class CommandOSModList : public Command showEnc = 0; showQA = 0; showDB = 0; + showSocketEngine = 0; } else if (param == qa) { @@ -92,6 +98,7 @@ class CommandOSModList : public Command showEnc = 0; showQA = 1; showDB = 0; + showSocketEngine = 0; } else if (param == enc) { @@ -102,6 +109,7 @@ class CommandOSModList : public Command showEnc = 1; showQA = 0; showDB = 0; + showSocketEngine = 0; } else if (param == db) { @@ -112,6 +120,12 @@ class CommandOSModList : public Command showEnc = 0; showQA = 0; showDB = 1; + showSocketEngine = 0; + } + else if (param == socketengine) + { + showCore = showThird = showProto = showSupported = showEnc = showQA = showDB = 0; + showSocketEngine = 1; } } @@ -171,6 +185,14 @@ class CommandOSModList : public Command notice_lang(Config.s_OperServ, u, OPER_MODULE_LIST, m->name.c_str(), m->version.c_str(), db); ++count; } + break; + case SOCKETENGINE: + if (showSocketEngine) + { + notice_lang(Config.s_OperServ, u, OPER_MODULE_LIST, m->name.c_str(), m->version.c_str(), socketengine); + ++count; + } + break; } } if (!count) diff --git a/src/init.cpp b/src/init.cpp index 07597c829..b0f019ad3 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -352,6 +352,9 @@ int init_primary(int ac, char **av) /* Add Database Modules */ ModuleManager::LoadModuleList(Config.DBModuleList); + /* Load the socket engine */ + ModuleManager::LoadModule(Config.SocketEngine, NULL); + return 0; } diff --git a/src/main.cpp b/src/main.cpp index 7634a0e15..dccfcc15f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -109,16 +109,18 @@ class UpdateTimer : public Timer Socket *UplinkSock = NULL; -class UplinkSocket : public Socket +class UplinkSocket : public ClientSocket { public: - UplinkSocket(const std::string &nTargetHost, int nPort, const std::string &nBindHost = "", bool nIPv6 = false) : Socket(nTargetHost, nPort, nBindHost, nIPv6) + UplinkSocket(const std::string &nTargetHost, int nPort, const std::string &nBindHost = "", bool nIPv6 = false) : ClientSocket(nTargetHost, nPort, nBindHost, nIPv6) { UplinkSock = this; } ~UplinkSocket() { + /* Process the last bits of data before disconnecting */ + SocketEngine->Process(); UplinkSock = NULL; } @@ -189,8 +191,6 @@ void do_restart_services() UserListByUID.erase(it->second->GetUID().c_str()); } ircdproto->SendSquit(Config.ServerName, quitmsg); - /* Process to send the last bits of information before disconnecting */ - socketEngine.Process(); delete UplinkSock; close_log(); /* First don't unload protocol module, then do so */ @@ -239,8 +239,6 @@ static void services_shutdown() while (!UserListByNick.empty()) delete UserListByNick.begin()->second; } - /* Process to send the last bits of information before disconnecting */ - socketEngine.Process(); delete UplinkSock; FOREACH_MOD(I_OnShutdown, OnShutdown()); /* First don't unload protocol module, then do so */ @@ -538,7 +536,7 @@ int main(int ac, char **av, char **envp) ModeManager::ProcessModes(); /* Process the socket engine */ - socketEngine.Process(); + SocketEngine->Process(); } if (quitting) diff --git a/src/modules/ssl/m_ssl.cpp b/src/modules/ssl/m_ssl.cpp index 5bc9870da..0771e6d4c 100644 --- a/src/modules/ssl/m_ssl.cpp +++ b/src/modules/ssl/m_ssl.cpp @@ -14,22 +14,22 @@ static SSL_CTX *ctx; -class SSLSocket : public Socket +class SSLSocket : public ClientSocket { private: SSL *sslsock; - int RecvInternal(char *buf, size_t sz) const + const int RecvInternal(char *buf, size_t sz) const { return SSL_read(sslsock, buf, sz); } - int SendInternal(const std::string &buf) const + const int SendInternal(const std::string &buf) const { return SSL_write(sslsock, buf.c_str(), buf.size()); } public: - SSLSocket(const std::string &nTargetHost, int nPort, const std::string &nBindHost = "", bool nIPv6 = false) : Socket(nTargetHost, nPort, nBindHost, nIPv6) + SSLSocket(const std::string &nTargetHost, int nPort, const std::string &nBindHost = "", bool nIPv6 = false) : ClientSocket(nTargetHost, nPort, nBindHost, nIPv6) { sslsock = SSL_new(ctx); @@ -37,7 +37,7 @@ class SSLSocket : public Socket throw CoreException("Unable to initialize SSL socket"); SSL_set_connect_state(sslsock); - SSL_set_fd(sslsock, Sock); + SSL_set_fd(sslsock, sock); SSL_connect(sslsock); UplinkSock = this; diff --git a/src/sockets.cpp b/src/sockets.cpp index db51d7b03..15592c7d6 100644 --- a/src/sockets.cpp +++ b/src/sockets.cpp @@ -1,11 +1,11 @@ #include "services.h" -SocketEngine socketEngine; +SocketEngineBase *SocketEngine; int32 TotalRead = 0; int32 TotalWritten = 0; /** Trims all the \r and \ns from the begining and end of a string - * @return A string without trailing \r and \ns + * @param buffer The buffer to trim */ static void TrimBuf(std::string &buffer) { @@ -15,148 +15,54 @@ static void TrimBuf(std::string &buffer) buffer.erase(buffer.length() - 1); } -/** Default constructor - * @param nTargetHost Hostname to connect to - * @param nPort Port to connect to - * @param nBindHos Host to bind to when connecting - * @param nIPv6 true to use IPv6 +/** Constructor + * @param nsock The socket + * @param nIPv6 IPv6? */ -Socket::Socket(const std::string &nTargetHost, int nPort, const std::string &nBindHost, bool nIPv6) : TargetHost(nTargetHost), Port(nPort), BindHost(nBindHost), IPv6(nIPv6) +Socket::Socket(int nsock, bool nIPv6) { - if (!IPv6 && (TargetHost.find(':') != std::string::npos || BindHost.find(':') != std::string::npos)) - IPv6 = true; - - Sock = socket(IPv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0); - - addrinfo hints; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = 0; - hints.ai_protocol = IPPROTO_TCP; - hints.ai_family = IPv6 ? AF_INET6 : AF_INET; - - if (!BindHost.empty()) - { - addrinfo *bindar; - sockaddr_in bindaddr; - sockaddr_in6 bindaddr6; - - int Bound = -1; - if (!getaddrinfo(BindHost.c_str(), NULL, &hints, &bindar)) - { - if (IPv6) - memcpy(&bindaddr6, bindar->ai_addr, bindar->ai_addrlen); - else - memcpy(&bindaddr, bindar->ai_addr, bindar->ai_addrlen); - - freeaddrinfo(bindar); - - Bound = bind(Sock, reinterpret_cast(&bindaddr), sizeof(bindaddr)); - } - if (Bound < 0) - { - if (IPv6) - { - bindaddr6.sin6_family = AF_INET6; - - if (inet_pton(AF_INET6, BindHost.c_str(), &bindaddr6.sin6_addr) < 1) - throw SocketException("Invalid bind host"); - - if (bind(Sock, reinterpret_cast(&bindaddr6), sizeof(bindaddr6)) == -1) - throw SocketException("Unable to bind to address"); - } - else - { - bindaddr.sin_family = AF_INET; - - if (inet_pton(bindaddr.sin_family, BindHost.c_str(), &bindaddr.sin_addr) < 1) - throw SocketException("Invalid bind host"); - - if (bind(Sock, reinterpret_cast(&bindaddr), sizeof(bindaddr)) == -1) - throw SocketException("Unable to bind to address"); - } - } - } - - addrinfo *conar; - sockaddr_in conaddr; - sockaddr_in6 conaddr6; - if (!getaddrinfo(TargetHost.c_str(), NULL, &hints, &conar)) - { - if (IPv6) - memcpy(&conaddr6, conar->ai_addr, conar->ai_addrlen); - else - memcpy(&conaddr, conar->ai_addr, conar->ai_addrlen); - - freeaddrinfo(conar); - } + Type = SOCKTYPE_CLIENT; + IPv6 = nIPv6; + if (nsock == 0) + sock = socket(IPv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0); else - { - if (IPv6) - { - if (inet_pton(AF_INET6, TargetHost.c_str(), &conaddr6.sin6_addr) < 1) - throw SocketException("Invalid server address"); - } - else - { - if (inet_pton(AF_INET, TargetHost.c_str(), &conaddr.sin_addr) < 1) - throw SocketException("Invalid server address"); - } - } - - if (IPv6) - { - conaddr6.sin6_family = AF_INET6; - conaddr6.sin6_port = htons(Port); - - if (connect(Sock, reinterpret_cast(&conaddr6), sizeof(conaddr6)) < 0) - throw SocketException("Error connecting to server"); - } - else - { - conaddr.sin_family = AF_INET; - conaddr.sin_port = htons(Port); - - if (connect(Sock, reinterpret_cast(&conaddr), sizeof(conaddr)) < 0) - throw SocketException("Error connecting to server"); - } - - socketEngine.AddSocket(this); + sock = nsock; + SocketEngine->AddSocket(this); } /** Default destructor - */ +*/ Socket::~Socket() { - CloseSocket(Sock); - - socketEngine.DelSocket(this); + SocketEngine->DelSocket(this); + CloseSocket(sock); } -/** Read from the socket - * @param buf Buffer to read to +/** Really recieve something from the buffer + * @param buf The buf to read to * @param sz How much to read * @return Number of bytes recieved */ -int Socket::RecvInternal(char *buf, size_t sz) const +const int Socket::RecvInternal(char *buf, size_t sz) const { return recv(GetSock(), buf, sz, 0); } -/** Write to the socket +/** Really write something to the socket * @param buf What to write - * @return Number of bytes sent, -1 on error + * @return Number of bytes written */ -int Socket::SendInternal(const std::string &buf) const +const int Socket::SendInternal(const std::string &buf) const { return send(GetSock(), buf.c_str(), buf.length(), 0); } /** Get the socket FD for this socket - * @return The fd + * @return the fd */ int Socket::GetSock() const { - return Sock; + return sock; } /** Check if this socket is IPv6 @@ -167,98 +73,6 @@ bool Socket::IsIPv6() const return IPv6; } -/** Called when there is something to be read from thie socket - * @return true on success, false to kill this socket - */ -bool Socket::ProcessRead() -{ - char buffer[NET_BUFSIZE]; - memset(&buffer, 0, sizeof(buffer)); - - RecvLen = RecvInternal(buffer, sizeof(buffer) - 1); - if (RecvLen <= 0) - return false; - TotalRead += RecvLen; - - std::string sbuffer = extrabuf; - sbuffer.append(buffer); - extrabuf.clear(); - size_t lastnewline = sbuffer.find_last_of('\n'); - if (lastnewline < sbuffer.size() - 1) - { - extrabuf = sbuffer.substr(lastnewline); - TrimBuf(extrabuf); - sbuffer = sbuffer.substr(0, lastnewline); - } - - sepstream stream(sbuffer, '\n'); - std::string buf; - - while (stream.GetToken(buf)) - { - TrimBuf(buf); - - if (!buf.empty()) - if (!Read(buf)) - return false; - } - - return true; -} - -/** Called when this socket becomes writeable - * @return true on success, false to drop this socket - */ -bool Socket::ProcessWrite() -{ - int Written = SendInternal(WriteBuffer); - if (Written == -1) - return false; - TotalWritten += Written; - - WriteBuffer.clear(); - return true; -} - -/** Called when there is an error on this socket - */ -void Socket::ProcessError() -{ -} - -/** Called with a message recieved from the socket - * @param buf The message - * @return true on success, false to kill this socket - */ -bool Socket::Read(const std::string &buf) -{ - return true; -} - -/** Write to the socket - * @param message The message to write - */ -void Socket::Write(const char *message, ...) -{ - char buf[BUFSIZE]; - va_list vi; - va_start(vi, message); - vsnprintf(buf, sizeof(buf), message, vi); - va_end(vi); - - std::string sbuf = buf; - Write(sbuf); -} - -/** Write to the socket - * @param message The message to write - */ -void Socket::Write(std::string &message) -{ - WriteBuffer.append(message + "\r\n"); - socketEngine.MarkWriteable(this); -} - /** Get the length of the read buffer * @return The length of the read buffer */ @@ -272,159 +86,338 @@ size_t Socket::ReadBufferLen() const */ size_t Socket::WriteBufferLen() const { - return WriteBuffer.size(); + return WriteBuffer.length(); +} + +/** Called when there is something to be recieved for this socket + * @return true on success, false to drop this socket + */ +bool Socket::ProcessRead() +{ + char tbuffer[NET_BUFSIZE]; + memset(&tbuffer, '\0', sizeof(tbuffer)); + + RecvLen = RecvInternal(tbuffer, sizeof(tbuffer) - 1); + if (RecvLen <= 0) + return false; + + std::string sbuffer = extrabuf; + sbuffer.append(tbuffer); + extrabuf.clear(); + size_t lastnewline = sbuffer.find_last_of('\n'); + if (lastnewline < sbuffer.size() - 1) + { + extrabuf = sbuffer.substr(lastnewline); + TrimBuf(extrabuf); + sbuffer = sbuffer.substr(0, lastnewline); + } + + sepstream stream(sbuffer, '\n'); + + std::string tbuf; + while (stream.GetToken(tbuf)) + { + TrimBuf(tbuf); + + if (!tbuf.empty()) + if (!Read(tbuf)) + return false; + } + + return true; +} + +/** Called when there is something to be written to this socket + * @return true on success, false to drop this socket + */ +bool Socket::ProcessWrite() +{ + if (WriteBuffer.empty()) + { + return true; + } + if (SendInternal(WriteBuffer) == -1) + { + return false; + } + WriteBuffer.clear(); + SocketEngine->ClearWriteable(this); + + return true; +} + +/** Called when there is an error for this socket + * @return true on success, false to drop this socket + */ +void Socket::ProcessError() +{ +} + +/** Called with a line recieved from the socket + * @param buf The line + * @return true to continue reading, false to drop the socket + */ +bool Socket::Read(const std::string &buf) +{ + return false; +} + +/** Write to the socket + * @param message The message + */ +void Socket::Write(const char *message, ...) +{ + va_list vi; + char tbuffer[BUFSIZE]; + std::string sbuf; + + if (!message) + return; + + va_start(vi, message); + vsnprintf(tbuffer, sizeof(tbuffer), message, vi); + va_end(vi); + + sbuf = tbuffer; + Write(sbuf); +} + +/** Write to the socket + * @param message The message + */ +void Socket::Write(const std::string &message) +{ + WriteBuffer.append(message + "\r\n"); + SocketEngine->MarkWriteable(this); } /** Constructor + * @param nLS The listen socket this connection came from + * @param nu The user using this socket + * @param nsock The socket + * @param nIPv6 IPv6 */ -SocketEngine::SocketEngine() +ClientSocket::ClientSocket(const std::string &nTargetHost, int nPort, const std::string &nBindHost, bool nIPv6) : Socket(0, nIPv6), TargetHost(nTargetHost), Port(nPort), BindHost(nBindHost) { - FD_ZERO(&ReadFDs); - FD_ZERO(&WriteFDs); - MaxFD = 0; + if (!IPv6 && (TargetHost.find(':') != std::string::npos || BindHost.find(':') != std::string::npos)) + IPv6 = true; + + addrinfo hints; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = 0; + hints.ai_protocol = IPPROTO_TCP; + hints.ai_family = IPv6 ? AF_INET6 : AF_INET; -#ifdef _WIN32 - WSADATA wsa; - if (WSAStartup(MAKEWORD(2, 0), &wsa)) - Alog() << "Failed to initialize WinSock library"; -#endif + if (!BindHost.empty()) + { + addrinfo *bindar; + sockaddr_in bindaddr; + sockaddr_in6 bindaddr6; + + if (getaddrinfo(BindHost.c_str(), NULL, &hints, &bindar) == 0) + { + if (IPv6) + memcpy(&bindaddr6, bindar->ai_addr, bindar->ai_addrlen); + else + memcpy(&bindaddr, bindar->ai_addr, bindar->ai_addrlen); + + freeaddrinfo(bindar); + } + else + { + if (IPv6) + { + bindaddr6.sin6_family = AF_INET6; + + if (inet_pton(AF_INET6, BindHost.c_str(), &bindaddr6.sin6_addr) < 1) + throw SocketException("Invalid bind host: " + std::string(strerror(errno))); + } + else + { + bindaddr.sin_family = AF_INET; + + if (inet_pton(AF_INET, BindHost.c_str(), &bindaddr.sin_addr) < 1) + throw SocketException("Invalid bind host: " + std::string(strerror(errno))); + } + } + + if (IPv6) + { + if (bind(sock, reinterpret_cast(&bindaddr6), sizeof(bindaddr6)) == -1) + throw SocketException("Unable to bind to address: " + std::string(strerror(errno))); + } + else + { + if (bind(sock, reinterpret_cast(&bindaddr), sizeof(bindaddr)) == -1) + throw SocketException("Unable to bind to address: " + std::string(strerror(errno))); + } + } + + addrinfo *conar; + sockaddr_in6 addr6; + sockaddr_in addr; + + if (getaddrinfo(TargetHost.c_str(), NULL, &hints, &conar) == 0) + { + if (IPv6) + memcpy(&addr6, conar->ai_addr, conar->ai_addrlen); + else + memcpy(&addr, conar->ai_addr, conar->ai_addrlen); + + freeaddrinfo(conar); + } + else + { + if (IPv6) + { + if (inet_pton(AF_INET6, TargetHost.c_str(), &addr6.sin6_addr) < 1) + throw SocketException("Invalid server host: " + std::string(strerror(errno))); + } + else + { + if (inet_pton(AF_INET, TargetHost.c_str(), &addr.sin_addr) < 1) + throw SocketException("Invalid server host: " + std::string(strerror(errno))); + } + } + + if (IPv6) + { + addr6.sin6_family = AF_INET6; + addr6.sin6_port = htons(nPort); + + if (connect(sock, reinterpret_cast(&addr6), sizeof(addr6)) == -1) + { + throw SocketException("Error connecting to server: " + std::string(strerror(errno))); + } + } + else + { + addr.sin_family = AF_INET; + addr.sin_port = htons(nPort); + + if (connect(sock, reinterpret_cast(&addr), sizeof(addr)) == -1) + { + throw SocketException("Error connecting to server: " + std::string(strerror(errno))); + } + } +} + +/** Default destructor + */ +ClientSocket::~ClientSocket() +{ +} + +/** Called with a line recieved from the socket + * @param buf The line + * @return true to continue reading, false to drop the socket + */ +bool ClientSocket::Read(const std::string &buf) +{ + return true; +} + +/** Constructor + * @param bind The IP to bind to + * @param port The port to listen on + */ +ListenSocket::ListenSocket(const std::string &bindip, int port) : Socket(0, (bindip.find(':') != std::string::npos ? true : false)) +{ + Type = SOCKTYPE_LISTEN; + BindIP = bindip; + Port = port; + + sockaddr_in sock_addr; + sockaddr_in6 sock_addr6; + + if (IPv6) + { + sock_addr6.sin6_family = AF_INET6; + sock_addr6.sin6_port = htons(port); + + if (inet_pton(AF_INET6, bindip.c_str(), &sock_addr6.sin6_addr) < 1) + { + throw SocketException("Invalid bind host: " + std::string(strerror(errno))); + } + } + else + { + sock_addr.sin_family = AF_INET; + sock_addr.sin_port = htons(port); + + if (inet_pton(AF_INET, bindip.c_str(), &sock_addr.sin_addr) < 1) + { + throw SocketException("Invalid bind host: " + std::string(strerror(errno))); + } + } + + if (IPv6) + { + if (bind(sock, reinterpret_cast(&sock_addr6), sizeof(sock_addr6)) == -1) + { + throw SocketException("Unable to bind to address: " + std::string(strerror(errno))); + } + } + else + { + if (bind(sock, reinterpret_cast(&sock_addr), sizeof(sock_addr)) == -1) + { + throw SocketException("Unable to bind to address: " + std::string(strerror(errno))); + } + } + + if (listen(sock, 5) == -1) + { + throw SocketException("Unable to listen: " + std::string(strerror(errno))); + } } /** Destructor */ -SocketEngine::~SocketEngine() +ListenSocket::~ListenSocket() { -#ifdef _WIN32 - WSACleanup(); +} + +/** Accept a connection in this sockets queue + */ +bool ListenSocket::ProcessRead() +{ + int newsock = accept(sock, NULL, NULL); + +#ifndef _WIN32 +# define INVALID_SOCKET 0 #endif -} -/** Add a socket to the socket engine - * @param s The socket - */ -void SocketEngine::AddSocket(Socket *s) -{ - if (s->GetSock() > MaxFD) - MaxFD = s->GetSock(); - FD_SET(s->GetSock(), &ReadFDs); - Sockets.insert(s); -} - -/** Delete a socket from the socket engine - * @param s The socket - */ -void SocketEngine::DelSocket(Socket *s) -{ - if (s->GetSock() == MaxFD) - --MaxFD; - FD_CLR(s->GetSock(), &ReadFDs); - FD_CLR(s->GetSock(), &WriteFDs); - Sockets.erase(s); -} - -/** Mark a socket as wanting to be written to - * @param s The socket - */ -void SocketEngine::MarkWriteable(Socket *s) -{ - FD_SET(s->GetSock(), &WriteFDs); -} - -/** Unmark a socket as writeable - * @param s The socket - */ -void SocketEngine::ClearWriteable(Socket *s) -{ - FD_CLR(s->GetSock(), &WriteFDs); -} - -/** Called to iterate through each socket and check for activity - */ -void SocketEngine::Process() -{ - fd_set rfdset = ReadFDs, wfdset = WriteFDs, efdset = ReadFDs; - timeval tval; - - tval.tv_sec = Config.ReadTimeout; - tval.tv_usec = 0; - - int sresult = select(MaxFD + 1, &rfdset, &wfdset, &efdset, &tval); - - if (sresult == -1) - Alog() << "SocketEngine::Process error, " << GetError(); - else if (sresult) + if (newsock > 0 && newsock != INVALID_SOCKET) { - for (std::set::iterator it = Sockets.begin(); it != Sockets.end(); ++it) - { - Socket *s = *it; - - if (FD_ISSET(s->GetSock(), &efdset)) - { - s->ProcessError(); - OldSockets.insert(s); - continue; - } - if (FD_ISSET(s->GetSock(), &rfdset)) - { - if (!s->ProcessRead()) - OldSockets.insert(s); - } - if (FD_ISSET(s->GetSock(), &wfdset)) - { - ClearWriteable(s); - if (!s->ProcessWrite()) - OldSockets.insert(s); - } - } + return this->OnAccept(new Socket(newsock, IPv6)); } - while (!OldSockets.empty()) - { - delete (*OldSockets.begin()); - OldSockets.erase(OldSockets.begin()); - } + return true; } -/** Get the last socket error - * @return The error +/** Called when a connection is accepted + * @param s The socket for the new connection + * @return true if the listen socket should remain alive */ -const std::string SocketEngine::GetError() const +bool ListenSocket::OnAccept(Socket *s) { -#ifdef _WIN32 - errno = WSAGetLastError(); -#endif - switch (errno) - { - case EBADF: - return "Socket error, invalid file descriptor given to select()"; - break; - case EINTR: - return "Socket engine caught signal"; - break; -#ifdef WIN32 - case WSANOTINITIALISED: - return "A successful WSAStartup call must occur before using this function."; - break; - case WSAEFAULT: - return "The Windows Sockets implementation was unable to allocate needed resources for its internal operations, or the readfds, writefds, exceptfds, or timeval parameters are not part of the user address space."; - break; - case WSAENETDOWN: - return "The network subsystem has failed."; - break; - case WSAEINVAL: - return "The time-out value is not valid, or all three descriptor parameters were null."; - break; - case WSAEINTR: - return "A blocking Windows Socket 1.1 call was canceled through WSACancelBlockingCall."; - break; - case WSAEINPROGRESS: - return "A blocking Windows Sockets 1.1 call is in progress, or the service provider is still processing a callback function."; - break; - case WSAENOTSOCK: - return "One of the descriptor sets contains an entry that is not a socket."; - break; -#endif - default: - return "Socket engine caught unknown error"; - } + return true; } + +/** Get the bind IP for this socket + * @return the bind ip + */ +const std::string &ListenSocket::GetBindIP() const +{ + return BindIP; +} + +/** Get the port this socket is bound to + * @return The port + */ +const int ListenSocket::GetPort() const +{ + return Port; +} +