从零实现系列|RPC

0.序言

1.服务端与消息编码

2.高性能客户端

3.服务注册

这一章是干什么的?

服务端的主要工作:

  • 监听端口
  • 响应请求
    • 解析请求
    • 处理请求
      • 从请求中获取方法名、参数
      • 获取方法
      • 调用方法
    • 返回结果

这一章主要完成处理请求部分,其中,从请求中获取方法名、参数在上一章已完成。client 在发送请求时,方法名在请求头里 client.header.ServiceMethod,参数在请求体里 call.Args

所以我们重点关注

  • 获取方法
  • 调用方法

如何获取方法

服务端通过方法名获取实际的方法,需要两步:

  1. 在启动服务端服务时,我们可以将所有可调用的方法都加到一个 map 中。

  2. 当 client 调用 hello 时,我们从 map 中找到对应的方法皆可。

由于不同类型的结构体,可能有同名的结构体方法。我们的方法名,使用 结构体实例名.方法名 来作为方法名。比如上一章中的,结构体类型 Foo ,有 Sum 方法,客户端实例化 foo 对象后,远程调用 Client.Call 时,传入的就是 foo.Sum

如何方法注册到 map 中

首先,先说数据存储方式,简单来说是这样的,

flowchart LR
    A[Server.serviceMap] -->|map储存多个service| B[Service<br>一个 service 表示一个变量类型]
    B -->|service包含一个 map 变量| C[service.method]
    C -->|map 储存多个方法| D[methodType<br>一个 methodType 表示一个该变量类型的方法]
 

具体数据存储方式设计如下:

服务端保存着一个 serviceMap。使用 serviceMap 保存变量类型及其方法。key 是变量类型名,value 是 service 类型实例。

1
2
3
4
5
// server.go
// Server 定义 Server 结构体,封装了 Accept、ServeConn、serveCodec 方法
type Server struct {
serviceMap sync.Map
}

service 类型定义如下,一个 service 表示一个变量类型。其中,也包含一个 map,保存着这个变量的方法。key 是方法名,value 是 methodType 类型实例。

1
2
3
4
5
6
7
8
// service.go
// 一个 service 表示一个变量类型
type service struct {
name string // 变量类型名,用来打 log,例如字符串 "Foo"
typ reflect.Type // 变量类型,通过变量的类型,可以直接获取变量的方法,例如 Foo 类型,可获取到方法 Sum
rcvr reflect.Value // 变量的值,通过变量的值,可以调用变量的方法,例如 &foo,调用 Sum
method map[string]*methodType // map 储存变量的方法中所有可调用的方法
}

methodType 类型定义如下,一个 methodType 表示一个方法。

1
2
3
4
5
6
7
8
// service.go
// 一个 methodType 表示一个方法
type methodType struct {
method reflect.Method // 方法本身,用于调用方法,例如 Foo.Sum
ArgType reflect.Type // 参数的类型,用于判断参数是否正确
ReplyType reflect.Type // 返回的类型,用于判断返回值是否正确
numCalls uint64 // 方法调用次数,用于统计方法调用次数
}

最后,下面将以一个实际例子,演示如何将方法注册到 map 中。

graph LR
A[startServer, 启动服务器] --> B[Server.Register , 为 Foo 类型注册服务]
B --> C[newService, 为 Foo 类型创建 service 结构体]
C --> D[获取 rcvr 值, 类型名和类型]
C --> G[service.registerMethods, 获取方法列表]
G --> K[检查方法的参数和返回值]
K --> L[保存方法到 map 变量 service.method]
  1. 启动服务,将 Foo 类型注册到 rpc server 中。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// main.go

type Foo int // Foo 类型,实现了 Sum 方法
type Args struct{ Num1, Num2 int }
func (f Foo) Sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}

func startServer(addr chan string) {

var foo Foo // 实例化 Foo 类型的对象
if err := geerpc.Register(&foo); err != nil { // 注册 Foo 类型的对象,注册的是 Foo 类型的对象,不是 Foo 类型的方法
log.Fatal("register error:", err)
}

l, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatal("network error:", err)
}
log.Println("start rpc server on", l.Addr())
addr <- l.Addr().String()
geerpc.Accept(l)
}
  1. 为 Foo 类型创建 service 对象,并保存到 server 的 map 中。
