diff options
| author | Felix Hanley <felix@userspace.com.au> | 2018-02-21 04:20:06 +0000 |
|---|---|---|
| committer | Felix Hanley <felix@userspace.com.au> | 2018-02-21 04:21:39 +0000 |
| commit | e9adf3a2bf8b81615275a6705b7957e43753f0ec (patch) | |
| tree | 1eaeb5081f3914a8ffa936d96ad1f1548c9aeb2f /dht | |
| parent | 020a8f9ec7e541d284ddb65111aafe42547927e5 (diff) | |
| download | dhtsearch-e9adf3a2bf8b81615275a6705b7957e43753f0ec.tar.gz dhtsearch-e9adf3a2bf8b81615275a6705b7957e43753f0ec.tar.bz2 | |
Seperate shared packages
Diffstat (limited to 'dht')
| -rw-r--r-- | dht/infohash.go | 2 | ||||
| -rw-r--r-- | dht/infohash_test.go | 2 | ||||
| -rw-r--r-- | dht/krpc.go | 177 | ||||
| -rw-r--r-- | dht/krpc_test.go | 30 | ||||
| -rw-r--r-- | dht/messages.go | 38 | ||||
| -rw-r--r-- | dht/node.go | 77 | ||||
| -rw-r--r-- | dht/peer.go | 11 | ||||
| -rw-r--r-- | dht/remote_node.go | 7 | ||||
| -rw-r--r-- | dht/routing_table.go | 6 | ||||
| -rw-r--r-- | dht/routing_table_test.go | 2 |
10 files changed, 68 insertions, 284 deletions
diff --git a/dht/infohash.go b/dht/infohash.go index 6d4596d..cb5170e 100644 --- a/dht/infohash.go +++ b/dht/infohash.go @@ -82,7 +82,7 @@ func generateNeighbour(first, second Infohash) Infohash { return Infohash(s) } -func randomInfoHash() (ih Infohash) { +func GenInfohash() (ih Infohash) { random := rand.New(rand.NewSource(time.Now().UnixNano())) hash := sha1.New() io.WriteString(hash, time.Now().String()) diff --git a/dht/infohash_test.go b/dht/infohash_test.go index 1574b19..6d627fc 100644 --- a/dht/infohash_test.go +++ b/dht/infohash_test.go @@ -39,7 +39,7 @@ func TestInfohashImport(t *testing.T) { } func TestInfohashLength(t *testing.T) { - ih := randomInfoHash() + ih := GenInfohash() if len(ih) != 20 { t.Errorf("%s as string should be length 20, got %d", ih, len(ih)) } diff --git a/dht/krpc.go b/dht/krpc.go deleted file mode 100644 index 2a7c103..0000000 --- a/dht/krpc.go +++ /dev/null @@ -1,177 +0,0 @@ -package dht - -import ( - "errors" - "fmt" - "math/rand" - "net" - "strconv" -) - -const transIDBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - -func newTransactionID() string { - b := make([]byte, 2) - for i := range b { - b[i] = transIDBytes[rand.Int63()%int64(len(transIDBytes))] - } - return string(b) -} - -// makeQuery returns a query-formed data. -func makeQuery(t, q string, a map[string]interface{}) map[string]interface{} { - return map[string]interface{}{ - "t": t, - "y": "q", - "q": q, - "a": a, - } -} - -// makeResponse returns a response-formed data. -func makeResponse(t string, r map[string]interface{}) map[string]interface{} { - return map[string]interface{}{ - "t": t, - "y": "r", - "r": r, - } -} - -func getStringKey(data map[string]interface{}, key string) (string, error) { - val, ok := data[key] - if !ok { - return "", fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.(string) - if !ok { - return "", fmt.Errorf("krpc: key type mismatch") - } - return out, nil -} - -func getIntKey(data map[string]interface{}, key string) (int, error) { - val, ok := data[key] - if !ok { - return 0, fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.(int) - if !ok { - return 0, fmt.Errorf("krpc: key type mismatch") - } - return out, nil -} - -func getMapKey(data map[string]interface{}, key string) (map[string]interface{}, error) { - val, ok := data[key] - if !ok { - return nil, fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("krpc: key type mismatch") - } - return out, nil -} - -func getListKey(data map[string]interface{}, key string) ([]interface{}, error) { - val, ok := data[key] - if !ok { - return nil, fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.([]interface{}) - if !ok { - return nil, fmt.Errorf("krpc: key type mismatch") - } - return out, nil -} - -// parseKeys parses keys. It just wraps parseKey. -func checkKeys(data map[string]interface{}, pairs [][]string) (err error) { - for _, args := range pairs { - key, t := args[0], args[1] - if err = checkKey(data, key, t); err != nil { - break - } - } - return err -} - -// parseKey parses the key in dict data. `t` is type of the keyed value. -// It's one of "int", "string", "map", "list". -func checkKey(data map[string]interface{}, key string, t string) error { - val, ok := data[key] - if !ok { - return fmt.Errorf("krpc: missing key %s", key) - } - - switch t { - case "string": - _, ok = val.(string) - case "int": - _, ok = val.(int) - case "map": - _, ok = val.(map[string]interface{}) - case "list": - _, ok = val.([]interface{}) - default: - return errors.New("krpc: invalid type") - } - - if !ok { - return errors.New("krpc: key type mismatch") - } - - return nil -} - -// Swiped from nictuku -func decodeCompactNodeAddr(cni string) string { - if len(cni) == 6 { - return fmt.Sprintf("%d.%d.%d.%d:%d", cni[0], cni[1], cni[2], cni[3], (uint16(cni[4])<<8)|uint16(cni[5])) - } else if len(cni) == 18 { - b := []byte(cni[:16]) - return fmt.Sprintf("[%s]:%d", net.IP.String(b), (uint16(cni[16])<<8)|uint16(cni[17])) - } else { - return "" - } -} - -func encodeCompactNodeAddr(addr string) string { - var a []uint8 - host, port, _ := net.SplitHostPort(addr) - ip := net.ParseIP(host) - if ip == nil { - return "" - } - aa, _ := strconv.ParseUint(port, 10, 16) - c := uint16(aa) - if ip2 := net.IP.To4(ip); ip2 != nil { - a = make([]byte, net.IPv4len+2, net.IPv4len+2) - copy(a, ip2[0:net.IPv4len]) // ignore bytes IPv6 bytes if it's IPv4. - a[4] = byte(c >> 8) - a[5] = byte(c) - } else { - a = make([]byte, net.IPv6len+2, net.IPv6len+2) - copy(a, ip) - a[16] = byte(c >> 8) - a[17] = byte(c) - } - return string(a) -} - -func int2bytes(val int64) []byte { - data, j := make([]byte, 8), -1 - for i := 0; i < 8; i++ { - shift := uint64((7 - i) * 8) - data[i] = byte((val & (0xff << shift)) >> shift) - - if j == -1 && data[i] != 0 { - j = i - } - } - - if j != -1 { - return data[j:] - } - return data[:1] -} diff --git a/dht/krpc_test.go b/dht/krpc_test.go deleted file mode 100644 index 5bc8373..0000000 --- a/dht/krpc_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package dht - -import ( - "encoding/hex" - "testing" -) - -func TestCompactNodeAddr(t *testing.T) { - - tests := []struct { - in string - out string - }{ - {in: "192.168.1.1:6881", out: "c0a801011ae1"}, - {in: "[2001:9372:434a:800::2]:6881", out: "20019372434a080000000000000000021ae1"}, - } - - for _, tt := range tests { - r := encodeCompactNodeAddr(tt.in) - out, _ := hex.DecodeString(tt.out) - if r != string(out) { - t.Errorf("encodeCompactNodeAddr(%s) => %x, expected %s", tt.in, r, tt.out) - } - - s := decodeCompactNodeAddr(r) - if s != tt.in { - t.Errorf("decodeCompactNodeAddr(%x) => %s, expected %s", r, s, tt.in) - } - } -} diff --git a/dht/messages.go b/dht/messages.go index 023e27d..9305b6f 100644 --- a/dht/messages.go +++ b/dht/messages.go @@ -3,28 +3,29 @@ package dht import ( "fmt" "net" - //"strings" + + "github.com/felix/dhtsearch/krpc" ) func (n *Node) onPingQuery(rn remoteNode, msg map[string]interface{}) error { - t, err := getStringKey(msg, "t") + t, err := krpc.GetString(msg, "t") if err != nil { return err } - n.queueMsg(rn, makeResponse(t, map[string]interface{}{ + n.queueMsg(rn, krpc.MakeResponse(t, map[string]interface{}{ "id": string(n.id), })) return nil } func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error { - a, err := getMapKey(msg, "a") + a, err := krpc.GetMap(msg, "a") if err != nil { return err } // This is the ih of the torrent - torrent, err := getStringKey(a, "info_hash") + torrent, err := krpc.GetString(a, "info_hash") if err != nil { return err } @@ -40,7 +41,7 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error nodes := n.rTable.get(8) compactNS := []string{} for _, rn := range nodes { - ns := encodeCompactNodeAddr(rn.address.String()) + ns := encodeCompactNodeAddr(rn.addr.String()) if ns == "" { n.log.Warn("failed to compact node", "address", rn.address.String()) continue @@ -50,7 +51,7 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error */ t := msg["t"].(string) - n.queueMsg(rn, makeResponse(t, map[string]interface{}{ + n.queueMsg(rn, krpc.MakeResponse(t, map[string]interface{}{ "id": string(neighbour), "token": token, "nodes": "", @@ -60,7 +61,7 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error //nodes := n.rTable.get(50) /* fmt.Printf("sending get_peers for %s to %d nodes\n", *th, len(nodes)) - q := makeQuery(newTransactionID(), "get_peers", map[string]interface{}{ + q := krpc.MakeQuery(newTransactionID(), "get_peers", map[string]interface{}{ "id": string(id), "info_hash": string(*th), }) @@ -72,22 +73,17 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error } func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) error { - a, err := getMapKey(msg, "a") + a, err := krpc.GetMap(msg, "a") if err != nil { return err } - err = checkKeys(a, [][]string{ - {"info_hash", "string"}, - {"port", "int"}, - {"token", "string"}, - }) n.log.Debug("announce_peer", "source", rn) - if impliedPort, err := getIntKey(a, "implied_port"); err == nil { + if impliedPort, err := krpc.GetInt(a, "implied_port"); err == nil { if impliedPort != 0 { // Use the port in the message - host, _, err := net.SplitHostPort(rn.address.String()) + host, _, err := net.SplitHostPort(rn.addr.String()) if err != nil { return err } @@ -96,13 +92,13 @@ func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) er return fmt.Errorf("ignoring port 0") } addr, err := net.ResolveUDPAddr(n.family, fmt.Sprintf("%s:%d", host, newPort)) - rn = remoteNode{address: addr, id: rn.id} + rn = remoteNode{addr: addr, id: rn.id} } } // TODO do we reply? - ihStr, err := getStringKey(a, "info_hash") + ihStr, err := krpc.GetString(a, "info_hash") if err != nil { return err } @@ -111,8 +107,7 @@ func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) er n.log.Warn("invalid torrent", "infohash", ihStr) } - p := Peer{Node: rn, Infohash: *ih} - n.log.Info("anounce_peer", p) + p := Peer{Addr: rn.addr, ID: rn.id, Infohash: *ih} if n.OnAnnouncePeer != nil { go n.OnAnnouncePeer(p) } @@ -121,9 +116,6 @@ func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) er func (n *Node) onFindNodeResponse(rn remoteNode, msg map[string]interface{}) { r := msg["r"].(map[string]interface{}) - if err := checkKey(r, "id", "string"); err != nil { - return - } nodes := r["nodes"].(string) n.processFindNodeResults(rn, nodes) } diff --git a/dht/node.go b/dht/node.go index db61eed..f8e6113 100644 --- a/dht/node.go +++ b/dht/node.go @@ -7,6 +7,7 @@ import ( "time" "github.com/felix/dhtsearch/bencode" + "github.com/felix/dhtsearch/krpc" "github.com/felix/logger" "golang.org/x/time/rate" ) @@ -42,7 +43,7 @@ type Node struct { // NewNode creates a new DHT node func NewNode(opts ...Option) (n *Node, err error) { - id := randomInfoHash() + id := GenInfohash() k, err := newRoutingTable(id, 2000) if err != nil { @@ -159,7 +160,7 @@ func (n *Node) bootstrap() { n.log.Error("failed to parse bootstrap address", "error", err) continue } - rn := &remoteNode{address: addr} + rn := &remoteNode{addr: addr} n.findNode(rn, n.id) } } @@ -186,7 +187,7 @@ func (n *Node) packetWriter() { } func (n *Node) findNode(rn *remoteNode, id Infohash) { - target := randomInfoHash() + target := GenInfohash() n.sendQuery(rn, "find_node", map[string]interface{}{ "id": string(id), "target": string(target), @@ -207,10 +208,9 @@ func (n *Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) return nil } - t := newTransactionID() - //n.log.Debug("sending message", "type", qType, "remote", rn) + t := krpc.NewTransactionID() - data := makeQuery(t, qType, a) + data := krpc.MakeQuery(t, qType, a) b, err := bencode.Encode(data) if err != nil { return err @@ -218,43 +218,38 @@ func (n *Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) //fmt.Printf("sending %s to %s\n", qType, rn.String()) n.packetsOut <- packet{ data: b, - raddr: rn.address, + raddr: rn.addr, } return nil } // Parse a KRPC packet into a message -func (n *Node) processPacket(p packet) { - data, err := bencode.Decode(p.data) +func (n *Node) processPacket(p packet) error { + response, _, err := bencode.DecodeDict(p.data, 0) if err != nil { - return - } - - response, ok := data.(map[string]interface{}) - if !ok { - n.log.Debug("failed to parse packet", "error", "response is not dict") - return + return err } - if err := checkKeys(response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil { - n.log.Debug("failed to parse packet", "error", err) - return + y, err := krpc.GetString(response, "y") + if err != nil { + return err } - switch response["y"].(string) { + switch y { case "q": err = n.handleRequest(p.raddr, response) case "r": err = n.handleResponse(p.raddr, response) case "e": - n.handleError(p.raddr, response) + err = n.handleError(p.raddr, response) default: n.log.Warn("missing request type") - return + return nil } if err != nil { n.log.Warn("failed to process packet", "error", err) } + return err } // bencode data and send @@ -265,24 +260,24 @@ func (n *Node) queueMsg(rn remoteNode, data map[string]interface{}) error { } n.packetsOut <- packet{ data: b, - raddr: rn.address, + raddr: rn.addr, } return nil } // handleRequest handles the requests received from udp. func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error { - q, err := getStringKey(m, "q") + q, err := krpc.GetString(m, "q") if err != nil { return err } - a, err := getMapKey(m, "a") + a, err := krpc.GetMap(m, "a") if err != nil { return err } - id, err := getStringKey(a, "id") + id, err := krpc.GetString(a, "id") if err != nil { return err } @@ -296,7 +291,7 @@ func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error { return nil } - rn := &remoteNode{address: addr, id: *ih} + rn := &remoteNode{addr: addr, id: *ih} switch q { case "ping": @@ -318,11 +313,11 @@ func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error { // handleResponse handles responses received from udp. func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { - r, err := getMapKey(m, "r") + r, err := krpc.GetMap(m, "r") if err != nil { return err } - id, err := getStringKey(r, "id") + id, err := krpc.GetString(r, "id") if err != nil { return err } @@ -331,9 +326,9 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { return err } - rn := &remoteNode{address: addr, id: *ih} + rn := &remoteNode{addr: addr, id: *ih} - nodes, err := getStringKey(r, "nodes") + nodes, err := krpc.GetString(r, "nodes") // find_nodes/get_peers response with nodes if err == nil { n.onFindNodeResponse(*rn, m) @@ -342,12 +337,12 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { return nil } - values, err := getListKey(r, "values") + values, err := krpc.GetList(r, "values") // get_peers response if err == nil { n.log.Debug("get_peers response", "source", rn) for _, v := range values { - addr := decodeCompactNodeAddr(v.(string)) + addr := krpc.DecodeCompactNodeAddr(v.(string)) n.log.Debug("unhandled get_peer request", "addres", addr) // TODO new peer needs to be matched to previous get_peers request @@ -359,20 +354,20 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { } // handleError handles errors received from udp. -func (n *Node) handleError(addr net.Addr, m map[string]interface{}) bool { - if err := checkKey(m, "e", "list"); err != nil { - return false +func (n *Node) handleError(addr net.Addr, m map[string]interface{}) error { + e, err := krpc.GetList(m, "e") + if err != nil { + return err } - e := m["e"].([]interface{}) if len(e) != 2 { - return false + return fmt.Errorf("error packet wrong length %d", len(e)) } code := e[0].(int64) msg := e[1].(string) n.log.Debug("error packet", "address", addr.String(), "code", code, "error", msg) - return true + return nil } // Process another node's response to a find_node query. @@ -392,7 +387,7 @@ func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) { // We got a byte array in groups of 26 or 38 for i := 0; i < len(nodeList); i += nodeLength { id := nodeList[i : i+ihLength] - addrStr := decodeCompactNodeAddr(nodeList[i+ihLength : i+nodeLength]) + addrStr := krpc.DecodeCompactNodeAddr(nodeList[i+ihLength : i+nodeLength]) ih, err := InfohashFromString(id) if err != nil { @@ -406,7 +401,7 @@ func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) { continue } - rn := &remoteNode{address: addr, id: *ih} + rn := &remoteNode{addr: addr, id: *ih} n.rTable.add(rn) } } diff --git a/dht/peer.go b/dht/peer.go index f9669ba..42e8438 100644 --- a/dht/peer.go +++ b/dht/peer.go @@ -1,13 +1,18 @@ package dht -import "fmt" +import ( + "fmt" + "net" +) // Peer on DHT network type Peer struct { - Node remoteNode + Addr net.Addr + ID Infohash Infohash Infohash } +// String implements fmt.Stringer func (p Peer) String() string { - return fmt.Sprintf("%s (%s)", p.Infohash, p.Node) + return fmt.Sprintf("%s (%s)", p.Infohash, p.Addr.String()) } diff --git a/dht/remote_node.go b/dht/remote_node.go index 5bb2585..4bb9319 100644 --- a/dht/remote_node.go +++ b/dht/remote_node.go @@ -6,12 +6,11 @@ import ( ) type remoteNode struct { - address net.Addr - id Infohash - //lastSeen time.Time + addr net.Addr + id Infohash } // String implements fmt.Stringer func (r remoteNode) String() string { - return fmt.Sprintf("%s (%s)", r.id.String(), r.address.String()) + return fmt.Sprintf("%s (%s)", r.id.String(), r.addr.String()) } diff --git a/dht/routing_table.go b/dht/routing_table.go index 3c1f2d2..b10574c 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -73,10 +73,10 @@ func (k *routingTable) add(rn *remoteNode) { k.Lock() defer k.Unlock() - if _, ok := k.addresses[rn.address.String()]; ok { + if _, ok := k.addresses[rn.addr.String()]; ok { return } - k.addresses[rn.address.String()] = rn + k.addresses[rn.addr.String()] = rn item := &rItem{ value: rn, @@ -88,7 +88,7 @@ func (k *routingTable) add(rn *remoteNode) { if len(k.items) > k.max { for i := k.max - 1; i < len(k.items); i++ { old := k.items[i] - delete(k.addresses, old.value.address.String()) + delete(k.addresses, old.value.addr.String()) heap.Remove(&k.items, i) } } diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go index 77c0a17..1eeeca3 100644 --- a/dht/routing_table_test.go +++ b/dht/routing_table_test.go @@ -32,7 +32,7 @@ func TestPriorityQueue(t *testing.T) { t.Errorf("failed to create infohash: %s\n", err) } addr, _ := net.ResolveUDPAddr("udp", fmt.Sprintf("0.0.0.0:%d", i)) - pq.add(&remoteNode{id: *iht, address: addr}) + pq.add(&remoteNode{id: *iht, addr: addr}) } if len(pq.items) != len(pq.addresses) { |
