package com.walker.support.milvus.engine; import com.walker.support.milvus.DataSet; import com.walker.support.milvus.FieldType; import com.walker.support.milvus.OperateService; import com.walker.support.milvus.OutData; import com.walker.support.milvus.Query; import com.walker.support.milvus.Table; import com.walker.support.milvus.util.FieldTypeUtils; import io.milvus.client.MilvusServiceClient; import io.milvus.grpc.MutationResult; import io.milvus.grpc.SearchResults; import io.milvus.param.ConnectParam; import io.milvus.param.IndexType; import io.milvus.param.MetricType; import io.milvus.param.R; import io.milvus.param.RpcStatus; import io.milvus.param.collection.CreateCollectionParam; import io.milvus.param.collection.DropCollectionParam; import io.milvus.param.collection.LoadCollectionParam; import io.milvus.param.collection.ReleaseCollectionParam; import io.milvus.param.dml.InsertParam; import io.milvus.param.dml.SearchParam; import io.milvus.param.index.CreateIndexParam; import io.milvus.param.index.DropIndexParam; import io.milvus.response.SearchResultsWrapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; import java.util.Map; public class DefaultOperateService implements OperateService { protected final transient Logger logger = LoggerFactory.getLogger(this.getClass()); private MilvusServiceClient client = null; @Override public boolean connect(String ip, int port) { if(client != null){ this.client.close(); logger.warn("MilvusServiceClient在运行,正在停止,准备创建新对象: " + ip + ", " + port); } try { this.client = new MilvusServiceClient(ConnectParam.newBuilder() .withHost(ip) .withPort(port) .build()); return true; } catch (Exception ex){ logger.error("创建 MilvusServiceClient 错误:" + ip, ex); return false; } } @Override public void close() { if(this.client != null){ this.client.close(); } } @Override public boolean createTable(Table table) { this.checkConnection(); if(table == null){ logger.error("table 必须提供"); return false; } List fieldList = table.getFieldTypes(); if(fieldList == null || fieldList.size() == 0){ logger.error("未找到任何字段信息"); return false; } String tableName = table.getCollectionName(); if(tableName == null || tableName.equals("")){ logger.error("表名必须提供:tableName"); return false; } try{ List milvusFieldList = new ArrayList<>(); for(FieldType ft : fieldList){ milvusFieldList.add(FieldTypeUtils.toMilvusFieldType(ft, table.getDimension())); } CreateCollectionParam.Builder builder = CreateCollectionParam.newBuilder(); CreateCollectionParam param = builder.withCollectionName(tableName) .withDescription(table.getDescription()) .withShardsNum(table.getShardsNum()) .withFieldTypes(milvusFieldList) .build(); R statusR = this.client.createCollection(param); return this.checkStatusR(statusR); } catch (Exception ex){ logger.error("创建向量表失败:" + tableName, ex); return false; } } @Override public void dropTable(String tableName){ this.checkConnection(); if(tableName == null || tableName.equals("")){ logger.error("表名必须提供:tableName"); return; } this.client.dropCollection(DropCollectionParam.newBuilder() .withCollectionName(tableName) .build()); } @Override public boolean insertDataSet(DataSet dataSet){ this.checkConnection(); if(dataSet == null){ return false; } String tableName = dataSet.getTableName(); if(tableName == null || tableName.equals("")){ logger.error("表名必须提供:tableName"); return false; } Map> fields = dataSet.getFields(); if(fields == null || fields.size() == 0){ logger.error("数据集合必须提供:fields"); return false; } List fieldList = new ArrayList<>(); for(Map.Entry> entry : fields.entrySet()){ fieldList.add(new InsertParam.Field(entry.getKey(), entry.getValue())); } InsertParam.Builder builder = InsertParam.newBuilder(); builder.withCollectionName(dataSet.getTableName()); if(dataSet.getPartitionName() != null && !dataSet.getPartitionName().equals("")){ builder.withPartitionName(dataSet.getPartitionName()); } builder.withFields(fieldList); InsertParam insertParam = builder.build(); R statusR = this.client.insert(insertParam); if(statusR == null){ return false; } if(statusR.getStatus().intValue() == R.Status.Success.getCode()){ logger.error("insert 返回值:" + statusR.getStatus().intValue()); return true; } return false; } @Override public boolean createIndex(String tableName, String fieldName, String indexType, String indexParam , com.walker.support.milvus.MetricType myMetricType){ this.checkConnection(); IndexType INDEX_TYPE = null; if(indexType.equals("IVF_FLAT")){ INDEX_TYPE = IndexType.IVF_FLAT; } else if(indexType.equals("IVF_SQ8")){ INDEX_TYPE = IndexType.IVF_SQ8; } else if(indexType.equals("IVF_PQ")){ INDEX_TYPE = IndexType.IVF_PQ; } else if(indexType.equals("HNSW")){ INDEX_TYPE = IndexType.HNSW; } else if(indexType.equals("ANNOY")){ // INDEX_TYPE = IndexType.ANNOY; throw new UnsupportedOperationException("新版本已不支持:ANNOY"); } else if(indexType.equals("FLAT")){ INDEX_TYPE = IndexType.FLAT; } else if(indexType.equals("GPU_IVF_FLAT")){ INDEX_TYPE = IndexType.GPU_IVF_FLAT; } else if(indexType.equals("GPU_IVF_PQ")){ INDEX_TYPE = IndexType.GPU_IVF_PQ; } else if(indexType.equals("SCANN")){ INDEX_TYPE = IndexType.SCANN; } else { throw new IllegalArgumentException("暂不支持其他索引类型:" + indexType); } /** * ​**欧氏距离 (L2)**​: 主要运用于计算机视觉领域。 * ​**内积 (IP)**​: 主要运用于自然语言处理(NLP)领域。 * @date 2024-03-26 */ CreateIndexParam.Builder builder = CreateIndexParam.newBuilder(); builder.withCollectionName(tableName) .withFieldName(fieldName) .withIndexName(fieldName + "_index") .withIndexType(INDEX_TYPE) // .withMetricType(MetricType.L2) .withExtraParam(indexParam) .withSyncMode(false); if(myMetricType == com.walker.support.milvus.MetricType.NLP){ builder.withMetricType(MetricType.IP); } else if(myMetricType == com.walker.support.milvus.MetricType.IMAGE){ builder.withMetricType(MetricType.L2); } else { throw new UnsupportedOperationException("暂时不支持距离类型:" + myMetricType); } R statusR = this.client.createIndex(builder.build()); return checkStatusR(statusR); } @Override public boolean dropIndex(String tableName, String fieldName){ this.checkConnection(); R statusR = this.client.dropIndex(DropIndexParam.newBuilder() .withCollectionName(tableName) .withIndexName(fieldName + "_index") .build()); return checkStatusR(statusR); } @Override public boolean prepareSearch(String tableName){ this.checkConnection(); R statusR = this.client.loadCollection(LoadCollectionParam.newBuilder().withCollectionName(tableName).build()); // if(statusR == null){ // return false; // } // if(statusR.getStatus().intValue() == R.Status.Success.getCode()){ // return true; // } // return false; return checkStatusR(statusR); } @Override public OutData searchVector(Query query){ this.checkConnection(); List> search_vectors = query.getSearchVectors(); if(search_vectors == null){ logger.error("未设置搜索向量条件:search_vectors"); return null; } String vectorField = query.getVectorName(); if(vectorField == null || vectorField.equals("")){ logger.error("未设置搜索字段名称:vectorField"); return null; } List outputFieldList = query.getOutFieldList(); if(outputFieldList == null || outputFieldList.size() == 0){ logger.error("未设置输出字段名称:OutFieldList"); return null; } MetricType metricType = null; if(query.getMetricType() == null || query.getMetricType().equals("")){ metricType = MetricType.L2; } SearchParam searchParam = SearchParam.newBuilder() .withCollectionName(query.getTableName()) .withMetricType(metricType) .withOutFields(outputFieldList) .withTopK(query.getTopK()) .withVectors(query.getSearchVectors()) .withVectorFieldName(query.getVectorName()) .withParams(query.getSearchParam()) .build(); R respSearch = this.client.search(searchParam); if(respSearch == null){ logger.warn("未搜索到相似结果对象:" + query); return null; } SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(respSearch.getData().getResults()); System.out.println(wrapperSearch.getIDScore(0)); // 设置一个分值,评分过低的结果直接过滤。2022-08-26 OutData outData = new OutData(); List scoreList = wrapperSearch.getIDScore(0); if(scoreList != null && scoreList.size() > 0){ for(SearchResultsWrapper.IDScore idScore : scoreList){ outData.addIdScore(idScore.getLongID(), idScore.getScore()); } } for(String outField : outputFieldList){ if(outField.equals("id")){ outData.setKeyList((List)wrapperSearch.getFieldData("id", 0)); } else { outData.setBusinessIdList((List)wrapperSearch.getFieldData(outField, 0)); } } // System.out.println(wrapperSearch.getFieldData("book_id", 0)); // return wrapperSearch.getFieldData(query.getFieldPrimaryKey(), 0); return outData; } @Override public void releaseSearch(String tableName){ this.checkConnection(); this.client.releaseCollection(ReleaseCollectionParam.newBuilder() .withCollectionName(tableName) .build()); } private void checkConnection(){ if(this.client == null){ throw new RuntimeException("服务未连接,请先连接 milvus 服务"); } } private boolean checkStatusR(R statusR){ if(statusR == null){ return false; } if(statusR.getStatus().intValue() == R.Status.Success.getCode()){ return true; } return false; } }