自学内容网 自学内容网

【项目】仿muduo库One Thread One Loop式主从Reactor模型实现高并发服务器

本篇博客记录从0到1实现一个仿mudo库的One Thread One Loop式主从Reactor模型的高并发服务器组件。

在此之前我们要明确的是,该项目仅作为一个高并发服务器组件,因此该项目并不包含实际的业务需求处理内容。

前置知识背景

一、HTTP服务器

概念:

HTTP(Hyper Text Transfer Protocol),超文本传输协议是应用层协议,是一种简单的请求—响应协议(客户端根据自己的需要向服务器发送请求,服务器针对请求提供对应服务,响应完毕后通信结束)。

通过TCP/IP五层网络模型可以看出来,HTTP协议是⼀个运行在TCP协议之上的应用层协议。HTTP服务器本质上还是一个TCP服务器,只是在应用层上使用HTTP协议格式来进行数据的组织和解析,从而明确客户端的请求,并完成相应的业务处理

二、Reactor模型

Reactor模型可以分为三种类型:① 单Reactor单线程,② 单Reactor多线程,③ 多Reactor多线程

因为多Reactor多线程的Reactor模型更适合用来实现高并发服务器,所以当前知识点着重介绍多Reactor多线程模型

多Reactor多线程:

        一个Reactor线程仅进行获取新连接的处理,其他多个Reactor线程处理IO操作,IO Reactor线程拿到数据后,将数据分发给线程池种的业务线程进行业务处理。

| Reactor多线程模型框架流程图 |

1. 在主Reactor线程中处理新连接请求事件,有新连接到来则分发到从Reactor线程中监控。

2. 在从Reactor中进行客户端通信监控,有事件触发,则接收数据分发给业务线程池。

3. 业务线程池分配独立的线程进行具体的业务处理,业务线程处理完毕后,将响应交给从Reactor线程,让从Reactor线程向客户端进行数据响应。

多Reactor多线程模型的优点:充分利用CPU资源,并且可以合理分配(获取新连接的处理压力大时,可以适当增加主Reactor线程的数量;IO操作处理压力大时,可以适当增加从Reactor线程的数量。可根据服务器不同的环节来合理调配资源)。

※ 需要注意的是,执行流并不是越多越好,过多的执行流反而会增加CPU切换调度的成本。

三、项目目标定位

该项目旨在实现一个主从Reactor模型服务器组件,也就是主Reactor线程仅监控监听描述符,获取新建连接,保证获取新连接的高效性,提高服务器的并发性能。主Reactor线程获取到新连接后分发给从Reactor线程进行通信事件监控。而从Reactor线程监控各自的描述符的读写事件进行数据读写以及业务处理。

One Thread One Loop的思想:把所有的操作都放到一个线程中进行,一个线程对应一个事件处理的循环。

由于该项目仅仅是实现一个高并发的服务器组件,不涉及实际的业务,所以舍去了多Reactor多线程模型中的业务线程池部分。

让从线程循环以下三个操作:① IO事件监控,② IO操作,③ 业务处理

当前组件只实现主从Reactor线程,后续组件使用者可根据实际的业务需求来判断是否有必要完善业务线程池。


项目模块划分

基于以上的理解,我们要实现的是⼀个带有协议支持的Reactor模型高性能服务器组件,因此将整个项目的实现划分为以下两个大的模块:

  • SERVER模块:实现Reactor模型的TCP服务器
  • 协议模块:对当前的Reactor模型服务器提供应用层协议支持

SERVER模块

SERVER模块对所有的连接以及线程进行管理,具体管理分为以下三个方面:

  • 监听连接管理:对监听连接进行管理
  • 通信连接管理:对通信连接进行管理
  • 超时连接管理:对超时连接进行管理(处理一些恶意连接)

基于以上的管理思想,将SERVER模块进行细致划分又可以粉为以下多个子模块:

Buffer模块

Buffer模块是⼀个缓冲区模块,用于实现通信中用户态的接收缓冲区和发送缓冲区功能。

Socket模块

Socket模块是对套接字操作封装的一个模块,主要实现套接字的各项操作。

Channel模块

Channel模块是对⼀个描述符需要进行的IO事件管理的模块,实现对描述符可读,可写,错误事件的管理操作,以及Poller模块对描述符进行IO事件监控就绪后,根据不同的事件回调不同的处理函数功能。

Connection模块

Connection模块是对Buffer模块,Socket模块,Channel模块的一个整体封装,实现了对一个通信套接字的整体的管理,每一个进行数据通信的套接字(也就是accept获取到的新连接)都会使用Connection模块进行管理。

  • Connection模块内部包含有三个由组件使用者传入的回调函数:连接建立完成回调,事件回调,新数据回调,关闭回调
  • Connection模块内部包含有两个组件使用者提供的接口:数据发送接口,连接关闭接口
  • Connection模块内部包含有两个用户态缓冲区:用户态接收缓冲区,用户态发送缓冲区
  • Connection模块内部包含有一个Socket对象:完成描述符面向系统的IO操作
  • Connection模块内部包含有一个Channel对象:完成描述符IO事件就绪的处理

| 具体处理流程 |

  1. 实现向Channel提供可读,可写,错误等不同事件的IO事件回调函数,然后将Channel和对应的描述符添加到Poller事件监控中。
  2. 当描述符在Poller模块中就绪了IO可读事件,则调用描述符对应Channel中保存的读事件处理函数,进行数据读取,将socket接收缓冲区全部读取到Connection管理的用户态接收缓冲区中。然后调用由组件使用者传入的新数据到来回调函数进行处理。
  3. 组件使用者进行数据的业务处理完毕后,通过Connection向使用者提供的数据发送接口,将数据写入Connection的发送缓冲区中。
  4. 启动描述符在Poller模块中的IO写事件监控,就绪后调用Channel中保存的写事件处理函数,将发送缓冲区中的数据通过Socket进行面向系统的实际数据发送。

Acceptor模块

Acceptor模块是对Socket模块,Channel模块的⼀个整体封装,实现了对一个监听套接字的整体管理。

  • Acceptor模块内部包含有一个Socket对象:实现监听套接字的操作
  • Acceptor模块内部包含有一个Channel对象:实现监听套接字IO事件就绪的处理

| 具体处理流程 |

  1. 实现向Channel提供可读事件的IO事件处理回调函数,函数的功能其实也就是获取新连接。
  2. 为新连接实例化⼀个Connection对象。

TimerQueue模块

TimerQueue模块是实现固定时间定时任务的模块,可以理解为就是一个定时任务管理器,向定时任务管理器中添加一个任务,任务将在固定时间后被执行,同时也可以通过刷新定时任务来延迟任务的执行。

这个模块主要是对Connection对象的生命周期进行管理,对非活跃连接进行超时后的释放功能。

  • TimerQueue模块内部包含有一个timerfd:linux系统提供的定时器
  • TimerQueue模块内部包含有一个Channel对象:实现对timerfd的IO时间就绪回调处理

Poller模块

Poller模块是对epoll进行封装的一个模块,主要实现epoll的IO事件添加,修改,移除,获取活跃连接功能。

EventLoop模块

EventLoop模块可以理解就是我们上边所说的Reactor模块,它是对Poller模块,TimerQueue模块, Socket模块的一个整体封装,进行所有描述符的事件监控。

EventLoop模块必然是一个对象对应一个线程的模块,线程内部的目的就是运行EventLoop的启动函数。

EventLoop模块为了保证整个服务器的线程安全问题,因此要求使用者对于Connection的所有操作一定要在其对应的EventLoop线程内完成,不能在其他线程中进行(比如组件使用者使用Connection发送数据,以及关闭连接这种操作)。

EventLoop模块保证自己内部所监控的所有描述符,都要是活跃连接,非活跃连接就要及时释放避免资源浪费。

  • EventLoop模块内部包含有一个eventfd:eventfd其实就是linux内核提供的一个事件fd,专门用于事件通知
  • EventLoop模块内部包含有一个Poller对象:用于进行描述符的IO事件监控
  • EventLoop模块内部包含有一个TimerQueue对象:用于进行定时任务的管理
  • EventLoop模块内部包含有一个PendingTask队列:组件使用者将对Connection进行的所有操作, 都加入到任务队列中,由EventLoop模块进行管理,并在EventLoop对应的线程中进行执行

每一个Connection对象都会绑定到一个EventLoop线程上,这样能保证对这个连接的所有操作都是在一个线程中完成的。

| 具体处理流程 |

  1. 通过Poller模块对当前模块管理内的所有描述符进行IO事件监控,有描述符事件就绪后,通过描述符对应的Channel进行事件处理。
  2. 所有就绪的描述符IO事件处理完毕后,对任务队列中的所有操作顺序进行执行。
  3. 由于epoll的事件监控,有可能会因为没有事件到来而持续阻塞,导致任务队列中的任务不能及时得到执行,因此创建了eventfd,添加到Poller的事件监控中,用于实现每次向任务队列添加任务的时候,通过向eventfd写入数据来唤醒epoll的阻塞。

TcpServer模块

该模块是一个整体TCP服务器模块的封装,内部封装了Acceptor模块,EventLoopThreadPool模块。

  • TcpServer中包含有一个EventLoop对象:以备在超轻量使用场景中不需要EventLoop线程池,只需要在主线程中完成所有操作的情况
  • TcpServer模块内部包含有一个EventLoopThreadPool对象:其实就是EventLoop线程池,也就是从Reactor线程池
  • TcpServer模块内部包含有一个Acceptor对象:一个TcpServer服务器,必然对应有一个监听套接字,能够完成获取客户端新连接,并处理的任务
  • TcpServer模块内部包含有一个std::shared_ptr<Connection>的hash表:保存了所有的新建连接对应的Connection,注意!所有的Connection使用shared_ptr进行管理,这样能够保证在hash表中删除了Connection信息后,在shared_ptr计数器为0的情况下完成对Connection资源的释放操作。

| 具体处理流程 |

  1. 在实例化TcpServer对象过程中,完成BaseLoop的设置,Acceptor对象的实例化,以及EventLoop线程池的实例化,以及std::shared_ptr<Connection>的hash表的实例化。
  2. 为Acceptor对象设置回调函数,获取到新连接后,为新连接构建Connection对象,设置 Connection的各项回调,并使用shared_ptr进行管理,并添加到hash表中进行管理,并为Connection选择一个EventLoop线程,为Connection添加一个定时销毁任务,为Connection添加事件监控。
  3. 启动BaseLoop。

模块关系图


代码实现

前置知识技术及实例

简单的秒级定时任务

在当前的高并发服务器中,我们不得不考虑一个问题,那就是连接的超时关闭问题。我们需要避免一个连接长时间不通信,但是也不关闭,空耗资源的情况。 这时候我们就需要一个定时任务,定时的将超时过期的连接进行释放。

| Linux提供的定时器 |

创建定时器:

int timerfd_create(int clockid, int flags);

clockid:CLOCK_REALTIME--系统实时时间,如果修改了系统时间就会出问题; 
         CLOCK_MONOTONIC--从开机到现在的时间是一种相对时间;

flags:0--默认阻塞属性;

返回值:文件描述符

定时器的原理:每隔一段时间(定时器的超时时间),系统就会给这个描述符对应的定时器写入一个8字节数据

启动定时器:

int timerfd_settime(int fd, int flags, struct itimerspec *new, struct itimerspec *old);

fd:timerfd_create返回的⽂件描述符,即创建的定时器的标识符

flags:0--相对时间, 1--绝对时间;默认设置为0即可

new: ⽤于设置定时器的新超时时间

old: ⽤于接收原来的超时时间

struct timespec
{
    time_t tv_sec; // 秒
    long tv_nsec;  // 纳秒
};

struct itimerspec
{
    struct timespec it_interval; // 第⼀次之后的超时间隔时间(每隔多长时间提醒一次)
    struct timespec it_value;    // 第⼀次超时时间
};

| 实例 |

#include <iostream>
#include <unistd.h>
#include <fcntl.h>
#include <sys/timerfd.h>

int main()
{
    int timerfd = timerfd_create(CLOCK_MONOTONIC, 0);
    if (timerfd < 0)
    {
        perror("timer_create error");
        return -1;
    }

    struct itimerspec itime;

    // 第一次超时时间为1s后
    itime.it_value.tv_sec = 1;
    itime.it_value.tv_nsec = 0;

    // 第一次超时后,每次超时的间隔时间为5s
    itime.it_interval.tv_sec = 5;
    itime.it_interval.tv_nsec = 0;

    timerfd_settime(timerfd, 0, &itime, nullptr);

    while (true)
    {
        uint64_t times;
        int ret = read(timerfd, &times, 8);
        if (ret < 0)
        {
            perror("read error");
            return -1;
        }

        std::cout << "距离上一次超时了" << times << "次" << std::endl;
    }

    close(timerfd);

    return 0;
}

以上实例是一个简单的定时器使用示例,初次超时时间为1s后,往后每间隔5s触发一次定时器超时。

基于这个实例,则我们可以实现每隔5s,检测一下哪些连接超时了,然后将超时的连接释放掉,避免资源被占用。

时间轮思想

时间轮的思想来源于钟表,如果我们定了一个3点钟的闹铃,则当时针走到3的时候,就代表时间到 了。

同样的道理,如果我们定义了一个数组,并且有一个指针,指向数组起始位置,这个指针每秒钟向后走动一步,走到哪里,则代表哪里的任务该被执行了,那么如果我们想要定一个3s后的任务,则只需要将任务添加到tick+3位置,则每秒中走一步,三秒钟后tick走到对应位置,这时候执行对应位置的任务即可。

但是,同一时间可能会有大批量的定时任务,因此我们可以给数组对应位置下拉一个数组,这样就可以在同一个时刻上添加多个定时任务了。

上述操作也存在一些缺陷,比如我们如果要定义一个60s后的任务,则需要将数组的元素个数设置 为60才可以,如果设置一小时后的定时任务,则需要定义3600个元素的数组,这样无疑是比较麻烦的。

因此,可以采用多层级的时间轮,有秒针轮,分针轮,时针轮。60<time<3600,则time/60就是分针轮对应存储的位置,当tick/3600等于对应位置的时候,将其位置的任务向分针轮,秒针轮进行移动。

因为当前项目中不涉及长时间的定时任务,通常设置在30s内,因此使用简单的单层时间轮即可满足需求。

但是,我们也得考虑一个问题,当前的设计是时间到了,则主动去执行定时任务,释放连接。那能不能在时间到了后,自动执行定时任务呢?这时候我们就想到一个操作--类的析构函数

一个类的析构函数,在对象被释放时会自动被执行,那么我们如果将一个定时任务作为一个类的析构函数内的操作,则这个定时任务在对象被释放的时候就会执行。

但是仅仅为了这个目的,而设计一个额外的任务类,好像有些不划算。但是,这里我们又要考虑另一个问题,那就是假如有一个连接建立成功了,我们给这个连接设置了一个30s后的定时销毁任务,但是在第10s的时候,这个连接进行了一次通信,那么我们应该是在第30s的时候关闭,还是第40s的时候关闭呢?无疑应该是第40s的时候。也就是说,这时候我们需要让这个第30s的任务失效,但是我们该如何实现这个操作呢?

| 完善时间轮设计 |

这里,我们就用到了智能指针shared_ptr,shared_ptr有个计数器,当计数为0的时候,才会真正释放一个对象。那么如果该连接在第10s进行了一次通信,则我们继续向定时任务中,添加一个30s后(也就是第40s)的任务类对象的shared_ptr,则这时候两个任务shared_ptr计数为2,则第30s的定时任务被释放的时候,计数-1,变为1,并不为0,则并不会执行实际的析构函数,那么就相当于这个第30s的任务失效了,只有在第40s的时候,这个任务才会被真正释放,进而自动执行定时任务。

| 实例 |

#include <iostream>
#include <functional>
#include <memory>
#include <vector>
#include <unordered_map>
#include <unistd.h>

using TaskFunc = std::function<void()>;
using ReleaseFunc = std::function<void()>;

// 定时器任务类
class TimerTask
{
public:
    TimerTask(uint64_t id, uint32_t delay, const TaskFunc& cb)
        :_id(id)
        , _timeout(delay)
        , _canceled(false)
        , _task_cb(cb)
    {}

    void SetRelease(const ReleaseFunc& cb)
    {
        _release = cb;
    }

    uint32_t DelayTime()
    {
        return _timeout;
    }

    void Cancel()
    {
        _canceled = true;
    }

    ~TimerTask()
    {
        if (_canceled == false) _task_cb();
        _release();
    }

private:
    uint64_t _id;         // 定时任务对象id
    uint32_t _timeout;    // 定时任务的超时时间
    bool _canceled;        // true--取消定时任务,false--不取消
    TaskFunc _task_cb;    // 定时器要执行的定时任务
    ReleaseFunc _release; // 用于删除TimerWheel中保存的定时任务对象信息(删除unordered_map中的元素)
};

// 时间轮类
class TimerWheel
{
public:
    TimerWheel()
        :_tick(0)
        , _capacity(60)
        , _wheel(_capacity)
    {}

    // 添加定时任务
    void TimerAdd(uint64_t id, uint32_t delay, const TaskFunc& cb)
    {
        SharedTask st(new TimerTask(id, delay, cb));
        st->SetRelease(std::bind(&TimerWheel::RemoveTimer, this, id));
        _timers[id] = WeakTask(st);

        // 将定时任务添加到时间轮中
        int pos = (_tick + delay) % _capacity;
        _wheel[pos].push_back(st);
    }

    // 刷新/延迟定时任务
    void TimerRefresh(uint64_t id)
    {
        // 通过定时器对象的weak_ptr构造一个shared_ptr,再添加到时间轮中
        auto it = _timers.find(id);
        if (it == _timers.end()) return;

        SharedTask st = it->second.lock();
        int delay = st->DelayTime();
        int pos = (_tick + delay) % _capacity;
        _wheel[pos].push_back(st);
    }

    void TimerCancel(uint64_t id)
    {
        auto it = _timers.find(id);
        if (it == _timers.end()) return;

        SharedTask st = it->second.lock();
        if (st) st->Cancel();
    }

