简单Websocket通信实现

JavaWeb,技术向 2019-03-11

服务端推送技术

众所周知, HTTP协议是客户端向服务端单向请求的文本协议. 一般来说, 服务端返回请求结果以后, TCP连接就会关闭.
在这样的协议模式下, 是无法支持全双工通信的, 也就是服务端无法主动向客户端推送信息, 必须等待客户端请求才能返回信息.
因此, 服务端推送技术一般有如下几种解决方案:

  • 定时轮询: 这种方式是传统WEB解决服务端推送的方案

    • 问题主要是延迟大, 而且对服务端的压力非常高, 当然可以通过在服务端之间加一层缓存之类的方法加以缓解, 但依然效率太低, 而且浪费带宽
  • 长轮询: 是定时轮询的改进版, 即不再每次请求时服务端都立刻返回, 而是直到数据发生变化再返回

    • 能够有效节省带宽, 但问题主要是, 需要同时维持大量HTTP连接, 对服务端的压力也不小
    • 需要在服务端做额外处理
  • HTTP长连接: 在长轮询的基础上, 服务端在返回数据后不再断开TCP连接

    • 相比长轮询能够更加有效节省带宽, 而且通信的实时性更高, 依然是需要维持大量HTTP连接
    • 需要在客户端和服务端都做额外处理, 兼容性较差
  • HTML5 SSE: 直接运行于HTTP服务之上的服务器推送方案, 优点时方便简单, 缺点和上面一样, 需要维持大量HTTP连接
  • HTML5 WebSocket: 相当于在应用层重新实现了一个全双工的通信协议, 能够从HTTP协议升级(101)而来

    • 如果客户端支持的话, 这是实现HTTP全双工通信的最优方案, WebSocket协议与HTTP协议独立, 那么就可以将两个服务器分别实现, 从而避免给HTTP服务器带来太大连接压力

Java WebSocket

Java在2013年就已经通过JSR356支持了WebSocket技术, 使用起来非常简单, 使用注解就可以.

从实际业务出发, 我实现了一个AbstractEndPoint的抽象接口类, 封装了如下功能:

  • 维护客户端在线列表
  • 基于AES128生成AccessKey并验证
  • 封装了简单的文本信息发送

客户端在线列表需要考虑, 一个账号可能存在多个同时在线的客户端, 那么如果需要向一个账号发送消息就需要向所有在线的客户端发送消息.
因为操作的都是static的集合类, 所以需要考虑线程安全问题, 这里应该是读多写少所以采用ReadWriteLock.

public abstract class AbstractEndPoint {
    // Encrypt Parameters
    private final static String ivParameters = "1234567890123456"; // CBC需要16位长的IV, 随意设置
    private final static long expiredTime = 1000 * 60; // 一分钟过期
    private final static String magicStr = "QAQ"; // 用于替换accessKey中的/以避免无法捕捉到连接, 随意设置注意不要冲突
    // Static Map and rwLock - WARNING: Thread Safe Problem
    protected static Map<String, List<AbstractEndPoint>> clientMap = new HashMap<String, List<AbstractEndPoint>>(); // 在线列表
    protected static Map<String, SecretKey> keyMap = new HashMap<String, SecretKey>(); // 密钥列表
    protected static ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock(); // 读写锁
    // object Parameter
    protected String accountId;
    protected Session session;

    /**
     * 向指定账户发送文本信息
     * - 会向所有在线的连接发送消息
     * - 可以在上层套一层Object2Json比如GSON从而实现Object发送
     * - 或者使用JSR356的Encoder&Decoder
     *
     * @return 状态码 -2: 发送失败, -1:此用户不在线, 0:发送成功
     */
    public static int sendMessage(String accountId, String content){
        int statusCode = -2;
        try{
            if(rwLock.readLock().tryLock() || rwLock.readLock().tryLock(100, TimeUnit.MILLISECONDS)){
                try{
                    List<CustomEndPoint> endPoints = clientMap.get(accountId);
                    if(endPoints != null && endPoints.size() > 0){
                        for(CustomEndPoint endPoint : endPoints){
                            if(endPoint.session.isOpen()){
                                endPoint.session.getBasicRemote().sendText(content);
                            }
                        }
                        statusCode = 0;
                    }else{
                        statusCode = -1;
                    }
                }finally{
                    rwLock.readLock().unlock();
                }
            }
        }catch(InterruptedException | IOException e){
            e.printStackTrace();
        }
        return statusCode;
    }

