1 训练营课程链接
实战训练营的课程:https://mp.weixin.qq.com/s/3WrTMItNAGt8l2kjjf042w。
- 学习目的
基于车辆检测+AI安全+分类模型的模式,将攻击与防御注入到检测任务与分类任务的级联点中,完成AI项目的对抗攻防安全功能。
- 代码实现
整体流程:检测->截取检测目标的小图->送入对抗攻击监测模块->如有问题发送喵提醒
# aidlux相关
from cvs import *
import aidlite_gpu
from utils import detect_postprocess, preprocess_img, draw_detect_res, extract_detect_res
import time
import cv2,os
import numpy as np
import torch.nn as nn
import requests
import torch
from timm.models import create_model
from advertorch.utils import NormalizeByChannelMeanStd
from advertorch_examples.utils import bhwc2bchw
from advertorch_examples.utils import bchw2bhwc
### 对抗攻击监测模型
class Detect_Model(nn.Module):
def __init__(self, num_classes=2):
super(Detect_Model, self).__init__()
self.num_classes = num_classes
#model = create_model('mobilenetv3_large_075', pretrained=False, num_classes=num_classes)
model = create_model('resnet50', pretrained=False, num_classes=num_classes)
# self.multi_PreProcess = multi_PreProcess()
pth_path = os.path.join("/home/Lesson5_code/model", 'track2_resnet50_ANT_best_albation1_64_checkpoint.pth')
#pth_path = os.path.join("/Users/rocky/Desktop/训练营/Lesson5_code/model/", "track2_tf_mobilenetv3_large_075_64_checkpoint.pth")
state_dict = torch.load(pth_path, map_location='cpu')
is_strict = False
if 'model' in state_dict.keys():
model.load_state_dict(state_dict['model'], strict=is_strict)
else:
model.load_state_dict(state_dict, strict=is_strict)
normalize = NormalizeByChannelMeanStd(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# self.model = nn.Sequential(normalize, self.multi_PreProcess, model)
self.model = nn.Sequential(normalize, model)
def load_params(self):
pass
def forward(self, x):
# x = x[:,:,32:193,32:193]
# x = F.interpolate(x, size=(224,224), mode="bilinear", align_corners=True)
# x = self.multi_PreProcess.forward(x)
out = self.model(x)
if self.num_classes == 2:
out = out.softmax(1)
#return out[:,1:]
return out[:,1:]
device = "cuda" if torch.cuda.is_available() else "cpu"
detect_model = Detect_Model().eval().to(device)
# AidLite初始化:调用AidLite进行AI模型的加载与推理,需导入aidlite
aidlite = aidlite_gpu.aidlite()
# Aidlite模型路径
model_path = '/home/Lesson5_code/yolov5_code/models/yolov5_car_best-fp16.tflite'
# 定义输入输出shape
in_shape = [1 * 640 * 640 * 3 * 4]
out_shape = [1 * 25200 * 6 * 4]
# 加载Aidlite检测模型:支持tflite, tnn, mnn, ms, nb格式的模型加载
aidlite.ANNModel(model_path, in_shape, out_shape, 4, 0)
# 读取图片进行推理
# 设置测试集路径
source = "/home/Lesson5_code/yolov5_code/data/images/tests"
images_list = os.listdir(source)
print(images_list)
frame_id = 0
# 读取数据集
for image_name in images_list:
frame_id += 1
print("frame_id:", frame_id)
image_path = os.path.join(source, image_name)
frame = cvs.imread(image_path)
# 预处理
img = preprocess_img(frame, target_shape=(640, 640), div_num=255, means=None, stds=None)
# 数据转换:因为setTensor_Fp32()需要的是float32类型的数据,所以送入的input的数据需为float32,大多数的开发者都会忘记将图像的数据类型转换为float32
aidlite.setInput_Float32(img, 640, 640)
# 模型推理API
aidlite.invoke()
# 读取返回的结果
pred = aidlite.getOutput_Float32(0)
# 数据维度转换
pred = pred.reshape(1, 25200, 6)[0]
# 模型推理后处理
pred = detect_postprocess(pred, frame.shape, [640, 640, 3], conf_thres=0.25, iou_thres=0.45)
# 绘制推理结果
res_img = draw_detect_res(frame, pred)
# cvs.imshow(res_img)
# 测试结果展示停顿
#time.sleep(5)
# 图片裁剪,提取车辆目标区域
# extract_detect_res(frame, pred, image_name)
'''
检测结果提取
'''
img, all_boxes, image_name = frame, pred, image_name
img = img.astype(np.uint8)
color_step = int(255/len(all_boxes))
for bi in range(len(all_boxes)):
if len(all_boxes[bi]) == 0:
continue
count = 0
for box in all_boxes[bi]:
x, y, w, h = [int(t) for t in box[:4]]
#cv2.putText(img, f'{coco_class[bi]}', (x, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
#cv2.rectangle(img, (x,y), (x+w, y+h),(0, bi*color_step, 255-bi*color_step),thickness = 2)
cut_img = img[y:(y+h), x:(x + w)]
cv2.resize(cut_img,(80,177))
img = torch.tensor(bhwc2bchw(cut_img))[None, :, :, :].float().to(device)
### 对抗攻击监测
detect_pred = detect_model(img)
print(detect_pred)
if detect_pred > 0.5:
id = 'tGinrX9'
# 填写喵提醒中,发送的消息,这里放上前面提到的图片外链
text = "出现对抗攻击风险!!"
ts = str(time.time()) # 时间戳
type = 'json' # 返回内容格式
request_url = "http://miaotixing.com/trigger?"
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.67 Safari/537.36 Edg/87.0.664.47'}
result = requests.post(request_url + "id=" + id + "&text=" + text + "&ts=" + ts + "&type=" + type,
headers=headers)
# cv2.imwrite("/home/Lesson5_code/yolov5_code/aidlux/extract_results/" + image_name + "_" + str(count) + ".jpg",cut_img)
count += 1
实现视频:
https://zhuanlan.zhihu.com/p/589784525
- 总结
加深了对AidLux的认识,同时学习了对抗攻击等知识。