vlambda博客
学习文章列表

【50/50】WebSocket协议详解与实现

Photo by Kristine Weilert on Unsplash


这是2019年50篇文章的最后一篇


今天,我们来了解一下WebSocket协议,然后用Python基于TCP实现一个简单的WebSocket协议解析。

首先,我们来了解一下什么是WebSocket协议。

WebSocket协议是为了解决HTTP协议只能由客户端主动发起请求,服务器不能主动向客户端发送数据的问题而诞生的。WebSocket也是基于TCP协议的,握手过程使用HTTP,传输的内容可以是文本,也可以是二进制内容。使用WebSocket可以实现客户端和服务器的双向实时通信。

然后,我们来看一下WebSocket协议的握手过程。

WebSocket的握手连接过程是基于HTTP的,客户端首先要发送一个HTTP的请求,并且带上一些特定的请求头,然后服务器按照协议规定返回101状态码和相应的响应头完成连接的建立,之后就可以通过这个连接进行双向通信了。

那握手过程的请求和响应有什么规定呢?

首先,客户端向服务器发送一个GET请求,和HTTP协议不同的是WebSocket协议需要添加4个特殊的请求头,例如:

GET / HTTP/1.1
Host: example.com:8000
Connection: Upgrade
Upgrade: websocket
Sec-WebSocket-Version: 13
Sec-WebSocket-Key: YtqzKW5j8rYIYauXEwcJFw==
  • Connection: Upgrade:表示需要升级协议。

  • Upgrade: websocket:表示需要升级为WebSocket协议。

  • Sec-WebSocket-Version: 13:表示WebSocket协议的版本。

  • Sec-WebSocket-Key: YtqzKW5j8rYIYauXEwcJFw==:是客户端随机生成的。


然后服务器会返回101 Switching Protocols,并且会增加3个特殊的响应头,例如:

HTTP/1.1 101 Switching Protocols
Connection: Upgrade
Sec-WebSocket-Accept: 4q50AMbiRegDNPtQYmvSw+HGHv8=
Upgrade: WebSocket

响应头中的Connection和Upgrade就不说了,我们来说说这个Sec-WebSocket-Accept响应头。应该不难猜到,这个响应头应该和请求头中的Sec-WebSocket-Key有关,但是具体是如何生成的呢?

其实也很简单,就是服务器获取到Sec-WebSocket-Key请求头之后,将Sec-WebSocket-Key请求头的值拼接上一个特殊的字符串258EAFA5-E914-47DA-95CA-C5AB0DC85B11,然后使用SHA-1算法计算出摘要并使用base64编码就得到了响应头Sec-WebSocket-Accept的值。由于WebSocket的握手过程是基于HTTP的,所以这两个头是为了减少恶意连接、意外连接而设置的。

我们来用Python来实现这个key生成的过程:

import hashlibimport base64

def gen_websocket_key(key): sha1 = hashlib.sha1() magic_value = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" sha1.update((key + magic_value).encode()) return base64.b64encode(sha1.digest())

if __name__ == "__main__": print(gen_websocket_key("YtqzKW5j8rYIYauXEwcJFw=="))

运行输出:4q50AMbiRegDNPtQYmvSw+HGHv8=

在WebSocket握手完成后,就可以发送WebSocket协议的数据了。我们来继续看下WebSocket协议的格式是什么样的:

 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1+-+-+-+-+-------+-+-------------+-------------------------------+|F|R|R|R| opcode|M| Payload len | Extended payload length ||I|S|S|S| (4) |A| (7) | (16/64) ||N|V|V|V| |S| | (if payload len==126/127) || |1|2|3| |K| | |+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +| Extended payload length continued, if payload len == 127 |+ - - - - - - - - - - - - - - - +-------------------------------+| |Masking-key, if MASK set to 1 |+-------------------------------+-------------------------------+| Masking-key (continued) | Payload Data |+-------------------------------- - - - - - - - - - - - - - - - +: Payload Data continued ... :+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +| Payload Data continued ... |+---------------------------------------------------------------+

