diff --git a/inet_diag.go b/inet_diag.go index 72c1fcb..bee391a 100644 --- a/inet_diag.go +++ b/inet_diag.go @@ -27,4 +27,5 @@ const ( type InetDiagTCPInfoResp struct { InetDiagMsg *Socket TCPInfo *TCPInfo + TCPBBRInfo *TCPBBRInfo } diff --git a/socket_linux.go b/socket_linux.go index e4e7f7a..9b0f4a0 100644 --- a/socket_linux.go +++ b/socket_linux.go @@ -184,7 +184,7 @@ func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { req.AddData(&socketRequest{ Family: family, Protocol: unix.IPPROTO_TCP, - Ext: INET_DIAG_INFO, + Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)), States: uint32(0xfff), // All TCP states }) s.Send(req) @@ -220,19 +220,42 @@ loop: if err != nil { return nil, err } - var tcpInfo *TCPInfo - for _, a := range attrs { - if a.Attr.Type == INET_DIAG_INFO { - tcpInfo = &TCPInfo{} - if err := tcpInfo.deserialize(a.Value); err != nil { - return nil, err - } - break - } + + res, err := attrsToInetDiagTCPInfoResp(attrs, sockInfo) + if err != nil { + return nil, err } - r := &InetDiagTCPInfoResp{InetDiagMsg: sockInfo, TCPInfo: tcpInfo} - result = append(result, r) + + result = append(result, res) } } return result, nil } + +func attrsToInetDiagTCPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagTCPInfoResp, error) { + var tcpInfo *TCPInfo + var tcpBBRInfo *TCPBBRInfo + for _, a := range attrs { + if a.Attr.Type == INET_DIAG_INFO { + tcpInfo = &TCPInfo{} + if err := tcpInfo.deserialize(a.Value); err != nil { + return nil, err + } + continue + } + + if a.Attr.Type == INET_DIAG_BBRINFO { + tcpBBRInfo = &TCPBBRInfo{} + if err := tcpBBRInfo.deserialize(a.Value); err != nil { + return nil, err + } + continue + } + } + + return &InetDiagTCPInfoResp{ + InetDiagMsg: sockInfo, + TCPInfo: tcpInfo, + TCPBBRInfo: tcpBBRInfo, + }, nil +} diff --git a/socket_linux_test.go b/socket_linux_test.go new file mode 100644 index 0000000..d275b73 --- /dev/null +++ b/socket_linux_test.go @@ -0,0 +1,145 @@ +package netlink + +import ( + "reflect" + "syscall" + "testing" +) + +func TestAttrsToInetDiagTCPInfoResp(t *testing.T) { + tests := []struct { + name string + attrs []syscall.NetlinkRouteAttr + expected *InetDiagTCPInfoResp + wantFail bool + }{ + { + name: "Empty", + attrs: []syscall.NetlinkRouteAttr{}, + expected: &InetDiagTCPInfoResp{}, + }, + { + name: "BBRInfo Only", + attrs: []syscall.NetlinkRouteAttr{ + { + Attr: syscall.RtAttr{ + Len: 20, + Type: INET_DIAG_BBRINFO, + }, + Value: []byte{ + 100, 0, 0, 0, 0, 0, 0, 0, + 111, 0, 0, 0, + 222, 0, 0, 0, + 123, 0, 0, 0, + }, + }, + }, + expected: &InetDiagTCPInfoResp{ + TCPBBRInfo: &TCPBBRInfo{ + BBRBW: 100, + BBRMinRTT: 111, + BBRPacingGain: 222, + BBRCwndGain: 123, + }, + }, + }, + { + name: "TCPInfo Only", + attrs: []syscall.NetlinkRouteAttr{ + { + Attr: syscall.RtAttr{ + Len: 232, + Type: INET_DIAG_INFO, + }, + Value: tcpInfoData, + }, + }, + expected: &InetDiagTCPInfoResp{ + TCPInfo: tcpInfo, + }, + }, + { + name: "TCPInfo + TCPBBR", + attrs: []syscall.NetlinkRouteAttr{ + { + Attr: syscall.RtAttr{ + Len: 232, + Type: INET_DIAG_INFO, + }, + Value: tcpInfoData, + }, + { + Attr: syscall.RtAttr{ + Len: 20, + Type: INET_DIAG_BBRINFO, + }, + Value: []byte{ + 100, 0, 0, 0, 0, 0, 0, 0, + 111, 0, 0, 0, + 222, 0, 0, 0, + 123, 0, 0, 0, + }, + }, + }, + expected: &InetDiagTCPInfoResp{ + TCPInfo: tcpInfo, + TCPBBRInfo: &TCPBBRInfo{ + BBRBW: 100, + BBRMinRTT: 111, + BBRPacingGain: 222, + BBRCwndGain: 123, + }, + }, + }, + { + name: "TCPBBR + TCPInfo (reverse)", + attrs: []syscall.NetlinkRouteAttr{ + { + Attr: syscall.RtAttr{ + Len: 20, + Type: INET_DIAG_BBRINFO, + }, + Value: []byte{ + 100, 0, 0, 0, 0, 0, 0, 0, + 111, 0, 0, 0, + 222, 0, 0, 0, + 123, 0, 0, 0, + }, + }, + { + Attr: syscall.RtAttr{ + Len: 232, + Type: INET_DIAG_INFO, + }, + Value: tcpInfoData, + }, + }, + expected: &InetDiagTCPInfoResp{ + TCPInfo: tcpInfo, + TCPBBRInfo: &TCPBBRInfo{ + BBRBW: 100, + BBRMinRTT: 111, + BBRPacingGain: 222, + BBRCwndGain: 123, + }, + }, + }, + } + + for _, test := range tests { + res, err := attrsToInetDiagTCPInfoResp(test.attrs, nil) + if err != nil && !test.wantFail { + t.Errorf("Unexpected failure for test %q", test.name) + continue + } + + if err == nil && test.wantFail { + t.Errorf("Unexpected success for test %q", test.name) + continue + } + + if !reflect.DeepEqual(test.expected, res) { + t.Errorf("Unexpected failure for test %q", test.name) + } + } +} diff --git a/tcp_linux.go b/tcp_linux.go index 741ea16..d3b46a7 100644 --- a/tcp_linux.go +++ b/tcp_linux.go @@ -2,9 +2,14 @@ package netlink import ( "bytes" + "errors" "io" ) +const ( + tcpBBRInfoLen = 20 +) + type TCPInfo struct { State uint8 Ca_state uint8 @@ -391,3 +396,24 @@ func (t *TCPInfo) deserialize(b []byte) error { t.Snd_wnd = native.Uint32(next) return nil } + +type TCPBBRInfo struct { + BBRBW uint64 + BBRMinRTT uint32 + BBRPacingGain uint32 + BBRCwndGain uint32 +} + +func (t *TCPBBRInfo) deserialize(b []byte) error { + if len(b) != tcpBBRInfoLen { + return errors.New("Invalid length") + } + + rb := bytes.NewBuffer(b) + t.BBRBW = native.Uint64(rb.Next(8)) + t.BBRMinRTT = native.Uint32(rb.Next(4)) + t.BBRPacingGain = native.Uint32(rb.Next(4)) + t.BBRCwndGain = native.Uint32(rb.Next(4)) + + return nil +} diff --git a/tcp_linux_test.go b/tcp_linux_test.go new file mode 100644 index 0000000..1950cf2 --- /dev/null +++ b/tcp_linux_test.go @@ -0,0 +1,153 @@ +package netlink + +import ( + "reflect" + "testing" +) + +var ( + tcpInfoData []byte + tcpInfo *TCPInfo +) + +func init() { + tcpInfoData = []byte{ + 1, 0, 0, 0, 0, 7, 120, 1, 96, 216, 3, 0, 64, + 156, 0, 0, 120, 5, 0, 0, 64, 3, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 236, 216, 0, 0, 0, 0, 0, 0, 56, 216, + 0, 0, 144, 39, 0, 0, 220, 5, 0, 0, 88, 250, + 0, 0, 79, 190, 0, 0, 7, 5, 0, 0, 255, 255, + 255, 127, 10, 0, 0, 0, 168, 5, 0, 0, 3, 0, 0, + 0, 0, 0, 0, 0, 144, 56, 0, 0, 0, 0, 0, 0, 1, 197, + 8, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, + 255, 255, 157, 42, 0, 0, 0, 0, 0, 0, 148, 26, 0, + 0, 0, 0, 0, 0, 181, 0, 0, 0, 95, 0, 0, 0, 0, 0, 0, + 0, 93, 180, 0, 0, 61, 0, 0, 0, 89, 0, 0, 0, 47, 216, + 1, 0, 0, 0, 0, 0, 32, 65, 23, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 90, 0, 0, + 0, 0, 0, 0, 0, 156, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 195, 1, 0, + } + tcpInfo = &TCPInfo{ + State: 1, + Options: 7, + Snd_wscale: 7, + Rcv_wscale: 8, + Rto: 252000, + Ato: 40000, + Snd_mss: 1400, + Rcv_mss: 832, + Last_data_sent: 55532, + Last_data_recv: 55352, + Last_ack_recv: 10128, + Pmtu: 1500, + Rcv_ssthresh: 64088, + Rtt: 48719, + Rttvar: 1287, + Snd_ssthresh: 2147483647, + Snd_cwnd: 10, + Advmss: 1448, + Reordering: 3, + Rcv_space: 14480, + Pacing_rate: 574721, + Max_pacing_rate: 18446744073709551615, + Bytes_acked: 10909, + Bytes_received: 6804, + Segs_out: 181, + Segs_in: 95, + Min_rtt: 46173, + Data_segs_in: 61, + Data_segs_out: 89, + Delivery_rate: 120879, + Busy_time: 1524000, + Delivered: 90, + Bytes_sent: 10908, + Snd_wnd: 115456, + } +} + +func TestTCPInfoDeserialize(t *testing.T) { + tests := []struct { + name string + input []byte + expected *TCPInfo + wantFail bool + }{ + { + name: "Valid data", + input: tcpInfoData, + expected: tcpInfo, + }, + } + + for _, test := range tests { + tcpbbr := &TCPInfo{} + err := tcpbbr.deserialize(test.input) + if err != nil && !test.wantFail { + t.Errorf("Unexpected failure for test %q", test.name) + continue + } + + if err != nil && test.wantFail { + continue + } + + if !reflect.DeepEqual(test.expected, tcpbbr) { + t.Errorf("Unexpected failure for test %q", test.name) + } + } +} + +func TestTCPBBRInfoDeserialize(t *testing.T) { + tests := []struct { + name string + input []byte + expected *TCPBBRInfo + wantFail bool + }{ + { + name: "Valid data", + input: []byte{ + 100, 0, 0, 0, 0, 0, 0, 0, + 111, 0, 0, 0, + 222, 0, 0, 0, + 123, 0, 0, 0, + }, + expected: &TCPBBRInfo{ + BBRBW: 100, + BBRMinRTT: 111, + BBRPacingGain: 222, + BBRCwndGain: 123, + }, + }, + { + name: "Invalid length", + input: []byte{ + 100, 0, 0, 0, 0, 0, 0, 0, + 111, 0, 0, 0, + 222, 0, 0, 0, + 123, 0, 0, + }, + wantFail: true, + }, + } + + for _, test := range tests { + tcpbbr := &TCPBBRInfo{} + err := tcpbbr.deserialize(test.input) + if err != nil && !test.wantFail { + t.Errorf("Unexpected failure for test %q", test.name) + continue + } + + if err != nil && test.wantFail { + continue + } + + if !reflect.DeepEqual(test.expected, tcpbbr) { + t.Errorf("Unexpected failure for test %q", test.name) + } + } +}