package com.walker.embedding;
|
|
import ai.djl.ModelException;
|
import ai.djl.translate.TranslateException;
|
import com.walker.embedding.util.FeatureComparison;
|
import com.walker.embedding.util.WordEncoder;
|
import org.slf4j.Logger;
|
import org.slf4j.LoggerFactory;
|
|
import java.io.IOException;
|
import java.nio.file.Path;
|
import java.nio.file.Paths;
|
import java.util.Arrays;
|
|
public final class WordEncoderExample7 {
|
|
private static final Logger logger = LoggerFactory.getLogger(WordEncoderExample7.class);
|
|
private WordEncoderExample7() {}
|
|
public static void main(String[] args) throws IOException, ModelException, TranslateException {
|
Path vocabPath = Paths.get("d:/dev_tools/ai/w2v_sogou_dim300_vocab.txt");
|
Path embeddingPath = Paths.get("d:/dev_tools/ai/w2v_sogou_dim300.npy");
|
// Path vocabPath = Paths.get("src/test/resources/w2v_sogou_dim300_vocab.txt");
|
// Path embeddingPath = Paths.get("src/test/resources/w2v_sogou_dim300.npy");
|
|
WordEncoder encoder = new WordEncoder(vocabPath, embeddingPath);
|
|
// 获取单词的特征值embedding
|
// float[] embedding1 = encoder.search("搜索");
|
float[] embedding1 = encoder.search("搜索");
|
logger.info("搜索-特征值: " + Arrays.toString(embedding1));
|
// System.out.println("搜索-特征值: " + Arrays.toString(embedding1));
|
float[] embedding2 = encoder.search("检索");
|
logger.info("检索-特征值: " + Arrays.toString(embedding2));
|
// System.out.println("检索-特征值: " + Arrays.toString(embedding2));
|
|
// 计算两个词向量的余弦相似度
|
float cosineSim = FeatureComparison.cosineSim(embedding1, embedding2);
|
logger.info("余弦相似度: "+ Float.toString(cosineSim));
|
System.out.println("余弦相似度: "+ Float.toString(cosineSim));
|
|
// 计算两个词向量的内积
|
float dot = FeatureComparison.dot(embedding1, embedding2);
|
logger.info("内积: "+ Float.toString(dot));
|
System.out.println("内积: "+ Float.toString(dot));
|
}
|
}
|