    // 时间轮定时任务执行函数(该函数每秒执行一次)
    void RunTimerTask()
    {
        _tick = (_tick + 1) % _capacity;
        _wheel[_tick].clear(); // 清空指针位置上的数组,将所有管理定时任务对象的shared_ptr给释放掉
    }

private:
    void RemoveTimer(uint64_t id)
    {
        auto it = _timers.find(id);
        if (it != _timers.end())
        {
            _timers.erase(it);
        }
    }

private:
    using SharedTask = std::shared_ptr<TimerTask>;
    using WeakTask = std::weak_ptr<TimerTask>;

    int _tick;                                      // 当前的秒针(走到哪里释放哪里,释放哪里就相当于执行哪里的定时任务)
    int _capacity;                                  // 轮的最大数量(最大延迟时间)
    std::vector<std::vector<SharedTask>> _wheel;    // 时间轮
    std::unordered_map<uint64_t, WeakTask> _timers; // 将定时任务id和管理定时任务的weak_ptr绑定
};

class Test
{
public:
    Test() { std::cout << "构造" << std::endl; }
    ~Test() { std::cout << "析构" << std::endl; }
};

void DeleteTest(Test* t)
{
    delete t;
}

int main()
{
    TimerWheel tw;
    Test* t = new Test();

    tw.TimerAdd(888, 5, std::bind(DeleteTest, t));

    for (int i = 0; i < 5; ++i)
    {
        sleep(1);
        tw.TimerRefresh(888);
        tw.RunTimerTask();
        std::cout << "刷新了一下定时任务,重新需要5s后才能执行定时任务" << std::endl;
    }

    tw.TimerCancel(888); // 将定时任务取消

    while (true)
    {
        sleep(1);
        std::cout << "--------------------" << std::endl;
        tw.RunTimerTask();
    }

    return 0;
}

正则库的简单使用

正则表达式(regular expression)描述了一种字符串匹配的模式(pattern),可以用来检查一个字符串是否含有某种子串、将匹配的子串替换或者从某个字符串中取出符合某个条件的子串等。

该项目中使用正则表达式能够让HTTP请求的解析变得简单,减少编码工作量。

bool std::regex_match(const std::string& src, std::smatch& matches, std::regex& e);

src:原始字符串

matches:正则表达式可以从原始字符串中匹配并提取符合某种规则的数据,提取的数据就放在matches中(类似于数组的容器)

e:正则表达式的匹配规则

返回值:判断匹配是否成功

| 实例 |

#include <iostream>
#include <string>
#include <regex>

int main()
{
    std::string str = "/numbers/123";
    std::regex e("/numbers/(\\d+)"); // 匹配以/numbers/开头后跟一个或多个数字的字符,将匹配到的数字字符串提取出来
    std::smatch matches;

    bool ret = std::regex_match(str, matches, e);
    if (ret == false) return -1;

    for (const auto& e : matches)
    {
        std::cout << e << std::endl;
    }

    return 0;
}

※ 注意,matches中的第一个元素是/numbers/123,第二个元素才是提取出来的数字字符串123。

| 正则表达式提取HTTP请求行中的元素 |

#include <iostream>
#include <string>
#include <regex>

int main()
{
    // HTTP请求行格式:GET /home/login?user=nk&pass=123123 HTTP/1.1\r\n
    std::string str = "GET /home/login?user=nk&pass=123123 HTTP/1.1\r\n";
    std::smatch matches;

    // HTTP请求的方法匹配
    // GET POST PUT HEAD DELETE OPTIONS TRACE CONNECT LINK UNLINE
    // std::regex e("(GET|POST|PUT|HEAD|DELETE|OPTIONS|TRACE|CONNECT|LINK|UNLINE) .*");

    // HTTP请求的资源路径匹配
    // std::regex e("(GET|POST|PUT|HEAD|DELETE|OPTIONS|TRACE|CONNECT|LINK|UNLINE) ([^?]*).*");

    // HTTP请求的查询字符串匹配
    // std::regex e("(GET|POST|PUT|HEAD|DELETE|OPTIONS|TRACE|CONNECT|LINK|UNLINE) ([^?]*)\\?(.*) .*");

    // HTTP请求的协议版本匹配
    std::regex e("(GET|POST|PUT|HEAD|DELETE|OPTIONS|TRACE|CONNECT|LINK|UNLINE) ([^?]*)(?:\\?(.*))? (HTTP/1\\.[01])(?:\n|\r\n)?");

    bool ret = std::regex_match(str, matches, e);
    if (ret == false) return -1;

    for (const auto& e : matches)
    {
        std::cout << e << std::endl;
    }

    return 0;
}

通用类Any类型的实现

每一个Connection对连接进行管理,最终都不可避免需要涉及到应用层协议的处理,因此在 Connection中需要设置协议处理的上下文来控制处理节奏。服务器支持的协议后续可能会不断增多,不同的协议的上下文在结构上会存在差异,因此就需要一个通用的类型来保存各种不同的数据结构。

| 通用类Any的设计思想 |

| 实例与简单测试 |

#include <iostream>
#include <typeinfo>
#include <cassert>
#include <string>

class Any
{
public:
    Any()
        :_content(nullptr)
    {}

    template<class T>
    Any(const T& val)
        :_content(new placeholder<T>(val))
    {}

    Any(const Any& other)
        :_content(other._content ? other._content->clone() : nullptr)
    {}

    ~Any() { delete _content; }

    Any& swap(Any& other)
    {
        std::swap(_content, other._content);
        return *this;
    }

    // 获取Any对象中保存的数据的地址
    template<class T>
    T* get()
    {
        assert(typeid(T) == _content->type()); // 想要获取的数据类型必须和保存的数据类型一致
        return &((placeholder<T>*)_content)->_val;
    }

    template<class T>
    Any& operator=(const T& val)
    {
        Any(val).swap(*this);
        return *this;
    }

    Any& operator=(const Any& other)
    {
        Any(other).swap(*this);
        return *this;
    }

private:
    class holder
    {
    public:
        virtual ~holder() {}

        virtual const std::type_info& type() = 0;

        virtual holder* clone() = 0;
    };

    template <class T>
    class placeholder : public holder
    {
    public:
        placeholder(const T& val)
            :_val(val)
        {}

        // 获取子类对象保存的数据类型
        virtual const std::type_info& type() { return typeid(T); }

        // 根据当前对象克隆出一个新的子类对象
        virtual holder* clone() { return new placeholder(_val); }

    public:
        T _val;
    };

    holder* _content;
};

class Test
{
public:
    Test() { std::cout << "构造" << std::endl; }
    Test(const Test& t) { std::cout << "拷贝" << std::endl; }
    ~Test() { std::cout << "析构" << std::endl; }
};

int main()
{
    {
        Any a;
        Test t;
        a = t;
    }

    // Any a;
    // a = 10;
    // int* pa = a.get<int>();
    // std::cout << *pa << std::endl;

    // a = std::string("hello");
    // std::string* ps = a.get<std::string>();
    // std::cout << *ps << std::endl;

    return 0;
}

SERVER模块实现

Buffer模块

提供的功能:数据的存、取

实现思想:

  1. 采用vector<char>作为缓冲区的内存空间(a. 避免手动管理内存空间的释放;b. 不采用string是因为string的某些操作在遇到'\0'时会停止,而网络数据传输中可能有包含'0'的数据)
  2. 要素:a. vector<char>默认空间的大小;b. 当前读取数据的位置;c. 当前写入数据的位置
  3. 操作:

| 代码实现 |

// Buffer类
#define BUFFER_DEFAULT_SIZE 1024
class Buffer
{
public:
    Buffer()
        :_buffer(BUFFER_DEFAULT_SIZE)
        , _read_index(0)
        , _write_index(0)
    {}

    // 返回缓冲区的起始地址
    char* Begin() { return &(*_buffer.begin()); }

    // 获取当前读取位置起始地址
    char* GetReadPos() { return Begin() + _read_index; }

    // 获取当前写入位置起始地址(缓冲区起始地址加上写偏移量)
    char* GetWritePos() { return Begin() + _write_index; }

    // 获取缓冲区头部空闲空间大小--读偏移之前的空闲空间
    uint64_t HeadIdleSize() { return _read_index; }

    // 获取缓冲区末尾空闲空间大小--写偏移之后的空闲空间
    uint64_t TailIdleSize() { return _buffer.size() - _write_index; }

    // 获取可读数据大小
    uint64_t ReadableSize() { return _write_index - _read_index; }

    // 将读偏移向后移动
    void MoveReadOffset(uint64_t len)
    {
        if (len == 0) return;

        assert(len <= ReadableSize()); // 向后移动的大小必须小于可读数据的大小
        _read_index += len;
    }

    // 将写偏移向后移动
    void MoveWriteOffset(uint64_t len)
    {
        assert(len <= TailIdleSize()); // 向后移动的大小必须小于缓冲区末尾空闲空间的大小
        _write_index += len;
    }

    // 确保可写空间足够(整体的空闲空间够存数据就将数据移动,否则扩容)
    void EnsureWriteSpace(uint64_t len)
    {
        // 若末尾空闲空间足够,则直接返回
        if (TailIdleSize() >= len) return;

        // 若末尾空闲空间不够,则判断末尾空闲空间加上头部空闲空间是否足够
        // 足够则将可读数据移动到缓冲区的起始位置,给后面留出完整的空间来写入数据
        if (HeadIdleSize() + TailIdleSize() >= len)
        {
            // 将可读数据移动到缓冲区起始位置
            uint64_t readsize = ReadableSize(); // 保存缓冲区当前可读数据的大小
            std::copy(GetReadPos(), GetReadPos() + readsize, Begin()); // 将可读数据拷贝到缓冲区起始位置
            _read_index = 0; // 将读偏移归零
            _write_index = readsize; // 将写偏移置为可读数据大小
        }
        else // 总体空间不够,则进行扩容,无需移动数据,直接在写偏移之后扩容足够的空间即可
        {
            _buffer.resize(_write_index + len);
        }
    }

    // 读取数据
    void Read(void* buffer, uint64_t len)
    {
        // 读取数据的大小必须小于可读数据大小
        assert(len <= ReadableSize());
        std::copy(GetReadPos(), GetReadPos() + len, (char*)buffer);
    }

    // 读取数据并将读偏移向后移动
    void ReadAndPop(void* buffer, uint64_t len)
    {
        Read(buffer, len);
        MoveReadOffset(len);
    }

    // 将数据作为字符串读取
    std::string ReadAsString(uint64_t len)
    {
        // 读取数据的大小必须小于可读数据大小
        assert(len <= ReadableSize());

        std::string str;
        str.resize(len);

        Read(&str[0], len);

        return str;
    }

    std::string ReadAsStringAndPop(uint64_t len)
    {
        std::string str = ReadAsString(len);
        MoveReadOffset(len);

        return str;
    }

    // 写入数据
    void Write(const void* data, uint64_t len)
    {
        if (len == 0) return;

        // 1.保证缓冲区有足够的可写空间
        EnsureWriteSpace(len);

        // 2.拷贝数据
        const char* d = (const char*)data;
        std::copy(d, d + len, GetWritePos());
    }

    void WriteAndPush(const void* data, uint64_t len)
    {
        Write(data, len);
        MoveWriteOffset(len);
    }

    // 往缓冲区中写入字符串
    void WriteString(const std::string& data)
    {
        return Write(data.c_str(), data.size());
    }

    void WriteStringAndPush(const std::string& data)
    {
        WriteString(data);
        MoveWriteOffset(data.size());
    }

    // 往缓冲区中写入另一个缓冲区的数据
    void WriteBuffer(Buffer& data)
    {
        return Write(data.GetReadPos(), data.ReadableSize());
    }

    void WriteBufferAndPush(Buffer& data)
    {
        WriteBuffer(data);
        MoveWriteOffset(data.ReadableSize());
    }

    char* FindCRLF()
    {
        char* pos = (char*)memchr(GetReadPos(), '\n', ReadableSize());
        return pos;
    }

    std::string GetLine()
    {
        char* pos = FindCRLF();
        if (pos == NULL) return "";

        return ReadAsString(pos - GetReadPos() + 1); // +1是为了把'\n'也取出来
    }

    std::string GetLineAndPop()
    {
        std::string str = GetLine();
        MoveReadOffset(str.size());
        
        return str;
    }

    // 清空缓冲区
    void Clear()
    {
        // 只需要将两个偏移量归零即可
        _read_index = 0;
        _write_index = 0;
    }

private:
    std::vector<char> _buffer; // 使用vector进行内存管理
    uint64_t _read_index;      // 读偏移
    uint64_t _write_index;     // 写偏移
};

日志宏

| 代码实现 |

