链路预测算法MATLAB实现

简介: 链路预测是复杂网络分析中的重要任务,旨在预测网络中尚未连接的两个节点之间未来产生连接的可能性。

链路预测是复杂网络分析中的重要任务,旨在预测网络中尚未连接的两个节点之间未来产生连接的可能性。

程序概述

MATLAB程序实现了以下链路预测算法:

  1. 基于局部信息的相似性指标(Common Neighbors, Jaccard, Adamic-Adar等)
  2. 基于路径的相似性指标(Katz指数)
  3. 基于随机游走的相似性指标(Rooted PageRank, SimRank)
  4. 矩阵分解方法

代码

classdef LinkPrediction
    %LINKPREDICTION 链路预测算法实现
    %   包含多种链路预测算法

    properties
        A;          % 邻接矩阵
        train_mask; % 训练掩码矩阵
        test_mask;  % 测试掩码矩阵
        methods;    % 可用的预测方法
    end

    methods
        function obj = LinkPrediction(adj_matrix, train_ratio)
            %LINKPREDICTION 构造函数
            %   adj_matrix - 网络邻接矩阵
            %   train_ratio - 训练集比例(0-1)

            obj.A = adj_matrix;
            obj.methods = {
   'CommonNeighbors', 'Jaccard', 'AdamicAdar', ...
                          'PreferentialAttachment', 'Katz', 'RootedPageRank', ...
                          'SimRank', 'MatrixFactorization'};

            % 划分训练集和测试集
            if nargin > 1
                obj = obj.split_dataset(train_ratio);
            else
                obj.train_mask = ones(size(obj.A));
                obj.test_mask = zeros(size(obj.A));
            end
        end

        function obj = split_dataset(obj, train_ratio)
            %SPLIT_DATASET 划分训练集和测试集
            %   随机隐藏一部分边作为测试集

            [n, ~] = size(obj.A);
            obj.train_mask = ones(n);
            obj.test_mask = zeros(n);

            % 获取所有边的索引
            [rows, cols] = find(triu(obj.A, 1)); % 只取上三角避免重复
            num_edges = length(rows);
            num_train = round(num_edges * train_ratio);

            % 随机选择训练边
            idx = randperm(num_edges);
            train_idx = idx(1:num_train);
            test_idx = idx(num_train+1:end);

            % 创建掩码矩阵
            for i = 1:length(test_idx)
                r = rows(test_idx(i));
                c = cols(test_idx(i));
                obj.train_mask(r, c) = 0;
                obj.train_mask(c, r) = 0;
                obj.test_mask(r, c) = 1;
                obj.test_mask(c, r) = 1;
            end

            % 确保对角线为0
            obj.train_mask(1:n+1:end) = 0;
            obj.test_mask(1:n+1:end) = 0;
        end

        function scores = common_neighbors(obj)
            %COMMON_NEIGHBORS 共同邻居算法
            scores = (obj.A * obj.A) .* obj.train_mask;
        end

        function scores = jaccard(obj)
            %JACCARD Jaccard相似系数
            [n, ~] = size(obj.A);
            scores = zeros(n);

            for i = 1:n
                for j = i+1:n
                    if obj.train_mask(i, j) == 0
                        continue;
                    end

                    neighbors_i = find(obj.A(i, :));
                    neighbors_j = find(obj.A(j, :));

                    intersection = length(intersect(neighbors_i, neighbors_j));
                    union = length(union(neighbors_i, neighbors_j));

                    if union > 0
                        scores(i, j) = intersection / union;
                    else
                        scores(i, j) = 0;
                    end

                    scores(j, i) = scores(i, j);
                end
            end
        end

        function scores = adamic_adar(obj)
            %ADAMIC_ADAR Adamic-Adar指数
            [n, ~] = size(obj.A);
            scores = zeros(n);

            % 计算每个节点的度
            degrees = sum(obj.A, 2);

            for i = 1:n
                for j = i+1:n
                    if obj.train_mask(i, j) == 0
                        continue;
                    end

                    common_neighbors = find(obj.A(i, :) & obj.A(j, :));
                    score = 0;

                    for k = common_neighbors
                        if degrees(k) > 1 % 避免除以0
                            score = score + 1 / log(degrees(k));
                        end
                    end

                    scores(i, j) = score;
                    scores(j, i) = score;
                end
            end
        end

        function scores = preferential_attachment(obj)
            %PREFERENTIAL_ATTACHMENT 优先连接算法
            degrees = sum(obj.A, 2);
            scores = (degrees * degrees') .* obj.train_mask;
        end

        function scores = katz(obj, beta)
            %KATZ Katz指数
            %   beta - 衰减因子,默认0.01

            if nargin < 2
                beta = 0.01;
            end

            [n, ~] = size(obj.A);
            I = eye(n);
            scores = beta * obj.A; % 长度为1的路径

            % 计算Katz指数:S = βA + β²A² + β³A³ + ...
            % 使用矩阵求逆近似:S = (I - βA)^-1 - I
            scores = inv(I - beta * obj.A) - I;
            scores = scores .* obj.train_mask;
        end

        function scores = rooted_pagerank(obj, alpha, max_iter, tol)
            %ROOTED_PAGERANK Rooted PageRank算法
            %   alpha - 随机游走概率,默认0.85
            %   max_iter - 最大迭代次数,默认100
            %   tol - 收敛容差,默认1e-6

            if nargin < 2
                alpha = 0.85;
            end
            if nargin < 3
                max_iter = 100;
            end
            if nargin < 4
                tol = 1e-6;
            end

            [n, ~] = size(obj.A);
            scores = zeros(n);

            % 创建列随机矩阵(转移概率矩阵)
            P = obj.A ./ sum(obj.A, 1);
            P(isnan(P)) = 0; % 处理度为0的节点

            % 对每个节点计算Rooted PageRank
            for i = 1:n
                r = zeros(n, 1);
                r(i) = 1;

                for iter = 1:max_iter
                    r_new = alpha * P * r + (1 - alpha) * r;

                    if norm(r_new - r, 1) < tol
                        break;
                    end

                    r = r_new;
                end

                scores(:, i) = r;
            end

            scores = scores .* obj.train_mask;
        end

        function scores = simrank(obj, C, max_iter, tol)
            %SIMRANK SimRank算法
            %   C - 衰减因子,默认0.8
            %   max_iter - 最大迭代次数,默认10
            %   tol - 收敛容差,默认1e-4

            if nargin < 2
                C = 0.8;
            end
            if nargin < 3
                max_iter = 10;
            end
            if nargin < 4
                tol = 1e-4;
            end

            [n, ~] = size(obj.A);
            S = eye(n); % 初始化SimRank矩阵

            % 计算入邻居
            in_neighbors = cell(n, 1);
            for i = 1:n
                in_neighbors{
   i} = find(obj.A(:, i));
            end

            % 迭代计算SimRank
            for iter = 1:max_iter
                S_old = S;

                for i = 1:n
                    for j = 1:n
                        if i == j
                            S(i, j) = 1;
                            continue;
                        end

                        in_i = in_neighbors{
   i};
                        in_j = in_neighbors{
   j};

                        if isempty(in_i) || isempty(in_j)
                            S(i, j) = 0;
                            continue;
                        end

                        sum_sim = 0;
                        for a = 1:length(in_i)
                            for b = 1:length(in_j)
                                sum_sim = sum_sim + S_old(in_i(a), in_j(b));
                            end
                        end

                        S(i, j) = C * sum_sim / (length(in_i) * length(in_j));
                    end
                end

                if norm(S - S_old, 'fro') < tol
                    break;
                end
            end

            scores = S .* obj.train_mask;
        end

        function scores = matrix_factorization(obj, dim, lambda, max_iter, learning_rate)
            %MATRIX_FACTORIZATION 矩阵分解方法
            %   dim - 潜在特征维度,默认10
            %   lambda - 正则化参数,默认0.01
            %   max_iter - 最大迭代次数,默认100
            %   learning_rate - 学习率,默认0.01

            if nargin < 2
                dim = 10;
            end
            if nargin < 3
                lambda = 0.01;
            end
            if nargin < 4
                max_iter = 100;
            end
            if nargin < 5
                learning_rate = 0.01;
            end

            [n, ~] = size(obj.A);

            % 初始化用户和物品特征矩阵
            U = randn(n, dim) * 0.01;
            V = randn(n, dim) * 0.01;

            % 获取训练集中的非零元素(即存在的边)
            [rows, cols] = find(obj.train_mask);
            values = ones(length(rows), 1);

            % 随机梯度下降
            for iter = 1:max_iter
                total_error = 0;

                for idx = 1:length(rows)
                    i = rows(idx);
                    j = cols(idx);

                    % 计算预测值和误差
                    prediction = U(i, :) * V(j, :)';
                    error = values(idx) - prediction;
                    total_error = total_error + error^2;

                    % 更新特征向量
                    U_i_old = U(i, :);
                    U(i, :) = U(i, :) + learning_rate * (error * V(j, :) - lambda * U(i, :));
                    V(j, :) = V(j, :) + learning_rate * (error * U_i_old - lambda * V(j, :));
                end

                % 添加正则化项
                total_error = total_error + lambda * (norm(U, 'fro')^2 + norm(V, 'fro')^2);

                if mod(iter, 10) == 0
                    fprintf('迭代 %d, 误差: %.4f\n', iter, total_error);
                end
            end

            % 计算得分矩阵
            scores = U * V';
            scores = scores .* obj.train_mask;
        end

        function [precision, recall, auc] = evaluate(obj, scores, top_k)
            %EVALUATE 评估预测结果
            %   scores - 预测得分矩阵
            %   top_k - 计算precision@k和recall@k的k值

            if nargin < 3
                top_k = 100;
            end

            % 获取测试集中的正样本
            [test_rows, test_cols] = find(obj.test_mask);
            positive_pairs = [test_rows, test_cols];
            num_positives = size(positive_pairs, 1);

            % 获取所有未连接的节点对(负样本+测试正样本)
            negative_mask = (obj.train_mask == 0) & (obj.A == 0) & (eye(size(obj.A)) == 0);
            [negative_rows, negative_cols] = find(negative_mask);
            negative_pairs = [negative_rows, negative_cols];

            % 随机选择与正样本数量相同的负样本
            idx = randperm(size(negative_pairs, 1), num_positives);
            negative_pairs = negative_pairs(idx, :);

            % 合并正负样本
            all_pairs = [positive_pairs; negative_pairs];
            labels = [ones(num_positives, 1); zeros(num_positives, 1)];

            % 获取预测得分
            pred_scores = zeros(size(all_pairs, 1), 1);
            for i = 1:size(all_pairs, 1)
                pred_scores(i) = scores(all_pairs(i, 1), all_pairs(i, 2));
            end

            % 计算AUC
            [~, ~, ~, auc] = perfcurve(labels, pred_scores, 1);

            % 计算Precision@K和Recall@K
            % 获取得分最高的top_k个节点对
            [~, sorted_idx] = sort(pred_scores(1:num_positives), 'descend');
            top_predictions = sorted_idx(1:min(top_k, length(sorted_idx)));

            true_positives = sum(ismember(top_predictions, 1:num_positives));
            precision = true_positives / top_k;
            recall = true_positives / num_positives;
        end

        function results = compare_methods(obj, methods, top_k)
            %COMPARE_METHODS 比较不同算法的性能
            %   methods - 要比较的方法列表
            %   top_k - 计算precision@k和recall@k的k值

            if nargin < 2
                methods = obj.methods;
            end
            if nargin < 3
                top_k = 100;
            end

            results = struct();

            for i = 1:length(methods)
                method = methods{
   i};
                fprintf('正在计算 %s...\n', method);

                try
                    % 调用相应的方法
                    tic;
                    scores = obj.(lower(method))();
                    time = toc;

                    % 评估性能
                    [precision, recall, auc] = obj.evaluate(scores, top_k);

                    % 保存结果
                    results.(method).scores = scores;
                    results.(method).precision = precision;
                    results.(method).recall = recall;
                    results.(method).auc = auc;
                    results.(method).time = time;

                    fprintf('%s: Precision@%d=%.4f, Recall@%d=%.4f, AUC=%.4f, 时间=%.2fs\n', ...
                            method, top_k, precision, top_k, recall, auc, time);

                catch ME
                    fprintf('计算 %s 时出错: %s\n', method, ME.message);
                    results.(method).error = ME.message;
                end
            end
        end

        function plot_results(obj, results)
            %PLOT_RESULTS 可视化比较结果
            methods = fieldnames(results);
            num_methods = length(methods);

            precisions = zeros(num_methods, 1);
            recalls = zeros(num_methods, 1);
            aucs = zeros(num_methods, 1);
            times = zeros(num_methods, 1);

            for i = 1:num_methods
                if isfield(results.(methods{
   i}), 'error')
                    continue;
                end
                precisions(i) = results.(methods{
   i}).precision;
                recalls(i) = results.(methods{
   i}).recall;
                aucs(i) = results.(methods{
   i}).auc;
                times(i) = results.(methods{
   i}).time;
            end

            % 创建图形
            figure('Position', [100, 100, 1200, 800]);

            % 绘制精确度
            subplot(2, 2, 1);
            bar(precisions);
            set(gca, 'XTickLabel', methods, 'XTickLabelRotation', 45);
            title('Precision@K');
            ylabel('Precision');
            grid on;

            % 绘制召回率
            subplot(2, 2, 2);
            bar(recalls);
            set(gca, 'XTickLabel', methods, 'XTickLabelRotation', 45);
            title('Recall@K');
            ylabel('Recall');
            grid on;

            % 绘制AUC
            subplot(2, 2, 3);
            bar(aucs);
            set(gca, 'XTickLabel', methods, 'XTickLabelRotation', 45);
            title('AUC');
            ylabel('AUC');
            grid on;

            % 绘制运行时间
            subplot(2, 2, 4);
            bar(times);
            set(gca, 'XTickLabel', methods, 'XTickLabelRotation', 45);
            title('运行时间');
            ylabel('时间 (秒)');
            grid on;

            % 调整布局
            set(gcf, 'Color', 'w');
        end
    end
end

% 示例使用代码
function example_usage()
    % 生成示例网络(无标度网络)
    n = 100; % 节点数量
    A = create_scale_free_network(n);

    % 创建链路预测对象,使用80%的边作为训练集
    lp = LinkPrediction(A, 0.8);

    % 比较所有方法的性能
    results = lp.compare_methods();

    % 可视化结果
    lp.plot_results(results);

    % 单独使用某个方法
    scores = lp.common_neighbors();
    [precision, recall, auc] = lp.evaluate(scores);
    fprintf('\nCommon Neighbors: Precision=%.4f, Recall=%.4f, AUC=%.4f\n', precision, recall, auc);
end

function A = create_scale_free_network(n)
    %CREATE_SCALE_FREE_NETWORK 生成无标度网络(Barabási-Albert模型)
    %   n - 网络节点数

    % 初始完全图
    m0 = 5; % 初始节点数
    A = zeros(n);
    A(1:m0, 1:m0) = ones(m0) - eye(m0);

    % 添加新节点
    for i = m0+1:n
        % 计算现有节点的度
        degrees = sum(A(1:i-1, 1:i-1), 2);
        total_degree = sum(degrees);

        % 根据度分布选择连接节点
        if total_degree > 0
            prob = degrees / total_degree;
            targets = randsample(1:i-1, m0, true, prob);
        else
            targets = randperm(i-1, min(m0, i-1));
        end

        % 添加连接
        for j = targets
            A(i, j) = 1;
            A(j, i) = 1;
        end
    end
end

% 运行示例
example_usage();

说明

这个MATLAB链路预测程序提供了以下功能:

1. 核心类 LinkPrediction

包含多种链路预测算法的实现,以及评估和比较功能。

2. 实现的算法

  1. Common Neighbors (共同邻居):基于两个节点共同邻居的数量
  2. Jaccard Coefficient:共同邻居数除以总邻居数
  3. Adamic-Adar:考虑共同邻居的度,度越小权重越大
  4. Preferential Attachment:基于两个节点的度乘积
  5. Katz Index:考虑所有路径,路径越短权重越大
  6. Rooted PageRank:基于随机游走的相似性度量
  7. SimRank:基于结构上下文的相似性度量
  8. Matrix Factorization:基于矩阵分解的潜在特征方法

3. 评估指标

  • Precision@K:前K个预测中正确预测的比例
  • Recall@K:正确预测的正样本占所有正样本的比例
  • AUC:ROC曲线下面积,衡量分类器整体性能

4. 可视化功能

提供四种评估指标的可视化比较,便于分析不同算法的性能。

推荐代码 链路预测程序,主程序,包含31个链路预测的函数代码 www.youwenfan.com/contentalj/52463.html

使用

程序最后提供了一个示例使用代码:

  1. 生成一个无标度网络(Barabási-Albert模型)
  2. 创建链路预测对象,划分训练集和测试集
  3. 比较所有算法的性能
  4. 可视化比较结果
  5. 单独使用Common Neighbors算法并进行评估
相关文章
|
10天前
|
缓存 测试技术 API
Qwen 3.7 Plus 与 Max 实测:性价比与多模态能力差异解析(2026)
2026 年 6 月 1 日,阿里悄无声息地发布了 Qwen 3.7 Plus,距 Qwen 3.7 Max 上线刚好 11 天。同样的 1M 上下文,同样的 35 小时自治上限。但价格才是头条:Plus 是 0.40/M输入,Max是 2.50/M——便宜约 6 倍——并且还能看图、看视频。Vision Arena 上 Plus 已经排到 #16。所以这周真正值得讨论的问题不是”要不要为视觉能力买单”,而是”Max 凭什么用 6 倍价格换来 2 个百分点的 benchmark 领先”。
|
11天前
|
JavaScript 定位技术 API
CodeGraph 爆火:编程 Agent 需要的不是更多上下文,而是一张提前画好的代码地图
CodeGraph 是一款爆火的本地代码智能工具,通过 tree-sitter 解析 AST 构建结构化知识图谱(存于 SQLite),为编程 Agent 提前生成“代码地图”。它显著降低 Agent 在中大型项目中的探索成本——实测工具调用减少71%、Token 降57%、速度提升46%,支持19+语言及主流框架路由识别,完全离线、无需 API Key。
842 11
CodeGraph 爆火:编程 Agent 需要的不是更多上下文,而是一张提前画好的代码地图
|
11天前
|
人工智能 运维 JavaScript
阿里云Qoder CN(原通义灵码)全解析 产品形态、版本划分与技术适配说明
在AI辅助开发与智能办公工具持续普及的当下,阿里云旗下原通义灵码正式更名为Qoder CN,同时延伸出QoderWork CN、Qoder CN CLI、Qoder CN Mobile等多款配套产品,形成覆盖代码开发、日常办公、终端交互、移动端使用的完整工具矩阵。Qoder CN核心定位为AI智能编码助手,深度适配主流代码编辑器、集成开发环境以及终端场景;QoderWork CN则偏向桌面端综合办公辅助,二者面向不同使用场景,划分了多个版本档位,搭配差异化资源配额、功能权限与计费规则,同时兼容多款主流大模型。
848 7
|
11天前
|
存储 安全 Java
AgentScope Java 2.0:打造分布式、企业级智能体底座
AgentScope 2.0 面向分布式部署、稳定运行、权限安全等企业级需求全面升级,打造支持多租户隔离与长期稳定运行的企业级智能体底座。
|
11天前
|
JSON 缓存 安全
通过 CC Switch 本地路由让 Codex CLI 接入 DeepSeek 等第三方模型
CC Switch 通过本地路由(`127.0.0.1:15721`)实现协议转换:将 Codex 的 Responses API 请求自动映射为 DeepSeek 等厂商的 Chat Completions 接口,兼容流式响应与工具调用,无需修改 Codex 源码,安全隔离 API Key。(239字)
2273 4
通过 CC Switch 本地路由让 Codex CLI 接入 DeepSeek 等第三方模型
|
11天前
|
人工智能 弹性计算 安全
阿里云618活动时间、活动入口、优惠活动详细解读
2026年阿里云618创新加速季已全面开启,作为年度力度最大的云产品促销活动,本次大促覆盖轻量应用服务器、ECS云服务器、GPU云服务器、数据库、AI算力、安全服务、CDN等全品类产品,推出5亿元算力补贴、新用户限时秒杀、普惠满减、企业专享、免费试用、云大使返佣等多重福利,个人开发者、中小企业、AI团队均可享受专属低价。本文将系统梳理2026年阿里云618活动的完整时间节点、官方参与入口、各类优惠细则、使用规则、热门产品推荐及实操代码,帮助用户精准参与、高效省钱,以最低成本完成上云部署。
1872 6
|
11天前
|
数据采集 人工智能 前端开发
让 Coding Agent 从黑盒到透明:阿里云 Agent 观测审计数据采集实践
AI Agent 规模化落地带来执行黑盒、行为难追溯、成本难度量三大难题。阿里云基于 OTel 标准,面向 Coding Agent、个人通用助理和框架型 Agent,推出 LoongSuite Pilot、插件及探针等无侵入采集方案,让 Agent 实现可看见、可分析、可审计、可治理。
784 150
|
11天前
|
人工智能 运维 自然语言处理
阿里云百炼Qwen3.7-Max模型详解:综合能力、核心优势与订阅计划参考指南
2026年,大模型技术持续向通用化、高性能、场景化方向迭代,阿里云百炼作为一站式大模型服务平台,持续推出迭代升级的模型产品,Qwen3.7-Max便是当前主力旗舰级大模型之一。该模型依托深度优化的底层架构与大规模训练数据,在文本理解、逻辑推理、多模态交互、代码生成、长文本处理等多个维度实现能力升级,同时搭配灵活的订阅计划体系,能够适配个人开发者、中小企业、大型企业、政企机构等不同类型用户的使用需求。
632 2