NSFW 图片分类

简介: NSFW指的是 **不适宜工作场所**("Not Safe (or Suitable) For Work;")。在本文中,将介绍如何创建一个检测NSFW图像的图像分类模型。

数据集

由于数据集的性质,我们无法从一些数据集的网站(如Kaggle等)获得所有图像。

但是我们找到了一个专门抓取这种类型图片的github库,所以我们可以直接使用。clone项目后可以运行下面的代码来创建文件夹,并将每个图像下载到其特定的文件夹中。

 folders = ['drawings','hentai','neutral','porn','sexy']
 urls = ['urls_drawings.txt','urls_hentai.txt','urls_neutral.txt','urls_porn.txt','urls_sexy.txt']
 names = ['d','h','n','p','s']

 for i,j,k in zip(folders,urls,names):
     try:
         #Specify the path of the  folder that has to be made
         folder_path = os.path.join('your directory',i)
         os.mkdir(folder_path)
     except:
         pass
     #setup the path of url text file
     url_path = os.path.join('Datasets_Urls',j)
     my_file = open(url_path, "r")
     data = my_file.read()
     #create a list with all urls
     data_into_list = data.split("\n")
     my_file.close()
     icount = 0
     for ii in data_into_list:
         try:
             #create a unique image names for each images
             image_name = 'image'+str(icount)+str(k)+'.png'
             image_path = os.path.join(folder_path,image_name)
             #download it using the library
             urllib.request.urlretrieve(ii, image_path)
             icount+=1
         except Exception as e:
             pass
         #this below code is done to make the count of the image same for all the data 
         #you can use a big number if you are building a more complex model or if you have a good system
         if icount == 2000:
             break

这里的folder变量表示类的名称,urls变量用于获取URL文本文件(可以根据文本文件名更改它),name变量用于为每个图像创建唯一的名称。

上面代码将为每个类下载2000张图像,可以编辑最后一个“if”条件来更改下载图像的个数。

数据准备

我们下载的文件夹可能包含其他类型的文件,所以首先必须删除不需要的类型的文件。

 image_exts = ['jpeg','.jpg','bmp','png']
 path_list = ['drawings','hentai','neutral','porn','sexy']
 cwd = os.getcwd()
 def remove_other_images(path_list):
     for ii in path_list:
         data_dir = os.path.join(cwd,'DataSet',ii)
         for image in os.listdir(os.path.join(data_dir)):
             image_path = os.path.join(data_dir,image_class,image)
             try:
                 img = cv2.imread(image_path)
                 tip = imghdr.what(image_path)
                 if tip not in image_exts:
                     print('Image not in ext list {}'.format(image_path))
                     os.remove(image_path)
             except Exception as e:
                 print("Issue with image {}".format(image_path))
 remove_other_images(path_list)

上面的代码删除了扩展名不是指定格式的图像。

另外图像可能包含许多重复的图像,所以我们必须从每个文件夹中删除重复的图像。

 cwd = os.getcwd()
 path_list = ['drawings','hentai','neutral','porn','sexy']
 def remove_dup_images(path_list):
     for ii in path_list:
         os.chdir(os.path.join(cwd,'DataSet',ii))
         filelist = os.listdir()
         duplicates = []
         hash_keys = dict()
         for index, filename in enumerate(filelist):
             if os.path.isfile(filename):
                 with open(filename,'rb') as f:
                     filehash = hashlib.md5(f.read()).hexdigest()
                 if filehash not in hash_keys:
                     hash_keys[filehash] = index
                 else:
                     duplicates.append((index,hash_keys[filehash]))

         for index in duplicates:
             os.remove(filelist[index[0]])
             print('{} duplicates removed from {}'.format(len(duplicates),ii))
 remove_dup_images(path_list)

这里我们使用hashlib.md5编码来查找每个类中的重复图像。

Md5为每个图像创建一个唯一的哈希值,如果哈希值重复(重复图像),那么我们将重复图片添加到一个列表中,稍后进行删除。

因为使用TensorFlow框架所以需要判断是否被TensorFlow支持,所以我们这里加一个判断:

 import tensorflow as tf

 os.chdir('{data-set} directory')
 cwd = os.getcwd()

 for ii in path_list:
     os.chdir(os.path.join(cwd,ii))
     filelist = os.listdir()
     for image_file in filelist:
         with open(image_file, 'rb') as f:
             image_data = f.read()

         # Check the file format
         _, ext = os.path.splitext(image_file)
         if ext.lower() not in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
             print('Unsupported image format:', ext)
             os.remove(os.path.join(cwd,ii,image_file))            
         else:
             # Decode the image
             try:
                 image = tf.image.decode_image(image_data)
             except:
                 print(image_file)
                 print("unspported")
                 os.remove(os.path.join(cwd,ii,image_file))

