package com.walker.embedding; import ai.djl.ModelException; import ai.djl.translate.TranslateException; import com.walker.embedding.util.FeatureComparison; import com.walker.embedding.util.WordEncoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.nio.file.Path; import java.nio.file.Paths; import java.util.Arrays; public final class WordEncoderExample7 { private static final Logger logger = LoggerFactory.getLogger(WordEncoderExample7.class); private WordEncoderExample7() {} public static void main(String[] args) throws IOException, ModelException, TranslateException { Path vocabPath = Paths.get("d:/dev_tools/ai/w2v_sogou_dim300_vocab.txt"); Path embeddingPath = Paths.get("d:/dev_tools/ai/w2v_sogou_dim300.npy"); // Path vocabPath = Paths.get("src/test/resources/w2v_sogou_dim300_vocab.txt"); // Path embeddingPath = Paths.get("src/test/resources/w2v_sogou_dim300.npy"); WordEncoder encoder = new WordEncoder(vocabPath, embeddingPath); // 获取单词的特征值embedding // float[] embedding1 = encoder.search("搜索"); float[] embedding1 = encoder.search("搜索"); logger.info("搜索-特征值: " + Arrays.toString(embedding1)); // System.out.println("搜索-特征值: " + Arrays.toString(embedding1)); float[] embedding2 = encoder.search("检索"); logger.info("检索-特征值: " + Arrays.toString(embedding2)); // System.out.println("检索-特征值: " + Arrays.toString(embedding2)); // 计算两个词向量的余弦相似度 float cosineSim = FeatureComparison.cosineSim(embedding1, embedding2); logger.info("余弦相似度: "+ Float.toString(cosineSim)); System.out.println("余弦相似度: "+ Float.toString(cosineSim)); // 计算两个词向量的内积 float dot = FeatureComparison.dot(embedding1, embedding2); logger.info("内积: "+ Float.toString(dot)); System.out.println("内积: "+ Float.toString(dot)); } }