1
2
3
4
5
6
7
8
9
10
11
// server.go
var DefaultServer = NewServer()
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }

func (server *Server) Register(rcvr interface{}) error {
s := newService(rcvr) // 为 rcvr 变量的类型创建 service 结构体
if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { // 调用 serviceMap.LoadOrStore 将 service 结构体保存到 map 中
return errors.New("rpc: service already defined: " + s.name)
}
return nil
}
  1. 创建 service 对象、获取对象的方法、检查方法参数和返回值,最后将方法保存到 map 中。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
// service.go
// newService 通过发射获取 rcvr 变量的类型及其方法
func newService(rcvr interface{}) *service {
s := new(service)
s.rcvr = reflect.ValueOf(rcvr) // 获取 rcvr 变量的值,例如 &foo
s.name = reflect.Indirect(s.rcvr).Type().Name() // 获取 rcvr 变量的类型名,例如 "Foo"
s.typ = reflect.TypeOf(rcvr) // 获取 rcvr 变量的类型,例如 Foo
if !ast.IsExported(s.name) {
log.Fatalf("rpc server: %s is not a valid service name", s.name)
}
s.registerMethods() // 获取 rcvr 变量的方法列表,例如 Foo.Sum,并将其保存到 service 的方法 map 中
return s
}

// registerMethods 获取 rcvr 变量的方法列表,例如 Foo.Sum,并将其保存到 service 的方法 map 中
func (s *service) registerMethods() {
s.method = make(map[string]*methodType)

for i := 0; i < s.typ.NumMethod(); i++ { // 遍历 rcvr 变量的方法列表
method := s.typ.Method(i) // 获取 rcvr 变量,第 i 个方法
mType := method.Type // 获取 rcvr 变量的方法的类型

// 检查方法的参数是否正确。rpc 的方法必须满足以下条件:
// 1. 方法有 3 个参数,第 1 个参数是 rcvr 变量(相当于 python 的 self,java 的 this),第 2 个参数是传入参数,第 3 个参数是传出参数,指针类型
// 2. 第 2 个参数和第 3 个参数都是导出的类型
// 3. 返回值只有一个,是 error 类型
// 例如这样 func (t *T) MethodName(argType T1, replyType *T2) error
if mType.NumIn() != 3 || mType.NumOut() != 1 { // 检查参数个数
continue
}
if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { // 检测返回值个数和类型
continue
}
argType, replyType := mType.In(1), mType.In(2)
if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { // 检查参数类型是否是导出的类型或者内置类型
continue
}

// 检查完毕,将方法保存到 service 的方法 map 中
s.method[method.Name] = &methodType{
method: method,
ArgType: argType,
ReplyType: replyType,
}
log.Printf("rpc server: register %s.%s\n", s.name, method.Name)
}
}

// isExportedOrBuiltinType 检查类型是否是导出的类型或者内置类型
func isExportedOrBuiltinType(t reflect.Type) bool {
return ast.IsExported(t.Name()) || t.PkgPath() == ""
}

如何从 map 获取方法

在 readRequest 时,从请求头中,获取方法名;根据方法名在对应的 service 的 map 中获取到方法、传入参数、传出参数。

graph LR
A[Server.readRequest] -->|解析请求头| B[Server.readRequestHeader]
B-->|返回请求头,包含方法名| A
A -->|获取方法| C[Server.findService]
A -->|获取方法传入参数,传出参数| D[GobCodec.ReadBody]
E[request]
C-->|方法保存到 request| E
D-->|传入参数,传出参数保存到 request| E
  1. request 储存了一次远程调用请求的信息,方法、传入参数、传出参数指针。它的定义如下。
1
2
3
4
5
6
7
8
// server.go
// request 表示一次调用的所有信息
type request struct {
h *codec.Header // 请求头
svc *service // 请求对应的服务,使用 svc.call 调用对应的方法
mtype *methodType // 请求对应的方法,是 svc.call 的第一个参数
argv, replyv reflect.Value // 方法的传入参数和传出参数,是 svc.call 的第二个和第三个参数
}
  1. readRequest 获取方法名、方法、传入参数、传出参数,并储存在 request 实例中,为方法调用做准备。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// server.go