以上就是数据准备的所有工作,在清理完数据后,我们可以拆分数据。比如分割创建一个训练、验证和测试文件夹,并手动添加文件夹中的图像,我们将80%用于训练,10%用于验证,10%用于测试。

模型

首先导入tensorflow

 import tensorflow as tf
 import os
 import numpy as np
 import matplotlib.pyplot as plt
 from sklearn.utils import shuffle
 import hashlib
 from imageio import imread
 import numpy as np
 from tensorflow.keras.preprocessing.image import ImageDataGenerator
 from tensorflow.keras.applications.vgg16 import VGG16
 from tensorflow.keras.applications.vgg16 import preprocess_input
 from tensorflow.keras.layers import Flatten,Dense,Input
 from tensorflow.keras.models import Model,Sequential
 from keras import optimizers

对于图像,默认大小设置为224,224。

 IMAGE_SIZE = [224,224]

可以使用ImageDataGenerator库,进行数据增强。数据增强也叫数据扩充,是为了增加数据集的大小。ImageDataGenerator根据给定的参数创建新图像,并将其用于训练(注意:当使用ImageDataGenerator时,原始数据将不用于训练)。

 train_datagen = ImageDataGenerator(
         rescale=1./255,
         preprocessing_function=preprocess_input,
         rotation_range=40,
         width_shift_range=0.2,
         height_shift_range=0.2,
         shear_range=0.2,
         zoom_range=0.2,
         horizontal_flip=True,
         fill_mode='nearest')

对于测试集也是这样:

 test_datagen = ImageDataGenerator(rescale=1./255)

为了演示,我们直接使用VGG模型

vgg = VGG16(input_shape=IMAGE_SIZE+[3],weights='imagenet',include_top=False

然后冻结前面的层:

for layer in vgg.layers:
    layer.trainable = False

最后我们加入自己的分类头:

x = Flatten()(vgg.output)
prediction = Dense(5,activation='softmax')(x)
model = Model(inputs=vgg.input, outputs=prediction)
model.summary()

模型是这样的:

训练

看看我们训练集:

train_set = train_datagen.flow_from_directory('DataSet/train',
                                              target_size=(224,224),
                                              batch_size=32,
                                              class_mode='sparse')

验证集

val_set = train_datagen.flow_from_directory('DataSet/validation',
                                              target_size=(224,224),
                                              batch_size=32,
                                              class_mode='sparse')

使用' sparse_categorical_crossentropy '损失,这样可以将标签编码为整数而不是独热编码。

from tensorflow.keras.metrics import MeanSquaredError
from tensorflow.keras.metrics import CategoricalAccuracy
adam = optimizers.Adam()
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=adam,
              metrics=['accuracy',MeanSquaredError(name='val_loss'),CategoricalAccuracy(name='val_accuracy')])

然后就可以训练了:

from datetime import datetime
from keras.callbacks import ModelCheckpoint

log_dir = 'vg_log'

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = log_dir)

start = datetime.now()

history = model.fit_generator(train_set,
                              validation_data=val_set,
                              epochs=100,
                              steps_per_epoch=len(train_set)// batch_size,
                              validation_steps=len(val_set)//batch_size,
                              callbacks=[tensorboard_callback],
                             verbose=1)

duration = datetime.now() - start
print("Time taken for training is ",duration)

模型训练了100次。得到了80%的验证准确率。f1得分为93%

预测

下面的函数将获取一个图像列表并根据该列表进行预测。

import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
def print_classes(images,model):
    classes = ['Drawing','Hentai','Neutral','Porn','Sexual']
    fig, ax = plt.subplots(ncols=len(images), figsize=(20,20))
    for idx,img in enumerate(images):
        img = mpimg.imread(img)
        resize = tf.image.resize(img,(224,224))
        result = model.predict(np.expand_dims(resize/255,0))
        result = np.argmax(result)
        if classes[result] == 'Porn':
            img = gaussian_filter(img, sigma=6)
        elif classes[result] == 'Sexual':
            img = gaussian_filter(img, sigma=6)
        elif classes[result] == 'Hentai':
            img = gaussian_filter(img, sigma=6)
        ax[idx].imshow(img)
        ax[idx].title.set_text(classes[result])

li = ['test1.jpeg','test2.jpeg','test3.jpeg','test4.jpeg','test5.jpeg']
print_classes(li,model)

看结果还是可以的。

最后,本文的源代码:

https://avoid.overfit.cn/post/8f681841d02e4a8db7bcf77926e123f1

作者:Nikhil Thalappalli

