基于Permutohedral Lattice 的Bilateral filter 源码及部分注释【C++】

简介: 基于Permutohedral Lattice 的Bilateral filter 源码及部分注释【C++】

基于Permutohedral Lattice 的Bilateral filter 源码及部分注释【来自于网络】



实现基于论文《Fast High-Dimensional Filtering Using the Permutohedral Lattice》 .

延伸阅读 saliency filters精读之permutohedral lattice


1.bilateralPermutohedral 方法:

static Mat bilateralPermutohedral(Mat img, Mat edge, float sigma_s, float sigma_r)  // img 和 edge 都必须是CV_32F类型
    {
        float invSpatialStdev = 1.0f / sigma_s;
        float invColorStdev = 1.0f / sigma_r;
        // Construct the position vectors out of x, y, r, g, and b.
        int height = img.rows;
        int width = img.cols;
        int eCh = edge.channels();  // 1 或 3
        int iCh = img.channels();
        Image positions(1, width, height, 2 + eCh); // 只有一个子窗口
        Image input(1, width, height, iCh);
        //From Mat to Image
        for (int y = 0; y < height; y++)
        {
            float *pimg = img.ptr<float>(y);
            float *pedge = edge.ptr<float>(y);
            for (int x = 0; x < width; x++)
            {
        // 参考论文 p4 3.1
        // 5维的 positiion vector
                positions(x, y)[0] = invSpatialStdev * x;   // 0
                positions(x, y)[1] = invSpatialStdev * y;   // 1
                for(int c = 0; c < eCh; c++)
                    positions(x, y)[2 + c] = invColorStdev * pedge[x * eCh + c];  // 2+
        // 3维的 input vector
                for(int c = 0; c < iCh; c++)
                    input(x, y)[c] = pimg[x * iCh + c];
            }
        }
        // Filter the input with respect to the position vectors. (see permutohedral.h)
        Image out = PermutohedralLattice::filter(input, positions);
        // Save the result
        Mat imgOut(img.size(), img.type());
        for (int y = 0; y < height; y++)
        {
            float *pimgOut = imgOut.ptr<float>(y);
            for (int x = 0; x < width; x++)
            {
                for(int c = 0; c < iCh; c++)
                    pimgOut[x * iCh + c] = out(x, y)[c];
            }
        }
        return imgOut;
    }

2. PermutohedralLattice 类:

/***************************************************************/
/* The algorithm class that performs the filter
 *
 * PermutohedralLattice::filter(...) does all the work.
 *
 */