// 日志宏
#define INF 0
#define DBG 1
#define ERR 2
#define LOG_LEVEL DBG
#define LOG(level, format, ...)                                                             \
    do                                                                                      \
    {                                                                                       \
        if (level < LOG_LEVEL) break;                                                       \
        time_t t = time(NULL);                                                              \
        struct tm *ltm = localtime(&t);                                                     \
        char tmp[32] = {0};                                                                 \
        strftime(tmp, 31, "%H:%M:%S", ltm);                                                 \
        fprintf(stdout, "[%p %s %s:%d] " format "\n", (void*)pthread_self(), tmp, __FILE__, __LINE__, ##__VA_ARGS__); \
    } while (0)

#define INF_LOG(format, ...) LOG(INF, format, ##__VA_ARGS__)
#define DBG_LOG(format, ...) LOG(DBG, format, ##__VA_ARGS__)
#define ERR_LOG(format, ...) LOG(ERR, format, ##__VA_ARGS__)

Socket模块

Socket模块主要是对套接字的操作进行封装。

封装的接口:

  1. 创建套接字

  2. 绑定地址信息

  3. 开始监听

  4. 向服务器发起连接

  5. 获取新连接

  6. 接收数据

  7. 发送数据

  8. 关闭套接字

  9. 创建一个服务端连接

  10. 创建一个客户端连接

  11. 设置套接字选项——开启地址端口重用

  12. 设置套接字阻塞属性——非阻塞

| 代码实现 |

// Socket类
#define MAX_LISTEN 1024
class Socket
{
public:
    Socket() :_sockfd(-1) {}

    Socket(int sockfd) :_sockfd(sockfd) {}

    ~Socket() { Close(); }

    // 创建套接字
    bool Create()
    {
        // 调用socket()
        _sockfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
        if (_sockfd == -1)
        {
            ERR_LOG("create socket failed");
            return false;
        }
        
        return true;
    }

    // 绑定地址信息
    bool Bind(const std::string& ip, uint16_t port)
    {
        // 定义一个struct sockaddr_in对象
        struct sockaddr_in addr;
        addr.sin_family = AF_INET;
        addr.sin_port = htons(port);
        addr.sin_addr.s_addr = inet_addr(ip.c_str());
        socklen_t len = sizeof(struct sockaddr_in);

        // 调用bind()
        int ret = bind(_sockfd, (struct sockaddr*)&addr, len);
        if (ret == -1)
        {
            ERR_LOG("bind address failed");
            return false;
        }

        return true;
    }

    // 开始监听
    bool Listen(int backlog = MAX_LISTEN)
    {
        // 调用listen()
        int ret = listen(_sockfd, backlog);
        if (ret == -1)
        {
            ERR_LOG("socket listen failed");
            return false;
        }

        return true;
    }

    // 向服务器发起连接(传入服务器的ip和port)
    bool Connect(const std::string& ip, uint16_t port)
    {
        // 定义一个struct sockaddr_in对象
        struct sockaddr_in addr;
        addr.sin_family = AF_INET;
        addr.sin_port = htons(port);
        addr.sin_addr.s_addr = inet_addr(ip.c_str());
        socklen_t len = sizeof(struct sockaddr_in);

        // 调用connect()
        int ret = connect(_sockfd, (struct sockaddr*)&addr, len);
        if (ret == -1)
        {
            ERR_LOG("connect server failed");
            return false;
        }

        return true;
    }

    // 获取新连接
    int Accept()
    {
        // 调用accept()
        int ret = accept(_sockfd, NULL, NULL);
        if (ret == -1)
        {
            ERR_LOG("socket accept failed");
            return -1;
        }

        return ret;
    }

    // 接收数据
    ssize_t Recv(void* buf, size_t len, int flag = 0) // flag = 0默认阻塞
    {
        // 调用recv()
        ssize_t ret = recv(_sockfd, buf, len, flag);        
        if (ret <= 0)
        {
            // ret = EAGAIN 表示当前socket的接收缓冲区中没有数据,在非阻塞的情况下才会有这个返回值
            // ret = EINTR 表示当前socket的阻塞等待被信号打断了
            if (errno == EAGAIN || errno == EINTR)
            {
                return 0; // 表示此次接收没有收到数据
            }

            ERR_LOG("socket receive failed");
            return -1;
        }

        return ret; // 返回实际接收到数据的字节数
    }

    ssize_t NonBlockRecv(void* buf, size_t len)
    {
        return Recv(buf, len, MSG_DONTWAIT); // MSG_DONTWAIT表示为非阻塞接收
    }

    // 发送数据
    ssize_t Send(const void* buf, size_t len, int flag = 0) // flag = 0默认阻塞
    {
        // 调用send()
        ssize_t ret = send(_sockfd, buf, len, flag);
        if (ret < 0)
        {
            if (errno == EAGAIN || errno == EINTR)
            {
                return 0; // 表示此次接收没有收到数据
            }

            ERR_LOG("socket send failed");
            return -1;
        }

        return ret; // 返回实际发送数据的字节数
    }

    ssize_t NonBlockSend(const void* buf, size_t len)
    {
        if (len == 0) return 0;

        return Send(buf, len, MSG_DONTWAIT); // MSG_DONTWAIT表示为非阻塞发送
    }
    
    // 关闭套接字
    void Close()
    {
        if (_sockfd != -1)
        {
            close(_sockfd);
            _sockfd = -1;
        }
    }

    // 创建一个服务端连接
    bool CreateServer(uint16_t port, const std::string& ip = "0.0.0.0", bool block_flag = false)
    {
        // 1.创建套接字,2.设置非阻塞,3.绑定ip和port,4.开始监听,5.启动地址重用
        if (Create() == false) return false;
        if (block_flag) NonBlock();
        if (Bind(ip, port) == false) return false;
        if (Listen() == false) return false;
        ReuseAddress();

        return true;
    }

    // 创建一个客户端连接
    bool CreateClient(uint16_t port, const std::string& ip)
    {
        // 1.创建套接字,2.连接服务器
        if (Create() == false) return false;
        if (Connect(ip, port) == false) return false;

        return true;
    }

    // 设置套接字选项——开启地址端口重用
    void ReuseAddress()
    {
        // 调用setsockopt()
        int opt = 1;
        setsockopt(_sockfd, SOL_SOCKET, SO_REUSEADDR, (void*)&opt, sizeof(opt));
        opt = 1;
        setsockopt(_sockfd, SOL_SOCKET, SO_REUSEPORT, (void*)&opt, sizeof(opt));
    }

    // 设置套接字阻塞属性——非阻塞
    void NonBlock()
    {
        int flag = fcntl(_sockfd, F_GETFL);
        fcntl(_sockfd, F_SETFL, flag | O_NONBLOCK);
    }

    int Fd() { return _sockfd; }

private:
    int _sockfd;
};

Channel模块

实现功能:对于一个文件描述符进行监控事件管理

模块功能:

模块成员:

// 事件处理总函数
void HandleEvent()
{}

一旦该文件描述符有事件被触发,就调用事件处理总函数,HandleEvent()内部
会对触发的事件进行判断,触发的是什么事件,从而调用对应的回调函数,进行
事件处理

| HandleEvent()中的细节 |

// 可能会释放连接的操作事件,一次只处理一个
if (_revents & EPOLLOUT)
{
    if (_write_callback) _write_callback();
}
else if (_revents & EPOLLHUP)
{
    if (_close_callback) _close_callback();
}
else if (_revents & EPOLLERR)
{
    if (_error_callback) _error_callback();
}

如果下面两个用的不是else if,而是if,可能会导致程序崩溃
如果调用写事件回调函数导致连接断开了,此时_revents & EPOLLHUP
还是为真,但连接已断开,无法进行正常的文件描述符挂断事件操作

| 代码实现 |

// Channel类
class Poller; // Poller类的声明,让Channel能够使用Poller类
class EventLoop; // EventLoop类的声明,让Channel能够使用EventLoop类
class Channel
{
    using EventCallback = std::function<void()>;
public:
    Channel(EventLoop* loop, int fd)
        :_loop(loop)
        , _fd(fd)
        , _events(0)
        , _revents(0)
    {}

    int Fd() { return _fd; }

    // 获取文件描述符想要监控的事件
    uint32_t Events() { return _events; } 

    // 设置实际就绪的事件
    void SetREvents(uint32_t events) { _revents = events; }  // EventLoop模块会调用该函数,EventLoop模块会将文件描述符实际触发的事件设置进_revents中

    // 设置文件描述符的读事件回调函数
    void SetReadCallback(const EventCallback& cb) { _read_Callback = cb; }

    // 设置文件描述符的写事件回调函数
    void SetWriteCallback(const EventCallback& cb) { _write_Callback = cb; }

    // 设置文件描述符的挂断事件回调函数
    void SetCloseCallback(const EventCallback& cb) { _close_Callback = cb; }

    // 设置文件描述符的错误事件回调函数
    void SetErrorCallback(const EventCallback& cb) { _error_Callback = cb; }

    // 设置文件描述符的任意事件回调函数
    void SetEventCallback(const EventCallback& cb) { _event_Callback = cb; }

    // 判断当前描述符是否监控了可读事件
    bool ReadAble() { return (_events & EPOLLIN); }

    // 判断当前描述符是否监控了可写事件
    bool WriteAble() { return (_events & EPOLLOUT); }

    // 启动描述符读事件监控
    void EnableRead()
    {
        _events |= EPOLLIN;

        Update();
    }

    // 启动描述符写事件监控
    void EnableWrite()
    {
        _events |= EPOLLOUT;

        Update();
    }

    // 解除描述符读事件监控
    void DisableRead()
    {
        _events &= ~EPOLLIN;

        Update();
    }

    // 解除描述符写事件监控
    void DisableWrite()
    {
        _events &= ~EPOLLOUT;

        Update();
    }

    // 解除描述符所有事件监控
    void DisableAll() { _events = 0; }

    void Update();

    // 将描述符从epoll模型中移除监控
    void Remove();

    // 事件处理总函数
    void HandleEvent()
    {
        if ((_revents & EPOLLIN) || (_revents & EPOLLRDHUP) || (_revents & EPOLLPRI))
        {
            // 文件描述符触发任意事件都要调用任意事件回调函数
            if (_read_Callback) _read_Callback();
        }

        // 可能会释放连接的操作事件,一次只处理一个
        if (_revents & EPOLLOUT)
        {
            // 文件描述符触发任意事件都要调用任意事件回调函数            
            if (_write_Callback) _write_Callback();
        }
        else if (_revents & EPOLLHUP)
        {
            if (_close_Callback) _close_Callback();
        }
        else if (_revents & EPOLLERR)
        {
            if (_error_Callback) _error_Callback();
        }

        if (_event_Callback) _event_Callback();
    }

private:
    EventLoop* _loop;

    int _fd;           // 文件描述符 
    uint32_t _events;  // 当前需要监控的事件
    uint32_t _revents; // 当前连接触发的事件

    EventCallback _read_Callback;  // 可读事件被触发的回调函数
    EventCallback _write_Callback; // 可写事件被触发的回调函数
    EventCallback _close_Callback; // 挂断事件被触发的回调函数
    EventCallback _error_Callback; // 错误事件被触发的回调函数
    EventCallback _event_Callback; // 任意事件被触发的回调函数
};

// Channel类中的两个成员函数
void Channel::Update() { return _loop->UpdateEvent(this); }
void Channel::Remove() { return _loop->RemoveEvent(this); }

Poller模块

Poller模块是对epoll操作进行封装,使描述符进行事件监控的操作更加简单。

通过epoll系列函数对文件描述符的IO事件监控。

模块功能:

  1. 添加/修改文件描述符的事件监控(不存在则添加;存在则修改)
  2. 移除文件描述符的事件监控

封装思想:

  1. 拥有一个epoll的操作句柄
  2. 拥有一个struct epoll_event结构数组,监控时保存所有的活跃事件
  3. 使用hash表管理文件描述符与文件描述符对应的Channel对象

| 代码实现 |

// Poller类
#define MAX_EPOLLEVENTS 1024
class Poller
{
public:
    Poller()
    {
        _epfd = epoll_create(MAX_EPOLLEVENTS);
        if (_epfd == -1)
        {
            ERR_LOG("epoll create failed");
            abort(); // 退出程序
        }
    }

    // 添加或修改文件描述符的监控事件
    void UpdateEvent(Channel* channel)
    {
        // 先判断是否在hash中管理,不在则添加;在则修改
        bool ret = HasChannel(channel);
        if (ret == false)
        {
            // 将该channel和其对应的文件描述符添加到_channels中
            _channels.insert(std::make_pair(channel->Fd(), channel));

            return Update(channel, EPOLL_CTL_ADD);
        }

        return Update(channel, EPOLL_CTL_MOD);
    }

    // 移除文件描述符的事件监控
    void RemoveEvent(Channel* channel)
    {
        // 从hash中移除
        auto it = _channels.find(channel->Fd());
        if (it != _channels.end())
        {
            _channels.erase(it);
        }

        // 删除监控事件
        Update(channel, EPOLL_CTL_DEL);
    }

    // 开始监控,返回活跃连接
    void Poll(std::vector<Channel*>* active)
    {
        // 调用epoll_wait()
        int nfds = epoll_wait(_epfd, _evs, MAX_EPOLLEVENTS, -1);
        if (nfds == -1)
        {
            if (errno == EINTR) return;

            ERR_LOG("epoll wait error:%s", strerror(errno));
            abort(); // 退出程序
        }

        for (int i = 0; i < nfds; ++i)
        {
            auto it = _channels.find(_evs[i].data.fd);
            assert(it != _channels.end());

            it->second->SetREvents(_evs[i].events); // 将实际就绪事件设置到对应文件描述符的Channel对象中
            active->push_back(it->second);
        }
    }

private:
    // 对epoll直接操作(增,删,改)
    void Update(Channel* channel, int op)
    {
        // 调用epoll_ctl()
        int fd = channel->Fd();
        struct epoll_event ev;
        ev.data.fd = fd;
        ev.events = channel->Events();

        int ret = epoll_ctl(_epfd, op, fd, &ev);
        if (ret == -1)
        {
            ERR_LOG("epoll_ctl failed");
        }
    }

    // 判断一个Channel对象是否添加了事件监控(是否在hash中管理)
    bool HasChannel(Channel* channel)
    {
        auto it = _channels.find(channel->Fd());
        if (it == _channels.end())
        {
            return false;
        }

        return true;
    }

private:
    int _epfd;
    struct epoll_event _evs[MAX_EPOLLEVENTS];
    std::unordered_map<int, Channel*> _channels; // 文件描述符和其对应的channel对象的关联关系
};

【Channel模块和Poller模块整合】

| 代码实现 |

// Channel类
class Poller; // Poller类的声明,让Channel能够使用Poller类
class EventLoop; // EventLoop类的声明,让Channel能够使用EventLoop类
class Channel
{
    using EventCallback = std::function<void()>;
public:
    Channel(EventLoop* loop, int fd)
        :_loop(loop)
        , _fd(fd)
        , _events(0)
        , _revents(0)
    {}

    int Fd() { return _fd; }

    // 获取文件描述符想要监控的事件
    uint32_t Events() { return _events; } 

    // 设置实际就绪的事件
    void SetREvents(uint32_t events) { _revents = events; }  // EventLoop模块会调用该函数,EventLoop模块会将文件描述符实际触发的事件设置进_revents中

    // 设置文件描述符的读事件回调函数
    void SetReadCallback(const EventCallback& cb) { _read_Callback = cb; }

    // 设置文件描述符的写事件回调函数
    void SetWriteCallback(const EventCallback& cb) { _write_Callback = cb; }

    // 设置文件描述符的挂断事件回调函数
    void SetCloseCallback(const EventCallback& cb) { _close_Callback = cb; }

    // 设置文件描述符的错误事件回调函数
    void SetErrorCallback(const EventCallback& cb) { _error_Callback = cb; }

    // 设置文件描述符的任意事件回调函数
    void SetEventCallback(const EventCallback& cb) { _event_Callback = cb; }

    // 判断当前描述符是否监控了可读事件
    bool ReadAble() { return (_events & EPOLLIN); }

    // 判断当前描述符是否监控了可写事件
    bool WriteAble() { return (_events & EPOLLOUT); }

    // 启动描述符读事件监控
    void EnableRead()
    {
        _events |= EPOLLIN;

        Update();
    }

    // 启动描述符写事件监控
    void EnableWrite()
    {
        _events |= EPOLLOUT;

        Update();
    }

    // 解除描述符读事件监控
    void DisableRead()
    {
        _events &= ~EPOLLIN;

        Update();
    }

    // 解除描述符写事件监控
    void DisableWrite()
    {
        _events &= ~EPOLLOUT;

        Update();
    }

    // 解除描述符所有事件监控
    void DisableAll() { _events = 0; }

    void Update();

    // 将描述符从epoll模型中移除监控
    void Remove();

    // 事件处理总函数
    void HandleEvent()
    {
        if ((_revents & EPOLLIN) || (_revents & EPOLLRDHUP) || (_revents & EPOLLPRI))
        {
            // 文件描述符触发任意事件都要调用任意事件回调函数
            if (_read_Callback) _read_Callback();
        }

        // 可能会释放连接的操作事件,一次只处理一个
        if (_revents & EPOLLOUT)
        {
            // 文件描述符触发任意事件都要调用任意事件回调函数            
            if (_write_Callback) _write_Callback();
        }
        else if (_revents & EPOLLHUP)
        {
            if (_close_Callback) _close_Callback();
        }
        else if (_revents & EPOLLERR)
        {
            if (_error_Callback) _error_Callback();
        }

        if (_event_Callback) _event_Callback();
    }

private:
    EventLoop* _loop;

    int _fd;           // 文件描述符 
    uint32_t _events;  // 当前需要监控的事件
    uint32_t _revents; // 当前连接触发的事件

    EventCallback _read_Callback;  // 可读事件被触发的回调函数
    EventCallback _write_Callback; // 可写事件被触发的回调函数
    EventCallback _close_Callback; // 挂断事件被触发的回调函数
    EventCallback _error_Callback; // 错误事件被触发的回调函数
    EventCallback _event_Callback; // 任意事件被触发的回调函数
};

// Poller类
#define MAX_EPOLLEVENTS 1024
class Poller
{
public:
    Poller()
    {
        _epfd = epoll_create(MAX_EPOLLEVENTS);
        if (_epfd == -1)
        {
            ERR_LOG("epoll create failed");
            abort(); // 退出程序
        }
    }

    // 添加或修改文件描述符的监控事件
    void UpdateEvent(Channel* channel)
    {
        // 先判断是否在hash中管理,不在则添加;在则修改
        bool ret = HasChannel(channel);
        if (ret == false)
        {
            // 将该channel和其对应的文件描述符添加到_channels中
            _channels.insert(std::make_pair(channel->Fd(), channel));

            return Update(channel, EPOLL_CTL_ADD);
        }

        return Update(channel, EPOLL_CTL_MOD);
    }

    // 移除文件描述符的事件监控
    void RemoveEvent(Channel* channel)
    {
        // 从hash中移除
        auto it = _channels.find(channel->Fd());
        if (it != _channels.end())
        {
            _channels.erase(it);
        }

        // 删除监控事件
        Update(channel, EPOLL_CTL_DEL);
    }

    // 开始监控,返回活跃连接
    void Poll(std::vector<Channel*>* active)
    {
        // 调用epoll_wait()
        int nfds = epoll_wait(_epfd, _evs, MAX_EPOLLEVENTS, -1);
        if (nfds == -1)
        {
            if (errno == EINTR) return;

            ERR_LOG("epoll wait error:%s", strerror(errno));
            abort(); // 退出程序
        }

        for (int i = 0; i < nfds; ++i)
        {
            auto it = _channels.find(_evs[i].data.fd);
            assert(it != _channels.end());

            it->second->SetREvents(_evs[i].events); // 将实际就绪事件设置到对应文件描述符的Channel对象中
            active->push_back(it->second);
        }
    }

private:
    // 对epoll直接操作(增,删,改)
    void Update(Channel* channel, int op)
    {
        // 调用epoll_ctl()
        int fd = channel->Fd();
        struct epoll_event ev;
        ev.data.fd = fd;
        ev.events = channel->Events();

        int ret = epoll_ctl(_epfd, op, fd, &ev);
        if (ret == -1)
        {
            ERR_LOG("epoll_ctl failed");
        }
    }

    // 判断一个Channel对象是否添加了事件监控(是否在hash中管理)
    bool HasChannel(Channel* channel)
    {
        auto it = _channels.find(channel->Fd());
        if (it == _channels.end())
        {
            return false;
        }

        return true;
    }

private:
    int _epfd;
    struct epoll_event _evs[MAX_EPOLLEVENTS];
    std::unordered_map<int, Channel*> _channels; // 文件描述符和其对应的channel对象的关联关系
};

EventLoop模块

EventLoop是对文件描述符进行事件监控以及事件处理的模块。

这个模块与线程是一一对应的!

如果一个文件描述符在多个线程中都触发了事件,并且进行事件处理,就会存在线程安全问题。

因此在实现的时候需要将一个连接的事件监控事件处理,以及其他的操作都放在同一个线程中进行。操作方法为:将连接与EventLoop模块进行绑定。

一个EventLoop对应着一个线程,如何保证一个连接的所有操作都在EventLoop对应的线程中?

解决方案:

  1. 在EventLoop模块中添加一个任务队列
  2. 将对连接的所有操作进行封装,并不直接执行对连接的操作,而是将这些操作当作是任务,添加到任务队列中

EventLoop模块处理流程:

以上操作能保证对于连接的所有操作都是在同一个线程中进行的,不涉及线程安全问题。因为最后对连接的操作都交给了EventLoop模块去执行操作,而一个EventLoop模块对应一个线程。

可能因为要等待文件描述符的IO事件就绪,导致执行流会阻塞,此时任务队列中的任务就无法执行了。为了解决上述问题,引入eventfd,用于唤醒事件监控的阻塞。

认识eventfd:

创建一个文件描述符用于实现事件通知,是一种事件通知机制。

eventfd的本质是内核中管理的一个计数器,创建出一个eventfd就会在内核中创建一个计数器(结构),由内核来维护。

调用write()向eventfd中写入一个数值,则表示事件通知的次数,可以调用read()来读取eventfd中的数据,读取到的数据就是事件通知的次数。

  • 假设每次给eventfd中写入一个1,就表示事件通知了1次。连续写入3次后,再调用read()进行数据读取,读出来的数据就是3。读取完后内核中eventfd的计数器清0。

在EventLoop模块中的作用:实现线程间的事件通知功能。

int eventfd(unsigned int initval, int flags);

| 代码实现 |

// EventLoop类
class EventLoop
{
    using Functor = std::function<void()>;
public:
    EventLoop()
        :_thread_id(std::this_thread::get_id())
        , _eventfd(CreateEventFd())
        , _event_channel(new Channel(this, _eventfd))
        , _timer_wheel(this)
    {
        // 设置_eventfd读事件回调函数,读取eventfd事件通知次数
        _event_channel->SetReadCallback(std::bind(&EventLoop::ReadEventFd, this));

        // 启动_eventfd的读事件监控
        _event_channel->EnableRead();
    }

    // 三步走:事件监控-> 就绪事件处理-> 执行任务池中的任务
    void Start()
    {
        while (true)
        {
            // 1.事件监控
            std::vector<Channel*> actives_channels;
            _poller.Poll(&actives_channels);

            // 2.就绪事件处理
            for (const auto &channel : actives_channels)
            {
                channel->HandleEvent();
            }

            // 3.执行线程池中的任务
            RunAllTask();
        }
    }

    // 判断当前线程是否是EventLoop对应的线程
    bool IsInLoop() { return (_thread_id == std::this_thread::get_id()); }

    void AssertInLoop() { assert(_thread_id == std::this_thread::get_id()); }

    // 判断将要执行的任务是否在当前线程中,是则执行;不是则加入任务池中
    void RunInLoop(const Functor& cb)
    {
        // 要执行任务在当前线程中,直接调用执行
        if (IsInLoop()) return cb();

        // 要执行任务不在当前线程中,将任务加入任务池中
        return QueueInLoop(cb);
    }

    // 将操作加入任务池中
    void QueueInLoop(const Functor& cb)
    {
        {
            std::unique_lock<std::mutex> _lock(_mutex);

            _tasks.push_back(cb);
        }

        // 唤醒IO事件监控有可能导致的阻塞
        WakeUpEventFd();
    }

    // 添加或修改文件描述符的监控事件
    void UpdateEvent(Channel* channel) { return _poller.UpdateEvent(channel); }

    // 移除文件描述符的事件监控
    void RemoveEvent(Channel* channel) { return _poller.RemoveEvent(channel); }

    void TimerAdd(uint64_t id, uint32_t delay, const TaskFunc& cb) { return _timer_wheel.TimerAdd(id, delay, cb); }

    void TimerRefresh(uint64_t id) { return _timer_wheel.TimerRefresh(id); }

    void TimerCancel(uint64_t id) { return _timer_wheel.TimerCancel(id); }

    bool HasTimer(uint64_t id) { return _timer_wheel.HasTimer(id); }

private:
    // 执行任务池中的所有任务
    void RunAllTask()
    {
        std::vector<Functor> functors;
        {
            std::unique_lock<std::mutex> _lock(_mutex);
             
            _tasks.swap(functors);
        }
        
        for (const auto& f : functors)
        {
            f();
        }
    }

    static int CreateEventFd()
    {
        int efd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK);
        if (efd < 0)
        {
            ERR_LOG("create eventfd failed");
            abort(); // 让程序异常退出
        }

        return efd;
    }

    void ReadEventFd()
    {
        uint64_t res = 0;
        int ret = read(_eventfd, &res, sizeof(res));
        if (ret < 0)
        {
            // EINTR -- 表示被信号打断    EAGAIN -- 表示无数据可读
            if (errno == EINTR || errno == EAGAIN)
            {
                return;
            }

            ERR_LOG("read eventfd failed");
            abort();
        }
    }

    // 本质是向eventfd中写入数据
    void WakeUpEventFd()
    {
        uint64_t val = 1;
        int ret = write(_eventfd, &val, sizeof(val));
        if (ret < 0)
        {
            // EINTR -- 表示被信号打断
            if (errno == EINTR)
            {
                return;
            }

            ERR_LOG("write eventfd failed");
            abort();
        }
    }

private:
    std::thread::id _thread_id;                // 线程ID
    int _eventfd;                              // 用于唤醒IO事件监控有可能导致的阻塞
    std::unique_ptr<Channel> _event_channel;   // 用于管理_eventfd的事件监控
    Poller _poller;                            // 用于进行所有文件描述符的事件监控
    std::vector<Functor> _tasks;               // 任务池
    std::mutex _mutex;                         // 保证任务池操作的线程安全
    TimerWheel _timer_wheel;                   // 定时器模块
};

