package com.walker.embedding;
|
|
import com.walker.embedding.util.WordEncoder;
|
import com.walker.infrastructure.ApplicationRuntimeException;
|
import com.walker.infrastructure.utils.StringUtils;
|
import org.slf4j.Logger;
|
import org.slf4j.LoggerFactory;
|
|
import java.io.IOException;
|
import java.nio.file.Path;
|
import java.nio.file.Paths;
|
|
public class DefaultVectorGenerator implements VectorGenerator{
|
|
protected final transient Logger logger = LoggerFactory.getLogger(this.getClass());
|
|
private WordEncoder wordEncoder = null;
|
|
@Override
|
public void initLoadDict(String wordFile, String embeddingFile) {
|
if(StringUtils.isEmpty(wordFile) || StringUtils.isEmpty(embeddingFile)){
|
logger.warn("VectorGenerator初始化失败:请提供词库与对应的向量维度模型文件!");
|
return;
|
}
|
Path vocabPath = Paths.get(wordFile);
|
Path embeddingPath = Paths.get(embeddingFile);
|
try {
|
this.wordEncoder = new WordEncoder(vocabPath, embeddingPath);
|
} catch (IOException e) {
|
throw new ApplicationRuntimeException("VectorGenerator启动异常:" + e.getMessage(), e);
|
}
|
}
|
|
@Override
|
public float[] getWordEmbedding(String word) {
|
return this.wordEncoder.search(word);
|
}
|
|
@Override
|
public float cosineSim(float[] feature1, float[] feature2) {
|
return this.wordEncoder.cosineSim(feature1, feature2);
|
}
|
}
|