diff options
| author | Felix Hanley <felix@userspace.com.au> | 2018-02-15 11:42:34 +0000 |
|---|---|---|
| committer | Felix Hanley <felix@userspace.com.au> | 2018-02-15 11:42:40 +0000 |
| commit | 32a655f042a3752d93c4507b4c128b21bf6aa602 (patch) | |
| tree | 224c0d7e51efccac3b32dc5d0662baa2ab7304a5 | |
| parent | 2ded0704c8f675c3d92cf2b4874a32c65faf2553 (diff) | |
| download | dhtsearch-32a655f042a3752d93c4507b4c128b21bf6aa602.tar.gz dhtsearch-32a655f042a3752d93c4507b4c128b21bf6aa602.tar.bz2 | |
Refactor DHT code into separate package
| -rw-r--r-- | crawler/crawler.go | 129 | ||||
| -rw-r--r-- | crawler/dht.go | 171 | ||||
| -rw-r--r-- | crawler/dht_worker.go | 245 | ||||
| -rw-r--r-- | crawler/krpc.go | 37 | ||||
| -rw-r--r-- | crawler/packet.go | 9 | ||||
| -rw-r--r-- | crawler/peer.go | 9 | ||||
| -rw-r--r-- | crawler/remote_node.go | 24 | ||||
| -rw-r--r-- | crawler/routing_table.go | 58 | ||||
| -rw-r--r-- | dht/infohash.go | 59 | ||||
| -rw-r--r-- | dht/infohash_test.go | 64 | ||||
| -rw-r--r-- | dht/krpc.go | 72 | ||||
| -rw-r--r-- | dht/krpc_test.go | 25 | ||||
| -rw-r--r-- | dht/messages.go | 105 | ||||
| -rw-r--r-- | dht/node.go | 401 | ||||
| -rw-r--r-- | dht/options.go | 17 | ||||
| -rw-r--r-- | dht/peer.go | 10 | ||||
| -rw-r--r-- | dht/remote_node.go | 7 | ||||
| -rw-r--r-- | dht/routing_table.go | 116 | ||||
| -rw-r--r-- | dht/routing_table_test.go | 46 | ||||
| -rw-r--r-- | dht/slab.go | 4 | ||||
| -rw-r--r-- | dht/worker.go | 274 | ||||
| -rw-r--r-- | infohash.go | 43 | ||||
| -rw-r--r-- | slab.go | 25 | ||||
| -rw-r--r-- | stats.go | 32 |
24 files changed, 733 insertions, 1249 deletions
diff --git a/crawler/crawler.go b/crawler/crawler.go deleted file mode 100644 index bfd9785..0000000 --- a/crawler/crawler.go +++ /dev/null @@ -1,129 +0,0 @@ -package crawler - -import ( - "regexp" - - "github.com/felix/logger" -) - -const ( - TCPTimeout = 5 - UDPTimeout = 5 -) - -type Crawler struct { - port int - nodes int - httpAddress string - tagREs map[string]*regexp.Regexp - log logger.Logger -} - -// Option are options for the server -type Option func(*Crawler) error - -// NewCrawler creates a set of DHT nodes to crawl the network -func NewCrawer(opts ...Option) (*Crawler, error) { - s := &Crawler{ - port: 6881, - nodes: 1, - httpAddress: "localhost:6880", - tagREs: make(map[string]*regexp.Regexp), - } - - // Default logger - logOpts := &logger.Options{ - Name: "crawler", - Level: logger.Info, - } - s.log = logger.New(logOpts) - - err := mergeTagRegexps(s.tagREs, tags) - if err != nil { - s.log.Error("failed to compile tags", "error", err) - return nil, err - } - err = mergeCharacterTagREs(s.tagREs) - if err != nil { - s.log.Error("failed to compile character class tags", "error", err) - return nil, err - } - - // Set variadic options passed - for _, option := range opts { - err = option(s) - if err != nil { - return nil, err - } - } - - s.log.Debug("debugging output enabled") - - peers := make(chan peer) - - for i := 0; i < s.nodes; i++ { - // Consecutive port numbers - port := s.port + i - node := &dhtNode{ - id: genInfoHash(), - address: "", - port: port, - workers: 2, - log: s.log.Named("dht"), - peersOut: peers, - } - go node.run() - } - - return s, nil -} - -// SetLogger sets the server -func SetLogger(l logger.Logger) Option { - return func(s *Crawler) error { - s.log = l - return nil - } -} - -// SetPort sets the base port -func SetPort(p int) Option { - return func(s *Crawler) error { - s.port = p - return nil - } -} - -// SetNodes determines the number of nodes to start -func SetNodes(n int) Option { - return func(s *Crawler) error { - s.nodes = n - return nil - } -} - -// SetHTTPAddress determines the listening address for HTTP -func SetHTTPAddress(a string) Option { - return func(s *Crawler) error { - s.httpAddress = a - return nil - } -} - -// SetTags determines the listening address for HTTP -func SetTags(tags map[string]string) Option { - return func(s *Crawler) error { - // Merge user tags - err := mergeTagRegexps(s.tagREs, tags) - if err != nil { - s.log.Error("failed to compile tags", "error", err) - } - return err - } -} - -func (s *Crawler) Stats() Stats { - s.statlock.RLock() - defer s.statlock.RUnlock() - return s.stats -} diff --git a/crawler/dht.go b/crawler/dht.go deleted file mode 100644 index 51d478e..0000000 --- a/crawler/dht.go +++ /dev/null @@ -1,171 +0,0 @@ -package crawler - -import ( - "math" - "net" - "strconv" - "sync/atomic" - "time" - - "github.com/felix/logger" -) - -var ( - routers = []string{ - "router.bittorrent.com:6881", - "dht.transmissionbt.com:6881", - "router.utorrent.com:6881", - } -) - -type dhtNode struct { - id string - address string - port int - conn *net.UDPConn - pool chan chan packet - workers int - tid uint32 - packetsOut chan packet - peersOut chan<- peer - log logger.Logger - //table routingTable -} - -func (d *dhtNode) run() { - listener, err := net.ListenPacket("udp4", d.address+":"+strconv.Itoa(d.port)) - if err != nil { - d.log.Error("failed to listen", "error", err) - return - } - d.conn = listener.(*net.UDPConn) - d.port = d.conn.LocalAddr().(*net.UDPAddr).Port - - d.log.Info("listening", "address", d.address, "port", d.port) - - d.pool = make(chan chan packet) - - // Packets onto the network - d.packetsOut = make(chan packet, 512) - - // Create a slab for allocation - byteSlab := newSlab(8192, 10) - - rTable := newRoutingTable(d.id) - - // Start our workers - for i := 0; i < d.workers; i++ { - w := &dhtWorker{ - pool: d.pool, - packetsOut: d.packetsOut, - peersOut: d.peersOut, - rTable: rTable, - } - } - - // Start writing packets from channel to DHT - go func() { - var p packet - for { - select { - case p = <-d.packetsOut: - d.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(UDPTimeout))) - b, err := d.conn.WriteToUDP(p.b, &p.raddr) - if err != nil { - // TODO remove from routing or add to blacklist? - d.log.Error("failed to write packet", "error", err) - } - } - } - }() - - // TODO configurable - ticker := time.Tick(5 * time.Second) - - // Send packets from conn to workers - for { - b := byteSlab.Alloc() - c, addr, err := d.conn.ReadFromUDP(b) - if err != nil { - d.log.Warn("read error", "error", err) - continue - } - - select { - case pCh := <-d.pool: - // Chop and send - pCh <- packet{b[0:c], *addr} - byteSlab.Free(b) - - case <-ticker: - go func() { - d.log.Debug("making neighbours") - if rTable.isEmpty() { - d.bootstrap() - } else { - for _, rn := range rTable.getNodes() { - d.findNode(rn, rn.id) - } - rTable.refresh() - } - }() - } - } - return -} - -func (d *dhtNode) bootstrap() { - d.log.Debug("bootstrapping") - for _, s := range routers { - addr, err := net.ResolveUDPAddr("udp4", s) - if err != nil { - d.log.Error("failed to parse bootstrap address", "error", err) - return - } - rn := newRemoteNode(*addr, "") - d.findNode(rn, "") - } -} - -func (d dhtNode) findNode(rn *remoteNode, target string) { - var id string - if target == "" { - id = d.id - } else { - id = genNeighbour(d.id, target) - } - d.sendQuery(rn, "find_node", map[string]interface{}{ - "id": id, - "target": genInfoHash(), - }) -} - -// ping sends ping query to the chan. -func (d *dhtNode) ping(rn *remoteNode) { - d.sendQuery(rn, "ping", map[string]interface{}{ - "id": genNeighbour(d.id, rn.id), - }) -} - -func (d dhtNode) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) { - - // Stop if sending to self - if rn.id == d.id { - return - } - - t := d.newTransactionId() - - d.sendMsg(rn.address, makeQuery(t, qType, a)) -} - -// bencode data and send -func (d *dhtNode) sendMsg(raddr net.UDPAddr, data map[string]interface{}) { - d.packetsOut <- packet{[]byte(Encode(data)), raddr} -} - -func (d *dhtNode) newTransactionId() string { - t := atomic.AddUint32(&d.tid, 1) - t = t % math.MaxUint16 - return strconv.Itoa(int(t)) -} diff --git a/crawler/dht_worker.go b/crawler/dht_worker.go deleted file mode 100644 index 29f3bc5..0000000 --- a/crawler/dht_worker.go +++ /dev/null @@ -1,245 +0,0 @@ -package crawler - -import ( - "net" - - "github.com/felix/logger" -) - -type dhtWorker struct { - pool chan chan packet - packetsOut chan<- packet - peersOut chan<- peer - log logger.Logger - rTable *routingTable -} - -func (dw *dhtWorker) run(po chan<- packet) error { - packetsIn := make(chan packet) - dw.packetsOut = po - - for { - dw.pool <- packetsIn - - select { - // Wait for work - case p := <-packetsIn: - dw.process(p) - } - } -} - -// Parse a KRPC packet into a message -func (dw *dhtWorker) process(p packet) { - data, err := Decode(p.b) - if err != nil { - return - } - - response, err := parseMessage(data) - if err != nil { - dw.log.Warn("failed to parse packet", "error", err) - return - } - - switch response["y"].(string) { - case "q": - dw.handleRequest(&p.raddr, response) - case "r": - dw.handleResponse(&p.raddr, response) - case "e": - dw.handleError(&p.raddr, response) - default: - dw.log.Warn("missing request type") - return - } -} - -// bencode data and send -func (dw *dhtWorker) sendMsg(raddr net.UDPAddr, data map[string]interface{}) { - dw.packetsOut <- packet{[]byte(Encode(data)), raddr} -} - -// handleRequest handles the requests received from udp. -func (dw *dhtWorker) handleRequest(addr *net.UDPAddr, m map[string]interface{}) (success bool) { - - t := m["t"].(string) - - if err := parseKeys(m, [][]string{{"q", "string"}, {"a", "map"}}); err != nil { - - //d.sendMsg(addr, makeError(t, protocolError, err.Error())) - return - } - - q := m["q"].(string) - a := m["a"].(map[string]interface{}) - - if err := parseKey(a, "id", "string"); err != nil { - //d.sendMsg(addr, makeError(t, protocolError, err.Error())) - return - } - - id := a["id"].(string) - - if dw.rTable.id == id { - return - } - - if len(id) != 20 { - //dw.sendMsg(addr, makeError(t, protocolError, "invalid id")) - return - } - - var rn *remoteNode - switch q { - case "ping": - rn = newRemoteNode(*addr, id) - dw.sendMsg(*addr, makeResponse(t, map[string]interface{}{ - "id": dw.rTable.id, - })) - - case "get_peers": - if err := parseKey(a, "info_hash", "string"); err != nil { - //dw.sendMsg(addr, makeError(t, protocolError, err.Error())) - return - } - rn = newRemoteNode(*addr, id) - ih := a["info_hash"].(string) - dw.log.Debug("get_peers", "source", rn.String(), "infohash", ih) - - if len(ih) != ihLength { - //send(dht, addr, makeError(t, protocolError, "invalid info_hash")) - return - } - - // Crawling, we have no nodes - dw.sendMsg(*addr, makeResponse(t, map[string]interface{}{ - "id": genNeighbour(dw.rTable.id, ih), - "token": ih[:2], - "nodes": "", - })) - - case "announce_peer": - if err := parseKeys(a, [][]string{ - {"info_hash", "string"}, - {"port", "int"}, - {"token", "string"}}); err != nil { - - //dw.sendMsg(addr, makeError(t, protocolError, err.Error())) - return - } - - ih := a["info_hash"].(string) - rn = newRemoteNode(*addr, ih) - dw.log.Debug("announce_peer", "source", rn.String(), "infohash", ih) - - // TODO - if impliedPort, ok := a["implied_port"]; ok && - impliedPort.(int) != 0 { - //port = addr.Port - } - // TODO do we reply? - dw.peersOut <- peer{*addr, ih} - - default: - //dw.sendMsg(addr, makeError(t, protocolError, "invalid q")) - return - } - dw.rTable.add(rn) - return true -} - -// handleResponse handles responses received from udp. -func (dw *dhtWorker) handleResponse(addr *net.UDPAddr, m map[string]interface{}) (success bool) { - - //t := m["t"].(string) - - // inform transManager to delete the transaction. - if err := parseKey(m, "r", "map"); err != nil { - return - } - - r := m["r"].(map[string]interface{}) - if err := parseKey(r, "id", "string"); err != nil { - return - } - - ih := r["id"].(string) - rn := newRemoteNode(*addr, ih) - - // find_nodes response - if err := parseKey(r, "nodes", "string"); err == nil { - nodes := r["nodes"].(string) - dw.processFindNodeResults(rn, nodes) - return - } - - // get_peers response - if err := parseKey(r, "values", "list"); err == nil { - values := r["values"].([]interface{}) - for _, v := range values { - addr := compactNodeInfoToString(v.(string)) - dw.log.Debug("unhandled get_peer request", "addres", addr) - // TODO new peer - // dw.peersManager.Insert(ih, p) - } - } - dw.rTable.add(rn) - return true -} - -// handleError handles errors received from udp. -func (dw *dhtWorker) handleError(addr *net.UDPAddr, m map[string]interface{}) (success bool) { - if err := parseKey(m, "e", "list"); err != nil { - return - } - - if e := m["e"].([]interface{}); len(e) != 2 { - return - } - dw.log.Debug("error packet", "ip", addr.IP.String(), "port", addr.Port) - - return true -} - -// Process another node's response to a find_node query. -func (dw *dhtWorker) processFindNodeResults(rn *remoteNode, nodeList string) { - nodeLength := 26 - /* - if d.config.proto == "udp6" { - nodeList = m.R.Nodes6 - nodeLength = 38 - } else { - nodeList = m.R.Nodes - } - - // Not much to do - if nodeList == "" { - return - } - */ - - if len(nodeList)%nodeLength != 0 { - dw.log.Error("node list is wrong length", "length", len(nodeList)) - return - } - - // We got a byte array in groups of 26 or 38 - for i := 0; i < len(nodeList); i += nodeLength { - id := nodeList[i : i+ihLength] - addr := compactNodeInfoToString(nodeList[i+ihLength : i+nodeLength]) - - if dw.rTable.id == id { - dw.log.Debug("find_nodes ignoring self") - continue - } - - address, err := net.ResolveUDPAddr("udp4", addr) - if err != nil { - dw.log.Error("failed to resolve", "error", err) - continue - } - rn := newRemoteNode(*address, id) - dw.rTable.add(rn) - } -} diff --git a/crawler/krpc.go b/crawler/krpc.go deleted file mode 100644 index 67150c0..0000000 --- a/crawler/krpc.go +++ /dev/null @@ -1,37 +0,0 @@ -package crawler - -import ( - "errors" - "fmt" - "net" -) - -// makeQuery returns a query-formed data. -func makeQuery(t, q string, a map[string]interface{}) map[string]interface{} { - return map[string]interface{}{ - "t": t, - "y": "q", - "q": q, - "a": a, - } -} - -// makeResponse returns a response-formed data. -func makeResponse(t string, r map[string]interface{}) map[string]interface{} { - return map[string]interface{}{ - "t": t, - "y": "r", - "r": r, - } -} - -// parseKeys parses keys. It just wraps parseKey. -func parseKeys(data map[string]interface{}, pairs [][]string) error { - for _, args := range pairs { - key, t := args[0], args[1] - if err := parseKey(data, key, t); err != nil { - return err - } - } - return nil -} diff --git a/crawler/packet.go b/crawler/packet.go deleted file mode 100644 index 1fa3318..0000000 --- a/crawler/packet.go +++ /dev/null @@ -1,9 +0,0 @@ -package crawler - -import "net" - -// Unprocessed packet from socket -type packet struct { - b []byte - raddr net.UDPAddr -} diff --git a/crawler/peer.go b/crawler/peer.go deleted file mode 100644 index 06be506..0000000 --- a/crawler/peer.go +++ /dev/null @@ -1,9 +0,0 @@ -package crawler - -import "net" - -// Peer on DHT network -type Peer struct { - Address net.UDPAddr - ID string -} diff --git a/crawler/remote_node.go b/crawler/remote_node.go deleted file mode 100644 index bfbc5ac..0000000 --- a/crawler/remote_node.go +++ /dev/null @@ -1,24 +0,0 @@ -package crawler - -import ( - "fmt" - "net" - //"time" -) - -type remoteNode struct { - address net.UDPAddr - id string - //lastSeen time.Time -} - -func newRemoteNode(addr net.UDPAddr, id string) *remoteNode { - return &remoteNode{ - address: addr, - id: id, - } -} - -func (r *remoteNode) String() string { - return fmt.Sprintf("%s:%d", r.address.IP.String(), r.address.Port) -} diff --git a/crawler/routing_table.go b/crawler/routing_table.go deleted file mode 100644 index 8bb0d3c..0000000 --- a/crawler/routing_table.go +++ /dev/null @@ -1,58 +0,0 @@ -package crawler - -import ( - "net" - "sync" -) - -// Keep it simple for now -type routingTable struct { - id string - address net.UDPAddr - nodes []*remoteNode - max int - sync.Mutex -} - -func newRoutingTable(id string) *routingTable { - k := &routingTable{id: id, max: 4000} - k.refresh() - return k -} - -func (k *routingTable) add(rn *remoteNode) { - k.Lock() - defer k.Unlock() - - // Check IP and ports are valid and not self - if (rn.address.String() == k.address.String() && rn.address.Port == k.address.Port) || - rn.id == k.id || rn.id == "" { - return - } - k.nodes = append(k.nodes, rn) -} - -func (k *routingTable) getNodes() []*remoteNode { - k.Lock() - defer k.Unlock() - return k.nodes -} - -func (k *routingTable) isEmpty() bool { - k.Lock() - defer k.Unlock() - return len(k.nodes) == 0 -} - -func (k *routingTable) isFull() bool { - k.Lock() - defer k.Unlock() - return len(k.nodes) >= k.max -} - -// For now -func (k *routingTable) refresh() { - k.Lock() - defer k.Unlock() - k.nodes = make([]*remoteNode, 0) -} diff --git a/dht/infohash.go b/dht/infohash.go index 9c3ea3d..cd12446 100644 --- a/dht/infohash.go +++ b/dht/infohash.go @@ -3,6 +3,7 @@ package dht import ( "crypto/sha1" "encoding/hex" + "fmt" "io" "math/rand" "time" @@ -10,9 +11,29 @@ import ( const ihLength = 20 -// Infohash - +// Infohash is a 160 bit (20 byte) value type Infohash []byte +// InfohashFromString converts a 40 digit hexadecimal string to an Infohash +func InfohashFromString(s string) (*Infohash, error) { + switch len(s) { + case 20: + // Binary string + ih := Infohash([]byte(s)) + return &ih, nil + case 40: + // Hex string + b, err := hex.DecodeString(s) + if err != nil { + return nil, err + } + ih := Infohash(b) + return &ih, nil + default: + return nil, fmt.Errorf("invalid length %d", len(s)) + } +} + func (ih Infohash) String() string { return hex.EncodeToString(ih) } @@ -33,25 +54,31 @@ func (ih Infohash) Equal(other Infohash) bool { return true } -// FromString - -func (ih *Infohash) FromString(s string) error { - switch len(s) { - case 20: - // Byte string - *ih = Infohash([]byte(s)) - return nil - case 40: - b, err := hex.DecodeString(s) - if err != nil { - return err +// Distance determines the distance to another infohash as an integer +func (ih Infohash) Distance(other Infohash) int { + i := 0 + for ; i < 20; i++ { + if ih[i] != other[i] { + break } - *ih = Infohash(b) } - return nil + + if i == 20 { + return 160 + } + + xor := ih[i] ^ other[i] + + j := 0 + for (xor & 0x80) == 0 { + xor <<= 1 + j++ + } + return 8*i + j } -func (ih Infohash) GenNeighbour(other Infohash) Infohash { - s := append(ih[:10], other[10:]...) +func generateNeighbour(first, second Infohash) Infohash { + s := append(first[:10], second[10:]...) return Infohash(s) } diff --git a/dht/infohash_test.go b/dht/infohash_test.go index b3223f4..1574b19 100644 --- a/dht/infohash_test.go +++ b/dht/infohash_test.go @@ -6,19 +6,35 @@ import ( ) func TestInfohashImport(t *testing.T) { - var ih Infohash - idHex := "5a3ce1c14e7a08645677bbd1cfe7d8f956d53256" - err := ih.FromString(idHex) - if err != nil { - t.Errorf("FromString failed with %s", err) + tests := []struct { + str string + ok bool + }{ + {str: "5a3ce1c14e7a08645677bbd1cfe7d8f956d53256", ok: true}, + {str: "5a3ce1c14e7a08645677bbd1cfe7d8f956d53256000", ok: false}, } - idBytes, err := hex.DecodeString(idHex) + for _, tt := range tests { + ih, err := InfohashFromString(tt.str) + if tt.ok { + if err != nil { + t.Errorf("FromString failed with %s", err) + } - ih2 := Infohash(idBytes) - if !ih.Equal(ih2) { - t.Errorf("expected %s to equal %s", ih, ih2) + idBytes, err := hex.DecodeString(tt.str) + if err != nil { + t.Errorf("failed to decode %s to hex", tt.str) + } + ih2 := Infohash(idBytes) + if !ih.Equal(ih2) { + t.Errorf("expected %s to equal %s", ih, ih2) + } + } else { + if err == nil { + t.Errorf("FromString should have failed for %s", tt.str) + } + } } } @@ -28,3 +44,33 @@ func TestInfohashLength(t *testing.T) { t.Errorf("%s as string should be length 20, got %d", ih, len(ih)) } } + +func TestInfohashDistance(t *testing.T) { + id := "d1c5676ae7ac98e8b19f63565905105e3c4c37a2" + + var tests = []struct { + ih string + other string + distance int + }{ + {id, id, 160}, + {id, "d1c5676ae7ac98e8b19f63565905105e3c4c37a3", 159}, + } + + ih, err := InfohashFromString(id) + if err != nil { + t.Errorf("Failed to create Infohash: %s", err) + } + + for _, tt := range tests { + other, err := InfohashFromString(tt.other) + if err != nil { + t.Errorf("Failed to create Infohash: %s", err) + } + + dist := ih.Distance(*other) + if dist != tt.distance { + t.Errorf("Distance() => %d, expected %d", dist, tt.distance) + } + } +} diff --git a/dht/krpc.go b/dht/krpc.go index a926e81..de508d9 100644 --- a/dht/krpc.go +++ b/dht/krpc.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" "net" + "strconv" ) const transIDBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -36,19 +37,40 @@ func makeResponse(t string, r map[string]interface{}) map[string]interface{} { } } -// parseMessage parses the basic data received from udp. -// It returns a map value. -func parseMessage(data interface{}) (map[string]interface{}, error) { - response, ok := data.(map[string]interface{}) +func getStringKey(data map[string]interface{}, key string) (string, error) { + val, ok := data[key] + if !ok { + return "", fmt.Errorf("krpc: missing key %s", key) + } + out, ok := val.(string) if !ok { - return nil, errors.New("response is not dict") + return "", fmt.Errorf("krpc: key type mismatch") } + return out, nil +} - if err := checkKeys(response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil { - return nil, err +func getMapKey(data map[string]interface{}, key string) (map[string]interface{}, error) { + val, ok := data[key] + if !ok { + return nil, fmt.Errorf("krpc: missing key %s", key) } + out, ok := val.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("krpc: key type mismatch") + } + return out, nil +} - return response, nil +func getListKey(data map[string]interface{}, key string) ([]interface{}, error) { + val, ok := data[key] + if !ok { + return nil, fmt.Errorf("krpc: missing key %s", key) + } + out, ok := val.([]interface{}) + if !ok { + return nil, fmt.Errorf("krpc: key type mismatch") + } + return out, nil } // parseKeys parses keys. It just wraps parseKey. @@ -101,3 +123,37 @@ func compactNodeInfoToString(cni string) string { return "" } } + +func stringToCompactNodeInfo(addr string) ([]byte, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return []byte{}, err + } + pInt, err := strconv.ParseInt(port, 10, 64) + if err != nil { + return []byte{}, err + } + p := int2bytes(pInt) + if len(p) < 2 { + p = append(p, p[0]) + p[0] = 0 + } + return append([]byte(host), p...), nil +} + +func int2bytes(val int64) []byte { + data, j := make([]byte, 8), -1 + for i := 0; i < 8; i++ { + shift := uint64((7 - i) * 8) + data[i] = byte((val & (0xff << shift)) >> shift) + + if j == -1 && data[i] != 0 { + j = i + } + } + + if j != -1 { + return data[j:] + } + return data[:1] +} diff --git a/dht/krpc_test.go b/dht/krpc_test.go new file mode 100644 index 0000000..d710678 --- /dev/null +++ b/dht/krpc_test.go @@ -0,0 +1,25 @@ +package dht + +import ( + "testing" +) + +func TestStringToCompactNodeInfo(t *testing.T) { + + tests := []struct { + in string + out []byte + }{ + {in: "192.168.1.1:6881", out: []byte("asdfasdf")}, + } + + for _, tt := range tests { + r, err := stringToCompactNodeInfo(tt.in) + if err != nil { + t.Errorf("stringToCompactNodeInfo failed with %s", err) + } + if r != tt.out { + t.Errorf("stringToCompactNodeInfo(%s) => %s, expected %s", tt.in, r, tt.out) + } + } +} diff --git a/dht/messages.go b/dht/messages.go new file mode 100644 index 0000000..94b15ee --- /dev/null +++ b/dht/messages.go @@ -0,0 +1,105 @@ +package dht + +import ( + "fmt" + "net" +) + +func (n *Node) onPingQuery(rn remoteNode, msg map[string]interface{}) { + t := msg["t"].(string) + //n.log.Debug("ping", "source", rn) + n.queueMsg(rn, makeResponse(t, map[string]interface{}{ + "id": string(n.id), + })) +} + +func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) { + a := msg["a"].(map[string]interface{}) + if err := checkKey(a, "info_hash", "string"); err != nil { + //n.queueMsg(addr, makeError(t, protocolError, err.Error())) + return + } + + // This is the ih of the torrent + th, err := InfohashFromString(a["info_hash"].(string)) + if err != nil { + n.log.Warn("invalid torrent", "infohash", a["info_hash"]) + } + n.log.Debug("get_peers query", "source", rn, "torrent", th) + + token := []byte(*th)[:2] + + id := generateNeighbour(n.id, *th) + t := msg["t"].(string) + n.queueMsg(rn, makeResponse(t, map[string]interface{}{ + "id": string(id), + "token": token, + "nodes": "", + })) + + nodes := n.rTable.get(50) + fmt.Printf("sending get_peers for %s to %d nodes\n", *th, len(nodes)) + q := makeQuery(newTransactionID(), "get_peers", map[string]interface{}{ + "id": string(id), + "info_hash": string(*th), + }) + for _, o := range nodes { + n.queueMsg(*o, q) + } +} + +func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) { + a := msg["a"].(map[string]interface{}) + err := checkKeys(a, [][]string{ + {"info_hash", "string"}, + {"port", "int"}, + {"token", "string"}, + }) + if err != nil { + //n.queueMsg(addr, makeError(t, protocolError, err.Error())) + return + } + + n.log.Debug("announce_peer", "source", rn) + + // TODO + if impliedPort, ok := a["implied_port"]; ok && impliedPort.(int) != 0 { + // Use the port from the network + } else { + // Use the port in the message + host, _, err := net.SplitHostPort(rn.address.String()) + if err != nil { + n.log.Warn("failed to split host/port", "error", err) + return + } + newPort := a["port"] + if newPort == 0 { + n.log.Warn("sent port 0", "source", rn) + return + } + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", host, newPort)) + rn = remoteNode{address: addr, id: rn.id} + } + + // TODO do we reply? + + ih, err := InfohashFromString(a["info_hash"].(string)) + if err != nil { + n.log.Warn("invalid torrent", "infohash", a["info_hash"]) + } + + p := Peer{Node: rn, Infohash: *ih} + n.log.Info("anounce_peer", p) + if n.OnAnnouncePeer != nil { + go n.OnAnnouncePeer(p) + } +} + +func (n *Node) onFindNodeResponse(rn remoteNode, msg map[string]interface{}) { + r := msg["r"].(map[string]interface{}) + if err := checkKey(r, "id", "string"); err != nil { + return + } + nodes := r["nodes"].(string) + n.processFindNodeResults(rn, nodes) +} diff --git a/dht/node.go b/dht/node.go index 3f8f349..169f219 100644 --- a/dht/node.go +++ b/dht/node.go @@ -1,12 +1,15 @@ package dht import ( + //"fmt" + "context" "net" "strconv" "time" "github.com/felix/dhtsearch/bencode" "github.com/felix/logger" + "golang.org/x/time/rate" ) var ( @@ -23,35 +26,35 @@ type Node struct { id Infohash address string port int - conn *net.UDPConn + conn net.PacketConn pool chan chan packet rTable *routingTable - workers []*dhtWorker udpTimeout int packetsOut chan packet - peersOut chan Peer - closing chan chan error log logger.Logger //table routingTable // OnAnnoucePeer is called for each peer that announces itself - OnAnnoucePeer func(p *Peer) + OnAnnouncePeer func(p Peer) } // NewNode creates a new DHT node func NewNode(opts ...Option) (n *Node, err error) { - id := randomInfoHash() + k, err := newRoutingTable(id, 2000) + if err != nil { + n.log.Error("failed to create routing table", "error", err) + return nil, err + } + n = &Node{ - id: id, - address: "0.0.0.0", - port: 6881, - rTable: newRoutingTable(id), - workers: make([]*dhtWorker, 1), - closing: make(chan chan error), - log: logger.New(&logger.Options{Name: "dht"}), - peersOut: make(chan Peer), + id: id, + address: "0.0.0.0", + port: 6881, + udpTimeout: 10, + rTable: k, + log: logger.New(&logger.Options{Name: "dht"}), } // Set variadic options passed @@ -62,141 +65,111 @@ func NewNode(opts ...Option) (n *Node, err error) { } } + n.conn, err = net.ListenPacket("udp", n.address+":"+strconv.Itoa(n.port)) + if err != nil { + n.log.Error("failed to listen", "error", err) + return nil, err + } + n.log.Info("listening", "id", n.id, "network", n.conn.LocalAddr().Network(), "address", n.conn.LocalAddr().String()) + return n, nil } // Close stuff func (n *Node) Close() error { n.log.Warn("node closing") - errCh := make(chan error) - n.closing <- errCh - // Signal workers - for _, w := range n.workers { - w.stop() - } - return <-errCh + return nil } // Run starts the node on the DHT -func (n *Node) Run() chan Peer { - listener, err := net.ListenPacket("udp4", n.address+":"+strconv.Itoa(n.port)) - if err != nil { - n.log.Error("failed to listen", "error", err) - return nil - } - n.conn = listener.(*net.UDPConn) - n.port = n.conn.LocalAddr().(*net.UDPAddr).Port - n.log.Info("listening", "id", n.id, "address", n.address, "port", n.port) - - // Worker pool - n.pool = make(chan chan packet) +func (n *Node) Run() { // Packets onto the network - n.packetsOut = make(chan packet, 512) + n.packetsOut = make(chan packet, 1024) // Create a slab for allocation byteSlab := newSlab(8192, 10) - // Start our workers - n.log.Debug("starting workers", "count", len(n.workers)) - for i := 0; i < len(n.workers); i++ { - w := &dhtWorker{ - pool: n.pool, - packetsOut: n.packetsOut, - peersOut: n.peersOut, - rTable: n.rTable, - quit: make(chan struct{}), - log: n.log.Named("worker"), + n.log.Debug("starting packet writer") + go n.packetWriter() + + // Find neighbours + go n.makeNeighbours() + + n.log.Debug("starting packet reader") + for { + b := byteSlab.alloc() + c, addr, err := n.conn.ReadFrom(b) + if err != nil { + n.log.Warn("UDP read error", "error", err) + return } - go w.run() - n.workers[i] = w + + // Chop and process + n.processPacket(packet{ + data: b[0:c], + raddr: addr, + }) + byteSlab.free(b) } +} - n.log.Debug("starting packet writer") - // Start writing packets from channel to DHT - go func() { - var p packet - for { - select { - case p = <-n.packetsOut: - //n.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(n.udpTimeout))) - _, err := n.conn.WriteToUDP(p.data, &p.raddr) - if err != nil { - // TODO remove from routing or add to blacklist? - n.log.Warn("failed to write packet", "error", err) - } - } - } - }() +func (n *Node) makeNeighbours() { + // TODO configurable + ticker := time.Tick(5 * time.Second) - n.log.Debug("starting packet reader") - // Start reading packets - go func() { - n.bootstrap() - - // TODO configurable - ticker := time.Tick(10 * time.Second) - - // Send packets from conn to workers - for { - select { - case errCh := <-n.closing: - // TODO - errCh <- nil - case pCh := <-n.pool: - go func() { - b := byteSlab.Alloc() - c, addr, err := n.conn.ReadFromUDP(b) - if err != nil { - n.log.Warn("UDP read error", "error", err) - return - } - - // Chop and send - pCh <- packet{ - data: b[0:c], - raddr: *addr, - } - byteSlab.Free(b) - }() - - case <-ticker: - go func() { - if n.rTable.isEmpty() { - n.bootstrap() - } else { - n.makeNeighbours() - } - }() + n.bootstrap() + + for { + select { + case <-ticker: + if n.rTable.isEmpty() { + n.bootstrap() + } else { + // Send to all nodes + nodes := n.rTable.get(0) + for _, rn := range nodes { + n.findNode(rn, generateNeighbour(n.id, rn.id)) + } + n.rTable.flush() } } - }() - return n.peersOut + } } -func (n *Node) bootstrap() { +func (n Node) bootstrap() { n.log.Debug("bootstrapping") for _, s := range routers { - addr, err := net.ResolveUDPAddr("udp4", s) + addr, err := net.ResolveUDPAddr(n.conn.LocalAddr().Network(), s) if err != nil { n.log.Error("failed to parse bootstrap address", "error", err) return } - rn := &remoteNode{address: *addr} + rn := &remoteNode{address: addr} n.findNode(rn, n.id) } } -func (n *Node) makeNeighbours() { - n.log.Debug("making neighbours") - for _, rn := range n.rTable.getNodes() { - n.findNode(rn, n.id) +func (n *Node) packetWriter() { + l := rate.NewLimiter(rate.Limit(500), 100) + + for p := range n.packetsOut { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := l.Wait(ctx); err != nil { + n.log.Warn("rate limited", "error", err) + continue + } + _, err := n.conn.WriteTo(p.data, p.raddr) + if err != nil { + // TODO remove from routing or add to blacklist? + n.log.Warn("failed to write packet", "error", err) + } } - n.rTable.refresh() } func (n Node) findNode(rn *remoteNode, id Infohash) { target := randomInfoHash() - n.sendMsg(rn, "find_node", map[string]interface{}{ + n.sendQuery(rn, "find_node", map[string]interface{}{ "id": string(id), "target": string(target), }) @@ -204,13 +177,13 @@ func (n Node) findNode(rn *remoteNode, id Infohash) { // ping sends ping query to the chan. func (n *Node) ping(rn *remoteNode) { - id := n.id.GenNeighbour(rn.id) - n.sendMsg(rn, "ping", map[string]interface{}{ + id := generateNeighbour(n.id, rn.id) + n.sendQuery(rn, "ping", map[string]interface{}{ "id": string(id), }) } -func (n Node) sendMsg(rn *remoteNode, qType string, a map[string]interface{}) error { +func (n Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) error { // Stop if sending to self if rn.id.Equal(n.id) { return nil @@ -224,9 +197,207 @@ func (n Node) sendMsg(rn *remoteNode, qType string, a map[string]interface{}) er if err != nil { return err } + //fmt.Printf("sending %s to %s\n", qType, rn.String()) + n.packetsOut <- packet{ + data: b, + raddr: rn.address, + } + return nil +} + +// Parse a KRPC packet into a message +func (n *Node) processPacket(p packet) { + data, err := bencode.Decode(p.data) + if err != nil { + return + } + + response, ok := data.(map[string]interface{}) + if !ok { + n.log.Debug("failed to parse packet", "error", "response is not dict") + return + } + + if err := checkKeys(response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil { + n.log.Debug("failed to parse packet", "error", err) + return + } + + switch response["y"].(string) { + case "q": + n.handleRequest(p.raddr, response) + case "r": + err = n.handleResponse(p.raddr, response) + case "e": + n.handleError(p.raddr, response) + default: + n.log.Warn("missing request type") + return + } + if err != nil { + n.log.Warn("failed to process packet", "error", err) + } +} + +// bencode data and send +func (n *Node) queueMsg(rn remoteNode, data map[string]interface{}) error { + b, err := bencode.Encode(data) + if err != nil { + return err + } n.packetsOut <- packet{ data: b, raddr: rn.address, } return nil } + +// handleRequest handles the requests received from udp. +func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) (success bool) { + if err := checkKeys(m, [][]string{{"q", "string"}, {"a", "map"}}); err != nil { + + //d.queueMsg(addr, makeError(t, protocolError, err.Error())) + return + } + + a := m["a"].(map[string]interface{}) + + if err := checkKey(a, "id", "string"); err != nil { + //d.queueMsg(addr, makeError(t, protocolError, err.Error())) + return + } + + ih, err := InfohashFromString(a["id"].(string)) + if err != nil { + n.log.Warn("invalid request", "infohash", a["id"].(string)) + } + + if n.id.Equal(*ih) { + return + } + + rn := &remoteNode{address: addr, id: *ih} + q := m["q"].(string) + + switch q { + case "ping": + n.onPingQuery(*rn, m) + + case "get_peers": + n.onGetPeersQuery(*rn, m) + + case "announce_peer": + n.onAnnouncePeerQuery(*rn, m) + + default: + //n.queueMsg(addr, makeError(t, protocolError, "invalid q")) + return + } + n.rTable.add(rn) + return true +} + +// handleResponse handles responses received from udp. +func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { + r, err := getMapKey(m, "r") + if err != nil { + return err + } + id, err := getStringKey(r, "id") + if err != nil { + return err + } + ih, err := InfohashFromString(id) + if err != nil { + return err + } + + rn := &remoteNode{address: addr, id: *ih} + + nodes, err := getStringKey(r, "nodes") + // find_nodes/get_peers response with nodes + if err == nil { + n.onFindNodeResponse(*rn, m) + n.processFindNodeResults(*rn, nodes) + n.rTable.add(rn) + return nil + } + + values, err := getListKey(r, "values") + // get_peers response + if err == nil { + n.log.Debug("get_peers response", "source", rn) + for _, v := range values { + addr := compactNodeInfoToString(v.(string)) + n.log.Debug("unhandled get_peer request", "addres", addr) + + // TODO new peer needs to be matched to previous get_peers request + // n.peersManager.Insert(ih, p) + } + n.rTable.add(rn) + } + return nil +} + +// handleError handles errors received from udp. +func (n *Node) handleError(addr net.Addr, m map[string]interface{}) bool { + if err := checkKey(m, "e", "list"); err != nil { + return false + } + + e := m["e"].([]interface{}) + if len(e) != 2 { + return false + } + code := e[0].(int64) + msg := e[1].(string) + n.log.Debug("error packet", "address", addr.String(), "code", code, "error", msg) + + return true +} + +// Process another node's response to a find_node query. +func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) { + nodeLength := 26 + /* + if d.config.proto == "udp6" { + nodeList = m.R.Nodes6 + nodeLength = 38 + } else { + nodeList = m.R.Nodes + } + + // Not much to do + if nodeList == "" { + return + } + */ + + if len(nodeList)%nodeLength != 0 { + n.log.Error("node list is wrong length", "length", len(nodeList)) + return + } + + //fmt.Printf("%s sent %d nodes\n", rn.address.String(), len(nodeList)/nodeLength) + + // We got a byte array in groups of 26 or 38 + for i := 0; i < len(nodeList); i += nodeLength { + id := nodeList[i : i+ihLength] + addrStr := compactNodeInfoToString(nodeList[i+ihLength : i+nodeLength]) + + ih, err := InfohashFromString(id) + if err != nil { + n.log.Warn("invalid infohash in node list") + continue + } + + addr, err := net.ResolveUDPAddr("udp", addrStr) + if err != nil || addr.Port == 0 { + n.log.Warn("unable to resolve", "address", addrStr, "error", err) + continue + } + + rn := &remoteNode{address: addr, id: *ih} + n.rTable.add(rn) + } +} diff --git a/dht/options.go b/dht/options.go index 03a85bf..f870a79 100644 --- a/dht/options.go +++ b/dht/options.go @@ -6,6 +6,13 @@ import ( type Option func(*Node) error +func SetOnAnnouncePeer(f func(Peer)) Option { + return func(n *Node) error { + n.OnAnnouncePeer = f + return nil + } +} + // SetAddress sets the IP address to listen on func SetAddress(ip string) Option { return func(n *Node) error { @@ -22,14 +29,6 @@ func SetPort(p int) Option { } } -// SetWorkers sets the number of workers -func SetWorkers(c int) Option { - return func(n *Node) error { - n.workers = make([]*dhtWorker, c) - return nil - } -} - // SetUDPTimeout sets the number of seconds to wait for UDP connections func SetUDPTimeout(s int) Option { return func(n *Node) error { @@ -38,7 +37,7 @@ func SetUDPTimeout(s int) Option { } } -// SetLogger sets the number of workers +// SetLogger sets the logger func SetLogger(l logger.Logger) Option { return func(n *Node) error { n.log = l diff --git a/dht/peer.go b/dht/peer.go index a801331..f9669ba 100644 --- a/dht/peer.go +++ b/dht/peer.go @@ -1,9 +1,13 @@ package dht -import "net" +import "fmt" // Peer on DHT network type Peer struct { - Address net.UDPAddr - ID Infohash + Node remoteNode + Infohash Infohash +} + +func (p Peer) String() string { + return fmt.Sprintf("%s (%s)", p.Infohash, p.Node) } diff --git a/dht/remote_node.go b/dht/remote_node.go index 8a8d4a2..5bb2585 100644 --- a/dht/remote_node.go +++ b/dht/remote_node.go @@ -3,16 +3,15 @@ package dht import ( "fmt" "net" - //"time" ) type remoteNode struct { - address net.UDPAddr + address net.Addr id Infohash //lastSeen time.Time } // String implements fmt.Stringer -func (r *remoteNode) String() string { - return fmt.Sprintf("%s:%d", r.address.IP.String(), r.address.Port) +func (r remoteNode) String() string { + return fmt.Sprintf("%s (%s)", r.id.String(), r.address.String()) } diff --git a/dht/routing_table.go b/dht/routing_table.go index 0252519..3c1f2d2 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -1,57 +1,119 @@ package dht import ( - "net" + "container/heap" "sync" ) -// Keep it simple for now +type rItem struct { + value *remoteNode + distance int + index int // Index in heap +} + +type priorityQueue []*rItem + type routingTable struct { - id Infohash - address net.UDPAddr - nodes []*remoteNode - max int + id Infohash + max int + items priorityQueue + addresses map[string]*remoteNode sync.Mutex } -func newRoutingTable(id Infohash) *routingTable { - k := &routingTable{id: id, max: 4000} - k.refresh() - return k +func newRoutingTable(id Infohash, max int) (*routingTable, error) { + k := &routingTable{ + id: id, + max: max, + } + k.flush() + heap.Init(&k.items) + return k, nil } -func (k *routingTable) add(rn *remoteNode) { - k.Lock() - defer k.Unlock() +// Len implements sort.Interface +func (pq priorityQueue) Len() int { return len(pq) } + +// Less implements sort.Interface +func (pq priorityQueue) Less(i, j int) bool { + return pq[i].distance > pq[j].distance +} + +// Swap implements sort.Interface +func (pq priorityQueue) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] + pq[i].index = i + pq[j].index = j +} + +// Push implements heap.Interface +func (pq *priorityQueue) Push(x interface{}) { + n := len(*pq) + item := x.(*rItem) + item.index = n + *pq = append(*pq, item) +} + +// Pop implements heap.Interface +func (pq *priorityQueue) Pop() interface{} { + old := *pq + n := len(old) + item := old[n-1] + item.index = -1 // for safety + *pq = old[0 : n-1] + return item +} +func (k *routingTable) add(rn *remoteNode) { // Check IP and ports are valid and not self - if (rn.address.String() == k.address.String() && rn.address.Port == k.address.Port) || !rn.id.Valid() || rn.id.Equal(k.id) { + if !rn.id.Valid() || rn.id.Equal(k.id) { return } - k.nodes = append(k.nodes, rn) -} -func (k *routingTable) getNodes() []*remoteNode { k.Lock() defer k.Unlock() - return k.nodes + + if _, ok := k.addresses[rn.address.String()]; ok { + return + } + k.addresses[rn.address.String()] = rn + + item := &rItem{ + value: rn, + distance: k.id.Distance(rn.id), + } + + heap.Push(&k.items, item) + + if len(k.items) > k.max { + for i := k.max - 1; i < len(k.items); i++ { + old := k.items[i] + delete(k.addresses, old.value.address.String()) + heap.Remove(&k.items, i) + } + } } -func (k *routingTable) isEmpty() bool { - k.Lock() - defer k.Unlock() - return len(k.nodes) == 0 +func (k *routingTable) get(n int) (out []*remoteNode) { + if n == 0 { + n = len(k.items) + } + for i := 0; i < n && i < len(k.items); i++ { + out = append(out, k.items[i].value) + } + return out } -func (k *routingTable) isFull() bool { +func (k *routingTable) flush() { k.Lock() defer k.Unlock() - return len(k.nodes) >= k.max + + k.items = make(priorityQueue, 0) + k.addresses = make(map[string]*remoteNode, k.max) } -// For now -func (k *routingTable) refresh() { +func (k *routingTable) isEmpty() bool { k.Lock() defer k.Unlock() - k.nodes = make([]*remoteNode, 0) + return len(k.items) == 0 } diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go new file mode 100644 index 0000000..77c0a17 --- /dev/null +++ b/dht/routing_table_test.go @@ -0,0 +1,46 @@ +package dht + +import ( + "fmt" + "net" + "testing" +) + +func TestPriorityQueue(t *testing.T) { + id := "d1c5676ae7ac98e8b19f63565905105e3c4c37a2" + + tests := []string{ + "d1c5676ae7ac98e8b19f63565905105e3c4c37b9", + "d1c5676ae7ac98e8b19f63565905105e3c4c37a9", + "d1c5676ae7ac98e8b19f63565905105e3c4c37a4", + "d1c5676ae7ac98e8b19f63565905105e3c4c37a3", // distance of 159 + } + + ih, err := InfohashFromString(id) + if err != nil { + t.Errorf("failed to create infohash: %s\n", err) + } + + pq, err := newRoutingTable(*ih, 3) + if err != nil { + t.Errorf("failed to create kTable: %s\n", err) + } + + for i, idt := range tests { + iht, err := InfohashFromString(idt) + if err != nil { + t.Errorf("failed to create infohash: %s\n", err) + } + addr, _ := net.ResolveUDPAddr("udp", fmt.Sprintf("0.0.0.0:%d", i)) + pq.add(&remoteNode{id: *iht, address: addr}) + } + + if len(pq.items) != len(pq.addresses) { + t.Errorf("items and addresses out of sync") + } + + first := pq.items[0].value.id + if first.String() != "d1c5676ae7ac98e8b19f63565905105e3c4c37a3" { + t.Errorf("first is %s with distance %d\n", first, ih.Distance(first)) + } +} diff --git a/dht/slab.go b/dht/slab.go index 737b0b6..a8b4018 100644 --- a/dht/slab.go +++ b/dht/slab.go @@ -14,11 +14,11 @@ func newSlab(blockSize int, numBlocks int) slab { return s } -func (s slab) Alloc() (x []byte) { +func (s slab) alloc() (x []byte) { return <-s } -func (s slab) Free(x []byte) { +func (s slab) free(x []byte) { // Check we are using the right dimensions x = x[:cap(x)] s <- x diff --git a/dht/worker.go b/dht/worker.go deleted file mode 100644 index 36b0519..0000000 --- a/dht/worker.go +++ /dev/null @@ -1,274 +0,0 @@ -package dht - -import ( - "net" - - "github.com/felix/dhtsearch/bencode" - "github.com/felix/logger" -) - -type dhtWorker struct { - pool chan chan packet - packetsOut chan<- packet - peersOut chan<- Peer - log logger.Logger - rTable *routingTable - quit chan struct{} -} - -func (dw *dhtWorker) run() error { - packetsIn := make(chan packet) - - for { - dw.pool <- packetsIn - - // Wait for work or shutdown - select { - case p := <-packetsIn: - dw.process(p) - case <-dw.quit: - dw.log.Warn("worker closing") - break - } - } -} - -func (dw dhtWorker) stop() { - go func() { - dw.quit <- struct{}{} - }() -} - -// Parse a KRPC packet into a message -func (dw *dhtWorker) process(p packet) { - data, err := bencode.Decode(p.data) - if err != nil { - return - } - - response, err := parseMessage(data) - if err != nil { - dw.log.Debug("failed to parse packet", "error", err) - return - } - - switch response["y"].(string) { - case "q": - dw.handleRequest(&p.raddr, response) - case "r": - dw.handleResponse(&p.raddr, response) - case "e": - dw.handleError(&p.raddr, response) - default: - dw.log.Warn("missing request type") - return - } -} - -// bencode data and send -func (dw *dhtWorker) queueMsg(raddr net.UDPAddr, data map[string]interface{}) error { - b, err := bencode.Encode(data) - if err != nil { - return err - } - dw.packetsOut <- packet{ - data: b, - raddr: raddr, - } - return nil -} - -// handleRequest handles the requests received from udp. -func (dw *dhtWorker) handleRequest(addr *net.UDPAddr, m map[string]interface{}) (success bool) { - - t := m["t"].(string) - - if err := checkKeys(m, [][]string{{"q", "string"}, {"a", "map"}}); err != nil { - - //d.queueMsg(addr, makeError(t, protocolError, err.Error())) - return - } - - q := m["q"].(string) - a := m["a"].(map[string]interface{}) - - if err := checkKey(a, "id", "string"); err != nil { - //d.queueMsg(addr, makeError(t, protocolError, err.Error())) - return - } - - var ih Infohash - err := ih.FromString(a["id"].(string)) - if err != nil { - dw.log.Warn("invalid packet", "infohash", a["id"]) - } - - if dw.rTable.id.Equal(ih) { - return - } - - var rn *remoteNode - switch q { - case "ping": - rn = &remoteNode{address: *addr, id: ih} - dw.log.Debug("ping", "source", rn, "infohash", ih) - dw.queueMsg(*addr, makeResponse(t, map[string]interface{}{ - "id": string(dw.rTable.id), - })) - - case "get_peers": - if err := checkKey(a, "info_hash", "string"); err != nil { - //dw.queueMsg(addr, makeError(t, protocolError, err.Error())) - return - } - rn = &remoteNode{address: *addr, id: ih} - err = ih.FromString(a["info_hash"].(string)) - if err != nil { - dw.log.Warn("invalid packet", "infohash", a["id"]) - } - dw.log.Debug("get_peers", "source", rn, "infohash", ih) - - // Crawling, we have no nodes - id := dw.rTable.id.GenNeighbour(ih) - dw.queueMsg(*addr, makeResponse(t, map[string]interface{}{ - "id": string(id), - "token": ih[:2], - "nodes": "", - })) - - case "announce_peer": - if err := checkKeys(a, [][]string{ - {"info_hash", "string"}, - {"port", "int"}, - {"token", "string"}}); err != nil { - - //dw.queueMsg(addr, makeError(t, protocolError, err.Error())) - return - } - - rn = &remoteNode{address: *addr, id: ih} - dw.log.Debug("announce_peer", "source", rn, "infohash", ih) - - // TODO - if impliedPort, ok := a["implied_port"]; ok && - impliedPort.(int) != 0 { - //port = addr.Port - } - // TODO do we reply? - dw.peersOut <- Peer{*addr, ih} - - default: - //dw.queueMsg(addr, makeError(t, protocolError, "invalid q")) - return - } - dw.rTable.add(rn) - return true -} - -// handleResponse handles responses received from udp. -func (dw *dhtWorker) handleResponse(addr *net.UDPAddr, m map[string]interface{}) (success bool) { - - //t := m["t"].(string) - - if err := checkKey(m, "r", "map"); err != nil { - return - } - - r := m["r"].(map[string]interface{}) - if err := checkKey(r, "id", "string"); err != nil { - return - } - - var ih Infohash - ih.FromString(r["id"].(string)) - rn := &remoteNode{address: *addr, id: ih} - - // find_nodes response - if err := checkKey(r, "nodes", "string"); err == nil { - nodes := r["nodes"].(string) - dw.processFindNodeResults(rn, nodes) - return - } - - // get_peers response - if err := checkKey(r, "values", "list"); err == nil { - values := r["values"].([]interface{}) - for _, v := range values { - addr := compactNodeInfoToString(v.(string)) - dw.log.Debug("unhandled get_peer request", "addres", addr) - // TODO new peer - // dw.peersManager.Insert(ih, p) - } - } - dw.rTable.add(rn) - return true -} - -// handleError handles errors received from udp. -func (dw *dhtWorker) handleError(addr *net.UDPAddr, m map[string]interface{}) bool { - if err := checkKey(m, "e", "list"); err != nil { - return false - } - - e := m["e"].([]interface{}) - if len(e) != 2 { - return false - } - code := e[0].(int64) - msg := e[1].(string) - dw.log.Debug("error packet", "ip", addr.IP.String(), "port", addr.Port, "code", code, "error", msg) - - return true -} - -// Process another node's response to a find_node query. -func (dw *dhtWorker) processFindNodeResults(rn *remoteNode, nodeList string) { - nodeLength := 26 - /* - if d.config.proto == "udp6" { - nodeList = m.R.Nodes6 - nodeLength = 38 - } else { - nodeList = m.R.Nodes - } - - // Not much to do - if nodeList == "" { - return - } - */ - - if len(nodeList)%nodeLength != 0 { - dw.log.Error("node list is wrong length", "length", len(nodeList)) - return - } - - var ih Infohash - var err error - - //dw.log.Debug("got node list", "length", len(nodeList)) - - // We got a byte array in groups of 26 or 38 - for i := 0; i < len(nodeList); i += nodeLength { - id := nodeList[i : i+ihLength] - addr := compactNodeInfoToString(nodeList[i+ihLength : i+nodeLength]) - - err = ih.FromString(id) - if err != nil { - dw.log.Warn("invalid node list") - continue - } - - if dw.rTable.id.Equal(ih) { - continue - } - - address, err := net.ResolveUDPAddr("udp4", addr) - if err != nil { - dw.log.Error("failed to resolve", "error", err) - continue - } - rn := &remoteNode{address: *address, id: ih} - dw.rTable.add(rn) - } -} diff --git a/infohash.go b/infohash.go deleted file mode 100644 index 0327771..0000000 --- a/infohash.go +++ /dev/null @@ -1,43 +0,0 @@ -package dhtsearch - -import ( - "crypto/sha1" - "encoding/hex" - "errors" - "io" - "math/rand" - "time" -) - -const ihLength = 20 - -func genInfoHash() string { - random := rand.New(rand.NewSource(time.Now().UnixNano())) - hash := sha1.New() - io.WriteString(hash, time.Now().String()) - io.WriteString(hash, string(random.Int())) - ih := hash.Sum(nil) - return string(ih) -} - -func genNeighbour(first, second string) string { - s := second[:10] + first[10:] - return s -} - -func decodeInfoHash(in string) (b string, err error) { - var h []byte - h, err = hex.DecodeString(in) - if len(h) != ihLength { - return "", errors.New("invalid length") - } - return string(h), err -} - -func isValidInfoHash(id string) bool { - ih, err := hex.DecodeString(id) - if err != nil { - return false - } - return len(ih) == ihLength -} diff --git a/slab.go b/slab.go deleted file mode 100644 index 2210c9f..0000000 --- a/slab.go +++ /dev/null @@ -1,25 +0,0 @@ -package dhtsearch - -// Slab memory allocation - -// Initialise the slab as a channel of blocks, allocating them as required and -// pushing them back on the slab. This reduces garbage collection. -type slab chan []byte - -func newSlab(blockSize int, numBlocks int) slab { - s := make(slab, numBlocks) - for i := 0; i < numBlocks; i++ { - s <- make([]byte, blockSize) - } - return s -} - -func (s slab) Alloc() (x []byte) { - return <-s -} - -func (s slab) Free(x []byte) { - // Check we are using the right dimensions - x = x[:cap(x)] - s <- x -} diff --git a/stats.go b/stats.go deleted file mode 100644 index 0ff5e6f..0000000 --- a/stats.go +++ /dev/null @@ -1,32 +0,0 @@ -package dhtsearch - -type Stats struct { - DHTPacketsIn int `json:"dht_packets_in"` - DHTPacketsOut int `json:"dht_packets_out"` - DHTPacketsDropped int `json:"dht_packets_dropped"` - DHTErrors int `json:"dht_errors"` - DHTCachedPeers int `json:"dht_cached_peers"` - DHTBytesIn int `json:"dht_bytes_in"` - DHTBytesOut int `json:"dht_bytes_out"` - DHTWorkers int `json:"dht_workers"` - BTBytesInt int `json:"bt_bytes_int"` - BTBytesOut int `json:"bt_bytes_out"` - BTWorkers int `json:"bt_workers"` - PeersAnnounced int `json:"peers_announced"` - PeersSkipped int `json:"peers_skipped"` - TorrentsSkipped int `json:"torrents_skipped"` - TorrentsSaved int `json:"torrents_saved"` - TorrentsTotal int `json:"torrents_total"` -} - -func (s *Stats) Sub(other *Stats) Stats { - if other == nil { - return *s - } - var diff Stats - diff.MessagesIn = s.MessagesIn - other.MessagesIn - diff.BytesIn = s.BytesIn - other.BytesIn - diff.MessagesOut = s.MessagesOut - other.MessagesOut - diff.BytesOut = s.BytesOut - other.BytesOut - return diff -} |
