【CCCC】L3-023 计算图 (30分),dfs搜索+偏导数计算

简介: 【CCCC】L3-023 计算图 (30分),dfs搜索+偏导数计算

problem

L3-023 计算图 (30分)
“计算图”(computational graph)是现代深度学习系统的基础执行引擎,提供了一种表示任意数学表达式的方法,例如用有向无环图表示的神经网络。 图中的节点表示基本操作或输入变量,边表示节点之间的中间值的依赖性。 例如,下图就是一个函数 f(x
​1
​​ ,x
​2
​​ )=lnx
​1
​​ +x
​1
​​ x
​2
​​ −sinx
​2
​​ 的计算图。

figure.png

现在给定一个计算图,请你根据所有输入变量计算函数值及其偏导数(即梯度)。 例如,给定输入x
​1
​​ =2,x
​2
​​ =5,上述计算图获得函数值 f(2,5)=ln(2)+2×5−sin(5)=11.652;并且根据微分链式法则,上图得到的梯度 ∇f=[∂f/∂x
​1
​​ ,∂f/∂x
​2
​​ ]=[1/x
​1
​​ +x
​2
​​ ,x
​1
​​ −cosx
​2
​​ ]=[5.500,1.716]。

知道你已经把微积分忘了,所以这里只要求你处理几个简单的算子:加法、减法、乘法、指数(e
​x
​​ ,即编程语言中的 exp(x) 函数)、对数(lnx,即编程语言中的 log(x) 函数)和正弦函数(sinx,即编程语言中的 sin(x) 函数)。

友情提醒:

常数的导数是 0;x 的导数是 1;e
​x
​​ 的导数还是 e
​x
​​ ;lnx 的导数是 1/x;sinx 的导数是 cosx。
回顾一下什么是偏导数:在数学中,一个多变量的函数的偏导数,就是它关于其中一个变量的导数而保持其他变量恒定。在上面的例子中,当我们对 x
​1
​​ 求偏导数 ∂f/∂x
​1
​​ 时,就将 x
​2
​​ 当成常数,所以得到 lnx
​1
​​ 的导数是 1/x
​1
​​ ,x
​1
​​ x
​2
​​ 的导数是 x
​2
​​ ,sinx
​2
​​ 的导数是 0。
回顾一下链式法则:复合函数的导数是构成复合这有限个函数在相应点的导数的乘积,即若有 u=f(y),y=g(x),则 du/dx=du/dy⋅dy/dx。例如对 sin(lnx) 求导,就得到 cos(lnx)⋅(1/x)。
如果你注意观察,可以发现在计算图中,计算函数值是一个从左向右进行的计算,而计算偏导数则正好相反。

输入格式:
输入在第一行给出正整数 N(≤5×10
​4
​​ ),为计算图中的顶点数。

以下 N 行,第 i 行给出第 i 个顶点的信息,其中 i=0,1,⋯,N−1。第一个值是顶点的类型编号,分别为:

0 代表输入变量
1 代表加法,对应 x
​1
​​ +x
​2
​​
2 代表减法,对应 x
​1
​​ −x
​2
​​
3 代表乘法,对应 x
​1
​​ ×x
​2
​​
4 代表指数,对应 e
​x
​​
5 代表对数,对应 lnx
6 代表正弦函数,对应 sinx
对于输入变量,后面会跟它的双精度浮点数值;对于单目算子,后面会跟它对应的单个变量的顶点编号(编号从 0 开始);对于双目算子,后面会跟它对应两个变量的顶点编号。

题目保证只有一个输出顶点(即没有出边的顶点,例如上图最右边的 -),且计算过程不会超过双精度浮点数的计算精度范围。

输出格式:
首先在第一行输出给定计算图的函数值。在第二行顺序输出函数对于每个变量的偏导数的值,其间以一个空格分隔,行首尾不得有多余空格。偏导数的输出顺序与输入变量的出现顺序相同。输出小数点后 3 位。

输入样例:
7
0 2.0
0 5.0
5 0
3 0 1
6 1
1 2 3
2 5 4
输出样例:
11.652
5.500 1.716

  • 给出一张图(AOV网?)
  • 计算汇点的值

solution

  • 考虑直接搜索+记忆化
  • 结果节点编号不一定为n-1,否则会WA3,4,5
  • 微积分没学,计算公式不太会,参考了一下别人的
#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e4+10;

struct node{
    int op, left, right; //运算符和数值
    double val; //当前节点的值
    int post;  //后继节点的
}a[maxn];
map<int,map<int,map<int,double>>>f;//记忆化数组

