package com.walker.support.milvus.engine;
|
|
import com.walker.support.milvus.DataSet;
|
import com.walker.support.milvus.FieldType;
|
import com.walker.support.milvus.OperateService;
|
import com.walker.support.milvus.OutData;
|
import com.walker.support.milvus.Query;
|
import com.walker.support.milvus.Table;
|
import com.walker.support.milvus.util.FieldTypeUtils;
|
import io.milvus.client.MilvusServiceClient;
|
import io.milvus.grpc.MutationResult;
|
import io.milvus.grpc.SearchResults;
|
import io.milvus.param.ConnectParam;
|
import io.milvus.param.IndexType;
|
import io.milvus.param.MetricType;
|
import io.milvus.param.R;
|
import io.milvus.param.RpcStatus;
|
import io.milvus.param.collection.CreateCollectionParam;
|
import io.milvus.param.collection.DropCollectionParam;
|
import io.milvus.param.collection.LoadCollectionParam;
|
import io.milvus.param.collection.ReleaseCollectionParam;
|
import io.milvus.param.dml.InsertParam;
|
import io.milvus.param.dml.SearchParam;
|
import io.milvus.param.index.CreateIndexParam;
|
import io.milvus.param.index.DropIndexParam;
|
import io.milvus.response.SearchResultsWrapper;
|
import org.slf4j.Logger;
|
import org.slf4j.LoggerFactory;
|
|
import java.util.ArrayList;
|
import java.util.List;
|
import java.util.Map;
|
|
public class DefaultOperateService implements OperateService {
|
|
protected final transient Logger logger = LoggerFactory.getLogger(this.getClass());
|
|
private MilvusServiceClient client = null;
|
|
@Override
|
public boolean connect(String ip, int port) {
|
if(client != null){
|
this.client.close();
|
logger.warn("MilvusServiceClient在运行,正在停止,准备创建新对象: " + ip + ", " + port);
|
}
|
try {
|
this.client = new MilvusServiceClient(ConnectParam.newBuilder()
|
.withHost(ip)
|
.withPort(port)
|
.build());
|
return true;
|
} catch (Exception ex){
|
logger.error("创建 MilvusServiceClient 错误:" + ip, ex);
|
return false;
|
}
|
}
|
|
@Override
|
public void close() {
|
if(this.client != null){
|
this.client.close();
|
}
|
}
|
|
@Override
|
public boolean createTable(Table table) {
|
this.checkConnection();
|
if(table == null){
|
logger.error("table 必须提供");
|
return false;
|
}
|
List<FieldType> fieldList = table.getFieldTypes();
|
if(fieldList == null || fieldList.size() == 0){
|
logger.error("未找到任何字段信息");
|
return false;
|
}
|
String tableName = table.getCollectionName();
|
if(tableName == null || tableName.equals("")){
|
logger.error("表名必须提供:tableName");
|
return false;
|
}
|
|
try{
|
List<io.milvus.param.collection.FieldType> milvusFieldList = new ArrayList<>();
|
for(FieldType ft : fieldList){
|
milvusFieldList.add(FieldTypeUtils.toMilvusFieldType(ft, table.getDimension()));
|
}
|
|
CreateCollectionParam.Builder builder = CreateCollectionParam.newBuilder();
|
CreateCollectionParam param = builder.withCollectionName(tableName)
|
.withDescription(table.getDescription())
|
.withShardsNum(table.getShardsNum())
|
.withFieldTypes(milvusFieldList)
|
.build();
|
R<RpcStatus> statusR = this.client.createCollection(param);
|
return this.checkStatusR(statusR);
|
} catch (Exception ex){
|
logger.error("创建向量表失败:" + tableName, ex);
|
return false;
|
}
|
}
|
|
@Override
|
public void dropTable(String tableName){
|
this.checkConnection();
|
if(tableName == null || tableName.equals("")){
|
logger.error("表名必须提供:tableName");
|
return;
|
}
|
this.client.dropCollection(DropCollectionParam.newBuilder()
|
.withCollectionName(tableName)
|
.build());
|
}
|
|
@Override
|
public boolean insertDataSet(DataSet dataSet){
|
this.checkConnection();
|
if(dataSet == null){
|
return false;
|
}
|
String tableName = dataSet.getTableName();
|
if(tableName == null || tableName.equals("")){
|
logger.error("表名必须提供:tableName");
|
return false;
|
}
|
Map<String, List<?>> fields = dataSet.getFields();
|
if(fields == null || fields.size() == 0){
|
logger.error("数据集合必须提供:fields");
|
return false;
|
}
|
|
List<InsertParam.Field> fieldList = new ArrayList<>();
|
for(Map.Entry<String, List<?>> entry : fields.entrySet()){
|
fieldList.add(new InsertParam.Field(entry.getKey(), entry.getValue()));
|
}
|
|
InsertParam.Builder builder = InsertParam.newBuilder();
|
builder.withCollectionName(dataSet.getTableName());
|
if(dataSet.getPartitionName() != null && !dataSet.getPartitionName().equals("")){
|
builder.withPartitionName(dataSet.getPartitionName());
|
}
|
builder.withFields(fieldList);
|
|
InsertParam insertParam = builder.build();
|
R<MutationResult> statusR = this.client.insert(insertParam);
|
if(statusR == null){
|
return false;
|
}
|
if(statusR.getStatus().intValue() == R.Status.Success.getCode()){
|
logger.debug("insert 返回值:" + statusR.getStatus().intValue());
|
return true;
|
}
|
return false;
|
}
|
|
@Override
|
public boolean createIndex(String tableName, String fieldName, String indexType, String indexParam
|
, com.walker.support.milvus.MetricType myMetricType){
|
this.checkConnection();
|
IndexType INDEX_TYPE = null;
|
if(indexType.equals("IVF_FLAT")){
|
INDEX_TYPE = IndexType.IVF_FLAT;
|
} else if(indexType.equals("IVF_SQ8")){
|
INDEX_TYPE = IndexType.IVF_SQ8;
|
} else if(indexType.equals("IVF_PQ")){
|
INDEX_TYPE = IndexType.IVF_PQ;
|
} else if(indexType.equals("HNSW")){
|
INDEX_TYPE = IndexType.HNSW;
|
}
|
else if(indexType.equals("ANNOY")){
|
// INDEX_TYPE = IndexType.ANNOY;
|
throw new UnsupportedOperationException("新版本已不支持:ANNOY");
|
}
|
else if(indexType.equals("FLAT")){
|
INDEX_TYPE = IndexType.FLAT;
|
} else if(indexType.equals("GPU_IVF_FLAT")){
|
INDEX_TYPE = IndexType.GPU_IVF_FLAT;
|
} else if(indexType.equals("GPU_IVF_PQ")){
|
INDEX_TYPE = IndexType.GPU_IVF_PQ;
|
} else if(indexType.equals("SCANN")){
|
INDEX_TYPE = IndexType.SCANN;
|
} else {
|
throw new IllegalArgumentException("暂不支持其他索引类型:" + indexType);
|
}
|
|
/**
|
* **欧氏距离 (L2)**: 主要运用于计算机视觉领域。
|
* **内积 (IP)**: 主要运用于自然语言处理(NLP)领域。
|
* @date 2024-03-26
|
*/
|
CreateIndexParam.Builder builder = CreateIndexParam.newBuilder();
|
builder.withCollectionName(tableName)
|
.withFieldName(fieldName)
|
.withIndexName(fieldName + "_index")
|
.withIndexType(INDEX_TYPE)
|
// .withMetricType(MetricType.L2)
|
.withExtraParam(indexParam)
|
.withSyncMode(false);
|
if(myMetricType == com.walker.support.milvus.MetricType.NLP){
|
builder.withMetricType(MetricType.IP);
|
} else if(myMetricType == com.walker.support.milvus.MetricType.IMAGE){
|
builder.withMetricType(MetricType.L2);
|
} else {
|
throw new UnsupportedOperationException("暂时不支持距离类型:" + myMetricType);
|
}
|
|
R<RpcStatus> statusR = this.client.createIndex(builder.build());
|
return checkStatusR(statusR);
|
}
|
|
@Override
|
public boolean dropIndex(String tableName, String fieldName){
|
this.checkConnection();
|
R<RpcStatus> statusR = this.client.dropIndex(DropIndexParam.newBuilder()
|
.withCollectionName(tableName)
|
.withIndexName(fieldName + "_index")
|
.build());
|
return checkStatusR(statusR);
|
}
|
|
@Override
|
public boolean prepareSearch(String tableName){
|
this.checkConnection();
|
R<RpcStatus> statusR = this.client.loadCollection(LoadCollectionParam.newBuilder().withCollectionName(tableName).build());
|
// if(statusR == null){
|
// return false;
|
// }
|
// if(statusR.getStatus().intValue() == R.Status.Success.getCode()){
|
// return true;
|
// }
|
// return false;
|
return checkStatusR(statusR);
|
}
|
|
@Override
|
public OutData searchVector(Query query){
|
this.checkConnection();
|
List<List<Float>> search_vectors = query.getSearchVectors();
|
if(search_vectors == null){
|
logger.error("未设置搜索向量条件:search_vectors");
|
return null;
|
}
|
String vectorField = query.getVectorName();
|
if(vectorField == null || vectorField.equals("")){
|
logger.error("未设置搜索字段名称:vectorField");
|
return null;
|
}
|
|
List<String> outputFieldList = query.getOutFieldList();
|
if(outputFieldList == null || outputFieldList.size() == 0){
|
logger.error("未设置输出字段名称:OutFieldList");
|
return null;
|
}
|
MetricType metricType = null;
|
if(query.getMetricType() == null || query.getMetricType().equals("")){
|
metricType = MetricType.L2;
|
} else if(query.getMetricType().equals(com.walker.support.milvus.MetricType.INDEX_IMAGE)){
|
metricType = MetricType.L2;
|
} else if(query.getMetricType().equals(com.walker.support.milvus.MetricType.INDEX_NLP)){
|
metricType = MetricType.IP;
|
} else {
|
throw new UnsupportedOperationException("暂未支持的距离类型:" + query.getMetricType());
|
}
|
|
SearchParam searchParam = SearchParam.newBuilder()
|
.withCollectionName(query.getTableName())
|
.withMetricType(metricType)
|
.withOutFields(outputFieldList)
|
.withTopK(query.getTopK())
|
.withVectors(query.getSearchVectors())
|
.withVectorFieldName(query.getVectorName())
|
.withParams(query.getSearchParam())
|
.build();
|
|
R<SearchResults> respSearch = this.client.search(searchParam);
|
if(respSearch == null){
|
logger.warn("未搜索到相似结果对象:" + query);
|
return null;
|
}
|
SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(respSearch.getData().getResults());
|
System.out.println(wrapperSearch.getIDScore(0));
|
|
// 设置一个分值,评分过低的结果直接过滤。2022-08-26
|
OutData outData = new OutData();
|
|
List<SearchResultsWrapper.IDScore> scoreList = wrapperSearch.getIDScore(0);
|
if(scoreList != null && scoreList.size() > 0){
|
for(SearchResultsWrapper.IDScore idScore : scoreList){
|
outData.addIdScore(idScore.getLongID(), idScore.getScore());
|
}
|
}
|
for(String outField : outputFieldList){
|
if(outField.equals("id")){
|
outData.setKeyList((List<Long>)wrapperSearch.getFieldData("id", 0));
|
} else {
|
outData.setBusinessIdList((List<String>)wrapperSearch.getFieldData(outField, 0));
|
}
|
}
|
// System.out.println(wrapperSearch.getFieldData("book_id", 0));
|
// return wrapperSearch.getFieldData(query.getFieldPrimaryKey(), 0);
|
return outData;
|
}
|
|
@Override
|
public void releaseSearch(String tableName){
|
this.checkConnection();
|
this.client.releaseCollection(ReleaseCollectionParam.newBuilder()
|
.withCollectionName(tableName)
|
.build());
|
}
|
|
private void checkConnection(){
|
if(this.client == null){
|
throw new RuntimeException("服务未连接,请先连接 milvus 服务");
|
}
|
}
|
|
private boolean checkStatusR(R<RpcStatus> statusR){
|
if(statusR == null){
|
return false;
|
}
|
if(statusR.getStatus().intValue() == R.Status.Success.getCode()){
|
return true;
|
}
|
return false;
|
}
|
}
|