[Initial Image Segmentation Generator]论文实现:Efficient Graph-Based Image Segmentation

简介: [Initial Image Segmentation Generator]论文实现:Efficient Graph-Based Image Segmentation

一、完整代码

作者在文章开头地址中使用C++实现了这一过程,为了便于python学者理解,这里我们使用python代码进行实现

1.1 Python 完整程序

import cv2
import matplotlib.pyplot as plt
import numpy as np
class UnionFindSet:
    """
    创建一个交并集
    """
    def __init__(self, datalist):
        self.father = {}
        self.size = {}
        for data in datalist:
            self.father[data] = data
            self.size[data] = 1
    def find(self, node):
        father = self.father[node]
        if father != node:
            if father != self.father[father]:
                self.size[father] -= 1
            father = self.find(father)
        self.father[node] = father
        return father
    def union(self, node_a, node_b):
        if node_a is None or node_b is None:
            return
        a_head = self.find(node_a)
        b_head = self.find(node_b)
        if a_head != b_head:
            a_size = self.size[a_head]
            b_size = self.size[b_head]
            if a_size >= b_size:
                self.father[b_head] = a_head
                self.size[a_head] += b_size
            else:
                self.father[a_head] = b_head
                self.size[b_head] += a_size
    def the_same(self, node_a, node_b):
        return self.find(node_a) == self.find(node_b)
class Edge:
    def __init__(self, vi, vj, weight):
        self.vi = vi
        self.vj = vj
        self.weight = weight
def gaussian_blur(img, kernel_size, sigma):
    """
    img: the pic
    kernel_size: size
    sigma: Gaussian sigma
    output: img
    """
    img = cv2.GaussianBlur(np.array(img),(kernel_size,kernel_size),sigma,sigma)
    return img
def calculate_weight(img, x1, y1, x2, y2):
    r = img[x1][y1][0] - img[x2][y2][0]
    g = img[x1][y1][1] - img[x2][y2][1]
    b = img[x1][y1][2] - img[x2][y2][2]
    return np.sqrt(np.sum(np.square(np.array([r, g, b]))))
def calculate_weights(img, h, w):
    lst = []
    for i in range(h):
        for j in range(w):
            # 右上角
            if i != 0 and j < w - 1:
                weight = calculate_weight(img, i, j, i - 1, j + 1)
                edge = Edge(i * w + j, (i - 1) * w + j + 1, weight)
                lst.append(edge)
            # 右
            if j < w - 1:
                weight = calculate_weight(img, i, j, i, j + 1)
                edge = Edge(i * w + j, i * w + j + 1, weight)
                lst.append(edge)
            # 右下角
            if i != h - 1 and j < w - 1:
                weight = calculate_weight(img, i, j, i + 1, j + 1)
                edge = Edge(i * w + j, (i + 1) * w + j + 1, weight)
                lst.append(edge)
            # 下
            if i != h - 1:
                weight = calculate_weight(img, i, j, i + 1, j)
                edge = Edge(i * w + j, (i + 1) * w + j, weight)
                lst.append(edge)
    lst = sorted(lst, key=lambda x: x.weight)
    return lst
def get_size(S, node):
    return S.size[S.find(node)]
def main(img, k, min_size):
    h, w, chanles = img.shape
    lst = calculate_weights(img.astype(int), h, w)
    S = UnionFindSet(np.arange(h * w))
    Int = np.zeros(shape=h * w)
    # 因为从小到大排列 diff 就是weight
    for edge in lst:
        vi = edge.vi
        vj = edge.vj
        weight = edge.weight
        if not S.the_same(vi, vj) and weight <= min(Int[vi] + k / get_size(S, vi), Int[vj] + k / get_size(S, vj)):
            S.union(vi, vj)
            Int[vi] = Int[vj] = max(Int[vi], Int[vj], weight)
    for edge in lst:
        vi = edge.vi
        vj = edge.vj
        if not S.the_same(vi, vj) and (get_size(S, vi) <= min_size or get_size(S, vj) <= min_size):
            S.union(vi, vj)
            Int[vi] = Int[vj] = max(Int[vi], Int[vj], weight)
    trade = dict(zip(set(S.father.values()), range(len(set(S.father.values())))))
    img_ = np.array([trade[item] for item in list(S.father.values())]).reshape(h, w)
    return img_