TimerWheel模块

将timerfd和时间轮的思想整合成一个模块,用于执行定时任务。

| 代码实现 |

using TaskFunc = std::function<void()>;
using ReleaseFunc = std::function<void()>;

// 定时器任务类
class TimerTask
{
public:
    TimerTask(uint64_t id, uint32_t delay, const TaskFunc& cb)
        :_id(id)
        , _timeout(delay)
        , _canceled(false)
        , _task_cb(cb)
    {}

    void SetRelease(const ReleaseFunc& cb)
    {
        _release = cb;
    }

    uint32_t DelayTime()
    {
        return _timeout;
    }

    void Cancel()
    {
        _canceled = true;
    }

    ~TimerTask()
    {
        if (_canceled == false) _task_cb();
        _release();
    }

private:
    uint64_t _id;         // 定时任务对象id
    uint32_t _timeout;    // 定时任务的超时时间
    bool _canceled;       // true--取消定时任务,false--不取消
    TaskFunc _task_cb;    // 定时器要执行的定时任务
    ReleaseFunc _release; // 用于删除TimerWheel中保存的定时任务对象信息(删除unordered_map中的元素)
};

// 时间轮类
class TimerWheel
{
public:
    TimerWheel(EventLoop* loop)
        :_tick(0)
        , _capacity(60)
        , _wheel(_capacity)
        , _loop(loop)
        , _timerfd(CreateTimerFd())
        , _timer_channel(new Channel(_loop, _timerfd))
    {
        _timer_channel->SetReadCallback(std::bind(&TimerWheel::OnTime, this));
        _timer_channel->EnableRead(); // 启动定时器文件描述符的读事件监控
    }

    bool HasTimer(uint64_t id) // 该接口存在线程安全问题!不能被外界使用者调用,只能在模块内,在对应的EventLoop线程内执行
    {
        auto it = _timers.find(id);
        if (it == _timers.end()) return false;

        return true;
    }

    void TimerAdd(uint64_t id, uint32_t delay, const TaskFunc &cb);

    void TimerRefresh(uint64_t id);

    void TimerCancel(uint64_t id);

private:
    void RemoveTimer(uint64_t id)
    {
        auto it = _timers.find(id);
        if (it != _timers.end())
        {
            _timers.erase(it);
        }
    }

    static int CreateTimerFd()
    {
        int timerfd = timerfd_create(CLOCK_MONOTONIC, 0);
        if (timerfd < 0)
        {
            ERR_LOG("timerfd create failede");
            abort();
        }

        struct itimerspec itime;

        // 第一次超时时间为1s后
        itime.it_value.tv_sec = 1;
        itime.it_value.tv_nsec = 0;

        // 第一次超时后,每次超时的间隔时间为1s
        itime.it_interval.tv_sec = 1;
        itime.it_interval.tv_nsec = 0;

        timerfd_settime(timerfd, 0, &itime, nullptr);

        return timerfd;
    }

    int ReadTimerFd()
    {
        uint64_t times;
        int ret = read(_timerfd, &times, 8); // read读取到的数据times就是从上一次read之后超时的次数
        if (ret < 0)
        {
            ERR_LOG("read timerfd failed");
            abort();
        }

        return times;
    }

    // 时间轮定时任务执行函数(该函数每秒执行一次)
    void RunTimerTask()
    {
        _tick = (_tick + 1) % _capacity;
        _wheel[_tick].clear(); // 清空指针位置上的数组,将所有管理定时任务对象的shared_ptr给释放掉
    }

    void OnTime()
    {
        // 根据实际的超时次数来执行对应的超时任务
        int times = ReadTimerFd();
        for (int i = 0; i < times; ++i)
        {
            RunTimerTask();
        }
    }

    // 添加定时任务
    void TimerAddInLoop(uint64_t id, uint32_t delay, const TaskFunc& cb)
    {
        SharedTask st(new TimerTask(id, delay, cb));
        st->SetRelease(std::bind(&TimerWheel::RemoveTimer, this, id));
        _timers[id] = WeakTask(st);

        // 将定时任务添加到时间轮中
        int pos = (_tick + delay) % _capacity;
        _wheel[pos].push_back(st);
    }

    // 刷新/延迟定时任务
    void TimerRefreshInLoop(uint64_t id)
    {
        // 通过定时器对象的weak_ptr构造一个shared_ptr,再添加到时间轮中
        auto it = _timers.find(id);
        if (it == _timers.end()) return;

        SharedTask st = it->second.lock();
        int delay = st->DelayTime();
        int pos = (_tick + delay) % _capacity;
        _wheel[pos].push_back(st);
    }

    void TimerCancelInLoop(uint64_t id)
    {
        auto it = _timers.find(id);
        if (it == _timers.end()) return;

        SharedTask st = it->second.lock();
        if (st) st->Cancel();
    }

private:
    using SharedTask = std::shared_ptr<TimerTask>;
    using WeakTask = std::weak_ptr<TimerTask>;

    int _tick;                                      // 当前的秒针(走到哪里释放哪里,释放哪里就相当于执行哪里的定时任务)
    int _capacity;                                  // 轮的最大数量(最大延迟时间)
    std::vector<std::vector<SharedTask>> _wheel;    // 时间轮
    std::unordered_map<uint64_t, WeakTask> _timers; // 将定时任务id和管理定时任务的weak_ptr绑定

    EventLoop* _loop;
    int _timerfd; // 定时器文件描述符
    std::unique_ptr<Channel> _timer_channel;
};

// TimerWheel类中的三个成员函数
void TimerWheel::TimerAdd(uint64_t id, uint32_t delay, const TaskFunc &cb)
{
    _loop->RunInLoop(std::bind(&TimerWheel::TimerAddInLoop, this, id, delay, cb));
}
void TimerWheel::TimerRefresh(uint64_t id)
{
    _loop->RunInLoop(std::bind(&TimerWheel::TimerRefreshInLoop, this, id));
}
void TimerWheel::TimerCancel(uint64_t id)
{
    _loop->RunInLoop(std::bind(&TimerWheel::TimerCancelInLoop, this, id));
}

| 定时器每隔1s执行一次时间轮中的任务的代码执行逻辑 |

【TimerWheel模块和EventLoop模块整合】

| 代码实现 |

using TaskFunc = std::function<void()>;
using ReleaseFunc = std::function<void()>;

// 定时器任务类
class TimerTask
{
public:
    TimerTask(uint64_t id, uint32_t delay, const TaskFunc& cb)
        :_id(id)
        , _timeout(delay)
        , _canceled(false)
        , _task_cb(cb)
    {}

    void SetRelease(const ReleaseFunc& cb)
    {
        _release = cb;
    }

    uint32_t DelayTime()
    {
        return _timeout;
    }

    void Cancel()
    {
        _canceled = true;
    }

    ~TimerTask()
    {
        if (_canceled == false) _task_cb();
        _release();
    }

private:
    uint64_t _id;         // 定时任务对象id
    uint32_t _timeout;    // 定时任务的超时时间
    bool _canceled;       // true--取消定时任务,false--不取消
    TaskFunc _task_cb;    // 定时器要执行的定时任务
    ReleaseFunc _release; // 用于删除TimerWheel中保存的定时任务对象信息(删除unordered_map中的元素)
};

// 时间轮类
class TimerWheel
{
public:
    TimerWheel(EventLoop* loop)
        :_tick(0)
        , _capacity(60)
        , _wheel(_capacity)
        , _loop(loop)
        , _timerfd(CreateTimerFd())
        , _timer_channel(new Channel(_loop, _timerfd))
    {
        _timer_channel->SetReadCallback(std::bind(&TimerWheel::OnTime, this));
        _timer_channel->EnableRead(); // 启动定时器文件描述符的读事件监控
    }

    bool HasTimer(uint64_t id) // 该接口存在线程安全问题!不能被外界使用者调用,只能在模块内,在对应的EventLoop线程内执行
    {
        auto it = _timers.find(id);
        if (it == _timers.end()) return false;

        return true;
    }

    void TimerAdd(uint64_t id, uint32_t delay, const TaskFunc &cb);

    void TimerRefresh(uint64_t id);

    void TimerCancel(uint64_t id);

private:
    void RemoveTimer(uint64_t id)
    {
        auto it = _timers.find(id);
        if (it != _timers.end())
        {
            _timers.erase(it);
        }
    }

    static int CreateTimerFd()
    {
        int timerfd = timerfd_create(CLOCK_MONOTONIC, 0);
        if (timerfd < 0)
        {
            ERR_LOG("timerfd create failede");
            abort();
        }

        struct itimerspec itime;

        // 第一次超时时间为1s后
        itime.it_value.tv_sec = 1;
        itime.it_value.tv_nsec = 0;

        // 第一次超时后,每次超时的间隔时间为1s
        itime.it_interval.tv_sec = 1;
        itime.it_interval.tv_nsec = 0;

        timerfd_settime(timerfd, 0, &itime, nullptr);

        return timerfd;
    }

    int ReadTimerFd()
    {
        uint64_t times;
        int ret = read(_timerfd, &times, 8); // read读取到的数据times就是从上一次read之后超时的次数
        if (ret < 0)
        {
            ERR_LOG("read timerfd failed");
            abort();
        }

        return times;
    }

    // 时间轮定时任务执行函数(该函数每秒执行一次)
    void RunTimerTask()
    {
        _tick = (_tick + 1) % _capacity;
        _wheel[_tick].clear(); // 清空指针位置上的数组,将所有管理定时任务对象的shared_ptr给释放掉
    }

    void OnTime()
    {
        // 根据实际的超时次数来执行对应的超时任务
        int times = ReadTimerFd();
        for (int i = 0; i < times; ++i)
        {
            RunTimerTask();
        }
    }

    // 添加定时任务
    void TimerAddInLoop(uint64_t id, uint32_t delay, const TaskFunc& cb)
    {
        SharedTask st(new TimerTask(id, delay, cb));
        st->SetRelease(std::bind(&TimerWheel::RemoveTimer, this, id));
        _timers[id] = WeakTask(st);

        // 将定时任务添加到时间轮中
        int pos = (_tick + delay) % _capacity;
        _wheel[pos].push_back(st);
    }

    // 刷新/延迟定时任务
    void TimerRefreshInLoop(uint64_t id)
    {
        // 通过定时器对象的weak_ptr构造一个shared_ptr,再添加到时间轮中
        auto it = _timers.find(id);
        if (it == _timers.end()) return;

        SharedTask st = it->second.lock();
        int delay = st->DelayTime();
        int pos = (_tick + delay) % _capacity;
        _wheel[pos].push_back(st);
    }

    void TimerCancelInLoop(uint64_t id)
    {
        auto it = _timers.find(id);
        if (it == _timers.end()) return;

        SharedTask st = it->second.lock();
        if (st) st->Cancel();
    }

private:
    using SharedTask = std::shared_ptr<TimerTask>;
    using WeakTask = std::weak_ptr<TimerTask>;

    int _tick;                                      // 当前的秒针(走到哪里释放哪里,释放哪里就相当于执行哪里的定时任务)
    int _capacity;                                  // 轮的最大数量(最大延迟时间)
    std::vector<std::vector<SharedTask>> _wheel;    // 时间轮
    std::unordered_map<uint64_t, WeakTask> _timers; // 将定时任务id和管理定时任务的weak_ptr绑定

    EventLoop* _loop;
    int _timerfd; // 定时器文件描述符
    std::unique_ptr<Channel> _timer_channel;
};

// EventLoop类
class EventLoop
{
    using Functor = std::function<void()>;
public:
    EventLoop()
        :_thread_id(std::this_thread::get_id())
        , _eventfd(CreateEventFd())
        , _event_channel(new Channel(this, _eventfd))
        , _timer_wheel(this)
    {
        // 设置_eventfd读事件回调函数,读取eventfd事件通知次数
        _event_channel->SetReadCallback(std::bind(&EventLoop::ReadEventFd, this));

        // 启动_eventfd的读事件监控
        _event_channel->EnableRead();
    }

    // 三步走:事件监控-> 就绪事件处理-> 执行任务池中的任务
    void Start()
    {
        while (true)
        {
            // 1.事件监控
            std::vector<Channel*> actives_channels;
            _poller.Poll(&actives_channels);

            // 2.就绪事件处理
            for (const auto &channel : actives_channels)
            {
                channel->HandleEvent();
            }

            // 3.执行线程池中的任务
            RunAllTask();
        }
    }

    // 判断当前线程是否是EventLoop对应的线程
    bool IsInLoop() { return (_thread_id == std::this_thread::get_id()); }

    void AssertInLoop() { assert(_thread_id == std::this_thread::get_id()); }

    // 判断将要执行的任务是否在当前线程中,是则执行;不是则加入任务池中
    void RunInLoop(const Functor& cb)
    {
        // 要执行任务在当前线程中,直接调用执行
        if (IsInLoop()) return cb();

        // 要执行任务不在当前线程中,将任务加入任务池中
        return QueueInLoop(cb);
    }

    // 将操作加入任务池中
    void QueueInLoop(const Functor& cb)
    {
        {
            std::unique_lock<std::mutex> _lock(_mutex);

            _tasks.push_back(cb);
        }

        // 唤醒IO事件监控有可能导致的阻塞
        WakeUpEventFd();
    }

    // 添加或修改文件描述符的监控事件
    void UpdateEvent(Channel* channel) { return _poller.UpdateEvent(channel); }

    // 移除文件描述符的事件监控
    void RemoveEvent(Channel* channel) { return _poller.RemoveEvent(channel); }

    void TimerAdd(uint64_t id, uint32_t delay, const TaskFunc& cb) { return _timer_wheel.TimerAdd(id, delay, cb); }

    void TimerRefresh(uint64_t id) { return _timer_wheel.TimerRefresh(id); }

    void TimerCancel(uint64_t id) { return _timer_wheel.TimerCancel(id); }

    bool HasTimer(uint64_t id) { return _timer_wheel.HasTimer(id); }

private:
    // 执行任务池中的所有任务
    void RunAllTask()
    {
        std::vector<Functor> functors;
        {
            std::unique_lock<std::mutex> _lock(_mutex);
             
            _tasks.swap(functors);
        }
        
        for (const auto& f : functors)
        {
            f();
        }
    }

    static int CreateEventFd()
    {
        int efd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK);
        if (efd < 0)
        {
            ERR_LOG("create eventfd failed");
            abort(); // 让程序异常退出
        }