我们来看下这些字段都有什么用,格式是什么样的:

  • FIN:如果是1,表示这是消息的最后一个分片,如果是0,表示不是是消息的最后一个分片。

  • RSV1,RSV2,RSV3:用于WebSocket的扩展,一般情况下全为0。

  • opcode:数据的类型。

  • MASK:是否对数据进行掩码操作。客户端发送消息必须要进行掩码。

  • Payload len:数据长度,单位是字节。

  • Masking-key:掩码。如果MASK为1,则有4字节掩码,否则没有掩码。

  • Payload Data:数据。


opcode可选项:

0x0:表示一个延续帧。当opcode为0时,表示本次数据传输采用了数据分片,当前收到的数据帧为其中一个数据分片。

0x1:表示这是一个文本帧。

0x2:表示这是一个二进制帧。

0x3-0x7:保留的操作代码,用于后续定义的非控制帧。

0x8:表示连接断开。

0x9:表示这是一个ping操作。

0xA:表示这是一个pong操作。

0xB-0xF:保留的操作代码,用于后续定义的控制帧。


数据长度:

如果Payload len的值小于125,数据的长度就是Payload len的无符号整型的值。

如果Payload len的值是126,数据的长度是Payload len后边的16位的无符号整型的值。

如果Payload len的值是127,数据的长度是Payload len后边的64位的无符号整型的值(最高位为0)。


掩码算法:如果MASK设置为1,则数据是经过Masking-key掩码运算后的,掩码操作是对每一个字节做异或操作,需要对收到的数据再次进行异或操作才能获取到原始数据。Masking-key有4个字节,需要对Payload len和Masking-key进行循环异或操作进行掩码。使用Python代码来描述就是:

for i in range(len(payload))
payload[i] = payload[i] ^ masking_key[i%4]

一个数两次异或同一个数可以得到它本身

我们来使用Python的twisted这个网络库来基于TCP实现一个WebSocket的服务器:

import hashlibimport base64import refrom collections import defaultdict
from twisted.internet.protocol import Factory, connectionDonefrom twisted.internet import reactor, protocol

opcode_map = { "text": 1, "binary": 2, "close": 8, "ping": 9, "pong": 10,}

# 计算Sec-WebSocket-Acceptdef gen_sec_websocket_accept(key): sha1 = hashlib.sha1() magic_value = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" sha1.update((key + magic_value).encode()) return base64.b64encode(sha1.digest())

