aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFelix Hanley <felix@userspace.com.au>2018-02-26 11:25:32 +0000
committerFelix Hanley <felix@userspace.com.au>2018-02-26 11:25:32 +0000
commit40732d76b81aa16ed38d855e93ce82387d29643f (patch)
tree9ce244a2635fdf549291d059a9bcc542283606cb
parent5921a5c3a7829359703fa16c29ae6520408ef47c (diff)
downloaddhtsearch-40732d76b81aa16ed38d855e93ce82387d29643f.tar.gz
dhtsearch-40732d76b81aa16ed38d855e93ce82387d29643f.tar.bz2
Move bt client
-rw-r--r--bt/options.go50
-rw-r--r--bt/worker.go (renamed from bittorrent/client.go)276
2 files changed, 179 insertions, 147 deletions
diff --git a/bt/options.go b/bt/options.go
new file mode 100644
index 0000000..25ebc21
--- /dev/null
+++ b/bt/options.go
@@ -0,0 +1,50 @@
+package bt
+
+import (
+ "github.com/felix/dhtsearch/models"
+ "github.com/felix/logger"
+)
+
+type Option func(*Worker) error
+
+// SetNewTorrent sets the callback
+func SetOnNewTorrent(f func(models.Torrent)) Option {
+ return func(w *Worker) error {
+ w.OnNewTorrent = f
+ return nil
+ }
+}
+
+// SetPort sets the port to listen on
+func SetPort(p int) Option {
+ return func(w *Worker) error {
+ w.port = p
+ return nil
+ }
+}
+
+// SetIPv6 enables IPv6
+func SetIPv6(b bool) Option {
+ return func(w *Worker) error {
+ if b {
+ w.family = "tcp6"
+ }
+ return nil
+ }
+}
+
+// SetUDPTimeout sets the number of seconds to wait for UDP connections
+func SetTCPTimeout(s int) Option {
+ return func(w *Worker) error {
+ w.tcpTimeout = s
+ return nil
+ }
+}
+
+// SetLogger sets the logger
+func SetLogger(l logger.Logger) Option {
+ return func(w *Worker) error {
+ w.log = l
+ return nil
+ }
+}
diff --git a/bittorrent/client.go b/bt/worker.go
index f4ac28e..6d247f4 100644
--- a/bittorrent/client.go
+++ b/bt/worker.go
@@ -1,20 +1,18 @@
-package dhtsearch
-
-// Lifted and adapted from github.com/shiyanhui/dht
+package bt
import (
"bytes"
- "crypto/sha1"
"encoding/binary"
- "encoding/hex"
"errors"
+ "fmt"
"io"
"io/ioutil"
"net"
- "strings"
"time"
"github.com/felix/dhtsearch/bencode"
+ "github.com/felix/dhtsearch/krpc"
+ "github.com/felix/dhtsearch/models"
"github.com/felix/logger"
)
@@ -24,13 +22,13 @@ const (
)
const (
- // MsgRequest represents request message type
+ // MsgRequest marks a request message type
MsgRequest = iota
- // MsgData represents data message type
+ // MsgData marks a data message type
MsgData
- // MsgReject represents reject message type
+ // MsgReject marks a reject message type
MsgReject
- // MsgExtended represents it is a extended message
+ // MsgExtended marks it as an extended message
MsgExtended = 20
)
@@ -48,55 +46,65 @@ var handshakePrefix = []byte{
111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1,
}
-type btClient struct {
- pool chan chan peer
- log logger.Logger
+type Worker struct {
+ pool chan chan models.Peer
+ port int
+ family string
+ tcpTimeout int
+ OnNewTorrent func(t models.Torrent)
+ log logger.Logger
}
-func (bt *btClient) run(torrentCh chan<- *Torrent) error {
- peerCh := make(chan peer)
+func NewWorker(pool chan chan models.Peer, opts ...Option) (*Worker, error) {
+ var err error
+ w := &Worker{
+ pool: pool,
+ }
- if bt.log == nil {
- bt.log = logger.New(&logger.Options{
- Name: "bt",
- Level: logger.Info,
- })
+ // Set variadic options passed
+ for _, option := range opts {
+ err = option(w)
+ if err != nil {
+ return nil, err
+ }
}
+ return w, nil
+}
- go func() {
- for {
- // Signal we are ready for work
- bt.pool <- peerCh
+func (bt *Worker) Run() error {
+ peerCh := make(chan models.Peer)
- select {
- case p := <-peerCh:
- // Got work
- if len(p.id) != 20 {
- return
- }
- bt.log.Debug("fetching metadata", "peer", p.id)
- md, err := bt.fetchMetadata(p)
- if err != nil {
- bt.log.Error("failed to fetch metadata", "error", err)
- }
-
- t, err := decodeMetadata(p, md)
- if err != nil {
- bt.log.Error("failed to decode metadata", "error", err)
- }
- torrentCh <- t
+ for {
+ // Signal we are ready for work
+ bt.pool <- peerCh
+
+ select {
+ case p := <-peerCh:
+ // Got work
+ bt.log.Debug("worker got work", "peer", p)
+ md, err := bt.fetchMetadata(p)
+ if err != nil {
+ bt.log.Debug("failed to fetch metadata", "error", err)
+ continue
+ }
+ t, err := models.TorrentFromMetadata(p.Infohash, md)
+ if err != nil {
+ bt.log.Debug("failed to load torrent", "error", err)
+ continue
+ }
+ if bt.OnNewTorrent != nil {
+ go bt.OnNewTorrent(*t)
}
}
- }()
- return nil
+ }
}
// fetchMetadata fetchs medata info accroding to infohash from dht.
-func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) {
+func (bt *Worker) fetchMetadata(p models.Peer) (out []byte, err error) {
var (
length int
msgType byte
- piecesNum int
+ totalPieces int
pieces [][]byte
utMetadata int
metadataSize int
@@ -107,26 +115,48 @@ func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) {
recover()
}()
- infoHash := p.id
- address := p.address.String()
+ ll := bt.log.WithFields("address", p.Addr.String())
- dial, err := net.DialTimeout("tcp", address, time.Second*15)
+ ll.Debug("connecting")
+ dial, err := net.DialTimeout("tcp", p.Addr.String(), time.Second*15)
if err != nil {
return out, err
}
+ // Cast
conn := dial.(*net.TCPConn)
conn.SetLinger(0)
defer conn.Close()
+ ll.Debug("dialed")
data := bytes.NewBuffer(nil)
data.Grow(BlockSize)
+ ih := models.GenInfohash()
+
// TCP handshake
- if sendHandshake(conn, []byte(infoHash), []byte(genInfoHash())) != nil ||
- read(conn, 68, data) != nil ||
- onHandshake(data.Next(68)) != nil ||
- sendExtHandshake(conn) != nil {
- return
+ ll.Debug("sending handshake")
+ _, err = sendHandshake(conn, p.Infohash, ih)
+ if err != nil {
+ return nil, err
+ }
+
+ // Handle the handshake response
+ ll.Debug("handling handshake response")
+ err = read(conn, 68, data)
+ if err != nil {
+ return nil, err
+ }
+ next := data.Next(68)
+ ll.Debug("got next data")
+ if !(bytes.Equal(handshakePrefix[:20], next[:20]) && next[25]&0x10 != 0) {
+ ll.Debug("next data does not match", "next", next)
+ return nil, errors.New("invalid handshake response")
+ }
+
+ ll.Debug("sending ext handshake")
+ _, err = sendExtHandshake(conn)
+ if err != nil {
+ return nil, err
}
for {
@@ -166,13 +196,13 @@ func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) {
return out, err
}
- piecesNum = metadataSize / BlockSize
+ totalPieces = metadataSize / BlockSize
if metadataSize%BlockSize != 0 {
- piecesNum++
+ totalPieces++
}
- pieces = make([][]byte, piecesNum)
- go bt.requestPieces(conn, utMetadata, metadataSize, piecesNum)
+ pieces = make([][]byte, totalPieces)
+ go bt.requestPieces(conn, utMetadata, metadataSize, totalPieces)
continue
}
@@ -181,40 +211,40 @@ func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) {
return out, errors.New("no pieces found")
}
- d, index, err := bencode.DecodeDict(payload, 0)
+ dict, index, err := bencode.DecodeDict(payload, 0)
if err != nil {
return out, err
}
- dict := d.(map[string]interface{})
- err = parseKeys(dict, [][]string{{"msg_type", "int"}, {"piece", "int"}})
+ mt, err := krpc.GetInt(dict, "msg_type")
if err != nil {
return out, err
}
- if dict["msg_type"].(int) != MsgData {
+ if mt != MsgData {
continue
}
- piece := dict["piece"].(int)
+ piece, err := krpc.GetInt(dict, "piece")
+ if err != nil {
+ return out, err
+ }
+
pieceLen := length - 2 - index
- if (piece != piecesNum-1 && pieceLen != BlockSize) ||
- (piece == piecesNum-1 && pieceLen != metadataSize%BlockSize) {
- return out, errors.New("invalid piece count")
+ // Not last piece? should be full block
+ if totalPieces > 1 && piece != totalPieces-1 && pieceLen != BlockSize {
+ return out, fmt.Errorf("incomplete piece %d", piece)
+ }
+ // Last piece needs to equal remainder
+ if piece == totalPieces-1 && pieceLen != metadataSize%BlockSize {
+ return out, fmt.Errorf("incorrect final piece %d", piece)
}
pieces[piece] = payload[index:]
if bt.isDone(pieces) {
- metadataInfo := bytes.Join(pieces, nil)
-
- // Check the metadata
- info := sha1.Sum(metadataInfo)
- if !bytes.Equal([]byte(infoHash), info[:]) {
- return out, errors.New("metadata does not match infohash")
- }
- return metadataInfo, nil
+ return bytes.Join(pieces, nil), nil
}
default:
data.Reset()
@@ -222,50 +252,8 @@ func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) {
}
}
-func decodeMetadata(p peer, md []byte) (*Torrent, error) {
- metadata, err := bencode.Decode(md)
- if err != nil {
- return nil, err
- }
- info := metadata.(map[string]interface{})
-
- if _, ok := info["name"]; !ok {
- return nil, errors.New("Metadata missing name")
- }
-
- bt := Torrent{
- InfoHash: hex.EncodeToString([]byte(p.id)),
- Name: info["name"].(string),
- }
-
- if v, ok := info["files"]; ok {
- files := v.([]interface{})
- bt.Files = make([]File, len(files))
-
- for i, item := range files {
- f := item.(map[string]interface{})
- paths := f["path"].([]interface{})
- path := make([]string, len(paths))
- for j, p := range paths {
- path[j] = p.(string)
- }
- fSize := f["length"].(int)
- bt.Files[i] = File{
- // Assume Unix path sep
- Path: strings.Join(path[:], "/"),
- Size: fSize,
- }
- // Ensure the torrent size totals all files'
- bt.Size = bt.Size + fSize
- }
- } else if _, ok := info["length"]; ok {
- bt.Size = info["length"].(int)
- }
- return &bt, nil
-}
-
// isDone checks if all pieces are complete
-func (bt *btClient) isDone(pieces [][]byte) bool {
+func (bt *Worker) isDone(pieces [][]byte) bool {
for _, piece := range pieces {
if len(piece) == 0 {
return false
@@ -275,9 +263,8 @@ func (bt *btClient) isDone(pieces [][]byte) bool {
}
// read reads size-length bytes from conn to data.
-func read(conn *net.TCPConn, size int, data *bytes.Buffer) error {
+func read(conn net.Conn, size int, data io.Writer) error {
conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout)))
-
n, err := io.CopyN(data, conn, int64(size))
if err != nil || n != int64(size) {
return errors.New("read error")
@@ -286,7 +273,7 @@ func read(conn *net.TCPConn, size int, data *bytes.Buffer) error {
}
// readMessage gets a message from the tcp connection.
-func readMessage(conn *net.TCPConn, data *bytes.Buffer) (length int, err error) {
+func readMessage(conn net.Conn, data *bytes.Buffer) (length int, err error) {
if err = read(conn, 4, data); err != nil {
return length, err
}
@@ -305,27 +292,25 @@ func readMessage(conn *net.TCPConn, data *bytes.Buffer) (length int, err error)
}
// sendMessage sends data to the connection.
-func sendMessage(conn *net.TCPConn, data []byte) error {
+func sendMessage(conn net.Conn, data []byte) (int, error) {
length := int32(len(data))
buffer := bytes.NewBuffer(nil)
binary.Write(buffer, binary.BigEndian, length)
conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout)))
- b, err := conn.Write(append(buffer.Bytes(), data...))
- return err
+ return conn.Write(append(buffer.Bytes(), data...))
}
// sendHandshake sends handshake message to conn.
-func sendHandshake(conn *net.TCPConn, infoHash, peerID []byte) error {
+func sendHandshake(conn net.Conn, ih, id models.Infohash) (int, error) {
data := make([]byte, 68)
copy(data[:28], handshakePrefix)
- copy(data[28:48], infoHash)
- copy(data[48:], peerID)
+ copy(data[28:48], []byte(ih))
+ copy(data[48:], []byte(id))
conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout)))
- b, err := conn.Write(data)
- return err
+ return conn.Write(data)
}
// onHandshake handles the handshake response.
@@ -337,43 +322,40 @@ func onHandshake(data []byte) (err error) {
}
// sendExtHandshake requests for the ut_metadata and metadata_size.
-func sendExtHandshake(conn *net.TCPConn) error {
- data := append(
- []byte{MsgExtended, HandshakeBit},
- bencode.Encode(map[string]interface{}{
- "m": map[string]interface{}{"ut_metadata": 1},
- })...,
- )
+func sendExtHandshake(conn net.Conn) (int, error) {
+ m, err := bencode.EncodeDict(map[string]interface{}{
+ "m": map[string]interface{}{"ut_metadata": 1},
+ })
+ if err != nil {
+ return 0, err
+ }
+ data := append([]byte{MsgExtended, HandshakeBit}, m...)
return sendMessage(conn, data)
}
// getUTMetaSize returns the ut_metadata and metadata_size.
func getUTMetaSize(data []byte) (utMetadata int, metadataSize int, err error) {
- v, err := bencode.Decode(data)
+ dict, _, err := bencode.DecodeDict(data, 0)
if err != nil {
return utMetadata, metadataSize, err
}
- dict, ok := v.(map[string]interface{})
- if !ok {
- return utMetadata, metadataSize, errors.New("invalid dict")
+ m, err := krpc.GetMap(dict, "m")
+ if err != nil {
+ return utMetadata, metadataSize, err
}
- err = parseKeys(dict, [][]string{{"metadata_size", "int"}, {"m", "map"}})
+ utMetadata, err = krpc.GetInt(m, "ut_metadata")
if err != nil {
return utMetadata, metadataSize, err
}
- m := dict["m"].(map[string]interface{})
- err = parseKey(m, "ut_metadata", "int")
+ metadataSize, err = krpc.GetInt(dict, "metadata_size")
if err != nil {
return utMetadata, metadataSize, err
}
- utMetadata = m["ut_metadata"].(int)
- metadataSize = dict["metadata_size"].(int)
-
if metadataSize > MaxMetadataSize {
err = errors.New("metadata_size too long")
}
@@ -381,13 +363,13 @@ func getUTMetaSize(data []byte) (utMetadata int, metadataSize int, err error) {
}
// Request more pieces
-func (bt *btClient) requestPieces(conn *net.TCPConn, utMetadata int, metadataSize int, piecesNum int) {
+func (bt *Worker) requestPieces(conn net.Conn, utMetadata int, metadataSize int, totalPieces int) {
buffer := make([]byte, 1024)
- for i := 0; i < piecesNum; i++ {
+ for i := 0; i < totalPieces; i++ {
buffer[0] = MsgExtended
buffer[1] = byte(utMetadata)
- msg := bencode.Encode(map[string]interface{}{
+ msg, _ := bencode.EncodeDict(map[string]interface{}{
"msg_type": MsgRequest,
"piece": i,
})