        return efd;
    }

    void ReadEventFd()
    {
        uint64_t res = 0;
        int ret = read(_eventfd, &res, sizeof(res));
        if (ret < 0)
        {
            // EINTR -- 表示被信号打断    EAGAIN -- 表示无数据可读
            if (errno == EINTR || errno == EAGAIN)
            {
                return;
            }

            ERR_LOG("read eventfd failed");
            abort();
        }
    }

    // 本质是向eventfd中写入数据
    void WakeUpEventFd()
    {
        uint64_t val = 1;
        int ret = write(_eventfd, &val, sizeof(val));
        if (ret < 0)
        {
            // EINTR -- 表示被信号打断
            if (errno == EINTR)
            {
                return;
            }

            ERR_LOG("write eventfd failed");
            abort();
        }
    }

private:
    std::thread::id _thread_id;                // 线程ID
    int _eventfd;                              // 用于唤醒IO事件监控有可能导致的阻塞
    std::unique_ptr<Channel> _event_channel;   // 用于管理_eventfd的事件监控
    Poller _poller;                            // 用于进行所有文件描述符的事件监控
    std::vector<Functor> _tasks;               // 任务池
    std::mutex _mutex;                         // 保证任务池操作的线程安全
    TimerWheel _timer_wheel;                   // 定时器模块
};

// TimerWheel类中的三个成员函数
void TimerWheel::TimerAdd(uint64_t id, uint32_t delay, const TaskFunc &cb)
{
    _loop->RunInLoop(std::bind(&TimerWheel::TimerAddInLoop, this, id, delay, cb));
}
void TimerWheel::TimerRefresh(uint64_t id)
{
    _loop->RunInLoop(std::bind(&TimerWheel::TimerRefreshInLoop, this, id));
}
void TimerWheel::TimerCancel(uint64_t id)
{
    _loop->RunInLoop(std::bind(&TimerWheel::TimerCancelInLoop, this, id));
}

Connection模块

封装Connection模块的目的是对连接进行全方位的管理。对通信连接的所有操作都是通过这个模块提供的功能完成的。

管理:

  1. 套接字的管理----进行套接字的相关操作
  2. 连接时间的管理----管理可读,可写,挂断,错误,任意事件
  3. 缓冲区的管理----接收缓冲区和发送缓冲区,便于套接字接收和发送数据
  4. 协议上下文的管理----记录请求数据的处理过程
  5. 回调函数的管理

功能:

  1. 发送数据----给用户提供调用的发送数据接口,并不是真正的发送接口,而只是把数据放到发送缓冲区,然后启动写事件监控
  2. 关闭连接----给用户提供调用的关闭连接接口,应该在实际释放连接之间,判断接收缓冲区和发送缓冲区是否有数据待处理
  3. 启动非活跃连接的超时销毁功能
  4. 取消非活跃连接的超时销毁功能
  5. 协议切换----一个连接接收数据后如何进行业务处理,取决于上下文,以及数据的业务处理回调函数

Connection模块是对连接的管理模块,对于连接的所有操作都是通过该模块完成的。

可能存在以下场景:

  • 对一个连接进行操作时,但是该连接已经被释放,再对已释放的连接操作,会导致内存访问错误,最终程序崩溃

解决方案:

  • 使用shared_ptr对Connection对象进行管理,这样就能保证任意一个地方对Connection对象进行操作的时候保存一份shared_ptr的管理,因此就算在其他地方进行了释放操作,也只是对shared_ptr的计数器-1,而不会导致Connection对象的实际释放

| 代码实现 |

// Any类
class Any
{
public:
    Any()
        :_content(nullptr)
    {}

    template<class T>
    Any(const T& val)
        :_content(new placeholder<T>(val))
    {}

    Any(const Any& other)
        :_content(other._content ? other._content->clone() : nullptr)
    {}

    ~Any() { delete _content; }

    Any& swap(Any& other)
    {
        std::swap(_content, other._content);
        return *this;
    }

    // 获取Any对象中保存的数据的地址
    template<class T>
    T* get()
    {
        assert(typeid(T) == _content->type()); // 想要获取的数据类型必须和保存的数据类型一致
        return &((placeholder<T>*)_content)->_val;
    }

    template<class T>
    Any& operator=(const T& val)
    {
        Any(val).swap(*this);
        return *this;
    }

    Any& operator=(const Any& other)
    {
        Any(other).swap(*this);
        return *this;
    }

private:
    class holder
    {
    public:
        virtual ~holder() {}

        virtual const std::type_info& type() = 0;

        virtual holder* clone() = 0;
    };

    template <class T>
    class placeholder : public holder
    {
    public:
        placeholder(const T& val)
            :_val(val)
        {}

        // 获取子类对象保存的数据类型
        virtual const std::type_info& type() { return typeid(T); }

        // 根据当前对象克隆出一个新的子类对象
        virtual holder* clone() { return new placeholder(_val); }

    public:
        T _val;
    };

    holder* _content;
};

// Connection类
class Connection;
typedef enum
{
    DISCONNECTED, // 连接关闭状态
    CONNECTING,   // 连接建立成功待处理状态
    CONNECTED,    // 连接建立已完成,各种设置已完成可以通信的状态
    DISCONNECTING // 待关闭状态
} ConnStatu;

using SharedConnection = std::shared_ptr<Connection>; // 使用shared_ptr管理Connection对象,避免对已释放的Connection对象操作,导致内存访问错误
class Connection : public std::enable_shared_from_this<Connection>
{
    using ConnectedCallback = std::function<void(const SharedConnection&)>;
    using MessageCallback = std::function<void(const SharedConnection&, Buffer*)>;
    using ClosedCallback = std::function<void(const SharedConnection&)>;
    using AnyEventCallback = std::function<void(const SharedConnection&)>;
public:
    Connection(EventLoop* loop, uint64_t conn_id, int sockfd)
        :_conn_id(conn_id)
        , _sockfd(sockfd)
        , _enable_inactive_release(false)
        , _loop(loop)
        , _statu(CONNECTING)
        , _socket(_sockfd)
        , _channel(loop, _sockfd)
    {
        _channel.SetCloseCallback(std::bind(&Connection::HandleClose, this));
        _channel.SetEventCallback(std::bind(&Connection::HandleEvent, this));
        _channel.SetReadCallback(std::bind(&Connection::HandleRead, this));
        _channel.SetWriteCallback(std::bind(&Connection::HandleWrite, this));
        _channel.SetErrorCallback(std::bind(&Connection::HandleError, this));
    }

    ~Connection()
    {
        DBG_LOG("release connection: %p", this);
    }

    // 获取连接id
    int Id() { return _conn_id; }

    // 获取管理的文件描述符
    int Fd() { return _sockfd; }

    // 判断是否处于CONNECTED状态
    bool Connected() { return _statu == CONNECTED; }

    // 设置上下文(连接建立完成时进行调用)
    void SetContext(const Any& context) { _context = context; }

    // 获取上下文
    Any* GetContext() { return &_context; }

    void SetConnectedCallback(const ConnectedCallback& cb) { _connected_callback = cb; }

    void SetMessageCallback(const MessageCallback& cb) { _message_callback = cb; }

    void SetClosedCallback(const ClosedCallback& cb) { _closed_callback = cb; }

    void SetAnyEventCallback(const AnyEventCallback& cb) { _event_callback = cb; }

    void SetServerClosedCallback(const ClosedCallback& cb) { _server_closeed_callback = cb; }

    // 连接就绪后,对channel就绪回调设置,启动读事件监控,调用_connected_callback
    void Established()
    {
        _loop->RunInLoop(std::bind(&Connection::EstablishedInLoop, this));
    }

    // 发送数据,将数据放到发送缓冲区,启动写事件监控
    void Send(const char* data, size_t len)
    {
        // 外界传入的data,可能是个临时的空间,我们现在只是把发送操作压入了任务池,有可能并没有被立即执行
        // 因此有可能执行的时候,data指向的空间有可能已经被释放了。
        Buffer buf;
        buf.WriteAndPush(data, len);
        _loop->RunInLoop(std::bind(&Connection::SendInLoop, this, std::move(buf)));
    }

    // 提供给组件使用者的关闭接口(并不实际关闭,需要判断有没有数据在缓冲区中待处理)
    void Shutdown()
    {
        _loop->RunInLoop(std::bind(&Connection::ShutdownInLoop, this));
    }

    void Release()
    {
        _loop->QueueInLoop(std::bind(&Connection::ReleaseInLoop, this));
    }

    // 启动非活跃连接销毁,并定义多长时间无通信就是非活跃,添加定时任务
    void EnableInactiveRelease(int sec)
    {
        _loop->RunInLoop(std::bind(&Connection::EnableInactiveReleaseInLoop, this, sec));
    }

    // 取消非活跃连接销毁
    void CancelInactiveRelease()
    {
        _loop->RunInLoop(std::bind(&Connection::CancelInactiveReleaseInLoop, this));
    }

    // 切换协议(重置上下文以及阶段性处理函数)
    // 这个接口必须在EventLoop线程中立即执行,防备新的事件触发后,处理的时候,切换协议的任务还没有被执行,会导致使用旧有协议去处理数据
    void Upgrade(const Any& context, const ConnectedCallback& conn, const MessageCallback& msg, const ClosedCallback& closed, const AnyEventCallback& event)
    {
        _loop->AssertInLoop();
        _loop->RunInLoop(std::bind(&Connection::UpgradeInLoop, this, context, conn, msg, closed, event));
    }

private:
    // 文件描述符可读事件触发后调用的函数,将接收到的socket数据放到接收缓冲区中,然后调用_message_callback
    void HandleRead()
    {
        // 1.接收socket的数据,放到接收缓冲区中
        char buf[65536] = { 0 };
        ssize_t ret = _socket.NonBlockRecv(buf, 65535);
        if (ret < 0)
        {
            // 出错了,不能直接关闭!
            return ShutdownInLoop();
        }
        // 这里ret=0表示没有读取到数据,并不是连接断开,连接断开返回的是-1

        _in_buffer.WriteAndPush(buf, ret); // 将数据放入输入缓冲区,并将写偏移向后移动

        // 2.调用_message_callback进行业务处理
        if (_in_buffer.ReadableSize() > 0)
        {
            // shared_from_this---从当前对象自身获取自身的shared_ptr管理对象
            return _message_callback(shared_from_this(), &_in_buffer);
        }
    }

    // 文件描述符可写事件触发后调用的函数,将发送缓冲区中的数据进行发送
    void HandleWrite()
    {
        ssize_t ret = _socket.NonBlockSend(_out_buffer.GetReadPos(), _out_buffer.ReadableSize());
        if (ret < 0)
        {
            // 发送错误,关闭连接
            if (_in_buffer.ReadableSize() > 0) // 接收缓冲区中还有数据
            {
                _message_callback(shared_from_this(), &_in_buffer);
            }
            return Release(); // 真正的释放连接
        }

        _out_buffer.MoveReadOffset(ret); // 将读偏移量向后移动
        
        if (_out_buffer.ReadableSize() == 0)
        {
            _channel.DisableWrite(); // 没有数据发送了,关闭文件描述符的写事件监控

            // 如果当前是连接待关闭状态,发送缓冲区中有数据,则发送完数据后再释放连接;发送缓冲区中没有数据,则直接释放
            if (_statu == DISCONNECTING)
            {
                return Release();
            }
        }
    }

    // 文件描述符挂断事件触发后调用的函数
    void HandleClose()
    {
        // 一旦连接触发挂断事件,套接字就什么也干不了了,因此将待处理的数据处理完毕后,关闭连接即可
        if (_in_buffer.ReadableSize() > 0) // 接收缓冲区中还有数据
        {
            _message_callback(shared_from_this(), &_in_buffer);
        }
        return Release();
    }

    // 文件描述符错误事件触发后调用的函数
    void HandleError() { return HandleClose(); }

    // 文件描述符任意事件触发后调用的函数
    void HandleEvent()
    {
        // 1.刷新连接的活跃度(延迟定时销毁任务)
        if (_enable_inactive_release == true) _loop->TimerRefresh(_conn_id);

        // 2.调用组件使用者的任意事件回调
        if (_event_callback) _event_callback(shared_from_this());
    }

    // 连接获取之后,所处的状态下要进行各种设置(启动读事件监控,调用回调函数)
    void EstablishedInLoop()
    {
        // 1.修改连接状态
        assert(_statu == CONNECTING);
        _statu = CONNECTED;

        // 2.启动读事件监控
        _channel.EnableRead();

        // 3.调用回调函数
        if (_connected_callback) _connected_callback(shared_from_this());
    }

    // 该接口并不是实际的发送接口,而只是把数据放到了发送缓冲区,启动可写事件监控
    void SendInLoop(Buffer& buf)
    {
        if (_statu == DISCONNECTED) return;

        // 1.将数据放到输出缓冲区中,并将写偏移向后移动
        _out_buffer.WriteBufferAndPush(buf);

        // 2.启动可写事件监控
        if (_channel.WriteAble() == false)
        {
            _channel.EnableWrite();
        }
    }

    // 该接口并非实际的连接释放操作,接口内判断缓冲区中还有无待处理数据
    void ShutdownInLoop()
    {
        _statu = DISCONNECTING; // 将连接状态设置为半关闭状态

        // 接收缓冲区中还有数据待处理
        if (_in_buffer.ReadableSize() > 0)
        {
            if (_message_callback) _message_callback(shared_from_this(), &_in_buffer);
        }

        // 发送缓冲区中还有数据没发送给对端
        if (_out_buffer.ReadableSize() > 0)
        {
            if (_channel.WriteAble() == false)
            {
                _channel.EnableWrite();
            }
        }

        // 发送缓冲区中没有待发送数据,直接关闭连接
        if (_out_buffer.ReadableSize() == 0)
        {
            Release();
        }
    }

    // 实际的释放连接接口
    void ReleaseInLoop()
    {
        // 1.修改呢连接状态,将其设置为DISCONNECTED
        _statu = DISCONNECTED;

        // 2.移除连接的事件监控
        _channel.Remove();

        // 3.关闭文件描述符
        _socket.Close();

        // 4.若当前定时器队列中还有定时销毁任务,则取消任务
        if (_loop->HasTimer(_conn_id)) CancelInactiveReleaseInLoop();

        // 5.调用关闭回调函数
        if (_closed_callback) _closed_callback(shared_from_this());
        // 移除服务器内部管理的连接信息
        if (_server_closeed_callback) _server_closeed_callback(shared_from_this());
    }

    // 启动非活跃连接超时释放
    void EnableInactiveReleaseInLoop(int sec)
    {
        // 1.将判断标志_enable_inactive_release设置为true
        _enable_inactive_release = true;

        // 2.若当前已存在定时销毁任务,则刷新延迟即可
        if (_loop->HasTimer(_conn_id)) return _loop->TimerRefresh(_conn_id);

        // 3.若不存在定时销毁任务,则新增
        _loop->TimerAdd(_conn_id, sec, std::bind(&Connection::Release, this));
    }

    // 关闭非活跃连接超时释放
    void CancelInactiveReleaseInLoop()
    {
        _enable_inactive_release = false;

        if (_loop->HasTimer(_conn_id)) _loop->TimerCancel(_conn_id);
    }

    void UpgradeInLoop(const Any& context, const ConnectedCallback& conn, const MessageCallback& msg, const ClosedCallback& closed, const AnyEventCallback& event)
    {
        _context = context;
        _connected_callback = conn;
        _message_callback = msg;
        _closed_callback = closed;
        _event_callback = event;
    }

private:
    uint64_t _conn_id;             // 标识连接的唯一id,便于连接的管理和查找
    // uint64_t _timer_id;         // 定时器id,必须是唯一的(为了简化操作,使用_conn_id作为定时器id)
    int _sockfd;                   // 连接关联的文件描述符
    bool _enable_inactive_release; // 连接是否启动非活跃销毁的标志
    EventLoop* _loop;              // 连接所关联的一个EventLoop
    ConnStatu _statu;              // 连接状态
    Socket _socket;                // 套接字操作管理
    Channel _channel;              // 连接的事件管理
    Buffer _in_buffer;             // 输入缓冲区----存放从socket中读取到的数据
    Buffer _out_buffer;            // 输出缓冲区----存放要发送给对端的数据
    Any _context;                  // 请求的接收处理上下文

    ConnectedCallback _connected_callback;
    MessageCallback _message_callback;
    ClosedCallback _closed_callback;
    AnyEventCallback _event_callback;
    // 组件内的连接关闭回调--组件内设置的,因为服务器组件内会把所有的连接
    // 管理起来,一旦某个连接要关闭,就应该从管理的地方移除掉自己的信息
    ClosedCallback _server_closeed_callback;
};

Acceptor模块

Acceptor模块用于对监听套接字进行管理,在获取一个新连接的文件描述符后,为该通信连接实例化一个Connection对象,设置各自不同的回调函数。

  1. 创建一个监听套接字
  2. 启动监听套接字的读事件监控
  3. 读事件触发后,获取新连接
  4. 调用新连接获取成功后的回调函数

| 代码实现 |

// Acceptor类
class Acceptor
{
    using AcceptCallback = std::function<void(int)>;
public:
    Acceptor(EventLoop* loop, uint16_t port)
        :_socket(CreateServer(port))
        , _loop(loop)
        , _channel(loop, _socket.Fd())
    {
        _channel.SetReadCallback(std::bind(&Acceptor::HandleRead, this));
    }

    void Listen() { _channel.EnableRead(); }

    void SetAcceptCallback(const AcceptCallback& cb) { _accept_callback = cb; }

private:
    void HandleRead()
    {
        int newfd = _socket.Accept();
        if (newfd < 0)
        {
            return;
        }

        if (_accept_callback) _accept_callback(newfd);
    }

    int CreateServer(uint16_t port)
    {
        bool ret = _socket.CreateServer(port);
        assert(ret == true);
        
        return _socket.Fd();
    }

private:
    Socket _socket;   // 用于创建监听套接字
    EventLoop* _loop; // 用于对监听套接字进行事件监控
    Channel _channel; // 用于对监听套接字进行事件管理

    AcceptCallback _accept_callback;
};

LoopThread模块

LoopThread模块实现的功能是,将EventLoop模块与线程整合到一起。

思想:

  1. 创建一个线程
  2. 在线程中实例化一个EventLoop对象

提供的功能接口:向外部返回一个实例化的EventLoop对象