// readRequest 读取请求:调用 readRequestHeader 方法读取请求头,调用 ReadBody 方法读取请求参数,返回 request 结构体
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
// 读取请求头
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}

// 初始化请求结构体
req := &request{h: h}

// 根据请求头中的 ServiceMethod 字段找到对应的服务和方法类型
req.svc, req.mtype, err = server.findService(h.ServiceMethod)
if err != nil {
return req, err
}

// 创建传入参数和传出参数的反射对象
req.argv = req.mtype.newArgv()
req.replyv = req.mtype.newReplyv()

// 检查请求传入参数的类型是否为指针类型,如果不是,则使用 Addr() 方法将 req.argv 转换为指针类型
// 为什么?
// 因为如果传入值是值类型,传入后,是值拷贝,不会修改传入变量的原值,所以需要使用 Addr() 获取地址后传入。
argvi := req.argv.Interface() // 使用 interface() 方法将 req.argv 转换为 interface{} 类型,这样可以传入任意类型的参数
if req.argv.Type().Kind() != reflect.Ptr {
argvi = req.argv.Addr().Interface()
}

// ReadBody 方法会将请求参数解码到 argvi 中储存
if err = cc.ReadBody(argvi); err != nil {
log.Println("rpc server: read body err:", err)
return req, err
}
return req, nil
}

另外,需要注意一下,

  • 创建传入参数时,需区分值类型还是指针类型。
  • 创建传出参数时,需对 map、slice 特殊处理。(因为 reflect.New 初始化时,map、slice 都是 nil)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// service.go
// newArgv 创建一个参数变量
func (m *methodType) newArgv() reflect.Value {
var argv reflect.Value
// 参数可能是指针类型也可能是值类型
if m.ArgType.Kind() == reflect.Ptr {
argv = reflect.New(m.ArgType.Elem()) // 指针类型,则创建指针类型变量
} else {
argv = reflect.New(m.ArgType).Elem() // 值类型,则创建值类型变量
}
return argv
}

// newReplyv 创建一个传出参数变量
func (m *methodType) newReplyv() reflect.Value {
// 传出参数必须是指针类型
replyv := reflect.New(m.ReplyType.Elem())

// 为什么需要对 map 和 slice 特殊处理?
// reflect.New 创建的 map 和 slice 都是 nil,需要先初始化,再使用。
// 为什么 newArgv 方法不需要对 map 和 slice 特殊处理?
// 因为 argv 是用于接收调用方传入的参数,这些参数由调用方提供且已经初始化。
switch m.ReplyType.Elem().Kind() {
case reflect.Map:
replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
case reflect.Slice:
replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
}
return replyv
}

如何调用方法

在 handleRequest 时,通过 service.call 利用反射来调用函数。这里主要学习这个用法。

graph LR
A[Server.serveCodec]-->|获取请求信息| B[Server.readRequest] 
A -->|调用方法| C[Server.handleRequest]
C --> D[request.svc.call, 也就是 Service.call]
D -->|使用反射调用| E[methodType.method.Func.Call]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
// server.go
// handleRequest 处理请求:构造请求响应信息,调用 sendResponse 方法发送响应
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) {
defer wg.Done()

err := req.svc.call(req.mtype, req.argv, req.replyv) // 调用

if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
return
}
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
// service.go
// call 调用 rcvr 变量的方法,例如 Foo.Sum
func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
atomic.AddUint64(&m.numCalls, 1)
f := m.method.Func
// 用反射的方式调用方法
// 第一个参数是 rcvr 变量,例如 &foo,类似于 java 的 this,python 的 self
// 第 2 个参数是传入参数,第 3 个参数是传出参数,指针类型
returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv})
if errInter := returnValues[0].Interface(); errInter != nil {
return errInter.(error)
}
return nil
}

更深入了解反射

请看这篇 https://www.aimtao.net/go#11-reflect

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
# 代码结构
version_3_service[geerpc]
├── client.go
├── codec
│   ├── codec.go
│   └── gob.go
├── go.mod
├── main
│   └── main.go (改动)
├── server.go (改动)
├── service.go (改动)
└── service_test.go (改动)

