From e9adf3a2bf8b81615275a6705b7957e43753f0ec Mon Sep 17 00:00:00 2001 From: Felix Hanley Date: Wed, 21 Feb 2018 15:20:06 +1100 Subject: Seperate shared packages --- db.go | 28 ----- dht/infohash.go | 2 +- dht/infohash_test.go | 2 +- dht/krpc.go | 177 --------------------------- dht/krpc_test.go | 30 ----- dht/messages.go | 38 +++--- dht/node.go | 77 ++++++------ dht/peer.go | 11 +- dht/remote_node.go | 7 +- dht/routing_table.go | 6 +- dht/routing_table_test.go | 2 +- http.go | 299 ---------------------------------------------- infohash_test.go | 43 ------- krpc/krpc.go | 177 +++++++++++++++++++++++++++ krpc/krpc_test.go | 30 +++++ models/tag.go | 5 + models/torrent.go | 98 +++++++++++++++ torrent.go | 190 ----------------------------- util.go | 92 -------------- 19 files changed, 378 insertions(+), 936 deletions(-) delete mode 100644 db.go delete mode 100644 dht/krpc.go delete mode 100644 dht/krpc_test.go delete mode 100644 http.go delete mode 100644 infohash_test.go create mode 100644 krpc/krpc.go create mode 100644 krpc/krpc_test.go create mode 100644 models/tag.go create mode 100644 models/torrent.go delete mode 100644 torrent.go delete mode 100644 util.go diff --git a/db.go b/db.go deleted file mode 100644 index 1051366..0000000 --- a/db.go +++ /dev/null @@ -1,28 +0,0 @@ -package dhtsearch - -import ( - "fmt" - _ "github.com/jackc/pgx/stdlib" - "github.com/jmoiron/sqlx" -) - -type database struct { - *sqlx.DB -} - -// Global -var DB *database - -func newDB(dsn string) (*database, error) { - d, err := sqlx.Connect("pgx", dsn) - if err != nil { - fmt.Printf("Error creating DB %q\n", err) - return nil, err - } - var count int - err = d.QueryRow("select count(*) from torrents").Scan(&count) - if err != nil { - return nil, err - } - return &database{d}, nil -} diff --git a/dht/infohash.go b/dht/infohash.go index 6d4596d..cb5170e 100644 --- a/dht/infohash.go +++ b/dht/infohash.go @@ -82,7 +82,7 @@ func generateNeighbour(first, second Infohash) Infohash { return Infohash(s) } -func randomInfoHash() (ih Infohash) { +func GenInfohash() (ih Infohash) { random := rand.New(rand.NewSource(time.Now().UnixNano())) hash := sha1.New() io.WriteString(hash, time.Now().String()) diff --git a/dht/infohash_test.go b/dht/infohash_test.go index 1574b19..6d627fc 100644 --- a/dht/infohash_test.go +++ b/dht/infohash_test.go @@ -39,7 +39,7 @@ func TestInfohashImport(t *testing.T) { } func TestInfohashLength(t *testing.T) { - ih := randomInfoHash() + ih := GenInfohash() if len(ih) != 20 { t.Errorf("%s as string should be length 20, got %d", ih, len(ih)) } diff --git a/dht/krpc.go b/dht/krpc.go deleted file mode 100644 index 2a7c103..0000000 --- a/dht/krpc.go +++ /dev/null @@ -1,177 +0,0 @@ -package dht - -import ( - "errors" - "fmt" - "math/rand" - "net" - "strconv" -) - -const transIDBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - -func newTransactionID() string { - b := make([]byte, 2) - for i := range b { - b[i] = transIDBytes[rand.Int63()%int64(len(transIDBytes))] - } - return string(b) -} - -// makeQuery returns a query-formed data. -func makeQuery(t, q string, a map[string]interface{}) map[string]interface{} { - return map[string]interface{}{ - "t": t, - "y": "q", - "q": q, - "a": a, - } -} - -// makeResponse returns a response-formed data. -func makeResponse(t string, r map[string]interface{}) map[string]interface{} { - return map[string]interface{}{ - "t": t, - "y": "r", - "r": r, - } -} - -func getStringKey(data map[string]interface{}, key string) (string, error) { - val, ok := data[key] - if !ok { - return "", fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.(string) - if !ok { - return "", fmt.Errorf("krpc: key type mismatch") - } - return out, nil -} - -func getIntKey(data map[string]interface{}, key string) (int, error) { - val, ok := data[key] - if !ok { - return 0, fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.(int) - if !ok { - return 0, fmt.Errorf("krpc: key type mismatch") - } - return out, nil -} - -func getMapKey(data map[string]interface{}, key string) (map[string]interface{}, error) { - val, ok := data[key] - if !ok { - return nil, fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("krpc: key type mismatch") - } - return out, nil -} - -func getListKey(data map[string]interface{}, key string) ([]interface{}, error) { - val, ok := data[key] - if !ok { - return nil, fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.([]interface{}) - if !ok { - return nil, fmt.Errorf("krpc: key type mismatch") - } - return out, nil -} - -// parseKeys parses keys. It just wraps parseKey. -func checkKeys(data map[string]interface{}, pairs [][]string) (err error) { - for _, args := range pairs { - key, t := args[0], args[1] - if err = checkKey(data, key, t); err != nil { - break - } - } - return err -} - -// parseKey parses the key in dict data. `t` is type of the keyed value. -// It's one of "int", "string", "map", "list". -func checkKey(data map[string]interface{}, key string, t string) error { - val, ok := data[key] - if !ok { - return fmt.Errorf("krpc: missing key %s", key) - } - - switch t { - case "string": - _, ok = val.(string) - case "int": - _, ok = val.(int) - case "map": - _, ok = val.(map[string]interface{}) - case "list": - _, ok = val.([]interface{}) - default: - return errors.New("krpc: invalid type") - } - - if !ok { - return errors.New("krpc: key type mismatch") - } - - return nil -} - -// Swiped from nictuku -func decodeCompactNodeAddr(cni string) string { - if len(cni) == 6 { - return fmt.Sprintf("%d.%d.%d.%d:%d", cni[0], cni[1], cni[2], cni[3], (uint16(cni[4])<<8)|uint16(cni[5])) - } else if len(cni) == 18 { - b := []byte(cni[:16]) - return fmt.Sprintf("[%s]:%d", net.IP.String(b), (uint16(cni[16])<<8)|uint16(cni[17])) - } else { - return "" - } -} - -func encodeCompactNodeAddr(addr string) string { - var a []uint8 - host, port, _ := net.SplitHostPort(addr) - ip := net.ParseIP(host) - if ip == nil { - return "" - } - aa, _ := strconv.ParseUint(port, 10, 16) - c := uint16(aa) - if ip2 := net.IP.To4(ip); ip2 != nil { - a = make([]byte, net.IPv4len+2, net.IPv4len+2) - copy(a, ip2[0:net.IPv4len]) // ignore bytes IPv6 bytes if it's IPv4. - a[4] = byte(c >> 8) - a[5] = byte(c) - } else { - a = make([]byte, net.IPv6len+2, net.IPv6len+2) - copy(a, ip) - a[16] = byte(c >> 8) - a[17] = byte(c) - } - return string(a) -} - -func int2bytes(val int64) []byte { - data, j := make([]byte, 8), -1 - for i := 0; i < 8; i++ { - shift := uint64((7 - i) * 8) - data[i] = byte((val & (0xff << shift)) >> shift) - - if j == -1 && data[i] != 0 { - j = i - } - } - - if j != -1 { - return data[j:] - } - return data[:1] -} diff --git a/dht/krpc_test.go b/dht/krpc_test.go deleted file mode 100644 index 5bc8373..0000000 --- a/dht/krpc_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package dht - -import ( - "encoding/hex" - "testing" -) - -func TestCompactNodeAddr(t *testing.T) { - - tests := []struct { - in string - out string - }{ - {in: "192.168.1.1:6881", out: "c0a801011ae1"}, - {in: "[2001:9372:434a:800::2]:6881", out: "20019372434a080000000000000000021ae1"}, - } - - for _, tt := range tests { - r := encodeCompactNodeAddr(tt.in) - out, _ := hex.DecodeString(tt.out) - if r != string(out) { - t.Errorf("encodeCompactNodeAddr(%s) => %x, expected %s", tt.in, r, tt.out) - } - - s := decodeCompactNodeAddr(r) - if s != tt.in { - t.Errorf("decodeCompactNodeAddr(%x) => %s, expected %s", r, s, tt.in) - } - } -} diff --git a/dht/messages.go b/dht/messages.go index 023e27d..9305b6f 100644 --- a/dht/messages.go +++ b/dht/messages.go @@ -3,28 +3,29 @@ package dht import ( "fmt" "net" - //"strings" + + "github.com/felix/dhtsearch/krpc" ) func (n *Node) onPingQuery(rn remoteNode, msg map[string]interface{}) error { - t, err := getStringKey(msg, "t") + t, err := krpc.GetString(msg, "t") if err != nil { return err } - n.queueMsg(rn, makeResponse(t, map[string]interface{}{ + n.queueMsg(rn, krpc.MakeResponse(t, map[string]interface{}{ "id": string(n.id), })) return nil } func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error { - a, err := getMapKey(msg, "a") + a, err := krpc.GetMap(msg, "a") if err != nil { return err } // This is the ih of the torrent - torrent, err := getStringKey(a, "info_hash") + torrent, err := krpc.GetString(a, "info_hash") if err != nil { return err } @@ -40,7 +41,7 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error nodes := n.rTable.get(8) compactNS := []string{} for _, rn := range nodes { - ns := encodeCompactNodeAddr(rn.address.String()) + ns := encodeCompactNodeAddr(rn.addr.String()) if ns == "" { n.log.Warn("failed to compact node", "address", rn.address.String()) continue @@ -50,7 +51,7 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error */ t := msg["t"].(string) - n.queueMsg(rn, makeResponse(t, map[string]interface{}{ + n.queueMsg(rn, krpc.MakeResponse(t, map[string]interface{}{ "id": string(neighbour), "token": token, "nodes": "", @@ -60,7 +61,7 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error //nodes := n.rTable.get(50) /* fmt.Printf("sending get_peers for %s to %d nodes\n", *th, len(nodes)) - q := makeQuery(newTransactionID(), "get_peers", map[string]interface{}{ + q := krpc.MakeQuery(newTransactionID(), "get_peers", map[string]interface{}{ "id": string(id), "info_hash": string(*th), }) @@ -72,22 +73,17 @@ func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error } func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) error { - a, err := getMapKey(msg, "a") + a, err := krpc.GetMap(msg, "a") if err != nil { return err } - err = checkKeys(a, [][]string{ - {"info_hash", "string"}, - {"port", "int"}, - {"token", "string"}, - }) n.log.Debug("announce_peer", "source", rn) - if impliedPort, err := getIntKey(a, "implied_port"); err == nil { + if impliedPort, err := krpc.GetInt(a, "implied_port"); err == nil { if impliedPort != 0 { // Use the port in the message - host, _, err := net.SplitHostPort(rn.address.String()) + host, _, err := net.SplitHostPort(rn.addr.String()) if err != nil { return err } @@ -96,13 +92,13 @@ func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) er return fmt.Errorf("ignoring port 0") } addr, err := net.ResolveUDPAddr(n.family, fmt.Sprintf("%s:%d", host, newPort)) - rn = remoteNode{address: addr, id: rn.id} + rn = remoteNode{addr: addr, id: rn.id} } } // TODO do we reply? - ihStr, err := getStringKey(a, "info_hash") + ihStr, err := krpc.GetString(a, "info_hash") if err != nil { return err } @@ -111,8 +107,7 @@ func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) er n.log.Warn("invalid torrent", "infohash", ihStr) } - p := Peer{Node: rn, Infohash: *ih} - n.log.Info("anounce_peer", p) + p := Peer{Addr: rn.addr, ID: rn.id, Infohash: *ih} if n.OnAnnouncePeer != nil { go n.OnAnnouncePeer(p) } @@ -121,9 +116,6 @@ func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) er func (n *Node) onFindNodeResponse(rn remoteNode, msg map[string]interface{}) { r := msg["r"].(map[string]interface{}) - if err := checkKey(r, "id", "string"); err != nil { - return - } nodes := r["nodes"].(string) n.processFindNodeResults(rn, nodes) } diff --git a/dht/node.go b/dht/node.go index db61eed..f8e6113 100644 --- a/dht/node.go +++ b/dht/node.go @@ -7,6 +7,7 @@ import ( "time" "github.com/felix/dhtsearch/bencode" + "github.com/felix/dhtsearch/krpc" "github.com/felix/logger" "golang.org/x/time/rate" ) @@ -42,7 +43,7 @@ type Node struct { // NewNode creates a new DHT node func NewNode(opts ...Option) (n *Node, err error) { - id := randomInfoHash() + id := GenInfohash() k, err := newRoutingTable(id, 2000) if err != nil { @@ -159,7 +160,7 @@ func (n *Node) bootstrap() { n.log.Error("failed to parse bootstrap address", "error", err) continue } - rn := &remoteNode{address: addr} + rn := &remoteNode{addr: addr} n.findNode(rn, n.id) } } @@ -186,7 +187,7 @@ func (n *Node) packetWriter() { } func (n *Node) findNode(rn *remoteNode, id Infohash) { - target := randomInfoHash() + target := GenInfohash() n.sendQuery(rn, "find_node", map[string]interface{}{ "id": string(id), "target": string(target), @@ -207,10 +208,9 @@ func (n *Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) return nil } - t := newTransactionID() - //n.log.Debug("sending message", "type", qType, "remote", rn) + t := krpc.NewTransactionID() - data := makeQuery(t, qType, a) + data := krpc.MakeQuery(t, qType, a) b, err := bencode.Encode(data) if err != nil { return err @@ -218,43 +218,38 @@ func (n *Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) //fmt.Printf("sending %s to %s\n", qType, rn.String()) n.packetsOut <- packet{ data: b, - raddr: rn.address, + raddr: rn.addr, } return nil } // Parse a KRPC packet into a message -func (n *Node) processPacket(p packet) { - data, err := bencode.Decode(p.data) +func (n *Node) processPacket(p packet) error { + response, _, err := bencode.DecodeDict(p.data, 0) if err != nil { - return - } - - response, ok := data.(map[string]interface{}) - if !ok { - n.log.Debug("failed to parse packet", "error", "response is not dict") - return + return err } - if err := checkKeys(response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil { - n.log.Debug("failed to parse packet", "error", err) - return + y, err := krpc.GetString(response, "y") + if err != nil { + return err } - switch response["y"].(string) { + switch y { case "q": err = n.handleRequest(p.raddr, response) case "r": err = n.handleResponse(p.raddr, response) case "e": - n.handleError(p.raddr, response) + err = n.handleError(p.raddr, response) default: n.log.Warn("missing request type") - return + return nil } if err != nil { n.log.Warn("failed to process packet", "error", err) } + return err } // bencode data and send @@ -265,24 +260,24 @@ func (n *Node) queueMsg(rn remoteNode, data map[string]interface{}) error { } n.packetsOut <- packet{ data: b, - raddr: rn.address, + raddr: rn.addr, } return nil } // handleRequest handles the requests received from udp. func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error { - q, err := getStringKey(m, "q") + q, err := krpc.GetString(m, "q") if err != nil { return err } - a, err := getMapKey(m, "a") + a, err := krpc.GetMap(m, "a") if err != nil { return err } - id, err := getStringKey(a, "id") + id, err := krpc.GetString(a, "id") if err != nil { return err } @@ -296,7 +291,7 @@ func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error { return nil } - rn := &remoteNode{address: addr, id: *ih} + rn := &remoteNode{addr: addr, id: *ih} switch q { case "ping": @@ -318,11 +313,11 @@ func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error { // handleResponse handles responses received from udp. func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { - r, err := getMapKey(m, "r") + r, err := krpc.GetMap(m, "r") if err != nil { return err } - id, err := getStringKey(r, "id") + id, err := krpc.GetString(r, "id") if err != nil { return err } @@ -331,9 +326,9 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { return err } - rn := &remoteNode{address: addr, id: *ih} + rn := &remoteNode{addr: addr, id: *ih} - nodes, err := getStringKey(r, "nodes") + nodes, err := krpc.GetString(r, "nodes") // find_nodes/get_peers response with nodes if err == nil { n.onFindNodeResponse(*rn, m) @@ -342,12 +337,12 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { return nil } - values, err := getListKey(r, "values") + values, err := krpc.GetList(r, "values") // get_peers response if err == nil { n.log.Debug("get_peers response", "source", rn) for _, v := range values { - addr := decodeCompactNodeAddr(v.(string)) + addr := krpc.DecodeCompactNodeAddr(v.(string)) n.log.Debug("unhandled get_peer request", "addres", addr) // TODO new peer needs to be matched to previous get_peers request @@ -359,20 +354,20 @@ func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { } // handleError handles errors received from udp. -func (n *Node) handleError(addr net.Addr, m map[string]interface{}) bool { - if err := checkKey(m, "e", "list"); err != nil { - return false +func (n *Node) handleError(addr net.Addr, m map[string]interface{}) error { + e, err := krpc.GetList(m, "e") + if err != nil { + return err } - e := m["e"].([]interface{}) if len(e) != 2 { - return false + return fmt.Errorf("error packet wrong length %d", len(e)) } code := e[0].(int64) msg := e[1].(string) n.log.Debug("error packet", "address", addr.String(), "code", code, "error", msg) - return true + return nil } // Process another node's response to a find_node query. @@ -392,7 +387,7 @@ func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) { // We got a byte array in groups of 26 or 38 for i := 0; i < len(nodeList); i += nodeLength { id := nodeList[i : i+ihLength] - addrStr := decodeCompactNodeAddr(nodeList[i+ihLength : i+nodeLength]) + addrStr := krpc.DecodeCompactNodeAddr(nodeList[i+ihLength : i+nodeLength]) ih, err := InfohashFromString(id) if err != nil { @@ -406,7 +401,7 @@ func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) { continue } - rn := &remoteNode{address: addr, id: *ih} + rn := &remoteNode{addr: addr, id: *ih} n.rTable.add(rn) } } diff --git a/dht/peer.go b/dht/peer.go index f9669ba..42e8438 100644 --- a/dht/peer.go +++ b/dht/peer.go @@ -1,13 +1,18 @@ package dht -import "fmt" +import ( + "fmt" + "net" +) // Peer on DHT network type Peer struct { - Node remoteNode + Addr net.Addr + ID Infohash Infohash Infohash } +// String implements fmt.Stringer func (p Peer) String() string { - return fmt.Sprintf("%s (%s)", p.Infohash, p.Node) + return fmt.Sprintf("%s (%s)", p.Infohash, p.Addr.String()) } diff --git a/dht/remote_node.go b/dht/remote_node.go index 5bb2585..4bb9319 100644 --- a/dht/remote_node.go +++ b/dht/remote_node.go @@ -6,12 +6,11 @@ import ( ) type remoteNode struct { - address net.Addr - id Infohash - //lastSeen time.Time + addr net.Addr + id Infohash } // String implements fmt.Stringer func (r remoteNode) String() string { - return fmt.Sprintf("%s (%s)", r.id.String(), r.address.String()) + return fmt.Sprintf("%s (%s)", r.id.String(), r.addr.String()) } diff --git a/dht/routing_table.go b/dht/routing_table.go index 3c1f2d2..b10574c 100644 --- a/dht/routing_table.go +++ b/dht/routing_table.go @@ -73,10 +73,10 @@ func (k *routingTable) add(rn *remoteNode) { k.Lock() defer k.Unlock() - if _, ok := k.addresses[rn.address.String()]; ok { + if _, ok := k.addresses[rn.addr.String()]; ok { return } - k.addresses[rn.address.String()] = rn + k.addresses[rn.addr.String()] = rn item := &rItem{ value: rn, @@ -88,7 +88,7 @@ func (k *routingTable) add(rn *remoteNode) { if len(k.items) > k.max { for i := k.max - 1; i < len(k.items); i++ { old := k.items[i] - delete(k.addresses, old.value.address.String()) + delete(k.addresses, old.value.addr.String()) heap.Remove(&k.items, i) } } diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go index 77c0a17..1eeeca3 100644 --- a/dht/routing_table_test.go +++ b/dht/routing_table_test.go @@ -32,7 +32,7 @@ func TestPriorityQueue(t *testing.T) { t.Errorf("failed to create infohash: %s\n", err) } addr, _ := net.ResolveUDPAddr("udp", fmt.Sprintf("0.0.0.0:%d", i)) - pq.add(&remoteNode{id: *iht, address: addr}) + pq.add(&remoteNode{id: *iht, addr: addr}) } if len(pq.items) != len(pq.addresses) { diff --git a/http.go b/http.go deleted file mode 100644 index 32fdae9..0000000 --- a/http.go +++ /dev/null @@ -1,299 +0,0 @@ -package dhtsearch - -import ( - "encoding/json" - "expvar" - "fmt" - "net/http" - "strconv" -) - -type results struct { - Page int `json:"page"` - PageSize int `json:"page_size"` - Torrents []Torrent `json:"torrents"` -} - -func indexHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "public") - if r.URL.Path != "/" { - w.WriteHeader(404) - return - } - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(200) - w.Write(html) -} - -func statsHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("Cache-Control", "no-cache") - w.WriteHeader(200) - fmt.Fprintf(w, "{") - first := true - expvar.Do(func(kv expvar.KeyValue) { - if kv.Key == "cmdline" || kv.Key == "memstats" { - return - } - if !first { - fmt.Fprintf(w, ",") - } - first = false - fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value) - }) - fmt.Fprintf(w, "}") -} - -func searchHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("Cache-Control", "no-cache") - - offset := 0 - page := 1 - var err error - pStr := r.URL.Query().Get("page") - if pStr != "" { - page, err = strconv.Atoi(pStr) - if err != nil { - fmt.Printf("Failed to parse page: %q\n", err) - } - offset = (page - 1) * 50 - } - - if q := r.URL.Query().Get("q"); q != "" { - torrents, err := torrentsByName(q, offset) - if err != nil { - w.WriteHeader(500) - fmt.Printf("Error: %q\n", err) - return - } - w.WriteHeader(200) - json.NewEncoder(w).Encode(results{Page: page, PageSize: Config.ResultsPageSize, Torrents: torrents}) - return - } - - if tag := r.URL.Query().Get("tag"); tag != "" { - torrents, err := torrentsByTag(tag, offset) - if err != nil { - w.WriteHeader(500) - fmt.Printf("Error: %q\n", err) - return - } - w.WriteHeader(200) - json.NewEncoder(w).Encode(results{Page: page, PageSize: Config.ResultsPageSize, Torrents: torrents}) - return - } - - w.WriteHeader(406) - json.NewEncoder(w).Encode("Query required") -} - -var html = []byte(` - - - - DHT search - - - - - -
- -
-
-
- - - -`) diff --git a/infohash_test.go b/infohash_test.go deleted file mode 100644 index b62f3e2..0000000 --- a/infohash_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package dhtsearch - -import ( - "testing" -) - -var hashes = []struct { - s string - valid bool -}{ - {"59066769b9ad42da2e508611c33d7c4480b3857b", true}, - {"59066769b9ad42da2e508611c33d7c4480b3857", false}, - {"59066769b9ad42da2e508611c33d7c4480b385", false}, - {"59066769b9ad42da2e508611c33d7c4480b3857k", false}, - {"5906676b99a4d2d2ae506811c33d7c4480b8357b", true}, -} - -func TestGenNeighbour(t *testing.T) { - for _, test := range hashes { - r := genNeighbour(test.s) - if r != test.valid { - t.Errorf("isValidInfoHash(%q) => %v expected %v", test.s, r, test.valid) - } - } -} - -func TestIsValidInfoHash(t *testing.T) { - for _, test := range hashes { - r := isValidInfoHash(test.s) - if r != test.valid { - t.Errorf("isValidInfoHash(%q) => %v, expected %v", test.s, r, test.valid) - } - } -} - -func TestDecodeInfoHash(t *testing.T) { - for _, test := range hashes { - _, err := decodeInfoHash(test.s) - if (err == nil) != test.valid { - t.Errorf("decodeInfoHash(%q) => %v expected %v", test.s, err, test.valid) - } - } -} diff --git a/krpc/krpc.go b/krpc/krpc.go new file mode 100644 index 0000000..a766fcf --- /dev/null +++ b/krpc/krpc.go @@ -0,0 +1,177 @@ +package krpc + +import ( + "errors" + "fmt" + "math/rand" + "net" + "strconv" +) + +const transIDBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func NewTransactionID() string { + b := make([]byte, 2) + for i := range b { + b[i] = transIDBytes[rand.Int63()%int64(len(transIDBytes))] + } + return string(b) +} + +// makeQuery returns a query-formed data. +func MakeQuery(transaction, query string, data map[string]interface{}) map[string]interface{} { + return map[string]interface{}{ + "t": transaction, + "y": "q", + "q": query, + "a": data, + } +} + +// makeResponse returns a response-formed data. +func MakeResponse(transaction string, data map[string]interface{}) map[string]interface{} { + return map[string]interface{}{ + "t": transaction, + "y": "r", + "r": data, + } +} + +func GetString(data map[string]interface{}, key string) (string, error) { + val, ok := data[key] + if !ok { + return "", fmt.Errorf("krpc: missing key %s", key) + } + out, ok := val.(string) + if !ok { + return "", fmt.Errorf("krpc: key type mismatch") + } + return out, nil +} + +func GetInt(data map[string]interface{}, key string) (int, error) { + val, ok := data[key] + if !ok { + return 0, fmt.Errorf("krpc: missing key %s", key) + } + out, ok := val.(int) + if !ok { + return 0, fmt.Errorf("krpc: key type mismatch") + } + return out, nil +} + +func GetMap(data map[string]interface{}, key string) (map[string]interface{}, error) { + val, ok := data[key] + if !ok { + return nil, fmt.Errorf("krpc: missing key %s", key) + } + out, ok := val.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("krpc: key type mismatch") + } + return out, nil +} + +func GetList(data map[string]interface{}, key string) ([]interface{}, error) { + val, ok := data[key] + if !ok { + return nil, fmt.Errorf("krpc: missing key %s", key) + } + out, ok := val.([]interface{}) + if !ok { + return nil, fmt.Errorf("krpc: key type mismatch") + } + return out, nil +} + +// parseKeys parses keys. It just wraps parseKey. +func checkKeys(data map[string]interface{}, pairs [][]string) (err error) { + for _, args := range pairs { + key, t := args[0], args[1] + if err = checkKey(data, key, t); err != nil { + break + } + } + return err +} + +// parseKey parses the key in dict data. `t` is type of the keyed value. +// It's one of "int", "string", "map", "list". +func checkKey(data map[string]interface{}, key string, t string) error { + val, ok := data[key] + if !ok { + return fmt.Errorf("krpc: missing key %s", key) + } + + switch t { + case "string": + _, ok = val.(string) + case "int": + _, ok = val.(int) + case "map": + _, ok = val.(map[string]interface{}) + case "list": + _, ok = val.([]interface{}) + default: + return errors.New("krpc: invalid type") + } + + if !ok { + return errors.New("krpc: key type mismatch") + } + + return nil +} + +// Swiped from nictuku +func DecodeCompactNodeAddr(cni string) string { + if len(cni) == 6 { + return fmt.Sprintf("%d.%d.%d.%d:%d", cni[0], cni[1], cni[2], cni[3], (uint16(cni[4])<<8)|uint16(cni[5])) + } else if len(cni) == 18 { + b := []byte(cni[:16]) + return fmt.Sprintf("[%s]:%d", net.IP.String(b), (uint16(cni[16])<<8)|uint16(cni[17])) + } else { + return "" + } +} + +func EncodeCompactNodeAddr(addr string) string { + var a []uint8 + host, port, _ := net.SplitHostPort(addr) + ip := net.ParseIP(host) + if ip == nil { + return "" + } + aa, _ := strconv.ParseUint(port, 10, 16) + c := uint16(aa) + if ip2 := net.IP.To4(ip); ip2 != nil { + a = make([]byte, net.IPv4len+2, net.IPv4len+2) + copy(a, ip2[0:net.IPv4len]) // ignore bytes IPv6 bytes if it's IPv4. + a[4] = byte(c >> 8) + a[5] = byte(c) + } else { + a = make([]byte, net.IPv6len+2, net.IPv6len+2) + copy(a, ip) + a[16] = byte(c >> 8) + a[17] = byte(c) + } + return string(a) +} + +func int2bytes(val int64) []byte { + data, j := make([]byte, 8), -1 + for i := 0; i < 8; i++ { + shift := uint64((7 - i) * 8) + data[i] = byte((val & (0xff << shift)) >> shift) + + if j == -1 && data[i] != 0 { + j = i + } + } + + if j != -1 { + return data[j:] + } + return data[:1] +} diff --git a/krpc/krpc_test.go b/krpc/krpc_test.go new file mode 100644 index 0000000..c46d70c --- /dev/null +++ b/krpc/krpc_test.go @@ -0,0 +1,30 @@ +package krpc + +import ( + "encoding/hex" + "testing" +) + +func TestCompactNodeAddr(t *testing.T) { + + tests := []struct { + in string + out string + }{ + {in: "192.168.1.1:6881", out: "c0a801011ae1"}, + {in: "[2001:9372:434a:800::2]:6881", out: "20019372434a080000000000000000021ae1"}, + } + + for _, tt := range tests { + r := encodeCompactNodeAddr(tt.in) + out, _ := hex.DecodeString(tt.out) + if r != string(out) { + t.Errorf("encodeCompactNodeAddr(%s) => %x, expected %s", tt.in, r, tt.out) + } + + s := decodeCompactNodeAddr(r) + if s != tt.in { + t.Errorf("decodeCompactNodeAddr(%x) => %s, expected %s", r, s, tt.in) + } + } +} diff --git a/models/tag.go b/models/tag.go new file mode 100644 index 0000000..3675c75 --- /dev/null +++ b/models/tag.go @@ -0,0 +1,5 @@ +package models + +type tagStore interface { + saveTag(string) (int, error) +} diff --git a/models/torrent.go b/models/torrent.go new file mode 100644 index 0000000..960a6de --- /dev/null +++ b/models/torrent.go @@ -0,0 +1,98 @@ +package models + +import ( + "bytes" + "crypto/sha1" + "encoding/hex" + "fmt" + "os" + "strings" + "time" + + "github.com/felix/dhtsearch/bencode" + "github.com/felix/dhtsearch/dht" + "github.com/felix/dhtsearch/krpc" +) + +// Data for persistent storage +type Torrent struct { + ID int `json:"-"` + InfoHash string `json:"infohash"` + Name string `json:"name"` + Files []File `json:"files" db:"-"` + Size int `json:"size"` + Seen time.Time `json:"seen"` + Tags []string `json:"tags" db:"-"` +} + +type File struct { + ID int `json:"-"` + Path string `json:"path"` + Size int `json:"size"` + TorrentID int `json:"torrent_id" db:"torrent_id"` +} + +type torrentStore interface { + saveTorrent(*Torrent) error + torrentsByHash(hashes dht.Infohash, offset, limit int) (*Torrent, error) + torrentsByName(query string, offset, limit int) ([]*Torrent, error) + torrentsByTags(tags []string, offset, limit int) ([]*Torrent, error) +} + +func validMetadata(ih dht.Infohash, md []byte) bool { + info := sha1.Sum(md) + return bytes.Equal([]byte(ih), info[:]) +} + +func TorrentFromMetadata(ih dht.Infohash, md []byte) (*Torrent, error) { + if !validMetadata(ih, md) { + return nil, fmt.Errorf("infohash does not match metadata") + } + info, _, err := bencode.DecodeDict(md, 0) + if err != nil { + return nil, err + } + + // Get the directory or advisory filename + name, err := krpc.GetString(info, "name") + if err != nil { + return nil, err + } + + bt := Torrent{ + InfoHash: hex.EncodeToString([]byte(ih)), + Name: name, + } + + if files, err := krpc.GetList(info, "files"); err == nil { + // Multiple file mode + bt.Files = make([]File, len(files)) + + // Files is a list of dicts + for i, item := range files { + file := item.(map[string]interface{}) + + // Paths is a list of strings + paths := file["path"].([]interface{}) + path := make([]string, len(paths)) + for j, p := range paths { + path[j] = p.(string) + } + + fSize := file["length"].(int) + bt.Files[i] = File{ + // Assume Unix path sep? + Path: strings.Join(path[:], string(os.PathSeparator)), + Size: fSize, + } + // Ensure the torrent size totals all files' + bt.Size = bt.Size + fSize + } + } else if length, err := krpc.GetInt(info, "length"); err == nil { + // Single file mode + bt.Size = length + } else { + return nil, fmt.Errorf("found neither length or files") + } + return &bt, nil +} diff --git a/torrent.go b/torrent.go deleted file mode 100644 index bb8dcff..0000000 --- a/torrent.go +++ /dev/null @@ -1,190 +0,0 @@ -package dhtsearch - -import ( - "fmt" - "time" -) - -// Data for persistent storage -type Torrent struct { - Id int `json:"-"` - InfoHash string `json:"infohash"` - Name string `json:"name"` - Files []File `json:"files" db:"-"` - Size int `json:"size"` - Seen time.Time `json:"seen"` - Tags []string `json:"tags" db:"-"` -} - -type File struct { - Id int `json:"-"` - Path string `json:"path"` - Size int `json:"size"` - TorrentId int `json:"torrent_id" db:"torrent_id"` -} - -func torrentExists(ih string) bool { - rows, err := DB.Query(sqlGetTorrent, fmt.Sprintf("%s", ih)) - defer rows.Close() - if err != nil { - fmt.Printf("Failed to exec SQL: %q\n", err) - return false - } - return rows.Next() -} - -func (t *Torrent) save() error { - tx, err := DB.Begin() - if err != nil { - fmt.Printf("Transaction err %q\n", err) - } - defer tx.Commit() - - var torrentId int - - // Need to turn infohash into string here - err = tx.QueryRow(sqlInsertTorrent, t.Name, fmt.Sprintf("%s", t.InfoHash), t.Size).Scan(&torrentId) - if err != nil { - tx.Rollback() - return err - } - - // Write tags - for _, tag := range t.Tags { - tagId, err := createTag(tag) - if err != nil { - tx.Rollback() - return err - } - _, err = tx.Exec(sqlInsertTagTorrent, tagId, torrentId) - if err != nil { - tx.Rollback() - return err - } - } - - // Write files - for _, f := range t.Files { - _, err := tx.Exec(sqlInsertFile, torrentId, f.Path, f.Size) - if err != nil { - tx.Rollback() - return err - } - } - - // Should this be outside the transaction? - tx.Exec(sqlUpdateFTSVectors, torrentId) - if err != nil { - tx.Rollback() - return err - } - return nil -} - -// Fill in a torrents dependant data -func (t *Torrent) load() (err error) { - // Files - t.Files = []File{} - err = DB.Select(&t.Files, sqlSelectFiles, t.Id) - if err != nil { - fmt.Printf("Error selecting files %s\n", err) - } - // t.Files = files - - // Tags - t.Tags = []string{} - err = DB.Select(&t.Tags, sqlSelectTags, t.Id) - if err != nil { - fmt.Printf("Error selecting tags %s\n", err) - } - return -} - -func torrentsByName(query string, offset int) ([]Torrent, error) { - torrents := []Torrent{} - err := DB.Select(&torrents, sqlSearchTorrents, fmt.Sprintf("%%%s%%", query), offset) - if err != nil { - return nil, err - } - fmt.Printf("Search for %q returned %d torrents\n", query, len(torrents)) - - for idx, _ := range torrents { - torrents[idx].load() - } - return torrents, nil -} - -func torrentsByTag(tag string, offset int) ([]Torrent, error) { - torrents := []Torrent{} - err := DB.Select(&torrents, sqlTorrentsByTag, tag, offset) - if err != nil { - return nil, err - } - fmt.Printf("Search for tag %q returned %d torrents\n", tag, len(torrents)) - - for idx, _ := range torrents { - torrents[idx].load() - } - return torrents, nil -} - -const ( - sqlGetTorrent = `update torrents - set seen = now() - where infohash = $1 - returning id` - - sqlInsertTorrent = `insert into torrents ( - name, infohash, size, seen - ) values ( - $1, $2, $3, now() - ) on conflict (infohash) do - update set seen = now() - returning id` - - sqlUpdateFTSVectors = `update torrents - set tsv = sub.tsv from ( - select t.id, - setweight(to_tsvector(translate(t.name, '._-', ' ')), 'A') || - setweight(to_tsvector(translate(string_agg(coalesce(f.path, ''), ' '), './_-', ' ')), 'B') as tsv - from torrents t - left join files f on t.id = f.torrent_id - where t.id = $1 - group by t.id - ) as sub - where sub.id = torrents.id` - - sqlSearchTorrents = ` - select t.id, t.infohash, t.name, t.size, t.seen - from torrents t - where t.tsv @@ plainto_tsquery($1) - order by ts_rank(tsv, plainto_tsquery($1)) desc, t.seen desc - limit 50 offset $2` - - sqlTorrentsByTag = ` - select t.id, t.infohash, t.name, t.size, t.seen - from torrents t - inner join tags_torrents tt on t.id = tt.torrent_id - inner join tags ta on tt.tag_id = ta.id - where ta.name = $1 group by t.id - order by seen desc - limit 50 offset $2` - - sqlSelectFiles = `select * from files - where torrent_id = $1 - order by path asc` - - sqlInsertFile = `insert into files ( - torrent_id, path, size - ) values($1, $2, $3)` - - sqlSelectTags = `select name - from tags t - inner join tags_torrents tt on t.id = tt.tag_id - where tt.torrent_id = $1` - - sqlInsertTagTorrent = `insert into tags_torrents ( - tag_id, torrent_id - ) values ($1, $2) - on conflict do nothing` -) diff --git a/util.go b/util.go deleted file mode 100644 index 1548fab..0000000 --- a/util.go +++ /dev/null @@ -1,92 +0,0 @@ -package dhtsearch - -import ( - "errors" - "fmt" - "net" -) - -// makeQuery returns a query-formed data. -func makeQuery(t, q string, a map[string]interface{}) map[string]interface{} { - return map[string]interface{}{ - "t": t, - "y": "q", - "q": q, - "a": a, - } -} - -// makeResponse returns a response-formed data. -func makeResponse(t string, r map[string]interface{}) map[string]interface{} { - return map[string]interface{}{ - "t": t, - "y": "r", - "r": r, - } -} - -// parseKeys parses keys. It just wraps parseKey. -func parseKeys(data map[string]interface{}, pairs [][]string) error { - for _, args := range pairs { - key, t := args[0], args[1] - if err := parseKey(data, key, t); err != nil { - return err - } - } - return nil -} - -// parseKey parses the key in dict data. `t` is type of the keyed value. -// It's one of "int", "string", "map", "list". -func parseKey(data map[string]interface{}, key string, t string) error { - val, ok := data[key] - if !ok { - return errors.New("lack of key") - } - - switch t { - case "string": - _, ok = val.(string) - case "int": - _, ok = val.(int) - case "map": - _, ok = val.(map[string]interface{}) - case "list": - _, ok = val.([]interface{}) - default: - panic("invalid type") - } - - if !ok { - return errors.New("invalid key type") - } - - return nil -} - -// parseMessage parses the basic data received from udp. -// It returns a map value. -func parseMessage(data interface{}) (map[string]interface{}, error) { - response, ok := data.(map[string]interface{}) - if !ok { - return nil, errors.New("response is not dict") - } - - if err := parseKeys(response, [][]string{{"t", "string"}, {"y", "string"}}); err != nil { - return nil, err - } - - return response, nil -} - -// Swiped from nictuku -func compactNodeInfoToString(cni string) string { - if len(cni) == 6 { - return fmt.Sprintf("%d.%d.%d.%d:%d", cni[0], cni[1], cni[2], cni[3], (uint16(cni[4])<<8)|uint16(cni[5])) - } else if len(cni) == 18 { - b := []byte(cni[:16]) - return fmt.Sprintf("[%s]:%d", net.IP.String(b), (uint16(cni[16])<<8)|uint16(cni[17])) - } else { - return "" - } -} -- cgit v1.2.3