    /**
     * 检查连接的有效性
     * 有效性包括:
     * 1. 账户含有对应的密钥
     * 2. 能够使用密钥解密加密内容
     * 3. 加密内容包含帐户名和密钥创建时间
     * 4. 密钥未过期
     *
     * @param accountId 帐号id
     * @param accessKey 连接密钥
     */
    protected boolean checkAccessible(String accountId, String accessKey){
        // 尝试验证accessKey的有效性
        boolean allowAccess = false;
        try{
            if(rwLock.readLock().tryLock() || rwLock.readLock().tryLock(100, TimeUnit.MILLISECONDS)){
                try{
                    if(keyMap.get(accountId) != null){ // 密钥存在
                        // 尝试解密AccessKey
                        accessKey = accessKey.replaceAll(magicStr, "/");
                        String information = decrypt(accessKey, keyMap.get(accountId));

                        // 验证AccessKey内信息
                        int position = information.indexOf(accountId + "_");
                        if(position != -1){
                            String createdTime = information.substring(accountId.length() + 1);
                            if(System.currentTimeMillis() - Long.parseLong(createdTime) < expiredTime){
                                allowAccess = true;
                            }
                        }
                    }
                }finally{
                    rwLock.readLock().unlock();
                }
            }
        }catch(Exception e){
            e.printStackTrace();
            log.error("Authority Exception: " + e.getMessage());
        }
        return allowAccess;
    }

    /**
     * 更新在线列表, 清除使用过的密钥
     */
    protected void updateSession(String accountId){
        try{
            if(rwLock.writeLock().tryLock() || rwLock.writeLock().tryLock(100, TimeUnit.MILLISECONDS)){
                try{
                    List<CustomEndPoint> endpoints = clientMap.get(accountId) == null ? new ArrayList<CustomEndPoint>() : clientMap.get(accountId);
                    endpoints.add(this);
                    clientMap.put(accountId, endpoints); // 添加到在线列表
                    clearClosedSession(accountId); // 清除当前账户已关闭的连接
                    keyMap.remove(accountId); // 删除AccessKey
                }finally{
                    rwLock.writeLock().unlock();
                }
            }
        }catch(InterruptedException e){
            e.printStackTrace();
        }
    }

    /**
     * 清除已经关闭的连接
     *
     * @param accountId 指定时只清除指定帐号的连接, 否则清除所有帐号的连接(注意可能的性能问题, 当在线用户很多时)
     */
    protected void clearClosedSession(String accountId){
        try{
            if(rwLock.writeLock().tryLock() || rwLock.writeLock().tryLock(100, TimeUnit.MILLISECONDS)){
                try{
                    // Init accountIds
                    List<String> accountIds = new ArrayList<>();
                    if(accountId != null){
                        accountIds.add(accountId);
                    }else{
                        accountIds.addAll(clientMap.keySet());
                    }

                    // Delete closed session
                    for(String id : accountIds){
                        if(clientMap.containsKey(id)){
                            Iterator<CustomEndPoint> iterator = clientMap.get(id).iterator();
                            while(iterator.hasNext()){
                                CustomEndPoint endpoint = iterator.next();
                                if(!endpoint.session.isOpen()){
                                    iterator.remove();
                                }
                            }
                            if(clientMap.get(id).size() == 0){
                                clientMap.remove(id);
                            }
                        }
                    }
                }finally{
                    rwLock.writeLock().unlock();
                }
            }
        }catch(InterruptedException e){
            e.printStackTrace();
        }
    }