def felzenszwalb(img, k, min_size):
    img = gaussian_blur(img, 5,5)
    img_ = main(img, k, min_size)
    return img_
def get_color(origin_img, img):
    """
    orgin_img: origin img
    img: after process img
    """
    new_img = np.zeros_like(origin_img)
    for i in range(np.max(img)):
        new_img[:,:,0][img == i] = np.random.randint(255)
        new_img[:,:,1][img == i] = np.random.randint(255)
        new_img[:,:,2][img == i] = np.random.randint(255)
    return new_img
if __name__ == '__main__':
    img = plt.imread('../data/cat-dog.jpg')
    img_ = felzenszwalb(img, 5000, 500)
    new_img = get_color(img, img_)
    plt.imshow(new_img)
    plt.show()

1.2 skimage 导包

import numpy as np
from skimage import segmentation
import matplotlib.pyplot as plt
def get_color(origin_img, img):
    """
    orgin_img: origin img
    img: after process img
    """
    new_img = np.zeros_like(origin_img)
    for i in range(np.max(img)):
        new_img[:,:,0][img == i] = np.random.randint(255)
        new_img[:,:,1][img == i] = np.random.randint(255)
        new_img[:,:,2][img == i] = np.random.randint(255)
    return new_img
if __name__ == '__main__':
    img = plt.imread('../data/cat-dog.jpg')
    img_ = segmentation.felzenszwalb(
        img,
        scale=500,
        sigma=0.95
    )
    new_img = get_color(img, img_)
    plt.imshow(new_img)
    plt.show()

二、论文解读

2.1 目的及其意义

目的:将图像根据颜色RGB特征按区域进行分割

意义:将分割后得到的区域进行有效筛选后,可以进行例如区域推荐,目标检测等等高级操作,当然其本质还是图像分割

特点:该方法的一个重要特点是,它能够在低变异性图像区域保留细节,而忽略了高变异性图像区域的细节。也就是说在变化不大的区域可以对细节进行保留,在变化很大的区域对细节进行剔除,这个效果根据后面合并小区域得到的。

2.2 框架以及指标

正如论文题目所说,这个方法是基于Graph的图像分割技术,是用图 G = ( V , E ) G=(V,E) G=(V,E) 来表达问题并对问题进行分析的,在此问题中, G G G是无向图,其 V V V中每一个节点对应于图中每一个像素点的位置, E E E中的每一条边表示相邻的两个节点的信息以及其差异值即权重。


权重定义如下:

Screenshot_20240510_101950.jpg

其中 r i , g i , b i r_i, g_i,b_i ri,gi,bi是图像在三个通道上的值的大小;


现在问题已经表示完毕,接下来我们需要对图像进行分割,其本质就是把 V V V中的节点分成不同的集合 C 1 , C 2 , … , C m . C_1,C_2, \dots,C_m. C1,C2,,Cm. 为了表示集合中节点的整体关系,我们定义一个指标:

 

Screenshot_20240510_102131.jpg

Int(C)=eMST(C,E)maxw(vi,vj) 这个指标可以表示集合内部的最大差异,其中MST是最小生成树,论文的算法主要是围绕MST的Kruskal算法展开,在这里我们可以通过对 E E E排序的方式进行构造,并不需要对MST的生成有要求,感兴趣可以自行了解;


上面对集合内部节点整体关系定义了一个指标,我们还需要对集合之间的关系定义一个指标,指标如下:


Screenshot_20240510_102233.jpg

该指标可以表示集合之间的最小差异;


我们如何对集合之间进行合并呢?定义以下指标: Screenshot_20240510_101514.jpg

