AIHIA梦工厂

标题: Insightface记录——验证集制作 [打印本页]

作者: 骑着拖拉机去旅    时间: 2023-9-24 05:44 PM
标题: Insightface记录——验证集制作
一、pairs 文件制作
名单包含两种内容:
标签为1, 是同一个人的两两组合图相对,
不同人的两两组合图像对,标签是0,代表是不同人




看网上很多人写的在生成不同人的图像对时忽略了,同一个人员ID目录下包含多张图的情况,我这里进行了修改,代码如下:

  1. import argparse
  2. import itertools
  3. import os
  4. import random
  5. import time


  6. if __name__ == "__main__":
  7.     parser = argparse.ArgumentParser(description='Generate image pairs.')
  8.     parser.add_argument('--dataset', default='G:/myFaceData/myvaldata/imgs/',
  9.                         type=str, help='Dataset path.')
  10.     parser.add_argument('--output', default='G:/myFaceData/myvaldata/pairs1.txt', type=str, help='Output path.')
  11.     parser.add_argument('--same', default=5000, type=int,
  12.                         help="Number of same pairs for each person.")
  13.     args = parser.parse_args()

  14.     assert args.dataset and args.output, \
  15.         'Dataset and output should be defined.'
  16.     file = open(args.output, 'w+')
  17.     #file.writelines('\n')
  18.     allList = []
  19.     # ---- Generate same pairs for each person
  20.     same = 0
  21.     personList = os.listdir(args.dataset)
  22.     samePairs = []
  23.     for person in personList:
  24.         if not os.path.isdir(os.path.join(args.dataset, person)):
  25.             continue
  26.         IDPath = os.path.join(args.dataset, person)
  27.         subList = []
  28.         images = os.listdir(os.path.join(args.dataset, person))
  29.         for image in images:
  30.             if not image.lower().endswith(('jpg', 'jpeg', 'png')):
  31.                 continue
  32.             imagePath = os.path.join(args.dataset, person, image)
  33.             
  34.             allList.append(imagePath)#所有的路径都存下来
  35.             subList.append(imagePath)
  36.             


  37.         for pair in itertools.combinations(subList, 2):
  38.             samePairs.append(pair)

  39.     choices = random.sample(samePairs, args.same)
  40.     for pair in choices:
  41.         file.writelines(f'{pair[0]} {pair[1]} {1}\n')
  42.         same += 1
  43.         # ---- Generate different pairs
  44.     print(f"Will generate {same} different pairs.")
  45.     count = 0
  46.     sets = set()
  47.     while True:
  48.         if count >= same:
  49.             break
  50.         random.seed(time.time() * 100000 % 10000)
  51.         random.shuffle(allList)
  52.         indexa=random.sample(range(0, len(allList)),1)[0]
  53.         indexb=random.sample(range(0, len(allList)),1)[0]
  54.         IDPath_1 = os.path.split(allList[indexa])
  55.         IDPath_2 = os.path.split(allList[indexb])
  56.         if allList[indexa] != allList[indexb] and IDPath_1[0]!= IDPath_2[0] and not (allList[indexa], allList[indexb]) in sets:
  57.             sets.add((allList[indexa], allList[indexb]))
  58.             sets.add((allList[indexb], allList[indexa]))
  59.             file.writelines(f'{allList[indexa]} {allList[indexb]} {0}\n')
  60.             count += 1
复制代码

二、bin 文件制作

同样网上找到跑不通,自己修改了一些内容。


  1. #coding:utf-8

  2. import mxnet as mx
  3. from mxnet import ndarray as nd
  4. import argparse
  5. import pickle
  6. import sys
  7. import os
  8. import numpy as np
  9. import pdb
  10. import matplotlib.pyplot as plt

  11. def read_pairs(pairs_filename):
  12.     pairs = []
  13.     with open(pairs_filename, 'r') as f:
  14.         for line in f.readlines()[0:]:
  15.             pair = line.strip().split(',')
  16.             pairs.append(pair)
  17.     return np.array(pairs)


  18. def get_paths(pairs, same_pairs):
  19.     nrof_skipped_pairs = 0
  20.     path_list = []
  21.     issame_list = []
  22.     cnt = 1
  23.     for pair in pairs:
  24.         path0 = pair[0].split(' ')[0]
  25.         path1 = pair[0].split(' ')[1]

  26.         if cnt < same_pairs:
  27.             issame = True
  28.         else:
  29.             issame = False
  30.         if os.path.exists(path0) and os.path.exists(path1):  # Only add the pair if both paths exist
  31.             path_list += (path0, path1)
  32.             issame_list.append(issame)
  33.         else:
  34.             print('not exists', path0, path1)
  35.             nrof_skipped_pairs += 1
  36.         cnt += 1
  37.     if nrof_skipped_pairs > 0:
  38.         print('Skipped %d image pairs' % nrof_skipped_pairs)

  39.     return path_list, issame_list


  40. if __name__ == '__main__':
  41.     parser = argparse.ArgumentParser(description='Package  images')
  42.     # general
  43.     parser.add_argument('--data-dir', default='G:/myFaceData/myvaldata/imgs', help='')
  44.     parser.add_argument('--image-size', type=str, default='112,112', help='')
  45.     parser.add_argument('--output', default='G:/myFaceData/myvaldata/myval.bin', help='path to save.')
  46.     parser.add_argument('--txtfile', default='G:/myFaceData/myvaldata/pairs.txt', help='txtfile path.')
  47.     args = parser.parse_args()
  48.     image_size = [int(x) for x in args.image_size.split(',')]
  49.     img_pairs = read_pairs(args.txtfile)
  50.     img_paths, issame_list = get_paths(img_pairs, 5000)   # 这里的15925是相同图像对的个数,需要按照实际产生的相同图像对数量替换
  51.     img_bins = []
  52.     i = 0
  53.     for path in img_paths:
  54.         with open(path, 'rb') as fin:
  55.             _bin = fin.read()
  56.             img_bins.append(_bin)
  57.             i += 1
  58.     with open(args.output, 'wb') as f:
  59.          pickle.dump((img_bins, issame_list), f, protocol=pickle.HIGHEST_PROTOCOL)
复制代码

















欢迎光临 AIHIA梦工厂 (https://aihiamgc.com/) Powered by Discuz! X3.5