AFTFull
digraph AFTFull { rankdir=BT node [ style=filled, color=Black fontcolor=White, fillcolor="#30638e", fontname="SimHei", fontsize=32, width=5, height=2, ] inp [label="输入\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"] llq [label="LinearQ\n[HidSize, ProjSize]", shape="box"] llk [label="LinearK\n[HidSize, ProjSize]", shape="box"] llv [label="LinearV\n[HidSize, ProjSize]", shape="box"] w [label="W:Param\n[SeqLen, SeqLen]", shape="Mrecord"] q [label="Q\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"] k [label="K\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"] v [label="V\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"] σ [label="Sigmoid", shape="box", width=3] atten_op [label="exp(W) @ (exp(K) * V)\n/ exp(W) * exp(K)", shape="box"] atten [label="[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"] mul [label="*", shape="box", width=3] llo [label="LinearO\n[ProjSize, HidSize]", shape="box"] oup [label="输出\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"] inp -> llq inp -> llk inp -> llv llq -> q llk -> k llv -> v q -> σ w -> atten_op k -> atten_op v -> atten_op atten_op -> atten σ -> mul atten -> mul mul -> llo llo -> oup }
AFTSimple
digraph AFTSimple { rankdir=BT node [ style=filled, color=Black fontcolor=White, fillcolor="#30638e", fontname="SimHei", fontsize=32, width=5, height=2, ] inp [label="输入\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"] llq [label="LinearQ\n[HidSize, ProjSize]", shape="box"] llk [label="LinearK\n[HidSize, ProjSize]", shape="box"] llv [label="LinearV\n[HidSize, ProjSize]", shape="box"] q [label="Q\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"] k [label="K\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"] v [label="V\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"] σ [label="Sigmoid", shape="box", width=3] atten_op [label="sum(softmax(K, 1) * V, 1)", shape="box"] atten [label="[BatchSize, 1, ProjSize]", shape="Mrecord"] mul [label="*", shape="box", width=3] llo [label="LinearO\n[ProjSize, HidSize]", shape="box"] oup [label="输出\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"] inp -> llq inp -> llk inp -> llv llq -> q llk -> k llv -> v q -> σ k -> atten_op v -> atten_op atten_op -> atten σ -> mul atten -> mul mul -> llo llo -> oup }
AFTLocal
digraph AFTLocal { rankdir=BT node [ style=filled, color=Black fontcolor=White, fillcolor="#30638e", fontname="SimHei", fontsize=32, width=5, height=2, ] inp [label="输入\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"] llq [label="LinearQ\n[HidSize, ProjSize]", shape="box"] llk [label="LinearK\n[HidSize, ProjSize]", shape="box"] llv [label="LinearV\n[HidSize, ProjSize]", shape="box"] w [label="W:Param\n[SeqLen, SeqLen]", shape="Mrecord"] mask [label="mask\n[SeqLen, SeqLen]\nabs(i - j) < S? 1: 0", shape="box"] q [label="Q\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"] k [label="K\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"] v [label="V\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"] σ [label="Sigmoid", shape="box", width=3] atten_op [label="exp(W) @ (exp(K) * V)\n/ exp(W) * exp(K)", shape="box"] atten [label="[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"] mul [label="*", shape="box", width=3] llo [label="LinearO\n[ProjSize, HidSize]", shape="box"] oup [label="输出\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"] inp -> llq inp -> llk inp -> llv llq -> q llk -> k llv -> v q -> σ w -> mask mask -> atten_op k -> atten_op v -> atten_op atten_op -> atten σ -> mul atten -> mul mul -> llo llo -> oup