4.超时处理

哪些地方需要超时处理

简单来看整个流程:

  • 客户端:

    1. 拨号连接 ✓

    2. 远程调用(发送数据)✓

    3. 等待服务端处理 ✓

    4. 接收返回的结果(接收数据)✓

  • 服务端:

    1. 端口监听

    2. 接收请求信息(接收数据)✓

    3. 调用方法处理请求信息(处理数据)✓

    4. 返回请求结果(发送数据)✓

以上打 ✓ 的步骤,均可能出现超时,主要为,

  • 客户端建立连接超时
  • 发送数据
    • 客户端/服务端写报文时超时
  • 处理数据
    • 服务端调用方法超时
  • 接收数据
    • 客户端等待服务端响应超时
    • 客户端/服务端读取报文超时

基于这个超时的情景,我们可以在以下三个地方设置超时处理机制。

客户端:

  • 建立连接
  • Client.call() 的整个过程(包含发送数据、等待处理、接收数据)

服务端:

  • Server.handleRequest() 的整个过程(包括接收数据、处理数据、发送数据)

建立连接的超时处理

启动一个 goroutine 来拨号建立连接,使用 select 设置超时器,阻塞等待拨号结果,如果在接收拨号结果之前,先收到了超时信号,则进行超时处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
// client.go
type NewClientFunc func(conn net.Conn, opt *Option) (*Client, error)

func Dial(network, address string, opts ...*Option) (client *Client, err error) {
return dialTimeout(NewClient, network, address, opts...)
}

type dialResult struct {
client *Client
err error
}

func dialTimeout(f NewClientFunc, network, address string, opts ...*Option) (client *Client, err error) {
// 默认使用 Gob 编码
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}

// 声明一个通道,用于传输拨号建立连接的结果
result := make(chan dialResult, 1) // 设置缓冲区为 1,防止在超时后,无人接收 channel 数据,导致 channel 发送时阻塞,导致 goroutine 泄漏

// 当发生错误时,保证 client 为 nil
defer func() {
if err != nil {
client = nil
}
}()

// 启动一个 goroutine 连接服务器,连接成功后,调用 f 创建 Client 实例,并将结果发送到 result 通道中
go func() {
conn, err := net.DialTimeout(network, address, opt.ConnectTimeout)
if err != nil {
result <- dialResult{client: nil, err: err}
return
}
client, err = f(conn, opt)
result <- dialResult{client: client, err: err}
}()

// 如果超时时间为 0,表示没限制,直接等待 result 通道返回结果
if opt.ConnectTimeout == 0 {
result := <-result
return result.client, result.err
}

// 超时处理,阻塞等待,等待超时或收到结果
select {
case <-time.After(opt.ConnectTimeout):
return nil, fmt.Errorf("rpc client: connect timeout: %v", opt.ConnectTimeout)
case result := <-result:
return result.client, result.err
}
}

超时的时长,设置在 Option struct 内,便于用户自定义。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// server.go

// Option 定义 Option 结构体,封装了 MagicNumber 和 CodecType 字段,从 conn 中解析出 Option 的信息,表示 RPC 消息的编码方式
type Option struct {
MagicNumber int
CodecType codec.Type
ConnectTimeout time.Duration // Client 建立连接的超时时间
HandleTimeout time.Duration // Client.Call() 整个过程的超时时间
}

var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
ConnectTimeout: time.Second * 10,
//HandleTimeout: time.Second * 10, // 默认为 0,不设置超时时间
}

Call 的超时处理

把选择权交给用户。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// client.go

// Call 同步调用,阻塞等待响应结果
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))

// 等待响应结果/超时
select {
case <-ctx.Done():
client.removeCall(call.Seq)
return errors.New("rpc client: call failed: " + ctx.Err().Error())
case call := <-call.Done:
return call.Error
}
}

用户如何使用?

1
2
3
ctx, _ := context.WithTimeout(context.Background(), time.Second)
var reply int
err := client.Call(ctx, "Foo.Sum", &Args{1, 2}, &reply)

handleRequest 超时处理

这里先不管读取数据了,仅对 handleRequest 做超时处理。handleRequest 其中又分为两个步骤,处理数据、发送数据,仅对处理数据做超时处理。

