diff options
Diffstat (limited to 'bt/worker.go')
| -rw-r--r-- | bt/worker.go | 398 |
1 files changed, 398 insertions, 0 deletions
diff --git a/bt/worker.go b/bt/worker.go new file mode 100644 index 0000000..6d247f4 --- /dev/null +++ b/bt/worker.go @@ -0,0 +1,398 @@ +package bt + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "time" + + "github.com/felix/dhtsearch/bencode" + "github.com/felix/dhtsearch/krpc" + "github.com/felix/dhtsearch/models" + "github.com/felix/logger" +) + +const ( + TCPTimeout = 5 + UDPTimeout = 5 +) + +const ( + // MsgRequest marks a request message type + MsgRequest = iota + // MsgData marks a data message type + MsgData + // MsgReject marks a reject message type + MsgReject + // MsgExtended marks it as an extended message + MsgExtended = 20 +) + +const ( + // BlockSize is 2 ^ 14 + BlockSize = 16384 + // MaxMetadataSize represents the max medata it can accept + MaxMetadataSize = BlockSize * 1000 + // HandshakeBit represents handshake bit + HandshakeBit = 0 +) + +var handshakePrefix = []byte{ + 19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114, + 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1, +} + +type Worker struct { + pool chan chan models.Peer + port int + family string + tcpTimeout int + OnNewTorrent func(t models.Torrent) + log logger.Logger +} + +func NewWorker(pool chan chan models.Peer, opts ...Option) (*Worker, error) { + var err error + w := &Worker{ + pool: pool, + } + + // Set variadic options passed + for _, option := range opts { + err = option(w) + if err != nil { + return nil, err + } + } + return w, nil +} + +func (bt *Worker) Run() error { + peerCh := make(chan models.Peer) + + for { + // Signal we are ready for work + bt.pool <- peerCh + + select { + case p := <-peerCh: + // Got work + bt.log.Debug("worker got work", "peer", p) + md, err := bt.fetchMetadata(p) + if err != nil { + bt.log.Debug("failed to fetch metadata", "error", err) + continue + } + t, err := models.TorrentFromMetadata(p.Infohash, md) + if err != nil { + bt.log.Debug("failed to load torrent", "error", err) + continue + } + if bt.OnNewTorrent != nil { + go bt.OnNewTorrent(*t) + } + } + } +} + +// fetchMetadata fetchs medata info accroding to infohash from dht. +func (bt *Worker) fetchMetadata(p models.Peer) (out []byte, err error) { + var ( + length int + msgType byte + totalPieces int + pieces [][]byte + utMetadata int + metadataSize int + ) + + defer func() { + pieces = nil + recover() + }() + + ll := bt.log.WithFields("address", p.Addr.String()) + + ll.Debug("connecting") + dial, err := net.DialTimeout("tcp", p.Addr.String(), time.Second*15) + if err != nil { + return out, err + } + // Cast + conn := dial.(*net.TCPConn) + conn.SetLinger(0) + defer conn.Close() + ll.Debug("dialed") + + data := bytes.NewBuffer(nil) + data.Grow(BlockSize) + + ih := models.GenInfohash() + + // TCP handshake + ll.Debug("sending handshake") + _, err = sendHandshake(conn, p.Infohash, ih) + if err != nil { + return nil, err + } + + // Handle the handshake response + ll.Debug("handling handshake response") + err = read(conn, 68, data) + if err != nil { + return nil, err + } + next := data.Next(68) + ll.Debug("got next data") + if !(bytes.Equal(handshakePrefix[:20], next[:20]) && next[25]&0x10 != 0) { + ll.Debug("next data does not match", "next", next) + return nil, errors.New("invalid handshake response") + } + + ll.Debug("sending ext handshake") + _, err = sendExtHandshake(conn) + if err != nil { + return nil, err + } + + for { + length, err = readMessage(conn, data) + if err != nil { + return out, err + } + + if length == 0 { + continue + } + + msgType, err = data.ReadByte() + if err != nil { + return out, err + } + + switch msgType { + case MsgExtended: + extendedID, err := data.ReadByte() + if err != nil { + return out, err + } + + payload, err := ioutil.ReadAll(data) + if err != nil { + return out, err + } + + if extendedID == 0 { + if pieces != nil { + return out, errors.New("invalid extended ID") + } + + utMetadata, metadataSize, err = getUTMetaSize(payload) + if err != nil { + return out, err + } + + totalPieces = metadataSize / BlockSize + if metadataSize%BlockSize != 0 { + totalPieces++ + } + + pieces = make([][]byte, totalPieces) + go bt.requestPieces(conn, utMetadata, metadataSize, totalPieces) + + continue + } + + if pieces == nil { + return out, errors.New("no pieces found") + } + + dict, index, err := bencode.DecodeDict(payload, 0) + if err != nil { + return out, err + } + + mt, err := krpc.GetInt(dict, "msg_type") + if err != nil { + return out, err + } + + if mt != MsgData { + continue + } + + piece, err := krpc.GetInt(dict, "piece") + if err != nil { + return out, err + } + + pieceLen := length - 2 - index + + // Not last piece? should be full block + if totalPieces > 1 && piece != totalPieces-1 && pieceLen != BlockSize { + return out, fmt.Errorf("incomplete piece %d", piece) + } + // Last piece needs to equal remainder + if piece == totalPieces-1 && pieceLen != metadataSize%BlockSize { + return out, fmt.Errorf("incorrect final piece %d", piece) + } + + pieces[piece] = payload[index:] + + if bt.isDone(pieces) { + return bytes.Join(pieces, nil), nil + } + default: + data.Reset() + } + } +} + +// isDone checks if all pieces are complete +func (bt *Worker) isDone(pieces [][]byte) bool { + for _, piece := range pieces { + if len(piece) == 0 { + return false + } + } + return true +} + +// read reads size-length bytes from conn to data. +func read(conn net.Conn, size int, data io.Writer) error { + conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) + n, err := io.CopyN(data, conn, int64(size)) + if err != nil || n != int64(size) { + return errors.New("read error") + } + return nil +} + +// readMessage gets a message from the tcp connection. +func readMessage(conn net.Conn, data *bytes.Buffer) (length int, err error) { + if err = read(conn, 4, data); err != nil { + return length, err + } + + length, err = bytes2int(data.Next(4)) + if err != nil { + return length, err + } + + if length == 0 { + return length, nil + } + + err = read(conn, length, data) + return length, err +} + +// sendMessage sends data to the connection. +func sendMessage(conn net.Conn, data []byte) (int, error) { + length := int32(len(data)) + + buffer := bytes.NewBuffer(nil) + binary.Write(buffer, binary.BigEndian, length) + + conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) + return conn.Write(append(buffer.Bytes(), data...)) +} + +// sendHandshake sends handshake message to conn. +func sendHandshake(conn net.Conn, ih, id models.Infohash) (int, error) { + data := make([]byte, 68) + copy(data[:28], handshakePrefix) + copy(data[28:48], []byte(ih)) + copy(data[48:], []byte(id)) + + conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) + return conn.Write(data) +} + +// onHandshake handles the handshake response. +func onHandshake(data []byte) (err error) { + if !(bytes.Equal(handshakePrefix[:20], data[:20]) && data[25]&0x10 != 0) { + err = errors.New("invalid handshake response") + } + return err +} + +// sendExtHandshake requests for the ut_metadata and metadata_size. +func sendExtHandshake(conn net.Conn) (int, error) { + m, err := bencode.EncodeDict(map[string]interface{}{ + "m": map[string]interface{}{"ut_metadata": 1}, + }) + if err != nil { + return 0, err + } + data := append([]byte{MsgExtended, HandshakeBit}, m...) + + return sendMessage(conn, data) +} + +// getUTMetaSize returns the ut_metadata and metadata_size. +func getUTMetaSize(data []byte) (utMetadata int, metadataSize int, err error) { + dict, _, err := bencode.DecodeDict(data, 0) + if err != nil { + return utMetadata, metadataSize, err + } + + m, err := krpc.GetMap(dict, "m") + if err != nil { + return utMetadata, metadataSize, err + } + + utMetadata, err = krpc.GetInt(m, "ut_metadata") + if err != nil { + return utMetadata, metadataSize, err + } + + metadataSize, err = krpc.GetInt(dict, "metadata_size") + if err != nil { + return utMetadata, metadataSize, err + } + + if metadataSize > MaxMetadataSize { + err = errors.New("metadata_size too long") + } + return utMetadata, metadataSize, err +} + +// Request more pieces +func (bt *Worker) requestPieces(conn net.Conn, utMetadata int, metadataSize int, totalPieces int) { + buffer := make([]byte, 1024) + for i := 0; i < totalPieces; i++ { + buffer[0] = MsgExtended + buffer[1] = byte(utMetadata) + + msg, _ := bencode.EncodeDict(map[string]interface{}{ + "msg_type": MsgRequest, + "piece": i, + }) + + length := len(msg) + 2 + copy(buffer[2:length], msg) + + sendMessage(conn, buffer[:length]) + } + buffer = nil +} + +// bytes2int returns the int value it represents. +func bytes2int(data []byte) (int, error) { + n := len(data) + if n > 8 { + return 0, errors.New("data too long") + } + + val := uint64(0) + + for i, b := range data { + val += uint64(b) << uint64((n-i-1)*8) + } + return int(val), nil +} |