class Chat(protocol.Protocol):
def __init__(self, rooms): self.rooms = rooms # 所有聊天室 self.connection_name = "" self.room = "default" self.nick = "Anonymous" self.upgraded = False # 是否已升级 # 处理方法和opcode映射 self.handle_map = { 1: self.handle_text, 2: self.handle_binary, 8: self.handle_close, 9: self.hanlde_ping }
def connectionMade(self): self.connection_name = "%s:%s" % self.transport.client
def connectionLost(self, reason=connectionDone): if self.connection_name in self.rooms: print(f"connectionLost: {self.connection_name}") self.rooms[self.connection_name].remove(self.transport.socket)
# 协议升级 def upgrade(self, data): data = data.decode("utf-8") lines = data.split("\r\n") request_line = lines[0] try: method, path, version = request_line.split() except ValueError: raise Exception(f"error request line: {request_line}") if method != "GET" or version != "HTTP/1.1": raise Exception(f"error request line: {request_line}") # 匹配聊天室名字和用户nick r = re.compile(r"/(?P<room>\w+)(\?nick=(?P<nick>\w+))?") if r.match(path): res = r.search(path) self.room = res.group("room") nick = res.group("nick") self.nick = nick if nick else self.nick self.rooms[self.room].add(self) req_headers = dict() for line in lines: if line and len(line.split(":")) == 2: key, value = line.split(":") req_headers[key.strip().lower()] = value.strip() if "upgrade" not in req_headers.get("connection", "").lower(): raise Exception("connection error") if req_headers.get("upgrade", "") != "websocket": raise Exception("upgrade error") if req_headers.get("sec-websocket-version", "") != "13": raise Exception("websocket version error") if not req_headers.get("sec-websocket-key"): raise Exception("no Sec-WebSocket-Key") resp = ["HTTP/1.1 101 Switching Protocols"] resp_headers = {"Connection": "Upgrade", "Upgrade": "WebSocket", "Sec-WebSocket-Accept": gen_sec_websocket_accept(req_headers["sec-websocket-key"]).decode()} for key, value in resp_headers.items(): resp.append(f"{key}: {value}") resp.append("\r\n") self.transport.write("\r\n".join(resp).encode()) print("switching protocols success") self.upgraded = True
# 处理WebSocket数据 def parse(self, data): if len(data) < 2: raise Exception("data error") # TODO:处理FIN=0 if data[1] & 0x80 != 0x80: raise Exception("MASK error") length = data[1] & 0x7f masking_key = data[2:6] payload_data = data[6:6+length] if length == 126: if len(data) < 4: raise Exception("data error") length = int.from_bytes(data[2:4], byteorder='big') masking_key = data[4:8] payload_data = data[8:8+length] if length == 127: if len(data) < 8: raise Exception("data error") length = int.from_bytes(data[4:8], byteorder='big') masking_key = data[8:12] payload_data = data[12:12+length] message = bytearray() for i in range(length): t = payload_data[i] ^ masking_key[i % 4] message.append(t) opcode = data[0] & 0xf return opcode, bytes(message)
def handle(self, data): opcode, message = self.parse(data) handle = self.handle_map.get(opcode) if not handle: raise Exception("not support") handle(message)
def handle_binary(self, message): raise NotImplementedError
# 处理文本消息 def handle_text(self, message): message = self.nick.encode() + b": " + message data = self.make_data(opcode_map["text"], message) for c in self.rooms[self.room]: c.transport.write(data)
# 处理ping消息 def hanlde_ping(self, message): self.transport.write(self.make_data(opcode_map["pong"]))
# 处理close消息 def handle_close(self, message): self.transport.write(self.make_data(opcode_map["close"]))
# 生成响应消息 def make_data(self, opcode, message=b""): data = bytearray() length = len(message) data.append(0x80 | opcode) if length < 126: data.append(length) elif 126 <= length <= 65536: data.append(126) data.extend(length.to_bytes(2, 'big')) elif 65536 < length < 2**31: data.append(127) data.extend(length.to_bytes(4, 'big')) else: raise Exception("data too long") data.extend(message) return bytes(data)
# 接受数据回调 def dataReceived(self, data): print("dataReceived:", data) if not self.upgraded: try: self.upgrade(data) except Exception as err: print("upgrade error: ", err) else: try: self.handle(data) except Exception as err: print("handle error: ", err)

class ChatFactory(Factory):
def __init__(self): # 聊天室,key为聊天室名字,value为聊天室下所有的Chat对象集合 self.rooms = defaultdict(set)
def buildProtocol(self, addr): return Chat(self.rooms)

def main(): reactor.listenTCP(8000, ChatFactory(), interface="0.0.0.0") reactor.run()

if __name__ == "__main__": main()

我们来使用http://coolaf.com/tool/chattest这个网站测试一下:

这个程序仅作为理解WebSocket协议的一个示例,并没有完全实现WebSocket协议,并且有些方法的设计和错误处理并不是太好,大家可以自己阅读文档自己来实现一个更好的。

twisted这个库实现了很多协议,但就是没有实现WebSocket协议,有一个第三方库autobahn在twisted上实现了的WebSocket协议,如果想要深入了解WebSocket协议,可以去看下autobahn的源码。

参考资料:

https://www.cnblogs.com/chyingp/p/websocket-deep-in.html

https://developer.mozilla.org/zh-CN/docs/Web/API/WebSockets_API/Writing_WebSocket_servers


作者水平有限,如有错误之处,还望指正。