shikeying
2024-01-11 3b67e947e36133e2a40eb2737b15ea375e157ea0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
package com.walker.tcp.websocket;
 
import com.walker.infrastructure.ApplicationRuntimeException;
import com.walker.infrastructure.utils.ClassUtils;
import com.walker.infrastructure.utils.StringUtils;
import com.walker.tcp.Request;
import com.walker.tcp.data.AbstractStringRequest;
import com.walker.tcp.handler.LongHandler;
import com.walker.tcp.protocol.StringProtocolResolver;
 
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
 
public class WebsocketHandler extends LongHandler {
 
//    private final ProtocolResolver resolver = new WebsocketProtocolResolver();
    private final StringProtocolResolver resolver = new WebsocketProtocolResolver();
 
    @Override
    protected Request<?> createRequest(String content) throws Exception {
        if(this.getMapper() == null){
            throw new IllegalArgumentException(MSG_REQUEST_ERROR);
        }
 
//        ProtocolResolver resolver = ConvertorUtils.getProtocolResolver(msg, getProtocolResolverList());
//        if(resolver == null){
//            throw new IllegalArgumentException("protocolResolver not found, msg : " + msg);
//        }
 
        // 去掉报文中的分隔符,因为netty打开了显示分隔符
//        String content = msg.substring(0, msg.length()-resolver.getDelimiter().length());
 
        String protocol = resolver.getProtocolNum(content, content.length());
        String clazz = this.getMapper().get(protocol);
        if(StringUtils.isEmpty(clazz)){
            throw new IllegalArgumentException("请求协议对应的request类不存在。protocol = " + protocol + ", msg = " + content);
        }
 
        Class<?> clazzRequest = this.acquireRequestClazz(clazz);
        AbstractStringRequest request = (AbstractStringRequest)clazzRequest.newInstance();
        request.fromSource(content);
        request.setProtocolResolverId(resolver.getOrder());
        return request;
 
//        try {
//            Class<?> clazzRequest = ClassUtils.forName(clazz, this.getClass().getClassLoader());
//            AbstractStringRequest request = (AbstractStringRequest)clazzRequest.newInstance();
//            request.fromSource(content);
//            request.setProtocolResolverId(resolver.getOrder());
//            return request;
//
//        } catch (ClassNotFoundException e) {
//            logger.error("根据映射创建request对象错误:" + e.getMessage(), e);
//            throw new Exception(e);
//        }
    }
 
    private Class<?> acquireRequestClazz(String clazz){
        if(StringUtils.isEmpty(clazz)){
            throw new IllegalArgumentException("必须提供请求对象class名称!");
        }
        Class<?> requestClazz = this.requestClazzCache.get(clazz);
        if(requestClazz == null){
            try {
                requestClazz = ClassUtils.forName(clazz, this.getClass().getClassLoader());
                this.requestClazzCache.put(clazz, requestClazz);
            } catch (ClassNotFoundException e) {
                throw new ApplicationRuntimeException("tcp request class not found: " + clazz, e);
            }
        }
        return requestClazz;
    }
 
    private Map<String, Class<?>> requestClazzCache = new ConcurrentHashMap<>();
}