AI + Milvus:将时尚应用搭建进行到底

本文涉及的产品
阿里云百炼推荐规格 ADB PostgreSQL,4核16GB 100GB 1个月
简介: 如何利用人工智能技术(例如开源 AI 向量数据库 Milvus 和 Hugging Face 模型)寻找与自己穿搭风格相似的明星。

Fashion AI,主要利用微调模型对服装图片进行分割(segmentation),然后裁剪出图像中标注(label)的时尚单品,并将所有图片调整为相同的大小,最后将这些图像转化为 embedding 向量存储在开源向量数据库 Milvus 中。通过这个项目可以在 Milvus 数据库中查询并获得 3 个最相似的向量结果。随后,就可以通过上传一张自己穿着打扮的照片,最终确定与我们时尚风格最为相似的明星。

接下来,我将和大家分享这个项目具体的实现路径。

在正式开始前,可以通过这个链接 https://drive.google.com/file/d/1pBO02iLgToBSCOyMJ58zWHQf4ZRkP5AY/view 获取项目使用到的图片。此外,想要搭建本项目,还需要升级 Python 版本,通过指令 pip install milvus pymilvus torch torchvision matplotli 安装所需软件工具等。本项目使用了 Hugging Face 上由 Mateusz Dziemian 提供的 clothing segmenter 模型 https://huggingface.co/mattmdjaga/segformer_b2_clothes 以及 PyTorch 上由 Nvidia 提供的 ResNet50 模型 https://pytorch.org/hub/nvidia_deeplearningexamples_resnet50/对图像进行分割,将图像转化为 embedding 向量。

01. 图像分割

为了完成图像分割任务,我在 Hugging Face 上找到了以下 3 个模型:

  • Mateusz Dziemian 提供的 segformer_b2_clothes 模型
  • Valentina Feruere 提供的 YOLOS-Fashionpedia 模型
  • Patrick John Chia 提供的 Fashion-CLIP 模型

最终,我选择了 segformer 模型,因为它可以对不同的服装图片进行准确分割,并识别出 18 种“对象”类型。也就是说,这个模型可以检测到图片中的“上衣”、“连衣裙”、“左脚鞋子”、“右脚鞋子”等诸多服装类型。此外,这个模型还可以检测图片中的”脸部”、“头发”、“右腿”、“左腿”等。浏览该链接 https://huggingface.co/mattmdjaga/segformer_b2_clothes/blob/main/config.json#L30了解模型可以识别的全部 18 种对象(object)类型。

开始前,我们首先需要导入本项目中图像处理时所需的工具包,包括:

  • torch用于提取图像特征
  • 来自 transformerssegformer
  • 来自 torchvisionResizemasks_to_boxescrop
import torch
from torch import nn, tensor
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision.transforms as T
from torchvision.ops import masks_to_boxes
from torchvision.transforms.functional import crop
  • 使用 Hugging Face 生成图像分割掩膜

图像分割方法有很多种,采用哪种方法主要取决于你使用的模型及其检测到的内容。在本项目中,我们使用的模型会返回一个 18 层的图像,每层包含一种检测对象类型,其中包含图像背景。

现在,我们先编写一个函数来生成这个 18 层图像。

get_segmentation 函数需要三个参数:特征提取器(feature extractor)、模型(model)和图像(image)。首先,这个函数会使用图像和提取器生成输入特征(input feature), 然后将模型输出转换为 logits。之后,该函数通过 PyTorch 双线性插值(Bilinear Interpolation)上采样(upsample) logits。最后,该函数仅采取每个像素中的最大预测值,以创建分割掩膜(mask)。

def get_segmentation(extractor, model, image):
    inputs = extractor(images=image, return_tensors="pt")

    outputs = model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    pred_seg = upsampled_logits.argmax(dim=1)[0]
    return pred_seg

upsampled_logits 中的图像如下所示:

1.png

pred_seg 图像如下所示。上面两张都是 Andre 3000 的照片,但其实是不同的图像:

1.png

