0.序言 为什么需要 RPC 最直观得是,客户端可以像调用本地程序一样,进行远程调用,使用者无需关注内部的实现细节。
另外一种广泛使用的调用方式是基于 HTTP 协议得 Restful API,让 gpt 总结一下。
对比维度 Restful API RPC 协议 基于 HTTP 协议(如 HTTPS) 通常使用自定义协议(如 TCP 或高效二进制协议) 通信方式 基于 HTTP 的请求-响应模型(如 GET/POST) 类似本地方法调用的请求-响应模型(对开发者透明) 报文格式 文本格式(JSON/XML),冗余较多 二进制编码(如 Protobuf),精简高效 性能 较低(文本解析、冗余数据) 较高(二进制压缩、高效序列化) 使用场景 通用性强,适合跨语言、对外的开放接口 高性能要求高,适合内部服务间通信 可扩展性 扩展依赖网关等中间件,功能相对固定 原生支持注册中心、负载均衡、超时处理等扩展功能 抽象模型 面向资源 的抽象,通过 URI 唯一标识资源,通过 HTTP 方法定义操作面向过程 的抽象,客户端直接调用服务端的函数或方法
RPC 框架需要解决哪些问题 以上种种,业务之外的公共能力,RPC 框架均需具备。
市面上的 RPC 框架 有哪些? 框架名称 特点 net/rpc Go 标准库,轻量级,支持 TCP/HTTP 协议,默认使用 Gob 编码 gRPC Google 开源,基于 HTTP/2 和 Protobuf,跨语言,高性能,支持流式通信 rpcx 高性能、支持多种编码(JSON/Protobuf),集成注册中心、负载均衡 go-micro 微服务框架,包含 RPC 模块,支持插件化(注册中心、编码协议等)
本文如何从零实现 RPC 框架? 从零实现标准库 net/rpc
新增了协议交换(protocol exchange)、注册中心(registry)、服务发现(service discovery)、负载均衡(load balance)、超时处理(timeout processing)等特性 1.服务端与消息编码 如何设计一个 RPC 请求 这个问题也可以问成:一个 RPC 请求,应该传入传出什么数据,选择使用什么数据结构?
举个例子:一个典型的 RPC 调用。
1 err = client.Call("Arith.Multiply" , args, &reply)
Arith.Multiply
:服务名 Arith
和方法名 Multipy
。args
:传入参数reply
:传出参数我们将服务名、方法名放在请求的 header 中,传输传出参数放在 body 中。
其中 header 定义为:
1 2 3 4 5 6 type Header struct { ServiceMethod string Seq uint64 Error string }
用什么编解码方式进行编解码 编码方式可以选择 json、gob、protobuf 等等,为了更好的兼容性,我们抽象出来 Codec 的类型。
1 2 3 4 5 6 7 type Codec interface { io.Closer ReadHeader(*Header) error ReadBody(interface {}) error Write(*Header, interface {}) error }
之后想使用哪个编码方式都可以,只要实现了 Codec 接口即可。
将各个编解码方式名称硬编码保存下来,并使用 map 保存各个编解码方式的构造函数。所以使用的流程是,通过编解码方式名称获取到该编解码方式的构造函数,调用该构造函数,创建编解码的实例。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 type Type string const ( GobType Type = "application/gob" JsonType Type = "application/json" )type NewCodecFunc func (io.ReadWriteCloser) Codecvar NewCodecFuncMap map [Type]NewCodecFuncfunc init () { NewCodecFuncMap = make (map [Type]NewCodecFunc) NewCodecFuncMap[GobType] = NewGobCodec }
如何实现编解码方式的接口 下面将以 Gob 为例,说明如何实现 Codec 接口。先定义 Gob 编解码的结构体 GobCodec 。
1 2 3 4 5 6 type GobCodec struct { conn io.ReadWriteCloser buf *bufio.Writer dec *gob.Decoder enc *gob.Encoder }
实现 Codec 的接口函数,主要是封装 encoding/gob 的方法。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 func (c *GobCodec) ReadHeader(header *Header) error { return c.dec.Decode(header) }func (c *GobCodec) ReadBody(body interface {}) error { return c.dec.Decode(body) }func (c *GobCodec) Write(header *Header, body interface {}) (err error ) { defer func () { _ = c.buf.Flush() if err != nil { _ = c.Close() } }() if err = c.enc.Encode(header); err != nil { log.Panicln("rpc codec: gob error encoding header: " , err) return err } if err = c.enc.Encode(body); err != nil { log.Panicln("rpc codec: gob error encoding body: " , err) return err } return nil }func (c *GobCodec) Close() error { return c.conn.Close() }
最后再实现 GobCodec 的构造函数,方便初始化时,穿给 Codec 接口变量。
1 2 3 4 5 6 7 8 9 func NewGobCodec (conn io.ReadWriteCloser) Codec { buf := bufio.NewWriter(conn) return &GobCodec{ conn: conn, buf: buf, dec: gob.NewDecoder(conn), enc: gob.NewEncoder(buf), } }
P.S. 为什么 GobCodec 需要加一个 buf ?
用于缓冲写入操作,减少系统调用次数。如果没有 buf,每次小数据写入直接触发系统调用,增加延迟;频繁的 I/O 操作,影响性能。
通信过程如何协商编码方式 以 HTTP 报文为例,HTTP 报文分为 header 和 body 两个部分。客户端和服务端收发消息时,只需要先解析 header 部分,就知道 body 的格式 Content-Type、长度 Content-Length。
在 RPC 协议的报文里,为了提升性能,仅在报文最开始规划固定的字节,来协商编码方式。
1 2 3 4 5 type Option struct { MagicNumber int CodecType codec.Type }
所以实际的报文是这样的。
1 | Option | Header1 | Body1 | Header2 | Body2 | ...
那 Option 使用什么编码方式呢?
为了方便,我们可以规定,Option 使用 JSON 编码,后面的 header、body 使用 Option 规定的编解码方式编解码。
1 2 | Option {MagicNumber: xxx, CodecType: xxx} | Header {ServiceMethod ...} | Body interface{} || <------ Option 固定使用 JSON 编码 ------> | <------- Header /Body 编码方式由 CodeType 决定 ------->|
如何实现一个服务端 先看一下,服务端要做什么,要实现哪些功能。
建立连接:实现 Accept 方法,等待 socket 连接,并开启协程处理请求。 处理请求:解析 Option 信息,检查 Option.MagicNumber 是匹配,根据 Option.CodecType 实例化编解码器。 读取请求:使用编解码器实例解码 header 和 body。 处理请求。 回复请求。 搞清楚需要实现哪些功能,就可以开始具体实现了。(可以先看主要的调用流程,具体实现实现细节晚点再看)
1 2 3 4 5 6 7 8 const MagicNumber = 0x3bef5c type Server struct {} func NewServer () *Server { return &Server{} }
Accept 处理连接:建立 socket 连接,使用 goroutine 处理连接 1 2 3 4 5 6 7 8 9 10 func (server *Server) Accept(lis net.Listener) { for { conn, err := lis.Accept() if err != nil { log.Println("rpc server: accept error: " , err) return } go server.ServeConn(conn) } }
ServeConn 处理消息:解析出 Option 信息,根据 CodecType 选择对应的 codec,调用 serveCodec 方法处理剩下的消息 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 func (server *Server) ServeConn(conn io.ReadWriteCloser) { defer func () { _ = conn.Close() }() var opt Option if err := json.NewDecoder(conn).Decode(&opt); err != nil { log.Println("rpc server: options error: " , err) return } if opt.MagicNumber != MagicNumber { log.Printf("rpc server: invalid magic number %x" , opt.MagicNumber) return } f := codec.NewCodecFuncMap[opt.CodecType] if f == nil { log.Printf("rpc server: invalid codec type %s" , opt.CodecType) return } server.serveCodec(f(conn)) }
serveCodec 处理请求:调用 readRequest 方法读取请求,调用 handleRequest 方法处理请求 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 func (server *Server) serveCodec(cc codec.Codec) { sending := new (sync.Mutex) wg := new (sync.WaitGroup) for { req, err := server.readRequest(cc) if err != nil { break } wg.Add(1 ) go server.handleRequest(cc, req, sending, wg) } wg.Wait() _ = cc.Close() }
readRequest 读取请求:调用 readRequestHeader 方法读取请求头,调用 ReadBody 方法读取请求参数,返回 request 结构体这里定义了请求体的结构体,h 储存 header,argv 储存传入参数,replyValue 储存传出参数,也就是返回值。 argv 和 replyValue 为什么都是 reflect.Value
类型呢? 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 type request struct { h *codec.Header argv, replyValue reflect.Value }func (server *Server) readRequest(cc codec.Codec) (*request, error ) { h, err := server.readRequestHeader(cc) if err != nil { return nil , err } req := request{ h: h, } req.argv = reflect.New(reflect.TypeOf("" )) if err = cc.ReadBody(req.argv.Interface()); err != nil { log.Println("rpc server: read body error: " , err) } return &req, nil }func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error ) { var h codec.Header if err := cc.ReadHeader(&h); err != nil { if err != io.EOF && !errors.Is(err, io.ErrUnexpectedEOF) { log.Println("rpc server: read header error: " , err) } return nil , err } return &h, nil }
如何实现一个客户端 完整代码 2.高性能客户端 3.服务注册 这一章是干什么的? 服务端的主要工作:
这一章主要完成处理请求部分,其中,从请求中获取方法名、参数在上一章已完成。client 在发送请求时,方法名在请求头里 client.header.ServiceMethod
,参数在请求体里 call.Args
。
所以我们重点关注
如何获取方法 服务端通过方法名获取实际的方法,需要两步:
在启动服务端服务时,我们可以将所有可调用的方法都加到一个 map 中。
当 client 调用 hello 时,我们从 map 中找到对应的方法皆可。
由于不同类型的结构体,可能有同名的结构体方法。我们的方法名,使用 结构体实例名.方法名
来作为方法名。比如上一章中的,结构体类型 Foo
,有 Sum
方法,客户端实例化 foo
对象后,远程调用 Client.Call 时,传入的就是 foo.Sum
。
如何方法注册到 map 中 首先,先说数据存储方式,简单来说是这样的,
flowchart LR
A[Server.serviceMap] -->|map储存多个service| B[Service<br>一个 service 表示一个变量类型]
B -->|service包含一个 map 变量| C[service.method]
C -->|map 储存多个方法| D[methodType<br>一个 methodType 表示一个该变量类型的方法]
具体数据存储方式设计如下:
服务端保存着一个 serviceMap。使用 serviceMap 保存变量类型及其方法。key 是变量类型名,value 是 service 类型实例。
1 2 3 4 5 type Server struct { serviceMap sync.Map }
service 类型定义如下,一个 service 表示一个变量类型。其中,也包含一个 map,保存着这个变量的方法。key 是方法名,value 是 methodType 类型实例。
1 2 3 4 5 6 7 8 type service struct { name string typ reflect.Type rcvr reflect.Value method map [string ]*methodType }
methodType 类型定义如下,一个 methodType 表示一个方法。
1 2 3 4 5 6 7 8 type methodType struct { method reflect.Method ArgType reflect.Type ReplyType reflect.Type numCalls uint64 }
最后,下面将以一个实际例子,演示如何将方法注册到 map 中。
graph LR
A[startServer, 启动服务器] --> B[Server.Register , 为 Foo 类型注册服务]
B --> C[newService, 为 Foo 类型创建 service 结构体]
C --> D[获取 rcvr 值, 类型名和类型]
C --> G[service.registerMethods, 获取方法列表]
G --> K[检查方法的参数和返回值]
K --> L[保存方法到 map 变量 service.method]
启动服务,将 Foo 类型注册到 rpc server 中。 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 type Foo int type Args struct { Num1, Num2 int }func (f Foo) Sum(args Args, reply *int ) error { *reply = args.Num1 + args.Num2 return nil }func startServer (addr chan string ) { var foo Foo if err := geerpc.Register(&foo); err != nil { log.Fatal("register error:" , err) } l, err := net.Listen("tcp" , ":0" ) if err != nil { log.Fatal("network error:" , err) } log.Println("start rpc server on" , l.Addr()) addr <- l.Addr().String() geerpc.Accept(l) }
为 Foo 类型创建 service 对象,并保存到 server 的 map 中。 1 2 3 4 5 6 7 8 9 10 11 var DefaultServer = NewServer()func Register (rcvr interface {}) error { return DefaultServer.Register(rcvr) }func (server *Server) Register(rcvr interface {}) error { s := newService(rcvr) if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { return errors.New("rpc: service already defined: " + s.name) } return nil }
创建 service 对象、获取对象的方法、检查方法参数和返回值,最后将方法保存到 map 中。 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 func newService (rcvr interface {}) *service { s := new (service) s.rcvr = reflect.ValueOf(rcvr) s.name = reflect.Indirect(s.rcvr).Type().Name() s.typ = reflect.TypeOf(rcvr) if !ast.IsExported(s.name) { log.Fatalf("rpc server: %s is not a valid service name" , s.name) } s.registerMethods() return s }func (s *service) registerMethods() { s.method = make (map [string ]*methodType) for i := 0 ; i < s.typ.NumMethod(); i++ { method := s.typ.Method(i) mType := method.Type if mType.NumIn() != 3 || mType.NumOut() != 1 { continue } if mType.Out(0 ) != reflect.TypeOf((*error )(nil )).Elem() { continue } argType, replyType := mType.In(1 ), mType.In(2 ) if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { continue } s.method[method.Name] = &methodType{ method: method, ArgType: argType, ReplyType: replyType, } log.Printf("rpc server: register %s.%s\n" , s.name, method.Name) } }func isExportedOrBuiltinType (t reflect.Type) bool { return ast.IsExported(t.Name()) || t.PkgPath() == "" }
如何从 map 获取方法 在 readRequest 时,从请求头中,获取方法名;根据方法名在对应的 service 的 map 中获取到方法、传入参数、传出参数。
graph LR
A[Server.readRequest] -->|解析请求头| B[Server.readRequestHeader]
B-->|返回请求头,包含方法名| A
A -->|获取方法| C[Server.findService]
A -->|获取方法传入参数,传出参数| D[GobCodec.ReadBody]
E[request]
C-->|方法保存到 request| E
D-->|传入参数,传出参数保存到 request| E
request 储存了一次远程调用请求的信息,方法、传入参数、传出参数指针。它的定义如下。 1 2 3 4 5 6 7 8 type request struct { h *codec.Header svc *service mtype *methodType argv, replyv reflect.Value }
readRequest 获取方法名、方法、传入参数、传出参数,并储存在 request 实例中,为方法调用做准备。 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 func (server *Server) readRequest(cc codec.Codec) (*request, error ) { h, err := server.readRequestHeader(cc) if err != nil { return nil , err } req := &request{h: h} req.svc, req.mtype, err = server.findService(h.ServiceMethod) if err != nil { return req, err } req.argv = req.mtype.newArgv() req.replyv = req.mtype.newReplyv() argvi := req.argv.Interface() if req.argv.Type().Kind() != reflect.Ptr { argvi = req.argv.Addr().Interface() } if err = cc.ReadBody(argvi); err != nil { log.Println("rpc server: read body err:" , err) return req, err } return req, nil }
另外,需要注意一下,
创建传入参数时,需区分值类型还是指针类型。 创建传出参数时,需对 map、slice 特殊处理。(因为 reflect.New 初始化时,map、slice 都是 nil) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 func (m *methodType) newArgv() reflect.Value { var argv reflect.Value if m.ArgType.Kind() == reflect.Ptr { argv = reflect.New(m.ArgType.Elem()) } else { argv = reflect.New(m.ArgType).Elem() } return argv }func (m *methodType) newReplyv() reflect.Value { replyv := reflect.New(m.ReplyType.Elem()) switch m.ReplyType.Elem().Kind() { case reflect.Map: replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) case reflect.Slice: replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0 , 0 )) } return replyv }
如何调用方法 在 handleRequest 时,通过 service.call 利用反射来调用函数。这里主要学习这个用法。
graph LR
A[Server.serveCodec]-->|获取请求信息| B[Server.readRequest]
A -->|调用方法| C[Server.handleRequest]
C --> D[request.svc.call, 也就是 Service.call]
D -->|使用反射调用| E[methodType.method.Func.Call]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) { defer wg.Done() err := req.svc.call(req.mtype, req.argv, req.replyv) if err != nil { req.h.Error = err.Error() server.sendResponse(cc, req.h, invalidRequest, sending) return } server.sendResponse(cc, req.h, req.replyv.Interface(), sending) }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 func (s *service) call(m *methodType, argv, replyv reflect.Value) error { atomic.AddUint64(&m.numCalls, 1 ) f := m.method.Func returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) if errInter := returnValues[0 ].Interface(); errInter != nil { return errInter.(error ) } return nil }
更深入了解反射 请看这篇 https://www.aimtao.net/go#11-reflect
完整代码 1 2 3 4 5 6 7 8 9 10 11 12 version_3_service ├── client.go ├── codec │ ├── codec.go │ └── gob.go ├── go.mod ├── main │ └── main.go ├── server.go ├── service.go └── service_test.go
main/main.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 package mainimport ( geerpc "GeeRPC" "log" "net" "sync" "time" )type Foo int type Args struct { Num1, Num2 int }func (f Foo) Sum(args Args, reply *int ) error { *reply = args.Num1 + args.Num2 return nil }func startServer (addr chan string ) { var foo Foo if err := geerpc.Register(&foo); err != nil { log.Fatal("register error:" , err) } l, err := net.Listen("tcp" , ":0" ) if err != nil { log.Fatal("network error:" , err) } log.Println("start rpc server on" , l.Addr()) addr <- l.Addr().String() geerpc.Accept(l) }func main () { log.SetFlags(0 ) addr := make (chan string ) go startServer(addr) client, _ := geerpc.Dial("tcp" , <-addr) defer func () { _ = client.Close() }() time.Sleep(time.Second) var wg sync.WaitGroup for i := 0 ; i < 5 ; i++ { wg.Add(1 ) go func (i int ) { defer wg.Done() args := &Args{Num1: i, Num2: i * i} var reply int if err := client.Call("Foo.Sum" , args, &reply); err != nil { log.Fatal("call Foo.Sum error:" , err) } log.Printf("%d + %d = %d" , args.Num1, args.Num2, reply) }(i) } wg.Wait() }
server.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 package geerpcimport ( "GeeRPC/codec" "encoding/json" "errors" "io" "log" "net" "reflect" "strings" "sync" )const MagicNumber = 0x3bef5c type Option struct { MagicNumber int CodecType codec.Type }type Server struct { serviceMap sync.Map }func NewServer () *Server { return &Server{} }func (server *Server) Accept(lis net.Listener) { for { conn, err := lis.Accept() if err != nil { log.Println("rpc server: accept error: " , err) return } go server.ServeConn(conn) } }func (server *Server) ServeConn(conn io.ReadWriteCloser) { defer func () { _ = conn.Close() }() var opt Option if err := json.NewDecoder(conn).Decode(&opt); err != nil { log.Println("rpc server: options error: " , err) return } if opt.MagicNumber != MagicNumber { log.Printf("rpc server: invalid magic number %x" , opt.MagicNumber) return } f := codec.NewCodecFuncMap[opt.CodecType] if f == nil { log.Printf("rpc server: invalid codec type %s" , opt.CodecType) return } server.serveCodec(f(conn)) }func (server *Server) serveCodec(cc codec.Codec) { sending := new (sync.Mutex) wg := new (sync.WaitGroup) for { req, err := server.readRequest(cc) if err != nil { break } wg.Add(1 ) go server.handleRequest(cc, req, sending, wg) } wg.Wait() _ = cc.Close() }type request struct { h *codec.Header svc *service mtype *methodType argv, replyv reflect.Value }func (server *Server) readRequest(cc codec.Codec) (*request, error ) { h, err := server.readRequestHeader(cc) if err != nil { return nil , err } req := &request{h: h} req.svc, req.mtype, err = server.findService(h.ServiceMethod) if err != nil { return req, err } req.argv = req.mtype.newArgv() req.replyv = req.mtype.newReplyv() argvi := req.argv.Interface() if req.argv.Type().Kind() != reflect.Ptr { argvi = req.argv.Addr().Interface() } if err = cc.ReadBody(argvi); err != nil { log.Println("rpc server: read body err:" , err) return req, err } return req, nil }func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error ) { var h codec.Header if err := cc.ReadHeader(&h); err != nil { if err != io.EOF && !errors.Is(err, io.ErrUnexpectedEOF) { log.Println("rpc server: read header error: " , err) } return nil , err } return &h, nil }var invalidRequest = struct {}{}func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) { defer wg.Done() err := req.svc.call(req.mtype, req.argv, req.replyv) if err != nil { req.h.Error = err.Error() server.sendResponse(cc, req.h, invalidRequest, sending) return } server.sendResponse(cc, req.h, req.replyv.Interface(), sending) }func (server *Server) sendResponse(cc codec.Codec, header *codec.Header, body interface {}, sending *sync.Mutex) { sending.Lock() defer sending.Unlock() if err := cc.Write(header, body); err != nil { log.Println("rpc server: write response error: " , err) } }var DefaultServer = NewServer()func Register (rcvr interface {}) error { return DefaultServer.Register(rcvr) }func (server *Server) Register(rcvr interface {}) error { s := newService(rcvr) if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { return errors.New("rpc: service already defined: " + s.name) } return nil }func (server *Server) findService(serviceMethod string ) (svc *service, mtype *methodType, err error ) { dot := strings.LastIndex(serviceMethod, "." ) if dot < 0 { err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) return } serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1 :] svci, ok := server.serviceMap.Load(serviceName) if !ok { err = errors.New("rpc server: can't find service " + serviceName) return } svc = svci.(*service) mtype = svc.method[methodName] if mtype == nil { err = errors.New("rpc server: can't find method " + methodName) } return }var DefaultOption = &Option{ MagicNumber: MagicNumber, CodecType: codec.GobType, }func Accept (lis net.Listener) { DefaultServer.Accept(lis) }
service.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 package geerpcimport ( "go/ast" "log" "reflect" "sync/atomic" )type methodType struct { method reflect.Method ArgType reflect.Type ReplyType reflect.Type numCalls uint64 }func (m *methodType) NumCalls() uint64 { return atomic.LoadUint64(&m.numCalls) }func (m *methodType) newArgv() reflect.Value { var argv reflect.Value if m.ArgType.Kind() == reflect.Ptr { argv = reflect.New(m.ArgType.Elem()) } else { argv = reflect.New(m.ArgType).Elem() } return argv }func (m *methodType) newReplyv() reflect.Value { replyv := reflect.New(m.ReplyType.Elem()) switch m.ReplyType.Elem().Kind() { case reflect.Map: replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem())) case reflect.Slice: replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0 , 0 )) } return replyv }type service struct { name string typ reflect.Type rcvr reflect.Value method map [string ]*methodType }func newService (rcvr interface {}) *service { s := new (service) s.rcvr = reflect.ValueOf(rcvr) s.name = reflect.Indirect(s.rcvr).Type().Name() s.typ = reflect.TypeOf(rcvr) if !ast.IsExported(s.name) { log.Fatalf("rpc server: %s is not a valid service name" , s.name) } s.registerMethods() return s }func (s *service) registerMethods() { s.method = make (map [string ]*methodType) for i := 0 ; i < s.typ.NumMethod(); i++ { method := s.typ.Method(i) mType := method.Type if mType.NumIn() != 3 || mType.NumOut() != 1 { continue } if mType.Out(0 ) != reflect.TypeOf((*error )(nil )).Elem() { continue } argType, replyType := mType.In(1 ), mType.In(2 ) if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) { continue } s.method[method.Name] = &methodType{ method: method, ArgType: argType, ReplyType: replyType, } log.Printf("rpc server: register %s.%s\n" , s.name, method.Name) } }func isExportedOrBuiltinType (t reflect.Type) bool { return ast.IsExported(t.Name()) || t.PkgPath() == "" }func (s *service) call(m *methodType, argv, replyv reflect.Value) error { atomic.AddUint64(&m.numCalls, 1 ) f := m.method.Func returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv}) if errInter := returnValues[0 ].Interface(); errInter != nil { return errInter.(error ) } return nil }
service_test.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 package geerpcimport ( "fmt" "reflect" "testing" )type Foo int type Args struct { Num1, Num2 int }func (f Foo) Sum(args Args, reply *int ) error { *reply = args.Num1 + args.Num2 return nil }func (f Foo) sum(args Args, reply *int ) error { *reply = args.Num1 + args.Num2 return nil }func _assert (condition bool , msg string , v ...interface {}) { if !condition { panic (fmt.Sprintf("assertion failed: " +msg, v...)) } }func TestNewService (t *testing.T) { var foo Foo s := newService(&foo) _assert(len (s.method) == 1 , "wrong service Method, expect 1, but got %d" , len (s.method)) mType := s.method["Sum" ] _assert(mType != nil , "wrong Method, Sum shouldn't nil" ) }func TestMethodType_Call (t *testing.T) { var foo Foo s := newService(&foo) mType := s.method["Sum" ] argv := mType.newArgv() replyv := mType.newReplyv() argv.Set(reflect.ValueOf(Args{Num1: 1 , Num2: 3 })) err := s.call(mType, argv, replyv) _assert(err == nil && *replyv.Interface().(*int ) == 4 && mType.NumCalls() == 1 , "failed to call Foo.Sum" ) }
4.超时处理 哪些地方需要超时处理 简单来看整个流程:
客户端:
拨号连接 ✓
远程调用(发送数据)✓
等待服务端处理 ✓
接收返回的结果(接收数据)✓
服务端:
端口监听
接收请求信息(接收数据)✓
调用方法处理请求信息(处理数据)✓
返回请求结果(发送数据)✓
以上打 ✓ 的步骤,均可能出现超时,主要为,
客户端建立连接超时 发送数据 处理数据 接收数据客户端等待服务端响应超时 客户端/服务端读取报文超时 基于这个超时的情景,我们可以在以下三个地方设置超时处理机制。
客户端:
建立连接 Client.call()
的整个过程(包含发送数据、等待处理、接收数据)服务端:
Server.handleRequest()
的整个过程(包括处理数据、发送数据,接收数据先不管了)建立连接的超时处理 启动一个 goroutine 来拨号建立连接,使用 select 设置超时器,阻塞等待拨号结果,如果在接收拨号结果之前,先收到了超时信号,则进行超时处理。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 type NewClientFunc func (conn net.Conn, opt *Option) (*Client, error )func Dial (network, address string , opts ...*Option) (client *Client, err error ) { return dialTimeout(NewClient, network, address, opts...) }type dialResult struct { client *Client err error }func dialTimeout (f NewClientFunc, network, address string , opts ...*Option) (client *Client, err error ) { opt, err := parseOptions(opts...) if err != nil { return nil , err } result := make (chan dialResult) defer func () { if err != nil { client = nil } }() go func () { conn, err := net.DialTimeout(network, address, opt.ConnectTimeout) if err != nil { result <- dialResult{client: nil , err: err} return } client, err = f(conn, opt) result <- dialResult{client: client, err: err} }() select { case <-time.After(opt.ConnectTimeout): return nil , fmt.Errorf("rpc client: connect timeout: %v" , opt.ConnectTimeout) case result := <-result: return result.client, result.err } }
超时的时长,设置在 Option struct 内,便于用户自定义。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 type Option struct { MagicNumber int CodecType codec.Type ConnectTimeout time.Duration HandleTimeout time.Duration }var DefaultOption = &Option{ MagicNumber: MagicNumber, CodecType: codec.GobType, ConnectTimeout: time.Second * 10 , }
Call 的超时处理 把选择权交给用户。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 func (client *Client) Call(ctx context.Context, serviceMethod string , args, reply interface {}) error { call := client.Go(serviceMethod, args, reply, make (chan *Call, 1 )) select { case <-ctx.Done(): client.removeCall(call.Seq) return errors.New("rpc client: call failed: " + ctx.Err().Error()) case call := <-call.Done: return call.Error } }
用户如何使用?
1 2 3 ctx, _ := context.WithTimeout(context.Background(), time.Second)var reply int err := client.Call(ctx, "Foo.Sum" , &Args{1 , 2 }, &reply)
handleRequest 超时处理 这里先不管读取数据了,仅对 handleRequest 做超时处理。handleRequest 其中又分为两个步骤,处理数据、发送数据,仅对处理数据做超时处理。
设置两个 channel,一个表示 req.svc.call
完成,一个表示 server.sendResponse
完成。当 req.svc.call
完成后,就不用管超时时间是否到达,继续让 server.sendResponse 发送数据。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) { defer wg.Done() called := make (chan struct {}, 1 ) sent := make (chan struct {}, 1 ) go func () { err := req.svc.call(req.mtype, req.argv, req.replyv) called <- struct {}{} if err != nil { req.h.Error = err.Error() server.sendResponse(cc, req.h, invalidRequest, sending) sent <- struct {}{} return } server.sendResponse(cc, req.h, req.replyv.Interface(), sending) sent <- struct {}{} }() if timeout == 0 { <-called <-sent return } select { case <-time.After(timeout): server.sendResponse(cc, req.h, invalidRequest, sending) case <-called: <-sent } }
测试代码 测试客户端处理连接超时的情况
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 func TestClient_dialTimeout (t *testing.T) { t.Parallel() f := func (conn net.Conn, opt *Option) (client *Client, err error ) { _ = conn.Close() time.Sleep(time.Second * 2 ) return nil , nil } l, _ := net.Listen("tcp" , ":0" ) t.Run("connect timeout" , func (t *testing.T) { _, err := dialTimeout(f, "tcp" , l.Addr().String(), &Option{ConnectTimeout: time.Second}) _assert(err != nil && strings.Contains(err.Error(), "connect timeout" ), "expect a timeout error" ) }) t.Run("0" , func (t *testing.T) { _, err := dialTimeout(f, "tcp" , l.Addr().String(), &Option{ConnectTimeout: 0 }) _assert(err == nil , "0 means no limit" ) }) }
测试客户端处理调用超时和服务端处理超时的情况
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 type ServiceTemp int func (s ServiceTemp) Timeout(args int , reply *int ) error { time.Sleep(time.Second * time.Duration(args)) *reply = 0 return nil }func TestClient_Call (t *testing.T) { t.Parallel() addrCh := make (chan string ) go func (chan string ) { var s ServiceTemp _ = Register(&s) l, _ := net.Listen("tcp" , ":0" ) addrCh <- l.Addr().String() Accept(l) }(addrCh) addr := <-addrCh t.Run("client call timeout" , func (t *testing.T) { client, _ := Dial("tcp" , addr) ctx, _ := context.WithTimeout(context.Background(), time.Second) var reply int err := client.Call(ctx, "ServiceTemp.Timeout" , 20 , &reply) _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error" ) }) t.Run("server handle timeout" , func (t *testing.T) { client, _ := Dial("tcp" , addr, &Option{HandleTimeout: time.Second}) var reply int err := client.Call(context.Background(), "ServiceTemp.Timeout" , 20 , &reply) _assert(err != nil && strings.Contains(err.Error(), "handle timeout" ), "expect a timeout error" ) }) }
完整代码 1 2 3 4 5 6 7 8 9 10 11 12 # 代码结构 version_4_timeout ├── client.go ├── codec │ ├── codec.go │ └── gob.go ├── go .mod ├── main │ └── main.go ├── server.go ├── service.go └── service_test.go
server.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 package geerpcimport ( "encoding/json" "errors" "fmt" "geerpc/codec" "io" "log" "net" "reflect" "strings" "sync" "time" )const MagicNumber = 0x3bef5c type Option struct { MagicNumber int CodecType codec.Type ConnectTimeout time.Duration HandleTimeout time.Duration }type Server struct { serviceMap sync.Map }func NewServer () *Server { return &Server{} }func (server *Server) Accept(lis net.Listener) { for { conn, err := lis.Accept() if err != nil { log.Println("rpc server: accept error: " , err) return } go server.ServeConn(conn) } }func (server *Server) ServeConn(conn io.ReadWriteCloser) { defer func () { _ = conn.Close() }() var opt Option if err := json.NewDecoder(conn).Decode(&opt); err != nil { log.Println("rpc server: options error: " , err) return } if opt.MagicNumber != MagicNumber { log.Printf("rpc server: invalid magic number %x" , opt.MagicNumber) return } f := codec.NewCodecFuncMap[opt.CodecType] if f == nil { log.Printf("rpc server: invalid codec type %s" , opt.CodecType) return } server.serveCodec(f(conn), &opt) }func (server *Server) serveCodec(cc codec.Codec, opt *Option) { sending := new (sync.Mutex) wg := new (sync.WaitGroup) for { req, err := server.readRequest(cc) if err != nil { break } wg.Add(1 ) go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout) } wg.Wait() _ = cc.Close() }type request struct { h *codec.Header svc *service mtype *methodType argv, replyv reflect.Value }func (server *Server) readRequest(cc codec.Codec) (*request, error ) { h, err := server.readRequestHeader(cc) if err != nil { return nil , err } req := &request{h: h} req.svc, req.mtype, err = server.findService(h.ServiceMethod) if err != nil { return req, err } req.argv = req.mtype.newArgv() req.replyv = req.mtype.newReplyv() argvi := req.argv.Interface() if req.argv.Type().Kind() != reflect.Ptr { argvi = req.argv.Addr().Interface() } if err = cc.ReadBody(argvi); err != nil { log.Println("rpc server: read body err:" , err) return req, err } return req, nil }func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error ) { var h codec.Header if err := cc.ReadHeader(&h); err != nil { if err != io.EOF && !errors.Is(err, io.ErrUnexpectedEOF) { log.Println("rpc server: read header error: " , err) } return nil , err } return &h, nil }var invalidRequest = struct {}{}func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) { defer wg.Done() called := make (chan struct {}, 1 ) sent := make (chan struct {}, 1 ) go func () { err := req.svc.call(req.mtype, req.argv, req.replyv) called <- struct {}{} if err != nil { req.h.Error = err.Error() server.sendResponse(cc, req.h, invalidRequest, sending) sent <- struct {}{} return } server.sendResponse(cc, req.h, req.replyv.Interface(), sending) sent <- struct {}{} }() if timeout == 0 { <-called <-sent return } select { case <-time.After(timeout): req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s" , timeout) server.sendResponse(cc, req.h, invalidRequest, sending) case <-called: <-sent } }func (server *Server) sendResponse(cc codec.Codec, header *codec.Header, body interface {}, sending *sync.Mutex) { sending.Lock() defer sending.Unlock() if err := cc.Write(header, body); err != nil { log.Println("rpc server: write response error: " , err) } }var DefaultServer = NewServer()func Register (rcvr interface {}) error { return DefaultServer.Register(rcvr) }func (server *Server) Register(rcvr interface {}) error { s := newService(rcvr) if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup { return errors.New("rpc: service already defined: " + s.name) } return nil }func (server *Server) findService(serviceMethod string ) (svc *service, mtype *methodType, err error ) { dot := strings.LastIndex(serviceMethod, "." ) if dot < 0 { err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod) return } serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1 :] svci, ok := server.serviceMap.Load(serviceName) if !ok { err = errors.New("rpc server: can't find service " + serviceName) return } svc = svci.(*service) mtype = svc.method[methodName] if mtype == nil { err = errors.New("rpc server: can't find method " + methodName) } return }var DefaultOption = &Option{ MagicNumber: MagicNumber, CodecType: codec.GobType, ConnectTimeout: time.Second * 10 , }func Accept (lis net.Listener) { DefaultServer.Accept(lis) }
client.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 package geerpcimport ( "context" "encoding/json" "errors" "fmt" "geerpc/codec" "log" "net" "sync" "time" )type Call struct { Seq uint64 ServiceMethod string Args interface {} Reply interface {} Error error Done chan *Call }func (call *Call) done() { call.Done <- call }type Client struct { cc codec.Codec opt *Option sending sync.Mutex header codec.Header mu sync.Mutex seq uint64 pending map [uint64 ]*Call closing bool shutdown bool }var ErrShutdown = errors.New("client has been shut down" )func (client *Client) Close() error { client.mu.Lock() defer client.mu.Unlock() if client.closing { return ErrShutdown } client.closing = true return client.cc.Close() }func (client *Client) IsAvailable() bool { client.mu.Lock() defer client.mu.Unlock() return !client.shutdown && !client.closing }func (client *Client) registerCall(call *Call) (uint64 , error ) { client.mu.Lock() defer client.mu.Unlock() if client.closing || client.shutdown { return 0 , ErrShutdown } call.Seq = client.seq client.pending[call.Seq] = call client.seq++ return call.Seq, nil }func (client *Client) removeCall(seq uint64 ) *Call { client.mu.Lock() defer client.mu.Unlock() call := client.pending[seq] delete (client.pending, seq) return call }func (client *Client) terminateCalls(err error ) { client.sending.Lock() defer client.sending.Unlock() client.mu.Lock() defer client.mu.Unlock() client.shutdown = true for _, call := range client.pending { call.Error = err call.done() } }func (client *Client) receive() { var err error for err == nil { var h codec.Header if err = client.cc.ReadHeader(&h); err != nil { break } call := client.removeCall(h.Seq) switch { case call == nil : err = client.cc.ReadBody(nil ) case h.Error != "" : call.Error = fmt.Errorf(h.Error) err = client.cc.ReadBody(nil ) call.done() default : err = client.cc.ReadBody(call.Reply) if err != nil { call.Error = errors.New("reading body " + err.Error()) } call.done() } } client.terminateCalls(err) }func NewClient (conn net.Conn, opt *Option) (*Client, error ) { if err := json.NewEncoder(conn).Encode(opt); err != nil { log.Println("rpc client: options error: " , err) _ = conn.Close() return nil , err } newCodecFunc := codec.NewCodecFuncMap[opt.CodecType] if newCodecFunc == nil { err := fmt.Errorf("invalid codec type %s" , opt.CodecType) log.Println("rpc client: codec error:" , err) return nil , err } return newClientCodec(newCodecFunc(conn), opt), nil }func newClientCodec (cc codec.Codec, opt *Option) *Client { client := &Client{ seq: 1 , cc: cc, opt: opt, pending: make (map [uint64 ]*Call), } go client.receive() return client }func parseOptions (opts ...*Option) (*Option, error ) { if len (opts) == 0 || opts[0 ] == nil { return DefaultOption, nil } if len (opts) != 1 { return nil , errors.New("number of options is more than 1" ) } opt := opts[0 ] opt.MagicNumber = DefaultOption.MagicNumber if opt.CodecType == "" { opt.CodecType = DefaultOption.CodecType } return opt, nil }type NewClientFunc func (conn net.Conn, opt *Option) (*Client, error )func Dial (network, address string , opts ...*Option) (client *Client, err error ) { return dialTimeout(NewClient, network, address, opts...) }type dialResult struct { client *Client err error }func dialTimeout (f NewClientFunc, network, address string , opts ...*Option) (client *Client, err error ) { opt, err := parseOptions(opts...) if err != nil { return nil , err } result := make (chan dialResult, 1 ) defer func () { if err != nil { client = nil } }() go func () { conn, err := net.DialTimeout(network, address, opt.ConnectTimeout) if err != nil { result <- dialResult{client: nil , err: err} return } client, err = f(conn, opt) result <- dialResult{client: client, err: err} }() if opt.ConnectTimeout == 0 { result := <-result return result.client, result.err } select { case <-time.After(opt.ConnectTimeout): return nil , fmt.Errorf("rpc client: connect timeout: %v" , opt.ConnectTimeout) case result := <-result: return result.client, result.err } }func (client *Client) send(call *Call) { client.sending.Lock() defer client.sending.Unlock() seq, err := client.registerCall(call) if err != nil { call.Error = err call.done() return } client.header.ServiceMethod = call.ServiceMethod client.header.Seq = seq client.header.Error = "" if err := client.cc.Write(&client.header, call.Args); err != nil { call := client.removeCall(seq) if call != nil { call.Error = err call.done() } } }func (client *Client) Go(serviceMethod string , args, reply interface {}, done chan *Call) *Call { if done == nil { done = make (chan *Call, 10 ) } else if cap (done) == 0 { log.Panic("rpc client: done channel is unbuffered" ) } call := &Call{ ServiceMethod: serviceMethod, Args: args, Reply: reply, Done: done, } client.send(call) return call }func (client *Client) Call(ctx context.Context, serviceMethod string , args, reply interface {}) error { call := client.Go(serviceMethod, args, reply, make (chan *Call, 1 )) select { case <-ctx.Done(): client.removeCall(call.Seq) return errors.New("rpc client: call failed: " + ctx.Err().Error()) case call := <-call.Done: return call.Error } }
server_test.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 package geerpcimport ( "context" "fmt" "net" "reflect" "strings" "testing" "time" )type Foo int type Args struct { Num1, Num2 int }func (f Foo) Sum(args Args, reply *int ) error { *reply = args.Num1 + args.Num2 return nil }func (f Foo) sum(args Args, reply *int ) error { *reply = args.Num1 + args.Num2 return nil }func _assert (condition bool , msg string , v ...interface {}) { if !condition { panic (fmt.Sprintf("assertion failed: " +msg, v...)) } }func TestNewService (t *testing.T) { var foo Foo s := newService(&foo) _assert(len (s.method) == 1 , "wrong service Method, expect 1, but got %d" , len (s.method)) mType := s.method["Sum" ] _assert(mType != nil , "wrong Method, Sum shouldn't nil" ) }func TestMethodType_Call (t *testing.T) { var foo Foo s := newService(&foo) mType := s.method["Sum" ] argv := mType.newArgv() replyv := mType.newReplyv() argv.Set(reflect.ValueOf(Args{Num1: 1 , Num2: 3 })) err := s.call(mType, argv, replyv) _assert(err == nil && *replyv.Interface().(*int ) == 4 && mType.NumCalls() == 1 , "failed to call Foo.Sum" ) }type ServiceTemp int func (s ServiceTemp) Timeout(args int , reply *int ) error { time.Sleep(time.Second * time.Duration(args)) *reply = 0 return nil }func TestClient_Call (t *testing.T) { t.Parallel() addrCh := make (chan string ) go func (chan string ) { var s ServiceTemp _ = Register(&s) l, _ := net.Listen("tcp" , ":0" ) addrCh <- l.Addr().String() Accept(l) }(addrCh) addr := <-addrCh t.Run("client call timeout" , func (t *testing.T) { client, _ := Dial("tcp" , addr) ctx, _ := context.WithTimeout(context.Background(), time.Second) var reply int err := client.Call(ctx, "ServiceTemp.Timeout" , 20 , &reply) _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error" ) }) t.Run("server handle timeout" , func (t *testing.T) { client, _ := Dial("tcp" , addr, &Option{HandleTimeout: time.Second}) var reply int err := client.Call(context.Background(), "ServiceTemp.Timeout" , 20 , &reply) _assert(err != nil && strings.Contains(err.Error(), "handle timeout" ), "expect a timeout error" ) }) }
client_test.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 package geerpcimport ( "net" "strings" "testing" "time" )func TestClient_dialTimeout (t *testing.T) { t.Parallel() f := func (conn net.Conn, opt *Option) (client *Client, err error ) { _ = conn.Close() time.Sleep(time.Second * 2 ) return nil , nil } l, _ := net.Listen("tcp" , ":0" ) t.Run("connect timeout" , func (t *testing.T) { _, err := dialTimeout(f, "tcp" , l.Addr().String(), &Option{ConnectTimeout: time.Second}) _assert(err != nil && strings.Contains(err.Error(), "connect timeout" ), "expect a timeout error" ) }) t.Run("0" , func (t *testing.T) { _, err := dialTimeout(f, "tcp" , l.Addr().String(), &Option{ConnectTimeout: 0 }) _assert(err == nil , "0 means no limit" ) }) }
5.支持 HTTP 协议 为什么需要支持 HTTP 协议 当前我们 RPC 框架是基于 TCP 协议的。
1 2 l, _ := net.Listen("tcp" , ":9999" )
为什么需要支持 HTTP 协议?
兼容性:很多应用是基于 HTTP 协议的,RPC 框架支持 HTTP,便于兼容。 方便调试:通过 HTTP 提供调试界面,可以实时监控RPC服务的状态和调用统计,方便开发和运维。 安全性:HTTP 可以和 TLS/SSL 结合,提供安全的通信方式。 如何支持 HTTP 协议 基于通信流程梳理需要做哪些事情:
客户端向 RPC 服务器,发送 HTTP CONNECT 方法请求 服务端响应 HTTP CONNECT 方法请求,表示建立连接 客户端使用建立的连接,发送 RPC 报文(先发送 option,再发送 RPC 报文) 服务端处理 RPC 请求,并响应。 客户端发送请求 服务端响应请求 在哪里监听?