gopacket reassembly源码分析
调用
参考示例example/reassemblydump
- 自定义一个factory,实现
New
接口
type tcpStreamFactory struct {
wg sync.WaitGroup
doHTTP bool
}
func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream {
}
- 自定义一个stream,实现
reassembly.Stream
接口
type Stream interface {
// 是否接受这个包
Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir TCPFlowDirection, nextSeq Sequence, start *bool, ac AssemblerContext) bool
// 用来读取包体
ReassembledSG(sg ScatterGather, ac AssemblerContext)
// 包关闭的处理
ReassemblyComplete(ac AssemblerContext) bool
}
- 将TCP包传入assembley中
func main() {
defer util.Run()()
// 1. 打开设备
var handle *pcap.Handle
var err error
handle, err = pcap.OpenLive(*iface, int32(*snaplen), true, pcap.BlockForever)
if err != nil {
log.Fatal(err)
}
// 设置BPF
if err := handle.SetBPFFilter(*filter); err != nil {
log.Fatal(err)
}
// 2. 初始化assembly
streamFactory := &tcpStreamFactory{doHTTP: !*nohttp}
streamPool := reassembly.NewStreamPool(streamFactory)
assembler := reassembly.NewAssembler(streamPool)
log.Println("reading in packets")
// 3.初始化packetSource
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
packets := packetSource.Packets()
ticker := time.Tick(time.Second)
for {
select {
// 4. 读取包
case packet := <-packets:
// A nil packet indicates the end of a pcap file.
if packet == nil {
return
}
if *logAllPackets {
log.Println(packet)
}
if packet.NetworkLayer() == nil || packet.TransportLayer() == nil || packet.TransportLayer().LayerType() != layers.LayerTypeTCP {
log.Println("Unusable packet")
continue
}
tcp := packet.TransportLayer().(*layers.TCP)
// 5. tcp直接丢进去
assembler.AssembleWithContext(packet.NetworkLayer().NetworkFlow(), tcp, &c)
case <-ticker:
// 6. 定时书信连接
flushed, closed := assembler.FlushWithOptions(reassembly.FlushOptions{T: ref.Add(-timeout), TC: ref.Add(-closeTimeout)})
}
}
}
Assembler
Assembler处理并重组TCP流,并发是不安全的。
在通过Assemble传入数据包必须等待调用返回,然后再次调用Assemble。可以通过建多个共享StreamPool来解决这个问题。
Assembler 提供(希望)快速的 TCP 流重组,用于嗅探用 Go 编写的应用程序。Assembler 使用以下方法尽可能快地处理数据包:
- 避免锁:
Assembler 锁定conn,但每个conn都有一个单独的锁,两个Assembler 很少会查看同一个conn。Assembler 在查找conn时会锁定StreamPool,但它们最初使用读取器锁,并且仅在需要创建新conn或关闭conn时才强制写入锁。这些发生的频率远低于单个数据包处理。
每个Assembler 都在自己的 goroutine 中运行,goroutine 之间共享的唯一状态是通过 StreamPool。因此,所有内部Assembler状态都可以在没有任何锁定的情况下处理。
注意:如果您可以保证发送到一组 Assembler 的数据包将包含有关每个 Assembler 不同conn的信息(例如,它们已经通过 PF_RING 散列或其他一些散列机制进行散列),那么我们建议您使用单独的 StreamPool 每个汇编程序,从而避免所有锁争用。只有当不同的 Assembler 可以接收相同 Stream 的数据包时,它们之间才应该共享 StreamPool。 - 避免内存复制:
在常见情况下,处理单个 TCP 数据包应导致内存分配为零。Assembler将查找连接,确定数据包已按顺序到达,并立即将该数据包传递给适当的连接处理代码。只有当数据包无序到达时,才会将其内容复制并存储在内存中以备后用。 - 避免内存分配:
除非绝对必要,否则Assembler会尽量不使用内存分配。顺序数据包的数据直接传递给Stream,无需复制或分配。乱序包的包数据被复制到可重用的page中,只有在page缓存用完时才很少分配新page。pageCache是特定于Assembler的,因此不会同时使用并且不需要锁定。
随着时间的推移,conn对象的内部表示也会被重用。因此,Assembler 完成的最常见的内存分配通常是 StreamFactory.New 中的调用者完成的。如果在那里没有进行分配,那么就很少进行分配,主要是为了处理带宽或连接数量的大幅增加。
type AssemblerOptions struct {
// 等待无序包时要缓冲的page总数最大值
// 一旦达到这个上限值, Assembler将会降级刷新每个连接的,如果<=0将被忽略。
MaxBufferedPagesTotal int
// 单个连接缓冲的page最大值
// 如果达到上限,则将刷新最小序列号以及任何连续数据。如果<= 0,这将被忽略。
MaxBufferedPagesPerConnection int
}
type Assembler struct {
AssemblerOptions // 选项
ret []byteContainer // 数据包
pc *pageCache // 数据缓存页
connPool *StreamPool // 每个连接的池
cacheLP livePacket
cacheSG reassemblyObject
start bool
}
func NewAssembler(pool *StreamPool) *Assembler {
pool.mu.Lock()
pool.users++
pool.mu.Unlock()
return &Assembler{
ret: make([]byteContainer, 0, assemblerReturnValueInitialSize),
pc: newPageCache(),
connPool: pool,
AssemblerOptions: DefaultAssemblerOptions,
}
}
AssembleWithContext
AssembleWithContext 将给定的 TCP 数据包重新组合到其适当的Stream中。
传入的时间戳必须是看到数据包的时间戳。对于从网络上读取的数据包,time.Now() 应该没问题。对于从 PCAP 文件读取的数据包,应传入CaptureInfo.Timestamp。此时间戳将影响通过调用 FlushCloseOlderThan 刷新哪些流。
func (a *Assembler) AssembleWithContext(netFlow gopacket.Flow, t *layers.TCP, ac AssemblerContext) {
var conn *connection
var half *halfconnection
var rev *halfconnection
a.ret = a.ret[:0]
key := key{netFlow, t.TransportFlow()}
ci := ac.GetCaptureInfo()
timestamp := ci.Timestamp
// 获取/创建一个conn
conn, half, rev = a.connPool.getConnection(key, false, timestamp, t, ac)
if conn == nil {
if *debugLog {
log.Printf("%v got empty packet on otherwise empty connection", key)
}
return
}
// 锁的范围那么大吗
conn.mu.Lock()
defer conn.mu.Unlock()
if half.lastSeen.Before(timestamp) {
half.lastSeen = timestamp
}
a.start = half.nextSeq == invalidSequence && t.SYN
if *debugLog {
if half.nextSeq < rev.ackSeq {
log.Printf("Delay detected on %v, data is acked but not assembled yet (acked %v, nextSeq %v)", key, rev.ackSeq, half.nextSeq)
}
}
// 判断是否要直接丢弃该包
if !half.stream.Accept(t, ci, half.dir, half.nextSeq, &a.start, ac) {
if *debugLog {
log.Printf("Ignoring packet")
}
return
}
// 连接被关闭不处理
if half.closed {
// this way is closed
if *debugLog {
log.Printf("%v got packet on closed half", key)
}
return
}
seq, ack, bytes := Sequence(t.Seq), Sequence(t.Ack), t.Payload
if t.ACK {
half.ackSeq = ack
}
// TODO: push when Ack is seen ??
action := assemblerAction{
nextSeq: Sequence(invalidSequence),
queue: true,
}
a.dump("AssembleWithContext()", half)
if half.nextSeq == invalidSequence {
// 一般来说只有第一个包才会nextSeq== invalidSequence
// 然后只处理syn,其他包放到队列中不进行处理?
if t.SYN {
if *debugLog {
log.Printf("%v saw first SYN packet, returning immediately, seq=%v", key, seq)
}
seq = seq.Add(1)
half.nextSeq = seq
action.queue = false
} else if a.start {
if *debugLog {
log.Printf("%v start forced", key)
}
half.nextSeq = seq
action.queue = false
} else {
if *debugLog {
log.Printf("%v waiting for start, storing into connection", key)
}
}
} else {
diff := half.nextSeq.Difference(seq)
if diff > 0 {
if *debugLog {
log.Printf("%v gap in sequence numbers (%v, %v) diff %v, storing into connection", key, half.nextSeq, seq, diff)
}
} else {
if *debugLog {
log.Printf("%v found contiguous data (%v, %v), returning immediately: len:%d", key, seq, half.nextSeq, len(bytes))
}
action.queue = false
}
}
action = a.handleBytes(bytes, seq, half, t.SYN, t.RST || t.FIN, action, ac)
if len(a.ret) > 0 {
action.nextSeq = a.sendToConnection(conn, half, ac)
}
if action.nextSeq != invalidSequence {
half.nextSeq = action.nextSeq
if t.FIN {
half.nextSeq = half.nextSeq.Add(1)
}
}
if *debugLog {
log.Printf("%v nextSeq:%d", key, half.nextSeq)
}
}
handleBytes
将就绪的数据包添加ret中或者将需要添加到队列的数据包添加到队列
func (a *Assembler) handleBytes(bytes []byte, seq Sequence, half *halfconnection, start bool, end bool, action assemblerAction, ac AssemblerContext) assemblerAction {
a.cacheLP.bytes = bytes
a.cacheLP.start = start
a.cacheLP.end = end
a.cacheLP.seq = seq
a.cacheLP.ac = ac
if action.queue {
a.checkOverlap(half, true, ac)
if (a.MaxBufferedPagesPerConnection > 0 && half.pages >= a.MaxBufferedPagesPerConnection) ||
(a.MaxBufferedPagesTotal > 0 && a.pc.used >= a.MaxBufferedPagesTotal) {
if *debugLog {
log.Printf("hit max buffer size: %+v, %v, %v", a.AssemblerOptions, half.pages, a.pc.used)
}
action.queue = false
a.addNextFromConn(half)
}
a.dump("handleBytes after queue", half)
} else {
a.cacheLP.bytes, a.cacheLP.seq = a.overlapExisting(half, seq, seq.Add(len(bytes)), a.cacheLP.bytes)
a.checkOverlap(half, false, ac)
if len(a.cacheLP.bytes) != 0 || end || start {
a.ret = append(a.ret, &a.cacheLP)
}
a.dump("handleBytes after no queue", half)
}
return action
}
checkOverlap
func (a *Assembler) checkOverlap(half *halfconnection, queue bool, ac AssemblerContext) {
var next *page
cur := half.last
bytes := a.cacheLP.bytes
start := a.cacheLP.seq
end := start.Add(len(bytes))
a.dump("before checkOverlap", half)
// [s6 : e6]
// [s1:e1][s2:e2] -- [s3:e3] -- [s4:e4][s5:e5]
// [s <--ds-- : --de--> e]
for cur != nil {
if *debugLog {
log.Printf("cur = %p (%s)\n", cur, cur)
}
// end < cur.start: continue (5)
if end.Difference(cur.seq) > 0 {
if *debugLog {
log.Printf("case 5\n")
}
next = cur
cur = cur.prev
continue
}
curEnd := cur.seq.Add(len(cur.bytes))
// start > cur.end: stop (1)
if start.Difference(curEnd) <= 0 {
if *debugLog {
log.Printf("case 1\n")
}
break
}
diffStart := start.Difference(cur.seq)
diffEnd := end.Difference(curEnd)
// end > cur.end && start < cur.start: drop (3)
if diffEnd <= 0 && diffStart >= 0 {
if *debugLog {
log.Printf("case 3\n")
}
if cur.isPacket() {
half.overlapPackets++
}
half.overlapBytes += len(cur.bytes)
// update links
if cur.prev != nil {
cur.prev.next = cur.next
} else {
half.first = cur.next
}
if cur.next != nil {
cur.next.prev = cur.prev
} else {
half.last = cur.prev
}
tmp := cur.prev
half.pages -= cur.release(a.pc)
cur = tmp
continue
}
// end > cur.end && start < cur.end: drop cur's end (2)
if diffEnd < 0 && start.Difference(curEnd) > 0 {
if *debugLog {
log.Printf("case 2\n")
}
cur.bytes = cur.bytes[:-start.Difference(cur.seq)]
break
} else
// start < cur.start && end > cur.start: drop cur's start (4)
if diffStart > 0 && end.Difference(cur.seq) < 0 {
if *debugLog {
log.Printf("case 4\n")
}
cur.bytes = cur.bytes[-end.Difference(cur.seq):]
cur.seq = cur.seq.Add(-end.Difference(cur.seq))
next = cur
} else
// end < cur.end && start > cur.start: replace bytes inside cur (6)
if diffEnd >= 0 && diffStart <= 0 {
if *debugLog {
log.Printf("case 6\n")
}
copy(cur.bytes[-diffStart:-diffStart+len(bytes)], bytes)
bytes = bytes[:0]
} else {
if *debugLog {
log.Printf("no overlap\n")
}
next = cur
}
cur = cur.prev
}
// Split bytes into pages, and insert in queue
a.cacheLP.bytes = bytes
a.cacheLP.seq = start
if len(bytes) > 0 && queue {
p, p2, numPages := a.cacheLP.convertToPages(a.pc, 0, ac)
half.queuedPackets++
half.queuedBytes += len(bytes)
half.pages += numPages
if cur != nil {
if *debugLog {
log.Printf("adding %s after %s", p, cur)
}
cur.next = p
p.prev = cur
} else {
if *debugLog {
log.Printf("adding %s as first", p)
}
half.first = p
}
if next != nil {
if *debugLog {
log.Printf("setting %s as next of new %s", next, p2)
}
p2.next = next
next.prev = p2
} else {
if *debugLog {
log.Printf("setting %s as last", p2)
}
half.last = p2
}
}
a.dump("After checkOverlap", half)
}
overExisting
func (a *Assembler) overlapExisting(half *halfconnection, start, end Sequence, bytes []byte) ([]byte, Sequence) {
if half.nextSeq == invalidSequence {
// no start yet
return bytes, start
}
diff := start.Difference(half.nextSeq)
if diff == 0 {
return bytes, start
}
s := 0
e := len(bytes)
// TODO: depending on strategy, we might want to shrink half.saved if possible
if e != 0 {
if *debugLog {
log.Printf("Overlap detected: ignoring current packet's first %d bytes", diff)
}
half.overlapPackets++
half.overlapBytes += diff
}
s += diff
if s >= e {
// Completely included in sent
s = e
}
bytes = bytes[s:]
return bytes, half.nextSeq
}
dump
func (a *Assembler) dump(text string, half *halfconnection) {
if !*debugLog {
return
}
log.Printf("%s: dump\n", text)
if half != nil {
p := half.first
if p == nil {
log.Printf(" * half.first = %p, no chunks queued\n", p)
} else {
s := 0
nb := 0
log.Printf(" * half.first = %p, queued chunks:", p)
for p != nil {
log.Printf("\t%s bytes:%s\n", p, hex.EncodeToString(p.bytes))
s += len(p.bytes)
nb++
p = p.next
}
log.Printf("\t%d chunks for %d bytes", nb, s)
}
log.Printf(" * half.last = %p\n", half.last)
log.Printf(" * half.saved = %p\n", half.saved)
p = half.saved
for p != nil {
log.Printf("\tseq:%d %s bytes:%s\n", p.getSeq(), p, hex.EncodeToString(p.bytes))
p = p.next
}
}
log.Printf(" * a.ret\n")
for i, r := range a.ret {
log.Printf("\t%d: %v b:%s\n", i, r.captureInfo(), hex.EncodeToString(r.getBytes()))
}
log.Printf(" * a.cacheSG.all\n")
for i, r := range a.cacheSG.all {
log.Printf("\t%d: %v b:%s\n", i, r.captureInfo(), hex.EncodeToString(r.getBytes()))
}
}
key
作为conn的key,其第一个应该是netflow,第二个应该是transportFlow
type key [2]gopacket.Flow
Reverse:反转src和dst
func (k *key) Reverse() key {
return key{
k[0].Reverse(),
k[1].Reverse(),
}
}
String:打印
func (k *key) String() string {
return fmt.Sprintf("%s:%s", k[0], k[1])
}
StreamPool
StreamPool 存储由 Assemblers 创建的所有Stream,允许多个Assembler在Stream处理上协同工作,同时强制执行单个Stream串行接收其数据的事实。它对并发是安全的,可供多个Assembler同时使用。
StreamPool 处理一个或多个 Assembler 对象使用的 Stream 对象的创建和存储。当 Assembler 找到一个新的 TCP 流时,它会通过调用 StreamFactory 的 New 方法创建一个关联的 Stream。此后(直到流关闭),该 Stream 对象将通过 Assembler 对流的 Reassembled 函数的调用接收组装的 TCP 数据。
与 Assembler 一样,StreamPool 尝试最小化分配。但是,与 Assembler 不同的是,它确实必须做一些锁定以确保它存储的连接对象可以被多个 Assembler 访问。
type StreamPool struct {
conns map[key]*connection
users int
mu sync.RWMutex
factory StreamFactory
free []*connection
all [][]connection
nextAlloc int
newConnectionCount int64
}
const initialAllocSize = 1024
func NewStreamPool(factory StreamFactory) *StreamPool {
return &StreamPool{
conns: make(map[key]*connection, initialAllocSize),
free: make([]*connection, 0, initialAllocSize),
factory: factory,
nextAlloc: initialAllocSize,
}
}
connections-获取所有的连接
func (p *StreamPool) connections() []*connection {
p.mu.RLock()
conns := make([]*connection, 0, len(p.conns))
for _, conn := range p.conns {
conns = append(conns, conn)
}
p.mu.RUnlock()
return conns
}
remove - 删除连接
func (p *StreamPool) remove(conn *connection) {
p.mu.Lock()
if _, ok := p.conns[conn.key]; ok {
delete(p.conns, conn.key)
p.free = append(p.free, conn)
}
p.mu.Unlock()
}
grow - 分配一组连接
默认为1024个
func (p *StreamPool) grow() {
conns := make([]connection, p.nextAlloc)
p.all = append(p.all, conns)
for i := range conns {
p.free = append(p.free, &conns[i])
}
if *memLog {
log.Println("StreamPool: created", p.nextAlloc, "new connections")
}
p.nextAlloc *= 2
}
dump - 打印剩余的连接数和当前连接的信息
func (p *StreamPool) Dump() {
p.mu.Lock()
defer p.mu.Unlock()
log.Printf("Remaining %d connections: ", len(p.conns))
for _, conn := range p.conns {
log.Printf("%v %s", conn.key, conn)
}
}
newConnection - 创建新的连接
func (p *StreamPool) newConnection(k key, s Stream, ts time.Time) (c *connection, h *halfconnection, r *halfconnection) {
if *memLog {
p.newConnectionCount++
if p.newConnectionCount&0x7FFF == 0 {
log.Println("StreamPool:", p.newConnectionCount, "requests,", len(p.conns), "used,", len(p.free), "free")
}
}
if len(p.free) == 0 {
p.grow()
}
index := len(p.free) - 1
c, p.free = p.free[index], p.free[:index]
c.reset(k, s, ts)
return c, &c.c2s, &c.s2c
}
getHalf - 获取一个conn
func (p *StreamPool) getHalf(k key) (*connection, *halfconnection, *halfconnection) {
conn := p.conns[k]
if conn != nil {
return conn, &conn.c2s, &conn.s2c
}
rk := k.Reverse()
conn = p.conns[rk]
if conn != nil {
return conn, &conn.s2c, &conn.c2s
}
return nil, nil, nil
}
getConnection - 获取一个conn,当conn不存在时会创建
func (p *StreamPool) getConnection(k key, end bool, ts time.Time, tcp *layers.TCP, ac AssemblerContext) (*connection, *halfconnection, *halfconnection) {
p.mu.RLock()
conn, half, rev := p.getHalf(k)
p.mu.RUnlock()
if end || conn != nil {
return conn, half, rev
}
s := p.factory.New(k[0], k[1], tcp, ac)
p.mu.Lock()
defer p.mu.Unlock()
conn, half, rev = p.newConnection(k, s, ts)
conn2, half2, rev2 := p.getHalf(k)
if conn2 != nil {
if conn2.key != k {
panic("FIXME: other dir added in the meantime...")
}
// FIXME: delete s ?
return conn2, half2, rev2
}
p.conns[k] = conn
return conn, half, rev
}
connection
type connection struct {
key key // client->server
c2s, s2c halfconnection
mu sync.Mutex
}
reset - 设置连接信息(复用)
func (c *connection) reset(k key, s Stream, ts time.Time) {
c.key = k
base := halfconnection{
nextSeq: invalidSequence,
ackSeq: invalidSequence,
created: ts,
lastSeen: ts,
stream: s,
}
c.c2s, c.s2c = base, base
c.c2s.dir, c.s2c.dir = TCPDirClientToServer, TCPDirServerToClient
}
lastSeen
func (c *connection) lastSeen() time.Time {
if c.c2s.lastSeen.Before(c.s2c.lastSeen) {
return c.s2c.lastSeen
}
return c.c2s.lastSeen
}
String
func (c *connection) String() string {
return fmt.Sprintf("c2s: %s, s2c: %s", &c.c2s, &c.s2c)
}
halfconnection - 单向的连接
type halfconnection struct {
dir TCPFlowDirection
pages int // Number of pages used (both in first/last and saved)
saved *page // Doubly-linked list of in-order pages (seq < nextSeq) already given to Stream who told us to keep
first, last *page // Doubly-linked list of out-of-order pages (seq > nextSeq)
nextSeq Sequence // sequence number of in-order received bytes
ackSeq Sequence
created, lastSeen time.Time
stream Stream
closed bool
// for stats
queuedBytes int
queuedPackets int
overlapBytes int
overlapPackets int
}
Dump
func (half *halfconnection) Dump() string {
s := fmt.Sprintf("pages: %d\n"+
"nextSeq: %d\n"+
"ackSeq: %d\n"+
"Seen : %s\n"+
"dir: %s\n", half.pages, half.nextSeq, half.ackSeq, half.lastSeen, half.dir)
nb := 0
for p := half.first; p != nil; p = p.next {
s += fmt.Sprintf(" Page[%d] %s len: %d\n", nb, p, len(p.bytes))
nb++
}
return s
}
String
func (half *halfconnection) String() string {
closed := ""
if half.closed {
closed = "closed "
}
return fmt.Sprintf("%screated:%v, last:%v", closed, half.created, half.lastSeen)
}
pageCache
type pageCache struct {
pagePool *sync.Pool
used int
pageRequests int64
}
func newPageCache() *pageCache {
pc := &pageCache{
pagePool: &sync.Pool{
New: func() interface{} { return new(page) },
}}
return pc
}
next
func (c *pageCache) next(ts time.Time) (p *page) {
if *memLog {
c.pageRequests++
if c.pageRequests&0xFFFF == 0 {
log.Println("PageCache:", c.pageRequests, "requested,", c.used, "used,")
}
}
p = c.pagePool.Get().(*page)
p.seen = ts
p.bytes = p.buf[:0]
c.used++
if *memLog {
log.Printf("allocator returns %s\n", p)
}
return p
}
replace
func (c *pageCache) replace(p *page) {
c.used--
if *memLog {
log.Printf("replacing %s\n", p)
}
p.prev = nil
p.next = nil
c.pagePool.Put(p)
}