package com.iplatform.api;
|
|
import com.iplatform.base.SystemController;
|
import com.iplatform.milvus.EventVo;
|
import com.iplatform.milvus.MilvusEngine;
|
import com.iplatform.milvus.ParamList;
|
import com.iplatform.milvus.service.EventServiceImpl;
|
import com.walker.infrastructure.utils.StringUtils;
|
import com.walker.web.ResponseValue;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.http.HttpStatus;
|
import org.springframework.http.ResponseEntity;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.client.RestTemplate;
|
|
import java.util.ArrayList;
|
import java.util.Collection;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.concurrent.TimeUnit;
|
|
@RestController
|
@RequestMapping("/test/milvus")
|
public class MilvusChatApi extends SystemController {
|
|
private EventServiceImpl eventService;
|
private RestTemplate restTemplate;
|
private MilvusEngine milvusEngine;
|
|
private Long existId = 4048L;
|
private boolean isBreak = false;
|
|
private static final String URL_EMBEDDING = "http://120.26.128.84:7003/ai/text/embedding";
|
private static final String URL_MILVUS = "120.26.128.84";
|
|
@Autowired
|
public MilvusChatApi(EventServiceImpl eventService, RestTemplate restTemplate){
|
this.eventService = eventService;
|
this.restTemplate = restTemplate;
|
if(this.milvusEngine == null){
|
MilvusEngine engine = new MilvusEngine(URL_MILVUS, 19530);
|
this.milvusEngine = engine;
|
logger.info("milvus engine ok!");
|
}
|
}
|
|
@RequestMapping("/embedding")
|
public ResponseValue testHttpEmbedding(){
|
List<EventVo> data = new ArrayList(8);
|
EventVo eventVo = new EventVo();
|
eventVo.setContent("第一句");
|
data.add(eventVo);
|
|
eventVo = new EventVo();
|
eventVo.setContent("第二句");
|
data.add(eventVo);
|
|
boolean success = this.acquireEmbedding(data);
|
this.milvusEngine.insertTestData();
|
|
return ResponseValue.success("结果是:" + success);
|
}
|
|
@RequestMapping("/write")
|
public ResponseValue testWriteMilvus(){
|
Collection<EventVo> eventVoList = this.acquireEventVoList();
|
if(eventVoList == null){
|
return ResponseValue.error("没有加载到数据");
|
}
|
logger.info("加载了 event vo: {}个", eventVoList.size());
|
|
new Thread(new WriteTask(eventVoList)).start();
|
return ResponseValue.success();
|
}
|
|
@RequestMapping("/query")
|
public ResponseValue testQueryMilvus(String text){
|
|
return ResponseValue.success();
|
}
|
|
private boolean acquireEmbedding(List<EventVo> batchData){
|
ParamList paramList = new ParamList();
|
for(EventVo eventVo : batchData){
|
paramList.add(eventVo.getContent());
|
}
|
try{
|
ResponseEntity<ResponseValue> responseEntity = this.restTemplate.postForEntity(URL_EMBEDDING, paramList, ResponseValue.class);
|
if(responseEntity.getStatusCode() == HttpStatus.OK){
|
List<List<Float>> data = (List<List<Float>>)responseEntity.getBody().getData();
|
// double[] one = null;
|
List<Float> vector = null;
|
for(int i=0; i<data.size(); i++){
|
vector = data.get(i);
|
// one = new double[vector.size()];
|
// for(int j=0; j< vector.size(); j++){
|
// one[j] = vector.get(j);
|
// }
|
|
// logger.debug("data = {}", vector);
|
// logger.debug("class type = {}", vector.getClass().getName());
|
batchData.get(i).setEmbedding(this.transfer2FloatList(vector));
|
}
|
return true;
|
} else {
|
logger.error("http 返回错误:{}", responseEntity.getBody());
|
}
|
return false;
|
} catch (Exception cause){
|
logger.error("获取向量出现错误:{}" + cause.getMessage(), cause);
|
return false;
|
}
|
}
|
|
private List<Float> transfer2FloatList(List<?> list){
|
List<Float> vector = new ArrayList<>(768);
|
for(int i=0; i<list.size(); i++){
|
vector.add(Float.parseFloat(list.get(i).toString()));
|
}
|
return vector;
|
}
|
|
private class WriteTask implements Runnable{
|
|
private Collection<EventVo> eventVoList;
|
|
public WriteTask(Collection<EventVo> eventVoList){
|
this.eventVoList = eventVoList;
|
}
|
|
@Override
|
public void run() {
|
logger.info(".......... start task ...");
|
|
// 复位状态
|
isBreak = false;
|
|
int count = 0;
|
List<EventVo> batchData = new ArrayList<>();
|
for(EventVo eventVo : this.eventVoList){
|
// 已存在上次加载记录,已写入的不在重新处理
|
if(existId != null && eventVo.getId() < existId.longValue()){
|
continue;
|
}
|
|
if(isBreak){
|
break;
|
}
|
|
if(count == 0){
|
logger.info("1) 开始(或继续)采集第 {} 记录", eventVo.getId());
|
}
|
|
if(count >= 8){
|
// 触发一次批量写入
|
logger.info("2) 触发一次调用:{}", batchData.get(7).getId());
|
try {
|
boolean successEmbedding = acquireEmbedding(batchData);
|
if(!successEmbedding){
|
logger.error("获取向量失败,任务结束");
|
break;
|
}
|
|
// 获取向量,写入数据库
|
milvusEngine.insertEventVoList(batchData);
|
|
} catch (Exception ex){
|
logger.error("error = " + ex.getMessage(), ex);
|
// existId = eventVo.getId();
|
isBreak = true;
|
logger.error("3) 采集任务异常, 当前 id = {}", existId == null ? "" : existId);
|
break;
|
|
} finally {
|
try {
|
TimeUnit.SECONDS.sleep(1);
|
} catch (InterruptedException e) {
|
e.printStackTrace();
|
}
|
}
|
|
// 清理本次数据
|
batchData.clear();
|
count = 0;
|
}
|
batchData.add(eventVo);
|
count ++;
|
}
|
logger.info(".......... end task ...");
|
}
|
}
|
|
private Collection<EventVo> acquireEventVoList(){
|
List<EventVo> data = null;
|
// if(this.isBreak){
|
// data = this.eventService.queryEventAll(existId);
|
// } else {
|
// data = this.eventService.queryEventAll(null);
|
// }
|
data = this.eventService.queryEventAll(null);
|
if(StringUtils.isEmptyList(data)){
|
return null;
|
}
|
|
// 过滤掉 content 重复的数据记录
|
LinkedHashMap<String, EventVo> cache = new LinkedHashMap();
|
EventVo temp = null;
|
for(EventVo e : data){
|
temp = cache.get(e.getContent());
|
if(temp == null){
|
cache.put(e.getContent(), e);
|
}
|
}
|
return cache.values();
|
}
|
}
|