diff options
| -rw-r--r-- | btclient.go | 393 | ||||
| -rw-r--r-- | tag.go | 2 |
2 files changed, 190 insertions, 205 deletions
diff --git a/btclient.go b/btclient.go index 8a87c72..3a33ffa 100644 --- a/btclient.go +++ b/btclient.go @@ -42,200 +42,52 @@ var handshakePrefix = []byte{ 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1, } -// Annouced peer -type peer struct { - address net.UDPAddr - id string -} - type btClient struct { - peersIn <-chan peer - torrentsOut chan<- Torrent - workerTokens chan struct{} + pool chan chan peer + log logger.Logger } -func newBTClient(r <-chan peer, t chan<- Torrent) *btClient { - return &btClient{ - peersIn: r, - torrentsOut: t, - workerTokens: make(chan struct{}, Config.Advanced.MaxBtWorkers), +func (bt *btClient) run(torrentCh chan<- *Torrent) error { + peerCh := make(chan peer) + + if bt.log == nil { + bt.log = logger.New(&logger.Options{ + Name: "bt", + Level: logger.Info, + }) } -} -func (bt *btClient) run(done <-chan struct{}) error { - var p peer go func() { for { + // Signal we are ready for work + bt.pool <- peerCh + select { - case <-done: - return - case p = <-bt.peersIn: - bt.workerTokens <- struct{}{} + case p := <-peerCh: + // Got work btWorkers.Add(1) + if len(p.id) != 20 { + return + } + bt.log.Debug("fetching metadata", "peer", p.id) + md, err := bt.fetchMetadata(p) + if err != nil { + bt.log.Error("failed to fetch metadata", "error", err) + } - go func(p peer) { - defer func() { - btWorkers.Add(-1) - <-bt.workerTokens - }() - - if len(p.id) != 20 { - return - } - - if Config.Debug { - fmt.Printf("Fetching metadata for %x\n", p.id) - } - bt.fetchMetadata(p) - }(p) + t, err := decodeMetadata(p, md) + if err != nil { + bt.log.Error("failed to decode metadata", "error", err) + } + torrentCh <- t } } }() return nil } -// isDone returns whether the wire get all pieces of the metadata info. -func (bt *btClient) isDone(pieces [][]byte) bool { - for _, piece := range pieces { - if len(piece) == 0 { - return false - } - } - return true -} - -// read reads size-length bytes from conn to data. -func read(conn *net.TCPConn, size int, data *bytes.Buffer) error { - conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(Config.Advanced.TcpTimeout))) - - n, err := io.CopyN(data, conn, int64(size)) - if err != nil || n != int64(size) { - return errors.New("read error") - } - btBytesIn.Add(n) - return nil -} - -// readMessage gets a message from the tcp connection. -func readMessage(conn *net.TCPConn, data *bytes.Buffer) (length int, err error) { - if err = read(conn, 4, data); err != nil { - return - } - - length = int(bytes2int(data.Next(4))) - if length == 0 { - return - } - - if err = read(conn, length, data); err != nil { - return - } - return -} - -// sendMessage sends data to the connection. -func sendMessage(conn *net.TCPConn, data []byte) error { - length := int32(len(data)) - - buffer := bytes.NewBuffer(nil) - binary.Write(buffer, binary.BigEndian, length) - - conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(Config.Advanced.TcpTimeout))) - b, err := conn.Write(append(buffer.Bytes(), data...)) - btBytesOut.Add(int64(b)) - return err -} - -// sendHandshake sends handshake message to conn. -func sendHandshake(conn *net.TCPConn, infoHash, peerID []byte) error { - data := make([]byte, 68) - copy(data[:28], handshakePrefix) - copy(data[28:48], infoHash) - copy(data[48:], peerID) - - conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(Config.Advanced.TcpTimeout))) - b, err := conn.Write(data) - btBytesOut.Add(int64(b)) - return err -} - -// onHandshake handles the handshake response. -func onHandshake(data []byte) (err error) { - if !(bytes.Equal(handshakePrefix[:20], data[:20]) && data[25]&0x10 != 0) { - err = errors.New("invalid handshake response") - } - return -} - -// sendExtHandshake requests for the ut_metadata and metadata_size. -func sendExtHandshake(conn *net.TCPConn) error { - data := append( - []byte{EXTENDED, HANDSHAKE}, - Encode(map[string]interface{}{ - "m": map[string]interface{}{"ut_metadata": 1}, - })..., - ) - - return sendMessage(conn, data) -} - -// getUTMetaSize returns the ut_metadata and metadata_size. -func getUTMetaSize(data []byte) ( - utMetadata int, metadataSize int, err error) { - - v, err := Decode(data) - if err != nil { - return - } - - dict, ok := v.(map[string]interface{}) - if !ok { - err = errors.New("invalid dict") - return - } - - if err = parseKeys( - dict, [][]string{{"metadata_size", "int"}, {"m", "map"}}); err != nil { - return - } - - m := dict["m"].(map[string]interface{}) - if err = parseKey(m, "ut_metadata", "int"); err != nil { - return - } - - utMetadata = m["ut_metadata"].(int) - metadataSize = dict["metadata_size"].(int) - - if metadataSize > MaxMetadataSize { - err = errors.New("metadata_size too long") - } - return -} - -func (bt *btClient) requestPieces( - conn *net.TCPConn, utMetadata int, metadataSize int, piecesNum int) { - - buffer := make([]byte, 1024) - for i := 0; i < piecesNum; i++ { - buffer[0] = EXTENDED - buffer[1] = byte(utMetadata) - - msg := Encode(map[string]interface{}{ - "msg_type": REQUEST, - "piece": i, - }) - - length := len(msg) + 2 - copy(buffer[2:length], msg) - - sendMessage(conn, buffer[:length]) - } - buffer = nil -} - // fetchMetadata fetchs medata info accroding to infohash from dht. -func (bt *btClient) fetchMetadata(p peer) { +func (bt *btClient) fetchMetadata(p peer) (out []byte, err error) { var ( length int msgType byte @@ -255,7 +107,7 @@ func (bt *btClient) fetchMetadata(p peer) { dial, err := net.DialTimeout("tcp", address, time.Second*15) if err != nil { - return + return out, err } conn := dial.(*net.TCPConn) conn.SetLinger(0) @@ -275,7 +127,7 @@ func (bt *btClient) fetchMetadata(p peer) { for { length, err = readMessage(conn, data) if err != nil { - return + return out, err } if length == 0 { @@ -284,29 +136,29 @@ func (bt *btClient) fetchMetadata(p peer) { msgType, err = data.ReadByte() if err != nil { - return + return out, err } switch msgType { case EXTENDED: extendedID, err := data.ReadByte() if err != nil { - return + return out, err } payload, err := ioutil.ReadAll(data) if err != nil { - return + return out, err } if extendedID == 0 { if pieces != nil { - return + return out, errors.New("invalid extended ID") } utMetadata, metadataSize, err = getUTMetaSize(payload) if err != nil { - return + return out, err } piecesNum = metadataSize / BLOCK @@ -321,19 +173,18 @@ func (bt *btClient) fetchMetadata(p peer) { } if pieces == nil { - return + return out, errors.New("no pieces found") } d, index, err := DecodeDict(payload, 0) if err != nil { - return + return out, err } dict := d.(map[string]interface{}) - if err = parseKeys(dict, [][]string{ - {"msg_type", "int"}, - {"piece", "int"}}); err != nil { - return + err = parseKeys(dict, [][]string{{"msg_type", "int"}, {"piece", "int"}}) + if err != nil { + return out, err } if dict["msg_type"].(int) != DATA { @@ -345,7 +196,7 @@ func (bt *btClient) fetchMetadata(p peer) { if (piece != piecesNum-1 && pieceLen != BLOCK) || (piece == piecesNum-1 && pieceLen != metadataSize%BLOCK) { - return + return out, errors.New("invalid piece count") } pieces[piece] = payload[index:] @@ -356,16 +207,9 @@ func (bt *btClient) fetchMetadata(p peer) { // Check the metadata info := sha1.Sum(metadataInfo) if !bytes.Equal([]byte(infoHash), info[:]) { - fmt.Println("Metadata does not match infohash") - return + return out, errors.New("metadata does not match infohash") } - - torrent, err := decodeMetadata(p, metadataInfo) - if err != nil { - return - } - bt.torrentsOut <- *torrent - return + return metadataInfo, nil } default: data.Reset() @@ -415,15 +259,156 @@ func decodeMetadata(p peer, md []byte) (*Torrent, error) { return &bt, nil } +// isDone checks if all pieces are complete +func (bt *btClient) isDone(pieces [][]byte) bool { + for _, piece := range pieces { + if len(piece) == 0 { + return false + } + } + return true +} + +// read reads size-length bytes from conn to data. +func read(conn *net.TCPConn, size int, data *bytes.Buffer) error { + conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) + + n, err := io.CopyN(data, conn, int64(size)) + if err != nil || n != int64(size) { + return errors.New("read error") + } + btBytesIn.Add(n) + return nil +} + +// readMessage gets a message from the tcp connection. +func readMessage(conn *net.TCPConn, data *bytes.Buffer) (length int, err error) { + if err = read(conn, 4, data); err != nil { + return length, err + } + + length, err = bytes2int(data.Next(4)) + if err != nil { + return length, err + } + + if length == 0 { + return length, nil + } + + err = read(conn, length, data) + return length, err +} + +// sendMessage sends data to the connection. +func sendMessage(conn *net.TCPConn, data []byte) error { + length := int32(len(data)) + + buffer := bytes.NewBuffer(nil) + binary.Write(buffer, binary.BigEndian, length) + + conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) + b, err := conn.Write(append(buffer.Bytes(), data...)) + btBytesOut.Add(int64(b)) + return err +} + +// sendHandshake sends handshake message to conn. +func sendHandshake(conn *net.TCPConn, infoHash, peerID []byte) error { + data := make([]byte, 68) + copy(data[:28], handshakePrefix) + copy(data[28:48], infoHash) + copy(data[48:], peerID) + + conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) + b, err := conn.Write(data) + btBytesOut.Add(int64(b)) + return err +} + +// onHandshake handles the handshake response. +func onHandshake(data []byte) (err error) { + if !(bytes.Equal(handshakePrefix[:20], data[:20]) && data[25]&0x10 != 0) { + err = errors.New("invalid handshake response") + } + return err +} + +// sendExtHandshake requests for the ut_metadata and metadata_size. +func sendExtHandshake(conn *net.TCPConn) error { + data := append( + []byte{EXTENDED, HANDSHAKE}, + Encode(map[string]interface{}{ + "m": map[string]interface{}{"ut_metadata": 1}, + })..., + ) + + return sendMessage(conn, data) +} + +// getUTMetaSize returns the ut_metadata and metadata_size. +func getUTMetaSize(data []byte) (utMetadata int, metadataSize int, err error) { + v, err := Decode(data) + if err != nil { + return utMetadata, metadataSize, err + } + + dict, ok := v.(map[string]interface{}) + if !ok { + return utMetadata, metadataSize, errors.New("invalid dict") + } + + err = parseKeys(dict, [][]string{{"metadata_size", "int"}, {"m", "map"}}) + if err != nil { + return utMetadata, metadataSize, err + } + + m := dict["m"].(map[string]interface{}) + err = parseKey(m, "ut_metadata", "int") + if err != nil { + return utMetadata, metadataSize, err + } + + utMetadata = m["ut_metadata"].(int) + metadataSize = dict["metadata_size"].(int) + + if metadataSize > MaxMetadataSize { + err = errors.New("metadata_size too long") + } + return utMetadata, metadataSize, err +} + +// Request more pieces +func (bt *btClient) requestPieces(conn *net.TCPConn, utMetadata int, metadataSize int, piecesNum int) { + buffer := make([]byte, 1024) + for i := 0; i < piecesNum; i++ { + buffer[0] = EXTENDED + buffer[1] = byte(utMetadata) + + msg := Encode(map[string]interface{}{ + "msg_type": REQUEST, + "piece": i, + }) + + length := len(msg) + 2 + copy(buffer[2:length], msg) + + sendMessage(conn, buffer[:length]) + } + buffer = nil +} + // bytes2int returns the int value it represents. -func bytes2int(data []byte) uint64 { - n, val := len(data), uint64(0) +func bytes2int(data []byte) (int, error) { + n := len(data) if n > 8 { - panic("data too long") + return 0, errors.New("data too long") } + val := uint64(0) + for i, b := range data { val += uint64(b) << uint64((n-i-1)*8) } - return val + return int(val), nil } @@ -1,4 +1,4 @@ -package dhtsearch +package main import ( "fmt" |
