wujiefeng il y a 1 an
Parent
commit
2a3b6f13b6

+ 1 - 1
centers/AuthCenter/AuthServer/src/main/java/com/github/microservice/auth/server/core/domain/UserFace.java

@@ -24,7 +24,7 @@ public class UserFace extends SuperEntity {
     private String faceFSId;
 
     //人脸向量
-    private List<Float> vector;
+    private List<Double> vector;
 
     private String faceDataKey;
 

+ 1 - 1
centers/AuthCenter/AuthServer/src/main/java/com/github/microservice/auth/server/core/model/ResultsModel.java

@@ -16,7 +16,7 @@ public class ResultsModel {
     @AllArgsConstructor
     @NoArgsConstructor
     public static class Result {
-        private List<Float> embedding;
+        private List<Double> embedding;
         private double faceConfidence;
         private FacialArea facialArea;
 

+ 19 - 18
centers/AuthCenter/AuthServer/src/main/java/com/github/microservice/auth/server/core/service/local/UserFaceServiceImpl.java

@@ -21,6 +21,7 @@ import lombok.Cleanup;
 import lombok.SneakyThrows;
 import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.lang3.StringUtils;
+import org.apache.kafka.common.protocol.types.Field;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Value;
 import org.springframework.data.mongodb.core.query.Criteria;
@@ -118,7 +119,7 @@ public class UserFaceServiceImpl implements UserFaceService {
 
 
         System.out.println(body);
-        List<Float> embedding = resultsModel.getResults().get(0).getEmbedding();
+        List<Double> embedding = resultsModel.getResults().get(0).getEmbedding();
 
 
 
@@ -133,7 +134,7 @@ public class UserFaceServiceImpl implements UserFaceService {
 
 
         // 将 List<Float> 转换为 float[]
-        float[] vector = listToFloatArray(embedding);
+        double[] vector = listToFloatArray(embedding);
         // 归一化向量
         double norm = calculateNorm(vector);
         for (int i = 0; i < vector.length; i++) {
@@ -141,9 +142,9 @@ public class UserFaceServiceImpl implements UserFaceService {
         }
 
         // 转换为字节数组 (FLOAT32,每个值占 4 字节)
-        ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4);
-        for (float value : vector) {
-            buffer.putFloat(value);
+        ByteBuffer buffer = ByteBuffer.allocate(vector.length * Double.BYTES);
+        for (double value : vector) {
+            buffer.putDouble(value);
         }
 
         String key = prefix + user.getId() + System.currentTimeMillis();
@@ -181,7 +182,7 @@ public class UserFaceServiceImpl implements UserFaceService {
                     "local dim = ARGV[1] " +
                     "redis.call('FT.CREATE', indexName, 'ON', 'HASH', 'PREFIX', '1', prefix, " +
                     "'SCHEMA', fieldName, 'VECTOR', 'FLAT', '6', " +
-                    "'TYPE', 'FLOAT32', 'DIM', dim, 'DISTANCE_METRIC', 'COSINE')";
+                    "'TYPE', 'FLOAT64', 'DIM', dim, 'DISTANCE_METRIC', 'COSINE')";
 
 
             String a = """
@@ -197,17 +198,17 @@ public class UserFaceServiceImpl implements UserFaceService {
     }
 
 
-    private static float calculateNorm(float[] vector) {
-        float sum = 0;
-        for (float v : vector) {
+    private static double calculateNorm(double[] vector) {
+        double sum = 0;
+        for (double v : vector) {
             sum += v * v;  // 对每个分量平方求和
         }
-        return (float) Math.sqrt(sum);  // 求平方和的平方根
+        return (double) Math.sqrt(sum);  // 求平方和的平方根
     }
 
-    public static float[] listToFloatArray(List<Float> list) {
+    public static double[] listToFloatArray(List<Double> list) {
         // 创建一个与 List 大小相同的 float 数组
-        float[] array = new float[list.size()];
+        double[] array = new double[list.size()];
 
         // 使用普通的 for 循环将 List 中的每个元素放入数组中
         for (int i = 0; i < list.size(); i++) {
@@ -305,18 +306,18 @@ public class UserFaceServiceImpl implements UserFaceService {
 
 
         System.out.println(body);
-        List<Float> embedding = resultsModel.getResults().get(0).getEmbedding();
-        float[] vector = listToFloatArray(embedding);
+        List<Double> embedding = resultsModel.getResults().get(0).getEmbedding();
+        double[] vector = listToFloatArray(embedding);
 
-        float norm = calculateNorm(vector);
+        double norm = calculateNorm(vector);
         for (int i = 0; i < vector.length; i++) {
             vector[i] /= norm;
         }
 
         // 转换为字节数组 (FLOAT32,每个值占 4 字节)
-        ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4);
-        for (float value : vector) {
-            buffer.putFloat(value);
+        ByteBuffer buffer = ByteBuffer.allocate(vector.length * Double.BYTES);
+        for (double value : vector) {
+            buffer.putDouble(value);
         }
         // 打印字节数组的内容
         byte[] queryVector = buffer.array();