package com.iplatform.api; import com.iplatform.base.SystemController; import com.iplatform.milvus.EventVo; import com.iplatform.milvus.MilvusEngine; import com.iplatform.milvus.ParamList; import com.iplatform.milvus.ScoreText; import com.iplatform.milvus.SearchResult; import com.iplatform.milvus.service.EventServiceImpl; import com.walker.infrastructure.utils.FileUtils; import com.walker.infrastructure.utils.JsonUtils; import com.walker.infrastructure.utils.StringUtils; import com.walker.support.milvus.OutData; import com.walker.web.ResponseValue; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.client.RestTemplate; import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @RestController @RequestMapping("/test/milvus") public class MilvusChatApi extends SystemController { private EventServiceImpl eventService; private RestTemplate restTemplate; private MilvusEngine milvusEngine; private Long existId = 4048L; private boolean isBreak = false; private static final String URL_EMBEDDING = "http://120.26.128.84:7003/ai/text/embedding"; private static final String URL_SEARCH_SIMILAR = "http://120.26.128.84:7003/ai/text/search_similar"; private static final String URL_MILVUS = "120.26.128.84"; private static final double BEST_MATCH_SCORE = 0.75; @Autowired public MilvusChatApi(EventServiceImpl eventService, RestTemplate restTemplate){ this.eventService = eventService; this.restTemplate = restTemplate; if(this.milvusEngine == null){ MilvusEngine engine = new MilvusEngine(URL_MILVUS, 19530); this.milvusEngine = engine; this.milvusEngine.loadChatSimilar4Search(); logger.info("milvus engine ok!"); } } @RequestMapping("/embedding") public ResponseValue testHttpEmbedding(){ List data = new ArrayList(8); EventVo eventVo = new EventVo(); eventVo.setContent("第一句"); data.add(eventVo); eventVo = new EventVo(); eventVo.setContent("第二句"); data.add(eventVo); boolean success = this.acquireEmbedding(data); this.milvusEngine.insertTestData(); return ResponseValue.success("结果是:" + success); } @RequestMapping("/write") public ResponseValue testWriteMilvus(){ Collection eventVoList = this.acquireEventVoList(); if(eventVoList == null){ return ResponseValue.error("没有加载到数据"); } logger.info("加载了 event vo: {}个", eventVoList.size()); new Thread(new WriteTask(eventVoList)).start(); return ResponseValue.success(); } @RequestMapping("/query") public ResponseValue testQueryMilvus(String text){ if(StringUtils.isEmpty(text)){ return ResponseValue.error("text is required!"); } // 用分号分隔多个语句 String[] array = StringUtils.toArray(text); ParamList paramList = new ParamList(); for(String value : array){ paramList.add(value); } SearchResult searchResult = this.acquireSearchResult(paramList); if(searchResult == null){ logger.warn("远程调用最匹配语句向量失败,无法继续搜索相似度"); return ResponseValue.error("远程调用最匹配语句向量失败,无法继续搜索相似度"); } List> requestVectors = new ArrayList(4); if(!StringUtils.isEmptyList(searchResult.getBest_list()) && searchResult.getBest_embedding() != null){ ScoreText bestScoreText = searchResult.getBest_list().get(0); logger.debug("在对话集合中,存在最匹配的句子: {}", bestScoreText); if(bestScoreText.getScore() >= BEST_MATCH_SCORE){ logger.debug("最匹配的分值: {} 大于设置阈值,可以直接作为查询基准", bestScoreText.getScore()); requestVectors.add(searchResult.getBest_embedding()); } else { requestVectors.add(searchResult.getBest_embedding()); requestVectors.add(searchResult.getAll_embedding()); } } else { logger.debug("只使用全量查询向量:{}", text); requestVectors.add(searchResult.getAll_embedding()); } OutData outData = this.milvusEngine.searchChatSimilar(requestVectors); if(outData == null){ return ResponseValue.error("未检索到任何匹配相似提问"); } List dataList = outData.getResultList(); for(OutData.Data d : dataList){ logger.info("data = {}", d); } return ResponseValue.success(dataList); } private SearchResult acquireSearchResult(ParamList paramList){ try { ResponseEntity responseEntity = this.restTemplate.postForEntity(URL_SEARCH_SIMILAR, paramList, ResponseValue.class); // logger.debug(responseEntity.toString()); if(responseEntity.getStatusCode() == HttpStatus.OK){ ResponseValue> responseValue = responseEntity.getBody(); if(responseValue.getCode() != 1){ logger.error("调用返回acquireSearchResult返回状态错误:{}", responseValue.getMsg()); return null; } Map map = responseValue.getData(); logger.debug("map = {}", map); // String bestListJson = JsonUtils.objectToJsonString(map.get("best_list")); // List bestList = JsonUtils.jsonStringToList(bestListJson, ScoreText.class); // logger.debug("bestList = {}", bestList); // SearchResult searchResult = new SearchResult(); // searchResult.setBest_embedding(this.transfer2FloatList((List)map.get("best_embedding"))); // searchResult.setAll_embedding(this.transfer2FloatList((List)map.get("all_embedding"))); String json = JsonUtils.objectToJsonString(map); SearchResult searchResult = JsonUtils.jsonStringToObject(json, SearchResult.class); searchResult.setBest_embedding(this.transfer2FloatList(searchResult.getBest_embedding())); searchResult.setAll_embedding(this.transfer2FloatList(searchResult.getAll_embedding())); return searchResult; } else { logger.error("调用 {} 结果返回失败:{}", URL_SEARCH_SIMILAR, responseEntity.getStatusCodeValue()); return null; } } catch (Exception ex){ logger.error("获取搜索语句向量结果错误:{}" + ex.getMessage(), ex); return null; } } private boolean acquireEmbedding(List batchData){ ParamList paramList = new ParamList(); for(EventVo eventVo : batchData){ paramList.add(eventVo.getContent()); } try{ ResponseEntity responseEntity = this.restTemplate.postForEntity(URL_EMBEDDING, paramList, ResponseValue.class); if(responseEntity.getStatusCode() == HttpStatus.OK){ List> data = (List>)responseEntity.getBody().getData(); // double[] one = null; List vector = null; for(int i=0; i transfer2FloatList(List list){ List vector = new ArrayList<>(768); for(int i=0; i eventVoList; public WriteTask(Collection eventVoList){ this.eventVoList = eventVoList; } @Override public void run() { logger.info(".......... start task ..."); // 复位状态 isBreak = false; int count = 0; List batchData = new ArrayList<>(); for(EventVo eventVo : this.eventVoList){ // 已存在上次加载记录,已写入的不在重新处理 if(existId != null && eventVo.getId() < existId.longValue()){ continue; } if(isBreak){ break; } if(count == 0){ logger.info("1) 开始(或继续)采集第 {} 记录", eventVo.getId()); } if(count >= 8){ // 触发一次批量写入 logger.info("2) 触发一次调用:{}", batchData.get(7).getId()); try { boolean successEmbedding = acquireEmbedding(batchData); if(!successEmbedding){ logger.error("获取向量失败,任务结束"); break; } // 获取向量,写入数据库 milvusEngine.insertEventVoList(batchData); } catch (Exception ex){ logger.error("error = " + ex.getMessage(), ex); // existId = eventVo.getId(); isBreak = true; logger.error("3) 采集任务异常, 当前 id = {}", existId == null ? "" : existId); break; } finally { try { TimeUnit.SECONDS.sleep(1); } catch (InterruptedException e) { e.printStackTrace(); } } // 清理本次数据 batchData.clear(); count = 0; } batchData.add(eventVo); count ++; } logger.info(".......... end task ..."); } } private Collection acquireEventVoList(){ List data = null; // if(this.isBreak){ // data = this.eventService.queryEventAll(existId); // } else { // data = this.eventService.queryEventAll(null); // } data = this.eventService.queryEventAll(null); if(StringUtils.isEmptyList(data)){ return null; } // 过滤掉 content 重复的数据记录 LinkedHashMap cache = new LinkedHashMap(); EventVo temp = null; for(EventVo e : data){ temp = cache.get(e.getContent()); if(temp == null){ cache.put(e.getContent(), e); } } return cache.values(); } }