目录
相关文章
|
并行计算 开发工具 C++
无所不谈,百无禁忌,Win11本地部署无内容审查中文大语言模型CausalLM-14B
目前流行的开源大语言模型大抵都会有内容审查机制,这并非是新鲜事,因为之前chat-gpt就曾经被“玩”坏过,如果没有内容审查,恶意用户可能通过精心设计的输入(prompt)来操纵LLM执行不当行为。内容审查可以帮助识别和过滤这些潜在的攻击,确保LLM按照既定的安全策略和道德标准运行。 但我们今天讨论的是无内容审查机制的大模型,在中文领域公开的模型中,能力相对比较强的有阿里的 Qwen-14B 和清华的 ChatGLM3-6B。 而今天的主角,CausalLM-14B则是在Qwen-14B基础上使用了 Qwen-14B 的部分权重,并且加入一些其他的中文数据集,最终炼制了一个无内容审核的
无所不谈,百无禁忌,Win11本地部署无内容审查中文大语言模型CausalLM-14B
|
Java Spring
仿写@DS 多数据源动态切换
最近公司在做项目,用到了多数据源,我在网上找了好多的开源项目。
仿写@DS 多数据源动态切换
|
关系型数据库 数据库 PostgreSQL
PostgreSQL批量删除数据
当需要对一些不需要的历史数据进行大批量删除时, 在使用delete语句时,会发现在删除一些数据时会非常慢 比如 DELETE FROM test where id < 10000; 删除缓慢的原因主要在于外键约束,当数据库在有约束的情况下,无论进行删除或者更新操作, 都会对相关表进行一个校验,判断相关表的相关记录是否被删除或者更新。 这个检查的过程会非常慢, 尤其在外建表又关联着外建表的这种层层嵌套的情况下。
2460 0
|
10月前
|
机器学习/深度学习 存储 数据中心
《深度揭秘:TPU张量计算架构如何重塑深度学习运算》
TPU(张量处理单元)是谷歌为应对深度学习模型计算需求而设计的专用硬件。其核心矩阵乘法单元(MXU)采用脉动阵列架构,显著提升矩阵运算效率;内存管理单元优化数据流通,减少瓶颈;控制单元协调系统运作,确保高效稳定。TPU在训练和推理速度、能耗方面表现出色,大幅缩短BERT等模型的训练时间,降低数据中心成本。尽管通用性和易用性仍有挑战,但TPU已为深度学习带来革命性变化,未来有望进一步优化。
617 19
|
人工智能 数据安全/隐私保护 计算机视觉
GitHub爆款神器 | IOPaint:21.7k star 开源AI图像修复项目,竟能秒删水印、拓展画幅!
IOPaint 是一款由 Sanster 团队开发的开源图像处理工具,集成多种 SOTA AI 模型,支持图像擦除、对象替换、文本绘制和图像外扩等功能。它操作简便,一键安装,适用于 Windows、macOS、Linux 和 Apple Silicon 系统,适合摄影爱好者、电商从业者及内容创作者使用,大幅提升图像处理效率。
584 0
|
Unix 编译器 iOS开发
苹果AppleMacOs系统Sonoma本地部署无内容审查(NSFW)大语言量化模型Causallm
最近Mac系统在运行大语言模型(LLMs)方面的性能已经得到了显著提升,尤其是随着苹果M系列芯片的不断迭代,本次我们在最新的MacOs系统Sonoma中本地部署无内容审查大语言量化模型Causallm。 这里推荐使用koboldcpp项目,它是由c++编写的kobold项目,而MacOS又是典型的Unix操作系统,自带clang编译器,也就是说MacOS操作系统是可以直接编译C语言的。
苹果AppleMacOs系统Sonoma本地部署无内容审查(NSFW)大语言量化模型Causallm
|
JavaScript Java Spring
springboot+vue 实现校园二手商城(毕业设计一)
这篇文章介绍了一个使用Spring Boot和Vue实现的校园二手商城系统的毕业设计,包括用户和商家的功能需求,如登录注册、订单管理、商品评价、联系客服等,以及项目依赖项的安装过程。
springboot+vue 实现校园二手商城(毕业设计一)
|
安全 Java Shell
【内网—内网转发】——http协议代理转发_reGeorg代理转发
【内网—内网转发】——http协议代理转发_reGeorg代理转发
587 3
|
前端开发
CSS中class的样式赋值方法详解
CSS中class的样式赋值方法详解
323 0
|
JavaScript 前端开发 Java
博客管理系统|基于SpringBoot+Vue+ElementUI个人博客系统的设计与实现
博客管理系统|基于SpringBoot+Vue+ElementUI个人博客系统的设计与实现
665 0