deploy-jar-template/src/main/java/com/iplatform/api/MilvusChatApi.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
deploy-jar-template/src/main/java/com/iplatform/api/TestWebController.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
deploy-jar-template/src/main/java/com/iplatform/milvus/MilvusEngine.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
deploy-jar-template/src/main/java/com/iplatform/milvus/ScoreText.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
deploy-jar-template/src/main/java/com/iplatform/milvus/SearchResult.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
deploy-jar-template/src/main/java/com/iplatform/milvus/service/EventServiceImpl.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
deploy-jar-template/src/main/resources/application-dev.yml | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
iplatform-base/src/main/java/com/iplatform/base/support/LogAspect.java | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
deploy-jar-template/src/main/java/com/iplatform/api/MilvusChatApi.java
@@ -4,8 +4,13 @@ import com.iplatform.milvus.EventVo; import com.iplatform.milvus.MilvusEngine; import com.iplatform.milvus.ParamList; import com.iplatform.milvus.ScoreText; import com.iplatform.milvus.SearchResult; import com.iplatform.milvus.service.EventServiceImpl; import com.walker.infrastructure.utils.FileUtils; import com.walker.infrastructure.utils.JsonUtils; import com.walker.infrastructure.utils.StringUtils; import com.walker.support.milvus.OutData; import com.walker.web.ResponseValue; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpStatus; @@ -18,6 +23,7 @@ import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @RestController @@ -32,7 +38,9 @@ private boolean isBreak = false; private static final String URL_EMBEDDING = "http://120.26.128.84:7003/ai/text/embedding"; private static final String URL_SEARCH_SIMILAR = "http://120.26.128.84:7003/ai/text/search_similar"; private static final String URL_MILVUS = "120.26.128.84"; private static final double BEST_MATCH_SCORE = 0.75; @Autowired public MilvusChatApi(EventServiceImpl eventService, RestTemplate restTemplate){ @@ -41,6 +49,7 @@ if(this.milvusEngine == null){ MilvusEngine engine = new MilvusEngine(URL_MILVUS, 19530); this.milvusEngine = engine; this.milvusEngine.loadChatSimilar4Search(); logger.info("milvus engine ok!"); } } @@ -76,8 +85,84 @@ @RequestMapping("/query") public ResponseValue testQueryMilvus(String text){ if(StringUtils.isEmpty(text)){ return ResponseValue.error("text is required!"); } // 用分号分隔多个语句 String[] array = StringUtils.toArray(text); ParamList paramList = new ParamList(); for(String value : array){ paramList.add(value); } return ResponseValue.success(); SearchResult searchResult = this.acquireSearchResult(paramList); if(searchResult == null){ logger.warn("远程调用最匹配语句向量失败,无法继续搜索相似度"); return ResponseValue.error("远程调用最匹配语句向量失败,无法继续搜索相似度"); } List<List<Float>> requestVectors = new ArrayList(4); if(!StringUtils.isEmptyList(searchResult.getBest_list()) && searchResult.getBest_embedding() != null){ ScoreText bestScoreText = searchResult.getBest_list().get(0); logger.debug("在对话集合中,存在最匹配的句子: {}", bestScoreText); if(bestScoreText.getScore() >= BEST_MATCH_SCORE){ logger.debug("最匹配的分值: {} 大于设置阈值,可以直接作为查询基准", bestScoreText.getScore()); requestVectors.add(searchResult.getBest_embedding()); } else { requestVectors.add(searchResult.getBest_embedding()); requestVectors.add(searchResult.getAll_embedding()); } } else { logger.debug("只使用全量查询向量:{}", text); requestVectors.add(searchResult.getAll_embedding()); } OutData outData = this.milvusEngine.searchChatSimilar(requestVectors); if(outData == null){ return ResponseValue.error("未检索到任何匹配相似提问"); } List<OutData.Data> dataList = outData.getResultList(); for(OutData.Data d : dataList){ logger.info("data = {}", d); } return ResponseValue.success(dataList); } private SearchResult acquireSearchResult(ParamList paramList){ try { ResponseEntity<ResponseValue> responseEntity = this.restTemplate.postForEntity(URL_SEARCH_SIMILAR, paramList, ResponseValue.class); // logger.debug(responseEntity.toString()); if(responseEntity.getStatusCode() == HttpStatus.OK){ ResponseValue<Map<String, Object>> responseValue = responseEntity.getBody(); if(responseValue.getCode() != 1){ logger.error("调用返回acquireSearchResult返回状态错误:{}", responseValue.getMsg()); return null; } Map<String, Object> map = responseValue.getData(); logger.debug("map = {}", map); // String bestListJson = JsonUtils.objectToJsonString(map.get("best_list")); // List<ScoreText> bestList = JsonUtils.jsonStringToList(bestListJson, ScoreText.class); // logger.debug("bestList = {}", bestList); // SearchResult searchResult = new SearchResult(); // searchResult.setBest_embedding(this.transfer2FloatList((List<Float>)map.get("best_embedding"))); // searchResult.setAll_embedding(this.transfer2FloatList((List<Float>)map.get("all_embedding"))); String json = JsonUtils.objectToJsonString(map); SearchResult searchResult = JsonUtils.jsonStringToObject(json, SearchResult.class); searchResult.setBest_embedding(this.transfer2FloatList(searchResult.getBest_embedding())); searchResult.setAll_embedding(this.transfer2FloatList(searchResult.getAll_embedding())); return searchResult; } else { logger.error("调用 {} 结果返回失败:{}", URL_SEARCH_SIMILAR, responseEntity.getStatusCodeValue()); return null; } } catch (Exception ex){ logger.error("获取搜索语句向量结果错误:{}" + ex.getMessage(), ex); return null; } } private boolean acquireEmbedding(List<EventVo> batchData){ deploy-jar-template/src/main/java/com/iplatform/api/TestWebController.java
@@ -2,6 +2,7 @@ import com.iplatform.base.PushController; import com.iplatform.base.PushData; import com.iplatform.base.pojo.UserParam; import com.iplatform.base.service.UserServiceImpl; import com.iplatform.model.po.S_user_core; import com.walker.db.page.GenericPager; @@ -9,10 +10,13 @@ import com.walker.infrastructure.utils.JsonUtils; import com.walker.web.ResponseValue; import com.walker.web.WebRuntimeException; import com.walker.web.log.BusinessType; import com.walker.web.log.Log; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.security.core.parameters.P; import org.springframework.transaction.annotation.Transactional; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @@ -35,6 +39,19 @@ this.userService = userService; } @Log(title = "测试分页列表", businessType = BusinessType.Delete, isSaveRequestData = true, isSaveResponseData = true) @GetMapping("/list") public ResponseValue list(UserParam userParam){ try { logger.debug("userParam = {}", JsonUtils.objectToJsonString(userParam)); } catch (Exception e) { throw new RuntimeException(e); } GenericPager<S_user_core> pager = this.userService.queryPageUserList(0 , 0, userParam.getUserName(), userParam.getPhonenumber(), userParam.getStatus()); return ResponseValue.success(pager); } @RequestMapping("/push_msg") public ResponseValue testPushMessage() throws Exception{ deploy-jar-template/src/main/java/com/iplatform/milvus/MilvusEngine.java
@@ -6,6 +6,8 @@ import com.walker.support.milvus.FieldType; import com.walker.support.milvus.MetricType; import com.walker.support.milvus.OperateService; import com.walker.support.milvus.OutData; import com.walker.support.milvus.Query; import com.walker.support.milvus.Table; import com.walker.support.milvus.engine.DefaultOperateService; import org.slf4j.Logger; @@ -20,6 +22,8 @@ public class MilvusEngine { protected final transient Logger logger = LoggerFactory.getLogger(this.getClass()); public static final String TABLE_CHAT_SIMILAR = "chat_similar"; public MilvusEngine(String ip, int port){ DefaultOperateService service = new DefaultOperateService(); @@ -44,7 +48,7 @@ */ public void createChatSimilarTable(){ Table chatSimilarTable = new Table(); chatSimilarTable.setCollectionName("chat_similar"); chatSimilarTable.setCollectionName(TABLE_CHAT_SIMILAR); chatSimilarTable.setDescription("聊天提取工单摘要历史数据"); chatSimilarTable.setShardsNum(1); chatSimilarTable.setDimension(768); // 这个是根据使用向量模型维度定的 @@ -84,7 +88,7 @@ public void insertTestData(){ DataSet dataSet = new DataSet(); dataSet.setTableName("chat_similar"); dataSet.setTableName(TABLE_CHAT_SIMILAR); List<List<Float>> vectorList = new ArrayList<>(); vectorList.add(Arrays.asList(mockVector)); @@ -132,7 +136,7 @@ } DataSet dataSet = new DataSet(); dataSet.setTableName("chat_similar"); dataSet.setTableName(TABLE_CHAT_SIMILAR); Map<String, List<?>> fieldMap = new HashMap(); fieldMap.put("id", ids); @@ -145,6 +149,26 @@ logger.info("写入了: {}", ids); } public OutData searchChatSimilar(List<List<Float>> vectors){ Query query = new Query(); query.setMetricType(MetricType.NLP.getIndex()); query.setTableName(TABLE_CHAT_SIMILAR); query.setTopK(4); query.setVectorName("embedding"); query.setOutFieldList(Arrays.asList(new String[]{"id","title","content"})); query.setFieldPrimaryKey("id"); query.setSearchVectors(vectors); return this.operateService.searchVector(query); } /** * 必须在查询之前,加载数据到内存中。 * @date 2024-03-31 */ public void loadChatSimilar4Search(){ this.operateService.prepareSearch(TABLE_CHAT_SIMILAR); } private OperateService operateService; // private Double[] mockVector = new Double[]{-0.051114246249198914, 0.889954432}; deploy-jar-template/src/main/java/com/iplatform/milvus/ScoreText.java
New file @@ -0,0 +1,31 @@ package com.iplatform.milvus; public class ScoreText { public double getScore() { return score; } public void setScore(double score) { this.score = score; } public String getText() { return text; } public void setText(String text) { this.text = text; } private double score; private String text; @Override public String toString() { return "ScoreText{" + "score=" + score + ", text='" + text + '\'' + '}'; } } deploy-jar-template/src/main/java/com/iplatform/milvus/SearchResult.java
New file @@ -0,0 +1,38 @@ package com.iplatform.milvus; import java.util.List; /** * 查询给定语句集合,返回的向量信息。 * @date 2024-03-31 */ public class SearchResult { public List<ScoreText> getBest_list() { return best_list; } public void setBest_list(List<ScoreText> best_list) { this.best_list = best_list; } public List<Float> getBest_embedding() { return best_embedding; } public void setBest_embedding(List<Float> best_embedding) { this.best_embedding = best_embedding; } public List<Float> getAll_embedding() { return all_embedding; } public void setAll_embedding(List<Float> all_embedding) { this.all_embedding = all_embedding; } private List<ScoreText> best_list; private List<Float> best_embedding; private List<Float> all_embedding; } deploy-jar-template/src/main/java/com/iplatform/milvus/service/EventServiceImpl.java
@@ -3,17 +3,32 @@ import com.iplatform.milvus.EventVo; import com.walker.jdbc.service.BaseServiceImpl; import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.stereotype.Service; import java.sql.ResultSet; import java.sql.SQLException; import java.util.HashMap; import java.util.List; import java.util.Map; @Service public class EventServiceImpl extends BaseServiceImpl { private final EventVoMapper eventVoMapper = new EventVoMapper(); /** * 根据工单id集合,返回这些工单对象(历史工单基本信息) * @param ids * @return * @date 2024-03-31 */ public List<EventVo> queryEventWhereIn(List<Long> ids){ MapSqlParameterSource parameterSource = new MapSqlParameterSource(); parameterSource.addValue("ids", ids); return this.sqlListObjectWhereIn("select * from event_history where id in (:ids)", eventVoMapper, parameterSource); } public List<EventVo> queryEventAll(Long existId){ if(existId == null){ return this.select("select * from event_history order by id asc", new Object[]{}, this.eventVoMapper); deploy-jar-template/src/main/resources/application-dev.yml
@@ -288,7 +288,7 @@ # 如果不打开,则设备登录的uuid更新操作也无法获得,用于记录每个登录用户的uuid(用户登录角色更新),2023-03-23 login-enabled: true # 是否打开操作日志,2023-01-05 operate-enabled: false operate-enabled: true # 验证码相关配置,2023-01-27 captcha: iplatform-base/src/main/java/com/iplatform/base/support/LogAspect.java
@@ -9,6 +9,7 @@ import com.walker.infrastructure.utils.StringUtils; import com.walker.web.Constants; import com.walker.web.ResponseCode; import com.walker.web.WebRuntimeException; import com.walker.web.log.BusinessType; import com.walker.web.log.Log; import com.walker.web.util.IpUtils; @@ -92,7 +93,13 @@ private void handleLog(final JoinPoint joinPoint, Log logAnnotation, final Exception e, Object jsonResult){ try{ S_oper_log s_oper_log = new S_oper_log(); S_user_core user_core = this.securitySpi.getCurrentUser(); // S_user_core user_core = this.securitySpi.getCurrentUser(); S_user_core user_core = null; try { user_core = this.securitySpi.getCurrentUser(); } catch (WebRuntimeException ex) { logger.debug("该接口无需认证,无法找到当前人信息"); } if(user_core != null){ s_oper_log.setOper_name(user_core.getUser_name()); } @@ -128,9 +135,24 @@ s_oper_log.setOper_param(params); } } else { String queryString = ServletUtils.getRequest().getQueryString(); // Object param2 = ServletUtils.getRequest().getAttribute(HandlerMapping.MATRIX_VARIABLES_ATTRIBUTE); // Object param3 = ServletUtils.getRequest().getAttribute(HandlerMapping.PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE); // logger.debug(queryString); // logger.debug("param2 = {}", param2); // logger.debug("param3 = {}", param3); if(StringUtils.isEmpty(queryString)){ Map<?, ?> paramsMap = (Map<?, ?>) ServletUtils.getRequest().getAttribute(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE); if(paramsMap != null){ s_oper_log.setOper_param(StringUtils.substring(paramsMap.toString(), 0, MAX_DATA_SIZE)); queryString = paramsMap.toString(); } } if(StringUtils.isNotEmpty(queryString)){ if(queryString.length() > MAX_DATA_SIZE){ s_oper_log.setOper_param(StringUtils.substring(queryString, 0, MAX_DATA_SIZE)); } else { s_oper_log.setOper_param(queryString); } } } }