diff options
Diffstat (limited to 'dht')
| -rw-r--r-- | dht/infohash.go | 64 | ||||
| -rw-r--r-- | dht/infohash_test.go | 30 | ||||
| -rw-r--r-- | dht/krpc.go | 103 | ||||
| -rw-r--r-- | dht/node.go | 232 | ||||
| -rw-r--r-- | dht/options.go | 47 | ||||
| -rw-r--r-- | dht/peer.go | 9 | ||||
| -rw-r--r-- | dht/remote_node.go | 18 | ||||
| -rw-r--r-- | dht/routing_table.go | 57 | ||||
| -rw-r--r-- | dht/slab.go | 25 | ||||
| -rw-r--r-- | dht/worker.go | 274 |
10 files changed, 859 insertions, 0 deletions
diff --git a/dht/infohash.go b/dht/infohash.go new file mode 100644 index 0000000..9c3ea3d --- /dev/null +++ b/dht/infohash.go @@ -0,0 +1,64 @@ +package dht + +import ( + "crypto/sha1" + "encoding/hex" + "io" + "math/rand" + "time" +) + +const ihLength = 20 + +// Infohash - +type Infohash []byte + +func (ih Infohash) String() string { + return hex.EncodeToString(ih) +} +func (ih Infohash) Valid() bool { + // TODO + return len(ih) == 20 +} + +func (ih Infohash) Equal(other Infohash) bool { + if len(ih) != len(other) { + return false + } + for i := 0; i < len(ih); i++ { + if ih[i] != other[i] { + return false + } + } + return true +} + +// FromString - +func (ih *Infohash) FromString(s string) error { + switch len(s) { + case 20: + // Byte string + *ih = Infohash([]byte(s)) + return nil + case 40: + b, err := hex.DecodeString(s) + if err != nil { + return err + } + *ih = Infohash(b) + } + return nil +} + +func (ih Infohash) GenNeighbour(other Infohash) Infohash { + s := append(ih[:10], other[10:]...) + return Infohash(s) +} + +func randomInfoHash() (ih Infohash) { + random := rand.New(rand.NewSource(time.Now().UnixNano())) + hash := sha1.New() + io.WriteString(hash, time.Now().String()) + io.WriteString(hash, string(random.Int())) + return Infohash(hash.Sum(nil)) +} diff --git a/dht/infohash_test.go b/dht/infohash_test.go new file mode 100644 index 0000000..b3223f4 --- /dev/null +++ b/dht/infohash_test.go @@ -0,0 +1,30 @@ +package dht + +import ( + "encoding/hex" + "testing" +) + +func TestInfohashImport(t *testing.T) { + var ih Infohash + + idHex := "5a3ce1c14e7a08645677bbd1cfe7d8f956d53256" + err := ih.FromString(idHex) + if err != nil { + t.Errorf("FromString failed with %s", err) + } + + idBytes, err := hex.DecodeString(idHex) + + ih2 := Infohash(idBytes) + if !ih.Equal(ih2) { + t.Errorf("expected %s to equal %s", ih, ih2) + } +} + +func TestInfohashLength(t *testing.T) { + ih := randomInfoHash() + if len(ih) != 20 { + t.Errorf("%s as string should be length 20, got %d", ih, len(ih)) + } +} diff --git a/dht/krpc.go b/dht/krpc.go new file mode 100644 index 0000000..a926e81 --- /dev/null +++ b/dht/krpc.go @@ -0,0 +1,103 @@ +package dht + +import ( + "errors" + "fmt" + "math/rand" + "net" +) + +const transIDBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func newTransactionID() string { + b := make([]byte, 2) + for i := range b { + b[i] = transIDBytes[rand.Int63()%int64(len(transIDBytes))] + } + return string(b) +} + +// makeQuery returns a query-formed data. +func makeQuery(t, q string, a map[string]interface{}) map[string]interface{} { + return map[string]interface{}{ + "t": t, + "y": "q", + "q": q, + "a": a, + } +} + +// makeResponse returns a response-formed data. +func makeResponse(t string, r map[string]interface{}) map[string]interface{} { + return map[string]interface{}{ + "t": t, + "y": "r", + "r": r, + } +} + +// parseMessage parses the basic data received from udp. +// It returns a map value. +func parseMessage(data interface{}) (map[string]interface{}, error) { + response, ok := data.(map[string]interface{}) + if !ok { + return nil, errors.New("response is not dict") + } + + if err := checkKeys(response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil { + return nil, err + } + + return response, nil +} + +// parseKeys parses keys. It just wraps parseKey. +func checkKeys(data map[string]interface{}, pairs [][]string) (err error) { + for _, args := range pairs { + key, t := args[0], args[1] + if err = checkKey(data, key, t); err != nil { + break + } + } + return err +} + +// parseKey parses the key in dict data. `t` is type of the keyed value. +// It's one of "int", "string", "map", "list". +func checkKey(data map[string]interface{}, key string, t string) error { + val, ok := data[key] + if !ok { + return fmt.Errorf("krpc: missing key %s", key) + } + + switch t { + case "string": + _, ok = val.(string) + case "int": + _, ok = val.(int) + case "map": + _, ok = val.(map[string]interface{}) + case "list": + _, ok = val.([]interface{}) + default: + return errors.New("krpc: invalid type") + } + + if !ok { + return errors.New("krpc: key type mismatch") + } + + return nil +} + +// Swiped from nictuku +func compactNodeInfoToString(cni string) string { + if len(cni) == 6 { + return fmt.Sprintf("%d.%d.%d.%d:%d", cni[0], cni[1], cni[2], cni[3], (uint16(cni[4])<<8)|uint16(cni[5])) + } else if len(cni) == 18 { + b := []byte(cni[:16]) + return fmt.Sprintf("[%s]:%d", net.IP.String(b), (uint16(cni[16])<<8)|uint16(cni[17])) + } else { + return "" + } +} diff --git a/dht/node.go b/dht/node.go new file mode 100644 index 0000000..3f8f349 --- /dev/null +++ b/dht/node.go @@ -0,0 +1,232 @@ +package dht + +import ( + "net" + "strconv" + "time" + + "github.com/felix/dhtsearch/bencode" + "github.com/felix/logger" +) + +var ( + routers = []string{ + "router.bittorrent.com:6881", + "dht.transmissionbt.com:6881", + "router.utorrent.com:6881", + "dht.aelitis.com:6881", + } +) + +// Node joins the DHT network +type Node struct { + id Infohash + address string + port int + conn *net.UDPConn + pool chan chan packet + rTable *routingTable + workers []*dhtWorker + udpTimeout int + packetsOut chan packet + peersOut chan Peer + closing chan chan error + log logger.Logger + //table routingTable + + // OnAnnoucePeer is called for each peer that announces itself + OnAnnoucePeer func(p *Peer) +} + +// NewNode creates a new DHT node +func NewNode(opts ...Option) (n *Node, err error) { + + id := randomInfoHash() + + n = &Node{ + id: id, + address: "0.0.0.0", + port: 6881, + rTable: newRoutingTable(id), + workers: make([]*dhtWorker, 1), + closing: make(chan chan error), + log: logger.New(&logger.Options{Name: "dht"}), + peersOut: make(chan Peer), + } + + // Set variadic options passed + for _, option := range opts { + err = option(n) + if err != nil { + return nil, err + } + } + + return n, nil +} + +// Close stuff +func (n *Node) Close() error { + n.log.Warn("node closing") + errCh := make(chan error) + n.closing <- errCh + // Signal workers + for _, w := range n.workers { + w.stop() + } + return <-errCh +} + +// Run starts the node on the DHT +func (n *Node) Run() chan Peer { + listener, err := net.ListenPacket("udp4", n.address+":"+strconv.Itoa(n.port)) + if err != nil { + n.log.Error("failed to listen", "error", err) + return nil + } + n.conn = listener.(*net.UDPConn) + n.port = n.conn.LocalAddr().(*net.UDPAddr).Port + n.log.Info("listening", "id", n.id, "address", n.address, "port", n.port) + + // Worker pool + n.pool = make(chan chan packet) + // Packets onto the network + n.packetsOut = make(chan packet, 512) + + // Create a slab for allocation + byteSlab := newSlab(8192, 10) + + // Start our workers + n.log.Debug("starting workers", "count", len(n.workers)) + for i := 0; i < len(n.workers); i++ { + w := &dhtWorker{ + pool: n.pool, + packetsOut: n.packetsOut, + peersOut: n.peersOut, + rTable: n.rTable, + quit: make(chan struct{}), + log: n.log.Named("worker"), + } + go w.run() + n.workers[i] = w + } + + n.log.Debug("starting packet writer") + // Start writing packets from channel to DHT + go func() { + var p packet + for { + select { + case p = <-n.packetsOut: + //n.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(n.udpTimeout))) + _, err := n.conn.WriteToUDP(p.data, &p.raddr) + if err != nil { + // TODO remove from routing or add to blacklist? + n.log.Warn("failed to write packet", "error", err) + } + } + } + }() + + n.log.Debug("starting packet reader") + // Start reading packets + go func() { + n.bootstrap() + + // TODO configurable + ticker := time.Tick(10 * time.Second) + + // Send packets from conn to workers + for { + select { + case errCh := <-n.closing: + // TODO + errCh <- nil + case pCh := <-n.pool: + go func() { + b := byteSlab.Alloc() + c, addr, err := n.conn.ReadFromUDP(b) + if err != nil { + n.log.Warn("UDP read error", "error", err) + return + } + + // Chop and send + pCh <- packet{ + data: b[0:c], + raddr: *addr, + } + byteSlab.Free(b) + }() + + case <-ticker: + go func() { + if n.rTable.isEmpty() { + n.bootstrap() + } else { + n.makeNeighbours() + } + }() + } + } + }() + return n.peersOut +} + +func (n *Node) bootstrap() { + n.log.Debug("bootstrapping") + for _, s := range routers { + addr, err := net.ResolveUDPAddr("udp4", s) + if err != nil { + n.log.Error("failed to parse bootstrap address", "error", err) + return + } + rn := &remoteNode{address: *addr} + n.findNode(rn, n.id) + } +} + +func (n *Node) makeNeighbours() { + n.log.Debug("making neighbours") + for _, rn := range n.rTable.getNodes() { + n.findNode(rn, n.id) + } + n.rTable.refresh() +} + +func (n Node) findNode(rn *remoteNode, id Infohash) { + target := randomInfoHash() + n.sendMsg(rn, "find_node", map[string]interface{}{ + "id": string(id), + "target": string(target), + }) +} + +// ping sends ping query to the chan. +func (n *Node) ping(rn *remoteNode) { + id := n.id.GenNeighbour(rn.id) + n.sendMsg(rn, "ping", map[string]interface{}{ + "id": string(id), + }) +} + +func (n Node) sendMsg(rn *remoteNode, qType string, a map[string]interface{}) error { + // Stop if sending to self + if rn.id.Equal(n.id) { + return nil + } + + t := newTransactionID() + //n.log.Debug("sending message", "type", qType, "remote", rn) + + data := makeQuery(t, qType, a) + b, err := bencode.Encode(data) + if err != nil { + return err + } + n.packetsOut <- packet{ + data: b, + raddr: rn.address, + } + return nil +} diff --git a/dht/options.go b/dht/options.go new file mode 100644 index 0000000..03a85bf --- /dev/null +++ b/dht/options.go @@ -0,0 +1,47 @@ +package dht + +import ( + "github.com/felix/logger" +) + +type Option func(*Node) error + +// SetAddress sets the IP address to listen on +func SetAddress(ip string) Option { + return func(n *Node) error { + n.address = ip + return nil + } +} + +// SetPort sets the port to listen on +func SetPort(p int) Option { + return func(n *Node) error { + n.port = p + return nil + } +} + +// SetWorkers sets the number of workers +func SetWorkers(c int) Option { + return func(n *Node) error { + n.workers = make([]*dhtWorker, c) + return nil + } +} + +// SetUDPTimeout sets the number of seconds to wait for UDP connections +func SetUDPTimeout(s int) Option { + return func(n *Node) error { + n.udpTimeout = s + return nil + } +} + +// SetLogger sets the number of workers +func SetLogger(l logger.Logger) Option { + return func(n *Node) error { + n.log = l + return nil + } +} diff --git a/dht/peer.go b/dht/peer.go new file mode 100644 index 0000000..a801331 --- /dev/null +++ b/dht/peer.go @@ -0,0 +1,9 @@ +package dht + +import "net" + +// Peer on DHT network +type Peer struct { + Address net.UDPAddr + ID Infohash +} diff --git a/dht/remote_node.go b/dht/remote_node.go new file mode 100644 index 0000000..8a8d4a2 --- /dev/null +++ b/dht/remote_node.go @@ -0,0 +1,18 @@ +package dht + +import ( + "fmt" + "net" + //"time" +) + +type remoteNode struct { + address net.UDPAddr + id Infohash + //lastSeen time.Time +} + +// String implements fmt.Stringer +func (r *remoteNode) String() string { + return fmt.Sprintf("%s:%d", r.address.IP.String(), r.address.Port) +} diff --git a/dht/routing_table.go b/dht/routing_table.go new file mode 100644 index 0000000..0252519 --- /dev/null +++ b/dht/routing_table.go @@ -0,0 +1,57 @@ +package dht + +import ( + "net" + "sync" +) + +// Keep it simple for now +type routingTable struct { + id Infohash + address net.UDPAddr + nodes []*remoteNode + max int + sync.Mutex +} + +func newRoutingTable(id Infohash) *routingTable { + k := &routingTable{id: id, max: 4000} + k.refresh() + return k +} + +func (k *routingTable) add(rn *remoteNode) { + k.Lock() + defer k.Unlock() + + // Check IP and ports are valid and not self + if (rn.address.String() == k.address.String() && rn.address.Port == k.address.Port) || !rn.id.Valid() || rn.id.Equal(k.id) { + return + } + k.nodes = append(k.nodes, rn) +} + +func (k *routingTable) getNodes() []*remoteNode { + k.Lock() + defer k.Unlock() + return k.nodes +} + +func (k *routingTable) isEmpty() bool { + k.Lock() + defer k.Unlock() + return len(k.nodes) == 0 +} + +func (k *routingTable) isFull() bool { + k.Lock() + defer k.Unlock() + return len(k.nodes) >= k.max +} + +// For now +func (k *routingTable) refresh() { + k.Lock() + defer k.Unlock() + k.nodes = make([]*remoteNode, 0) +} diff --git a/dht/slab.go b/dht/slab.go new file mode 100644 index 0000000..737b0b6 --- /dev/null +++ b/dht/slab.go @@ -0,0 +1,25 @@ +package dht + +// Slab memory allocation + +// Initialise the slab as a channel of blocks, allocating them as required and +// pushing them back on the slab. This reduces garbage collection. +type slab chan []byte + +func newSlab(blockSize int, numBlocks int) slab { + s := make(slab, numBlocks) + for i := 0; i < numBlocks; i++ { + s <- make([]byte, blockSize) + } + return s +} + +func (s slab) Alloc() (x []byte) { + return <-s +} + +func (s slab) Free(x []byte) { + // Check we are using the right dimensions + x = x[:cap(x)] + s <- x +} diff --git a/dht/worker.go b/dht/worker.go new file mode 100644 index 0000000..36b0519 --- /dev/null +++ b/dht/worker.go @@ -0,0 +1,274 @@ +package dht + +import ( + "net" + + "github.com/felix/dhtsearch/bencode" + "github.com/felix/logger" +) + +type dhtWorker struct { + pool chan chan packet + packetsOut chan<- packet + peersOut chan<- Peer + log logger.Logger + rTable *routingTable + quit chan struct{} +} + +func (dw *dhtWorker) run() error { + packetsIn := make(chan packet) + + for { + dw.pool <- packetsIn + + // Wait for work or shutdown + select { + case p := <-packetsIn: + dw.process(p) + case <-dw.quit: + dw.log.Warn("worker closing") + break + } + } +} + +func (dw dhtWorker) stop() { + go func() { + dw.quit <- struct{}{} + }() +} + +// Parse a KRPC packet into a message +func (dw *dhtWorker) process(p packet) { + data, err := bencode.Decode(p.data) + if err != nil { + return + } + + response, err := parseMessage(data) + if err != nil { + dw.log.Debug("failed to parse packet", "error", err) + return + } + + switch response["y"].(string) { + case "q": + dw.handleRequest(&p.raddr, response) + case "r": + dw.handleResponse(&p.raddr, response) + case "e": + dw.handleError(&p.raddr, response) + default: + dw.log.Warn("missing request type") + return + } +} + +// bencode data and send +func (dw *dhtWorker) queueMsg(raddr net.UDPAddr, data map[string]interface{}) error { + b, err := bencode.Encode(data) + if err != nil { + return err + } + dw.packetsOut <- packet{ + data: b, + raddr: raddr, + } + return nil +} + +// handleRequest handles the requests received from udp. +func (dw *dhtWorker) handleRequest(addr *net.UDPAddr, m map[string]interface{}) (success bool) { + + t := m["t"].(string) + + if err := checkKeys(m, [][]string{{"q", "string"}, {"a", "map"}}); err != nil { + + //d.queueMsg(addr, makeError(t, protocolError, err.Error())) + return + } + + q := m["q"].(string) + a := m["a"].(map[string]interface{}) + + if err := checkKey(a, "id", "string"); err != nil { + //d.queueMsg(addr, makeError(t, protocolError, err.Error())) + return + } + + var ih Infohash + err := ih.FromString(a["id"].(string)) + if err != nil { + dw.log.Warn("invalid packet", "infohash", a["id"]) + } + + if dw.rTable.id.Equal(ih) { + return + } + + var rn *remoteNode + switch q { + case "ping": + rn = &remoteNode{address: *addr, id: ih} + dw.log.Debug("ping", "source", rn, "infohash", ih) + dw.queueMsg(*addr, makeResponse(t, map[string]interface{}{ + "id": string(dw.rTable.id), + })) + + case "get_peers": + if err := checkKey(a, "info_hash", "string"); err != nil { + //dw.queueMsg(addr, makeError(t, protocolError, err.Error())) + return + } + rn = &remoteNode{address: *addr, id: ih} + err = ih.FromString(a["info_hash"].(string)) + if err != nil { + dw.log.Warn("invalid packet", "infohash", a["id"]) + } + dw.log.Debug("get_peers", "source", rn, "infohash", ih) + + // Crawling, we have no nodes + id := dw.rTable.id.GenNeighbour(ih) + dw.queueMsg(*addr, makeResponse(t, map[string]interface{}{ + "id": string(id), + "token": ih[:2], + "nodes": "", + })) + + case "announce_peer": + if err := checkKeys(a, [][]string{ + {"info_hash", "string"}, + {"port", "int"}, + {"token", "string"}}); err != nil { + + //dw.queueMsg(addr, makeError(t, protocolError, err.Error())) + return + } + + rn = &remoteNode{address: *addr, id: ih} + dw.log.Debug("announce_peer", "source", rn, "infohash", ih) + + // TODO + if impliedPort, ok := a["implied_port"]; ok && + impliedPort.(int) != 0 { + //port = addr.Port + } + // TODO do we reply? + dw.peersOut <- Peer{*addr, ih} + + default: + //dw.queueMsg(addr, makeError(t, protocolError, "invalid q")) + return + } + dw.rTable.add(rn) + return true +} + +// handleResponse handles responses received from udp. +func (dw *dhtWorker) handleResponse(addr *net.UDPAddr, m map[string]interface{}) (success bool) { + + //t := m["t"].(string) + + if err := checkKey(m, "r", "map"); err != nil { + return + } + + r := m["r"].(map[string]interface{}) + if err := checkKey(r, "id", "string"); err != nil { + return + } + + var ih Infohash + ih.FromString(r["id"].(string)) + rn := &remoteNode{address: *addr, id: ih} + + // find_nodes response + if err := checkKey(r, "nodes", "string"); err == nil { + nodes := r["nodes"].(string) + dw.processFindNodeResults(rn, nodes) + return + } + + // get_peers response + if err := checkKey(r, "values", "list"); err == nil { + values := r["values"].([]interface{}) + for _, v := range values { + addr := compactNodeInfoToString(v.(string)) + dw.log.Debug("unhandled get_peer request", "addres", addr) + // TODO new peer + // dw.peersManager.Insert(ih, p) + } + } + dw.rTable.add(rn) + return true +} + +// handleError handles errors received from udp. +func (dw *dhtWorker) handleError(addr *net.UDPAddr, m map[string]interface{}) bool { + if err := checkKey(m, "e", "list"); err != nil { + return false + } + + e := m["e"].([]interface{}) + if len(e) != 2 { + return false + } + code := e[0].(int64) + msg := e[1].(string) + dw.log.Debug("error packet", "ip", addr.IP.String(), "port", addr.Port, "code", code, "error", msg) + + return true +} + +// Process another node's response to a find_node query. +func (dw *dhtWorker) processFindNodeResults(rn *remoteNode, nodeList string) { + nodeLength := 26 + /* + if d.config.proto == "udp6" { + nodeList = m.R.Nodes6 + nodeLength = 38 + } else { + nodeList = m.R.Nodes + } + + // Not much to do + if nodeList == "" { + return + } + */ + + if len(nodeList)%nodeLength != 0 { + dw.log.Error("node list is wrong length", "length", len(nodeList)) + return + } + + var ih Infohash + var err error + + //dw.log.Debug("got node list", "length", len(nodeList)) + + // We got a byte array in groups of 26 or 38 + for i := 0; i < len(nodeList); i += nodeLength { + id := nodeList[i : i+ihLength] + addr := compactNodeInfoToString(nodeList[i+ihLength : i+nodeLength]) + + err = ih.FromString(id) + if err != nil { + dw.log.Warn("invalid node list") + continue + } + + if dw.rTable.id.Equal(ih) { + continue + } + + address, err := net.ResolveUDPAddr("udp4", addr) + if err != nil { + dw.log.Error("failed to resolve", "error", err) + continue + } + rn := &remoteNode{address: *address, id: ih} + dw.rTable.add(rn) + } +} |
