|
|
@@ -21,6 +21,7 @@ import lombok.Cleanup;
|
|
|
import lombok.SneakyThrows;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
+import org.apache.kafka.common.protocol.types.Field;
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
import org.springframework.beans.factory.annotation.Value;
|
|
|
import org.springframework.data.mongodb.core.query.Criteria;
|
|
|
@@ -118,7 +119,7 @@ public class UserFaceServiceImpl implements UserFaceService {
|
|
|
|
|
|
|
|
|
System.out.println(body);
|
|
|
- List<Float> embedding = resultsModel.getResults().get(0).getEmbedding();
|
|
|
+ List<Double> embedding = resultsModel.getResults().get(0).getEmbedding();
|
|
|
|
|
|
|
|
|
|
|
|
@@ -133,7 +134,7 @@ public class UserFaceServiceImpl implements UserFaceService {
|
|
|
|
|
|
|
|
|
// 将 List<Float> 转换为 float[]
|
|
|
- float[] vector = listToFloatArray(embedding);
|
|
|
+ double[] vector = listToFloatArray(embedding);
|
|
|
// 归一化向量
|
|
|
double norm = calculateNorm(vector);
|
|
|
for (int i = 0; i < vector.length; i++) {
|
|
|
@@ -141,9 +142,9 @@ public class UserFaceServiceImpl implements UserFaceService {
|
|
|
}
|
|
|
|
|
|
// 转换为字节数组 (FLOAT32,每个值占 4 字节)
|
|
|
- ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4);
|
|
|
- for (float value : vector) {
|
|
|
- buffer.putFloat(value);
|
|
|
+ ByteBuffer buffer = ByteBuffer.allocate(vector.length * Double.BYTES);
|
|
|
+ for (double value : vector) {
|
|
|
+ buffer.putDouble(value);
|
|
|
}
|
|
|
|
|
|
String key = prefix + user.getId() + System.currentTimeMillis();
|
|
|
@@ -181,7 +182,7 @@ public class UserFaceServiceImpl implements UserFaceService {
|
|
|
"local dim = ARGV[1] " +
|
|
|
"redis.call('FT.CREATE', indexName, 'ON', 'HASH', 'PREFIX', '1', prefix, " +
|
|
|
"'SCHEMA', fieldName, 'VECTOR', 'FLAT', '6', " +
|
|
|
- "'TYPE', 'FLOAT32', 'DIM', dim, 'DISTANCE_METRIC', 'COSINE')";
|
|
|
+ "'TYPE', 'FLOAT64', 'DIM', dim, 'DISTANCE_METRIC', 'COSINE')";
|
|
|
|
|
|
|
|
|
String a = """
|
|
|
@@ -197,17 +198,17 @@ public class UserFaceServiceImpl implements UserFaceService {
|
|
|
}
|
|
|
|
|
|
|
|
|
- private static float calculateNorm(float[] vector) {
|
|
|
- float sum = 0;
|
|
|
- for (float v : vector) {
|
|
|
+ private static double calculateNorm(double[] vector) {
|
|
|
+ double sum = 0;
|
|
|
+ for (double v : vector) {
|
|
|
sum += v * v; // 对每个分量平方求和
|
|
|
}
|
|
|
- return (float) Math.sqrt(sum); // 求平方和的平方根
|
|
|
+ return (double) Math.sqrt(sum); // 求平方和的平方根
|
|
|
}
|
|
|
|
|
|
- public static float[] listToFloatArray(List<Float> list) {
|
|
|
+ public static double[] listToFloatArray(List<Double> list) {
|
|
|
// 创建一个与 List 大小相同的 float 数组
|
|
|
- float[] array = new float[list.size()];
|
|
|
+ double[] array = new double[list.size()];
|
|
|
|
|
|
// 使用普通的 for 循环将 List 中的每个元素放入数组中
|
|
|
for (int i = 0; i < list.size(); i++) {
|
|
|
@@ -305,18 +306,18 @@ public class UserFaceServiceImpl implements UserFaceService {
|
|
|
|
|
|
|
|
|
System.out.println(body);
|
|
|
- List<Float> embedding = resultsModel.getResults().get(0).getEmbedding();
|
|
|
- float[] vector = listToFloatArray(embedding);
|
|
|
+ List<Double> embedding = resultsModel.getResults().get(0).getEmbedding();
|
|
|
+ double[] vector = listToFloatArray(embedding);
|
|
|
|
|
|
- float norm = calculateNorm(vector);
|
|
|
+ double norm = calculateNorm(vector);
|
|
|
for (int i = 0; i < vector.length; i++) {
|
|
|
vector[i] /= norm;
|
|
|
}
|
|
|
|
|
|
// 转换为字节数组 (FLOAT32,每个值占 4 字节)
|
|
|
- ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4);
|
|
|
- for (float value : vector) {
|
|
|
- buffer.putFloat(value);
|
|
|
+ ByteBuffer buffer = ByteBuffer.allocate(vector.length * Double.BYTES);
|
|
|
+ for (double value : vector) {
|
|
|
+ buffer.putDouble(value);
|
|
|
}
|
|
|
// 打印字节数组的内容
|
|
|
byte[] queryVector = buffer.array();
|