aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFelix Hanley <felix@userspace.com.au>2018-02-15 11:42:34 +0000
committerFelix Hanley <felix@userspace.com.au>2018-02-15 11:42:40 +0000
commit32a655f042a3752d93c4507b4c128b21bf6aa602 (patch)
tree224c0d7e51efccac3b32dc5d0662baa2ab7304a5
parent2ded0704c8f675c3d92cf2b4874a32c65faf2553 (diff)
downloaddhtsearch-32a655f042a3752d93c4507b4c128b21bf6aa602.tar.gz
dhtsearch-32a655f042a3752d93c4507b4c128b21bf6aa602.tar.bz2
Refactor DHT code into separate package
-rw-r--r--crawler/crawler.go129
-rw-r--r--crawler/dht.go171
-rw-r--r--crawler/dht_worker.go245
-rw-r--r--crawler/krpc.go37
-rw-r--r--crawler/packet.go9
-rw-r--r--crawler/peer.go9
-rw-r--r--crawler/remote_node.go24
-rw-r--r--crawler/routing_table.go58
-rw-r--r--dht/infohash.go59
-rw-r--r--dht/infohash_test.go64
-rw-r--r--dht/krpc.go72
-rw-r--r--dht/krpc_test.go25
-rw-r--r--dht/messages.go105
-rw-r--r--dht/node.go401
-rw-r--r--dht/options.go17
-rw-r--r--dht/peer.go10
-rw-r--r--dht/remote_node.go7
-rw-r--r--dht/routing_table.go116
-rw-r--r--dht/routing_table_test.go46
-rw-r--r--dht/slab.go4
-rw-r--r--dht/worker.go274
-rw-r--r--infohash.go43
-rw-r--r--slab.go25
-rw-r--r--stats.go32
24 files changed, 733 insertions, 1249 deletions
diff --git a/crawler/crawler.go b/crawler/crawler.go
deleted file mode 100644
index bfd9785..0000000
--- a/crawler/crawler.go
+++ /dev/null
@@ -1,129 +0,0 @@
-package crawler
-
-import (
- "regexp"
-
- "github.com/felix/logger"
-)
-
-const (
- TCPTimeout = 5
- UDPTimeout = 5
-)
-
-type Crawler struct {
- port int
- nodes int
- httpAddress string
- tagREs map[string]*regexp.Regexp
- log logger.Logger
-}
-
-// Option are options for the server
-type Option func(*Crawler) error
-
-// NewCrawler creates a set of DHT nodes to crawl the network
-func NewCrawer(opts ...Option) (*Crawler, error) {
- s := &Crawler{
- port: 6881,
- nodes: 1,
- httpAddress: "localhost:6880",
- tagREs: make(map[string]*regexp.Regexp),
- }
-
- // Default logger
- logOpts := &logger.Options{
- Name: "crawler",
- Level: logger.Info,
- }
- s.log = logger.New(logOpts)
-
- err := mergeTagRegexps(s.tagREs, tags)
- if err != nil {
- s.log.Error("failed to compile tags", "error", err)
- return nil, err
- }
- err = mergeCharacterTagREs(s.tagREs)
- if err != nil {
- s.log.Error("failed to compile character class tags", "error", err)
- return nil, err
- }
-
- // Set variadic options passed
- for _, option := range opts {
- err = option(s)
- if err != nil {
- return nil, err
- }
- }
-
- s.log.Debug("debugging output enabled")
-
- peers := make(chan peer)
-
- for i := 0; i < s.nodes; i++ {
- // Consecutive port numbers
- port := s.port + i
- node := &dhtNode{
- id: genInfoHash(),
- address: "",
- port: port,
- workers: 2,
- log: s.log.Named("dht"),
- peersOut: peers,
- }
- go node.run()
- }
-
- return s, nil
-}
-
-// SetLogger sets the server
-func SetLogger(l logger.Logger) Option {
- return func(s *Crawler) error {
- s.log = l
- return nil
- }
-}
-
-// SetPort sets the base port
-func SetPort(p int) Option {
- return func(s *Crawler) error {
- s.port = p
- return nil
- }
-}
-
-// SetNodes determines the number of nodes to start
-func SetNodes(n int) Option {
- return func(s *Crawler) error {
- s.nodes = n
- return nil
- }
-}
-
-// SetHTTPAddress determines the listening address for HTTP
-func SetHTTPAddress(a string) Option {
- return func(s *Crawler) error {
- s.httpAddress = a
- return nil
- }
-}
-
-// SetTags determines the listening address for HTTP
-func SetTags(tags map[string]string) Option {
- return func(s *Crawler) error {
- // Merge user tags
- err := mergeTagRegexps(s.tagREs, tags)
- if err != nil {
- s.log.Error("failed to compile tags", "error", err)
- }
- return err
- }
-}
-
-func (s *Crawler) Stats() Stats {
- s.statlock.RLock()
- defer s.statlock.RUnlock()
- return s.stats
-}
diff --git a/crawler/dht.go b/crawler/dht.go
deleted file mode 100644
index 51d478e..0000000
--- a/crawler/dht.go
+++ /dev/null
@@ -1,171 +0,0 @@
-package crawler
-
-import (
- "math"
- "net"
- "strconv"
- "sync/atomic"
- "time"
-
- "github.com/felix/logger"
-)
-
-var (
- routers = []string{
- "router.bittorrent.com:6881",
- "dht.transmissionbt.com:6881",
- "router.utorrent.com:6881",
- }
-)
-
-type dhtNode struct {
- id string
- address string
- port int
- conn *net.UDPConn
- pool chan chan packet
- workers int
- tid uint32
- packetsOut chan packet
- peersOut chan<- peer
- log logger.Logger
- //table routingTable
-}
-
-func (d *dhtNode) run() {
- listener, err := net.ListenPacket("udp4", d.address+":"+strconv.Itoa(d.port))
- if err != nil {
- d.log.Error("failed to listen", "error", err)
- return
- }
- d.conn = listener.(*net.UDPConn)
- d.port = d.conn.LocalAddr().(*net.UDPAddr).Port
-
- d.log.Info("listening", "address", d.address, "port", d.port)
-
- d.pool = make(chan chan packet)
-
- // Packets onto the network
- d.packetsOut = make(chan packet, 512)
-
- // Create a slab for allocation
- byteSlab := newSlab(8192, 10)
-
- rTable := newRoutingTable(d.id)
-
- // Start our workers
- for i := 0; i < d.workers; i++ {
- w := &dhtWorker{
- pool: d.pool,
- packetsOut: d.packetsOut,
- peersOut: d.peersOut,
- rTable: rTable,
- }
- }
-
- // Start writing packets from channel to DHT
- go func() {
- var p packet
- for {
- select {
- case p = <-d.packetsOut:
- d.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(UDPTimeout)))
- b, err := d.conn.WriteToUDP(p.b, &p.raddr)
- if err != nil {
- // TODO remove from routing or add to blacklist?
- d.log.Error("failed to write packet", "error", err)
- }
- }
- }
- }()
-
- // TODO configurable
- ticker := time.Tick(5 * time.Second)
-
- // Send packets from conn to workers
- for {
- b := byteSlab.Alloc()
- c, addr, err := d.conn.ReadFromUDP(b)
- if err != nil {
- d.log.Warn("read error", "error", err)
- continue
- }
-
- select {
- case pCh := <-d.pool:
- // Chop and send
- pCh <- packet{b[0:c], *addr}
- byteSlab.Free(b)
-
- case <-ticker:
- go func() {
- d.log.Debug("making neighbours")
- if rTable.isEmpty() {
- d.bootstrap()
- } else {
- for _, rn := range rTable.getNodes() {
- d.findNode(rn, rn.id)
- }
- rTable.refresh()
- }
- }()
- }
- }
- return
-}
-
-func (d *dhtNode) bootstrap() {
- d.log.Debug("bootstrapping")
- for _, s := range routers {
- addr, err := net.ResolveUDPAddr("udp4", s)
- if err != nil {
- d.log.Error("failed to parse bootstrap address", "error", err)
- return
- }
- rn := newRemoteNode(*addr, "")
- d.findNode(rn, "")
- }
-}
-
-func (d dhtNode) findNode(rn *remoteNode, target string) {
- var id string
- if target == "" {
- id = d.id
- } else {
- id = genNeighbour(d.id, target)
- }
- d.sendQuery(rn, "find_node", map[string]interface{}{
- "id": id,
- "target": genInfoHash(),
- })
-}
-
-// ping sends ping query to the chan.
-func (d *dhtNode) ping(rn *remoteNode) {
- d.sendQuery(rn, "ping", map[string]interface{}{
- "id": genNeighbour(d.id, rn.id),
- })
-}
-
-func (d dhtNode) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) {
-
- // Stop if sending to self
- if rn.id == d.id {
- return
- }
-
- t := d.newTransactionId()
-
- d.sendMsg(rn.address, makeQuery(t, qType, a))
-}
-
-// bencode data and send
-func (d *dhtNode) sendMsg(raddr net.UDPAddr, data map[string]interface{}) {
- d.packetsOut <- packet{[]byte(Encode(data)), raddr}
-}
-
-func (d *dhtNode) newTransactionId() string {
- t := atomic.AddUint32(&d.tid, 1)
- t = t % math.MaxUint16
- return strconv.Itoa(int(t))
-}
diff --git a/crawler/dht_worker.go b/crawler/dht_worker.go
deleted file mode 100644
index 29f3bc5..0000000
--- a/crawler/dht_worker.go
+++ /dev/null
@@ -1,245 +0,0 @@
-package crawler
-
-import (
- "net"
-
- "github.com/felix/logger"
-)
-
-type dhtWorker struct {
- pool chan chan packet
- packetsOut chan<- packet
- peersOut chan<- peer
- log logger.Logger
- rTable *routingTable
-}
-
-func (dw *dhtWorker) run(po chan<- packet) error {
- packetsIn := make(chan packet)
- dw.packetsOut = po
-
- for {
- dw.pool <- packetsIn
-
- select {
- // Wait for work
- case p := <-packetsIn:
- dw.process(p)
- }
- }
-}
-
-// Parse a KRPC packet into a message
-func (dw *dhtWorker) process(p packet) {
- data, err := Decode(p.b)
- if err != nil {
- return
- }
-
- response, err := parseMessage(data)
- if err != nil {
- dw.log.Warn("failed to parse packet", "error", err)
- return
- }
-
- switch response["y"].(string) {
- case "q":
- dw.handleRequest(&p.raddr, response)
- case "r":
- dw.handleResponse(&p.raddr, response)
- case "e":
- dw.handleError(&p.raddr, response)
- default:
- dw.log.Warn("missing request type")
- return
- }
-}
-
-// bencode data and send
-func (dw *dhtWorker) sendMsg(raddr net.UDPAddr, data map[string]interface{}) {
- dw.packetsOut <- packet{[]byte(Encode(data)), raddr}
-}
-
-// handleRequest handles the requests received from udp.
-func (dw *dhtWorker) handleRequest(addr *net.UDPAddr, m map[string]interface{}) (success bool) {
-
- t := m["t"].(string)
-
- if err := parseKeys(m, [][]string{{"q", "string"}, {"a", "map"}}); err != nil {
-
- //d.sendMsg(addr, makeError(t, protocolError, err.Error()))
- return
- }
-
- q := m["q"].(string)
- a := m["a"].(map[string]interface{})
-
- if err := parseKey(a, "id", "string"); err != nil {
- //d.sendMsg(addr, makeError(t, protocolError, err.Error()))
- return
- }
-
- id := a["id"].(string)
-
- if dw.rTable.id == id {
- return
- }
-
- if len(id) != 20 {
- //dw.sendMsg(addr, makeError(t, protocolError, "invalid id"))
- return
- }
-
- var rn *remoteNode
- switch q {
- case "ping":
- rn = newRemoteNode(*addr, id)
- dw.sendMsg(*addr, makeResponse(t, map[string]interface{}{
- "id": dw.rTable.id,
- }))
-
- case "get_peers":
- if err := parseKey(a, "info_hash", "string"); err != nil {
- //dw.sendMsg(addr, makeError(t, protocolError, err.Error()))
- return
- }
- rn = newRemoteNode(*addr, id)
- ih := a["info_hash"].(string)
- dw.log.Debug("get_peers", "source", rn.String(), "infohash", ih)
-
- if len(ih) != ihLength {
- //send(dht, addr, makeError(t, protocolError, "invalid info_hash"))
- return
- }
-
- // Crawling, we have no nodes
- dw.sendMsg(*addr, makeResponse(t, map[string]interface{}{
- "id": genNeighbour(dw.rTable.id, ih),
- "token": ih[:2],
- "nodes": "",
- }))
-
- case "announce_peer":
- if err := parseKeys(a, [][]string{
- {"info_hash", "string"},
- {"port", "int"},
- {"token", "string"}}); err != nil {
-
- //dw.sendMsg(addr, makeError(t, protocolError, err.Error()))
- return
- }
-
- ih := a["info_hash"].(string)
- rn = newRemoteNode(*addr, ih)
- dw.log.Debug("announce_peer", "source", rn.String(), "infohash", ih)
-
- // TODO
- if impliedPort, ok := a["implied_port"]; ok &&
- impliedPort.(int) != 0 {
- //port = addr.Port
- }
- // TODO do we reply?
- dw.peersOut <- peer{*addr, ih}
-
- default:
- //dw.sendMsg(addr, makeError(t, protocolError, "invalid q"))
- return
- }
- dw.rTable.add(rn)
- return true
-}
-
-// handleResponse handles responses received from udp.
-func (dw *dhtWorker) handleResponse(addr *net.UDPAddr, m map[string]interface{}) (success bool) {
-
- //t := m["t"].(string)
-
- // inform transManager to delete the transaction.
- if err := parseKey(m, "r", "map"); err != nil {
- return
- }
-
- r := m["r"].(map[string]interface{})
- if err := parseKey(r, "id", "string"); err != nil {
- return
- }
-
- ih := r["id"].(string)
- rn := newRemoteNode(*addr, ih)
-
- // find_nodes response
- if err := parseKey(r, "nodes", "string"); err == nil {
- nodes := r["nodes"].(string)
- dw.processFindNodeResults(rn, nodes)
- return
- }
-
- // get_peers response
- if err := parseKey(r, "values", "list"); err == nil {
- values := r["values"].([]interface{})
- for _, v := range values {
- addr := compactNodeInfoToString(v.(string))
- dw.log.Debug("unhandled get_peer request", "addres", addr)
- // TODO new peer
- // dw.peersManager.Insert(ih, p)
- }
- }
- dw.rTable.add(rn)
- return true
-}
-
-// handleError handles errors received from udp.
-func (dw *dhtWorker) handleError(addr *net.UDPAddr, m map[string]interface{}) (success bool) {
- if err := parseKey(m, "e", "list"); err != nil {
- return
- }
-
- if e := m["e"].([]interface{}); len(e) != 2 {
- return
- }
- dw.log.Debug("error packet", "ip", addr.IP.String(), "port", addr.Port)
-
- return true
-}
-
-// Process another node's response to a find_node query.
-func (dw *dhtWorker) processFindNodeResults(rn *remoteNode, nodeList string) {
- nodeLength := 26
- /*
- if d.config.proto == "udp6" {
- nodeList = m.R.Nodes6
- nodeLength = 38
- } else {
- nodeList = m.R.Nodes
- }
-
- // Not much to do
- if nodeList == "" {
- return
- }
- */
-
- if len(nodeList)%nodeLength != 0 {
- dw.log.Error("node list is wrong length", "length", len(nodeList))
- return
- }
-
- // We got a byte array in groups of 26 or 38
- for i := 0; i < len(nodeList); i += nodeLength {
- id := nodeList[i : i+ihLength]
- addr := compactNodeInfoToString(nodeList[i+ihLength : i+nodeLength])
-
- if dw.rTable.id == id {
- dw.log.Debug("find_nodes ignoring self")
- continue
- }
-
- address, err := net.ResolveUDPAddr("udp4", addr)
- if err != nil {
- dw.log.Error("failed to resolve", "error", err)
- continue
- }
- rn := newRemoteNode(*address, id)
- dw.rTable.add(rn)
- }
-}
diff --git a/crawler/krpc.go b/crawler/krpc.go
deleted file mode 100644
index 67150c0..0000000
--- a/crawler/krpc.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package crawler
-
-import (
- "errors"
- "fmt"
- "net"
-)
-
-// makeQuery returns a query-formed data.
-func makeQuery(t, q string, a map[string]interface{}) map[string]interface{} {
- return map[string]interface{}{
- "t": t,
- "y": "q",
- "q": q,
- "a": a,
- }
-}
-
-// makeResponse returns a response-formed data.
-func makeResponse(t string, r map[string]interface{}) map[string]interface{} {
- return map[string]interface{}{
- "t": t,
- "y": "r",
- "r": r,
- }
-}
-
-// parseKeys parses keys. It just wraps parseKey.
-func parseKeys(data map[string]interface{}, pairs [][]string) error {
- for _, args := range pairs {
- key, t := args[0], args[1]
- if err := parseKey(data, key, t); err != nil {
- return err
- }
- }
- return nil
-}
diff --git a/crawler/packet.go b/crawler/packet.go
deleted file mode 100644
index 1fa3318..0000000
--- a/crawler/packet.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package crawler
-
-import "net"
-
-// Unprocessed packet from socket
-type packet struct {
- b []byte
- raddr net.UDPAddr
-}
diff --git a/crawler/peer.go b/crawler/peer.go
deleted file mode 100644
index 06be506..0000000
--- a/crawler/peer.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package crawler
-
-import "net"
-
-// Peer on DHT network
-type Peer struct {
- Address net.UDPAddr
- ID string
-}
diff --git a/crawler/remote_node.go b/crawler/remote_node.go
deleted file mode 100644
index bfbc5ac..0000000
--- a/crawler/remote_node.go
+++ /dev/null
@@ -1,24 +0,0 @@
-package crawler
-
-import (
- "fmt"
- "net"
- //"time"
-)
-
-type remoteNode struct {
- address net.UDPAddr
- id string
- //lastSeen time.Time
-}
-
-func newRemoteNode(addr net.UDPAddr, id string) *remoteNode {
- return &remoteNode{
- address: addr,
- id: id,
- }
-}
-
-func (r *remoteNode) String() string {
- return fmt.Sprintf("%s:%d", r.address.IP.String(), r.address.Port)
-}
diff --git a/crawler/routing_table.go b/crawler/routing_table.go
deleted file mode 100644
index 8bb0d3c..0000000
--- a/crawler/routing_table.go
+++ /dev/null
@@ -1,58 +0,0 @@
-package crawler
-
-import (
- "net"
- "sync"
-)
-
-// Keep it simple for now
-type routingTable struct {
- id string
- address net.UDPAddr
- nodes []*remoteNode
- max int
- sync.Mutex
-}
-
-func newRoutingTable(id string) *routingTable {
- k := &routingTable{id: id, max: 4000}
- k.refresh()
- return k
-}
-
-func (k *routingTable) add(rn *remoteNode) {
- k.Lock()
- defer k.Unlock()
-
- // Check IP and ports are valid and not self
- if (rn.address.String() == k.address.String() && rn.address.Port == k.address.Port) ||
- rn.id == k.id || rn.id == "" {
- return
- }
- k.nodes = append(k.nodes, rn)
-}
-
-func (k *routingTable) getNodes() []*remoteNode {
- k.Lock()
- defer k.Unlock()
- return k.nodes
-}
-
-func (k *routingTable) isEmpty() bool {
- k.Lock()
- defer k.Unlock()
- return len(k.nodes) == 0
-}
-
-func (k *routingTable) isFull() bool {
- k.Lock()
- defer k.Unlock()
- return len(k.nodes) >= k.max
-}
-
-// For now
-func (k *routingTable) refresh() {
- k.Lock()
- defer k.Unlock()
- k.nodes = make([]*remoteNode, 0)
-}
diff --git a/dht/infohash.go b/dht/infohash.go
index 9c3ea3d..cd12446 100644
--- a/dht/infohash.go
+++ b/dht/infohash.go
@@ -3,6 +3,7 @@ package dht
import (
"crypto/sha1"
"encoding/hex"
+ "fmt"
"io"
"math/rand"
"time"
@@ -10,9 +11,29 @@ import (
const ihLength = 20
-// Infohash -
+// Infohash is a 160 bit (20 byte) value
type Infohash []byte
+// InfohashFromString converts a 40 digit hexadecimal string to an Infohash
+func InfohashFromString(s string) (*Infohash, error) {
+ switch len(s) {
+ case 20:
+ // Binary string
+ ih := Infohash([]byte(s))
+ return &ih, nil
+ case 40:
+ // Hex string
+ b, err := hex.DecodeString(s)
+ if err != nil {
+ return nil, err
+ }
+ ih := Infohash(b)
+ return &ih, nil
+ default:
+ return nil, fmt.Errorf("invalid length %d", len(s))
+ }
+}
+
func (ih Infohash) String() string {
return hex.EncodeToString(ih)
}
@@ -33,25 +54,31 @@ func (ih Infohash) Equal(other Infohash) bool {
return true
}
-// FromString -
-func (ih *Infohash) FromString(s string) error {
- switch len(s) {
- case 20:
- // Byte string
- *ih = Infohash([]byte(s))
- return nil
- case 40:
- b, err := hex.DecodeString(s)
- if err != nil {
- return err
+// Distance determines the distance to another infohash as an integer
+func (ih Infohash) Distance(other Infohash) int {
+ i := 0
+ for ; i < 20; i++ {
+ if ih[i] != other[i] {
+ break
}
- *ih = Infohash(b)
}
- return nil
+
+ if i == 20 {
+ return 160
+ }
+
+ xor := ih[i] ^ other[i]
+
+ j := 0
+ for (xor & 0x80) == 0 {
+ xor <<= 1
+ j++
+ }
+ return 8*i + j
}
-func (ih Infohash) GenNeighbour(other Infohash) Infohash {
- s := append(ih[:10], other[10:]...)
+func generateNeighbour(first, second Infohash) Infohash {
+ s := append(first[:10], second[10:]...)
return Infohash(s)
}
diff --git a/dht/infohash_test.go b/dht/infohash_test.go
index b3223f4..1574b19 100644
--- a/dht/infohash_test.go
+++ b/dht/infohash_test.go
@@ -6,19 +6,35 @@ import (
)
func TestInfohashImport(t *testing.T) {
- var ih Infohash
- idHex := "5a3ce1c14e7a08645677bbd1cfe7d8f956d53256"
- err := ih.FromString(idHex)
- if err != nil {
- t.Errorf("FromString failed with %s", err)
+ tests := []struct {
+ str string
+ ok bool
+ }{
+ {str: "5a3ce1c14e7a08645677bbd1cfe7d8f956d53256", ok: true},
+ {str: "5a3ce1c14e7a08645677bbd1cfe7d8f956d53256000", ok: false},
}
- idBytes, err := hex.DecodeString(idHex)
+ for _, tt := range tests {
+ ih, err := InfohashFromString(tt.str)
+ if tt.ok {
+ if err != nil {
+ t.Errorf("FromString failed with %s", err)
+ }
- ih2 := Infohash(idBytes)
- if !ih.Equal(ih2) {
- t.Errorf("expected %s to equal %s", ih, ih2)
+ idBytes, err := hex.DecodeString(tt.str)
+ if err != nil {
+ t.Errorf("failed to decode %s to hex", tt.str)
+ }
+ ih2 := Infohash(idBytes)
+ if !ih.Equal(ih2) {
+ t.Errorf("expected %s to equal %s", ih, ih2)
+ }
+ } else {
+ if err == nil {
+ t.Errorf("FromString should have failed for %s", tt.str)
+ }
+ }
}
}
@@ -28,3 +44,33 @@ func TestInfohashLength(t *testing.T) {
t.Errorf("%s as string should be length 20, got %d", ih, len(ih))
}
}
+
+func TestInfohashDistance(t *testing.T) {
+ id := "d1c5676ae7ac98e8b19f63565905105e3c4c37a2"
+
+ var tests = []struct {
+ ih string
+ other string
+ distance int
+ }{
+ {id, id, 160},
+ {id, "d1c5676ae7ac98e8b19f63565905105e3c4c37a3", 159},
+ }
+
+ ih, err := InfohashFromString(id)
+ if err != nil {
+ t.Errorf("Failed to create Infohash: %s", err)
+ }
+
+ for _, tt := range tests {
+ other, err := InfohashFromString(tt.other)
+ if err != nil {
+ t.Errorf("Failed to create Infohash: %s", err)
+ }
+
+ dist := ih.Distance(*other)
+ if dist != tt.distance {
+ t.Errorf("Distance() => %d, expected %d", dist, tt.distance)
+ }
+ }
+}
diff --git a/dht/krpc.go b/dht/krpc.go
index a926e81..de508d9 100644
--- a/dht/krpc.go
+++ b/dht/krpc.go
@@ -5,6 +5,7 @@ import (
"fmt"
"math/rand"
"net"
+ "strconv"
)
const transIDBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
@@ -36,19 +37,40 @@ func makeResponse(t string, r map[string]interface{}) map[string]interface{} {
}
}
-// parseMessage parses the basic data received from udp.
-// It returns a map value.
-func parseMessage(data interface{}) (map[string]interface{}, error) {
- response, ok := data.(map[string]interface{})
+func getStringKey(data map[string]interface{}, key string) (string, error) {
+ val, ok := data[key]
+ if !ok {
+ return "", fmt.Errorf("krpc: missing key %s", key)
+ }
+ out, ok := val.(string)
if !ok {
- return nil, errors.New("response is not dict")
+ return "", fmt.Errorf("krpc: key type mismatch")
}
+ return out, nil
+}
- if err := checkKeys(response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil {
- return nil, err
+func getMapKey(data map[string]interface{}, key string) (map[string]interface{}, error) {
+ val, ok := data[key]
+ if !ok {
+ return nil, fmt.Errorf("krpc: missing key %s", key)
}
+ out, ok := val.(map[string]interface{})
+ if !ok {
+ return nil, fmt.Errorf("krpc: key type mismatch")
+ }
+ return out, nil
+}
- return response, nil
+func getListKey(data map[string]interface{}, key string) ([]interface{}, error) {
+ val, ok := data[key]
+ if !ok {
+ return nil, fmt.Errorf("krpc: missing key %s", key)
+ }
+ out, ok := val.([]interface{})
+ if !ok {
+ return nil, fmt.Errorf("krpc: key type mismatch")
+ }
+ return out, nil
}
// parseKeys parses keys. It just wraps parseKey.
@@ -101,3 +123,37 @@ func compactNodeInfoToString(cni string) string {
return ""
}
}
+
+func stringToCompactNodeInfo(addr string) ([]byte, error) {
+ host, port, err := net.SplitHostPort(addr)
+ if err != nil {
+ return []byte{}, err
+ }
+ pInt, err := strconv.ParseInt(port, 10, 64)
+ if err != nil {
+ return []byte{}, err
+ }
+ p := int2bytes(pInt)
+ if len(p) < 2 {
+ p = append(p, p[0])
+ p[0] = 0
+ }
+ return append([]byte(host), p...), nil
+}
+
+func int2bytes(val int64) []byte {
+ data, j := make([]byte, 8), -1
+ for i := 0; i < 8; i++ {
+ shift := uint64((7 - i) * 8)
+ data[i] = byte((val & (0xff << shift)) >> shift)
+
+ if j == -1 && data[i] != 0 {
+ j = i
+ }
+ }
+
+ if j != -1 {
+ return data[j:]
+ }
+ return data[:1]
+}
diff --git a/dht/krpc_test.go b/dht/krpc_test.go
new file mode 100644
index 0000000..d710678
--- /dev/null
+++ b/dht/krpc_test.go
@@ -0,0 +1,25 @@
+package dht
+
+import (
+ "testing"
+)
+
+func TestStringToCompactNodeInfo(t *testing.T) {
+
+ tests := []struct {
+ in string
+ out []byte
+ }{
+ {in: "192.168.1.1:6881", out: []byte("asdfasdf")},
+ }
+
+ for _, tt := range tests {
+ r, err := stringToCompactNodeInfo(tt.in)
+ if err != nil {
+ t.Errorf("stringToCompactNodeInfo failed with %s", err)
+ }
+ if r != tt.out {
+ t.Errorf("stringToCompactNodeInfo(%s) => %s, expected %s", tt.in, r, tt.out)
+ }
+ }
+}
diff --git a/dht/messages.go b/dht/messages.go
new file mode 100644
index 0000000..94b15ee
--- /dev/null
+++ b/dht/messages.go
@@ -0,0 +1,105 @@
+package dht
+
+import (
+ "fmt"
+ "net"
+)
+
+func (n *Node) onPingQuery(rn remoteNode, msg map[string]interface{}) {
+ t := msg["t"].(string)
+ //n.log.Debug("ping", "source", rn)
+ n.queueMsg(rn, makeResponse(t, map[string]interface{}{
+ "id": string(n.id),
+ }))
+}
+
+func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) {
+ a := msg["a"].(map[string]interface{})
+ if err := checkKey(a, "info_hash", "string"); err != nil {
+ //n.queueMsg(addr, makeError(t, protocolError, err.Error()))
+ return
+ }
+
+ // This is the ih of the torrent
+ th, err := InfohashFromString(a["info_hash"].(string))
+ if err != nil {
+ n.log.Warn("invalid torrent", "infohash", a["info_hash"])
+ }
+ n.log.Debug("get_peers query", "source", rn, "torrent", th)
+
+ token := []byte(*th)[:2]
+
+ id := generateNeighbour(n.id, *th)
+ t := msg["t"].(string)
+ n.queueMsg(rn, makeResponse(t, map[string]interface{}{
+ "id": string(id),
+ "token": token,
+ "nodes": "",
+ }))
+
+ nodes := n.rTable.get(50)
+ fmt.Printf("sending get_peers for %s to %d nodes\n", *th, len(nodes))
+ q := makeQuery(newTransactionID(), "get_peers", map[string]interface{}{
+ "id": string(id),
+ "info_hash": string(*th),
+ })
+ for _, o := range nodes {
+ n.queueMsg(*o, q)
+ }
+}
+
+func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) {
+ a := msg["a"].(map[string]interface{})
+ err := checkKeys(a, [][]string{
+ {"info_hash", "string"},
+ {"port", "int"},
+ {"token", "string"},
+ })
+ if err != nil {
+ //n.queueMsg(addr, makeError(t, protocolError, err.Error()))
+ return
+ }
+
+ n.log.Debug("announce_peer", "source", rn)
+
+ // TODO
+ if impliedPort, ok := a["implied_port"]; ok && impliedPort.(int) != 0 {
+ // Use the port from the network
+ } else {
+ // Use the port in the message
+ host, _, err := net.SplitHostPort(rn.address.String())
+ if err != nil {
+ n.log.Warn("failed to split host/port", "error", err)
+ return
+ }
+ newPort := a["port"]
+ if newPort == 0 {
+ n.log.Warn("sent port 0", "source", rn)
+ return
+ }
+ addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", host, newPort))
+ rn = remoteNode{address: addr, id: rn.id}
+ }
+
+ // TODO do we reply?
+
+ ih, err := InfohashFromString(a["info_hash"].(string))
+ if err != nil {
+ n.log.Warn("invalid torrent", "infohash", a["info_hash"])
+ }
+
+ p := Peer{Node: rn, Infohash: *ih}
+ n.log.Info("anounce_peer", p)
+ if n.OnAnnouncePeer != nil {
+ go n.OnAnnouncePeer(p)
+ }
+}
+
+func (n *Node) onFindNodeResponse(rn remoteNode, msg map[string]interface{}) {
+ r := msg["r"].(map[string]interface{})
+ if err := checkKey(r, "id", "string"); err != nil {
+ return
+ }
+ nodes := r["nodes"].(string)
+ n.processFindNodeResults(rn, nodes)
+}
diff --git a/dht/node.go b/dht/node.go
index 3f8f349..169f219 100644
--- a/dht/node.go
+++ b/dht/node.go
@@ -1,12 +1,15 @@
package dht
import (
+ //"fmt"
+ "context"
"net"
"strconv"
"time"
"github.com/felix/dhtsearch/bencode"
"github.com/felix/logger"
+ "golang.org/x/time/rate"
)
var (
@@ -23,35 +26,35 @@ type Node struct {
id Infohash
address string
port int
- conn *net.UDPConn
+ conn net.PacketConn
pool chan chan packet
rTable *routingTable
- workers []*dhtWorker
udpTimeout int
packetsOut chan packet
- peersOut chan Peer
- closing chan chan error
log logger.Logger
//table routingTable
// OnAnnoucePeer is called for each peer that announces itself
- OnAnnoucePeer func(p *Peer)
+ OnAnnouncePeer func(p Peer)
}
// NewNode creates a new DHT node
func NewNode(opts ...Option) (n *Node, err error) {
-
id := randomInfoHash()
+ k, err := newRoutingTable(id, 2000)
+ if err != nil {
+ n.log.Error("failed to create routing table", "error", err)
+ return nil, err
+ }
+
n = &Node{
- id: id,
- address: "0.0.0.0",
- port: 6881,
- rTable: newRoutingTable(id),
- workers: make([]*dhtWorker, 1),
- closing: make(chan chan error),
- log: logger.New(&logger.Options{Name: "dht"}),
- peersOut: make(chan Peer),
+ id: id,
+ address: "0.0.0.0",
+ port: 6881,
+ udpTimeout: 10,
+ rTable: k,
+ log: logger.New(&logger.Options{Name: "dht"}),
}
// Set variadic options passed
@@ -62,141 +65,111 @@ func NewNode(opts ...Option) (n *Node, err error) {
}
}
+ n.conn, err = net.ListenPacket("udp", n.address+":"+strconv.Itoa(n.port))
+ if err != nil {
+ n.log.Error("failed to listen", "error", err)
+ return nil, err
+ }
+ n.log.Info("listening", "id", n.id, "network", n.conn.LocalAddr().Network(), "address", n.conn.LocalAddr().String())
+
return n, nil
}
// Close stuff
func (n *Node) Close() error {
n.log.Warn("node closing")
- errCh := make(chan error)
- n.closing <- errCh
- // Signal workers
- for _, w := range n.workers {
- w.stop()
- }
- return <-errCh
+ return nil
}
// Run starts the node on the DHT
-func (n *Node) Run() chan Peer {
- listener, err := net.ListenPacket("udp4", n.address+":"+strconv.Itoa(n.port))
- if err != nil {
- n.log.Error("failed to listen", "error", err)
- return nil
- }
- n.conn = listener.(*net.UDPConn)
- n.port = n.conn.LocalAddr().(*net.UDPAddr).Port
- n.log.Info("listening", "id", n.id, "address", n.address, "port", n.port)
-
- // Worker pool
- n.pool = make(chan chan packet)
+func (n *Node) Run() {
// Packets onto the network
- n.packetsOut = make(chan packet, 512)
+ n.packetsOut = make(chan packet, 1024)
// Create a slab for allocation
byteSlab := newSlab(8192, 10)
- // Start our workers
- n.log.Debug("starting workers", "count", len(n.workers))
- for i := 0; i < len(n.workers); i++ {
- w := &dhtWorker{
- pool: n.pool,
- packetsOut: n.packetsOut,
- peersOut: n.peersOut,
- rTable: n.rTable,
- quit: make(chan struct{}),
- log: n.log.Named("worker"),
+ n.log.Debug("starting packet writer")
+ go n.packetWriter()
+
+ // Find neighbours
+ go n.makeNeighbours()
+
+ n.log.Debug("starting packet reader")
+ for {
+ b := byteSlab.alloc()
+ c, addr, err := n.conn.ReadFrom(b)
+ if err != nil {
+ n.log.Warn("UDP read error", "error", err)
+ return
}
- go w.run()
- n.workers[i] = w
+
+ // Chop and process
+ n.processPacket(packet{
+ data: b[0:c],
+ raddr: addr,
+ })
+ byteSlab.free(b)
}
+}
- n.log.Debug("starting packet writer")
- // Start writing packets from channel to DHT
- go func() {
- var p packet
- for {
- select {
- case p = <-n.packetsOut:
- //n.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(n.udpTimeout)))
- _, err := n.conn.WriteToUDP(p.data, &p.raddr)
- if err != nil {
- // TODO remove from routing or add to blacklist?
- n.log.Warn("failed to write packet", "error", err)
- }
- }
- }
- }()
+func (n *Node) makeNeighbours() {
+ // TODO configurable
+ ticker := time.Tick(5 * time.Second)
- n.log.Debug("starting packet reader")
- // Start reading packets
- go func() {
- n.bootstrap()
-
- // TODO configurable
- ticker := time.Tick(10 * time.Second)
-
- // Send packets from conn to workers
- for {
- select {
- case errCh := <-n.closing:
- // TODO
- errCh <- nil
- case pCh := <-n.pool:
- go func() {
- b := byteSlab.Alloc()
- c, addr, err := n.conn.ReadFromUDP(b)
- if err != nil {
- n.log.Warn("UDP read error", "error", err)
- return
- }
-
- // Chop and send
- pCh <- packet{
- data: b[0:c],
- raddr: *addr,
- }
- byteSlab.Free(b)
- }()
-
- case <-ticker:
- go func() {
- if n.rTable.isEmpty() {
- n.bootstrap()
- } else {
- n.makeNeighbours()
- }
- }()
+ n.bootstrap()
+
+ for {
+ select {
+ case <-ticker:
+ if n.rTable.isEmpty() {
+ n.bootstrap()
+ } else {
+ // Send to all nodes
+ nodes := n.rTable.get(0)
+ for _, rn := range nodes {
+ n.findNode(rn, generateNeighbour(n.id, rn.id))
+ }
+ n.rTable.flush()
}
}
- }()
- return n.peersOut
+ }
}
-func (n *Node) bootstrap() {
+func (n Node) bootstrap() {
n.log.Debug("bootstrapping")
for _, s := range routers {
- addr, err := net.ResolveUDPAddr("udp4", s)
+ addr, err := net.ResolveUDPAddr(n.conn.LocalAddr().Network(), s)
if err != nil {
n.log.Error("failed to parse bootstrap address", "error", err)
return
}
- rn := &remoteNode{address: *addr}
+ rn := &remoteNode{address: addr}
n.findNode(rn, n.id)
}
}
-func (n *Node) makeNeighbours() {
- n.log.Debug("making neighbours")
- for _, rn := range n.rTable.getNodes() {
- n.findNode(rn, n.id)
+func (n *Node) packetWriter() {
+ l := rate.NewLimiter(rate.Limit(500), 100)
+
+ for p := range n.packetsOut {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ if err := l.Wait(ctx); err != nil {
+ n.log.Warn("rate limited", "error", err)
+ continue
+ }
+ _, err := n.conn.WriteTo(p.data, p.raddr)
+ if err != nil {
+ // TODO remove from routing or add to blacklist?
+ n.log.Warn("failed to write packet", "error", err)
+ }
}
- n.rTable.refresh()
}
func (n Node) findNode(rn *remoteNode, id Infohash) {
target := randomInfoHash()
- n.sendMsg(rn, "find_node", map[string]interface{}{
+ n.sendQuery(rn, "find_node", map[string]interface{}{
"id": string(id),
"target": string(target),
})
@@ -204,13 +177,13 @@ func (n Node) findNode(rn *remoteNode, id Infohash) {
// ping sends ping query to the chan.
func (n *Node) ping(rn *remoteNode) {
- id := n.id.GenNeighbour(rn.id)
- n.sendMsg(rn, "ping", map[string]interface{}{
+ id := generateNeighbour(n.id, rn.id)
+ n.sendQuery(rn, "ping", map[string]interface{}{
"id": string(id),
})
}
-func (n Node) sendMsg(rn *remoteNode, qType string, a map[string]interface{}) error {
+func (n Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) error {
// Stop if sending to self
if rn.id.Equal(n.id) {
return nil
@@ -224,9 +197,207 @@ func (n Node) sendMsg(rn *remoteNode, qType string, a map[string]interface{}) er
if err != nil {
return err
}
+ //fmt.Printf("sending %s to %s\n", qType, rn.String())
+ n.packetsOut <- packet{
+ data: b,
+ raddr: rn.address,
+ }
+ return nil
+}
+
+// Parse a KRPC packet into a message
+func (n *Node) processPacket(p packet) {
+ data, err := bencode.Decode(p.data)
+ if err != nil {
+ return
+ }
+
+ response, ok := data.(map[string]interface{})
+ if !ok {
+ n.log.Debug("failed to parse packet", "error", "response is not dict")
+ return
+ }
+
+ if err := checkKeys(response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil {
+ n.log.Debug("failed to parse packet", "error", err)
+ return
+ }
+
+ switch response["y"].(string) {
+ case "q":
+ n.handleRequest(p.raddr, response)
+ case "r":
+ err = n.handleResponse(p.raddr, response)
+ case "e":
+ n.handleError(p.raddr, response)
+ default:
+ n.log.Warn("missing request type")
+ return
+ }
+ if err != nil {
+ n.log.Warn("failed to process packet", "error", err)
+ }
+}
+
+// bencode data and send
+func (n *Node) queueMsg(rn remoteNode, data map[string]interface{}) error {
+ b, err := bencode.Encode(data)
+ if err != nil {
+ return err
+ }
n.packetsOut <- packet{
data: b,
raddr: rn.address,
}
return nil
}
+
+// handleRequest handles the requests received from udp.
+func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) (success bool) {
+ if err := checkKeys(m, [][]string{{"q", "string"}, {"a", "map"}}); err != nil {
+
+ //d.queueMsg(addr, makeError(t, protocolError, err.Error()))
+ return
+ }
+
+ a := m["a"].(map[string]interface{})
+
+ if err := checkKey(a, "id", "string"); err != nil {
+ //d.queueMsg(addr, makeError(t, protocolError, err.Error()))
+ return
+ }
+
+ ih, err := InfohashFromString(a["id"].(string))
+ if err != nil {
+ n.log.Warn("invalid request", "infohash", a["id"].(string))
+ }
+
+ if n.id.Equal(*ih) {
+ return
+ }
+
+ rn := &remoteNode{address: addr, id: *ih}
+ q := m["q"].(string)
+
+ switch q {
+ case "ping":
+ n.onPingQuery(*rn, m)
+
+ case "get_peers":
+ n.onGetPeersQuery(*rn, m)
+
+ case "announce_peer":
+ n.onAnnouncePeerQuery(*rn, m)
+
+ default:
+ //n.queueMsg(addr, makeError(t, protocolError, "invalid q"))
+ return
+ }
+ n.rTable.add(rn)
+ return true
+}
+
+// handleResponse handles responses received from udp.
+func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error {
+ r, err := getMapKey(m, "r")
+ if err != nil {
+ return err
+ }
+ id, err := getStringKey(r, "id")
+ if err != nil {
+ return err
+ }
+ ih, err := InfohashFromString(id)
+ if err != nil {
+ return err
+ }
+
+ rn := &remoteNode{address: addr, id: *ih}
+
+ nodes, err := getStringKey(r, "nodes")
+ // find_nodes/get_peers response with nodes
+ if err == nil {
+ n.onFindNodeResponse(*rn, m)
+ n.processFindNodeResults(*rn, nodes)
+ n.rTable.add(rn)
+ return nil
+ }
+
+ values, err := getListKey(r, "values")
+ // get_peers response
+ if err == nil {
+ n.log.Debug("get_peers response", "source", rn)
+ for _, v := range values {
+ addr := compactNodeInfoToString(v.(string))
+ n.log.Debug("unhandled get_peer request", "addres", addr)
+
+ // TODO new peer needs to be matched to previous get_peers request
+ // n.peersManager.Insert(ih, p)
+ }
+ n.rTable.add(rn)
+ }
+ return nil
+}
+
+// handleError handles errors received from udp.
+func (n *Node) handleError(addr net.Addr, m map[string]interface{}) bool {
+ if err := checkKey(m, "e", "list"); err != nil {
+ return false
+ }
+
+ e := m["e"].([]interface{})
+ if len(e) != 2 {
+ return false
+ }
+ code := e[0].(int64)
+ msg := e[1].(string)
+ n.log.Debug("error packet", "address", addr.String(), "code", code, "error", msg)
+
+ return true
+}
+
+// Process another node's response to a find_node query.
+func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) {
+ nodeLength := 26
+ /*
+ if d.config.proto == "udp6" {
+ nodeList = m.R.Nodes6
+ nodeLength = 38
+ } else {
+ nodeList = m.R.Nodes
+ }
+
+ // Not much to do
+ if nodeList == "" {
+ return
+ }
+ */
+
+ if len(nodeList)%nodeLength != 0 {
+ n.log.Error("node list is wrong length", "length", len(nodeList))
+ return
+ }
+
+ //fmt.Printf("%s sent %d nodes\n", rn.address.String(), len(nodeList)/nodeLength)
+
+ // We got a byte array in groups of 26 or 38
+ for i := 0; i < len(nodeList); i += nodeLength {
+ id := nodeList[i : i+ihLength]
+ addrStr := compactNodeInfoToString(nodeList[i+ihLength : i+nodeLength])
+
+ ih, err := InfohashFromString(id)
+ if err != nil {
+ n.log.Warn("invalid infohash in node list")
+ continue
+ }
+
+ addr, err := net.ResolveUDPAddr("udp", addrStr)
+ if err != nil || addr.Port == 0 {
+ n.log.Warn("unable to resolve", "address", addrStr, "error", err)
+ continue
+ }
+
+ rn := &remoteNode{address: addr, id: *ih}
+ n.rTable.add(rn)
+ }
+}
diff --git a/dht/options.go b/dht/options.go
index 03a85bf..f870a79 100644
--- a/dht/options.go
+++ b/dht/options.go
@@ -6,6 +6,13 @@ import (
type Option func(*Node) error
+func SetOnAnnouncePeer(f func(Peer)) Option {
+ return func(n *Node) error {
+ n.OnAnnouncePeer = f
+ return nil
+ }
+}
+
// SetAddress sets the IP address to listen on
func SetAddress(ip string) Option {
return func(n *Node) error {
@@ -22,14 +29,6 @@ func SetPort(p int) Option {
}
}
-// SetWorkers sets the number of workers
-func SetWorkers(c int) Option {
- return func(n *Node) error {
- n.workers = make([]*dhtWorker, c)
- return nil
- }
-}
-
// SetUDPTimeout sets the number of seconds to wait for UDP connections
func SetUDPTimeout(s int) Option {
return func(n *Node) error {
@@ -38,7 +37,7 @@ func SetUDPTimeout(s int) Option {
}
}
-// SetLogger sets the number of workers
+// SetLogger sets the logger
func SetLogger(l logger.Logger) Option {
return func(n *Node) error {
n.log = l
diff --git a/dht/peer.go b/dht/peer.go
index a801331..f9669ba 100644
--- a/dht/peer.go
+++ b/dht/peer.go
@@ -1,9 +1,13 @@
package dht
-import "net"
+import "fmt"
// Peer on DHT network
type Peer struct {
- Address net.UDPAddr
- ID Infohash
+ Node remoteNode
+ Infohash Infohash
+}
+
+func (p Peer) String() string {
+ return fmt.Sprintf("%s (%s)", p.Infohash, p.Node)
}
diff --git a/dht/remote_node.go b/dht/remote_node.go
index 8a8d4a2..5bb2585 100644
--- a/dht/remote_node.go
+++ b/dht/remote_node.go
@@ -3,16 +3,15 @@ package dht
import (
"fmt"
"net"
- //"time"
)
type remoteNode struct {
- address net.UDPAddr
+ address net.Addr
id Infohash
//lastSeen time.Time
}
// String implements fmt.Stringer
-func (r *remoteNode) String() string {
- return fmt.Sprintf("%s:%d", r.address.IP.String(), r.address.Port)
+func (r remoteNode) String() string {
+ return fmt.Sprintf("%s (%s)", r.id.String(), r.address.String())
}
diff --git a/dht/routing_table.go b/dht/routing_table.go
index 0252519..3c1f2d2 100644
--- a/dht/routing_table.go
+++ b/dht/routing_table.go
@@ -1,57 +1,119 @@
package dht
import (
- "net"
+ "container/heap"
"sync"
)
-// Keep it simple for now
+type rItem struct {
+ value *remoteNode
+ distance int
+ index int // Index in heap
+}
+
+type priorityQueue []*rItem
+
type routingTable struct {
- id Infohash
- address net.UDPAddr
- nodes []*remoteNode
- max int
+ id Infohash
+ max int
+ items priorityQueue
+ addresses map[string]*remoteNode
sync.Mutex
}
-func newRoutingTable(id Infohash) *routingTable {
- k := &routingTable{id: id, max: 4000}
- k.refresh()
- return k
+func newRoutingTable(id Infohash, max int) (*routingTable, error) {
+ k := &routingTable{
+ id: id,
+ max: max,
+ }
+ k.flush()
+ heap.Init(&k.items)
+ return k, nil
}
-func (k *routingTable) add(rn *remoteNode) {
- k.Lock()
- defer k.Unlock()
+// Len implements sort.Interface
+func (pq priorityQueue) Len() int { return len(pq) }
+
+// Less implements sort.Interface
+func (pq priorityQueue) Less(i, j int) bool {
+ return pq[i].distance > pq[j].distance
+}
+
+// Swap implements sort.Interface
+func (pq priorityQueue) Swap(i, j int) {
+ pq[i], pq[j] = pq[j], pq[i]
+ pq[i].index = i
+ pq[j].index = j
+}
+
+// Push implements heap.Interface
+func (pq *priorityQueue) Push(x interface{}) {
+ n := len(*pq)
+ item := x.(*rItem)
+ item.index = n
+ *pq = append(*pq, item)
+}
+
+// Pop implements heap.Interface
+func (pq *priorityQueue) Pop() interface{} {
+ old := *pq
+ n := len(old)
+ item := old[n-1]
+ item.index = -1 // for safety
+ *pq = old[0 : n-1]
+ return item
+}
+func (k *routingTable) add(rn *remoteNode) {
// Check IP and ports are valid and not self
- if (rn.address.String() == k.address.String() && rn.address.Port == k.address.Port) || !rn.id.Valid() || rn.id.Equal(k.id) {
+ if !rn.id.Valid() || rn.id.Equal(k.id) {
return
}
- k.nodes = append(k.nodes, rn)
-}
-func (k *routingTable) getNodes() []*remoteNode {
k.Lock()
defer k.Unlock()
- return k.nodes
+
+ if _, ok := k.addresses[rn.address.String()]; ok {
+ return
+ }
+ k.addresses[rn.address.String()] = rn
+
+ item := &rItem{
+ value: rn,
+ distance: k.id.Distance(rn.id),
+ }
+
+ heap.Push(&k.items, item)
+
+ if len(k.items) > k.max {
+ for i := k.max - 1; i < len(k.items); i++ {
+ old := k.items[i]
+ delete(k.addresses, old.value.address.String())
+ heap.Remove(&k.items, i)
+ }
+ }
}
-func (k *routingTable) isEmpty() bool {
- k.Lock()
- defer k.Unlock()
- return len(k.nodes) == 0
+func (k *routingTable) get(n int) (out []*remoteNode) {
+ if n == 0 {
+ n = len(k.items)
+ }
+ for i := 0; i < n && i < len(k.items); i++ {
+ out = append(out, k.items[i].value)
+ }
+ return out
}
-func (k *routingTable) isFull() bool {
+func (k *routingTable) flush() {
k.Lock()
defer k.Unlock()
- return len(k.nodes) >= k.max
+
+ k.items = make(priorityQueue, 0)
+ k.addresses = make(map[string]*remoteNode, k.max)
}
-// For now
-func (k *routingTable) refresh() {
+func (k *routingTable) isEmpty() bool {
k.Lock()
defer k.Unlock()
- k.nodes = make([]*remoteNode, 0)
+ return len(k.items) == 0
}
diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go
new file mode 100644
index 0000000..77c0a17
--- /dev/null
+++ b/dht/routing_table_test.go
@@ -0,0 +1,46 @@
+package dht
+
+import (
+ "fmt"
+ "net"
+ "testing"
+)
+
+func TestPriorityQueue(t *testing.T) {
+ id := "d1c5676ae7ac98e8b19f63565905105e3c4c37a2"
+
+ tests := []string{
+ "d1c5676ae7ac98e8b19f63565905105e3c4c37b9",
+ "d1c5676ae7ac98e8b19f63565905105e3c4c37a9",
+ "d1c5676ae7ac98e8b19f63565905105e3c4c37a4",
+ "d1c5676ae7ac98e8b19f63565905105e3c4c37a3", // distance of 159
+ }
+
+ ih, err := InfohashFromString(id)
+ if err != nil {
+ t.Errorf("failed to create infohash: %s\n", err)
+ }
+
+ pq, err := newRoutingTable(*ih, 3)
+ if err != nil {
+ t.Errorf("failed to create kTable: %s\n", err)
+ }
+
+ for i, idt := range tests {
+ iht, err := InfohashFromString(idt)
+ if err != nil {
+ t.Errorf("failed to create infohash: %s\n", err)
+ }
+ addr, _ := net.ResolveUDPAddr("udp", fmt.Sprintf("0.0.0.0:%d", i))
+ pq.add(&remoteNode{id: *iht, address: addr})
+ }
+
+ if len(pq.items) != len(pq.addresses) {
+ t.Errorf("items and addresses out of sync")
+ }
+
+ first := pq.items[0].value.id
+ if first.String() != "d1c5676ae7ac98e8b19f63565905105e3c4c37a3" {
+ t.Errorf("first is %s with distance %d\n", first, ih.Distance(first))
+ }
+}
diff --git a/dht/slab.go b/dht/slab.go
index 737b0b6..a8b4018 100644
--- a/dht/slab.go
+++ b/dht/slab.go
@@ -14,11 +14,11 @@ func newSlab(blockSize int, numBlocks int) slab {
return s
}
-func (s slab) Alloc() (x []byte) {
+func (s slab) alloc() (x []byte) {
return <-s
}
-func (s slab) Free(x []byte) {
+func (s slab) free(x []byte) {
// Check we are using the right dimensions
x = x[:cap(x)]
s <- x
diff --git a/dht/worker.go b/dht/worker.go
deleted file mode 100644
index 36b0519..0000000
--- a/dht/worker.go
+++ /dev/null
@@ -1,274 +0,0 @@
-package dht
-
-import (
- "net"
-
- "github.com/felix/dhtsearch/bencode"
- "github.com/felix/logger"
-)
-
-type dhtWorker struct {
- pool chan chan packet
- packetsOut chan<- packet
- peersOut chan<- Peer
- log logger.Logger
- rTable *routingTable
- quit chan struct{}
-}
-
-func (dw *dhtWorker) run() error {
- packetsIn := make(chan packet)
-
- for {
- dw.pool <- packetsIn
-
- // Wait for work or shutdown
- select {
- case p := <-packetsIn:
- dw.process(p)
- case <-dw.quit:
- dw.log.Warn("worker closing")
- break
- }
- }
-}
-
-func (dw dhtWorker) stop() {
- go func() {
- dw.quit <- struct{}{}
- }()
-}
-
-// Parse a KRPC packet into a message
-func (dw *dhtWorker) process(p packet) {
- data, err := bencode.Decode(p.data)
- if err != nil {
- return
- }
-
- response, err := parseMessage(data)
- if err != nil {
- dw.log.Debug("failed to parse packet", "error", err)
- return
- }
-
- switch response["y"].(string) {
- case "q":
- dw.handleRequest(&p.raddr, response)
- case "r":
- dw.handleResponse(&p.raddr, response)
- case "e":
- dw.handleError(&p.raddr, response)
- default:
- dw.log.Warn("missing request type")
- return
- }
-}
-
-// bencode data and send
-func (dw *dhtWorker) queueMsg(raddr net.UDPAddr, data map[string]interface{}) error {
- b, err := bencode.Encode(data)
- if err != nil {
- return err
- }
- dw.packetsOut <- packet{
- data: b,
- raddr: raddr,
- }
- return nil
-}
-
-// handleRequest handles the requests received from udp.
-func (dw *dhtWorker) handleRequest(addr *net.UDPAddr, m map[string]interface{}) (success bool) {
-
- t := m["t"].(string)
-
- if err := checkKeys(m, [][]string{{"q", "string"}, {"a", "map"}}); err != nil {
-
- //d.queueMsg(addr, makeError(t, protocolError, err.Error()))
- return
- }
-
- q := m["q"].(string)
- a := m["a"].(map[string]interface{})
-
- if err := checkKey(a, "id", "string"); err != nil {
- //d.queueMsg(addr, makeError(t, protocolError, err.Error()))
- return
- }
-
- var ih Infohash
- err := ih.FromString(a["id"].(string))
- if err != nil {
- dw.log.Warn("invalid packet", "infohash", a["id"])
- }
-
- if dw.rTable.id.Equal(ih) {
- return
- }
-
- var rn *remoteNode
- switch q {
- case "ping":
- rn = &remoteNode{address: *addr, id: ih}
- dw.log.Debug("ping", "source", rn, "infohash", ih)
- dw.queueMsg(*addr, makeResponse(t, map[string]interface{}{
- "id": string(dw.rTable.id),
- }))
-
- case "get_peers":
- if err := checkKey(a, "info_hash", "string"); err != nil {
- //dw.queueMsg(addr, makeError(t, protocolError, err.Error()))
- return
- }
- rn = &remoteNode{address: *addr, id: ih}
- err = ih.FromString(a["info_hash"].(string))
- if err != nil {
- dw.log.Warn("invalid packet", "infohash", a["id"])
- }
- dw.log.Debug("get_peers", "source", rn, "infohash", ih)
-
- // Crawling, we have no nodes
- id := dw.rTable.id.GenNeighbour(ih)
- dw.queueMsg(*addr, makeResponse(t, map[string]interface{}{
- "id": string(id),
- "token": ih[:2],
- "nodes": "",
- }))
-
- case "announce_peer":
- if err := checkKeys(a, [][]string{
- {"info_hash", "string"},
- {"port", "int"},
- {"token", "string"}}); err != nil {
-
- //dw.queueMsg(addr, makeError(t, protocolError, err.Error()))
- return
- }
-
- rn = &remoteNode{address: *addr, id: ih}
- dw.log.Debug("announce_peer", "source", rn, "infohash", ih)
-
- // TODO
- if impliedPort, ok := a["implied_port"]; ok &&
- impliedPort.(int) != 0 {
- //port = addr.Port
- }
- // TODO do we reply?
- dw.peersOut <- Peer{*addr, ih}
-
- default:
- //dw.queueMsg(addr, makeError(t, protocolError, "invalid q"))
- return
- }
- dw.rTable.add(rn)
- return true
-}
-
-// handleResponse handles responses received from udp.
-func (dw *dhtWorker) handleResponse(addr *net.UDPAddr, m map[string]interface{}) (success bool) {
-
- //t := m["t"].(string)
-
- if err := checkKey(m, "r", "map"); err != nil {
- return
- }
-
- r := m["r"].(map[string]interface{})
- if err := checkKey(r, "id", "string"); err != nil {
- return
- }
-
- var ih Infohash
- ih.FromString(r["id"].(string))
- rn := &remoteNode{address: *addr, id: ih}
-
- // find_nodes response
- if err := checkKey(r, "nodes", "string"); err == nil {
- nodes := r["nodes"].(string)
- dw.processFindNodeResults(rn, nodes)
- return
- }
-
- // get_peers response
- if err := checkKey(r, "values", "list"); err == nil {
- values := r["values"].([]interface{})
- for _, v := range values {
- addr := compactNodeInfoToString(v.(string))
- dw.log.Debug("unhandled get_peer request", "addres", addr)
- // TODO new peer
- // dw.peersManager.Insert(ih, p)
- }
- }
- dw.rTable.add(rn)
- return true
-}
-
-// handleError handles errors received from udp.
-func (dw *dhtWorker) handleError(addr *net.UDPAddr, m map[string]interface{}) bool {
- if err := checkKey(m, "e", "list"); err != nil {
- return false
- }
-
- e := m["e"].([]interface{})
- if len(e) != 2 {
- return false
- }
- code := e[0].(int64)
- msg := e[1].(string)
- dw.log.Debug("error packet", "ip", addr.IP.String(), "port", addr.Port, "code", code, "error", msg)
-
- return true
-}
-
-// Process another node's response to a find_node query.
-func (dw *dhtWorker) processFindNodeResults(rn *remoteNode, nodeList string) {
- nodeLength := 26
- /*
- if d.config.proto == "udp6" {
- nodeList = m.R.Nodes6
- nodeLength = 38
- } else {
- nodeList = m.R.Nodes
- }
-
- // Not much to do
- if nodeList == "" {
- return
- }
- */
-
- if len(nodeList)%nodeLength != 0 {
- dw.log.Error("node list is wrong length", "length", len(nodeList))
- return
- }
-
- var ih Infohash
- var err error
-
- //dw.log.Debug("got node list", "length", len(nodeList))
-
- // We got a byte array in groups of 26 or 38
- for i := 0; i < len(nodeList); i += nodeLength {
- id := nodeList[i : i+ihLength]
- addr := compactNodeInfoToString(nodeList[i+ihLength : i+nodeLength])
-
- err = ih.FromString(id)
- if err != nil {
- dw.log.Warn("invalid node list")
- continue
- }
-
- if dw.rTable.id.Equal(ih) {
- continue
- }
-
- address, err := net.ResolveUDPAddr("udp4", addr)
- if err != nil {
- dw.log.Error("failed to resolve", "error", err)
- continue
- }
- rn := &remoteNode{address: *address, id: ih}
- dw.rTable.add(rn)
- }
-}
diff --git a/infohash.go b/infohash.go
deleted file mode 100644
index 0327771..0000000
--- a/infohash.go
+++ /dev/null
@@ -1,43 +0,0 @@
-package dhtsearch
-
-import (
- "crypto/sha1"
- "encoding/hex"
- "errors"
- "io"
- "math/rand"
- "time"
-)
-
-const ihLength = 20
-
-func genInfoHash() string {
- random := rand.New(rand.NewSource(time.Now().UnixNano()))
- hash := sha1.New()
- io.WriteString(hash, time.Now().String())
- io.WriteString(hash, string(random.Int()))
- ih := hash.Sum(nil)
- return string(ih)
-}
-
-func genNeighbour(first, second string) string {
- s := second[:10] + first[10:]
- return s
-}
-
-func decodeInfoHash(in string) (b string, err error) {
- var h []byte
- h, err = hex.DecodeString(in)
- if len(h) != ihLength {
- return "", errors.New("invalid length")
- }
- return string(h), err
-}
-
-func isValidInfoHash(id string) bool {
- ih, err := hex.DecodeString(id)
- if err != nil {
- return false
- }
- return len(ih) == ihLength
-}
diff --git a/slab.go b/slab.go
deleted file mode 100644
index 2210c9f..0000000
--- a/slab.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package dhtsearch
-
-// Slab memory allocation
-
-// Initialise the slab as a channel of blocks, allocating them as required and
-// pushing them back on the slab. This reduces garbage collection.
-type slab chan []byte
-
-func newSlab(blockSize int, numBlocks int) slab {
- s := make(slab, numBlocks)
- for i := 0; i < numBlocks; i++ {
- s <- make([]byte, blockSize)
- }
- return s
-}
-
-func (s slab) Alloc() (x []byte) {
- return <-s
-}
-
-func (s slab) Free(x []byte) {
- // Check we are using the right dimensions
- x = x[:cap(x)]
- s <- x
-}
diff --git a/stats.go b/stats.go
deleted file mode 100644
index 0ff5e6f..0000000
--- a/stats.go
+++ /dev/null
@@ -1,32 +0,0 @@
-package dhtsearch
-
-type Stats struct {
- DHTPacketsIn int `json:"dht_packets_in"`
- DHTPacketsOut int `json:"dht_packets_out"`
- DHTPacketsDropped int `json:"dht_packets_dropped"`
- DHTErrors int `json:"dht_errors"`
- DHTCachedPeers int `json:"dht_cached_peers"`
- DHTBytesIn int `json:"dht_bytes_in"`
- DHTBytesOut int `json:"dht_bytes_out"`
- DHTWorkers int `json:"dht_workers"`
- BTBytesInt int `json:"bt_bytes_int"`
- BTBytesOut int `json:"bt_bytes_out"`
- BTWorkers int `json:"bt_workers"`
- PeersAnnounced int `json:"peers_announced"`
- PeersSkipped int `json:"peers_skipped"`
- TorrentsSkipped int `json:"torrents_skipped"`
- TorrentsSaved int `json:"torrents_saved"`
- TorrentsTotal int `json:"torrents_total"`
-}
-
-func (s *Stats) Sub(other *Stats) Stats {
- if other == nil {
- return *s
- }
- var diff Stats
- diff.MessagesIn = s.MessagesIn - other.MessagesIn
- diff.BytesIn = s.BytesIn - other.BytesIn
- diff.MessagesOut = s.MessagesOut - other.MessagesOut
- diff.BytesOut = s.BytesOut - other.BytesOut
- return diff
-}