aboutsummaryrefslogtreecommitdiff
path: root/dht
diff options
context:
space:
mode:
authorFelix Hanley <felix@userspace.com.au>2018-02-21 04:20:06 +0000
committerFelix Hanley <felix@userspace.com.au>2018-02-21 04:21:39 +0000
commite9adf3a2bf8b81615275a6705b7957e43753f0ec (patch)
tree1eaeb5081f3914a8ffa936d96ad1f1548c9aeb2f /dht
parent020a8f9ec7e541d284ddb65111aafe42547927e5 (diff)
downloaddhtsearch-e9adf3a2bf8b81615275a6705b7957e43753f0ec.tar.gz
dhtsearch-e9adf3a2bf8b81615275a6705b7957e43753f0ec.tar.bz2
Seperate shared packages
Diffstat (limited to 'dht')
-rw-r--r--dht/infohash.go2
-rw-r--r--dht/infohash_test.go2
-rw-r--r--dht/krpc.go177
-rw-r--r--dht/krpc_test.go30
-rw-r--r--dht/messages.go38
-rw-r--r--dht/node.go77
-rw-r--r--dht/peer.go11
-rw-r--r--dht/remote_node.go7
-rw-r--r--dht/routing_table.go6
-rw-r--r--dht/routing_table_test.go2
10 files changed, 68 insertions, 284 deletions
diff --git a/dht/infohash.go b/dht/infohash.go
index 6d4596d..cb5170e 100644
--- a/dht/infohash.go
+++ b/dht/infohash.go
@@ -82,7 +82,7 @@ func generateNeighbour(first, second Infohash) Infohash {
return Infohash(s)
}
-func randomInfoHash() (ih Infohash) {
+func GenInfohash() (ih Infohash) {
random := rand.New(rand.NewSource(time.Now().UnixNano()))
hash := sha1.New()
io.WriteString(hash, time.Now().String())
diff --git a/dht/infohash_test.go b/dht/infohash_test.go
index 1574b19..6d627fc 100644
--- a/dht/infohash_test.go
+++ b/dht/infohash_test.go
@@ -39,7 +39,7 @@ func TestInfohashImport(t *testing.T) {
}
func TestInfohashLength(t *testing.T) {
- ih := randomInfoHash()
+ ih := GenInfohash()
if len(ih) != 20 {
t.Errorf("%s as string should be length 20, got %d", ih, len(ih))
}
diff --git a/dht/krpc.go b/dht/krpc.go
deleted file mode 100644
index 2a7c103..0000000
--- a/dht/krpc.go
+++ /dev/null
@@ -1,177 +0,0 @@
-package dht
-
-import (
- "errors"
- "fmt"
- "math/rand"
- "net"
- "strconv"
-)
-
-const transIDBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
-
-func newTransactionID() string {
- b := make([]byte, 2)
- for i := range b {
- b[i] = transIDBytes[rand.Int63()%int64(len(transIDBytes))]
- }
- return string(b)
-}
-
-// makeQuery returns a query-formed data.
-func makeQuery(t, q string, a map[string]interface{}) map[string]interface{} {
- return map[string]interface{}{
- "t": t,
- "y": "q",
- "q": q,
- "a": a,
- }
-}
-
-// makeResponse returns a response-formed data.
-func makeResponse(t string, r map[string]interface{}) map[string]interface{} {
- return map[string]interface{}{
- "t": t,
- "y": "r",
- "r": r,
- }
-}
-
-func getStringKey(data map[string]interface{}, key string) (string, error) {
- val, ok := data[key]
- if !ok {
- return "", fmt.Errorf("krpc: missing key %s", key)
- }
- out, ok := val.(string)
- if !ok {
- return "", fmt.Errorf("krpc: key type mismatch")
- }
- return out, nil
-}
-
-func getIntKey(data map[string]interface{}, key string) (int, error) {
- val, ok := data[key]
- if !ok {
- return 0, fmt.Errorf("krpc: missing key %s", key)
- }
- out, ok := val.(int)
- if !ok {
- return 0, fmt.Errorf("krpc: key type mismatch")
- }
- return out, nil
-}
-
-func getMapKey(data map[string]interface{}, key string) (map[string]interface{}, error) {
- val, ok := data[key]
- if !ok {
- return nil, fmt.Errorf("krpc: missing key %s", key)
- }
- out, ok := val.(map[string]interface{})
- if !ok {
- return nil, fmt.Errorf("krpc: key type mismatch")
- }
- return out, nil
-}
-
-func getListKey(data map[string]interface{}, key string) ([]interface{}, error) {
- val, ok := data[key]
- if !ok {
- return nil, fmt.Errorf("krpc: missing key %s", key)
- }
- out, ok := val.([]interface{})
- if !ok {
- return nil, fmt.Errorf("krpc: key type mismatch")
- }
- return out, nil
-}
-
-// parseKeys parses keys. It just wraps parseKey.
-func checkKeys(data map[string]interface{}, pairs [][]string) (err error) {
- for _, args := range pairs {
- key, t := args[0], args[1]
- if err = checkKey(data, key, t); err != nil {
- break
- }
- }
- return err
-}
-
-// parseKey parses the key in dict data. `t` is type of the keyed value.
-// It's one of "int", "string", "map", "list".
-func checkKey(data map[string]interface{}, key string, t string) error {
- val, ok := data[key]
- if !ok {
- return fmt.Errorf("krpc: missing key %s", key)
- }
-
- switch t {
- case "string":
- _, ok = val.(string)
- case "int":
- _, ok = val.(int)
- case "map":
- _, ok = val.(map[string]interface{})
- case "list":
- _, ok = val.([]interface{})
- default:
- return errors.New("krpc: invalid type")
- }
-
- if !ok {
- return errors.New("krpc: key type mismatch")
- }
-
- return nil
-}
-
-// Swiped from nictuku
-func decodeCompactNodeAddr(cni string) string {
- if len(cni) == 6 {
- return fmt.Sprintf("%d.%d.%d.%d:%d", cni[0], cni[1], cni[2], cni[3], (uint16(cni[4])<<8)|uint16(cni[5]))
- } else if len(cni) == 18 {
- b := []byte(cni[:16])
- return fmt.Sprintf("[%s]:%d", net.IP.String(b), (uint16(cni[16])<<8)|uint16(cni[17]))
- } else {
- return ""
- }
-}
-
-func encodeCompactNodeAddr(addr string) string {
- var a []uint8
- host, port, _ := net.SplitHostPort(addr)
- ip := net.ParseIP(host)
- if ip == nil {
- return ""
- }
- aa, _ := strconv.ParseUint(port, 10, 16)
- c := uint16(aa)
- if ip2 := net.IP.To4(ip); ip2 != nil {
- a = make([]byte, net.IPv4len+2, net.IPv4len+2)
- copy(a, ip2[0:net.IPv4len]) // ignore bytes IPv6 bytes if it's IPv4.
- a[4] = byte(c >> 8)
- a[5] = byte(c)
- } else {
- a = make([]byte, net.IPv6len+2, net.IPv6len+2)
- copy(a, ip)
- a[16] = byte(c >> 8)
- a[17] = byte(c)
- }
- return string(a)
-}
-
-func int2bytes(val int64) []byte {
- data, j := make([]byte, 8), -1
- for i := 0; i < 8; i++ {
- shift := uint64((7 - i) * 8)
- data[i] = byte((val & (0xff << shift)) >> shift)
-
- if j == -1 && data[i] != 0 {
- j = i
- }
- }
-
- if j != -1 {
- return data[j:]
- }
- return data[:1]
-}
diff --git a/dht/krpc_test.go b/dht/krpc_test.go
deleted file mode 100644
index 5bc8373..0000000
--- a/dht/krpc_test.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package dht
-
-import (
- "encoding/hex"
- "testing"
-)
-
-func TestCompactNodeAddr(t *testing.T) {
-
- tests := []struct {
- in string
- out string
- }{
- {in: "192.168.1.1:6881", out: "c0a801011ae1"},
- {in: "[2001:9372:434a:800::2]:6881", out: "20019372434a080000000000000000021ae1"},
- }
-
- for _, tt := range tests {
- r := encodeCompactNodeAddr(tt.in)
- out, _ := hex.DecodeString(tt.out)
- if r != string(out) {
- t.Errorf("encodeCompactNodeAddr(%s) => %x, expected %s", tt.in, r, tt.out)
- }
-
- s := decodeCompactNodeAddr(r)
- if s != tt.in {
- t.Errorf("decodeCompactNodeAddr(%x) => %s, expected %s", r, s, tt.in)
- }
- }
-}
diff --git a/dht/messages.go b/dht/messages.go
index 023e27d..9305b6f 100644
--- a/dht/messages.go
+++ b/dht/messages.go
@@ -3,28 +3,29 @@ package dht
import (
"fmt"
"net"
- //"strings"
+
+ "github.com/felix/dhtsearch/krpc"
)
func (n *Node) onPingQuery(rn remoteNode, msg map[string]interface{}) error {
- t, err := getStringKey(msg, "t")
+ t, err := krpc.GetString(msg, "t")
if err != nil {
return err
}
- n.queueMsg(rn, makeResponse(t, map[string]interface{}{
+ n.queueMsg(rn, krpc.MakeResponse(t, map[string]interface{}{
"id": string(n.id),
}))
return nil
}
func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error {
- a, err := getMapKey(msg, "a")
+ a, err := krpc.GetMap(msg, "a")
if err != nil {
return err
}
// This is the ih of the torrent
- torrent, err := getStringKey(a, "info_hash")
+ torrent, err := krpc.GetString(a, "info_hash")
if err != nil {
return err
}
@@ -40,7 +41,7 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error
nodes := n.rTable.get(8)
compactNS := []string{}
for _, rn := range nodes {
- ns := encodeCompactNodeAddr(rn.address.String())
+ ns := encodeCompactNodeAddr(rn.addr.String())
if ns == "" {
n.log.Warn("failed to compact node", "address", rn.address.String())
continue
@@ -50,7 +51,7 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error
*/
t := msg["t"].(string)
- n.queueMsg(rn, makeResponse(t, map[string]interface{}{
+ n.queueMsg(rn, krpc.MakeResponse(t, map[string]interface{}{
"id": string(neighbour),
"token": token,
"nodes": "",
@@ -60,7 +61,7 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error
//nodes := n.rTable.get(50)
/*
fmt.Printf("sending get_peers for %s to %d nodes\n", *th, len(nodes))
- q := makeQuery(newTransactionID(), "get_peers", map[string]interface{}{
+ q := krpc.MakeQuery(newTransactionID(), "get_peers", map[string]interface{}{
"id": string(id),
"info_hash": string(*th),
})
@@ -72,22 +73,17 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error
}
func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) error {
- a, err := getMapKey(msg, "a")
+ a, err := krpc.GetMap(msg, "a")
if err != nil {
return err
}
- err = checkKeys(a, [][]string{
- {"info_hash", "string"},
- {"port", "int"},
- {"token", "string"},
- })
n.log.Debug("announce_peer", "source", rn)
- if impliedPort, err := getIntKey(a, "implied_port"); err == nil {
+ if impliedPort, err := krpc.GetInt(a, "implied_port"); err == nil {
if impliedPort != 0 {
// Use the port in the message
- host, _, err := net.SplitHostPort(rn.address.String())
+ host, _, err := net.SplitHostPort(rn.addr.String())
if err != nil {
return err
}
@@ -96,13 +92,13 @@ func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) er
return fmt.Errorf("ignoring port 0")
}
addr, err := net.ResolveUDPAddr(n.family, fmt.Sprintf("%s:%d", host, newPort))
- rn = remoteNode{address: addr, id: rn.id}
+ rn = remoteNode{addr: addr, id: rn.id}
}
}
// TODO do we reply?
- ihStr, err := getStringKey(a, "info_hash")
+ ihStr, err := krpc.GetString(a, "info_hash")
if err != nil {
return err
}
@@ -111,8 +107,7 @@ func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) er
n.log.Warn("invalid torrent", "infohash", ihStr)
}
- p := Peer{Node: rn, Infohash: *ih}
- n.log.Info("anounce_peer", p)
+ p := Peer{Addr: rn.addr, ID: rn.id, Infohash: *ih}
if n.OnAnnouncePeer != nil {
go n.OnAnnouncePeer(p)
}
@@ -121,9 +116,6 @@ func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) er
func (n *Node) onFindNodeResponse(rn remoteNode, msg map[string]interface{}) {
r := msg["r"].(map[string]interface{})
- if err := checkKey(r, "id", "string"); err != nil {
- return
- }
nodes := r["nodes"].(string)
n.processFindNodeResults(rn, nodes)
}
diff --git a/dht/node.go b/dht/node.go
index db61eed..f8e6113 100644
--- a/dht/node.go
+++ b/dht/node.go
@@ -7,6 +7,7 @@ import (
"time"
"github.com/felix/dhtsearch/bencode"
+ "github.com/felix/dhtsearch/krpc"
"github.com/felix/logger"
"golang.org/x/time/rate"
)
@@ -42,7 +43,7 @@ type Node struct {
// NewNode creates a new DHT node
func NewNode(opts ...Option) (n *Node, err error) {
- id := randomInfoHash()
+ id := GenInfohash()
k, err := newRoutingTable(id, 2000)
if err != nil {
@@ -159,7 +160,7 @@ func (n *Node) bootstrap() {
n.log.Error("failed to parse bootstrap address", "error", err)
continue
}
- rn := &remoteNode{address: addr}
+ rn := &remoteNode{addr: addr}
n.findNode(rn, n.id)
}
}
@@ -186,7 +187,7 @@ func (n *Node) packetWriter() {
}
func (n *Node) findNode(rn *remoteNode, id Infohash) {
- target := randomInfoHash()
+ target := GenInfohash()
n.sendQuery(rn, "find_node", map[string]interface{}{
"id": string(id),
"target": string(target),
@@ -207,10 +208,9 @@ func (n *Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{})
return nil
}
- t := newTransactionID()
- //n.log.Debug("sending message", "type", qType, "remote", rn)
+ t := krpc.NewTransactionID()
- data := makeQuery(t, qType, a)
+ data := krpc.MakeQuery(t, qType, a)
b, err := bencode.Encode(data)
if err != nil {
return err
@@ -218,43 +218,38 @@ func (n *Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{})
//fmt.Printf("sending %s to %s\n", qType, rn.String())
n.packetsOut <- packet{
data: b,
- raddr: rn.address,
+ raddr: rn.addr,
}
return nil
}
// Parse a KRPC packet into a message
-func (n *Node) processPacket(p packet) {
- data, err := bencode.Decode(p.data)
+func (n *Node) processPacket(p packet) error {
+ response, _, err := bencode.DecodeDict(p.data, 0)
if err != nil {
- return
- }
-
- response, ok := data.(map[string]interface{})
- if !ok {
- n.log.Debug("failed to parse packet", "error", "response is not dict")
- return
+ return err
}
- if err := checkKeys(response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil {
- n.log.Debug("failed to parse packet", "error", err)
- return
+ y, err := krpc.GetString(response, "y")
+ if err != nil {
+ return err
}
- switch response["y"].(string) {
+ switch y {
case "q":
err = n.handleRequest(p.raddr, response)
case "r":
err = n.handleResponse(p.raddr, response)
case "e":
- n.handleError(p.raddr, response)
+ err = n.handleError(p.raddr, response)
default:
n.log.Warn("missing request type")
- return
+ return nil
}
if err != nil {
n.log.Warn("failed to process packet", "error", err)
}
+ return err
}
// bencode data and send
@@ -265,24 +260,24 @@ func (n *Node) queueMsg(rn remoteNode, data map[string]interface{}) error {
}
n.packetsOut <- packet{
data: b,
- raddr: rn.address,
+ raddr: rn.addr,
}
return nil
}
// handleRequest handles the requests received from udp.
func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error {
- q, err := getStringKey(m, "q")
+ q, err := krpc.GetString(m, "q")
if err != nil {
return err
}
- a, err := getMapKey(m, "a")
+ a, err := krpc.GetMap(m, "a")
if err != nil {
return err
}
- id, err := getStringKey(a, "id")
+ id, err := krpc.GetString(a, "id")
if err != nil {
return err
}
@@ -296,7 +291,7 @@ func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error {
return nil
}
- rn := &remoteNode{address: addr, id: *ih}
+ rn := &remoteNode{addr: addr, id: *ih}
switch q {
case "ping":
@@ -318,11 +313,11 @@ func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error {
// handleResponse handles responses received from udp.
func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error {
- r, err := getMapKey(m, "r")
+ r, err := krpc.GetMap(m, "r")
if err != nil {
return err
}
- id, err := getStringKey(r, "id")
+ id, err := krpc.GetString(r, "id")
if err != nil {
return err
}
@@ -331,9 +326,9 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error {
return err
}
- rn := &remoteNode{address: addr, id: *ih}
+ rn := &remoteNode{addr: addr, id: *ih}
- nodes, err := getStringKey(r, "nodes")
+ nodes, err := krpc.GetString(r, "nodes")
// find_nodes/get_peers response with nodes
if err == nil {
n.onFindNodeResponse(*rn, m)
@@ -342,12 +337,12 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error {
return nil
}
- values, err := getListKey(r, "values")
+ values, err := krpc.GetList(r, "values")
// get_peers response
if err == nil {
n.log.Debug("get_peers response", "source", rn)
for _, v := range values {
- addr := decodeCompactNodeAddr(v.(string))
+ addr := krpc.DecodeCompactNodeAddr(v.(string))
n.log.Debug("unhandled get_peer request", "addres", addr)
// TODO new peer needs to be matched to previous get_peers request
@@ -359,20 +354,20 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error {
}
// handleError handles errors received from udp.
-func (n *Node) handleError(addr net.Addr, m map[string]interface{}) bool {
- if err := checkKey(m, "e", "list"); err != nil {
- return false
+func (n *Node) handleError(addr net.Addr, m map[string]interface{}) error {
+ e, err := krpc.GetList(m, "e")
+ if err != nil {
+ return err
}
- e := m["e"].([]interface{})
if len(e) != 2 {
- return false
+ return fmt.Errorf("error packet wrong length %d", len(e))
}
code := e[0].(int64)
msg := e[1].(string)
n.log.Debug("error packet", "address", addr.String(), "code", code, "error", msg)
- return true
+ return nil
}
// Process another node's response to a find_node query.
@@ -392,7 +387,7 @@ func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) {
// We got a byte array in groups of 26 or 38
for i := 0; i < len(nodeList); i += nodeLength {
id := nodeList[i : i+ihLength]
- addrStr := decodeCompactNodeAddr(nodeList[i+ihLength : i+nodeLength])
+ addrStr := krpc.DecodeCompactNodeAddr(nodeList[i+ihLength : i+nodeLength])
ih, err := InfohashFromString(id)
if err != nil {
@@ -406,7 +401,7 @@ func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) {
continue
}
- rn := &remoteNode{address: addr, id: *ih}
+ rn := &remoteNode{addr: addr, id: *ih}
n.rTable.add(rn)
}
}
diff --git a/dht/peer.go b/dht/peer.go
index f9669ba..42e8438 100644
--- a/dht/peer.go
+++ b/dht/peer.go
@@ -1,13 +1,18 @@
package dht
-import "fmt"
+import (
+ "fmt"
+ "net"
+)
// Peer on DHT network
type Peer struct {
- Node remoteNode
+ Addr net.Addr
+ ID Infohash
Infohash Infohash
}
+// String implements fmt.Stringer
func (p Peer) String() string {
- return fmt.Sprintf("%s (%s)", p.Infohash, p.Node)
+ return fmt.Sprintf("%s (%s)", p.Infohash, p.Addr.String())
}
diff --git a/dht/remote_node.go b/dht/remote_node.go
index 5bb2585..4bb9319 100644
--- a/dht/remote_node.go
+++ b/dht/remote_node.go
@@ -6,12 +6,11 @@ import (
)
type remoteNode struct {
- address net.Addr
- id Infohash
- //lastSeen time.Time
+ addr net.Addr
+ id Infohash
}
// String implements fmt.Stringer
func (r remoteNode) String() string {
- return fmt.Sprintf("%s (%s)", r.id.String(), r.address.String())
+ return fmt.Sprintf("%s (%s)", r.id.String(), r.addr.String())
}
diff --git a/dht/routing_table.go b/dht/routing_table.go
index 3c1f2d2..b10574c 100644
--- a/dht/routing_table.go
+++ b/dht/routing_table.go
@@ -73,10 +73,10 @@ func (k *routingTable) add(rn *remoteNode) {
k.Lock()
defer k.Unlock()
- if _, ok := k.addresses[rn.address.String()]; ok {
+ if _, ok := k.addresses[rn.addr.String()]; ok {
return
}
- k.addresses[rn.address.String()] = rn
+ k.addresses[rn.addr.String()] = rn
item := &rItem{
value: rn,
@@ -88,7 +88,7 @@ func (k *routingTable) add(rn *remoteNode) {
if len(k.items) > k.max {
for i := k.max - 1; i < len(k.items); i++ {
old := k.items[i]
- delete(k.addresses, old.value.address.String())
+ delete(k.addresses, old.value.addr.String())
heap.Remove(&k.items, i)
}
}
diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go
index 77c0a17..1eeeca3 100644
--- a/dht/routing_table_test.go
+++ b/dht/routing_table_test.go
@@ -32,7 +32,7 @@ func TestPriorityQueue(t *testing.T) {
t.Errorf("failed to create infohash: %s\n", err)
}
addr, _ := net.ResolveUDPAddr("udp", fmt.Sprintf("0.0.0.0:%d", i))
- pq.add(&remoteNode{id: *iht, address: addr})
+ pq.add(&remoteNode{id: *iht, addr: addr})
}
if len(pq.items) != len(pq.addresses) {