package com.walker.embedding.util;
|
|
import ai.djl.util.Utils;
|
import com.walker.infrastructure.utils.StringUtils;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.factory.Nd4j;
|
import org.slf4j.Logger;
|
import org.slf4j.LoggerFactory;
|
|
import java.io.File;
|
import java.io.FileInputStream;
|
import java.io.IOException;
|
import java.io.InputStream;
|
import java.nio.file.Path;
|
import java.util.HashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
public final class WordEncoder {
|
|
// private static final String EMPTY_VALUE = "";
|
|
protected final transient Logger logger = LoggerFactory.getLogger(this.getClass());
|
|
// List<String> words = null;
|
private Map<String, Integer> wordsCache = new HashMap<>();
|
private float[][] embeddings = null;
|
|
public WordEncoder(Path vocab, Path embedding) throws IOException {
|
try (InputStream is = new FileInputStream(new File(vocab.toString()))) {
|
List<String> words = Utils.readLines(is, false);
|
int i = 0;
|
for(String w : words){
|
this.wordsCache.put(w, i);
|
i++;
|
}
|
}
|
|
File file = new File(embedding.toString());
|
INDArray array = Nd4j.readNpy(file);
|
embeddings = array.toFloatMatrix();
|
}
|
|
public float[] search(String word) {
|
// for (int i = 0; i < words.size(); i++) {
|
// if (words.get(i).equals(word)) {
|
// return embeddings[i];
|
// }
|
// }
|
if(StringUtils.isEmpty(word)){
|
logger.warn("请提供word单词");
|
return null;
|
}
|
Integer index = this.wordsCache.get(word);
|
if(index == null){
|
logger.debug("未查询到词向量索引:" + word);
|
return null;
|
}
|
return embeddings[index];
|
}
|
|
public float cosineSim(float[] feature1, float[] feature2) {
|
float ret = 0.0f;
|
float mod1 = 0.0f;
|
float mod2 = 0.0f;
|
int length = feature1.length;
|
for (int i = 0; i < length; ++i) {
|
ret += feature1[i] * feature2[i];
|
mod1 += feature1[i] * feature1[i];
|
mod2 += feature2[i] * feature2[i];
|
}
|
return (float) ((ret / Math.sqrt(mod1) / Math.sqrt(mod2) + 1) / 2.0f);
|
}
|
}
|