至此,获取分割 mask 的操作就十分简单了。我们获取分割结果中所有的唯一值。根据本项目采用的模型,最多可以获取 18 个值。第一个结果代表的是图像背景,所以可以舍弃这个结果。为了生成 mask,我们提取分割像素中与对象 ID 一致的像素。

以下函数会返回 mask 和 ID,以便可以同时查看二者:

# 返回 2 个 lists masks (tensor) 和obj_ids(int)
# 来自 hugging face 的 "mattmdjaga/segformer_b2_clothes" 模型
def get_masks(segmentation):
    obj_ids = torch.unique(segmentation)
    obj_ids = obj_ids[1:]
    masks = segmentation == obj_ids[:, None, None]
    return masks, obj_ids

函数生成的图像 mask 如下所示。左图为头发 mask,右图为上衣 mask:

1.png

  • 使用 Pytorch 裁剪和调整图像大小

接下来使用 get_masks 函数为图像中每个监测到的对象以及原图生成新图像。随后用 masks_to_boxes 函数将 mask 转化为边界框(bounding box)。此前,我们已经通过 torchvision.ops 导入了这个函数。

接着,创建一系列边界框并将边界框坐标系转为 crop 坐标系。边界框的形式为 (x1, x2, y1, y2)crop 函数期望输入形式为 (top, left, height, width)

在正式裁剪图像前,我们还定义了一个图像预处理函数。将每个图像调整为 256x256 的大小,并转化为 PyTorch tensor (目前是 PIL 图像)。裁剪时,循环遍历裁剪框,并调用 crop 函数。随后我们将预处理完成的图片加入到 dictionary 中,以对应分割 ID 的主键值。函数最后会返回 dictionary。

def crop_images(masks, obj_ids, img):
    boxes = masks_to_boxes(masks)
    crop_boxes = []
    for box in boxes:
        crop_box = tensor([box[0], box[1], box[2]-box[0], box[3]-box[1]])
        crop_boxes.append(crop_box)

    preprocess = T.Compose([
        T.Resize(size=(256, 256)),
        T.ToTensor()
    ])

    cropped_images = {}
    for i in range(len(crop_boxes)):
        crop_box = crop_boxes[i]
        cropped = crop(img, crop_box[1].item(), crop_box[0].item(), crop_box[3].item(), crop_box[2].item())
        cropped_images[obj_ids[i].item()] = preprocess(cropped)
    return cropped_images

下面的示例图中 Drake 穿着鲜橙色的衣服。我们使用裁剪框框处图像中的对象(时尚单品)并为他们各自生成单独的图像:

1.png

02. 将图像数据添加至向量数据库中

图像分割裁剪完成后,我们就可以将其添加至 Milvus 向量数据库中了。为了方便上手,本项目中使用了 Milvus Lite 版本,可以在 notebook 中运行 Milvus 实例。接下来,使用 PyMilvus 连接至 Milvus Lite 提供的默认服务器。

这一步骤中,还需要设置一些常量。定义向量维度、数据量、集合名称、返回的结果个数。随后,运行 ssl 函数来创建上下文,从 PyTorch 获取模型。

from milvus import default_server
from pymilvus import utility, connections
default_server.start()
connections.connect(host="127.0.0.1", port=default_server.listen_port)

DIMENSION = 2048
BATCH_SIZE = 128
COLLECTION_NAME = "fashion"
TOP_K = 3

# 如果遇到 SSL 证书 URL 错误,请在导入 resnet50 模型前运行此内容
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
  • 在向量数据库中定制 Schema 并存储元数据

先定制 Schema。Schema 用于组织向量数据库中存储的数据。id 字段就和 SQL 或者 NoSQL 数据库中的 key ID 一样。Milvus Schema 中的其他字段可以设置 int64、varchar、float 等数据类型。

在本项目中,我们是保存文件路径、明星名字、分割 ID,并将其作为元数据,后续还会考虑添加更多字段,例如边界框、mask 位置等。定义好 FieldSchema、CollectionSchema 后,就可以创建 1 个 Miluvs Collection。

