package com.iplatform.milvus;
import com.walker.infrastructure.utils.StringUtils;
import com.walker.support.milvus.DataSet;
import com.walker.support.milvus.DataType;
import com.walker.support.milvus.FieldType;
import com.walker.support.milvus.MetricType;
import com.walker.support.milvus.OperateService;
import com.walker.support.milvus.OutData;
import com.walker.support.milvus.Query;
import com.walker.support.milvus.Table;
import com.walker.support.milvus.engine.DefaultOperateService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class MilvusEngine {
protected final transient Logger logger = LoggerFactory.getLogger(this.getClass());
public static final String TABLE_CHAT_SIMILAR = "chat_similar";
// private static final int VECTOR_DIMENSION = 768;
private static final int VECTOR_DIMENSION = 512;
public MilvusEngine(String ip, int port){
DefaultOperateService service = new DefaultOperateService();
service.connect(ip, port);
this.operateService = service;
logger.info("connect milvus: {}:{}", ip, port);
}
public void close(){
if(this.operateService != null){
this.operateService.close();
}
}
/**
* 创建表:测试从聊天一键提取工单内容使用。
*
* 1) 从历史工单数据中,收集用户提问内容,整理到表中
* 2) 把这些数据通过向量转化,写入milvus数据库。
*
* @date 2024-03-28
*/
public void createChatSimilarTable(){
Table chatSimilarTable = new Table();
chatSimilarTable.setCollectionName(TABLE_CHAT_SIMILAR);
chatSimilarTable.setDescription("聊天提取工单摘要历史数据");
chatSimilarTable.setShardsNum(1);
// chatSimilarTable.setDimension(768); // 这个是根据使用向量模型维度定的
chatSimilarTable.setDimension(VECTOR_DIMENSION); // 这个是根据使用向量模型维度定的
// 设置字段
FieldType id = FieldType.newBuilder()
.withName("id").withPrimaryKey(true).withMaxLength(18).withDataType(DataType.Long).build();
FieldType title = FieldType.newBuilder()
.withName("title").withPrimaryKey(false).withMaxLength(180).withDataType(DataType.VarChar).build();
FieldType content = FieldType.newBuilder()
.withName("content").withPrimaryKey(false).withMaxLength(255).withDataType(DataType.VarChar).build();
FieldType answer = FieldType.newBuilder()
.withName("answer").withPrimaryKey(false).withMaxLength(255).withDataType(DataType.VarChar).build();
FieldType embedding = FieldType.newBuilder()
.withName("embedding").withPrimaryKey(false).withDataType(DataType.FloatVector).withDimension(VECTOR_DIMENSION).build();
List fieldTypeList = new ArrayList<>(8);
fieldTypeList.add(id);
fieldTypeList.add(title);
fieldTypeList.add(content);
fieldTypeList.add(answer);
fieldTypeList.add(embedding);
chatSimilarTable.setFieldTypes(fieldTypeList);
this.operateService.createTable(chatSimilarTable);
logger.info("创建了 table = {}", chatSimilarTable.getCollectionName());
// 创建索引
this.operateService.createIndex(chatSimilarTable.getCollectionName()
, "embedding", "HNSW", "{\"nlist\":16384, \"efConstruction\":128, \"M\":8}", MetricType.NLP);
logger.info("创建了 index = {}", chatSimilarTable.getCollectionName() + "_index");
}
public void dropChatSimilarTable(){
this.operateService.dropTable("chat_similar");
this.operateService.dropIndex("chat_similar", "chat_similar_index");
}
public void insertTestData(){
DataSet dataSet = new DataSet();
dataSet.setTableName(TABLE_CHAT_SIMILAR);
List> vectorList = new ArrayList<>();
vectorList.add(Arrays.asList(mockVector));
vectorList.add(Arrays.asList(mockVector));
Map> fieldMap = new HashMap();
fieldMap.put("id", Arrays.asList(new Long[]{1L, 2L}));
fieldMap.put("title", Arrays.asList(new String[]{"第一个标题", "第二个标题"}));
fieldMap.put("content", Arrays.asList(new String[]{"第一个内容", "2222"}));
fieldMap.put("answer", Arrays.asList(new String[]{"第一个答案", "22222222"}));
fieldMap.put("embedding", vectorList);
dataSet.setFields(fieldMap);
this.operateService.insertDataSet(dataSet);
logger.info("写入了测试数据: {}", dataSet);
}
public void insertEventVoList(List batchData){
if(StringUtils.isEmptyList(batchData)){
return;
}
List ids = new ArrayList<>(8);
List titles = new ArrayList<>(8);
List contents = new ArrayList<>(8);
List answers = new ArrayList<>(8);
List> vectorSet = new ArrayList<>(8);
for(EventVo vo : batchData){
ids.add(vo.getId());
vectorSet.add(vo.getEmbedding());
if(StringUtils.isNotEmpty(vo.getTitle())){
titles.add(vo.getTitle());
} else {
titles.add("none");
}
if(StringUtils.isNotEmpty(vo.getAnswer())){
answers.add(vo.getAnswer());
} else {
answers.add("none");
}
if(vo.getContent().length() > 200){
contents.add(vo.getContent().substring(0, 200));
} else {
contents.add(vo.getContent());
}
}
DataSet dataSet = new DataSet();
dataSet.setTableName(TABLE_CHAT_SIMILAR);
Map> fieldMap = new HashMap();
fieldMap.put("id", ids);
fieldMap.put("title", titles);
fieldMap.put("content", contents);
fieldMap.put("answer", answers);
fieldMap.put("embedding", vectorSet);
dataSet.setFields(fieldMap);
this.operateService.insertDataSet(dataSet);
logger.info("写入了: {}", ids);
}
public OutData searchChatSimilar(List> vectors){
Query query = new Query();
query.setMetricType(MetricType.NLP.getIndex());
query.setTableName(TABLE_CHAT_SIMILAR);
query.setTopK(4);
query.setVectorName("embedding");
query.setOutFieldList(Arrays.asList(new String[]{"id","title","content"}));
query.setFieldPrimaryKey("id");
query.setSearchVectors(vectors);
return this.operateService.searchVector(query);
}
/**
* 必须在查询之前,加载数据到内存中。
* @date 2024-03-31
*/
public void loadChatSimilar4Search(){
this.operateService.prepareSearch(TABLE_CHAT_SIMILAR);
}
private OperateService operateService;
// private Double[] mockVector = new Double[]{-0.051114246249198914, 0.889954432};
private Float[] mockVector = new Float[]{-0.051114246249198914f, 0.889954432f};
}