package com.walker.embedding.util; import ai.djl.util.Utils; import com.walker.infrastructure.utils.StringUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.nio.file.Path; import java.util.HashMap; import java.util.List; import java.util.Map; public final class WordEncoder { // private static final String EMPTY_VALUE = ""; protected final transient Logger logger = LoggerFactory.getLogger(this.getClass()); // List words = null; private Map wordsCache = new HashMap<>(); private float[][] embeddings = null; public WordEncoder(Path vocab, Path embedding) throws IOException { try (InputStream is = new FileInputStream(new File(vocab.toString()))) { List words = Utils.readLines(is, false); int i = 0; for(String w : words){ this.wordsCache.put(w, i); i++; } } File file = new File(embedding.toString()); INDArray array = Nd4j.readNpy(file); embeddings = array.toFloatMatrix(); } public float[] search(String word) { // for (int i = 0; i < words.size(); i++) { // if (words.get(i).equals(word)) { // return embeddings[i]; // } // } if(StringUtils.isEmpty(word)){ logger.warn("请提供word单词"); return null; } Integer index = this.wordsCache.get(word); if(index == null){ logger.debug("未查询到词向量索引:" + word); return null; } return embeddings[index]; } public float cosineSim(float[] feature1, float[] feature2) { float ret = 0.0f; float mod1 = 0.0f; float mod2 = 0.0f; int length = feature1.length; for (int i = 0; i < length; ++i) { ret += feature1[i] * feature2[i]; mod1 += feature1[i] * feature1[i]; mod2 += feature2[i] * feature2[i]; } return (float) ((ret / Math.sqrt(mod1) / Math.sqrt(mod2) + 1) / 2.0f); } }