Collection 创建完成后,构建索引。索引参数十分简单。选择 IVF Flat 的索引类型和 L2 相似度类型。这个索引是针对于 Collection 中的 embedding 向量字段。索引构建完成后,将 Collection 加载到内存中,以便后续操作。

from pymilvus import FieldSchema, CollectionSchema, Collection, DataType

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="name", dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="seg_id", dtype=DataType.INT64),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]

schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)

index_params = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()
  • 从 Nvidia ResNet50 模型获取 embedding 向量

我们需要先从 PyTorch 中加载 Nvidia ResNet50 模型,然后删除最后一层输出层,因为embedding 向量是模型的倒数第二层输出。

# 加载 embedding 模型并删除最后一层输出
embeddings_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
embeddings_model = torch.nn.Sequential(*(list(embeddings_model.children())[:-1]))
embeddings_model.eval()

以下函数负责接收向量并将数据插入 Milvus。主要有三个参数:数据、集合对象和模型(也就是本项目中使用的 embedding 模型)。为了解插入到数据库中的数据,以下代码中添加了几条打印语句。

除了打印调试数据外,我们还将 data[0] 中的所有值堆叠到一个 tensor 中,然后使用 squeeze 函数从输出中删除维度是 1 的值。随后,插入新的数据列表,其中包括原数据中的最后三条以及由 tensor 输出转化而来的数据列表,这些数据对应文件路径、名称、分割 ID、2048 维向量。

def embed_insert(data, collection, model):
    with torch.no_grad():
        print(len(data[0]))
        print(data[0][0].size())
        output = model(torch.stack(data[0])).squeeze()
        print(type(output))
        print(len(output))
        print(len(output[0]))
        print(output[0])

    collection.insert([data[1], data[2], data[3], output.tolist()])

打印的数据如下图所示:

1.png

每个数据批次的大小为 128,每条数据的大小为 3x256x256。输出是 PyTorch tensor,长度为 128,输出中的每条数据长度为 2048。打印的 tensor 是数据批次中的第一条数据。

  • 将图像数据存储到向量数据库中

还记得前文提到的特征提取器和分割模型吗?接下来轮到它们出场了。我们需要用到 segformer 预训练模型,在循环遍历所有文件路径之后,将所有文件路径放入一个列表中。

extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
import os
image_paths = []
for celeb in os.listdir("./photos"):
    for image in os.listdir(f"./photos/{celeb}/"):
        # print(image)
        image_paths.append(f"./photos/{celeb}/{image}")

Milvus 期望输入格式为列表。在本项目中,我们使用了 4 个列表,分别对应图像、文件路径、名称和分割 ID。在 embed_insert 函数中,将图像转换为 embedding 向量。然后,循环遍历每个图像文件的文件路径,收集它们的分割 mask 并对其进行裁剪。最后,将图像及元数据添加到数据批处理中。

每 128 张图像作为一批数据,我们将其转化为向量并插入到 Milvus 中,然后清空这批数据。在循环结束时,会 flush 数据完成索引构建。注意,在配备 M1 2021 Mac 和 16GB RAM 的计算机上,运行此过程需要约 8 分钟。

from PIL import Image
data_batch = [[], [], [], []]

for path in image_paths:
    image = Image.open(path)
    path_split = path.split("/")
    name = " ".join(path_split[2].split("_"))
    segmentation = get_segmentation(extractor, model, image)
    masks, ids = get_masks(segmentation)
    cropped_images = crop_images(masks, ids, image)

    for key, image in cropped_images.items():
        data_batch[0].append(image)
        data_batch[1].append(path)
        data_batch[2].append(name)
        data_batch[3].append(key)

    if len(data_batch[0]) % BATCH_SIZE == 0:
        embed_insert(data_batch, collection, embeddings_model)
        data_batch = [[], [], [], []]

if len(data_batch[0]) != 0:
    embed_insert(data_batch, collection, embeddings_model)

collection.flush()

03. 寻找与你时尚风格最相似的明星

