| | |
| | | import com.iplatform.milvus.EventVo; |
| | | import com.iplatform.milvus.MilvusEngine; |
| | | import com.iplatform.milvus.ParamList; |
| | | import com.iplatform.milvus.ScoreText; |
| | | import com.iplatform.milvus.SearchResult; |
| | | import com.iplatform.milvus.service.EventServiceImpl; |
| | | import com.walker.infrastructure.utils.FileUtils; |
| | | import com.walker.infrastructure.utils.JsonUtils; |
| | | import com.walker.infrastructure.utils.StringUtils; |
| | | import com.walker.support.milvus.OutData; |
| | | import com.walker.web.ResponseValue; |
| | | import org.springframework.beans.factory.annotation.Autowired; |
| | | import org.springframework.http.HttpStatus; |
| | |
| | | import java.util.Collection; |
| | | import java.util.LinkedHashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | import java.util.concurrent.TimeUnit; |
| | | |
| | | @RestController |
| | |
| | | private boolean isBreak = false; |
| | | |
| | | private static final String URL_EMBEDDING = "http://120.26.128.84:7003/ai/text/embedding"; |
| | | private static final String URL_SEARCH_SIMILAR = "http://120.26.128.84:7003/ai/text/search_similar"; |
| | | private static final String URL_MILVUS = "120.26.128.84"; |
| | | private static final double BEST_MATCH_SCORE = 0.75; |
| | | |
| | | @Autowired |
| | | public MilvusChatApi(EventServiceImpl eventService, RestTemplate restTemplate){ |
| | |
| | | if(this.milvusEngine == null){ |
| | | MilvusEngine engine = new MilvusEngine(URL_MILVUS, 19530); |
| | | this.milvusEngine = engine; |
| | | this.milvusEngine.loadChatSimilar4Search(); |
| | | logger.info("milvus engine ok!"); |
| | | } |
| | | } |
| | |
| | | |
| | | @RequestMapping("/query") |
| | | public ResponseValue testQueryMilvus(String text){ |
| | | if(StringUtils.isEmpty(text)){ |
| | | return ResponseValue.error("text is required!"); |
| | | } |
| | | // 用分号分隔多个语句 |
| | | String[] array = StringUtils.toArray(text); |
| | | ParamList paramList = new ParamList(); |
| | | for(String value : array){ |
| | | paramList.add(value); |
| | | } |
| | | |
| | | return ResponseValue.success(); |
| | | SearchResult searchResult = this.acquireSearchResult(paramList); |
| | | if(searchResult == null){ |
| | | logger.warn("远程调用最匹配语句向量失败,无法继续搜索相似度"); |
| | | return ResponseValue.error("远程调用最匹配语句向量失败,无法继续搜索相似度"); |
| | | } |
| | | |
| | | List<List<Float>> requestVectors = new ArrayList(4); |
| | | |
| | | if(!StringUtils.isEmptyList(searchResult.getBest_list()) && searchResult.getBest_embedding() != null){ |
| | | ScoreText bestScoreText = searchResult.getBest_list().get(0); |
| | | logger.debug("在对话集合中,存在最匹配的句子: {}", bestScoreText); |
| | | if(bestScoreText.getScore() >= BEST_MATCH_SCORE){ |
| | | logger.debug("最匹配的分值: {} 大于设置阈值,可以直接作为查询基准", bestScoreText.getScore()); |
| | | requestVectors.add(searchResult.getBest_embedding()); |
| | | } else { |
| | | requestVectors.add(searchResult.getBest_embedding()); |
| | | requestVectors.add(searchResult.getAll_embedding()); |
| | | } |
| | | } else { |
| | | logger.debug("只使用全量查询向量:{}", text); |
| | | requestVectors.add(searchResult.getAll_embedding()); |
| | | } |
| | | |
| | | OutData outData = this.milvusEngine.searchChatSimilar(requestVectors); |
| | | if(outData == null){ |
| | | return ResponseValue.error("未检索到任何匹配相似提问"); |
| | | } |
| | | |
| | | List<OutData.Data> dataList = outData.getResultList(); |
| | | for(OutData.Data d : dataList){ |
| | | logger.info("data = {}", d); |
| | | } |
| | | |
| | | |
| | | return ResponseValue.success(dataList); |
| | | } |
| | | |
| | | private SearchResult acquireSearchResult(ParamList paramList){ |
| | | try { |
| | | ResponseEntity<ResponseValue> responseEntity = this.restTemplate.postForEntity(URL_SEARCH_SIMILAR, paramList, ResponseValue.class); |
| | | // logger.debug(responseEntity.toString()); |
| | | if(responseEntity.getStatusCode() == HttpStatus.OK){ |
| | | ResponseValue<Map<String, Object>> responseValue = responseEntity.getBody(); |
| | | if(responseValue.getCode() != 1){ |
| | | logger.error("调用返回acquireSearchResult返回状态错误:{}", responseValue.getMsg()); |
| | | return null; |
| | | } |
| | | Map<String, Object> map = responseValue.getData(); |
| | | logger.debug("map = {}", map); |
| | | // String bestListJson = JsonUtils.objectToJsonString(map.get("best_list")); |
| | | // List<ScoreText> bestList = JsonUtils.jsonStringToList(bestListJson, ScoreText.class); |
| | | // logger.debug("bestList = {}", bestList); |
| | | // SearchResult searchResult = new SearchResult(); |
| | | // searchResult.setBest_embedding(this.transfer2FloatList((List<Float>)map.get("best_embedding"))); |
| | | // searchResult.setAll_embedding(this.transfer2FloatList((List<Float>)map.get("all_embedding"))); |
| | | String json = JsonUtils.objectToJsonString(map); |
| | | SearchResult searchResult = JsonUtils.jsonStringToObject(json, SearchResult.class); |
| | | searchResult.setBest_embedding(this.transfer2FloatList(searchResult.getBest_embedding())); |
| | | searchResult.setAll_embedding(this.transfer2FloatList(searchResult.getAll_embedding())); |
| | | return searchResult; |
| | | } else { |
| | | logger.error("调用 {} 结果返回失败:{}", URL_SEARCH_SIMILAR, responseEntity.getStatusCodeValue()); |
| | | return null; |
| | | } |
| | | } catch (Exception ex){ |
| | | logger.error("获取搜索语句向量结果错误:{}" + ex.getMessage(), ex); |
| | | return null; |
| | | } |
| | | } |
| | | |
| | | private boolean acquireEmbedding(List<EventVo> batchData){ |