From 40732d76b81aa16ed38d855e93ce82387d29643f Mon Sep 17 00:00:00 2001 From: Felix Hanley Date: Mon, 26 Feb 2018 22:25:32 +1100 Subject: Move bt client --- bittorrent/client.go | 416 --------------------------------------------------- bt/options.go | 50 +++++++ bt/worker.go | 398 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 448 insertions(+), 416 deletions(-) delete mode 100644 bittorrent/client.go create mode 100644 bt/options.go create mode 100644 bt/worker.go diff --git a/bittorrent/client.go b/bittorrent/client.go deleted file mode 100644 index f4ac28e..0000000 --- a/bittorrent/client.go +++ /dev/null @@ -1,416 +0,0 @@ -package dhtsearch - -// Lifted and adapted from github.com/shiyanhui/dht - -import ( - "bytes" - "crypto/sha1" - "encoding/binary" - "encoding/hex" - "errors" - "io" - "io/ioutil" - "net" - "strings" - "time" - - "github.com/felix/dhtsearch/bencode" - "github.com/felix/logger" -) - -const ( - TCPTimeout = 5 - UDPTimeout = 5 -) - -const ( - // MsgRequest represents request message type - MsgRequest = iota - // MsgData represents data message type - MsgData - // MsgReject represents reject message type - MsgReject - // MsgExtended represents it is a extended message - MsgExtended = 20 -) - -const ( - // BlockSize is 2 ^ 14 - BlockSize = 16384 - // MaxMetadataSize represents the max medata it can accept - MaxMetadataSize = BlockSize * 1000 - // HandshakeBit represents handshake bit - HandshakeBit = 0 -) - -var handshakePrefix = []byte{ - 19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114, - 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1, -} - -type btClient struct { - pool chan chan peer - log logger.Logger -} - -func (bt *btClient) run(torrentCh chan<- *Torrent) error { - peerCh := make(chan peer) - - if bt.log == nil { - bt.log = logger.New(&logger.Options{ - Name: "bt", - Level: logger.Info, - }) - } - - go func() { - for { - // Signal we are ready for work - bt.pool <- peerCh - - select { - case p := <-peerCh: - // Got work - if len(p.id) != 20 { - return - } - bt.log.Debug("fetching metadata", "peer", p.id) - md, err := bt.fetchMetadata(p) - if err != nil { - bt.log.Error("failed to fetch metadata", "error", err) - } - - t, err := decodeMetadata(p, md) - if err != nil { - bt.log.Error("failed to decode metadata", "error", err) - } - torrentCh <- t - } - } - }() - return nil -} - -// fetchMetadata fetchs medata info accroding to infohash from dht. -func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) { - var ( - length int - msgType byte - piecesNum int - pieces [][]byte - utMetadata int - metadataSize int - ) - - defer func() { - pieces = nil - recover() - }() - - infoHash := p.id - address := p.address.String() - - dial, err := net.DialTimeout("tcp", address, time.Second*15) - if err != nil { - return out, err - } - conn := dial.(*net.TCPConn) - conn.SetLinger(0) - defer conn.Close() - - data := bytes.NewBuffer(nil) - data.Grow(BlockSize) - - // TCP handshake - if sendHandshake(conn, []byte(infoHash), []byte(genInfoHash())) != nil || - read(conn, 68, data) != nil || - onHandshake(data.Next(68)) != nil || - sendExtHandshake(conn) != nil { - return - } - - for { - length, err = readMessage(conn, data) - if err != nil { - return out, err - } - - if length == 0 { - continue - } - - msgType, err = data.ReadByte() - if err != nil { - return out, err - } - - switch msgType { - case MsgExtended: - extendedID, err := data.ReadByte() - if err != nil { - return out, err - } - - payload, err := ioutil.ReadAll(data) - if err != nil { - return out, err - } - - if extendedID == 0 { - if pieces != nil { - return out, errors.New("invalid extended ID") - } - - utMetadata, metadataSize, err = getUTMetaSize(payload) - if err != nil { - return out, err - } - - piecesNum = metadataSize / BlockSize - if metadataSize%BlockSize != 0 { - piecesNum++ - } - - pieces = make([][]byte, piecesNum) - go bt.requestPieces(conn, utMetadata, metadataSize, piecesNum) - - continue - } - - if pieces == nil { - return out, errors.New("no pieces found") - } - - d, index, err := bencode.DecodeDict(payload, 0) - if err != nil { - return out, err - } - dict := d.(map[string]interface{}) - - err = parseKeys(dict, [][]string{{"msg_type", "int"}, {"piece", "int"}}) - if err != nil { - return out, err - } - - if dict["msg_type"].(int) != MsgData { - continue - } - - piece := dict["piece"].(int) - pieceLen := length - 2 - index - - if (piece != piecesNum-1 && pieceLen != BlockSize) || - (piece == piecesNum-1 && pieceLen != metadataSize%BlockSize) { - return out, errors.New("invalid piece count") - } - - pieces[piece] = payload[index:] - - if bt.isDone(pieces) { - metadataInfo := bytes.Join(pieces, nil) - - // Check the metadata - info := sha1.Sum(metadataInfo) - if !bytes.Equal([]byte(infoHash), info[:]) { - return out, errors.New("metadata does not match infohash") - } - return metadataInfo, nil - } - default: - data.Reset() - } - } -} - -func decodeMetadata(p peer, md []byte) (*Torrent, error) { - metadata, err := bencode.Decode(md) - if err != nil { - return nil, err - } - info := metadata.(map[string]interface{}) - - if _, ok := info["name"]; !ok { - return nil, errors.New("Metadata missing name") - } - - bt := Torrent{ - InfoHash: hex.EncodeToString([]byte(p.id)), - Name: info["name"].(string), - } - - if v, ok := info["files"]; ok { - files := v.([]interface{}) - bt.Files = make([]File, len(files)) - - for i, item := range files { - f := item.(map[string]interface{}) - paths := f["path"].([]interface{}) - path := make([]string, len(paths)) - for j, p := range paths { - path[j] = p.(string) - } - fSize := f["length"].(int) - bt.Files[i] = File{ - // Assume Unix path sep - Path: strings.Join(path[:], "/"), - Size: fSize, - } - // Ensure the torrent size totals all files' - bt.Size = bt.Size + fSize - } - } else if _, ok := info["length"]; ok { - bt.Size = info["length"].(int) - } - return &bt, nil -} - -// isDone checks if all pieces are complete -func (bt *btClient) isDone(pieces [][]byte) bool { - for _, piece := range pieces { - if len(piece) == 0 { - return false - } - } - return true -} - -// read reads size-length bytes from conn to data. -func read(conn *net.TCPConn, size int, data *bytes.Buffer) error { - conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) - - n, err := io.CopyN(data, conn, int64(size)) - if err != nil || n != int64(size) { - return errors.New("read error") - } - return nil -} - -// readMessage gets a message from the tcp connection. -func readMessage(conn *net.TCPConn, data *bytes.Buffer) (length int, err error) { - if err = read(conn, 4, data); err != nil { - return length, err - } - - length, err = bytes2int(data.Next(4)) - if err != nil { - return length, err - } - - if length == 0 { - return length, nil - } - - err = read(conn, length, data) - return length, err -} - -// sendMessage sends data to the connection. -func sendMessage(conn *net.TCPConn, data []byte) error { - length := int32(len(data)) - - buffer := bytes.NewBuffer(nil) - binary.Write(buffer, binary.BigEndian, length) - - conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) - b, err := conn.Write(append(buffer.Bytes(), data...)) - return err -} - -// sendHandshake sends handshake message to conn. -func sendHandshake(conn *net.TCPConn, infoHash, peerID []byte) error { - data := make([]byte, 68) - copy(data[:28], handshakePrefix) - copy(data[28:48], infoHash) - copy(data[48:], peerID) - - conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) - b, err := conn.Write(data) - return err -} - -// onHandshake handles the handshake response. -func onHandshake(data []byte) (err error) { - if !(bytes.Equal(handshakePrefix[:20], data[:20]) && data[25]&0x10 != 0) { - err = errors.New("invalid handshake response") - } - return err -} - -// sendExtHandshake requests for the ut_metadata and metadata_size. -func sendExtHandshake(conn *net.TCPConn) error { - data := append( - []byte{MsgExtended, HandshakeBit}, - bencode.Encode(map[string]interface{}{ - "m": map[string]interface{}{"ut_metadata": 1}, - })..., - ) - - return sendMessage(conn, data) -} - -// getUTMetaSize returns the ut_metadata and metadata_size. -func getUTMetaSize(data []byte) (utMetadata int, metadataSize int, err error) { - v, err := bencode.Decode(data) - if err != nil { - return utMetadata, metadataSize, err - } - - dict, ok := v.(map[string]interface{}) - if !ok { - return utMetadata, metadataSize, errors.New("invalid dict") - } - - err = parseKeys(dict, [][]string{{"metadata_size", "int"}, {"m", "map"}}) - if err != nil { - return utMetadata, metadataSize, err - } - - m := dict["m"].(map[string]interface{}) - err = parseKey(m, "ut_metadata", "int") - if err != nil { - return utMetadata, metadataSize, err - } - - utMetadata = m["ut_metadata"].(int) - metadataSize = dict["metadata_size"].(int) - - if metadataSize > MaxMetadataSize { - err = errors.New("metadata_size too long") - } - return utMetadata, metadataSize, err -} - -// Request more pieces -func (bt *btClient) requestPieces(conn *net.TCPConn, utMetadata int, metadataSize int, piecesNum int) { - buffer := make([]byte, 1024) - for i := 0; i < piecesNum; i++ { - buffer[0] = MsgExtended - buffer[1] = byte(utMetadata) - - msg := bencode.Encode(map[string]interface{}{ - "msg_type": MsgRequest, - "piece": i, - }) - - length := len(msg) + 2 - copy(buffer[2:length], msg) - - sendMessage(conn, buffer[:length]) - } - buffer = nil -} - -// bytes2int returns the int value it represents. -func bytes2int(data []byte) (int, error) { - n := len(data) - if n > 8 { - return 0, errors.New("data too long") - } - - val := uint64(0) - - for i, b := range data { - val += uint64(b) << uint64((n-i-1)*8) - } - return int(val), nil -} diff --git a/bt/options.go b/bt/options.go new file mode 100644 index 0000000..25ebc21 --- /dev/null +++ b/bt/options.go @@ -0,0 +1,50 @@ +package bt + +import ( + "github.com/felix/dhtsearch/models" + "github.com/felix/logger" +) + +type Option func(*Worker) error + +// SetNewTorrent sets the callback +func SetOnNewTorrent(f func(models.Torrent)) Option { + return func(w *Worker) error { + w.OnNewTorrent = f + return nil + } +} + +// SetPort sets the port to listen on +func SetPort(p int) Option { + return func(w *Worker) error { + w.port = p + return nil + } +} + +// SetIPv6 enables IPv6 +func SetIPv6(b bool) Option { + return func(w *Worker) error { + if b { + w.family = "tcp6" + } + return nil + } +} + +// SetUDPTimeout sets the number of seconds to wait for UDP connections +func SetTCPTimeout(s int) Option { + return func(w *Worker) error { + w.tcpTimeout = s + return nil + } +} + +// SetLogger sets the logger +func SetLogger(l logger.Logger) Option { + return func(w *Worker) error { + w.log = l + return nil + } +} diff --git a/bt/worker.go b/bt/worker.go new file mode 100644 index 0000000..6d247f4 --- /dev/null +++ b/bt/worker.go @@ -0,0 +1,398 @@ +package bt + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "time" + + "github.com/felix/dhtsearch/bencode" + "github.com/felix/dhtsearch/krpc" + "github.com/felix/dhtsearch/models" + "github.com/felix/logger" +) + +const ( + TCPTimeout = 5 + UDPTimeout = 5 +) + +const ( + // MsgRequest marks a request message type + MsgRequest = iota + // MsgData marks a data message type + MsgData + // MsgReject marks a reject message type + MsgReject + // MsgExtended marks it as an extended message + MsgExtended = 20 +) + +const ( + // BlockSize is 2 ^ 14 + BlockSize = 16384 + // MaxMetadataSize represents the max medata it can accept + MaxMetadataSize = BlockSize * 1000 + // HandshakeBit represents handshake bit + HandshakeBit = 0 +) + +var handshakePrefix = []byte{ + 19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114, + 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1, +} + +type Worker struct { + pool chan chan models.Peer + port int + family string + tcpTimeout int + OnNewTorrent func(t models.Torrent) + log logger.Logger +} + +func NewWorker(pool chan chan models.Peer, opts ...Option) (*Worker, error) { + var err error + w := &Worker{ + pool: pool, + } + + // Set variadic options passed + for _, option := range opts { + err = option(w) + if err != nil { + return nil, err + } + } + return w, nil +} + +func (bt *Worker) Run() error { + peerCh := make(chan models.Peer) + + for { + // Signal we are ready for work + bt.pool <- peerCh + + select { + case p := <-peerCh: + // Got work + bt.log.Debug("worker got work", "peer", p) + md, err := bt.fetchMetadata(p) + if err != nil { + bt.log.Debug("failed to fetch metadata", "error", err) + continue + } + t, err := models.TorrentFromMetadata(p.Infohash, md) + if err != nil { + bt.log.Debug("failed to load torrent", "error", err) + continue + } + if bt.OnNewTorrent != nil { + go bt.OnNewTorrent(*t) + } + } + } +} + +// fetchMetadata fetchs medata info accroding to infohash from dht. +func (bt *Worker) fetchMetadata(p models.Peer) (out []byte, err error) { + var ( + length int + msgType byte + totalPieces int + pieces [][]byte + utMetadata int + metadataSize int + ) + + defer func() { + pieces = nil + recover() + }() + + ll := bt.log.WithFields("address", p.Addr.String()) + + ll.Debug("connecting") + dial, err := net.DialTimeout("tcp", p.Addr.String(), time.Second*15) + if err != nil { + return out, err + } + // Cast + conn := dial.(*net.TCPConn) + conn.SetLinger(0) + defer conn.Close() + ll.Debug("dialed") + + data := bytes.NewBuffer(nil) + data.Grow(BlockSize) + + ih := models.GenInfohash() + + // TCP handshake + ll.Debug("sending handshake") + _, err = sendHandshake(conn, p.Infohash, ih) + if err != nil { + return nil, err + } + + // Handle the handshake response + ll.Debug("handling handshake response") + err = read(conn, 68, data) + if err != nil { + return nil, err + } + next := data.Next(68) + ll.Debug("got next data") + if !(bytes.Equal(handshakePrefix[:20], next[:20]) && next[25]&0x10 != 0) { + ll.Debug("next data does not match", "next", next) + return nil, errors.New("invalid handshake response") + } + + ll.Debug("sending ext handshake") + _, err = sendExtHandshake(conn) + if err != nil { + return nil, err + } + + for { + length, err = readMessage(conn, data) + if err != nil { + return out, err + } + + if length == 0 { + continue + } + + msgType, err = data.ReadByte() + if err != nil { + return out, err + } + + switch msgType { + case MsgExtended: + extendedID, err := data.ReadByte() + if err != nil { + return out, err + } + + payload, err := ioutil.ReadAll(data) + if err != nil { + return out, err + } + + if extendedID == 0 { + if pieces != nil { + return out, errors.New("invalid extended ID") + } + + utMetadata, metadataSize, err = getUTMetaSize(payload) + if err != nil { + return out, err + } + + totalPieces = metadataSize / BlockSize + if metadataSize%BlockSize != 0 { + totalPieces++ + } + + pieces = make([][]byte, totalPieces) + go bt.requestPieces(conn, utMetadata, metadataSize, totalPieces) + + continue + } + + if pieces == nil { + return out, errors.New("no pieces found") + } + + dict, index, err := bencode.DecodeDict(payload, 0) + if err != nil { + return out, err + } + + mt, err := krpc.GetInt(dict, "msg_type") + if err != nil { + return out, err + } + + if mt != MsgData { + continue + } + + piece, err := krpc.GetInt(dict, "piece") + if err != nil { + return out, err + } + + pieceLen := length - 2 - index + + // Not last piece? should be full block + if totalPieces > 1 && piece != totalPieces-1 && pieceLen != BlockSize { + return out, fmt.Errorf("incomplete piece %d", piece) + } + // Last piece needs to equal remainder + if piece == totalPieces-1 && pieceLen != metadataSize%BlockSize { + return out, fmt.Errorf("incorrect final piece %d", piece) + } + + pieces[piece] = payload[index:] + + if bt.isDone(pieces) { + return bytes.Join(pieces, nil), nil + } + default: + data.Reset() + } + } +} + +// isDone checks if all pieces are complete +func (bt *Worker) isDone(pieces [][]byte) bool { + for _, piece := range pieces { + if len(piece) == 0 { + return false + } + } + return true +} + +// read reads size-length bytes from conn to data. +func read(conn net.Conn, size int, data io.Writer) error { + conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) + n, err := io.CopyN(data, conn, int64(size)) + if err != nil || n != int64(size) { + return errors.New("read error") + } + return nil +} + +// readMessage gets a message from the tcp connection. +func readMessage(conn net.Conn, data *bytes.Buffer) (length int, err error) { + if err = read(conn, 4, data); err != nil { + return length, err + } + + length, err = bytes2int(data.Next(4)) + if err != nil { + return length, err + } + + if length == 0 { + return length, nil + } + + err = read(conn, length, data) + return length, err +} + +// sendMessage sends data to the connection. +func sendMessage(conn net.Conn, data []byte) (int, error) { + length := int32(len(data)) + + buffer := bytes.NewBuffer(nil) + binary.Write(buffer, binary.BigEndian, length) + + conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) + return conn.Write(append(buffer.Bytes(), data...)) +} + +// sendHandshake sends handshake message to conn. +func sendHandshake(conn net.Conn, ih, id models.Infohash) (int, error) { + data := make([]byte, 68) + copy(data[:28], handshakePrefix) + copy(data[28:48], []byte(ih)) + copy(data[48:], []byte(id)) + + conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) + return conn.Write(data) +} + +// onHandshake handles the handshake response. +func onHandshake(data []byte) (err error) { + if !(bytes.Equal(handshakePrefix[:20], data[:20]) && data[25]&0x10 != 0) { + err = errors.New("invalid handshake response") + } + return err +} + +// sendExtHandshake requests for the ut_metadata and metadata_size. +func sendExtHandshake(conn net.Conn) (int, error) { + m, err := bencode.EncodeDict(map[string]interface{}{ + "m": map[string]interface{}{"ut_metadata": 1}, + }) + if err != nil { + return 0, err + } + data := append([]byte{MsgExtended, HandshakeBit}, m...) + + return sendMessage(conn, data) +} + +// getUTMetaSize returns the ut_metadata and metadata_size. +func getUTMetaSize(data []byte) (utMetadata int, metadataSize int, err error) { + dict, _, err := bencode.DecodeDict(data, 0) + if err != nil { + return utMetadata, metadataSize, err + } + + m, err := krpc.GetMap(dict, "m") + if err != nil { + return utMetadata, metadataSize, err + } + + utMetadata, err = krpc.GetInt(m, "ut_metadata") + if err != nil { + return utMetadata, metadataSize, err + } + + metadataSize, err = krpc.GetInt(dict, "metadata_size") + if err != nil { + return utMetadata, metadataSize, err + } + + if metadataSize > MaxMetadataSize { + err = errors.New("metadata_size too long") + } + return utMetadata, metadataSize, err +} + +// Request more pieces +func (bt *Worker) requestPieces(conn net.Conn, utMetadata int, metadataSize int, totalPieces int) { + buffer := make([]byte, 1024) + for i := 0; i < totalPieces; i++ { + buffer[0] = MsgExtended + buffer[1] = byte(utMetadata) + + msg, _ := bencode.EncodeDict(map[string]interface{}{ + "msg_type": MsgRequest, + "piece": i, + }) + + length := len(msg) + 2 + copy(buffer[2:length], msg) + + sendMessage(conn, buffer[:length]) + } + buffer = nil +} + +// bytes2int returns the int value it represents. +func bytes2int(data []byte) (int, error) { + n := len(data) + if n > 8 { + return 0, errors.New("data too long") + } + + val := uint64(0) + + for i, b := range data { + val += uint64(b) << uint64((n-i-1)*8) + } + return int(val), nil +} -- cgit v1.2.3