{TrueDiff(C1,C2)>MInt(C1,C2)Falseotherwise

\begin{cases} True & \quad Diff(C_1, C_2) > MInt(C_1, C_2)\\ False & \quad otherwise \end{cases}

其中True表示合并,False表示不合并,其中


Screenshot_20240510_101810.jpg


这里 k k k是一个常数, ∣ C ∣ |C| C表示集合 C C C的大小;

指标定义到此完毕!

2.3 论文算法流程

算法原文如下图所示:

  1. 初始化,设  V中有n个节点,  E中有m条边, image.png 即每一个节点单独成一个集合
  2. 对 E E E根据权重进行由小到大的排序
  3. 遍历 E E E中每个元素,得到每个元素中包含的两个节点 image.png 找到 image.png 属于的集合 image.png 如果 C i = C j ,则跳过,否则计算 D ( C i , C j )判断是否合并,得到 image.png
  4. image.png

三、过程实现

为了适应实际情况,实现分为三步走:高斯模糊 -> 主要算法 -> 随机上色

3.1 导包

import cv2
import numpy as np
import matplotlib.pyplot as plt

3.2 高斯模糊

由于是对图像进行分割,所以不需要图片过于细节,把图像进行高斯平滑处理,去掉细节

def gaussian_blur(img, kernel_size, sigma):
    """
    img: the pic
    kernel_size: size
    sigma: Gaussian sigma
    output: img
    """
    img = cv2.GaussianBlur(np.array(img),(kernel_size,kernel_size),sigma,sigma)
    return img

在kernel_size = 5, sigma = 5 得到结果如下:

左边是原图, 右边是平滑过后的图像;

3.3 主要算法

首先创建一个并查集类

class UnionFindSet:
    """
    创建一个并查集
    """
    def __init__(self, datalist):
        self.father = {}
        self.size = {}
        for data in datalist:
            self.father[data] = data
            self.size[data] = 1
    
    def find(self, node):
        father = self.father[node]
        if father != node:
            if father != self.father[father]:
                self.size[father] -= 1
            father = self.find(father)
        self.father[node] = father
        return father
    def union(self, node_a, node_b):
        if node_a is None or node_b is None:
            return 
        
        a_head = self.find(node_a)
        b_head = self.find(node_b)
        if a_head != b_head:
            a_size = self.size[a_head]
            b_size = self.size[b_head]
            if a_size >= b_size:
                self.father[b_head] = a_head
                self.size[a_head] += b_size
            else:
                self.father[a_head] = b_head
                self.size[b_head] += a_size
    
    def the_same(self, node_a, node_b):
        return self.find(node_a) == self.find(node_b)

再进行主要算法计算

class Edge:
    def __init__(self, vi, vj, weight):
        self.vi = vi
        self.vj = vj
        self.weight = weight
def calculate_weight(img, x1, y1, x2, y2):
    r = img[x1][y1][0] - img[x2][y2][0]
    g = img[x1][y1][1] - img[x2][y2][1]
    b = img[x1][y1][2] - img[x2][y2][2]
    return np.sqrt(np.sum(np.square(np.array([r,g,b]))))
def calculate_weights(img, h, w):
    lst = []
    for i in range(h):
        for j in range(w):
            # 右上角
            if i != 0 and j < w-1:
                weight = calculate_weight(img, i, j, i-1,j+1)
                edge = Edge(i*w+j, (i-1)*w+j+1, weight)
                lst.append(edge)
            # 右
            if j < w-1:
                weight = calculate_weight(img, i, j, i, j+1)
                edge = Edge(i*w+j, i*w+j+1, weight)
                lst.append(edge)
            # 右下角
            if i != h-1 and j < w-1:
                weight = calculate_weight(img, i, j, i+1, j+1)
                edge = Edge(i*w+j, (i+1)*w+j+1, weight)
                lst.append(edge)
            # 下
            if i != h-1:
                weight = calculate_weight(img, i, j, i+1, j)
                edge = Edge(i*w+j, (i+1)*w+j, weight)
                lst.append(edge)
    lst = sorted(lst, key=lambda x:x.weight)
    return lst
def get_size(S, node):
    return S.size[S.find(node)]
def main(k, min_size):
    h, w, chanles = img.shape
    lst = calculate_weights(img.astype(int), h, w)
    S = UnionFindSet(np.arange(h*w))
    Int = np.zeros(shape=h*w)
    
    # 因为从小到大排列 diff 就是weight
    for edge in lst:
        vi = edge.vi
        vj = edge.vj
        weight = edge.weight
        if not S.the_same(vi, vj) and weight <= min(Int[vi] + k/get_size(S, vi), Int[vj] + k/get_size(S, vj)):
            S.union(vi, vj)
            Int[vi] = Int[vj] = max(Int[vi], Int[vj], weight)
            
    # 合并小集合
    for edge in lst:
        vi = edge.vi
        vj = edge.vj
        if not S.the_same(vi, vj) and (get_size(S, vi) <= min_size or get_size(S, vj) <= min_size):
            S.union(vi, vj)
            Int[vi] = Int[vj] = max(Int[vi], Int[vj], weight)
    trade = dict(zip(set(S.father.values()),range(len(set(S.father.values())))))
    img_ = np.array([trade[item] for item in list(S.father.values())]).reshape(h,w)
    return img_
main(5000,500)

3.4 随机上色

相比于之前的步骤,这一步很简单

def get_color(origin_img, img):
    """
    orgin_img: the original img
    img: the img after process
    """
    new_img = np.zeros_like(origin_img)
    for i in range(np.max(img)):
        new_img[:,:,0][img == i] = np.random.randint(255)
        new_img[:,:,1][img == i] = np.random.randint(255)
        new_img[:,:,2][img == i] = np.random.randint(255)
    return new_img

得到结果如下:

四、整体总结

没什么好总结的,干就完了!


目录
相关文章
RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)
RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)
2310 1
论文解读:LaMa:Resolution-robust Large Mask Inpainting with Fourier Convolutions
论文解读:LaMa:Resolution-robust Large Mask Inpainting with Fourier Convolutions
1160 0
React 图片轮播 Carousel:从入门到进阶
本文介绍了在 React 中实现图片轮播(Carousel)的方法,从基础安装和配置 `react-slick` 开始,逐步讲解了常见问题如图片路径、性能优化、自定义样式和交互处理,以及高级话题如动态数据加载和响应式设计。通过详细示例,帮助开发者避免易错点,提升轮播图的用户体验。
184 3
利用PyTorch的三元组损失Hard Triplet Loss进行嵌入模型微调
本文介绍了如何使用 PyTorch 和三元组边缘损失(Triplet Margin Loss)微调嵌入模型,详细讲解了实现细节和代码示例。
244 4
大规模数据集管理:DataLoader在分布式环境中的应用
【8月更文第29天】随着大数据时代的到来,如何高效地处理和利用大规模数据集成为了许多领域面临的关键挑战之一。本文将探讨如何在分布式环境中使用`DataLoader`来优化大规模数据集的管理与加载过程,并通过具体的代码示例展示其实现方法。
633 1
数据库技术的前沿探索:创新、挑战与未来机遇
一、引言 数据库技术作为信息化社会的基础设施,一直在不断演进以适应日益复杂的数据处理需求
1128 0
Python用Markowitz马克维兹有效边界构建最优投资组合可视化分析四只股票
Python用Markowitz马克维兹有效边界构建最优投资组合可视化分析四只股票
MySQL的MyISAM引擎:技术特点与应用场景
【4月更文挑战第20天】MySQL的MyISAM引擎特点是表级锁定,适合读多写少的场景,不支持事务但提供全文索引,适用于只读应用、全文搜索和简单备份恢复。在选择存储引擎时,应根据具体需求权衡。
1018 11
Flink报错问题之SQL作业中调用UDTF报错如何解决
Apache Flink是由Apache软件基金会开发的开源流处理框架,其核心是用Java和Scala编写的分布式流数据流引擎。本合集提供有关Apache Flink相关技术、使用技巧和最佳实践的资源。
计算机视觉五大核心研究任务全解:分类识别、检测分割、人体分析、三维视觉、视频分析
计算机视觉五大核心研究任务全解:分类识别、检测分割、人体分析、三维视觉、视频分析
595 0
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等

登录插画

登录以查看您的控制台资源

管理云资源
状态一览
快捷访问