/***************************************************************/
class PermutohedralLattice
{
public:
    /* Filters given image against a reference image.
     *   im : image to be bilateral-filtered. (input vector)
     *   ref : reference image whose edges are to be respected. (position vector)
     */
    static Image filter(Image im, Image ref)
    {
        //timeval t[5];
        // Create lattice
        // gettimeofday(t+0, NULL);
    // d = ref.channels            (5)
    // vd = im.channels + 1    (3+1)
        PermutohedralLattice lattice(ref.channels, im.channels + 1, im.width * im.height * im.frames);
        // Splat into the lattice
        // gettimeofday(t+1, NULL);
        //  printf("Splatting...\n");
        float *col = new float[im.channels + 1];
        col[im.channels] = 1; // homogeneous coordinate
        float *imPtr = im(0, 0, 0);
        float *refPtr = ref(0, 0, 0); // position vector
        for (int t = 0; t < im.frames; t++)
        {
            for (int y = 0; y < im.height; y++)
            {
                for (int x = 0; x < im.width; x++)
                {
                    for (int c = 0; c < im.channels; c++)
                    {
                        col[c] = *imPtr++;
                    }
                    lattice.splat(refPtr, col);
                    refPtr += ref.channels;
                }
            }
        }
        // Blur the lattice
        // gettimeofday(t+2, NULL);
        //  printf("Blurring...");
        lattice.blur();
        // Slice from the lattice
        // gettimeofday(t+3, NULL);
        //  printf("Slicing...\n");
        Image out(im.frames, im.width, im.height, im.channels);
        lattice.beginSlice();
        float *outPtr = out(0, 0, 0);
        for (int t = 0; t < im.frames; t++)
        {
            for (int y = 0; y < im.height; y++)
            {
                for (int x = 0; x < im.width; x++)
                {
                    lattice.slice(col);
                    float scale = 1.0f / col[im.channels];
                    for (int c = 0; c < im.channels; c++)
                    {
                        *outPtr++ = col[c] * scale;
                    }
                }
            }
        }
        // Print time elapsed for each step
        //    gettimeofday(t+4, NULL);
        //     const char *names[4] = {"Init  ", "Splat ", "Blur  ", "Slice "};
        //     for (int i = 1; i < 5; i++)
        //       printf("%s: %3.3f ms\n", names[i-1], (t[i].tv_sec - t[i-1].tv_sec) +
        //       (t[i].tv_usec - t[i-1].tv_usec)/1000000.0);
        return out;
    }
    /* Constructor
     *     d_ : dimensionality of key vectors (ref.channels)
     *    vd_ : dimensionality of value vectors (im.channels + 1)
     *    nData_ : number of points in the input (im.size * im.frames)
     */
    PermutohedralLattice(int d_, int vd_, int nData_) :
        d(d_), vd(vd_), nData(nData_), hashTable(d_, vd_)
    {
        // Allocate storage for various arrays
        elevated = new float[d + 1];
        scaleFactor = new float[d];
        greedy = new short[d + 1];
        rank = new char[d + 1];
        barycentric = new float[d + 2];
        replay = new ReplayEntry[nData * (d + 1)];
        nReplay = 0;
        canonical = new short[(d + 1) * (d + 1)];
        key = new short[d + 1];
        // compute the coordinates of the canonical simplex, in which
        // the difference between a contained point and the zero
        // remainder vertex is always in ascending order. (See pg.4 of paper.)
    // 论文第四页,d=4的矩阵例子(列主序)
        for (int i = 0; i <= d; i++)
        {
            for (int j = 0; j <= d - i; j++)
                canonical[i * (d + 1) + j] = i;
            for (int j = d - i + 1; j <= d; j++)
                canonical[i * (d + 1) + j] = i - (d + 1);
        }
        // Compute parts of the rotation matrix E. (See pg.4-5 of paper.)
        for (int i = 0; i < d; i++)
        {
            // the diagonal entries for normalization
            scaleFactor[i] = 1.0f / (sqrtf( (float)(i + 1) * (i + 2) ));
            /* We presume that the user would like to do a Gaussian blur of standard deviation
             * 1 in each dimension (or a total variance of d, summed over dimensions.)
             * Because the total variance of the blur performed by this algorithm is not d,
             * we must scale the space to offset this.
             *
             * The total variance of the algorithm is (See pg.6 and 10 of paper):
             *  [variance of splatting] + [variance of blurring] + [variance of splatting]
             *   = d(d+1)(d+1)/12 + d(d+1)(d+1)/2 + d(d+1)(d+1)/12
             *   = 2d(d+1)(d+1)/3.
             *
             * So we need to scale the space by (d+1)sqrt(2/3).
             */
       // 论文 第四页 scale position vector
            scaleFactor[i] *= (d + 1) * sqrtf(2.0 / 3);
        }
    }
    /* Performs splatting with given position and value vectors */
  // position: d-dimension position vector
  // value: [r, g, b, 1]
    void splat(float *position, float *value)
    {
        // first rotate position into the (d+1)-dimensional hyperplane
    // 论文 第五页 Ex计算
        elevated[d] = -d * position[d - 1] * scaleFactor[d - 1];
        for (int i = d - 1; i > 0; i--)
            elevated[i] = (elevated[i + 1] -
                           i * position[i - 1] * scaleFactor[i - 1] +
                           (i + 2) * position[i] * scaleFactor[i]);
        elevated[0] = elevated[1] + 2 * position[0] * scaleFactor[0];
        // prepare to find the closest lattice points
        float scale = 1.0f / (d + 1);
        char *myrank = rank;
        short *mygreedy = greedy;
        // greedily search for the closest zero-colored lattice point
    // 论文 第三页
        int sum = 0;
        for (int i = 0; i <= d; i++)
        {
            float v = elevated[i] * scale;
            float up = ceilf(v) * (d + 1);  // 查找最近的整数点,up / down
            float down = floorf(v) * (d + 1);
            if (up - elevated[i] < elevated[i] - down) 
        mygreedy[i] = (short)up;
            else 
        mygreedy[i] = (short)down;
            sum += mygreedy[i];
        }
        sum /= d + 1; // consistent remainder (d+1)
        // rank differential to find the permutation between this simplex and the canonical one.
        // (See pg. 3-4 in paper.)
    // 相对差值小的rank++
        memset(myrank, 0, sizeof(char) * (d + 1));
        for (int i = 0; i < d; i++)
            for (int j = i + 1; j <= d; j++)
                if (elevated[i] - mygreedy[i] < elevated[j] - mygreedy[j])
          myrank[i]++;
                else 
          myrank[j]++;
        if (sum > 0)
        {
            // sum too large - the point is off the hyperplane.
            // need to bring down the ones with the smallest differential
            for (int i = 0; i <= d; i++)
            {
                if (myrank[i] >= d + 1 - sum)
                {
                    mygreedy[i] -= d + 1;
                    myrank[i] += sum - (d + 1);
                }
                else
                    myrank[i] += sum;
            }
        }
        else if (sum < 0)
        {
            // sum too small - the point is off the hyperplane
            // need to bring up the ones with largest differential
            for (int i = 0; i <= d; i++)
            {
                if (myrank[i] < -sum)
                {
                    mygreedy[i] += d + 1;
                    myrank[i] += (d + 1) + sum;
                }
                else
                    myrank[i] += sum;
            }
        }
        // Compute barycentric coordinates (See pg.10 of paper.)
        memset(barycentric, 0, sizeof(float) * (d + 2));
        for (int i = 0; i <= d; i++)
        {
            barycentric[d - myrank[i]] += (elevated[i] - mygreedy[i]) * scale;
            barycentric[d + 1 - myrank[i]] -= (elevated[i] - mygreedy[i]) * scale;
        }
        barycentric[0] += 1.0f + barycentric[d + 1];
        // Splat the value into each vertex of the simplex, with barycentric weights.
        for (int remainder = 0; remainder <= d; remainder++)
        {
            // Compute the location of the lattice point explicitly (all but the last coordinate - it's redundant because they sum to zero)
            for (int i = 0; i < d; i++)
                key[i] = mygreedy[i] + canonical[remainder * (d + 1) + myrank[i]];
            // Retrieve pointer to the value at this vertex.
            float *val = hashTable.lookup(key, true);
            // Accumulate values with barycentric weight.
            for (int i = 0; i < vd; i++)
                val[i] += barycentric[remainder] * value[i];
            // Record this interaction to use later when slicing
            replay[nReplay].offset = val - hashTable.getValues();
            replay[nReplay].weight = barycentric[remainder];
            nReplay++;
        }
    }
    // Prepare for slicing
    void beginSlice()
    {
        nReplay = 0;
    }
    /* Performs slicing out of position vectors. Note that the barycentric weights and the simplex
     * containing each position vector were calculated and stored in the splatting step.
     * We may reuse this to accelerate the algorithm. (See pg. 6 in paper.)
     */
    void slice(float *col)
    {
        float *base = hashTable.getValues();
        for (int j = 0; j < vd; j++) 
      col[j] = 0;
        for (int i = 0; i <= d; i++)
        {
            ReplayEntry r = replay[nReplay++];
            for (int j = 0; j < vd; j++)
            {
                col[j] += r.weight * base[r.offset + j];
            }
        }
    }
    /* Performs a Gaussian blur along each projected axis in the hyperplane. */
    void blur()
    {
        // Prepare arrays
        short *neighbor1 = new short[d + 1];
        short *neighbor2 = new short[d + 1];
        float *newValue = new float[vd * hashTable.size()];
        float *oldValue = hashTable.getValues();
        float *hashTableBase = oldValue;
        float *zero = new float[vd];
        for (int k = 0; k < vd; k++) 
      zero[k] = 0;
        // For each of d+1 axes,
        for (int j = 0; j <= d; j++)
        {
            printf("blur %d\t", j);
            fflush(stdout);
            // For each vertex in the lattice,
            for (int i = 0; i < hashTable.size(); i++)   // blur point i in dimension j
            {
                short *key    = hashTable.getKeys() + i * (d); // keys to current vertex
                for (int k = 0; k < d; k++)
                {
                    neighbor1[k] = key[k] + 1;
                    neighbor2[k] = key[k] - 1;
                }
                neighbor1[j] = key[j] - d;
                neighbor2[j] = key[j] + d; // keys to the neighbors along the given axis.
                float *oldVal = oldValue + i * vd;
                float *newVal = newValue + i * vd;
                float *vm1, *vp1;
        //printf("first neighbor\n");
                vm1 = hashTable.lookup(neighbor1, false); // look up first neighbor
                if (vm1) 
          vm1 = vm1 - hashTableBase + oldValue;
                else 
          vm1 = zero;
        //printf("second neighbor\n");
                vp1 = hashTable.lookup(neighbor2, false); // look up second neighbor
                if (vp1) 
          vp1 = vp1 - hashTableBase + oldValue;
                else 
          vp1 = zero;
                // Mix values of the three vertices
                for (int k = 0; k < vd; k++)
                    newVal[k] = (0.25f * vm1[k] + 0.5f * oldVal[k] + 0.25f * vp1[k]);
            }
            float *tmp = newValue;
            newValue = oldValue;
            oldValue = tmp;
            // the freshest data is now in oldValue, and newValue is ready to be written over
        }
        // depending where we ended up, we may have to copy data
        if (oldValue != hashTableBase)
        {
            memcpy(hashTableBase, oldValue, hashTable.size()*vd * sizeof(float));
            delete oldValue;
        }
        else
        {
            delete newValue;
        }
        printf("\n");
        delete zero;
        delete neighbor1;
        delete neighbor2;
    }
private:
    int d, vd, nData;
    float *elevated, *scaleFactor, *barycentric;
    short *canonical;
    short *key;
    // slicing is done by replaying splatting (ie storing the sparse matrix)
    struct ReplayEntry
    {
        int offset;
        float weight;
    } *replay;
    int nReplay, nReplaySub;
public:
    char  *rank;
    short *greedy;
    HashTablePermutohedral hashTable;
};