※LoopThread模块成员变量中要有互斥锁和条件变量,避免线程创建了,EventLoop对象还没实例化就调用GetLoop()去获取EventLoop对象的情况。

| 代码实现 |

// LoopThread类
class LoopThread
{
public:
    // 创建线程,设定线程入口函数
    LoopThread()
        :_loop(NULL)
        , _thread(&LoopThread::ThreadEntry, this)
    {}

    // 返回当前线程关联的EventLoop对象指针
    EventLoop* GetLoop()
    {
        EventLoop* loop = NULL;

        {
            std::unique_lock<std::mutex> lock(_mutex); // 加锁
            _cond.wait(lock, [&](){ return _loop != NULL; }); // _loop为空就一直阻塞等待

            // 代码执行到这,说明_loop已经实例化出来了
            loop = _loop;
        }

        return loop;
    }

private:
    // 实例化EventLoop对象,并且开始执行EventLoop模块的功能
    void ThreadEntry()
    {
        EventLoop loop; // 实例化一个EventLoop对象

        {
            std::unique_lock<std::mutex> lock(_mutex); // 加锁
            _loop = &loop;

            _cond.notify_all(); // 唤醒有可能阻塞的线程
        }

        _loop->Start();
    }

private:
    std::mutex _mutex;             // 互斥锁
    std::condition_variable _cond; // 条件变量
    EventLoop* _loop;              // EventLoop指针变量,这个变量需要在线程内实例化
    std::thread _thread;           // EventLoop对象对应的线程
};

LoopThreadPool模块

针对LoopThread模块设计一个线程池,对所有LoopThread模块进行管理和分配。

 功能:

  1. 线程数量可配置(0个或多个)注意:在服务器中,主从Reactor模型是主线程只负责获取新连接,从线程负责新连接的事件监控和事件处理。因此当前的线程池,有可能从线程的数量为0,也就是实现单Reactor服务器,一个线程既负责获取新连接,也负责连接的事件处理
  2. 对所有的线程进行管理:其实就是管理0个或多个LoopThread对象
  3. 线程分配功能:当主线程获取了一个新连接,需要将新连接挂到从线程上进行事件监控和事件处理。若有0个从线程,则直接分配给主线程的EventLoop进行处理;若有多个从线程,则采用RR轮转思想,进行从线程的分配(将对应线程的EventLoop获取到,设置给对应的Connection)

 | 代码实现 |

// LoopThreadPool类
class LoopThreadPool
{
public:
    LoopThreadPool(EventLoop* base_loop)
        :_thread_count(0)
        , _next_loop_idx(0)
        , _base_loop(base_loop)
    {}

    void SetThreadCount(int count) { _thread_count = count; }

    void Create()
    {
        if (_thread_count > 0)
        {
            _threads.resize(_thread_count);
            _loops.resize(_thread_count);

            for (int i = 0; i < _thread_count; ++i)
            {
                _threads[i] = new LoopThread();
                _loops[i] = _threads[i]->GetLoop();
            }
        }
    }

    EventLoop* NextLoop()
    {
        if (_thread_count == 0) return _base_loop;

        _next_loop_idx = (_next_loop_idx + 1) % _thread_count;
        return _loops[_next_loop_idx];
    }

private:
    int _thread_count; // 从线程的数量
    int _next_loop_idx;
    EventLoop* _base_loop;             // 主EventLoop,运行在主线程;若从线程数量为0,则所有操作都在_base_loop中进行
    std::vector<LoopThread*> _threads; // 保存所有的LoopThread对象
    std::vector<EventLoop*> _loops;    // 从线程数量大于0,则从_loops中进行线程EventLoop分配
};

TcpServer模块

TcpServer模块是对所有模块的一个整合,通过TcpServer模块实例化对象,对该对象进行操作可以非常简单的完成一个服务器的搭建。

管理:

  • Acceptor对象:创建一个监听套接字
  • EventLoop对象:baseloop对象,实现对监听套接字的事件监控
  • unordered_map<uint64_t, SharedConnection> _conns:实现对所有新建连接的管理
  • LoopThreadPool对象:创建loop线程池,对新建连接进行事件监控和事件处理

功能:

  • 设置从线程数量
  • 启动服务器
  • 设置各种回调函数(连接建立完成,消息,关闭,任意),由用户设置给TcpServer对象,TcpServer对象设置给获取的新连接
  • 是否启动非活跃连接超时销毁功能
  • 添加定时任务功能

流程:

  1. 在TcpServer中实例化一个Acceptor对象,以及一个EventLoop对象(baseloop运行在主线程上)
  2. 将Acceptor挂到baseloop上进行事件监控
  3. 一旦Acceptor对象可读事件就绪,则执行读事件回调函数获取新连接
  4. 为新连接创建一个Connection对象进行管理
  5. 对连接对应的Connection对象设置功能回调(连接完成回调,消息回调,关闭回调,任意事件回调)

| 代码实现 |

// TcpServer类
class TcpServer
{
    using ConnectedCallback = std::function<void(const SharedConnection&)>;
    using MessageCallback = std::function<void(const SharedConnection&, Buffer*)>;
    using ClosedCallback = std::function<void(const SharedConnection&)>;
    using AnyEventCallback = std::function<void(const SharedConnection&)>;

    using Functor = std::function<void()>;
public:
    TcpServer(uint16_t port)
        :_next_id(0)
        , _port(port)
        , _enable_inactive_release(false)
        , _acceptor(&_base_loop, port)
        , _pool(&_base_loop)
    {
        _acceptor.SetAcceptCallback(std::bind(&TcpServer::NewConnection, this, std::placeholders::_1));
        _acceptor.Listen(); // 将监听套接字挂到_base_loop上
    }

    // 设置从线程个数
    void SetThreadCount(int count) { return _pool.SetThreadCount(count); }

    void SetConnectedCallback(const ConnectedCallback& cb) { _connected_callback = cb; }

    void SetMessageCallback(const MessageCallback& cb) { _message_callback = cb; }

    void SetClosedCallback(const ClosedCallback& cb) { _closed_callback = cb; }

    void SetAnyEventCallback(const AnyEventCallback& cb) { _event_callback = cb; }

    void EnableInactiveRelease(int timeout) { _timeout = timeout; _enable_inactive_release = true; }

    void RunAfter(const Functor& task, int delay)
    {
        _base_loop.RunInLoop(std::bind(&TcpServer::RunAfterInLoop, this, task, delay));
    }

    void Start()
    {
        _pool.Create(); // 创建线程池中的从线程
        _base_loop.Start();
    }

private:
    // 为新连接构造一个Connection对象进行管理
    void NewConnection(int fd)
    {
        ++_next_id;

        SharedConnection conn(new Connection(_pool.NextLoop(), _next_id, fd));
        conn->SetConnectedCallback(_connected_callback);
        conn->SetMessageCallback(_message_callback);
        conn->SetClosedCallback(_closed_callback);
        conn->SetAnyEventCallback(_event_callback);
        conn->SetServerClosedCallback(std::bind(&TcpServer::RemoveConnection, this, std::placeholders::_1));

        if (_enable_inactive_release) conn->EnableInactiveRelease(_timeout); // 启动非活跃连接销毁
        conn->Established(); // 就绪初始化

        _conns.insert(std::make_pair(_next_id, conn));
    }

    // 从管理Connection的_conns中移除连接信息
    void RemoveConnection(const SharedConnection& conn)
    {
        _base_loop.RunInLoop(std::bind(&TcpServer::RemoveConnectionInLoop, this, conn));
    }

    // 用于添加一个定时任务
    void RunAfterInLoop(const Functor& task, uint32_t delay)
    {
        ++_next_id;
        _base_loop.TimerAdd(_next_id, delay, task);
    }

    void RemoveConnectionInLoop(const SharedConnection& conn)
    {
        int id = conn->Id(); // 获取连接id

        auto it = _conns.find(id);
        if (it != _conns.end()) // 找到了
        {
            _conns.erase(it);
        }
    }

private:
    uint64_t _next_id;                                     // 自动增长的连接id,连接的唯一标识
    uint16_t _port;                                        // 连接的端口号
    int _timeout;                                          // 非活跃连接的统计时间----多长时间无通信就是非活跃
    bool _enable_inactive_release;                         // 是否启动非活跃连接超时销毁的判断标志
    EventLoop _base_loop;                                  // 主线程的EventLoop对象,负责监听套接字的事件的处理
    Acceptor _acceptor;                                    // 监听套接字的管理对象
    LoopThreadPool _pool;                                  // 从属EventLoop线程池
    std::unordered_map<uint64_t, SharedConnection> _conns; // 保存管理所有连接对应的shared_ptr对象

    ConnectedCallback _connected_callback;
    MessageCallback _message_callback;
    ClosedCallback _closed_callback;
    AnyEventCallback _event_callback;
};

class NetWork {
public:
    NetWork() 
    {
        DBG_LOG("SIGPIPE INIT");
        signal(SIGPIPE, SIG_IGN);
    }
};
static NetWork nw;

HTTP协议支持模块实现

Util模块

该模块用于整合一些零碎的功能性接口,在协议支持模块中,当需要某些零碎功能的时候,便于使用。

提供功能:

  • 读取文件内容

  • 向文件写入内容

  • URL编码

  • URL解码

  • HTTP状态码&描述信息

  • 根据文件后缀获取mime

  • 判断一个文件是否是目录

  • 判断一个文件是否是普通文件

  • HTTP资源路径有效性判断

| URL编码 |

避免URL中资源路径与查询字符串中的特殊字符与HTTP请求中的特殊字符产生歧义

编码格式:将特殊字符的ascii码转换为两个16进制字符,前缀为%(C++ --> C%2B%2B)

RFC3986文档规定,(.  -  _  ~  字母  数字)属于绝对不编码字符

RFC3986文档规定,编码格式为:%HH

W3C标准中规定,查询字符串中的空格需要编码为+,解码则是将+转空格

| URL解码 |

遇到%,则将紧随其后的两个字符分别转换为数字,然后将第一个数字左移4位,再加上第二个数字即可(%2B --> (2 << 4) + 11)

| 代码实现 |

std::unordered_map<int, std::string> _statu_msg = {
    {100, "Continue"},
    {101, "Switching Protocol"},
    {102, "Processing"},
    {103, "Early Hints"},
    {200, "OK"},
    {201, "Created"},
    {202, "Accepted"},
    {203, "Non-Authoritative Information"},
    {204, "No Content"},
    {205, "Reset Content"},
    {206, "Partial Content"},
    {207, "Multi-Status"},
    {208, "Already Reported"},
    {226, "IM Used"},
    {300, "Multiple Choice"},
    {301, "Moved Permanently"},
    {302, "Found"},
    {303, "See Other"},
    {304, "Not Modified"},
    {305, "Use Proxy"},
    {306, "unused"},
    {307, "Temporary Redirect"},
    {308, "Permanent Redirect"},
    {400, "Bad Request"},
    {401, "Unauthorized"},
    {402, "Payment Required"},
    {403, "Forbidden"},
    {404, "Not Found"},
    {405, "Method Not Allowed"},
    {406, "Not Acceptable"},
    {407, "Proxy Authentication Required"},
    {408, "Request Timeout"},
    {409, "Conflict"},
    {410, "Gone"},
    {411, "Length Required"},
    {412, "Precondition Failed"},
    {413, "Payload Too Large"},
    {414, "URI Too Long"},
    {415, "Unsupported Media Type"},
    {416, "Range Not Satisfiable"},
    {417, "Expectation Failed"},
    {418, "I'm a teapot"},
    {421, "Misdirected Request"},
    {422, "Unprocessable Entity"},
    {423, "Locked"},
    {424, "Failed Dependency"},
    {425, "Too Early"},
    {426, "Upgrade Required"},
    {428, "Precondition Required"},
    {429, "Too Many Requests"},
    {431, "Request Header Fields Too Large"},
    {451, "Unavailable For Legal Reasons"},
    {501, "Not Implemented"},
    {502, "Bad Gateway"},
    {503, "Service Unavailable"},
    {504, "Gateway Timeout"},
    {505, "HTTP Version Not Supported"},
    {506, "Variant Also Negotiates"},
    {507, "Insufficient Storage"},
    {508, "Loop Detected"},
    {510, "Not Extended"},
    {511, "Network Authentication Required"}};

