diff options
| author | Felix Hanley <felix@userspace.com.au> | 2018-02-26 11:25:32 +0000 |
|---|---|---|
| committer | Felix Hanley <felix@userspace.com.au> | 2018-02-26 11:25:32 +0000 |
| commit | 40732d76b81aa16ed38d855e93ce82387d29643f (patch) | |
| tree | 9ce244a2635fdf549291d059a9bcc542283606cb | |
| parent | 5921a5c3a7829359703fa16c29ae6520408ef47c (diff) | |
| download | dhtsearch-40732d76b81aa16ed38d855e93ce82387d29643f.tar.gz dhtsearch-40732d76b81aa16ed38d855e93ce82387d29643f.tar.bz2 | |
Move bt client
| -rw-r--r-- | bt/options.go | 50 | ||||
| -rw-r--r-- | bt/worker.go (renamed from bittorrent/client.go) | 276 |
2 files changed, 179 insertions, 147 deletions
diff --git a/bt/options.go b/bt/options.go new file mode 100644 index 0000000..25ebc21 --- /dev/null +++ b/bt/options.go @@ -0,0 +1,50 @@ +package bt + +import ( + "github.com/felix/dhtsearch/models" + "github.com/felix/logger" +) + +type Option func(*Worker) error + +// SetNewTorrent sets the callback +func SetOnNewTorrent(f func(models.Torrent)) Option { + return func(w *Worker) error { + w.OnNewTorrent = f + return nil + } +} + +// SetPort sets the port to listen on +func SetPort(p int) Option { + return func(w *Worker) error { + w.port = p + return nil + } +} + +// SetIPv6 enables IPv6 +func SetIPv6(b bool) Option { + return func(w *Worker) error { + if b { + w.family = "tcp6" + } + return nil + } +} + +// SetUDPTimeout sets the number of seconds to wait for UDP connections +func SetTCPTimeout(s int) Option { + return func(w *Worker) error { + w.tcpTimeout = s + return nil + } +} + +// SetLogger sets the logger +func SetLogger(l logger.Logger) Option { + return func(w *Worker) error { + w.log = l + return nil + } +} diff --git a/bittorrent/client.go b/bt/worker.go index f4ac28e..6d247f4 100644 --- a/bittorrent/client.go +++ b/bt/worker.go @@ -1,20 +1,18 @@ -package dhtsearch - -// Lifted and adapted from github.com/shiyanhui/dht +package bt import ( "bytes" - "crypto/sha1" "encoding/binary" - "encoding/hex" "errors" + "fmt" "io" "io/ioutil" "net" - "strings" "time" "github.com/felix/dhtsearch/bencode" + "github.com/felix/dhtsearch/krpc" + "github.com/felix/dhtsearch/models" "github.com/felix/logger" ) @@ -24,13 +22,13 @@ const ( ) const ( - // MsgRequest represents request message type + // MsgRequest marks a request message type MsgRequest = iota - // MsgData represents data message type + // MsgData marks a data message type MsgData - // MsgReject represents reject message type + // MsgReject marks a reject message type MsgReject - // MsgExtended represents it is a extended message + // MsgExtended marks it as an extended message MsgExtended = 20 ) @@ -48,55 +46,65 @@ var handshakePrefix = []byte{ 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1, } -type btClient struct { - pool chan chan peer - log logger.Logger +type Worker struct { + pool chan chan models.Peer + port int + family string + tcpTimeout int + OnNewTorrent func(t models.Torrent) + log logger.Logger } -func (bt *btClient) run(torrentCh chan<- *Torrent) error { - peerCh := make(chan peer) +func NewWorker(pool chan chan models.Peer, opts ...Option) (*Worker, error) { + var err error + w := &Worker{ + pool: pool, + } - if bt.log == nil { - bt.log = logger.New(&logger.Options{ - Name: "bt", - Level: logger.Info, - }) + // Set variadic options passed + for _, option := range opts { + err = option(w) + if err != nil { + return nil, err + } } + return w, nil +} - go func() { - for { - // Signal we are ready for work - bt.pool <- peerCh +func (bt *Worker) Run() error { + peerCh := make(chan models.Peer) - select { - case p := <-peerCh: - // Got work - if len(p.id) != 20 { - return - } - bt.log.Debug("fetching metadata", "peer", p.id) - md, err := bt.fetchMetadata(p) - if err != nil { - bt.log.Error("failed to fetch metadata", "error", err) - } - - t, err := decodeMetadata(p, md) - if err != nil { - bt.log.Error("failed to decode metadata", "error", err) - } - torrentCh <- t + for { + // Signal we are ready for work + bt.pool <- peerCh + + select { + case p := <-peerCh: + // Got work + bt.log.Debug("worker got work", "peer", p) + md, err := bt.fetchMetadata(p) + if err != nil { + bt.log.Debug("failed to fetch metadata", "error", err) + continue + } + t, err := models.TorrentFromMetadata(p.Infohash, md) + if err != nil { + bt.log.Debug("failed to load torrent", "error", err) + continue + } + if bt.OnNewTorrent != nil { + go bt.OnNewTorrent(*t) } } - }() - return nil + } } // fetchMetadata fetchs medata info accroding to infohash from dht. -func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) { +func (bt *Worker) fetchMetadata(p models.Peer) (out []byte, err error) { var ( length int msgType byte - piecesNum int + totalPieces int pieces [][]byte utMetadata int metadataSize int @@ -107,26 +115,48 @@ func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) { recover() }() - infoHash := p.id - address := p.address.String() + ll := bt.log.WithFields("address", p.Addr.String()) - dial, err := net.DialTimeout("tcp", address, time.Second*15) + ll.Debug("connecting") + dial, err := net.DialTimeout("tcp", p.Addr.String(), time.Second*15) if err != nil { return out, err } + // Cast conn := dial.(*net.TCPConn) conn.SetLinger(0) defer conn.Close() + ll.Debug("dialed") data := bytes.NewBuffer(nil) data.Grow(BlockSize) + ih := models.GenInfohash() + // TCP handshake - if sendHandshake(conn, []byte(infoHash), []byte(genInfoHash())) != nil || - read(conn, 68, data) != nil || - onHandshake(data.Next(68)) != nil || - sendExtHandshake(conn) != nil { - return + ll.Debug("sending handshake") + _, err = sendHandshake(conn, p.Infohash, ih) + if err != nil { + return nil, err + } + + // Handle the handshake response + ll.Debug("handling handshake response") + err = read(conn, 68, data) + if err != nil { + return nil, err + } + next := data.Next(68) + ll.Debug("got next data") + if !(bytes.Equal(handshakePrefix[:20], next[:20]) && next[25]&0x10 != 0) { + ll.Debug("next data does not match", "next", next) + return nil, errors.New("invalid handshake response") + } + + ll.Debug("sending ext handshake") + _, err = sendExtHandshake(conn) + if err != nil { + return nil, err } for { @@ -166,13 +196,13 @@ func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) { return out, err } - piecesNum = metadataSize / BlockSize + totalPieces = metadataSize / BlockSize if metadataSize%BlockSize != 0 { - piecesNum++ + totalPieces++ } - pieces = make([][]byte, piecesNum) - go bt.requestPieces(conn, utMetadata, metadataSize, piecesNum) + pieces = make([][]byte, totalPieces) + go bt.requestPieces(conn, utMetadata, metadataSize, totalPieces) continue } @@ -181,40 +211,40 @@ func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) { return out, errors.New("no pieces found") } - d, index, err := bencode.DecodeDict(payload, 0) + dict, index, err := bencode.DecodeDict(payload, 0) if err != nil { return out, err } - dict := d.(map[string]interface{}) - err = parseKeys(dict, [][]string{{"msg_type", "int"}, {"piece", "int"}}) + mt, err := krpc.GetInt(dict, "msg_type") if err != nil { return out, err } - if dict["msg_type"].(int) != MsgData { + if mt != MsgData { continue } - piece := dict["piece"].(int) + piece, err := krpc.GetInt(dict, "piece") + if err != nil { + return out, err + } + pieceLen := length - 2 - index - if (piece != piecesNum-1 && pieceLen != BlockSize) || - (piece == piecesNum-1 && pieceLen != metadataSize%BlockSize) { - return out, errors.New("invalid piece count") + // Not last piece? should be full block + if totalPieces > 1 && piece != totalPieces-1 && pieceLen != BlockSize { + return out, fmt.Errorf("incomplete piece %d", piece) + } + // Last piece needs to equal remainder + if piece == totalPieces-1 && pieceLen != metadataSize%BlockSize { + return out, fmt.Errorf("incorrect final piece %d", piece) } pieces[piece] = payload[index:] if bt.isDone(pieces) { - metadataInfo := bytes.Join(pieces, nil) - - // Check the metadata - info := sha1.Sum(metadataInfo) - if !bytes.Equal([]byte(infoHash), info[:]) { - return out, errors.New("metadata does not match infohash") - } - return metadataInfo, nil + return bytes.Join(pieces, nil), nil } default: data.Reset() @@ -222,50 +252,8 @@ func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) { } } -func decodeMetadata(p peer, md []byte) (*Torrent, error) { - metadata, err := bencode.Decode(md) - if err != nil { - return nil, err - } - info := metadata.(map[string]interface{}) - - if _, ok := info["name"]; !ok { - return nil, errors.New("Metadata missing name") - } - - bt := Torrent{ - InfoHash: hex.EncodeToString([]byte(p.id)), - Name: info["name"].(string), - } - - if v, ok := info["files"]; ok { - files := v.([]interface{}) - bt.Files = make([]File, len(files)) - - for i, item := range files { - f := item.(map[string]interface{}) - paths := f["path"].([]interface{}) - path := make([]string, len(paths)) - for j, p := range paths { - path[j] = p.(string) - } - fSize := f["length"].(int) - bt.Files[i] = File{ - // Assume Unix path sep - Path: strings.Join(path[:], "/"), - Size: fSize, - } - // Ensure the torrent size totals all files' - bt.Size = bt.Size + fSize - } - } else if _, ok := info["length"]; ok { - bt.Size = info["length"].(int) - } - return &bt, nil -} - // isDone checks if all pieces are complete -func (bt *btClient) isDone(pieces [][]byte) bool { +func (bt *Worker) isDone(pieces [][]byte) bool { for _, piece := range pieces { if len(piece) == 0 { return false @@ -275,9 +263,8 @@ func (bt *btClient) isDone(pieces [][]byte) bool { } // read reads size-length bytes from conn to data. -func read(conn *net.TCPConn, size int, data *bytes.Buffer) error { +func read(conn net.Conn, size int, data io.Writer) error { conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) - n, err := io.CopyN(data, conn, int64(size)) if err != nil || n != int64(size) { return errors.New("read error") @@ -286,7 +273,7 @@ func read(conn *net.TCPConn, size int, data *bytes.Buffer) error { } // readMessage gets a message from the tcp connection. -func readMessage(conn *net.TCPConn, data *bytes.Buffer) (length int, err error) { +func readMessage(conn net.Conn, data *bytes.Buffer) (length int, err error) { if err = read(conn, 4, data); err != nil { return length, err } @@ -305,27 +292,25 @@ func readMessage(conn *net.TCPConn, data *bytes.Buffer) (length int, err error) } // sendMessage sends data to the connection. -func sendMessage(conn *net.TCPConn, data []byte) error { +func sendMessage(conn net.Conn, data []byte) (int, error) { length := int32(len(data)) buffer := bytes.NewBuffer(nil) binary.Write(buffer, binary.BigEndian, length) conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) - b, err := conn.Write(append(buffer.Bytes(), data...)) - return err + return conn.Write(append(buffer.Bytes(), data...)) } // sendHandshake sends handshake message to conn. -func sendHandshake(conn *net.TCPConn, infoHash, peerID []byte) error { +func sendHandshake(conn net.Conn, ih, id models.Infohash) (int, error) { data := make([]byte, 68) copy(data[:28], handshakePrefix) - copy(data[28:48], infoHash) - copy(data[48:], peerID) + copy(data[28:48], []byte(ih)) + copy(data[48:], []byte(id)) conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) - b, err := conn.Write(data) - return err + return conn.Write(data) } // onHandshake handles the handshake response. @@ -337,43 +322,40 @@ func onHandshake(data []byte) (err error) { } // sendExtHandshake requests for the ut_metadata and metadata_size. -func sendExtHandshake(conn *net.TCPConn) error { - data := append( - []byte{MsgExtended, HandshakeBit}, - bencode.Encode(map[string]interface{}{ - "m": map[string]interface{}{"ut_metadata": 1}, - })..., - ) +func sendExtHandshake(conn net.Conn) (int, error) { + m, err := bencode.EncodeDict(map[string]interface{}{ + "m": map[string]interface{}{"ut_metadata": 1}, + }) + if err != nil { + return 0, err + } + data := append([]byte{MsgExtended, HandshakeBit}, m...) return sendMessage(conn, data) } // getUTMetaSize returns the ut_metadata and metadata_size. func getUTMetaSize(data []byte) (utMetadata int, metadataSize int, err error) { - v, err := bencode.Decode(data) + dict, _, err := bencode.DecodeDict(data, 0) if err != nil { return utMetadata, metadataSize, err } - dict, ok := v.(map[string]interface{}) - if !ok { - return utMetadata, metadataSize, errors.New("invalid dict") + m, err := krpc.GetMap(dict, "m") + if err != nil { + return utMetadata, metadataSize, err } - err = parseKeys(dict, [][]string{{"metadata_size", "int"}, {"m", "map"}}) + utMetadata, err = krpc.GetInt(m, "ut_metadata") if err != nil { return utMetadata, metadataSize, err } - m := dict["m"].(map[string]interface{}) - err = parseKey(m, "ut_metadata", "int") + metadataSize, err = krpc.GetInt(dict, "metadata_size") if err != nil { return utMetadata, metadataSize, err } - utMetadata = m["ut_metadata"].(int) - metadataSize = dict["metadata_size"].(int) - if metadataSize > MaxMetadataSize { err = errors.New("metadata_size too long") } @@ -381,13 +363,13 @@ func getUTMetaSize(data []byte) (utMetadata int, metadataSize int, err error) { } // Request more pieces -func (bt *btClient) requestPieces(conn *net.TCPConn, utMetadata int, metadataSize int, piecesNum int) { +func (bt *Worker) requestPieces(conn net.Conn, utMetadata int, metadataSize int, totalPieces int) { buffer := make([]byte, 1024) - for i := 0; i < piecesNum; i++ { + for i := 0; i < totalPieces; i++ { buffer[0] = MsgExtended buffer[1] = byte(utMetadata) - msg := bencode.Encode(map[string]interface{}{ + msg, _ := bencode.EncodeDict(map[string]interface{}{ "msg_type": MsgRequest, "piece": i, }) |