//第一个参数为结点,第二个参数决定是否求导,第三个参数是对谁求导
double calc(int nd, int key, int p){
    if(f[nd][key][p])return f[nd][key][p];
    int id = a[nd].op;
    if(id==0)return f[nd][key][p] = (key == 0 ? a[nd].val : (nd == p ? 1 : 0)); 
    if(id==1)return f[nd][key][p] = calc(a[nd].left, key, p) + calc(a[nd].right, key, p); 
    if(id==2)return f[nd][key][p]= calc(a[nd].left, key, p) - calc(a[nd].right, key, p); 
    if(id==3)return f[nd][key][p] = (key ? calc(a[nd].left, key, p) * calc(a[nd].right, 0, p) + calc(a[nd].left, 0, p) * calc(a[nd].right, key, p) : calc(a[nd].left, key, p) * calc(a[nd].right, key, p)); 
    if(id==4)return f[nd][key][p]=(key ? exp(calc(a[nd].left, 0, p)) * calc(a[nd].left, key, p) : exp(calc(a[nd].left, key, p)));
    if(id==5)return f[nd][key][p] = (key ? 1 / (calc(a[nd].left, 0, p)) * (calc(a[nd].left, key, p)) : log(calc(a[nd].left, key, p)));
    if(id==6)return f[nd][key][p] = (key ? cos(calc(a[nd].left, 0, p)) * calc(a[nd].left, key, p) : sin(calc(a[nd].left, key, p)));
}

int main(){
    int n;  cin>>n;
    for(int i = 0; i < n; i++){
        cin>>a[i].op;
        if(a[i].op==0){
            cin>>a[i].val;
        }else if(a[i].op<=3){
            cin>>a[i].left>>a[i].right;
            a[a[i].left].post = 1;
            a[a[i].right].post = 1;
        }else{
            cin>>a[i].left;
            a[a[i].left].post = 1;
        }
    }
    int ed = 0, ok=0;
    while(a[ed].post)ed++;
    printf("%0.3lf\n",calc(ed,0,-1));
    for(int i = 0; i < n; i++){
        if(a[i].op==0){
            if(ok)cout<<" ";
            printf("%0.3lf",calc(ed,1,i));
            ok = 1;
        }
    }
    return 0;
}
目录
相关文章
|
3月前
|
机器学习/深度学习 人工智能 算法
【代数学作业1完整版-python实现GNFS一般数域筛】构造特定的整系数不可约多项式:涉及素数、模运算和优化问题
【代数学作业1完整版-python实现GNFS一般数域筛】构造特定的整系数不可约多项式:涉及素数、模运算和优化问题
60 0
|
3月前
|
机器学习/深度学习 人工智能 算法
【代数学作业1-python实现GNFS一般数域筛】构造特定的整系数不可约多项式:涉及素数、模运算和优化问题
【代数学作业1-python实现GNFS一般数域筛】构造特定的整系数不可约多项式:涉及素数、模运算和优化问题
54 0
|
5天前
【视频】什么是非线性模型与R语言多项式回归、局部平滑样条、 广义相加GAM分析工资数据|数据分享(上)
【视频】什么是非线性模型与R语言多项式回归、局部平滑样条、 广义相加GAM分析工资数据|数据分享
14 0
|
5天前
【视频】什么是非线性模型与R语言多项式回归、局部平滑样条、 广义相加GAM分析工资数据|数据分享(下)
【视频】什么是非线性模型与R语言多项式回归、局部平滑样条、 广义相加GAM分析工资数据|数据分享
|
4月前
油管公式(全)
油管公式(全)
56 0
|
5月前
|
算法 定位技术
插值、平稳假设、本征假设、变异函数、基台、块金、克里格、线性无偏最优…地学计算概念及公式推导
插值、平稳假设、本征假设、变异函数、基台、块金、克里格、线性无偏最优…地学计算概念及公式推导
|
7月前
|
算法
华为机试HJ70:矩阵乘法计算量估算
华为机试HJ70:矩阵乘法计算量估算
|
9月前
wustojc4010按公式计算y和z的值
wustojc4010按公式计算y和z的值
54 0
|
10月前
|
机器学习/深度学习 移动开发
线性代数高级--二次型--特征值与特征向量--特征值分解--多元函数的泰勒展开
线性代数高级--二次型--特征值与特征向量--特征值分解--多元函数的泰勒展开
|
数据可视化 JavaScript 前端开发
【数学篇】07 # 如何用向量和参数方程描述曲线?
【数学篇】07 # 如何用向量和参数方程描述曲线?
81 0
【数学篇】07 # 如何用向量和参数方程描述曲线?