package com.walker.support.milvus;
|
|
import org.slf4j.Logger;
|
import org.slf4j.LoggerFactory;
|
|
import java.io.Serializable;
|
import java.util.ArrayList;
|
import java.util.HashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
/**
|
* 检索输出结果定义。<br/>
|
* 由于输出列信息需要分开调用,因此使用该对象汇总处理。
|
* @author 时克英
|
* @date 2022-08-26
|
*/
|
public class OutData {
|
|
private final transient Logger logger = LoggerFactory.getLogger(this.getClass());
|
|
private Map<Long, Float> idScoreMap = new HashMap<>();
|
|
private List<Long> keyList = new ArrayList<>();
|
|
private List<String> businessIdList = new ArrayList<>();
|
|
/**
|
* 根据milvus数据库主键返回记录对应的预测分值。
|
* @param key 数据库主键(key)
|
* @return
|
*/
|
public float getScoreByKey(long key){
|
Float score = this.idScoreMap.get(key);
|
if(score == null){
|
System.out.println("key对应分值不存在,key = " + key);
|
return 0;
|
}
|
return score.floatValue();
|
}
|
|
/**
|
* 返回查询结果对象,里面包含了Data是记录相关内容,包含:分值、主键、业务ID。
|
* @return
|
*/
|
public List<Data> getResultList(){
|
if(keyList == null || keyList.size() == 0){
|
logger.error("keyList 为空");
|
return null;
|
}
|
if(businessIdList == null || businessIdList.size() == 0){
|
System.out.println("businessIdList 为空");
|
return null;
|
}
|
if(keyList.size() != businessIdList.size()){
|
logger.error("businessIdList 与 keyList 大小不一致!");
|
return null;
|
}
|
List<Data> resultList = new ArrayList<>();
|
Data d = null;
|
float oneScore = 0;
|
for(int i=0; i<keyList.size(); i++){
|
oneScore = this.getScoreByKey(keyList.get(i));
|
d = new Data(keyList.get(i), businessIdList.get(i), oneScore);
|
resultList.add(d);
|
logger.debug(d.toString());
|
}
|
return resultList;
|
}
|
|
/**
|
* 返回milvus主键对应的分值。
|
* @return
|
*/
|
public Map<Long, Float> getIdScoreMap() {
|
return idScoreMap;
|
}
|
|
public void setIdScoreMap(Map<Long, Float> idScoreMap) {
|
this.idScoreMap = idScoreMap;
|
}
|
|
/**
|
* 返回milvus数据库存储的主键
|
* @return
|
*/
|
public List<Long> getKeyList() {
|
return keyList;
|
}
|
|
public void setKeyList(List<Long> keyList) {
|
this.keyList = keyList;
|
}
|
|
/**
|
* 返回业务编号集合。
|
* @return
|
*/
|
public List<String> getBusinessIdList() {
|
return businessIdList;
|
}
|
|
public void setBusinessIdList(List<String> businessIdList) {
|
this.businessIdList = businessIdList;
|
}
|
|
public void addIdScore(long id, float score){
|
idScoreMap.put(id, score);
|
}
|
|
public static class Data implements Serializable {
|
|
private long key;
|
private String businessId;
|
|
private float score = 0;
|
|
public long getKey() {
|
return key;
|
}
|
|
public String getBusinessId() {
|
return businessId;
|
}
|
|
public float getScore() {
|
return score;
|
}
|
|
public Data(long key, String businessId, float score){
|
this.key = key;
|
this.businessId = businessId;
|
this.score = score;
|
}
|
|
@Override
|
public String toString(){
|
return new StringBuilder("[key=").append(this.key)
|
.append(", businessId=").append(this.businessId)
|
.append(", score=").append(this.score)
|
.append("]").toString();
|
}
|
}
|
}
|