上述步骤都完成后,就可以开始玩转这个系统了,它可以根据你上传的图片返回前 3 个与你穿搭风格最相似的明星。

  • 将上传图像转化为向量

首先需要处理上传的图像。以下函数需要两个参数:数据和 (embedding)模型。我们使用模型将图像转化为向量、处理图像,图像转化为列表并返回图片列表。

def embed_search_images(data, model):
    with torch.no_grad():
        print(len(data[0]))
        print(data[0][0].size())
        output = model(torch.stack(data))
        print(type(output))
        print(len(output))
        print(len(output[0]))
        print(output[0])
        if len(output) > 1:
            return output.squeeze().tolist()
        Else:
     return torch.flatten(output, start_dim=1).tolist()

如下图所示,传入本函数的 data 实际上是 data[0] 对象。

1.png

在查询时,我们只需要向量数据,但还是可以保留其他数据字段,就像把数据插入到 Milvus 中一样。

# data_batch[0] is a list of tensors
# data_batch[1] is a list of filepaths to the images (string)
# data_batch[2] is a list of the names of the people in the images (string)
# data_batch[3] is a list of segmentation keys (int)
data_batch = [[], [], [], []]


search_paths = ["./photos/Taylor_Swift/Taylor_Swift_3.jpg", "./photos/Taylor_Swift/Taylor_Swift_8.jpg"]


for path in search_paths:
    image = Image.open(path)
    path_split = path.split("/")
    name = " ".join(path_split[2].split("_"))
    segmentation = get_segmentation(extractor, model, image)
    masks, ids = get_masks(segmentation)
    cropped_images = crop_images(masks, ids, image)
    for key, image in cropped_images.items():
        data_batch[0].append(image)
        data_batch[1].append(path)
        data_batch[2].append(name)
        data_batch[3].append(key)


embeds = embed_search_images(data_batch[0], embeddings_model)
  • 查询向量数据库

将上传图片转化为向量后,便可以开始在向量数据库中查询相似数据了。为了测试,我们添加了 time 模块记录每次查询所需的时间。本项目中测量了查询 23 个 2048 维向量数据所需的时间,如果没有这个需求,可以直接使用 search 函数。

import time
start = time.time()
res = collection.search(embeds,
    anns_field='embedding',
    param={"metric_type": "L2",
    "params": {"nprobe": 10}},
    limit=TOP_K,
    output_fields=['filepath'])
finish = time.time()
print(finish - start)

在循环后,可以看到以下生成的响应。

for index, result in enumerate(res):
    print(index)
    print(result)

1.png

欢迎大家上手操作,期待你们的结果分享!

如果在使用 Milvus 或 Zilliz 产品有任何问题,可添加小助手微信 “zilliz-tech” 加入交流群。