    /**
     * 创建连接用AccessKey - AES加密后的 帐户名_时间戳
     * 因为Encode后的字符串可能含有/导致URL解析错误, 所以用magicStr替换
     *
     * @param accountId 用户名
     * @return 连接用密钥
     */
    public static String createAccessKey(String accountId){
        SecretKey key = generateSecretKey();
        String accessKey = encrypt(accountId + "_" + System.currentTimeMillis(), key);
        try{
            if(rwLock.writeLock().tryLock() || rwLock.writeLock().tryLock(100, TimeUnit.MILLISECONDS)){
                try{
                    keyMap.put(accountId, key);
                    return accessKey.replaceAll("/", magicStr);
                }finally{
                    rwLock.writeLock().unlock();
                }
            }
        }catch(InterruptedException e){
            e.printStackTrace();
        }
        throw new IllegalStateException("Failed to generate access key. ");
    }

    /**
     * 加密文本 - AES-128-CBC
     * 使用AES256需要额外的设置, 避免麻烦采用AES128
     *
     * @param plainText 原文本
     * @param key       加密密钥
     * @return 加密后的文本
     */
    private static String encrypt(String plainText, SecretKey key){
        if(key == null || plainText == null || plainText.isEmpty()){
            throw new IllegalArgumentException();
        }

        try{
            Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
            cipher.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(ivParameters.getBytes()));
            byte[] encryptedTestBytes = cipher.doFinal(plainText.getBytes(StandardCharsets.UTF_8));
            return DatatypeConverter.printBase64Binary(encryptedTestBytes);
        }catch(NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | InvalidAlgorithmParameterException |
                IllegalBlockSizeException | BadPaddingException e){
            e.printStackTrace();
        }
        throw new IllegalStateException("Cannot encrypt text. ");
    }

    /**
     * 解密文本 - AES-128-CBC
     *
     * @param encryptedText 被加密过的文字
     * @param key           AES密钥
     * @return 解密后的文本
     */
    private static String decrypt(String encryptedText, SecretKey key){
        if(key == null || encryptedText == null || encryptedText.isEmpty()){
            throw new IllegalArgumentException();
        }

        byte[] encryptedBytes = DatatypeConverter.parseBase64Binary(encryptedText);
        try{
            Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
            cipher.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(ivParameters.getBytes()));
            byte[] plainTextBytes = cipher.doFinal(encryptedBytes);
            return new String(plainTextBytes);
        }catch(NoSuchAlgorithmException | NoSuchPaddingException | BadPaddingException | InvalidKeyException |
                IllegalBlockSizeException | InvalidAlgorithmParameterException e){
            e.printStackTrace();
        }
        throw new IllegalStateException("Cannot decrypt text. ");
    }

    /**
     * 随机产生AES加密所需密钥 - 128bit
     */
    private static SecretKey generateSecretKey(){
        try{
            return KeyGenerator.getInstance("AES").generateKey();
        }catch(NoSuchAlgorithmException e){
            e.printStackTrace();
        }
        throw new IllegalStateException("Failed to generate secretKey. ");
    }
}

具体的EndPoint实现放在了子类中, 由于经过Nignx转发以后连接可能会超时, 这里采用了心跳机制保证连接不超时.
另外, 浏览器的页面跳转也可能导致连接断开, 这一部分的连接断开是正常的, 不需要记录warn日志, 浏览器如果主动放弃连接, 也是不需要记录warn日志的.
需要注意的是, 如果验证成功的话需要发送一个连接确认信息告知客户端验证通过, 否则客户端无法分辨逻辑上的连接是否建立成功.

@ServerEndpoint("/ws/{accountId}/{accessKey}")
public class MessageEndpoint extends AbstractEndPoint {
    @OnOpen
    public void onOpen(final Session session, @PathParam("accountId") String accountId, @PathParam("accessKey") String accessKey) throws IOException{
        boolean allowAccess = checkAccessible(accountId, accessKey);
        if(allowAccess){ // 允许连接
            this.accountId = accountId;
            this.session = session;
            updateSession(accountId); // 添加到在线列表
            session.getBasicRemote().sendText("ok"); // 发送连接确认信息
            log.info(String.format("Account %s #%d connected", accountId, clientMap.get(accountId).size()));
        }else{ // 验证失败
            CloseReason reason = new CloseReason(CloseReason.CloseCodes.CANNOT_ACCEPT, "Authority Failed");
            session.close(reason);
        }
    }