设置两个 channel,一个表示 req.svc.call 完成,一个表示 server.sendResponse 完成。当 req.svc.call 完成后,就不用管超时时间是否到达,继续让 server.sendResponse 发送数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// server.go

// handleRequest 处理请求:构造请求响应信息,调用 sendResponse 方法发送响应
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()

called := make(chan struct{}, 1) // 设置缓冲区为 1,防止在超时后,无人接收 channel 数据,导致 channel 发送时阻塞,导致 goroutine 泄漏
sent := make(chan struct{}, 1) // 设置缓冲区为 1,防止在超时后,无人接收 channel 数据,导致 channel 发送时阻塞,导致 goroutine 泄漏


go func() {
err := req.svc.call(req.mtype, req.argv, req.replyv) // 调用
called <- struct{}{} // 调用完成, 不管是否超时,继续发送数据
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
sent <- struct{}{}
return
}
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
sent <- struct{}{}
}()

if timeout == 0 { // 没有超时时间,直接等待
<-called
<-sent
return
}

// 有超时时间,使用 select 等待超时或调用完成
select {
case <-time.After(timeout):
server.sendResponse(cc, req.h, invalidRequest, sending)
case <-called: // 如果调用完成,则不管超时时间,等待 sent(仅对 req.svc.call 做超时处理)
<-sent
}
}

测试代码

测试客户端处理连接超时的情况

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// client_test.go
func TestClient_dialTimeout(t *testing.T) {
t.Parallel() // 设置测试项并行执行

f := func(conn net.Conn, opt *Option) (client *Client, err error) {
_ = conn.Close()
time.Sleep(time.Second * 2)
return nil, nil // 模拟了一个没有错误的客户端连接
}

l, _ := net.Listen("tcp", ":0")

// 测试客户端连接有超时处理的情况
t.Run("connect timeout", func(t *testing.T) {
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second})
_assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error")
})

// 测试客户端连接没有超时处理的情况
t.Run("0", func(t *testing.T) {
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0})
_assert(err == nil, "0 means no limit")
})
}

测试客户端处理调用超时和服务端处理超时的情况

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
// server_test.go
type ServiceTemp int

// ServiceTemp 有一个方法 Timeout,该方法耗时2s
func (s ServiceTemp) Timeout(args int, reply *int) error {
time.Sleep(time.Second * time.Duration(args))
*reply = 0
return nil
}

func TestClient_Call(t *testing.T) {
t.Parallel()

addrCh := make(chan string)
go func(chan string) { // 启动一个服务器,监听 0 端口,注册 ServiceTemp 类型的对象,然后启动 Accept 方法,等待客户端连接
var s ServiceTemp
_ = Register(&s)
l, _ := net.Listen("tcp", ":0")
addrCh <- l.Addr().String()
Accept(l)
}(addrCh)
addr := <-addrCh

// 测试客户端处理调用超时的情况
t.Run("client call timeout", func(t *testing.T) {
client, _ := Dial("tcp", addr)

ctx, _ := context.WithTimeout(context.Background(), time.Second) // 创建一个超时的 context,如果 1s 内没有返回结果,context.Done() 会传出信号 struct{}{}
var reply int
err := client.Call(ctx, "ServiceTemp.Timeout", 20, &reply) // 调用 ServiceTemp.Timeout 方法,Call 发生超时
_assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error")
})

// 测试服务端处理超时的情况
t.Run("server handle timeout", func(t *testing.T) {
client, _ := Dial("tcp", addr, &Option{HandleTimeout: time.Second})
var reply int
err := client.Call(context.Background(), "ServiceTemp.Timeout", 20, &reply) // 调用 ServiceTemp.Timeout 方法,服务端处理超时
_assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error")
})
}

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
# 代码结构
version_4_timeout
├── client.go (改动)
├── client_test.go (改动)
├── codec
│   ├── codec.go
│   └── gob.go
├── go.mod
├── main
│   └── main.go
├── server.go
├── service.go (改动)
└── service_test.go (改动)

5.支持 HTTP 协议

6.负载均衡

7.服务发现与注册中心


从零实现系列|RPC
https://www.aimtao.net/7days-rpc/
Posted on
2024-10-18
Licensed under