【50/50】WebSocket协议详解与实现
Photo by Kristine Weilert on Unsplash
这是2019年50篇文章的最后一篇
今天,我们来了解一下WebSocket协议,然后用Python基于TCP实现一个简单的WebSocket协议解析。
首先,我们来了解一下什么是WebSocket协议。
WebSocket协议是为了解决HTTP协议只能由客户端主动发起请求,服务器不能主动向客户端发送数据的问题而诞生的。WebSocket也是基于TCP协议的,握手过程使用HTTP,传输的内容可以是文本,也可以是二进制内容。使用WebSocket可以实现客户端和服务器的双向实时通信。
然后,我们来看一下WebSocket协议的握手过程。
WebSocket的握手连接过程是基于HTTP的,客户端首先要发送一个HTTP的请求,并且带上一些特定的请求头,然后服务器按照协议规定返回101状态码和相应的响应头完成连接的建立,之后就可以通过这个连接进行双向通信了。
那握手过程的请求和响应有什么规定呢?
首先,客户端向服务器发送一个GET请求,和HTTP协议不同的是WebSocket协议需要添加4个特殊的请求头,例如:
GET / HTTP/1.1
Host: example.com:8000
Connection: Upgrade
Upgrade: websocket
Sec-WebSocket-Version: 13
Sec-WebSocket-Key: YtqzKW5j8rYIYauXEwcJFw==
Connection: Upgrade:表示需要升级协议。Upgrade: websocket:表示需要升级为WebSocket协议。Sec-WebSocket-Version: 13:表示WebSocket协议的版本。Sec-WebSocket-Key: YtqzKW5j8rYIYauXEwcJFw==:是客户端随机生成的。
然后服务器会返回101 Switching Protocols,并且会增加3个特殊的响应头,例如:
HTTP/1.1 101 Switching Protocols
Connection: Upgrade
Sec-WebSocket-Accept: 4q50AMbiRegDNPtQYmvSw+HGHv8=
Upgrade: WebSocket
响应头中的Connection和Upgrade就不说了,我们来说说这个Sec-WebSocket-Accept响应头。应该不难猜到,这个响应头应该和请求头中的Sec-WebSocket-Key有关,但是具体是如何生成的呢?
其实也很简单,就是服务器获取到Sec-WebSocket-Key请求头之后,将Sec-WebSocket-Key请求头的值拼接上一个特殊的字符串258EAFA5-E914-47DA-95CA-C5AB0DC85B11,然后使用SHA-1算法计算出摘要并使用base64编码就得到了响应头Sec-WebSocket-Accept的值。由于WebSocket的握手过程是基于HTTP的,所以这两个头是为了减少恶意连接、意外连接而设置的。
我们来用Python来实现这个key生成的过程:
import hashlibimport base64def gen_websocket_key(key):sha1 = hashlib.sha1()magic_value = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sha1.update((key + magic_value).encode())return base64.b64encode(sha1.digest())if __name__ == "__main__":print(gen_websocket_key("YtqzKW5j8rYIYauXEwcJFw=="))
运行输出:4q50AMbiRegDNPtQYmvSw+HGHv8=
在WebSocket握手完成后,就可以发送WebSocket协议的数据了。我们来继续看下WebSocket协议的格式是什么样的:
0 1 2 30 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1+-+-+-+-+-------+-+-------------+-------------------------------+|F|R|R|R| opcode|M| Payload len | Extended payload length ||I|S|S|S| (4) |A| (7) | (16/64) ||N|V|V|V| |S| | (if payload len==126/127) || |1|2|3| |K| | |+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +| Extended payload length continued, if payload len == 127 |+ - - - - - - - - - - - - - - - +-------------------------------+| |Masking-key, if MASK set to 1 |+-------------------------------+-------------------------------+| Masking-key (continued) | Payload Data |+-------------------------------- - - - - - - - - - - - - - - - +: Payload Data continued ... :+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +| Payload Data continued ... |+---------------------------------------------------------------+
我们来看下这些字段都有什么用,格式是什么样的:
FIN:如果是1,表示这是消息的最后一个分片,如果是0,表示不是是消息的最后一个分片。
RSV1,RSV2,RSV3:用于WebSocket的扩展,一般情况下全为0。
opcode:数据的类型。
MASK:是否对数据进行掩码操作。客户端发送消息必须要进行掩码。
Payload len:数据长度,单位是字节。
Masking-key:掩码。如果MASK为1,则有4字节掩码,否则没有掩码。
Payload Data:数据。
opcode可选项:
0x0:表示一个延续帧。当opcode为0时,表示本次数据传输采用了数据分片,当前收到的数据帧为其中一个数据分片。
0x1:表示这是一个文本帧。
0x2:表示这是一个二进制帧。
0x3-0x7:保留的操作代码,用于后续定义的非控制帧。
0x8:表示连接断开。
0x9:表示这是一个ping操作。
0xA:表示这是一个pong操作。
0xB-0xF:保留的操作代码,用于后续定义的控制帧。
数据长度:
如果Payload len的值小于125,数据的长度就是Payload len的无符号整型的值。
如果Payload len的值是126,数据的长度是Payload len后边的16位的无符号整型的值。
如果Payload len的值是127,数据的长度是Payload len后边的64位的无符号整型的值(最高位为0)。
掩码算法:如果MASK设置为1,则数据是经过Masking-key掩码运算后的,掩码操作是对每一个字节做异或操作,需要对收到的数据再次进行异或操作才能获取到原始数据。Masking-key有4个字节,需要对Payload len和Masking-key进行循环异或操作进行掩码。使用Python代码来描述就是:
for i in range(len(payload))
payload[i] = payload[i] ^ masking_key[i%4]
一个数两次异或同一个数可以得到它本身
我们来使用Python的twisted这个网络库来基于TCP实现一个WebSocket的服务器:
import hashlibimport base64import refrom collections import defaultdictfrom twisted.internet.protocol import Factory, connectionDonefrom twisted.internet import reactor, protocolopcode_map = {"text": 1,"binary": 2,"close": 8,"ping": 9,"pong": 10,}# 计算Sec-WebSocket-Acceptdef gen_sec_websocket_accept(key):sha1 = hashlib.sha1()magic_value = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sha1.update((key + magic_value).encode())return base64.b64encode(sha1.digest())class Chat(protocol.Protocol):def __init__(self, rooms):self.rooms = rooms # 所有聊天室self.connection_name = ""self.room = "default"self.nick = "Anonymous"self.upgraded = False # 是否已升级# 处理方法和opcode映射self.handle_map = {1: self.handle_text,2: self.handle_binary,8: self.handle_close,9: self.hanlde_ping}def connectionMade(self):self.connection_name = "%s:%s" % self.transport.clientdef connectionLost(self, reason=connectionDone):if self.connection_name in self.rooms:print(f"connectionLost: {self.connection_name}")self.rooms[self.connection_name].remove(self.transport.socket)# 协议升级def upgrade(self, data):data = data.decode("utf-8")lines = data.split("\r\n")request_line = lines[0]try:method, path, version = request_line.split()except ValueError:raise Exception(f"error request line: {request_line}")if method != "GET" or version != "HTTP/1.1":raise Exception(f"error request line: {request_line}")# 匹配聊天室名字和用户nickr = re.compile(r"/(?P<room>\w+)(\?nick=(?P<nick>\w+))?")if r.match(path):res = r.search(path)self.room = res.group("room")nick = res.group("nick")self.nick = nick if nick else self.nickself.rooms[self.room].add(self)req_headers = dict()for line in lines:if line and len(line.split(":")) == 2:key, value = line.split(":")req_headers[key.strip().lower()] = value.strip()if "upgrade" not in req_headers.get("connection", "").lower():raise Exception("connection error")if req_headers.get("upgrade", "") != "websocket":raise Exception("upgrade error")if req_headers.get("sec-websocket-version", "") != "13":raise Exception("websocket version error")if not req_headers.get("sec-websocket-key"):raise Exception("no Sec-WebSocket-Key")resp = ["HTTP/1.1 101 Switching Protocols"]resp_headers = {"Connection": "Upgrade", "Upgrade": "WebSocket","Sec-WebSocket-Accept": gen_sec_websocket_accept(req_headers["sec-websocket-key"]).decode()}for key, value in resp_headers.items():resp.append(f"{key}: {value}")resp.append("\r\n")self.transport.write("\r\n".join(resp).encode())print("switching protocols success")self.upgraded = True# 处理WebSocket数据def parse(self, data):if len(data) < 2:raise Exception("data error")# TODO:处理FIN=0if data[1] & 0x80 != 0x80:raise Exception("MASK error")length = data[1] & 0x7fmasking_key = data[2:6]payload_data = data[6:6+length]if length == 126:if len(data) < 4:raise Exception("data error")length = int.from_bytes(data[2:4], byteorder='big')masking_key = data[4:8]payload_data = data[8:8+length]if length == 127:if len(data) < 8:raise Exception("data error")length = int.from_bytes(data[4:8], byteorder='big')masking_key = data[8:12]payload_data = data[12:12+length]message = bytearray()for i in range(length):t = payload_data[i] ^ masking_key[i % 4]message.append(t)opcode = data[0] & 0xfreturn opcode, bytes(message)def handle(self, data):opcode, message = self.parse(data)handle = self.handle_map.get(opcode)if not handle:raise Exception("not support")handle(message)def handle_binary(self, message):raise NotImplementedError# 处理文本消息def handle_text(self, message):message = self.nick.encode() + b": " + messagedata = self.make_data(opcode_map["text"], message)for c in self.rooms[self.room]:c.transport.write(data)# 处理ping消息def hanlde_ping(self, message):self.transport.write(self.make_data(opcode_map["pong"]))# 处理close消息def handle_close(self, message):self.transport.write(self.make_data(opcode_map["close"]))# 生成响应消息def make_data(self, opcode, message=b""):data = bytearray()length = len(message)data.append(0x80 | opcode)if length < 126:data.append(length)elif 126 <= length <= 65536:data.append(126)data.extend(length.to_bytes(2, 'big'))elif 65536 < length < 2**31:data.append(127)data.extend(length.to_bytes(4, 'big'))else:raise Exception("data too long")data.extend(message)return bytes(data)# 接受数据回调def dataReceived(self, data):print("dataReceived:", data)if not self.upgraded:try:self.upgrade(data)except Exception as err:print("upgrade error: ", err)else:try:self.handle(data)except Exception as err:print("handle error: ", err)class ChatFactory(Factory):def __init__(self):# 聊天室,key为聊天室名字,value为聊天室下所有的Chat对象集合self.rooms = defaultdict(set)def buildProtocol(self, addr):return Chat(self.rooms)def main():reactor.listenTCP(8000, ChatFactory(), interface="0.0.0.0")reactor.run()if __name__ == "__main__":main()
我们来使用http://coolaf.com/tool/chattest这个网站测试一下:
这个程序仅作为理解WebSocket协议的一个示例,并没有完全实现WebSocket协议,并且有些方法的设计和错误处理并不是太好,大家可以自己阅读文档自己来实现一个更好的。
twisted这个库实现了很多协议,但就是没有实现WebSocket协议,有一个第三方库autobahn在twisted上实现了的WebSocket协议,如果想要深入了解WebSocket协议,可以去看下autobahn的源码。
参考资料:
https://www.cnblogs.com/chyingp/p/websocket-deep-in.html
https://developer.mozilla.org/zh-CN/docs/Web/API/WebSockets_API/Writing_WebSocket_servers
作者水平有限,如有错误之处,还望指正。