3. 用于permutohedral lattice的哈希表:

/***************************************************************/
/* Hash table implementation for permutohedral lattice
 *
 * The lattice points are stored sparsely using a hash table.
 * The key for each point is its spatial location in the (d+1)-
 * dimensional space.
 */
/***************************************************************/
class HashTablePermutohedral
{
public:
    /* Constructor
     *  kd_: the dimensionality of the position vectors on the hyperplane.
     *  vd_: the dimensionality of the value vectors
     */
    HashTablePermutohedral(int kd_, int vd_) : kd(kd_), vd(vd_)
    {
        capacity = 1 << 15;
        filled = 0;
        entries = new Entry[capacity];
        keys = new short[kd * capacity / 2];    // 多维 键-值对
        values = new float[vd * capacity / 2];
        memset(values, 0, sizeof(float)*vd * capacity / 2);
    }
    // Returns the number of vectors stored.
    int size()
    {
        return filled;
    }
    // Returns a pointer to the keys array.
    short *getKeys()
    {
        return keys;
    }
    // Returns a pointer to the values array.
    float *getValues()
    {
        return values;
    }
    /* Returns the index into the hash table for a given key.
     *  key: a pointer to the position vector.
     *  h: hash of the position vector.
     *  create: a flag specifying whether an entry should be created,
     *          should an entry with the given key not found.
     */
  // 返回 value 指针的偏移量
    int lookupOffset(short *key, size_t h, bool create = true)
    {
        // Double hash table size if necessary
    // 如果存储的数据达到或超过容量的一半
        if (filled >= (capacity / 2) - 1)
        {
            grow();
        }
        // Find the entry with the given key
    // 根据给定的 hash 索引 entry
        while (1)
        {
            Entry e = entries[h];
            // check if the cell is empty
      // 检查该 entry 的 key 是否存在
            if (e.keyIdx == -1)
            {
                if (!create) 
          return -1; // Return not found.
                // need to create an entry. Store the given key.
                for (int i = 0; i < kd; i++)
                    keys[filled * kd + i] = key[i];
                e.keyIdx = filled * kd;
                e.valueIdx = filled * vd;
                entries[h] = e;
                filled++;
                return e.valueIdx;
            }
            // check if the cell has a matching key
            bool match = true;
            for (int i = 0; i < kd && match; i++)
                match = keys[e.keyIdx + i] == key[i];
            if (match)
                return e.valueIdx;
            // increment the bucket with wraparound
      // 顺序查找下一个 entry 【计算出的hash值相同的情况】
            h++;
      // 如果到达最后一个 entry, 则从第一个 entry 开始找
            if (h == capacity) 
        h = 0;
        }
    }
    /* Looks up the value vector associated with a given key vector.
     *  k : pointer to the key vector to be looked up.
     *  create : true if a non-existing key should be created.
     */
    float *lookup(short *k, bool create = true)
    {
        size_t h = hash(k) % capacity;
        int offset = lookupOffset(k, h, create);
        if (offset < 0) 
      return NULL;
        else 
      return values + offset;
    };
    /* Hash function used in this implementation. A simple base conversion. */
    size_t hash(const short *key)
    {
        size_t k = 0;
        for (int i = 0; i < kd; i++)
        {
            k += key[i];
            k *= 2531011;
        }
        return k;
    }
private:
    /* Grows the size of the hash table */
    void grow()
    {
        printf("Resizing hash table\n");
        size_t oldCapacity = capacity;
        capacity *= 2;  // 变为2倍容量
        // Migrate the value vectors.
        float *newValues = new float[vd * capacity / 2];
        memset(newValues, 0, sizeof(float)*vd * capacity / 2);
        memcpy(newValues, values, sizeof(float)*vd * filled);
        delete[] values;
        values = newValues;
        // Migrate the key vectors.
        short *newKeys = new short[kd * capacity / 2];
        memcpy(newKeys, keys, sizeof(short)*kd * filled);
        delete[] keys;
        keys = newKeys;
        Entry *newEntries = new Entry[capacity];
        // Migrate the table of indices.
        for (size_t i = 0; i < oldCapacity; i++)
        {
            if (entries[i].keyIdx == -1) 
        continue;
      // 根据键值计算hash
            size_t h = hash(keys + entries[i].keyIdx) % capacity;
      // 如果hash对应entry的keyidx已经被占用,则顺序往后找 entry,直到发现该 entry 的 keyidx 未被占用
            while (newEntries[h].keyIdx != -1)
            {
                h++;
                if (h == capacity)
          h = 0;
            }
            newEntries[h] = entries[i];
        }
        delete[] entries;
        entries = newEntries;
    }
    // Private struct for the hash table entries.
    struct Entry
    {
        Entry() : keyIdx(-1), valueIdx(-1) {}
        int keyIdx;   // keys 的索引
        int valueIdx; // values 的索引
    };
    short *keys;
    float *values;
    Entry *entries;
    size_t capacity, filled;  // 分别表示 entry 的容量 和 已填充的 entry 数
    int kd, vd;  // keys 和 values 数组的维度(PermutohedraLattice 会将数据 splat 到高维空间)
};