    @OnMessage
    public void onMessage(String txt) {
        if(!txt.equals("heartBeat")){
            log.debug("Received Msg: " + txt);
        }
    }

    @OnClose
    public void onClose(Session session, CloseReason closeReason) {
        clearClosedSession(accountId);
        if(closeReason.getCloseCode().equals(CloseReason.CloseCodes.GOING_AWAY)){ // 浏览器跳转页面导致的断开
            log.info("Account " + accountId + " disconnected due to SWITCH THE PAGE.");
        }else{
            log.warn("Account " + accountId + " disconnected due to " + closeReason.getReasonPhrase());
        }
    }

    @OnError
    public void onError(Session session, Throwable error) {
        clearClosedSession(accountId);
        if(!error.getMessage().contains("The client aborted the connection.")){ // 浏览器主动放弃连接不记录
            log.warn("Error communicating with Account " + accountId + ". Detail: " + error.getMessage());
        }
    }
}

在浏览器端就比较简单, 这里比较简单的维护了一下无权访问时以及WebSocket出错时的Fallback, 只有当客户端收到ok信息时才能认为连接成功:

function readMessage() {
    var url = "requestAccessKey";
    var wsBaseUrl = $('base').attr('href').replace('http', 'ws');
    var sessionStorage = window.sessionStorage;
    if(sessionStorage.getItem("NoPermission") != null){ // 没有权限的情况下不再请求
        return;
    }
    // 如果不支持Websocket或者连接出错超过10次则使用轮询方式
    if (!'WebSocket' in window || Number(sessionStorage.getItem("wsRetryCnt")) > 10) {
        $.get(url, function (data) {
            console.log(data);
        }).onload = function (evt) { // 无权访问时, 停止操作
            if (this.status === 403) {
                sessionStorage.setItem("NoPermission", "1");
            }
        }
    } else {
        if (null == ws_client) { // connection closed or hasn't established.
            var tmp = $.get(url, function (data) {
                ws_client = new WebSocket(wsBaseUrl + "ws/" + data.accountId + "/" + data.accessKey);
                ws_client.onopen = function (evt) {
                    console.log("Connecting...");
                };
                ws_client.onmessage = function (evt) {
                    if(evt.data === "ok"){
                        sessionStorage.setItem("wsRetryCnt", "0");
                        console.log("Connection Created.");
                    }else{
                        console.log(data);
                    }
                };
                ws_client.onclose = function (evt) {
                    ws_client = null;
                    var retryCnt = sessionStorage.getItem("wsRetryCnt") == null ? 1 : Number(sessionStorage.getItem("wsRetryCnt")) + 1;
                    sessionStorage.setItem("wsRetryCnt", String(retryCnt));
                    console.log("Connection Failed #"+ retryCnt + ".");
                }
            }).onload = function (evt) {  // 无权访问时, 停止操作
                if (this.status === 403) {
                    sessionStorage.setItem("NoPermission", "1");
                }
            }
        } else { // 如果已经有ws对象, 那么在这里维持心跳保证连接不断开
            if (ws_client.readyState === WebSocket.CLOSING || ws_client.readyState === WebSocket.CLOSED) { // 如果连接已断开则重新建立连接
                ws_client = null;
            } else {
                ws_client.send("heartBeat");
            }
        }
    }
}

简单总结

WebSocket作为最先进的服务端推送/客户端服务端双工通信技术的实现, 单纯实现起来其实很简单, JSR356等已经做了大量的底层实现.
但是如果要结合业务并同时考虑到安全性等问题的话, 就需要从即时通讯的角度来充分思考需要在其上需要额外附加的机制.

从本质上来说, WebSocket为HTTP协议提供了这样一种功能, 从一个简单的单向的文本协议升级为一个更为复杂的全双工二进制通信协议, 正如WebSokcet中Socket一词, 在使用上更像是向操作系统升级了一个Socket来进行操作.


本文由 SLKun 创作,采用 知识共享署名 3.0,可自由转载、引用,但需署名作者且注明文章出处。

还不快抢沙发

添加新评论