From 32a655f042a3752d93c4507b4c128b21bf6aa602 Mon Sep 17 00:00:00 2001 From: Felix Hanley Date: Thu, 15 Feb 2018 22:42:34 +1100 Subject: Refactor DHT code into separate package --- dht/node.go | 401 +++++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 286 insertions(+), 115 deletions(-) (limited to 'dht/node.go') diff --git a/dht/node.go b/dht/node.go index 3f8f349..169f219 100644 --- a/dht/node.go +++ b/dht/node.go @@ -1,12 +1,15 @@ package dht import ( + //"fmt" + "context" "net" "strconv" "time" "github.com/felix/dhtsearch/bencode" "github.com/felix/logger" + "golang.org/x/time/rate" ) var ( @@ -23,35 +26,35 @@ type Node struct { id Infohash address string port int - conn *net.UDPConn + conn net.PacketConn pool chan chan packet rTable *routingTable - workers []*dhtWorker udpTimeout int packetsOut chan packet - peersOut chan Peer - closing chan chan error log logger.Logger //table routingTable // OnAnnoucePeer is called for each peer that announces itself - OnAnnoucePeer func(p *Peer) + OnAnnouncePeer func(p Peer) } // NewNode creates a new DHT node func NewNode(opts ...Option) (n *Node, err error) { - id := randomInfoHash() + k, err := newRoutingTable(id, 2000) + if err != nil { + n.log.Error("failed to create routing table", "error", err) + return nil, err + } + n = &Node{ - id: id, - address: "0.0.0.0", - port: 6881, - rTable: newRoutingTable(id), - workers: make([]*dhtWorker, 1), - closing: make(chan chan error), - log: logger.New(&logger.Options{Name: "dht"}), - peersOut: make(chan Peer), + id: id, + address: "0.0.0.0", + port: 6881, + udpTimeout: 10, + rTable: k, + log: logger.New(&logger.Options{Name: "dht"}), } // Set variadic options passed @@ -62,141 +65,111 @@ func NewNode(opts ...Option) (n *Node, err error) { } } + n.conn, err = net.ListenPacket("udp", n.address+":"+strconv.Itoa(n.port)) + if err != nil { + n.log.Error("failed to listen", "error", err) + return nil, err + } + n.log.Info("listening", "id", n.id, "network", n.conn.LocalAddr().Network(), "address", n.conn.LocalAddr().String()) + return n, nil } // Close stuff func (n *Node) Close() error { n.log.Warn("node closing") - errCh := make(chan error) - n.closing <- errCh - // Signal workers - for _, w := range n.workers { - w.stop() - } - return <-errCh + return nil } // Run starts the node on the DHT -func (n *Node) Run() chan Peer { - listener, err := net.ListenPacket("udp4", n.address+":"+strconv.Itoa(n.port)) - if err != nil { - n.log.Error("failed to listen", "error", err) - return nil - } - n.conn = listener.(*net.UDPConn) - n.port = n.conn.LocalAddr().(*net.UDPAddr).Port - n.log.Info("listening", "id", n.id, "address", n.address, "port", n.port) - - // Worker pool - n.pool = make(chan chan packet) +func (n *Node) Run() { // Packets onto the network - n.packetsOut = make(chan packet, 512) + n.packetsOut = make(chan packet, 1024) // Create a slab for allocation byteSlab := newSlab(8192, 10) - // Start our workers - n.log.Debug("starting workers", "count", len(n.workers)) - for i := 0; i < len(n.workers); i++ { - w := &dhtWorker{ - pool: n.pool, - packetsOut: n.packetsOut, - peersOut: n.peersOut, - rTable: n.rTable, - quit: make(chan struct{}), - log: n.log.Named("worker"), + n.log.Debug("starting packet writer") + go n.packetWriter() + + // Find neighbours + go n.makeNeighbours() + + n.log.Debug("starting packet reader") + for { + b := byteSlab.alloc() + c, addr, err := n.conn.ReadFrom(b) + if err != nil { + n.log.Warn("UDP read error", "error", err) + return } - go w.run() - n.workers[i] = w + + // Chop and process + n.processPacket(packet{ + data: b[0:c], + raddr: addr, + }) + byteSlab.free(b) } +} - n.log.Debug("starting packet writer") - // Start writing packets from channel to DHT - go func() { - var p packet - for { - select { - case p = <-n.packetsOut: - //n.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(n.udpTimeout))) - _, err := n.conn.WriteToUDP(p.data, &p.raddr) - if err != nil { - // TODO remove from routing or add to blacklist? - n.log.Warn("failed to write packet", "error", err) - } - } - } - }() +func (n *Node) makeNeighbours() { + // TODO configurable + ticker := time.Tick(5 * time.Second) - n.log.Debug("starting packet reader") - // Start reading packets - go func() { - n.bootstrap() - - // TODO configurable - ticker := time.Tick(10 * time.Second) - - // Send packets from conn to workers - for { - select { - case errCh := <-n.closing: - // TODO - errCh <- nil - case pCh := <-n.pool: - go func() { - b := byteSlab.Alloc() - c, addr, err := n.conn.ReadFromUDP(b) - if err != nil { - n.log.Warn("UDP read error", "error", err) - return - } - - // Chop and send - pCh <- packet{ - data: b[0:c], - raddr: *addr, - } - byteSlab.Free(b) - }() - - case <-ticker: - go func() { - if n.rTable.isEmpty() { - n.bootstrap() - } else { - n.makeNeighbours() - } - }() + n.bootstrap() + + for { + select { + case <-ticker: + if n.rTable.isEmpty() { + n.bootstrap() + } else { + // Send to all nodes + nodes := n.rTable.get(0) + for _, rn := range nodes { + n.findNode(rn, generateNeighbour(n.id, rn.id)) + } + n.rTable.flush() } } - }() - return n.peersOut + } } -func (n *Node) bootstrap() { +func (n Node) bootstrap() { n.log.Debug("bootstrapping") for _, s := range routers { - addr, err := net.ResolveUDPAddr("udp4", s) + addr, err := net.ResolveUDPAddr(n.conn.LocalAddr().Network(), s) if err != nil { n.log.Error("failed to parse bootstrap address", "error", err) return } - rn := &remoteNode{address: *addr} + rn := &remoteNode{address: addr} n.findNode(rn, n.id) } } -func (n *Node) makeNeighbours() { - n.log.Debug("making neighbours") - for _, rn := range n.rTable.getNodes() { - n.findNode(rn, n.id) +func (n *Node) packetWriter() { + l := rate.NewLimiter(rate.Limit(500), 100) + + for p := range n.packetsOut { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := l.Wait(ctx); err != nil { + n.log.Warn("rate limited", "error", err) + continue + } + _, err := n.conn.WriteTo(p.data, p.raddr) + if err != nil { + // TODO remove from routing or add to blacklist? + n.log.Warn("failed to write packet", "error", err) + } } - n.rTable.refresh() } func (n Node) findNode(rn *remoteNode, id Infohash) { target := randomInfoHash() - n.sendMsg(rn, "find_node", map[string]interface{}{ + n.sendQuery(rn, "find_node", map[string]interface{}{ "id": string(id), "target": string(target), }) @@ -204,13 +177,13 @@ func (n Node) findNode(rn *remoteNode, id Infohash) { // ping sends ping query to the chan. func (n *Node) ping(rn *remoteNode) { - id := n.id.GenNeighbour(rn.id) - n.sendMsg(rn, "ping", map[string]interface{}{ + id := generateNeighbour(n.id, rn.id) + n.sendQuery(rn, "ping", map[string]interface{}{ "id": string(id), }) } -func (n Node) sendMsg(rn *remoteNode, qType string, a map[string]interface{}) error { +func (n Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) error { // Stop if sending to self if rn.id.Equal(n.id) { return nil @@ -220,6 +193,54 @@ func (n Node) sendMsg(rn *remoteNode, qType string, a map[string]interface{}) er //n.log.Debug("sending message", "type", qType, "remote", rn) data := makeQuery(t, qType, a) + b, err := bencode.Encode(data) + if err != nil { + return err + } + //fmt.Printf("sending %s to %s\n", qType, rn.String()) + n.packetsOut <- packet{ + data: b, + raddr: rn.address, + } + return nil +} + +// Parse a KRPC packet into a message +func (n *Node) processPacket(p packet) { + data, err := bencode.Decode(p.data) + if err != nil { + return + } + + response, ok := data.(map[string]interface{}) + if !ok { + n.log.Debug("failed to parse packet", "error", "response is not dict") + return + } + + if err := checkKeys(response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil { + n.log.Debug("failed to parse packet", "error", err) + return + } + + switch response["y"].(string) { + case "q": + n.handleRequest(p.raddr, response) + case "r": + err = n.handleResponse(p.raddr, response) + case "e": + n.handleError(p.raddr, response) + default: + n.log.Warn("missing request type") + return + } + if err != nil { + n.log.Warn("failed to process packet", "error", err) + } +} + +// bencode data and send +func (n *Node) queueMsg(rn remoteNode, data map[string]interface{}) error { b, err := bencode.Encode(data) if err != nil { return err @@ -230,3 +251,153 @@ func (n Node) sendMsg(rn *remoteNode, qType string, a map[string]interface{}) er } return nil } + +// handleRequest handles the requests received from udp. +func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) (success bool) { + if err := checkKeys(m, [][]string{{"q", "string"}, {"a", "map"}}); err != nil { + + //d.queueMsg(addr, makeError(t, protocolError, err.Error())) + return + } + + a := m["a"].(map[string]interface{}) + + if err := checkKey(a, "id", "string"); err != nil { + //d.queueMsg(addr, makeError(t, protocolError, err.Error())) + return + } + + ih, err := InfohashFromString(a["id"].(string)) + if err != nil { + n.log.Warn("invalid request", "infohash", a["id"].(string)) + } + + if n.id.Equal(*ih) { + return + } + + rn := &remoteNode{address: addr, id: *ih} + q := m["q"].(string) + + switch q { + case "ping": + n.onPingQuery(*rn, m) + + case "get_peers": + n.onGetPeersQuery(*rn, m) + + case "announce_peer": + n.onAnnouncePeerQuery(*rn, m) + + default: + //n.queueMsg(addr, makeError(t, protocolError, "invalid q")) + return + } + n.rTable.add(rn) + return true +} + +// handleResponse handles responses received from udp. +func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { + r, err := getMapKey(m, "r") + if err != nil { + return err + } + id, err := getStringKey(r, "id") + if err != nil { + return err + } + ih, err := InfohashFromString(id) + if err != nil { + return err + } + + rn := &remoteNode{address: addr, id: *ih} + + nodes, err := getStringKey(r, "nodes") + // find_nodes/get_peers response with nodes + if err == nil { + n.onFindNodeResponse(*rn, m) + n.processFindNodeResults(*rn, nodes) + n.rTable.add(rn) + return nil + } + + values, err := getListKey(r, "values") + // get_peers response + if err == nil { + n.log.Debug("get_peers response", "source", rn) + for _, v := range values { + addr := compactNodeInfoToString(v.(string)) + n.log.Debug("unhandled get_peer request", "addres", addr) + + // TODO new peer needs to be matched to previous get_peers request + // n.peersManager.Insert(ih, p) + } + n.rTable.add(rn) + } + return nil +} + +// handleError handles errors received from udp. +func (n *Node) handleError(addr net.Addr, m map[string]interface{}) bool { + if err := checkKey(m, "e", "list"); err != nil { + return false + } + + e := m["e"].([]interface{}) + if len(e) != 2 { + return false + } + code := e[0].(int64) + msg := e[1].(string) + n.log.Debug("error packet", "address", addr.String(), "code", code, "error", msg) + + return true +} + +// Process another node's response to a find_node query. +func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) { + nodeLength := 26 + /* + if d.config.proto == "udp6" { + nodeList = m.R.Nodes6 + nodeLength = 38 + } else { + nodeList = m.R.Nodes + } + + // Not much to do + if nodeList == "" { + return + } + */ + + if len(nodeList)%nodeLength != 0 { + n.log.Error("node list is wrong length", "length", len(nodeList)) + return + } + + //fmt.Printf("%s sent %d nodes\n", rn.address.String(), len(nodeList)/nodeLength) + + // We got a byte array in groups of 26 or 38 + for i := 0; i < len(nodeList); i += nodeLength { + id := nodeList[i : i+ihLength] + addrStr := compactNodeInfoToString(nodeList[i+ihLength : i+nodeLength]) + + ih, err := InfohashFromString(id) + if err != nil { + n.log.Warn("invalid infohash in node list") + continue + } + + addr, err := net.ResolveUDPAddr("udp", addrStr) + if err != nil || addr.Port == 0 { + n.log.Warn("unable to resolve", "address", addrStr, "error", err) + continue + } + + rn := &remoteNode{address: addr, id: *ih} + n.rTable.add(rn) + } +} -- cgit v1.2.3