效果图:

image.png

目录
相关文章
|
5天前
|
存储 机器学习/深度学习 人工智能
c/c++线性表实现附源码(超详解)
c/c++线性表实现附源码(超详解)
14 0
|
1月前
|
存储 人工智能 数据安全/隐私保护
【C++面向对象】C++考试题库管理系统(源码)【独一无二】
【C++面向对象】C++考试题库管理系统(源码)【独一无二】
|
1月前
|
存储 人工智能 搜索推荐
【C语言/C++】电子元器件管理系统(C源码)【独一无二】
【C语言/C++】电子元器件管理系统(C源码)【独一无二】
|
1月前
|
存储 人工智能 机器人
【C++面向对象】C++图书管理系统 (源码)【独一无二】
【C++面向对象】C++图书管理系统 (源码)【独一无二】
|
1月前
|
存储 人工智能 机器人
【C/C++】C语言 学生信息管理系统(源码)【独一无二】
【C/C++】C语言 学生信息管理系统(源码)【独一无二】
|
1月前
|
人工智能 机器人 测试技术
【C/C++】C语言 21点桌牌游戏 (源码) 【独一无二】
【C/C++】C语言 21点桌牌游戏 (源码) 【独一无二】
|
1月前
|
存储 人工智能 机器人
【C/C++】C++学籍信息管理系统(源码+报告)【独一无二】
【C/C++】C++学籍信息管理系统(源码+报告)【独一无二】
|
1月前
|
存储 人工智能 搜索推荐
【C/C++】C/C++招聘信息管理系统(源码)【独一无二】
【C/C++】C/C++招聘信息管理系统(源码)【独一无二】
|
1月前
|
人工智能 机器人 测试技术
【C++面向对象】C++飞机购票订票系统(源码+说明)【独一无二】
【C++面向对象】C++飞机购票订票系统(源码+说明)【独一无二】
|
1月前
|
存储 人工智能 BI
【C++面向对象】C++银行卡管理系统(源码+论文)【独一无二】
【C++面向对象】C++银行卡管理系统(源码+论文)【独一无二】