相关实践学习
阿里云百炼xAnalyticDB PostgreSQL构建AIGC应用
通过该实验体验在阿里云百炼中构建企业专属知识库构建及应用全流程。同时体验使用ADB-PG向量检索引擎提供专属安全存储,保障企业数据隐私安全。
AnalyticDB PostgreSQL 企业智能数据中台:一站式管理数据服务资产
企业在数据仓库之上可构建丰富的数据服务用以支持数据应用及业务场景;ADB PG推出全新企业智能数据平台,用以帮助用户一站式的管理企业数据服务资产,包括创建, 管理,探索, 监控等; 助力企业在现有平台之上快速构建起数据服务资产体系
目录
相关文章
|
2天前
|
机器学习/深度学习 人工智能 自然语言处理
当前AI大模型在软件开发中的创新应用与挑战
2024年,AI大模型在软件开发领域的应用正重塑传统流程,从自动化编码、智能协作到代码审查和测试,显著提升了开发效率和代码质量。然而,技术挑战、伦理安全及模型可解释性等问题仍需解决。未来,AI将继续推动软件开发向更高效、智能化方向发展。
|
6天前
|
机器学习/深度学习 人工智能 自然语言处理
AI在医疗领域的应用及其挑战
【10月更文挑战第34天】本文将探讨人工智能(AI)在医疗领域的应用及其面临的挑战。我们将从AI技术的基本概念入手,然后详细介绍其在医疗领域的各种应用,如疾病诊断、药物研发、患者护理等。最后,我们将讨论AI在医疗领域面临的主要挑战,包括数据隐私、算法偏见、法规合规等问题。
19 1
|
8天前
|
存储 XML 人工智能
深度解读AI在数字档案馆中的创新应用:高效识别与智能档案管理
基于OCR技术的纸质档案电子化方案,通过先进的AI能力平台,实现手写、打印、复古文档等多格式高效识别与智能归档。该方案大幅提升了档案管理效率,确保数据安全与隐私,为档案馆提供全面、智能化的电子化管理解决方案。
87 48
|
4天前
|
机器学习/深度学习 人工智能 算法
AI在医疗领域的应用与挑战
本文探讨了人工智能(AI)在医疗领域的应用,包括其在疾病诊断、治疗方案制定、患者管理等方面的优势和潜力。同时,也分析了AI在医疗领域面临的挑战,如数据隐私、伦理问题以及技术局限性等。通过对这些内容的深入分析,旨在为读者提供一个全面了解AI在医疗领域现状和未来发展的视角。
24 10
|
4天前
|
机器学习/深度学习 人工智能 监控
探索AI在医疗领域的应用与挑战
本文深入探讨了人工智能(AI)在医疗领域中的应用现状和面临的挑战。通过分析AI技术如何助力疾病诊断、治疗方案优化、患者管理等方面的创新实践,揭示了AI技术为医疗行业带来的变革潜力。同时,文章也指出了数据隐私、算法透明度、跨学科合作等关键问题,并对未来的发展趋势进行了展望。
|
8天前
|
机器学习/深度学习 人工智能 自然语言处理
当前AI大模型在软件开发中的创新应用与挑战
【10月更文挑战第31天】2024年,AI大模型在软件开发领域的应用取得了显著进展,从自动化代码生成、智能代码审查到智能化测试,极大地提升了开发效率和代码质量。然而,技术挑战、伦理与安全问题以及模型可解释性仍是亟待解决的关键问题。开发者需不断学习和适应,以充分利用AI的优势。
|
8天前
|
人工智能 安全 测试技术
探索AI在软件开发中的应用:提升开发效率与质量
【10月更文挑战第31天】在快速发展的科技时代,人工智能(AI)已成为软件开发领域的重要组成部分。本文探讨了AI在代码生成、缺陷预测、自动化测试、性能优化和CI/CD中的应用,以及这些应用如何提升开发效率和产品质量。同时,文章也讨论了数据隐私、模型可解释性和技术更新等挑战。
|
3天前
|
存储 人工智能 固态存储
如何应对生成式AI和大模型应用带来的存储挑战
如何应对生成式AI和大模型应用带来的存储挑战
|
5天前
|
传感器 人工智能 算法
AI在农业中的应用:精准农业的发展
随着科技的发展,人工智能(AI)在农业领域的应用日益广泛,尤其在精准农业方面取得了显著成效。精准农业通过GPS、GIS、遥感技术和自动化技术,实现对农业生产过程的精确监测和控制,提高产量和品质,降低成本和环境影响。AI在作物生长监测、气候预测、智能农机、农产品品质检测和智能灌溉等方面发挥重要作用,推动农业向智能化、高效化和可持续化方向发展。尽管面临技术集成、数据共享等挑战,但未来前景广阔。
|
11天前
|
机器学习/深度学习 人工智能 自然语言处理
思通数科AI平台在尽职调查中的技术解析与应用
思通数科AI多模态能力平台结合OCR、NLP和深度学习技术,为IPO尽职调查、融资等重要交易环节提供智能化解决方案。平台自动识别、提取并分类海量文档,实现高效数据核验与合规性检查,显著提升审查速度和精准度,同时保障敏感信息管理和数据安全。
56 11

热门文章

最新文章