package com.iplatform.api; import com.iplatform.base.SystemController; import com.iplatform.milvus.EventVo; import com.iplatform.milvus.MilvusEngine; import com.iplatform.milvus.ParamList; import com.iplatform.milvus.service.EventServiceImpl; import com.walker.infrastructure.utils.StringUtils; import com.walker.web.ResponseValue; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.client.RestTemplate; import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.concurrent.TimeUnit; @RestController @RequestMapping("/test/milvus") public class MilvusChatApi extends SystemController { private EventServiceImpl eventService; private RestTemplate restTemplate; private MilvusEngine milvusEngine; private Long existId = 4048L; private boolean isBreak = false; private static final String URL_EMBEDDING = "http://120.26.128.84:7003/ai/text/embedding"; private static final String URL_MILVUS = "120.26.128.84"; @Autowired public MilvusChatApi(EventServiceImpl eventService, RestTemplate restTemplate){ this.eventService = eventService; this.restTemplate = restTemplate; if(this.milvusEngine == null){ MilvusEngine engine = new MilvusEngine(URL_MILVUS, 19530); this.milvusEngine = engine; logger.info("milvus engine ok!"); } } @RequestMapping("/embedding") public ResponseValue testHttpEmbedding(){ List data = new ArrayList(8); EventVo eventVo = new EventVo(); eventVo.setContent("第一句"); data.add(eventVo); eventVo = new EventVo(); eventVo.setContent("第二句"); data.add(eventVo); boolean success = this.acquireEmbedding(data); this.milvusEngine.insertTestData(); return ResponseValue.success("结果是:" + success); } @RequestMapping("/write") public ResponseValue testWriteMilvus(){ Collection eventVoList = this.acquireEventVoList(); if(eventVoList == null){ return ResponseValue.error("没有加载到数据"); } logger.info("加载了 event vo: {}个", eventVoList.size()); new Thread(new WriteTask(eventVoList)).start(); return ResponseValue.success(); } @RequestMapping("/query") public ResponseValue testQueryMilvus(String text){ return ResponseValue.success(); } private boolean acquireEmbedding(List batchData){ ParamList paramList = new ParamList(); for(EventVo eventVo : batchData){ paramList.add(eventVo.getContent()); } try{ ResponseEntity responseEntity = this.restTemplate.postForEntity(URL_EMBEDDING, paramList, ResponseValue.class); if(responseEntity.getStatusCode() == HttpStatus.OK){ List> data = (List>)responseEntity.getBody().getData(); // double[] one = null; List vector = null; for(int i=0; i transfer2FloatList(List list){ List vector = new ArrayList<>(768); for(int i=0; i eventVoList; public WriteTask(Collection eventVoList){ this.eventVoList = eventVoList; } @Override public void run() { logger.info(".......... start task ..."); // 复位状态 isBreak = false; int count = 0; List batchData = new ArrayList<>(); for(EventVo eventVo : this.eventVoList){ // 已存在上次加载记录,已写入的不在重新处理 if(existId != null && eventVo.getId() < existId.longValue()){ continue; } if(isBreak){ break; } if(count == 0){ logger.info("1) 开始(或继续)采集第 {} 记录", eventVo.getId()); } if(count >= 8){ // 触发一次批量写入 logger.info("2) 触发一次调用:{}", batchData.get(7).getId()); try { boolean successEmbedding = acquireEmbedding(batchData); if(!successEmbedding){ logger.error("获取向量失败,任务结束"); break; } // 获取向量,写入数据库 milvusEngine.insertEventVoList(batchData); } catch (Exception ex){ logger.error("error = " + ex.getMessage(), ex); // existId = eventVo.getId(); isBreak = true; logger.error("3) 采集任务异常, 当前 id = {}", existId == null ? "" : existId); break; } finally { try { TimeUnit.SECONDS.sleep(1); } catch (InterruptedException e) { e.printStackTrace(); } } // 清理本次数据 batchData.clear(); count = 0; } batchData.add(eventVo); count ++; } logger.info(".......... end task ..."); } } private Collection acquireEventVoList(){ List data = null; // if(this.isBreak){ // data = this.eventService.queryEventAll(existId); // } else { // data = this.eventService.queryEventAll(null); // } data = this.eventService.queryEventAll(null); if(StringUtils.isEmptyList(data)){ return null; } // 过滤掉 content 重复的数据记录 LinkedHashMap cache = new LinkedHashMap(); EventVo temp = null; for(EventVo e : data){ temp = cache.get(e.getContent()); if(temp == null){ cache.put(e.getContent(), e); } } return cache.values(); } }