diff options
Diffstat (limited to 'dht')
| -rw-r--r-- | dht/krpc.go | 34 | ||||
| -rw-r--r-- | dht/krpc_test.go | 21 | ||||
| -rw-r--r-- | dht/messages.go | 36 | ||||
| -rw-r--r-- | dht/node.go | 67 | ||||
| -rw-r--r-- | dht/options.go | 10 |
5 files changed, 105 insertions, 63 deletions
diff --git a/dht/krpc.go b/dht/krpc.go index de508d9..bf66e20 100644 --- a/dht/krpc.go +++ b/dht/krpc.go @@ -113,7 +113,7 @@ func checkKey(data map[string]interface{}, key string, t string) error { } // Swiped from nictuku -func compactNodeInfoToString(cni string) string { +func decodeCompactNodeAddr(cni string) string { if len(cni) == 6 { return fmt.Sprintf("%d.%d.%d.%d:%d", cni[0], cni[1], cni[2], cni[3], (uint16(cni[4])<<8)|uint16(cni[5])) } else if len(cni) == 18 { @@ -124,21 +124,27 @@ func compactNodeInfoToString(cni string) string { } } -func stringToCompactNodeInfo(addr string) ([]byte, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return []byte{}, err - } - pInt, err := strconv.ParseInt(port, 10, 64) - if err != nil { - return []byte{}, err +func encodeCompactNodeAddr(addr string) string { + var a []uint8 + host, port, _ := net.SplitHostPort(addr) + ip := net.ParseIP(host) + if ip == nil { + return "" } - p := int2bytes(pInt) - if len(p) < 2 { - p = append(p, p[0]) - p[0] = 0 + aa, _ := strconv.ParseUint(port, 10, 16) + c := uint16(aa) + if ip2 := net.IP.To4(ip); ip2 != nil { + a = make([]byte, net.IPv4len+2, net.IPv4len+2) + copy(a, ip2[0:net.IPv4len]) // ignore bytes IPv6 bytes if it's IPv4. + a[4] = byte(c >> 8) + a[5] = byte(c) + } else { + a = make([]byte, net.IPv6len+2, net.IPv6len+2) + copy(a, ip) + a[16] = byte(c >> 8) + a[17] = byte(c) } - return append([]byte(host), p...), nil + return string(a) } func int2bytes(val int64) []byte { diff --git a/dht/krpc_test.go b/dht/krpc_test.go index d710678..5bc8373 100644 --- a/dht/krpc_test.go +++ b/dht/krpc_test.go @@ -1,25 +1,30 @@ package dht import ( + "encoding/hex" "testing" ) -func TestStringToCompactNodeInfo(t *testing.T) { +func TestCompactNodeAddr(t *testing.T) { tests := []struct { in string - out []byte + out string }{ - {in: "192.168.1.1:6881", out: []byte("asdfasdf")}, + {in: "192.168.1.1:6881", out: "c0a801011ae1"}, + {in: "[2001:9372:434a:800::2]:6881", out: "20019372434a080000000000000000021ae1"}, } for _, tt := range tests { - r, err := stringToCompactNodeInfo(tt.in) - if err != nil { - t.Errorf("stringToCompactNodeInfo failed with %s", err) + r := encodeCompactNodeAddr(tt.in) + out, _ := hex.DecodeString(tt.out) + if r != string(out) { + t.Errorf("encodeCompactNodeAddr(%s) => %x, expected %s", tt.in, r, tt.out) } - if r != tt.out { - t.Errorf("stringToCompactNodeInfo(%s) => %s, expected %s", tt.in, r, tt.out) + + s := decodeCompactNodeAddr(r) + if s != tt.in { + t.Errorf("decodeCompactNodeAddr(%x) => %s, expected %s", r, s, tt.in) } } } diff --git a/dht/messages.go b/dht/messages.go index 94b15ee..be7d6b6 100644 --- a/dht/messages.go +++ b/dht/messages.go @@ -3,6 +3,7 @@ package dht import ( "fmt" "net" + "strings" ) func (n *Node) onPingQuery(rn remoteNode, msg map[string]interface{}) { @@ -30,22 +31,35 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) { token := []byte(*th)[:2] id := generateNeighbour(n.id, *th) + nodes := n.rTable.get(8) + compactNS := []string{} + for _, rn := range nodes { + ns := encodeCompactNodeAddr(rn.address.String()) + if ns == "" { + n.log.Warn("failed to compact node", "address", rn.address.String()) + continue + } + compactNS = append(compactNS, ns) + } + t := msg["t"].(string) n.queueMsg(rn, makeResponse(t, map[string]interface{}{ "id": string(id), "token": token, - "nodes": "", + "nodes": strings.Join(compactNS, ""), })) - nodes := n.rTable.get(50) - fmt.Printf("sending get_peers for %s to %d nodes\n", *th, len(nodes)) - q := makeQuery(newTransactionID(), "get_peers", map[string]interface{}{ - "id": string(id), - "info_hash": string(*th), - }) - for _, o := range nodes { - n.queueMsg(*o, q) - } + //nodes := n.rTable.get(50) + /* + fmt.Printf("sending get_peers for %s to %d nodes\n", *th, len(nodes)) + q := makeQuery(newTransactionID(), "get_peers", map[string]interface{}{ + "id": string(id), + "info_hash": string(*th), + }) + for _, o := range nodes { + n.queueMsg(*o, q) + } + */ } func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) { @@ -77,7 +91,7 @@ func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) { n.log.Warn("sent port 0", "source", rn) return } - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", host, newPort)) + addr, err := net.ResolveUDPAddr(n.family, fmt.Sprintf("%s:%d", host, newPort)) rn = remoteNode{address: addr, id: rn.id} } diff --git a/dht/node.go b/dht/node.go index 169f219..7905964 100644 --- a/dht/node.go +++ b/dht/node.go @@ -1,10 +1,9 @@ package dht import ( - //"fmt" "context" + "fmt" "net" - "strconv" "time" "github.com/felix/dhtsearch/bencode" @@ -14,6 +13,7 @@ import ( var ( routers = []string{ + "dht.libtorrent.org:25401", "router.bittorrent.com:6881", "dht.transmissionbt.com:6881", "router.utorrent.com:6881", @@ -24,6 +24,7 @@ var ( // Node joins the DHT network type Node struct { id Infohash + family string address string port int conn net.PacketConn @@ -32,6 +33,7 @@ type Node struct { udpTimeout int packetsOut chan packet log logger.Logger + limiter *rate.Limiter //table routingTable // OnAnnoucePeer is called for each peer that announces itself @@ -50,10 +52,11 @@ func NewNode(opts ...Option) (n *Node, err error) { n = &Node{ id: id, - address: "0.0.0.0", + family: "udp4", port: 6881, udpTimeout: 10, rTable: k, + limiter: rate.NewLimiter(rate.Limit(100000), 2000000), log: logger.New(&logger.Options{Name: "dht"}), } @@ -65,12 +68,24 @@ func NewNode(opts ...Option) (n *Node, err error) { } } - n.conn, err = net.ListenPacket("udp", n.address+":"+strconv.Itoa(n.port)) + if n.family != "udp4" { + n.log.Debug("trying udp6 server") + n.conn, err = net.ListenPacket("udp6", fmt.Sprintf("[%s]:%d", net.IPv6zero.String(), n.port)) + if err == nil { + n.family = "udp6" + } + } + if n.conn == nil { + n.conn, err = net.ListenPacket("udp4", fmt.Sprintf("%s:%d", net.IPv4zero.String(), n.port)) + if err == nil { + n.family = "udp4" + } + } if err != nil { n.log.Error("failed to listen", "error", err) return nil, err } - n.log.Info("listening", "id", n.id, "network", n.conn.LocalAddr().Network(), "address", n.conn.LocalAddr().String()) + n.log.Info("listening", "id", n.id, "network", n.family, "address", n.conn.LocalAddr().String()) return n, nil } @@ -136,13 +151,13 @@ func (n *Node) makeNeighbours() { } } -func (n Node) bootstrap() { +func (n *Node) bootstrap() { n.log.Debug("bootstrapping") for _, s := range routers { - addr, err := net.ResolveUDPAddr(n.conn.LocalAddr().Network(), s) + addr, err := net.ResolveUDPAddr(n.family, s) if err != nil { n.log.Error("failed to parse bootstrap address", "error", err) - return + continue } rn := &remoteNode{address: addr} n.findNode(rn, n.id) @@ -150,24 +165,26 @@ func (n Node) bootstrap() { } func (n *Node) packetWriter() { - l := rate.NewLimiter(rate.Limit(500), 100) - for p := range n.packetsOut { + if p.raddr.String() == n.conn.LocalAddr().String() { + continue + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := l.Wait(ctx); err != nil { + if err := n.limiter.WaitN(ctx, len(p.data)); err != nil { n.log.Warn("rate limited", "error", err) continue } _, err := n.conn.WriteTo(p.data, p.raddr) if err != nil { // TODO remove from routing or add to blacklist? + // TODO reduce limit n.log.Warn("failed to write packet", "error", err) } } } -func (n Node) findNode(rn *remoteNode, id Infohash) { +func (n *Node) findNode(rn *remoteNode, id Infohash) { target := randomInfoHash() n.sendQuery(rn, "find_node", map[string]interface{}{ "id": string(id), @@ -183,7 +200,7 @@ func (n *Node) ping(rn *remoteNode) { }) } -func (n Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) error { +func (n *Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) error { // Stop if sending to self if rn.id.Equal(n.id) { return nil @@ -328,7 +345,7 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { if err == nil { n.log.Debug("get_peers response", "source", rn) for _, v := range values { - addr := compactNodeInfoToString(v.(string)) + addr := decodeCompactNodeAddr(v.(string)) n.log.Debug("unhandled get_peer request", "addres", addr) // TODO new peer needs to be matched to previous get_peers request @@ -359,19 +376,9 @@ func (n *Node) handleError(addr net.Addr, m map[string]interface{}) bool { // Process another node's response to a find_node query. func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) { nodeLength := 26 - /* - if d.config.proto == "udp6" { - nodeList = m.R.Nodes6 - nodeLength = 38 - } else { - nodeList = m.R.Nodes - } - - // Not much to do - if nodeList == "" { - return - } - */ + if n.family == "udp6" { + nodeLength = 38 + } if len(nodeList)%nodeLength != 0 { n.log.Error("node list is wrong length", "length", len(nodeList)) @@ -383,7 +390,7 @@ func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) { // We got a byte array in groups of 26 or 38 for i := 0; i < len(nodeList); i += nodeLength { id := nodeList[i : i+ihLength] - addrStr := compactNodeInfoToString(nodeList[i+ihLength : i+nodeLength]) + addrStr := decodeCompactNodeAddr(nodeList[i+ihLength : i+nodeLength]) ih, err := InfohashFromString(id) if err != nil { @@ -391,9 +398,9 @@ func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) { continue } - addr, err := net.ResolveUDPAddr("udp", addrStr) + addr, err := net.ResolveUDPAddr(n.family, addrStr) if err != nil || addr.Port == 0 { - n.log.Warn("unable to resolve", "address", addrStr, "error", err) + //n.log.Warn("unable to resolve", "address", addrStr, "error", err) continue } diff --git a/dht/options.go b/dht/options.go index f870a79..b7ded8a 100644 --- a/dht/options.go +++ b/dht/options.go @@ -29,6 +29,16 @@ func SetPort(p int) Option { } } +// SetIPv6 enables IPv6 +func SetIPv6(b bool) Option { + return func(n *Node) error { + if b { + n.family = "udp6" + } + return nil + } +} + // SetUDPTimeout sets the number of seconds to wait for UDP connections func SetUDPTimeout(s int) Option { return func(n *Node) error { |