std::unordered_map<std::string, std::string> _mime_msg = {
    {".aac", "audio/aac"},
    {".abw", "application/x-abiword"},
    {".arc", "application/x-freearc"},
    {".avi", "video/x-msvideo"},
    {".azw", "application/vnd.amazon.ebook"},
    {".bin", "application/octet-stream"},
    {".bmp", "image/bmp"},
    {".bz", "application/x-bzip"},
    {".bz2", "application/x-bzip2"},
    {".csh", "application/x-csh"},
    {".css", "text/css"},
    {".csv", "text/csv"},
    {".doc", "application/msword"},
    {".docx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"},
    {".eot", "application/vnd.ms-fontobject"},
    {".epub", "application/epub+zip"},
    {".gif", "image/gif"},
    {".htm", "text/html"},
    {".html", "text/html"},
    {".ico", "image/vnd.microsoft.icon"},
    {".ics", "text/calendar"},
    {".jar", "application/java-archive"},
    {".jpeg", "image/jpeg"},
    {".jpg", "image/jpeg"},
    {".js", "text/javascript"},
    {".json", "application/json"},
    {".jsonld", "application/ld+json"},
    {".mid", "audio/midi"},
    {".midi", "audio/x-midi"},
    {".mjs", "text/javascript"},
    {".mp3", "audio/mpeg"},
    {".mpeg", "video/mpeg"},
    {".mpkg", "application/vnd.apple.installer+xml"},
    {".odp", "application/vnd.oasis.opendocument.presentation"},
    {".ods", "application/vnd.oasis.opendocument.spreadsheet"},
    {".odt", "application/vnd.oasis.opendocument.text"},
    {".oga", "audio/ogg"},
    {".ogv", "video/ogg"},
    {".ogx", "application/ogg"},
    {".otf", "font/otf"},
    {".png", "image/png"},
    {".pdf", "application/pdf"},
    {".ppt", "application/vnd.ms-powerpoint"},
    {".pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation"},
    {".rar", "application/x-rar-compressed"},
    {".rtf", "application/rtf"},
    {".sh", "application/x-sh"},
    {".svg", "image/svg+xml"},
    {".swf", "application/x-shockwave-flash"},
    {".tar", "application/x-tar"},
    {".tif", "image/tiff"},
    {".tiff", "image/tiff"},
    {".ttf", "font/ttf"},
    {".txt", "text/plain"},
    {".vsd", "application/vnd.visio"},
    {".wav", "audio/wav"},
    {".weba", "audio/webm"},
    {".webm", "video/webm"},
    {".webp", "image/webp"},
    {".woff", "font/woff"},
    {".woff2", "font/woff2"},
    {".xhtml", "application/xhtml+xml"},
    {".xls", "application/vnd.ms-excel"},
    {".xlsx", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"},
    {".xml", "application/xml"},
    {".xul", "application/vnd.mozilla.xul+xml"},
    {".zip", "application/zip"},
    {".3gp", "video/3gpp"},
    {".3g2", "video/3gpp2"},
    {".7z", "application/x-7z-compressed"}};

class Util
{
public:
    // 字符串分割 ---- 将源字符串src按照sep进行分割,将分割得到的各个字符串放入array,最终返回子串的数量
    static size_t Split(const std::string& src, const std::string& sep, std::vector<std::string>* array)
    {
        size_t offset = 0;
        // abc,
        while (offset <= src.size())
        {
            size_t pos = src.find(sep, offset); // 在字符串src中从offset位置开始往后查找sep,找到返回sep最小下标
            if (pos == std::string::npos) // 从offset开始往后没有找到分隔字符串sep
            {
                // 将剩余部分当作整体,放入array中
                if (offset == src.size()) break;
                array->push_back(src.substr(offset));
                return array->size();
            }

            // 找到了分隔字符串sep
            if (pos - offset != 0)
            {
                array->push_back(src.substr(offset, pos - offset));
            }
            offset = pos + sep.size();
        }

        return array->size();
    }

    // 读取文件所有内容,将文件内容放到bu中
    static bool ReadFile(const std::string& filename, std::string* buf)
    {
        std::ifstream ifs(filename.c_str(), std::ios::binary);
        if (ifs.is_open() == false) // 打开文件失败
        {
            ERR_LOG("open %s file failed", filename.c_str());
            return false;
        }

        // 打开文件成功
        // 读取文件大小
        size_t fsize = 0;
        ifs.seekg(0, ifs.end); // 文件指针跳转到文件末尾
        fsize = ifs.tellg(); // 获取文件大小
        ifs.seekg(0, ifs.beg); // 文件指针跳转到文件开头

        // 将文件内容读取到str中
        buf->resize(fsize);
        ifs.read(&(*buf)[0], fsize);

        if (ifs.good() == false)
        {
            ERR_LOG("read %s file failed", filename.c_str());
            ifs.close();

            return false;
        }

        ifs.close();
        return true;
    }

    // 向文件中写入数据
    static bool WriteFile(const std::string& filename, const std::string& buf)
    {
        std::ofstream ofs(filename, std::ios::binary | std::ios::trunc);
        if (ofs.is_open() == false) // 打开文件失败
        {
            ERR_LOG("open %s file failed", filename.c_str());
            return false;
        }

        // 打开文件成功
        ofs.write(buf.c_str(), buf.size()); // 将buf中的内容写入ofs指向的文件中

        if (ofs.good() == false)
        {
            ERR_LOG("write %s file failed", filename.c_str());
            ofs.close();

            return false;
        }

        ofs.close();
        return true;
    }

    // URL编码
    static std::string URLEncode(const std::string& url, bool convert_space_to_plus)
    {
        std::string res;
        for (auto c : url)
        {
            if (c == '.' || c == '-' || c == '_' || c == '~' || isalnum(c))
            {
                res += c;
                continue;
            }

            if (c == ' ' && convert_space_to_plus)
            {
                res += '+';
                continue;
            }

            // 剩下的字符要编码成 %HH 的格式
            char tmp[4] = { 0 };
            snprintf(tmp, 4, "%%%02X", c);
            res += tmp;
        }

        return res;
    }

    static char HEXTOI(char c)
    {
        if (c >= '0' && c <= '9')
        {
            return c - '0';
        }
        else if (c >= 'a' && c <= 'z')
        {
            return c - 'a' + 10;
        }
        else if (c >= 'A' && c <= 'Z')
        {
            return c - 'A' + 10;
        }

        return -1;
    }

    // URL解码
    static std::string URLDecode(const std::string& url, bool convert_plus_to_space)
    {
        std::string res;
        for (int i = 0; i < url.size(); ++i)
        {
            if (url[i] == '+' && convert_plus_to_space)
            {
                res += ' ';
                continue;
            }

            if (url[i] == '%' && (i + 2) < url.size())
            {
                char v1 = HEXTOI(url[i + 1]);
                char v2 = HEXTOI(url[i + 2]);
                char v = (v1 << 4) + v2;
                res += v;
                i += 2;
                continue;
            }

            res += url[i];
        }

        return res;
    }

    // 获取响应状态码的描述信息
    static std::string StatuDesc(int statu)
    {
        auto it = _statu_msg.find(statu);
        if (it != _statu_msg.end())
        {
            return it->second;
        }

        return "Unkonw";
    }

    // 根据文件后缀名获取mime
    static std::string ExtMime(const std::string& filename)
    {
        // 获取文件扩展名
        size_t pos = filename.find_last_of('.');
        if (pos == std::string::npos)
        {
            return "application/octet-stream";
        }

        // 根据文件扩展名获取mime
        std::string ext = filename.substr(pos);
        auto it = _mime_msg.find(ext);
        if (it == _mime_msg.end())
        {
            return "application/octet-stream";
        }

        return it->second;
    }

    // 判断一个文件是否是目录
    static bool IsDirectory(const std::string& filename)
    {
        struct stat st;
        int ret = stat(filename.c_str(), &st);
        if (ret < 0)
        {
            return false;
        }

        return S_ISDIR(st.st_mode);
    }

    // 判断一个文件是否是普通文件
    static bool IsRegular(const std::string& filename)
    {
        struct stat st;
        int ret = stat(filename.c_str(), &st);
        if (ret < 0)
        {
            return false;
        }

        return S_ISREG(st.st_mode);
    }

    // http请求的资源路径有效性判断
    static bool ValidPath(const std::string& path)
    {
        std::vector<std::string> subdir;
        Split(path, "/", &subdir);

        int level = 0;
        for (const auto& dir : subdir)
        {
            if (dir == "..")
            {
                --level;
                if (level < 0)
                {
                    return false;
                }

                continue;
            }
            ++level;
        }

        return true;
    }
};

HttpRequest模块

该模块用于存储HTTP请求信息的要素,提供简单的功能性接口。

请求信息要素:

  • 请求行:请求方法,URL,协议版本
  • URL:资源路径,查询字符串
  • 请求头部
  • 正文

要素:请求方法,资源路径,查询字符串,头部字段,正文,协议版本

| 功能性接口 |

  1. 将成员变量设置为公有成员,便于直接访问
  2. 提供查询字符串,以及头部字段的单个查询,获取和插入功能
  3. 获取正文长度
  4. 判断长连接和短链接(Connection:close / keep-alive)

| 代码实现 |

class HttpRequest
{
public:
    HttpRequest() : _version("HTTP/1.1") {}

    // 插入头部字段
    void SetHeader(const std::string& key, const std::string& val)
    {
        _headers.insert(std::make_pair(key, val));
    }

    // 判断是否存在指定的头部字段
    bool HasHeader(const std::string& key) const
    {
        auto it = _headers.find(key);
        if (it == _headers.end())
        {
            return false;
        }

        return true;
    }

    // 获取指定的头部字段的值
    std::string GetHeader(const std::string& key) const
    {
        auto it = _headers.find(key);
        if (it == _headers.end())
        {
            return "";
        }

        return it->second;
    }

    // 插入查询字符串
    void SetParam(const std::string& key, const std::string& val)
    {
        _params.insert(std::make_pair(key, val));
    }

    // 判断是否存在指定的查询字符串
    bool HasParam(const std::string& key) const
    {
        auto it = _params.find(key);
        if (it == _params.end())
        {
            return false;
        }

        return true;
    }

    // 获取指定的查询字符串
    std::string GetParam(const std::string& key) const
    {
        auto it = _params.find(key);
        if (it == _params.end())
        {
            return "";
        }

        return it->second;
    }

    // 获取正文长度
    size_t ContentLength() const
    {
        bool ret = HasHeader("Content-Length");
        if (ret == false)
        {
            return 0;
        }

        std::string clen = GetHeader("Content-Length");
        return std::stol(clen);
    }

    // 判断是否是短链接
    bool Close() const
    {
        if (HasHeader("Connection") == true && GetHeader("Connection") == "keep-alive")
        {
            return false; // 长连接
        }

        return true; // 短链接
    }

    // 重置
    void ReSet()
    {
        _method.clear();
        _path.clear();
        _version = "HTTP/1.1";
        _body.clear();
        std::smatch match;
        _match.swap(match);
        _headers.clear();
        _params.clear();
    }

public:
    std::string _method;                                   // 请求方法
    std::string _path;                                     // 资源路径
    std::string _version;                                  // 协议版本
    std::string _body;                                     // 请求正文
    std::smatch _match;                                    // 资源路径的正则提取数据
    std::unordered_map<std::string, std::string> _headers; // 头部字段
    std::unordered_map<std::string, std::string> _params;  // 查询字符串
};

HttpResponse模块

该模块用于存储HTTP响应信息要素,提供简单的功能性接口。

响应信息要素:

  • 响应状态码
  • 头部字段
  • 响应正文
  • 重要定向信息(是否进行了重定向的标志,重定向的路径)

功能性接口:

  • 头部字段的新增,查询,获取
  • 正文的设置
  • 重定向的设置
  • 长短连接的判断

| 代码实现 |

class HttpResponse
{
public:
    HttpResponse()
        :_statu(200)
        , _redirect_flag(false)
    {}

    HttpResponse(int statu)
        :_statu(statu)
        , _redirect_flag(false)
    {}

    void ReSet()
    {
        _statu = 200;
        _redirect_flag = false;
        _body.clear();
        _redirect_url.clear();
        _headers.clear();
    }

    // 插入头部字段
    void SetHeader(const std::string& key, const std::string& val)
    {
        _headers.insert(std::make_pair(key, val));
    }

    // 判断是否存在指定的头部字段
    bool HasHeader(const std::string& key)
    {
        auto it = _headers.find(key);
        if (it == _headers.end())
        {
            return false;
        }

        return true;
    }

    // 获取指定的头部字段的值
    std::string GetHeader(const std::string& key)
    {
        auto it = _headers.find(key);
        if (it == _headers.end())
        {
            return "";
        }

        return it->second;
    }

    // 设置正文
    void SetContent(const std::string& body, const std::string& type = "text/html")
    {
        _body = body;
        SetHeader("Content-Type", type);
    }

    // 设置重定向
    void SetRedirect(const std::string& url, int statu = 302)
    {
        _statu = statu;
        _redirect_flag = true;
        _redirect_url = url;
    }

    // 判断是否是短链接
    bool Close()
    {
        if (HasHeader("Connection") == true && GetHeader("Connection") == "keep-alive")
        {
            return false; // 长连接
        }

        return true; // 短链接
    }

public:
    int _statu;                                            // 响应状态码
    bool _redirect_flag;                                   // 重定向标志
    std::string _body;                                     // 正文
    std::string _redirect_url;                             // 重定向url
    std::unordered_map<std::string, std::string> _headers; // 头部字段
};

HttpContext模块

该模块用于记录HTTP请求的接收和处理进度。

有可能出现接收的数据并不是一条完整的HTTP请求数据,也就是请求的处理需要再多次受到数据后才能处理完成,因此再每次处理的时候,就需要将处理的进度记录下来,以便下次从当前进度继续向下处理。

提供的功能性接口

  • 接受请求行

  • 解析请求行

  • 接收头部

  • 解析头部

  • 接收正文

  • 返回解析完毕的请求信息

  • 返回响应状态码

  • 返回接收解析状态

| 代码实现 |

typedef enum
{
    RECV_HTTP_ERROR,
    RECV_HTTP_LINE,
    RECV_HTTP_HEAD,
    RECV_HTTP_BODY,
    RECV_HTTP_OVER
} HttpRecvStatu;

#define MAX_LINE 8192
class HttpContext
{
public:
    HttpContext()
        :_resp_statu(200)
        , _recv_statu(RECV_HTTP_LINE)
    {}

    void ReSet()
    {
        _resp_statu = 200;
        _recv_statu = RECV_HTTP_LINE;
        _request.ReSet();
    } 

    int RespStatu() { return _resp_statu; }

    HttpRecvStatu RecvStatu() { return _recv_statu; }

    HttpRequest& Request() { return _request; }
    
    // 接收并解析HTTP请求
    void RecvHttpRequest(Buffer* buf)
    {
        switch (_recv_statu)
        {
        case RECV_HTTP_LINE:
            RecvHttpLine(buf);
        case RECV_HTTP_HEAD:
            RecvHttpHead(buf);
        case RECV_HTTP_BODY:
            RecvHttpBody(buf);
        }

        return;
    }

private:
    bool RecvHttpLine(Buffer* buf)
    {
        if (_recv_statu != RECV_HTTP_LINE) return false;
        
        // 1.获取一行数据(数据中包含末尾换行符)
        std::string line = buf->GetLineAndPop();
        
        // 2.需要考虑两个因素:缓冲区的数据不足一行/获取一行数据的大小很大
        if (line.size() == 0)
        {
            // 缓冲区中的数据不足一行,则需要判断缓冲区中可读数据的长度,若可读数据很长了都不足一行,那就是出问题了
            if (buf->ReadableSize() > MAX_LINE)
            {
                _recv_statu = RECV_HTTP_ERROR;
                _resp_statu = 414; // URI TOO LONG
                return false;
            }

            // 缓冲区中数据不足一行,但也不多,那就等后续数据的到来
            return true;
        }

        if (line.size() > MAX_LINE)
        {
            _recv_statu = RECV_HTTP_ERROR;
            _resp_statu = 414; // URI TOO LONG
            return false;
        }

        bool ret = ParseHttpLine(line);
        if (ret == false)
        {
            return false;
        }

        // 获取首行完毕,进入获取头部字段阶段
        _recv_statu = RECV_HTTP_HEAD;
        return true;
    }

    bool ParseHttpLine(const std::string& line)
    {
        std::smatch matches;
        std::regex e("(GET|POST|PUT|HEAD|DELETE|OPTIONS|TRACE|CONNECT|LINK|UNLINE) ([^?]*)(?:\\?(.*))? (HTTP/1\\.[01])(?:\n|\r\n)?", std::regex::icase);
        bool ret = std::regex_match(line, matches, e);
        if (ret == false)
        {
            _recv_statu = RECV_HTTP_ERROR;
            _resp_statu = 400; // BAD REQUEST
            return false;
        }

        // 0 : GET /home/login?user=nk&pass=123123 HTTP/1.1
        // 1 : GET
        // 2 : /home/login
        // 3 : user=nk&pass=123123
        // 4 : HTTP/1.1

        // 请求方法的获取
        _request._method = matches[1];
        std::transform(_request._method.begin(), _request._method.end(), _request._method.begin(), ::toupper); // 将请求行中的请求方法全部改成大写

        // 资源路径的获取,需要进行URL的解码操作,但不需要将+转为空格
        _request._path = Util::URLDecode(matches[2], false);

        // 协议版本的获取
        _request._version = matches[4];

        // 查询字符串的获取与处理
        std::vector<std::string> query_string_array;
        std::string query_string = matches[3]; // 获取查询字符串

        // 查询字符串的格式为:key=val&key=val...,先以&进行分隔,得到各个子串,格式为:key=val
        Util::Split(query_string, "&", &query_string_array);

        // 针对各个子串,以=进行分隔,得到key和val,对key和val也要进行解码操作
        for (auto& str : query_string_array)
        {
            size_t pos = str.find("=");
            if (pos == std::string::npos) // 在子串中没找到=,出错返回
            {
                _recv_statu = RECV_HTTP_ERROR;
                _resp_statu = 400; // BAD REQUEST
                return false;
            }

            // 获取key和val
            std::string key = Util::URLDecode(str.substr(0, pos), true);
            std::string val = Util::URLDecode(str.substr(pos + 1), true);

            // 插入查询字符串
            _request.SetParam(key, val);
        }

        return true;
    }

    bool RecvHttpHead(Buffer* buf)
    {
        if(_recv_statu != RECV_HTTP_HEAD) return false;
        // 一行一行取出数据,直到遇到空行为止,头部字段的格式为:key: val\r\nkey: val\r\n...

        while (true) // 在读取到空行之前要一直循环读取头部字段
        {
            // 1.获取一行数据
            std::string line = buf->GetLineAndPop();
            
            // 2.需要考虑两个因素:缓冲区的数据不足一行/获取一行数据的大小很大
            if (line.size() == 0)
            {
                // 缓冲区中的数据不足一行,则需要判断缓冲区中可读数据的长度,若可读数据很长了都不足一行,那就是出问题了
                if (buf->ReadableSize() > MAX_LINE)
                {
                    _recv_statu = RECV_HTTP_ERROR;
                    _resp_statu = 414; // URI TOO LONG
                    return false;
                }

                // 缓冲区中数据不足一行,但也不多,那就等后续数据的到来
                return true;
            }

            if (line.size() > MAX_LINE)
            {
                _recv_statu = RECV_HTTP_ERROR;
                _resp_statu = 414; // URI TOO LONG
                return false;
            }

            if (line == "\n" || line == "\r\n") // 读到了空行,获取所有头部字段完毕
            {
                break;
            }

            bool ret = ParseHttpHead(line);
            if (ret == false)
            {
                return false;
            }
        }

        // 获取头部完毕,进入获取正文阶段
        _recv_statu = RECV_HTTP_BODY;
        return true;
    }

    bool ParseHttpHead(std::string& line)
    {
        if (line.back() == '\n') line.pop_back(); // 末尾是换行则去掉换行字符
        if (line.back() == '\r') line.pop_back(); // 末尾是回车则去掉回车字符

        // 头部字段的格式:key: val\r\n
        size_t pos = line.find(": ");
        if (pos == std::string::npos)
        {
            _recv_statu = RECV_HTTP_ERROR;
            _resp_statu = 400; // BAD REQUEST
            return false;
        }

        // 获取key和val
        std::string key = line.substr(0, pos);
        std::string val = line.substr(pos + 2);

        // 插入头部字段
        _request.SetHeader(key, val);

        return true;
    }

    bool RecvHttpBody(Buffer* buf)
    {
        if (_recv_statu != RECV_HTTP_BODY) return false;

        // 1.获取正文长度
        size_t content_lenth = _request.ContentLength();
        if (content_lenth == 0) // 没有正文,则请求接收解析完毕
        {
            _recv_statu = RECV_HTTP_OVER;
            return true;
        }

        // 2.当前已经接收到多少正文,其实就是往_request._body中放了多少数据
        size_t real_len = content_lenth - _request._body.size(); // 实际还需要接收的长度

        // 3.接收正文放到body中,当时也要考虑当前缓冲区的数据是否是全部正文
        //    3.1.缓冲区中的数据包含了当前请求的所有正文,则取出所需的数据
        if (buf->ReadableSize() >= real_len)
        {
            _request._body.append(buf->GetReadPos(), real_len);
            buf->MoveReadOffset(real_len);

            _recv_statu = RECV_HTTP_OVER;
            return true;
        }

        //    3.2.缓冲区中的数据无法满足当前正文的完整性,数据不足,则取出数据,等待新数据的到来
        _request._body.append(buf->GetReadPos(), buf->ReadableSize());
        buf->MoveReadOffset(buf->ReadableSize());

        return true;
    }

private:
    int _resp_statu;           // 响应状态码
    HttpRecvStatu _recv_statu; // 当前接收及解析的阶段状态
    HttpRequest _request;      // 已经解析得到的请求信息
};

HttpServer模块

该模块用于实现HTTP服务器的搭建。

该模块中设计一张请求路由表,当服务器收到了一个请求,就在请求路由表中查找有没有对应请求的处理函数,如果有,则执行对应的请求处理函数即可。

说白了,收到什么请求,对应如何处理,由用户来设定,服务器收到请求只需要执行对应处理函数即可。

这样设计的好处:用户只需要实现业务处理函数,然后将请求与对应的处理函数的映射关系添加到服务器中的路由表中即可。而服务器只需要接收数据,解析数据,查找路由表映射关系,执行业务处理函数即可。

要素:

  • GET请求的路由映射表
  • POST请求的路由映射表
  • PUT请求的路由映射表
  • DELETE请求的路由映射表

(路由映射表记录请求方法与其对应的请求处理函数的映射关系,更多的是功能性请求的处理)

  • 静态资源相对根目录(实现静态资源请求的处理)
  • 高性能TCP服务器(进行连接的IO操作)

服务器处理流程:

  1. 从socket接收数据,放到接收缓冲区
  2. 调用OnMessage()回调函数进行业务处理
  3. 对请求进行解析,得到了一个HttpRequest结构,包含了所有的请求要素
  4. 进行请求的路由查找,找到对应请求的处理方法
  5. 对静态资源请求、功能性请求进行处理完毕后,得到了一个填充了响应信息的HttpResponse对象,组织HTTP格式响应, 进行发送

接口:

  • 添加请求-处理函数的映射关系信息(GET/POST/PUT/DELETE)
  • 设置静态资源根目录
  • 设置是否启动超时连接关闭
  • 设置线程池中线程的数量
  • 启动服务器
  • OnConnected() --- 用于给TcpServer设置协议上下文
  • OnMessage() --- 用于进行缓冲区数据解析处理
  • 请求的路由查找(静态资源请求的查找和处理,功能性请求的查找和处理)
  • 组织响应进行回复

| 代码实现 |

#define DEFAULT_TIMEOUT 10

class HttpServer
{
    using Handler = std::function<void(const HttpRequest&, HttpResponse*)>;
    using Handlers = std::vector<std::pair<std::regex, Handler>>;
public:
    HttpServer(int port, int timeout = DEFAULT_TIMEOUT)
        :_server(port)
    {
        _server.EnableInactiveRelease(timeout); // 启用非活跃连接超时销毁功能,默认为30s
        _server.SetConnectedCallback(std::bind(&HttpServer::OnConnected, this, std::placeholders::_1));
        _server.SetMessageCallback(std::bind(&HttpServer::OnMessage, this, std::placeholders::_1, std::placeholders::_2));
    }

    // 设置静态资源根目录
    void SetBaseDir(const std::string& path)
    {
        assert(Util::IsDirectory(path) == true);
        _basedir = path;
    }

    // 设置GET请求与对应处理函数的映射关系
    void Get(const std::string& pattern, const Handler& handler)
    {
        _get_route.push_back(make_pair(std::regex(pattern), handler));
    }

    // 设置POST请求与对应处理函数的映射关系
    void Post(const std::string& pattern, const Handler& handler)
    {
        _post_route.push_back(make_pair(std::regex(pattern), handler));
    }

    // 设置PUT请求与对应处理函数的映射关系
    void Put(const std::string& pattern, const Handler& handler)
    {
        _put_route.push_back(make_pair(std::regex(pattern), handler));
    }

    // 设置DELETE请求与对应处理函数的映射关系
    void Delete(const std::string& pattern, const Handler& handler)
    {
        _delete_route.push_back(make_pair(std::regex(pattern), handler));
    }

    // 设置线程数量
    void SetThreadCount(int count) { _server.SetThreadCount(count); }

    // 监听
    void Listen() { _server.Start(); }

private:
    // 将HttpResponse中的要素按照HTTP协议格式进行组织,发回
    void WriteResponse(const SharedConnection& conn, HttpRequest& req, HttpResponse& resp)
    {
        // 1.先完善头部字段
        if (req.Close() == true)
        {
            resp.SetHeader("Connection", "close");
        }
        else
        {
            resp.SetHeader("Connection", "keep-alive");
        }

        if (resp._body.empty() == false && resp.HasHeader("Content-Length") == false) // 响应正文不为空,但响应的头部字段中没有正文长度
        {
            resp.SetHeader("Content-Length", std::to_string(resp._body.size()));
        }

        if (resp._body.empty() == false && resp.HasHeader("Content-Type") == false) // 响应正文不为空,但响应的头部字段中没有正文类型
        {
            resp.SetHeader("Content-Type", "application/octet-stream");
        }

        if (resp._redirect_flag == true)
        {
            resp.SetHeader("Location", resp._redirect_url);
        }

        // 2.将resp中的要素按照HTTP协议格式进行组织
        std::stringstream resp_str;
        // 组织响应首行
        resp_str << req._version << " " << std::to_string(resp._statu) << " " << Util::StatuDesc(resp._statu) << "\r\n";
        
        // 组织响应头部字段
        for (const auto& it : resp._headers)
        {
            resp_str << it.first << ": " << it.second << "\r\n";
        }
        resp_str << "\r\n"; // 添加空行

        // 组织响应正文
        resp_str << resp._body;

        // 3.发送数据给客户端
        conn->Send(resp_str.str().c_str(), resp_str.str().size());
    }

    // 静态资源的请求处理
    void FileHandler(const HttpRequest& req, HttpResponse* resp)
    {
        std::string req_path = _basedir + req._path;
        if (req._path.back() == '/')
        {
            req_path += "index.html";
        }

        bool ret = Util::ReadFile(req_path, &resp->_body);
        if (ret == false) return;

        std::string mime = Util::ExtMime(req_path);
        resp->SetHeader("Content-Type", mime);

        return;
    }

    // 功能性请求的分类处理
    void Dispatcher(HttpRequest& req, HttpResponse* resp, Handlers& handlers)
    {
        // 在对应请求方法路由表中查找是否存在对应的请求处理函数,有则调用该函数,没有则返回404
        // handlers的本质是一个unordered_map,存放请求和对应的处理函数的映射关系

        // 思想:路由表存储的键值对 ---- 正则表达式 & 处理函数
        // 使用正则表达式对请求的资源路径进行正则匹配,匹配成功就调用对应函数进行处理
        for (auto& handler : handlers)
        {
            const std::regex& re = handler.first;
            const Handler& functor = handler.second;

            bool ret = std::regex_match(req._path, req._match, re);
            if(ret == false) // 正则匹配失败
            {
                continue; // 换一个路由请求表继续匹配
            }

            return functor(req, resp); // 传入请求信息和空的resp,执行处理函数
        }

        resp->_statu = 404; // 所有的路由表都没有匹配的请求处理函数,则返回404
    }

    // 判断是否是静态资源请求
    bool IsFileHandler(const HttpRequest& req)
    {
        // 1.必须设置静态资源根目录
        if (_basedir.empty()) return false;

        // 2.请求方法必须是GET/HEAD
        if (req._method != "GET" && req._method != "HEAD") return false;

        // 3.请求的资源路径必须是一个合法的路径
        if (Util::ValidPath(req._path) == false) return false;

        // 4.请求的资源必须存在,且必须是一个普通文件
        std::string req_path = _basedir + req._path;

        //   特殊情况:资源路径结尾是根目录的 -- 目录:/  /image/,这种情况默认在结尾追加一个index/html
        if (req._path.back() == '/')
        {
            req_path += "index.html";
        }

        // 判断资源路径是否是一个普通文件
        if (Util::IsRegular(req_path) == false) return false;

        return true;
    }

    void Route(HttpRequest& req, HttpResponse* resp)
    {
        // 1.对请求继续分辨,判断是静态资源请求还是功能性请求
        //   静态资源请求,则进行静态资源的处理
        //   功能性请求,则需要通过几个路由表来确定是否有对应的处理函数
        //   既不是静态资源请求,也没有设置对应的功能性请求处理函数,则返回405

        if (IsFileHandler(req) == true)
        {
            return FileHandler(req, resp);
        }

        if (req._method == "GET" || req._method == "HEAD")
        {
            return Dispatcher(req, resp, _get_route);
        }
        else if (req._method == "POST")
        {
            return Dispatcher(req, resp, _post_route);
        }
        else if (req._method == "PUT")
        {
            return Dispatcher(req, resp, _put_route);
        }
        else if (req._method == "DELETE")
        {
            return Dispatcher(req, resp, _delete_route);
        }

        resp->_statu = 405; // Method Not Allowed
    }

    // 设置上下文
    void OnConnected(const SharedConnection& conn)
    {
        conn->SetContext(HttpContext());
        DBG_LOG("new connection %p", conn.get());
    }

    void ErrorHandler(HttpResponse* resp)
    {
        // 1.组织一个错误的展示页面
        std::string body;
        body += "<html>";
        body += "<head>";
        body += "<meta http-equiv='Content-Type' content='text/html;charset=utf-8'>";
        body += "</head>";
        body += "<body>";
        body += "<h1>";
        body += std::to_string(resp->_statu);
        body += " ";
        body += Util::StatuDesc(resp->_statu);
        body += "</h1>";
        body += "</body>";
        body += "</html>";

        // 2.将页面数据,当作响应正文,放入resp中
        resp->SetContent(body, "text/html");
    }

    // 缓冲区数据解析+处理
    void OnMessage(const SharedConnection& conn, Buffer* buf)
    {
        while (buf->ReadableSize() > 0) // 只要缓冲区中有数据,解析+处理就是一个循环处理的过程
        {
            // 1.获取上下文
            HttpContext* context = conn->GetContext()->get<HttpContext>();

            // 2.通过上下文对缓冲区数据进行解析,得到HttpRequest对象
            // 2.1.若缓冲区的数据解析出错,则直接回复出错响应
            // 2.2.若解析正常,且请求已经获取完毕,才开始去进行处理
            context->RecvHttpRequest(buf);
            HttpRequest& req = context->Request();
            HttpResponse resp(context->RespStatu());

            if (context->RespStatu() >= 400)
            {
                // 进行错误响应,关闭连接
                ErrorHandler(&resp); // 填充一个错误显示页面到resp中
                WriteResponse(conn, req, resp); // 组织响应发送给客户端

                context->ReSet(); // 重置RecvStatu和RespStatu
                buf->MoveReadOffset(buf->ReadableSize()); // 出错就清空缓冲区中的数据
                conn->Shutdown();
                return;
            }

            if (context->RecvStatu() != RECV_HTTP_OVER)
            {
                // 当前请求还没有接收完整,先返回等新数据的到来再重新处理
                return;
            }

            // 3.请求路由 + 业务处理
            Route(req, &resp);

            // 4.对HttpResponse进行组织发送
            WriteResponse(conn, req, resp);

            // 5.重置上下文
            context->ReSet();

            // 6.根据长短连接判断是否关闭连接或者继续处理
            if (resp.Close() == true) conn->Shutdown(); // 短连接则直接关闭
        }

        return;
    }

private:
    Handlers _get_route;    // GET请求的路由映射表
    Handlers _post_route;   // POST请求的路由映射表
    Handlers _put_route;    // PUT请求的路由映射表
    Handlers _delete_route; // DELETE请求的路由映射表
    std::string _basedir;   // 静态资源路径
    TcpServer _server;      // tcp服务器
};

服务器搭建与测试

服务器搭建

| 代码实现 |

#include "http.hpp"

#define WWWROOT "./wwwroot/"

std::string RequestStr(const HttpRequest& req)
{
    std::stringstream ss;

    ss << req._method << " " << req._path << " " << req._version << "\r\n";

    for (auto& it : req._params)
    {
        ss << it.first << ": " << it .second << "\r\n";
    }

    for (auto& it : req._headers)
    {
        ss << it.first << ": " << it.second << "\r\n";
    }

    ss << "\r\n";
    ss << req._body;

    return ss.str();
}

void Hello(const HttpRequest& req, HttpResponse* resp)
{
    resp->SetContent(RequestStr(req), "text/plain");
}

void Login(const HttpRequest& req, HttpResponse* resp)
{
    resp->SetContent(RequestStr(req), "text/plain");
}

void PutFile(const HttpRequest& req, HttpResponse* resp)
{
    resp->SetContent(RequestStr(req), "text/plain");
}

void DeleteFile(const HttpRequest& req, HttpResponse* resp)
{
    resp->SetContent(RequestStr(req), "text/plain");
}

int main()
{
    HttpServer server(8080);
    server.SetThreadCount(3);
    server.SetBaseDir(WWWROOT);
    server.Get("/hello", Hello);
    server.Post("/login", Login);
    server.Put("/1234.txt", PutFile);
    server.Delete("/1234.txt", DeleteFile);
    server.Listen();

    return 0;
}

服务器测试

测试一:长连接测试

创建一个客户端持续给服务器发送数据,直到超过超时时间,看看是否正常运行。

| 代码实现 |

#include "../server.hpp"

int main()
{
    Socket cli_sock;
    cli_sock.CreateClient(8080, "127.0.0.1");

    std::string req = "GET /hello HTTP/1.1\r\nConnection: keep-alive\r\nContent-Length: 0\r\n\r\n";

    while (true)
    {
        assert(cli_sock.Send(req.c_str(), req.size()) != -1);

        char buf[1024] = { 0 };
        assert(cli_sock.Recv(buf, 1023) != -1);
        DBG_LOG("[%s]", buf);
        
        sleep(3);
    }
    cli_sock.Close();

    return 0;
}

测试二:超时连接测试

建一个客户端,给服务器发送一次数据后就停止操作,查看服务器是否会正常的超时关闭连接。

| 代码实现 |

#include "../server.hpp"

int main()
{
    Socket cli_sock;
    cli_sock.CreateClient(8080, "127.0.0.1");

    std::string req = "GET /hello HTTP/1.1\r\nConnection: keep-alive\r\nContent-Length: 0\r\n\r\n";

    while (true)
    {
        assert(cli_sock.Send(req.c_str(), req.size()) != -1);

        char buf[1024] = { 0 };
        assert(cli_sock.Recv(buf, 1023) != -1);
        DBG_LOG("[%s]", buf);
        
        sleep(15);
    }
    cli_sock.Close();

    return 0;
}

测试三:错误请求测试

给服务器发送一个数据,告诉服务器要发送1024字节的数据,但实际发送的数据不足1024字节,查看服务器处理结果。

  1. 如果数据只发送一次,服务器将得不到完整的请求就不会进行业务处理,客户端也就得不到响应,最终超时关闭连接。

  2. 连着给服务器发送了多次小的请求,服务器会将后边的请求当作是前面请求的正文呢来处理,而后面处理请求时可能会业务错误的处理而关闭连接。

| 代码实现 |

#include "../server.hpp"

int main()
{
    Socket cli_sock;
    cli_sock.CreateClient(8080, "127.0.0.1");

    std::string req = "GET /hello HTTP/1.1\r\nConnection: keep-alive\r\nContent-Length: 50\r\n\r\nnKzbc";

    while (true)
    {
        assert(cli_sock.Send(req.c_str(), req.size()) != -1);
        assert(cli_sock.Send(req.c_str(), req.size()) != -1);
        assert(cli_sock.Send(req.c_str(), req.size()) != -1);
        assert(cli_sock.Send(req.c_str(), req.size()) != -1);
        assert(cli_sock.Send(req.c_str(), req.size()) != -1);

        char buf[1024] = { 0 };
        assert(cli_sock.Recv(buf, 1023));
        DBG_LOG("[%s]", buf);
        
        sleep(3);
    }
    cli_sock.Close();

    return 0;
}

测试四:业务处理超时测试

业务处理超时,查看服务器的处理情况。

当服务器达到了一个性能瓶颈,在一次业务处理中花费了太长的时间(超过了服务器设置的非活跃超时时间),在一次业务处理中耗费太长时间,导致其他的连接也被连累超时,其他的连接有可能会被拖累超时释放。

假设现在  12345描述符就绪了, 在处理1的时候花费了30s处理完,超时了,导致2345描述符因为长时间没有刷新活跃度。

  1. 如果接下来的2345描述符都是通信连接描述符,如果都就绪了,则并不影响,因为接下来就会进行处理并刷新活跃度。

  2. 如果接下来的2号描述符是定时器事件描述符,定时器触发超时,执行定时任务,就会将345描述符给释放掉,这时候一旦345描述符对应的连接被释放,接下来在处理345事件的时候就会导致程序崩溃(内存访问错误)。因此这时候,在本次事件处理中,并不能直接对连接进行释放,而应该将释放操作压入到任务池中,等到事件处理完了执行任务池中的任务的时候,再去释放。

| 代码实现 |

#include "../server.hpp"

int main()
{
    signal(SIGCHLD, SIG_IGN);
    for (int i = 0; i < 10; ++i) // 父进程循环创建10个子进程
    {
        pid_t pid = fork();
        if (pid < 0)
        {
            DBG_LOG("fork() error");
            return -1;
        }
        else if (pid == 0)
        {
            Socket cli_sock;
            cli_sock.CreateClient(8080, "127.0.0.1");

            std::string req = "GET /hello HTTP/1.1\r\nConnection: keep-alive\r\nContent-Length: 0\r\n\r\n";

            while (true)
            {
                assert(cli_sock.Send(req.c_str(), req.size()) != -1);

                char buf[1024] = {0};
                assert(cli_sock.Recv(buf, 1023));
                DBG_LOG("[%s]", buf);
            }
            cli_sock.Close();
            exit(0);
        }
    }

    while (true) sleep(1);

    return 0;
}

测试五:处理多条请求测试

一次性给服务器发送多条数据,然后查看服务器的处理结果。每一条请求都要得到正常处理。

| 代码实现 |

#include "../server.hpp"

int main()
{
    Socket cli_sock;
    cli_sock.CreateClient(8080, "127.0.0.1");

    std::string req = "GET /hello HTTP/1.1\r\nConnection: keep-alive\r\nContent-Length: 0\r\n\r\n";
    req += "GET /hello HTTP/1.1\r\nConnection: keep-alive\r\nContent-Length: 0\r\n\r\n";
    req += "GET /hello HTTP/1.1\r\nConnection: keep-alive\r\nContent-Length: 0\r\n\r\n";

    while (true)
    {
        assert(cli_sock.Send(req.c_str(), req.size()) != -1);

        char buf[1024] = { 0 };
        assert(cli_sock.Recv(buf, 1023));
        DBG_LOG("[%s]", buf);
        
        sleep(3);
    }
    cli_sock.Close();

    return 0;
}

 

测试六:大文件传输测试

给服务器上传一个大文件,服务器将文件保存下来,观察处理结果。

预期:上传的文件和服务器保存的文件一致。

#include "../http/http.hpp"

int main()
{
    Socket cli_sock;
    cli_sock.CreateClient(8080, "127.0.0.1");

    std::string req = "PUT /1234.txt HTTP/1.1\r\nConnection: keep-alive\r\n";
    std::string body;
    Util::ReadFile("./hello.txt", &body);
    req += "Content-Length: " + std::to_string(body.size()) + "\r\n\r\n";

    assert(cli_sock.Send(req.c_str(), req.size()) != -1);
    assert(cli_sock.Send(body.c_str(), body.size()) != -1);

    char buf[1024] = {0};
    assert(cli_sock.Recv(buf, 1023) != -1);
    DBG_LOG("[%s]", buf);

    cli_sock.Close();

    return 0;
}

测试七:性能压力测试

借助webbench工具,创建大量的进程,在进程中,创建客户端连接服务器,发送请求,收到响应后关闭连接,开始写一个连接的建立。

测试环境:

云服务器 CPU 2核 - 内存 2GB - 带宽4Mbps - 系统盘 50GB

客户端同使用云服务器配置 - CentOS 7.6

使用webbench以3000并发量,向服务器发送请求,进行了3分钟的测试

测试结果:

并发量:3000

QPS:1471


项目源码

 项目完整源代码 → https://gitee.com/ning-jiahao_NICK/project/tree/master/TcpServer/src


原文地址:https://blog.csdn.net/m0_61064157/article/details/134270462

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!