aboutsummaryrefslogtreecommitdiff
path: root/dht/node.go
diff options
context:
space:
mode:
Diffstat (limited to 'dht/node.go')
-rw-r--r--dht/node.go77
1 files changed, 36 insertions, 41 deletions
diff --git a/dht/node.go b/dht/node.go
index db61eed..f8e6113 100644
--- a/dht/node.go
+++ b/dht/node.go
@@ -7,6 +7,7 @@ import (
"time"
"github.com/felix/dhtsearch/bencode"
+ "github.com/felix/dhtsearch/krpc"
"github.com/felix/logger"
"golang.org/x/time/rate"
)
@@ -42,7 +43,7 @@ type Node struct {
// NewNode creates a new DHT node
func NewNode(opts ...Option) (n *Node, err error) {
- id := randomInfoHash()
+ id := GenInfohash()
k, err := newRoutingTable(id, 2000)
if err != nil {
@@ -159,7 +160,7 @@ func (n *Node) bootstrap() {
n.log.Error("failed to parse bootstrap address", "error", err)
continue
}
- rn := &remoteNode{address: addr}
+ rn := &remoteNode{addr: addr}
n.findNode(rn, n.id)
}
}
@@ -186,7 +187,7 @@ func (n *Node) packetWriter() {
}
func (n *Node) findNode(rn *remoteNode, id Infohash) {
- target := randomInfoHash()
+ target := GenInfohash()
n.sendQuery(rn, "find_node", map[string]interface{}{
"id": string(id),
"target": string(target),
@@ -207,10 +208,9 @@ func (n *Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{})
return nil
}
- t := newTransactionID()
- //n.log.Debug("sending message", "type", qType, "remote", rn)
+ t := krpc.NewTransactionID()
- data := makeQuery(t, qType, a)
+ data := krpc.MakeQuery(t, qType, a)
b, err := bencode.Encode(data)
if err != nil {
return err
@@ -218,43 +218,38 @@ func (n *Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{})
//fmt.Printf("sending %s to %s\n", qType, rn.String())
n.packetsOut <- packet{
data: b,
- raddr: rn.address,
+ raddr: rn.addr,
}
return nil
}
// Parse a KRPC packet into a message
-func (n *Node) processPacket(p packet) {
- data, err := bencode.Decode(p.data)
+func (n *Node) processPacket(p packet) error {
+ response, _, err := bencode.DecodeDict(p.data, 0)
if err != nil {
- return
- }
-
- response, ok := data.(map[string]interface{})
- if !ok {
- n.log.Debug("failed to parse packet", "error", "response is not dict")
- return
+ return err
}
- if err := checkKeys(response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil {
- n.log.Debug("failed to parse packet", "error", err)
- return
+ y, err := krpc.GetString(response, "y")
+ if err != nil {
+ return err
}
- switch response["y"].(string) {
+ switch y {
case "q":
err = n.handleRequest(p.raddr, response)
case "r":
err = n.handleResponse(p.raddr, response)
case "e":
- n.handleError(p.raddr, response)
+ err = n.handleError(p.raddr, response)
default:
n.log.Warn("missing request type")
- return
+ return nil
}
if err != nil {
n.log.Warn("failed to process packet", "error", err)
}
+ return err
}
// bencode data and send
@@ -265,24 +260,24 @@ func (n *Node) queueMsg(rn remoteNode, data map[string]interface{}) error {
}
n.packetsOut <- packet{
data: b,
- raddr: rn.address,
+ raddr: rn.addr,
}
return nil
}
// handleRequest handles the requests received from udp.
func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error {
- q, err := getStringKey(m, "q")
+ q, err := krpc.GetString(m, "q")
if err != nil {
return err
}
- a, err := getMapKey(m, "a")
+ a, err := krpc.GetMap(m, "a")
if err != nil {
return err
}
- id, err := getStringKey(a, "id")
+ id, err := krpc.GetString(a, "id")
if err != nil {
return err
}
@@ -296,7 +291,7 @@ func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error {
return nil
}
- rn := &remoteNode{address: addr, id: *ih}
+ rn := &remoteNode{addr: addr, id: *ih}
switch q {
case "ping":
@@ -318,11 +313,11 @@ func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error {
// handleResponse handles responses received from udp.
func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error {
- r, err := getMapKey(m, "r")
+ r, err := krpc.GetMap(m, "r")
if err != nil {
return err
}
- id, err := getStringKey(r, "id")
+ id, err := krpc.GetString(r, "id")
if err != nil {
return err
}
@@ -331,9 +326,9 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error {
return err
}
- rn := &remoteNode{address: addr, id: *ih}
+ rn := &remoteNode{addr: addr, id: *ih}
- nodes, err := getStringKey(r, "nodes")
+ nodes, err := krpc.GetString(r, "nodes")
// find_nodes/get_peers response with nodes
if err == nil {
n.onFindNodeResponse(*rn, m)
@@ -342,12 +337,12 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error {
return nil
}
- values, err := getListKey(r, "values")
+ values, err := krpc.GetList(r, "values")
// get_peers response
if err == nil {
n.log.Debug("get_peers response", "source", rn)
for _, v := range values {
- addr := decodeCompactNodeAddr(v.(string))
+ addr := krpc.DecodeCompactNodeAddr(v.(string))
n.log.Debug("unhandled get_peer request", "addres", addr)
// TODO new peer needs to be matched to previous get_peers request
@@ -359,20 +354,20 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error {
}
// handleError handles errors received from udp.
-func (n *Node) handleError(addr net.Addr, m map[string]interface{}) bool {
- if err := checkKey(m, "e", "list"); err != nil {
- return false
+func (n *Node) handleError(addr net.Addr, m map[string]interface{}) error {
+ e, err := krpc.GetList(m, "e")
+ if err != nil {
+ return err
}
- e := m["e"].([]interface{})
if len(e) != 2 {
- return false
+ return fmt.Errorf("error packet wrong length %d", len(e))
}
code := e[0].(int64)
msg := e[1].(string)
n.log.Debug("error packet", "address", addr.String(), "code", code, "error", msg)
- return true
+ return nil
}
// Process another node's response to a find_node query.
@@ -392,7 +387,7 @@ func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) {
// We got a byte array in groups of 26 or 38
for i := 0; i < len(nodeList); i += nodeLength {
id := nodeList[i : i+ihLength]
- addrStr := decodeCompactNodeAddr(nodeList[i+ihLength : i+nodeLength])
+ addrStr := krpc.DecodeCompactNodeAddr(nodeList[i+ihLength : i+nodeLength])
ih, err := InfohashFromString(id)
if err != nil {
@@ -406,7 +401,7 @@ func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) {
continue
}
- rn := &remoteNode{address: addr, id: *ih}
+ rn := &remoteNode{addr: addr, id: *ih}
n.rTable.add(rn)
}
}