diff options
| -rw-r--r-- | Makefile | 15 | ||||
| -rw-r--r-- | bt/messages.go | 99 | ||||
| -rw-r--r-- | bt/options.go | 50 | ||||
| -rw-r--r-- | bt/torrent.go | 289 | ||||
| -rw-r--r-- | bt/worker.go | 398 | ||||
| -rw-r--r-- | cmd/indexer/README.md | 10 | ||||
| -rw-r--r-- | cmd/indexer/main.go | 15 | ||||
| -rw-r--r-- | cmd/indexer/otel.go | 45 | ||||
| -rw-r--r-- | cmd/indexer/run.go | 245 | ||||
| -rw-r--r-- | dht/client.go | 767 | ||||
| -rw-r--r-- | dht/compact_node.go | 133 | ||||
| -rw-r--r-- | dht/compact_node_test.go | 63 | ||||
| -rw-r--r-- | dht/krpc.go | 259 | ||||
| -rw-r--r-- | dht/krpc_test.go | 106 | ||||
| -rw-r--r-- | dht/ktable.go | 411 | ||||
| -rw-r--r-- | dht/ktable_test.go | 106 | ||||
| -rw-r--r-- | dht/messages.go | 127 | ||||
| -rw-r--r-- | dht/metrics.go | 76 | ||||
| -rw-r--r-- | dht/node.go | 443 | ||||
| -rw-r--r-- | dht/node_test.go | 42 | ||||
| -rw-r--r-- | dht/options.go | 73 | ||||
| -rw-r--r-- | dht/packet.go | 33 | ||||
| -rw-r--r-- | dht/remote_node.go | 18 | ||||
| -rw-r--r-- | dht/routing_table.go | 121 | ||||
| -rw-r--r-- | dht/routing_table_test.go | 48 | ||||
| -rw-r--r-- | dht/slab.go | 25 | ||||
| -rw-r--r-- | dht/transactions.go | 93 | ||||
| -rw-r--r-- | krpc/krpc.go | 181 | ||||
| -rw-r--r-- | krpc/krpc_test.go | 30 |
29 files changed, 2810 insertions, 1511 deletions
@@ -1,13 +1,12 @@ -VERSION ?= $(shell git describe --tags --always) -SRC := $(shell find . -type f -name '*.go') -FLAGS := --tags fts5 +VERSION != git describe --tags --always +SRC != find . -type f -name '*.go' PLAT := windows darwin linux freebsd openbsd BINARY := $(patsubst %,dist/%,$(shell find cmd/* -maxdepth 0 -type d -exec basename {} \;)) RELEASE := $(foreach os, $(PLAT), $(patsubst %,%-$(os), $(BINARY))) .PHONY: build -build: sqlite $(BINARY) +build: $(BINARY) .PHONY: release release: $(RELEASE) @@ -15,20 +14,16 @@ release: $(RELEASE) dist/%: export GOOS=$(word 2,$(subst -, ,$*)) dist/%: bin=$(word 1,$(subst -, ,$*)) dist/%: $(SRC) $(shell find cmd/$(bin) -type f -name '*.go') - go build -ldflags "-X main.version=$(VERSION)" $(FLAGS) \ + go build -ldflags "-X main.version=$(VERSION)" \ -o $@ ./cmd/$(bin) -sqlite: - CGO_ENABLED=1 go get -u $(FLAGS) github.com/mattn/go-sqlite3 \ - && go install $(FLAGS) github.com/mattn/go-sqlite3 - .PHONY: test test: go test -short -coverprofile=coverage.out ./... \ && go tool cover -func=coverage.out .PHONY: lint -lint: ; go vet ./... +lint: ; golangci-lint run .PHONY: clean clean: diff --git a/bt/messages.go b/bt/messages.go new file mode 100644 index 0000000..bcf7da6 --- /dev/null +++ b/bt/messages.go @@ -0,0 +1,99 @@ +package bt + +// BEP4 Core protocol Message IDs + +const ( + BTMsgChoke uint8 = iota + BTMsgUnchoke + BTMsgInterested + BTMsgNotInterested + BTMsgHave + BTMsgBitfield + BTMsgRequest + BTMsgPiece + BTMsgCancel + BTMsgPort + _ + _ + _ + BTMsgSuggest // 13 + BTMsgHaveAll + BTMsgHaveNone + BTMsgRejectRequest + BTMsgAllowedFast + _ + _ + BTMsgExtended // 20 Extended messages BEP10 +) +const ( + // BEP10 send message types + ExtMsgTypeHandshake = uint8(0) + // >0 is extension dependant +) + +const ( + // BEP9 extension message types + + // ExtMsgTypeRequest marks a request message type + ExtMsgTypeRequest = 0 + // extMsgTypeData marks a data message type + ExtMsgTypeData = 1 + // extMsgTypeReject marks a reject message type + ExtMsgTypeReject = 2 +) + +type ExtMsg struct { + Type int `bencode:"msg_type"` + Piece int `bencode:"piece"` + TotalSize int `bencode:"total_size,omitzero"` +} + +type ExtMsgHandshake struct { + // Dictionary of supported extension messages which maps names of + // extensions to an extended message ID for each extension message. + Messages map[string]int `bencode:"m"` + // Local TCP listen port. Allows each side to learn about the TCP port + // number of the other side. Note that there is no need for the receiving + // side of the connection to send this extension message, since its port + // number is already known. + Port int `bencode:"p,omitzero"` + // Client name and version (as a utf-8 string). + Version string `bencode:"v,omitzero"` + // A string containing the compact representation of the ip address this + // peer sees you as. i.e. this is the receiver's external ip address + // (no port is included). This may be either an IPv4 (4 bytes) or an + // IPv6 (16 bytes) address. + IP string `bencode:"yourip,omitzero"` + // If this peer has an IPv6 interface, this is the compact representation + // of that address (16 bytes). + IPv6 string `bencode:"ipv6,omitzero"` + // If this peer has an IPv4 interface, this is the compact representation + // of that address (4 bytes). + IPv4 string `bencode:"ipv4,omitzero"` + // An integer, the number of outstanding request messages this client + // supports without dropping any. The default in in libtorrent is 250. + Qsize int `bencode:"reqq,omitzero"` + + // BEP9 + // Specifies an integer value of the number of bytes of the metadata. + MetadataSize int `bencode:"metadata_size,omitzero"` +} + +type MetaInfo struct { + PieceLength int `bencode:"piece_length"` + Pieces string `bencode:"pieces"` + Private bool `bencode:"private"` + + Name string `bencode:"name"` + + // Single file + Length int `bencode:"length,omitzero"` + MD5Sum string `bencode:"md5sum,omitzero"` + + // Multiple files + Files []struct { + Length int `bencode:"length"` + MD5Sum string `bencode:"md5sum"` + Path string `bencode:"path"` + } `bencode:"files,omitzero"` +} diff --git a/bt/options.go b/bt/options.go deleted file mode 100644 index 391ad0f..0000000 --- a/bt/options.go +++ /dev/null @@ -1,50 +0,0 @@ -package bt - -import ( - "src.userspace.com.au/dhtsearch/models" - "src.userspace.com.au/logger" -) - -type Option func(*Worker) error - -// SetOnNewTorrent sets the callback -func SetOnNewTorrent(f func(models.Torrent)) Option { - return func(w *Worker) error { - w.OnNewTorrent = f - return nil - } -} - -// SetOnBadPeer sets the callback -func SetOnBadPeer(f func(models.Peer)) Option { - return func(w *Worker) error { - w.OnBadPeer = f - return nil - } -} - -// SetPort sets the port to listen on -func SetPort(p int) Option { - return func(w *Worker) error { - w.port = p - return nil - } -} - -// SetIPv6 enables IPv6 -func SetIPv6(b bool) Option { - return func(w *Worker) error { - if b { - w.family = "tcp6" - } - return nil - } -} - -// SetLogger sets the logger -func SetLogger(l logger.Logger) Option { - return func(w *Worker) error { - w.log = l - return nil - } -} diff --git a/bt/torrent.go b/bt/torrent.go new file mode 100644 index 0000000..785fd47 --- /dev/null +++ b/bt/torrent.go @@ -0,0 +1,289 @@ +package bt + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" + "time" + + "userspace.com.au/dhtsearch/bencode" + "userspace.com.au/dhtsearch/infohash" +) + +type Torrent struct { + ih infohash.ID +} + +func NewTorrent(ih infohash.ID) *Torrent { + return &Torrent{ih: ih} +} + +func (t *Torrent) FetchMetadata(ctx context.Context, ap netip.AddrPort) ([]byte, error) { + tcpTimeout := 5 * time.Second + tcpAddr := net.TCPAddrFromAddrPort(ap) + conn, err := net.DialTCP("tcp", nil, tcpAddr) + if err != nil { + return nil, err + } + defer conn.Close() + // conn.SetLinger(0) + // conn.SetNoDelay(true) + + local := conn.LocalAddr().(*net.TCPAddr).AddrPort() + + // Collect reads here + buf := new(bytes.Buffer) + //buf.Grow(blockSize) + + // Handshake + conn.SetWriteDeadline(time.Now().Add(tcpTimeout)) + hsLength, err := sendHandshake(conn, t.ih) + if err != nil { + return nil, err + } + conn.SetReadDeadline(time.Now().Add(tcpTimeout)) + if err := read(conn, hsLength, buf); err != nil { + return nil, fmt.Errorf("handshake failed: %w", err) + } + resp := buf.Next(hsLength) + // BEP10 Check for extension protocol + // The bit selected for the extension protocol is bit 20 from the right + // (counting starts at 0). So (reserved_byte[5] & 0x10) is the + // expression to use for checking if the client supports extended messaging. + // See https://www.libtorrent.org/extension_protocol.html + extBit := 1 + len(protocolString) + 6 + if resp[extBit]&0x10 != 0 { + return nil, errors.New("extension protocol not supported") + } + if err = sendExtHandshake(conn, local, ap); err != nil { + return nil, err + } + + var ( + // extended message type, set in first message + extMsgType uint8 + // Expected size of all metadata, set in first message + // Used to determine length of last piece + metadataSize int + // Collection of piece data, the output + pieces [][]byte + // Expected number of pieces, derived from metadatasize + totalPieces int + // Current count of pieces + gotPieces int + ) + + // Loop over incoming messages + for { + buf.Reset() + if err := readMessage(conn, buf); err != nil { + return nil, err + } + length := buf.Len() + if length == 0 { + continue + } + // We are only interested in bt extension protocol (20) + msgProtocol, err := buf.ReadByte() + if err != nil { + return nil, err + } + if msgProtocol != BTMsgExtended { + continue + } + + // Read the extension type ie. 0 == handshake + msgExtType, err := buf.ReadByte() + if err != nil { + return nil, err + } + // // This was set in the handshake and should match + // if msgExtType != ExtMsgTypeData { + // continue + // } + payload, err := io.ReadAll(buf) + if err != nil { + return nil, err + } + // We are past the protocol and extension type bits + br := bencode.NewReaderFromBytes(payload) + + if msgExtType == ExtMsgTypeHandshake { + if pieces != nil { + // We have already done the handshake! + return nil, errors.New("duplicate handshake") + } + msg := ExtMsgHandshake{} + if !br.ReadStruct(&msg) { + return nil, br.Err() + } + totalPieces = msg.MetadataSize / blockSize + if msg.MetadataSize%blockSize != 0 { + totalPieces += 1 + } + if totalPieces == 0 { + return nil, errors.New("no pieces to fetch") + } + pieces = make([][]byte, totalPieces) + + // The extenstion type we can handle + if id, ok := msg.Messages["ut_metadata"]; !ok { + return nil, errors.New("missing ut_metadata extension") + } else { + extMsgType = uint8(id) + } + + metadataSize = msg.MetadataSize + // Request all metadata pieces + if err := requestMetadata(conn, extMsgType, totalPieces); err != nil { + return nil, err + } + continue + } + if pieces == nil { + return nil, errors.New("pieces not created") + } + + var msg ExtMsg + if !br.ReadStruct(&msg) { + return nil, br.Err() + } + if msg.Type != ExtMsgTypeData { + continue + } + pieceLen := len(payload) - int(br.Count()) + if msg.Piece == totalPieces-1 { + // Last piece needs to equal remainder + if pieceLen != metadataSize%blockSize { + return nil, fmt.Errorf("incorrect final piece %d", msg.Piece) + } + } else { + // Should be full block + if pieceLen != blockSize { + return nil, fmt.Errorf("incomplete piece %d", msg.Piece) + } + } + + pieces[msg.Piece] = payload[br.Count():] + gotPieces += 1 + if gotPieces == totalPieces { + break + } + } + return bytes.Join(pieces, nil), nil +} + +// Our peer name and version TODO +const peerID = "ML-010" + +// BTv1 identifier +const protocolString = "BitTorrent protocol" + +// The metadata is handled in blocks of 16KiB (16384 Bytes), see BEP9 +const blockSize = 16384 + +const ( + extMsgTypeHandshake = uint8(0) +) + +func sendHandshake(w io.Writer, id infohash.ID) (int, error) { + // 49+len(pstr) bytes long + // <pstrlen><pstr><reserved><info_hash><peer_id> + out := make([]byte, 49+len(protocolString)) + out[0] = byte(len(protocolString)) + n := 1 + // pstr + copy(out[n:n+len(protocolString)], []byte(protocolString)) + n += len([]byte(protocolString)) + // reserved, setting extended protocol bit + reserved := bytes.Repeat([]byte{0}, 8) + reserved[5] |= 0x10 // Extension protocol + reserved[7] |= 0x01 // Distributed Hash Table + copy(out[n:n+8], reserved) + n += 8 + // target infohash + copy(out[n:n+20], id[:]) + n += 20 + // our peer ID + copy(out[n:], peerID[:]) + return w.Write(out) +} + +func sendExtHandshake(rw io.ReadWriter, local, remote netip.AddrPort) error { + payload := ExtMsgHandshake{ + Messages: map[string]int{ + "ut_metadata": 1, + }, + IP: string(remote.Addr().AsSlice()), + Port: int(local.Port()), + } + if local.Addr().Is6() { + payload.IPv6 = string(local.Addr().AsSlice()) + } else { + payload.IPv4 = string(local.Addr().AsSlice()) + } + b, err := bencode.Marshal(payload) + if err != nil { + return err + } + _, err = sendMessage(rw, ExtMsgTypeHandshake, b) + return err +} + +// Send multiple requests for all the pieces +func requestMetadata(w io.Writer, extMsgType uint8, totalPieces int) error { + for i := range totalPieces { + payload := ExtMsg{ + Type: ExtMsgTypeRequest, + Piece: i, + } + b, err := bencode.Marshal(payload) + if err != nil { + return err + } + sendMessage(w, extMsgType, b) + } + return nil +} + +// read reads size-length bytes +func read(r io.Reader, size int, w io.Writer) error { + n, err := io.CopyN(w, r, int64(size)) + if err != nil { + return err + } + if n != int64(size) { + return fmt.Errorf("short read, got %d, want %d", n, size) + } + return nil +} + +// Sends a length prefix, protocol byte, type byte, then data, see BEP10 +func sendMessage(w io.Writer, mType uint8, data []byte) (n int, err error) { + // 4 bytes length prefix. Size of the entire message. (Big endian) + // 1 byte bt extended message id (20) + // 1 byte bt message type id (0: handshake, >0: according to handshake) + out := append([]byte{BTMsgExtended}, mType) + out = append(out, data...) + length := int32(len(out)) + if err = binary.Write(w, binary.BigEndian, length); err != nil { + return n, err + } + return w.Write(out) +} + +func readMessage(r io.Reader, buf *bytes.Buffer) error { + var length int32 + if err := binary.Read(r, binary.BigEndian, &length); err != nil { + return err + } + if length == 0 { + return nil + } + return read(r, int(length), buf) +} diff --git a/bt/worker.go b/bt/worker.go deleted file mode 100644 index f3c45d0..0000000 --- a/bt/worker.go +++ /dev/null @@ -1,398 +0,0 @@ -package bt - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - "io/ioutil" - "net" - "time" - - "src.userspace.com.au/dhtsearch/krpc" - "src.userspace.com.au/dhtsearch/models" - "src.userspace.com.au/go-bencode" - "src.userspace.com.au/logger" -) - -const ( - // MsgRequest marks a request message type - MsgRequest = iota - // MsgData marks a data message type - MsgData - // MsgReject marks a reject message type - MsgReject - // MsgExtended marks it as an extended message - MsgExtended = 20 -) - -const ( - // BlockSize is 2 ^ 14 - BlockSize = 16384 - // MaxMetadataSize represents the max medata it can accept - MaxMetadataSize = BlockSize * 1000 - // HandshakeBit represents handshake bit - HandshakeBit = 0 - // TCPTimeout for BT connections - TCPTimeout = 5 -) - -var handshakePrefix = []byte{ - 19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114, - 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1, -} - -type Worker struct { - pool chan chan models.Peer - port int - family string - OnNewTorrent func(t models.Torrent) - OnBadPeer func(p models.Peer) - log logger.Logger -} - -func NewWorker(pool chan chan models.Peer, opts ...Option) (*Worker, error) { - var err error - w := &Worker{ - pool: pool, - } - - // Set variadic options passed - for _, option := range opts { - err = option(w) - if err != nil { - return nil, err - } - } - return w, nil -} - -func (bt *Worker) Run() error { - peerCh := make(chan models.Peer) - - for { - // Signal we are ready for work - bt.pool <- peerCh - - select { - case p := <-peerCh: - // Got work - bt.log.Debug("worker got work", "peer", p) - md, err := bt.fetchMetadata(p) - if err != nil { - bt.log.Debug("failed to fetch metadata", "error", err) - if bt.OnBadPeer != nil { - bt.OnBadPeer(p) - } - continue - } - t, err := models.TorrentFromMetadata(p.Infohash, md) - if err != nil { - bt.log.Warn("failed to load torrent", "error", err) - continue - } - if bt.OnNewTorrent != nil { - bt.OnNewTorrent(*t) - } - } - } -} - -// fetchMetadata fetchs medata info accroding to infohash from dht. -func (bt *Worker) fetchMetadata(p models.Peer) (out []byte, err error) { - var ( - length int - msgType byte - totalPieces int - pieces [][]byte - utMetadata int - metadataSize int - ) - - defer func() { - pieces = nil - recover() - }() - - //ll := bt.log.WithFields("address", p.Addr.String()) - - //ll.Debug("connecting") - dial, err := net.DialTimeout("tcp", p.Addr.String(), time.Second*15) - if err != nil { - return out, err - } - // Cast - conn := dial.(*net.TCPConn) - conn.SetLinger(0) - defer conn.Close() - //ll.Debug("dialed") - - data := bytes.NewBuffer(nil) - data.Grow(BlockSize) - - ih := models.GenInfohash() - - // TCP handshake - //ll.Debug("sending handshake") - _, err = sendHandshake(conn, p.Infohash, ih) - if err != nil { - return nil, err - } - - // Handle the handshake response - //ll.Debug("handling handshake response") - err = read(conn, 68, data) - if err != nil { - return nil, err - } - next := data.Next(68) - //ll.Debug("got next data") - if !(bytes.Equal(handshakePrefix[:20], next[:20]) && next[25]&0x10 != 0) { - //ll.Debug("next data does not match", "next", next) - return nil, errors.New("invalid handshake response") - } - - //ll.Debug("sending ext handshake") - _, err = sendExtHandshake(conn) - if err != nil { - return nil, err - } - - for { - length, err = readMessage(conn, data) - if err != nil { - return out, err - } - - if length == 0 { - continue - } - - msgType, err = data.ReadByte() - if err != nil { - return out, err - } - - switch msgType { - case MsgExtended: - extendedID, err := data.ReadByte() - if err != nil { - return out, err - } - - payload, err := ioutil.ReadAll(data) - if err != nil { - return out, err - } - - if extendedID == 0 { - if pieces != nil { - return out, errors.New("invalid extended ID") - } - - utMetadata, metadataSize, err = getUTMetaSize(payload) - if err != nil { - return out, err - } - - totalPieces = metadataSize / BlockSize - if metadataSize%BlockSize != 0 { - totalPieces++ - } - - pieces = make([][]byte, totalPieces) - go bt.requestPieces(conn, utMetadata, metadataSize, totalPieces) - - continue - } - - if pieces == nil { - return out, errors.New("no pieces found") - } - - dict, index, err := bencode.DecodeDict(payload, 0) - if err != nil { - return out, err - } - - mt, err := krpc.GetInt(dict, "msg_type") - if err != nil { - return out, err - } - - if mt != MsgData { - continue - } - - piece, err := krpc.GetInt(dict, "piece") - if err != nil { - return out, err - } - - pieceLen := length - 2 - index - - // Not last piece? should be full block - if totalPieces > 1 && piece != totalPieces-1 && pieceLen != BlockSize { - return out, fmt.Errorf("incomplete piece %d", piece) - } - // Last piece needs to equal remainder - if piece == totalPieces-1 && pieceLen != metadataSize%BlockSize { - return out, fmt.Errorf("incorrect final piece %d", piece) - } - - pieces[piece] = payload[index:] - - if bt.isDone(pieces) { - return bytes.Join(pieces, nil), nil - } - default: - data.Reset() - } - } -} - -// isDone checks if all pieces are complete -func (bt *Worker) isDone(pieces [][]byte) bool { - for _, piece := range pieces { - if len(piece) == 0 { - return false - } - } - return true -} - -// read reads size-length bytes from conn to data. -func read(conn net.Conn, size int, data io.Writer) error { - conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) - n, err := io.CopyN(data, conn, int64(size)) - if err != nil || n != int64(size) { - return errors.New("read error") - } - return nil -} - -// readMessage gets a message from the tcp connection. -func readMessage(conn net.Conn, data *bytes.Buffer) (length int, err error) { - if err = read(conn, 4, data); err != nil { - return length, err - } - - length, err = bytes2int(data.Next(4)) - if err != nil { - return length, err - } - - if length == 0 { - return length, nil - } - - err = read(conn, length, data) - return length, err -} - -// sendMessage sends data to the connection. -func sendMessage(conn net.Conn, data []byte) (int, error) { - length := int32(len(data)) - - buffer := bytes.NewBuffer(nil) - binary.Write(buffer, binary.BigEndian, length) - - conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) - return conn.Write(append(buffer.Bytes(), data...)) -} - -// sendHandshake sends handshake message to conn. -func sendHandshake(conn net.Conn, ih, id models.Infohash) (int, error) { - data := make([]byte, 68) - copy(data[:28], handshakePrefix) - copy(data[28:48], []byte(ih)) - copy(data[48:], []byte(id)) - - conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout))) - return conn.Write(data) -} - -// onHandshake handles the handshake response. -func onHandshake(data []byte) (err error) { - if !(bytes.Equal(handshakePrefix[:20], data[:20]) && data[25]&0x10 != 0) { - err = errors.New("invalid handshake response") - } - return err -} - -// sendExtHandshake requests for the ut_metadata and metadata_size. -func sendExtHandshake(conn net.Conn) (int, error) { - m, err := bencode.EncodeDict(map[string]interface{}{ - "m": map[string]interface{}{"ut_metadata": 1}, - }) - if err != nil { - return 0, err - } - data := append([]byte{MsgExtended, HandshakeBit}, m...) - - return sendMessage(conn, data) -} - -// getUTMetaSize returns the ut_metadata and metadata_size. -func getUTMetaSize(data []byte) (utMetadata int, metadataSize int, err error) { - dict, _, err := bencode.DecodeDict(data, 0) - if err != nil { - return utMetadata, metadataSize, err - } - - m, err := krpc.GetMap(dict, "m") - if err != nil { - return utMetadata, metadataSize, err - } - - utMetadata, err = krpc.GetInt(m, "ut_metadata") - if err != nil { - return utMetadata, metadataSize, err - } - - metadataSize, err = krpc.GetInt(dict, "metadata_size") - if err != nil { - return utMetadata, metadataSize, err - } - - if metadataSize > MaxMetadataSize { - err = errors.New("metadata_size too long") - } - return utMetadata, metadataSize, err -} - -// Request more pieces -func (bt *Worker) requestPieces(conn net.Conn, utMetadata int, metadataSize int, totalPieces int) { - buffer := make([]byte, 1024) - for i := 0; i < totalPieces; i++ { - buffer[0] = MsgExtended - buffer[1] = byte(utMetadata) - - msg, _ := bencode.EncodeDict(map[string]interface{}{ - "msg_type": MsgRequest, - "piece": i, - }) - - length := len(msg) + 2 - copy(buffer[2:length], msg) - - sendMessage(conn, buffer[:length]) - } - buffer = nil -} - -// bytes2int returns the int value it represents. -func bytes2int(data []byte) (int, error) { - n := len(data) - if n > 8 { - return 0, errors.New("data too long") - } - - val := uint64(0) - - for i, b := range data { - val += uint64(b) << uint64((n-i-1)*8) - } - return int(val), nil -} diff --git a/cmd/indexer/README.md b/cmd/indexer/README.md new file mode 100644 index 0000000..345301a --- /dev/null +++ b/cmd/indexer/README.md @@ -0,0 +1,10 @@ +# DHT indexer + +The general process: + +- generate random infohash +- on tick: send sample request to close nodes +- on find_nodes response: send sample request to found +- on samples response: send get_peers request to sender +- on get_peers response: download metadata + diff --git a/cmd/indexer/main.go b/cmd/indexer/main.go new file mode 100644 index 0000000..758c08e --- /dev/null +++ b/cmd/indexer/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "context" + "fmt" + "os" +) + +func main() { + ctx := context.Background() + if err := run(ctx, os.Args[1:], os.Getenv, os.Stdout, os.Stderr); err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err) + os.Exit(1) + } +} diff --git a/cmd/indexer/otel.go b/cmd/indexer/otel.go new file mode 100644 index 0000000..ae86a3f --- /dev/null +++ b/cmd/indexer/otel.go @@ -0,0 +1,45 @@ +package main + +import ( + "context" + "time" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp" + "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/resource" + semconv "go.opentelemetry.io/otel/semconv/v1.37.0" +) + +func newMeterProvider(ctx context.Context) (func(context.Context) error, error) { + res, err := resource.Merge( + resource.Default(), + resource.NewWithAttributes( + semconv.SchemaURL, + attribute.String("service.name", "indexer"), + //semconv.ServiceVersion("0.1.0"), + ), + ) + if err != nil { + return nil, err + } + exporter, err := otlpmetrichttp.New(ctx) + if err != nil { + return nil, err + } + // metricExporter, err := stdoutmetric.New() + // if err != nil { + // panic(err) + // } + + mp := metric.NewMeterProvider( + metric.WithResource(res), + metric.WithReader(metric.NewPeriodicReader( + exporter, + metric.WithInterval(30*time.Second), + )), + ) + otel.SetMeterProvider(mp) + return mp.Shutdown, nil +} diff --git a/cmd/indexer/run.go b/cmd/indexer/run.go new file mode 100644 index 0000000..85b3198 --- /dev/null +++ b/cmd/indexer/run.go @@ -0,0 +1,245 @@ +package main + +import ( + "bytes" + "context" + "errors" + "flag" + "io" + "log/slog" + "net" + "net/netip" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "userspace.com.au/dhtsearch/bencode" + "userspace.com.au/dhtsearch/bt" + "userspace.com.au/dhtsearch/dht" + "userspace.com.au/dhtsearch/infohash" +) + +func run( + ctx context.Context, + _ []string, + _ func(string) string, + stdout io.Writer, + _ io.Writer, +) error { + ctx, cancel := signal.NotifyContext( + ctx, + syscall.SIGINT, syscall.SIGTERM, + ) + defer cancel() + + var ( + addrFlag string + publicFlag string + stateFlag string + secureFlag bool + metricsFlag bool + verboseFlag bool + ) + + flag.StringVar(&addrFlag, "addr", "", "listen address:port") + flag.StringVar(&publicFlag, "public", "", "public address:port") + flag.StringVar(&stateFlag, "state", "", "state file") + flag.BoolVar(&secureFlag, "secure", false, "Generate secure infohash") + flag.BoolVar(&metricsFlag, "metrics", false, "Generate metrics") + flag.BoolVar(&verboseFlag, "verbose", false, "verbose") + flag.Parse() + + var lvl = new(slog.LevelVar) + lvl.Set(slog.LevelInfo) + if verboseFlag { + lvl.Set(slog.LevelDebug) + } + logger := slog.New(slog.NewTextHandler( + stdout, + &slog.HandlerOptions{Level: lvl}, + )) + + opts := []dht.Option{ + dht.WithLogger(logger), + } + if addrFlag != "" { + opts = append(opts, dht.WithListenAddress(addrFlag)) + } else { + ip := getFirstPublicIP() + if ip == nil { + logger.Error("no IP") + return errors.New("no IP") + } + addr := net.JoinHostPort(ip.String(), "6881") + l, _ := net.ListenPacket("udp", addr) + opts = append(opts, dht.WithListener(l)) + + } + if publicFlag != "" { + opts = append(opts, dht.WithPublicAddr(publicFlag, secureFlag)) + } + + if stateFlag != "" { + b, err := os.ReadFile(stateFlag) + if err != nil { + if !os.IsNotExist(err) { + return err + } + } else { + opts = append(opts, dht.WithState(bytes.NewReader(b))) + } + } + + if metricsFlag { + closer, err := newMeterProvider(ctx) + if err != nil { + return err + } + defer func() { + if err := closer(context.Background()); err != nil { + logger.Error(err.Error()) + } + }() + } + + //fetchMetadata := fetchMetadata(logger) + + svr, err := dht.NewClient(ctx, opts...) + if err != nil { + return err + } + gotInfohash := make(map[infohash.ID]bool) + type toGet struct { + ih infohash.ID + ap netip.AddrPort + } + hashes := make(chan toGet) + + go func() { + for tg := range hashes { + if gotInfohash[tg.ih] { + continue + } + gotInfohash[tg.ih] = true + t := bt.NewTorrent(tg.ih) + logger.Info("fetching metadata", "ih", tg.ih, "address", tg.ap) + b, err := t.FetchMetadata(ctx, tg.ap) + if err != nil { + gotInfohash[tg.ih] = false + logger.Error("failed to fetch metadata", "error", err) + return + } + var info bt.MetaInfo + err = bencode.Unmarshal(b, &info) + logger.Info("fetched metadata", "name", info.Name) + + } + }() + // port := uint16(m.Args.Port) + // if m.Args.ImpliedPort { + // port = rn.AddrPort.Port() + // } + // ap := netip.AddrPortFrom(rn.AddrPort.Addr(), port) + + getPeers := func(ctx context.Context, rn *dht.Node, m dht.Msg) { + if m.Response == nil { + return + } + samples := infohash.SplitCompactInfohashes(m.Response.Samples) + logger.Info("got samples", "count", len(samples)) + for _, ih := range samples { + svr.GetPeers(rn, ih, func(rn *dht.Node) { + logger.Info("get_peers callback") + hashes <- toGet{ + ih: ih, + ap: rn.AddrPort, + } + }) + } + } + getSamples := func(_ context.Context, _ *dht.Node, _ dht.Msg) { + ih := infohash.NewRandomID() + svr.GetSamples(ih) + } + //svr.OnPeersResponse = fetchMetadata + svr.OnNodesResponse = getSamples + svr.OnSamplesResponse = getPeers + + var wg sync.WaitGroup + wg.Go(func() { + if err := svr.Run(ctx); err != nil { + logger.Error("dht client failed", "error", err) + } + }) + + wg.Go(func() { + tick := time.NewTicker(10 * time.Second) + for { + select { + case <-ctx.Done(): + return + case <-tick.C: + ih := infohash.NewRandomID() + svr.GetSamples(ih) + } + } + }) + wg.Wait() + + if stateFlag != "" { + f, err := os.Create(stateFlag) + if err != nil { + return err + } + if err := svr.SaveState(f); err != nil { + return err + } + err = f.Close() + } + return err +} + +func getFirstPublicIP() net.IP { + addrs, _ := net.InterfaceAddrs() + for _, a := range addrs { + ip, _, err := net.ParseCIDR(a.String()) + if err == nil && !ip.IsLoopback() && !ip.IsPrivate() { + return ip + } + } + return nil +} + +func fetchMetadata(log *slog.Logger) func(ctx context.Context, rn *dht.Node, m dht.Msg) { + return func(ctx context.Context, rn *dht.Node, m dht.Msg) { + // port := uint16(m.Args.Port) + // // If it is present and non-zero, the port argument should be + // // ignored and the source port of the UDP packet should be used + // // as the peer's port instead. + // if m.Args.ImpliedPort { + // port = rn.AddrPort.Port() + // } + if m.Response == nil || m.Response.ID == nil { + log.Warn("cannot fetch metadata, missing details") + return + } + // ap := netip.AddrPortFrom(rn.AddrPort.Addr(), port) + ap := rn.AddrPort + + samples := infohash.SplitCompactInfohashes(m.Response.Samples) + for _, id := range samples { + t := bt.NewTorrent(id) + log.Info("fetching metadata", "ih", id, "address", ap) + b, err := t.FetchMetadata(ctx, ap) + if err != nil { + log.Error("failed to fetch metadata", "error", err) + } + var info bt.MetaInfo + err = bencode.Unmarshal(b, &info) + log.Info("fetched metadata", "name", info.Name) + } + + } +} diff --git a/dht/client.go b/dht/client.go new file mode 100644 index 0000000..ed0d0ce --- /dev/null +++ b/dht/client.go @@ -0,0 +1,767 @@ +package dht + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/netip" + "os" + "strconv" + "sync" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "userspace.com.au/dhtsearch/bencode" + "userspace.com.au/dhtsearch/infohash" +) + +var ( + DefaultBootstraps = []string{ + //"malkmus.yelnah.org:51000", + "torrents:51000", + "dht.aelitis.com:6881", + "dht.libtorrent.org:25401", + "dht.transmissionbt.com:6881", + "router.bittorrent.cloud:42069", + "router.bittorrent.com:6881", + "router.silotis.us:6881", + "router.utorrent.com:6881", + } +) + +const defaultPacketSize = 1500 + +const ( + defaultTickTableMaintenance = 31 * time.Second + defaultStaleNodeGrace = 15 * time.Minute + defaultTickAuditTransactions = 23 * time.Second +) + +// Client joins the DHT network +type Client struct { + conn net.PacketConn + maxPacketSize int + + node Node + bootstraps []string + udpTimeout int + packetsOut chan packet + chk *transactions + log *slog.Logger + + table RoutingTable + tableState io.Reader + + // Ticker and TTL durations + tickTableMaintenance time.Duration + tickAuditTransactions time.Duration + staleNodeGrace time.Duration + + //limiter *rate.Limiter + ////blacklist *lru.ARCCache + + // Hooks + + // onPingQuery is called when a ping query is received + onPingQuery func(*Node) + // onFindNodeQuery is called when a find_node query is received + onFindNodeQuery func(context.Context, *Node, Msg) + onGetPeersQuery func(context.Context, *Node, Msg) + onAnnouncePeerQuery func(context.Context, *Node, Msg) + OnNodesResponse func(context.Context, *Node, Msg) + OnPeersResponse func(context.Context, *Node, Msg) + OnSamplesResponse func(context.Context, *Node, Msg) +} + +// NewClient creates a new DHT client +func NewClient(ctx context.Context, opts ...Option) (*Client, error) { + var err error + + c := &Client{ + maxPacketSize: defaultPacketSize, + chk: newTransRegistry(), + node: Node{ + ID: infohash.NewRandomID(), + }, + log: slog.New(slog.NewTextHandler(os.Stdout, nil)), + udpTimeout: 1000, + //limiter: rate.NewLimiter(rate.Limit(50), 70), + + tickTableMaintenance: defaultTickTableMaintenance, + tickAuditTransactions: defaultTickAuditTransactions, + staleNodeGrace: defaultStaleNodeGrace, + } + + // Set variadic options passed + for _, option := range opts { + err = option(c) + if err != nil { + return nil, err + } + } + + if c.table == nil { + // Default ktable implementation + table := NewTable(c.node.ID) + table.log = c.log + c.table = table + } + + // After table exists, before import + if err := configureMetrics(c); err != nil { + return c, err + } + + if c.tableState != nil { + if _, err := c.table.ReadFrom(c.tableState); err != nil { + c.log.Warn("failed to read table", "error", err) + } + c.node.ID = c.table.ID() + } + + // if n.blacklist == nil { + // n.blacklist, err = lru.NewARC(1000) + // if err != nil { + // return nil, err + // } + // } + + if c.conn == nil { + if c.conn, err = net.ListenPacket("udp", "0.0.0.0:6881"); err != nil { + c.log.Error("failed to listen", "error", err) + return nil, err + } + } + + if len(c.bootstraps) == 0 { + c.bootstraps = DefaultBootstraps + } + + return c, nil +} + +type RoutingTable interface { + ID() infohash.ID + Add(*Node) bool + Has(infohash.ID) bool + SetSeen(infohash.ID) + Count() int + GetClosest(infohash.ID, int) []*Node + GetStale(int, time.Duration) []*Node + Remove(infohash.ID) bool + + // Used for import/export + io.WriterTo + io.ReaderFrom +} + +type Option func(*Client) error + +// WithPublicAddr sets the IP:port if different to the listen IP:port +func WithPublicAddr(s string, secure bool) Option { + return func(c *Client) error { + var err error + c.node.AddrPort, err = netip.ParseAddrPort(s) + if err != nil { + return err + } + fmt.Println(c.node.AddrPort) + ip := net.UDPAddrFromAddrPort(c.node.AddrPort) + c.node.ID = infohash.NewRandomID() + if secure { + c.node.ID = infohash.NewSecureID(ip.IP) + } + c.node.Secure = infohash.IsSecure(c.node.ID, ip.IP) + return nil + } +} + +func WithListener(p net.PacketConn) Option { + return func(c *Client) error { + c.conn = p + return nil + } +} + +// WithListenAddress sets the IP:port address to listen on +func WithListenAddress(s string) Option { + return func(c *Client) error { + var err error + c.conn, err = net.ListenPacket("udp", s) + return err + } +} + +// WithLogger sets the IP:port address to listen on +func WithLogger(l *slog.Logger) Option { + return func(c *Client) error { + c.log = l + return nil + } +} + +// WithBootstraps enables custom bootstrap addresses +func WithBootstraps(addrs ...string) Option { + return func(c *Client) error { + c.bootstraps = addrs + return nil + } +} + +// WithTable enables using a custom node table implementation. +func WithTable(rt RoutingTable) Option { + return func(c *Client) error { + c.table = rt + return nil + } +} + +// WithState enables using a custom node table implementation. +func WithState(r io.Reader) Option { + return func(c *Client) error { + c.tableState = r + return nil + } +} + +// OnAnnoucePeer is called when an announce_peer query is received +func OnAnnouncePeerQuery(f func(context.Context, *Node, Msg)) Option { + return func(c *Client) error { + c.onAnnouncePeerQuery = f + return nil + } +} + +// OnGetPeersQuery is called when a get_peers query is received +func OnGetPeersQuery(f func(context.Context, *Node, Msg)) Option { + return func(c *Client) error { + c.onGetPeersQuery = f + return nil + } +} + +// OnPeersResponse is called when a response has peers +func OnPeersResponse(f func(context.Context, *Node, Msg)) Option { + return func(c *Client) error { + c.OnPeersResponse = f + return nil + } +} + +// OnSamplesResponse is called when a response has samples +func OnSamplesResponse(f func(context.Context, *Node, Msg)) Option { + return func(c *Client) error { + c.OnSamplesResponse = f + return nil + } +} + +// Close stuff +func (c *Client) Close() error { + c.log.Warn("node closing") + return nil +} + +func (c *Client) SaveState(w io.Writer) error { + _, err := c.table.WriteTo(w) + return err +} + +// Run starts the node on the DHT +func (c *Client) Run(ctx context.Context) error { + var wg sync.WaitGroup + wg.Go(func() { c.packetWriter(ctx) }) + wg.Go(func() { c.tableMaintenance(ctx) }) + wg.Go(func() { c.auditTransactions(ctx) }) + wg.Go(func() { c.packetReader(ctx) }) + c.log.Info("listening", "id", c.node.ID, "address", c.node.AddrPort, "listen", c.conn.LocalAddr().String()) + wg.Wait() + // Close this after packetWriter has finished + close(c.packetsOut) + c.log.Debug("client stopped") + return nil +} + +func (c *Client) wants() []string { + if c.node.AddrPort.Addr().Is4() { + return []string{Want4} + } else { + return []string{Want6} + } +} + +//type PeerStore interface { +// Add(Node) (bool, error) +// Get(int) ([]Node, error) +// //Delete(Node) error +// Reset() error + +// // Used for import/export +// io.WriterTo +// io.ReaderFrom +//} + +// func addrPort2Addr(in netip.AddrPort) net.Addr { +// return net.UDPAddrFromAddrPort(in) +// } +// func addr2AddrPort(in net.Addr) (netip.AddrPort, error) { +// return netip.ParseAddrPort(in.String()) +// } + +// Unprocessed packet from socket +type packet struct { + raddr net.Addr + data []byte +} + +func (c *Client) tableMaintenance(ctx context.Context) { + c.log.Info("starting table maintenance", "interval", c.tickTableMaintenance, "grace", c.staleNodeGrace) + + pingStale := func() { + stale := c.table.GetStale(8, c.staleNodeGrace) + for _, n := range stale { + // if n.FailedResponses > 2 { + // c.table.Remove(n.ID) + // removed++ + // continue + // } + n.PingAttempts += 1 + // Don't just ping, get more nodes + //c.SendMsg(*n, NewPingQuery(c.node.ID)) + rID := infohash.NewCloseID(c.node.ID) + c.sendMsg(n, NewFindNodeQuery(c.node.ID, rID, c.wants())) + } + c.log.Debug("contacted stale nodes", "sent", len(stale)) + } + + makeNeighbours := func() { + n := c.table.Count() + if n == 0 { + c.bootstrap() + return + } + rID := infohash.NewRandomID() + nodes := c.table.GetClosest(c.node.ID, 8) + for _, rn := range nodes { + c.sendMsg(rn, NewFindNodeQuery(c.node.ID, rID, c.wants())) + } + c.log.Debug("making neighbours", "sent", len(nodes)) + } + + // Once at startup + makeNeighbours() + + ticker := time.Tick(c.tickTableMaintenance) + for { + select { + case <-ctx.Done(): + c.log.Info("stopping table maintenance") + return + case <-ticker: + makeNeighbours() + pingStale() + } + } +} + +func (c *Client) bootstrap() { + for _, s := range c.bootstraps { + host, portStr, err := net.SplitHostPort(s) + if err != nil { + c.log.Warn("failed to parse bootstrap entry", "address", s, "error", err) + continue + } + log := c.log.With(slog.String("host", host)) + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + c.log.Warn("failed to parse bootstrap port", "error", err) + continue + } + ips, err := net.LookupIP(host) + if err != nil { + log.Warn("failed to resolve", "error", err) + continue + } + for _, b := range ips { + ip, ok := netip.AddrFromSlice(b) + if !ok { + log.Warn("failed to convert address") + continue + } + if c.node.AddrPort.Addr().Is4() != ip.Is4() { + log.Debug("different network family") + continue + } + ap := netip.AddrPortFrom(ip, uint16(port)) + rn := &Node{ + ID: infohash.NewRandomID(), + AddrPort: ap, + } + log.Info("bootstrapping", "address", ap) + c.sendMsg(rn, NewFindNodeQuery(c.node.ID, rn.ID, c.wants())) + } + } +} + +// GetPeers sends a get_peers message +func (c *Client) GetPeers(rn *Node, ih infohash.ID, cb TransactionCallback) { + msg := &Msg{ + Query: "get_peers", + Type: "q", + Args: &MsgArgs{ + ID: &c.node.ID, + InfoHash: &ih, + }, + } + c.log.Debug("sending get_peers", "ih", ih, "raddr", rn.AddrPort, "id", rn.ID) + c.sendMsgWithCallback(rn, msg, cb) +} + +// GetSamples sends a sample_infohashes message +func (c *Client) GetSamples(ih infohash.ID) { + nodes := c.table.GetClosest(ih, 8) + msg := &Msg{ + Query: "sample_infohashes", + Type: "q", + Args: &MsgArgs{ + // The querying node + ID: &c.node.ID, + // The ID sought + Target: &ih, + }, + } + for _, rn := range nodes { + c.sendMsg(rn, msg) + } +} + +func (c *Client) sendMsg(rn *Node, m *Msg) { + c.sendMsgWithCallback(rn, m, nil) +} + +// sendMsg sends a KRPC message to the network +func (c *Client) sendMsgWithCallback(rn *Node, m *Msg, cb TransactionCallback) { + log := c.log.With( + "type", m.Query, + "addr", rn.AddrPort, + ) + // Don't send to self + if rn.ID.Equal(c.node.ID) { + log.Warn("not sending to self") + return + } + if m.Type == "q" { + log = c.log.With( + "qType", m.Type, + "query", m.Query, + ) + if !rn.canSend(m.Query) { + log.Warn("ratelimited") + return + } + } + addr := net.UDPAddrFromAddrPort(rn.AddrPort) + if m.Type == "q" { + c.registerTransaction(rn, m, cb) + } + log = log.With("tid", m.TID) + b, err := bencode.Marshal(m) + if err != nil { + log.Warn("failed to marshal", "error", err) + } + log.Debug("sending message") + c.packetsOut <- packet{ + data: b, + raddr: addr, + } +} + +func (c *Client) packetWriter(ctx context.Context) { + c.log.Debug("starting packet writer") + // Packets onto the network + c.packetsOut = make(chan packet, 2024) + for { + select { + case <-ctx.Done(): + return + case p := <-c.packetsOut: + if p.raddr.String() == c.conn.LocalAddr().String() { + continue + } + _ = c.conn.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) + n, err := c.conn.WriteTo(p.data, p.raddr) + if err != nil { + //n.blacklist.Add(p.raddr.String(), true) + // TODO reduce limit + c.log.Warn("failed to write packet", "error", err) + } + netPackets.Add( + ctx, 1, + metric.WithAttributes( + attribute.String("network.io.direction", "transmit"))) + netOctets.Add( + ctx, int64(n), + metric.WithAttributes( + attribute.String("network.io.direction", "transmit"))) + } + } +} + +func (c *Client) packetReader(ctx context.Context) { + c.log.Info("starting packet reader") + pool := sync.Pool{ + New: func() any { + out := make([]byte, c.maxPacketSize) + return &out + }, + } + + for { + select { + case <-ctx.Done(): + c.log.Info("stopping packet reader") + _ = c.conn.Close() + return + default: + b := pool.Get().(*[]byte) + _ = c.conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + //c.log.Debug("waiting UDP read") + n, addr, err := c.conn.ReadFrom(*b) + netOctets.Add( + ctx, int64(n), + metric.WithAttributes( + attribute.String("network.io.direction", "receive"))) + if err != nil { + if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() { + continue + } + c.log.Warn("UDP read error", "error", err) + return + } + netPackets.Add( + ctx, 1, + metric.WithAttributes( + attribute.String("network.io.direction", "receive"))) + + var m Msg + if err := bencode.Unmarshal((*b)[:n], &m); err != nil { + c.log.Warn("krpc unmarshal msg failed", "error", err) + return + } + pool.Put(b) + rn := &Node{ + LastContact: time.Now(), + AddrPort: addr.(*net.UDPAddr).AddrPort(), + } + go c.processMessage(ctx, rn, m) + } + } +} + +// Parse a KRPC packet into a message +// Called in goroutine +func (c *Client) processMessage(ctx context.Context, rn *Node, m Msg) { + log := c.log.With("type", m.Type, "raddr", rn.AddrPort) + // if _, black := n.blacklist.Get(p.raddr.String()); black { + // return fmt.Errorf("blacklisted: %s", p.raddr.String()) + // } + var err error + if m.Type == "q" { + err = c.handleQuery(ctx, rn, m) + } else { + err = c.handleResponse(ctx, rn, m) + } + if err != nil { + log.Warn("failed to process message", "error", err) + //n.blacklist.Add(p.raddr.String(), true) + } +} + +// handleResponse handles responses received from udp. +func (c *Client) handleResponse(ctx context.Context, rn *Node, m Msg) error { + log := c.log.With( + "raddr", rn.AddrPort, + "id", rn.ID, + ) + + trans, err := c.checkTransaction(m) + if m.Type == "e" { + log = log.With("code", m.Error.Code, "error", m.Error.Msg, "tid", m.TID, "qType", trans.msg) + if m.Error.Code == 204 { + // Don't query this node like this again + rn.rateQuery(trans.msg, -1) + log.Info("rate limiting node") + } + return nil + } + if err != nil { + return err + } + + if m.Response == nil { + return errors.New("missing reponse") + } + rn.ID = *m.Response.ID + if !c.table.Has(rn.ID) { + _ = c.table.Add(rn) + } + c.table.SetSeen(rn.ID) + log.Debug("received response") + r := *m.Response + + // Add new nodes to our routing table + var newNodes []*Node + if r.NodesList != nil { + newNodes = append(newNodes, r.NodesList.Nodes...) + } + if r.Nodes6List != nil { + newNodes = append(newNodes, r.Nodes6List.Nodes...) + } + if len(newNodes) > 0 { + for _, cn := range newNodes { + _ = c.table.Add(cn) + } + if c.OnNodesResponse != nil { + log.Debug("on nodes") + // TODO set deadline + go c.OnNodesResponse(ctx, rn, m) + } + } + + if len(r.Values) > 0 { + rn.rateQuery(trans.msg, r.Interval) + if c.OnPeersResponse != nil { + log.Info("on peers") + // TODO set deadline + go c.OnPeersResponse(ctx, rn, m) + } + } + + if len(r.Samples) > 0 { + log.Info("on samples") + if r.Interval > 0 { + rn.NextQuery["sample_infohashes"] = time.Now().Add(time.Duration(r.Interval) * time.Second) + } + // if m.Response != nil && m.Response.ID != nil { + // samples := infohash.SplitCompactInfohashes( + // m.Response.Samples, + // ) + // for _, ih := range samples { + // c.GetPeers(ih) + // } + // } + if c.OnSamplesResponse != nil { + // TODO set deadline + go c.OnSamplesResponse(ctx, rn, m) + } + } + if trans.cb != nil { + go trans.cb(rn) + } + + return nil +} + +// handleQuery handles incoming queries +func (c *Client) handleQuery(ctx context.Context, rn *Node, m Msg) error { + if m.Args.ID == nil { + return errors.New("query missing ID") + } + rn.ID = *m.Args.ID + log := c.log.With( + "query", m.Query, + "raddr", rn.AddrPort, + "id", rn.ID, + ) + c.table.Add(rn) + log.Debug("received query") + + switch m.Query { + case "ping": + c.sendMsg(rn, &Msg{ + Type: "r", + TID: m.TID, + Response: &Response{ + ID: &c.node.ID, + }, + }) + if c.onPingQuery != nil { + log.Info("on ping") + // TODO set deadline + go c.onPingQuery(rn) + } + case "find_node": + if m.Args.Target == nil { + return errors.New("query missing target ID") + } + target := *m.Args.Target + out := &Msg{ + Type: "r", + TID: m.TID, + Response: &Response{ + ID: &c.node.ID, + Token: m.Args.Token, + }, + } + nodes := c.table.GetClosest(target, 8) + switch c.node.Family { + case Want4: + out.Response.NodesList.Nodes = nodes + case Want6: + out.Response.Nodes6List.Nodes = nodes + } + c.sendMsg(rn, out) + if c.onFindNodeQuery != nil { + log.Info("on find_node") + // TODO set deadline + go c.onFindNodeQuery(ctx, rn, m) + } + case "get_peers": + if m.Args.InfoHash == nil { + return errors.New("query missing infohash ID") + } + target := *m.Args.InfoHash + out := &Msg{ + Type: "r", + TID: m.TID, + Response: &Response{ + ID: &c.node.ID, + Token: m.Args.Token, + }, + } + nodes := c.table.GetClosest(target, 8) + switch c.node.Family { + case Want4: + out.Response.NodesList.Nodes = nodes + case Want6: + out.Response.Nodes6List.Nodes = nodes + } + c.sendMsg(rn, out) + if c.onGetPeersQuery != nil { + log.Info("on get_peers") + // TODO set deadline + go c.onGetPeersQuery(ctx, rn, m) + } + case "announce_peer": + // port := uint16(m.Args.Port) + // // If it is present and non-zero, the port argument should be + // // ignored and the source port of the UDP packet should be used + // // as the peer's port instead. + // if m.Args.ImpliedPort { + // port = rn.AddrPort.Port() + // } + if c.onAnnouncePeerQuery != nil { + log.Info("on announce_peer") + // TODO set deadline + go c.onAnnouncePeerQuery(ctx, rn, m) + } + default: + log.Warn("unknown type") + } + return nil +} diff --git a/dht/compact_node.go b/dht/compact_node.go new file mode 100644 index 0000000..3d9a405 --- /dev/null +++ b/dht/compact_node.go @@ -0,0 +1,133 @@ +package dht + +import ( + "encoding/binary" + "errors" + "fmt" + "net/netip" + "slices" + + "userspace.com.au/dhtsearch/bencode" + "userspace.com.au/dhtsearch/infohash" +) + +const ( + ip4AddrLength = 4 + ip6AddrLength = 16 + portLength = 2 + compactIP4Length = ip4AddrLength + portLength + compactIP6Length = ip6AddrLength + portLength + CompactNodeInfoLength = infohash.Length + compactIP4Length // 26 bytes + CompactNode6InfoLength = infohash.Length + compactIP6Length // 38 bytes +) + +type NodeAddr netip.AddrPort + +// Like AddrPort UnmarshalBinary except port is big-endian +func (n *NodeAddr) UnmarshalBinary(b []byte) error { + var addr netip.Addr + var offset int + var ok bool + if len(b) > compactIP4Length { + addr, ok = netip.AddrFromSlice(b[:ip6AddrLength]) + offset = ip6AddrLength + } else { + addr, ok = netip.AddrFromSlice(b[:ip4AddrLength]) + offset = ip4AddrLength + } + if !ok { + return errors.New("failed to parse node address") + } + ap := netip.AddrPortFrom(addr, binary.BigEndian.Uint16(b[offset:])) + *n = NodeAddr(ap) + return nil +} + +// Like AddrPort MarshalBinary except port is big-endian +func (n NodeAddr) MarshalBinary() ([]byte, error) { + ap := netip.AddrPort(n) + b, err := ap.Addr().MarshalBinary() + if err != nil { + return nil, err + } + return binary.Append(b, binary.BigEndian, ap.Port()) +} + +type CompactNodeList struct { + Nodes []*Node +} + +func (c *CompactNodeList) UnmarshalBencode(b []byte) error { + var in []byte + if err := bencode.Unmarshal(b, &in); err != nil { + return err + } + if mod := len(in) % CompactNodeInfoLength; mod != 0 { + return fmt.Errorf("CompactNodeList trailing %d bytes", mod) + } + //fmt.Println("chunking", len(in), "into chunks of", CompactNodeInfoLength) + for chunk := range slices.Chunk(in, CompactNodeInfoLength) { + if len(chunk) == CompactNodeInfoLength { + n := new(Node) + if err := n.UnmarshalBinary(chunk); err != nil { + continue + } + c.Nodes = append(c.Nodes, n) + } + } + return nil +} + +func (c *CompactNodeList) MarshalBencode() ([]byte, error) { + var bb []byte + for _, n := range c.Nodes { + b, err := n.MarshalBinary() + if err != nil { + return nil, err + } + bb = append(bb, b...) + } + out := fmt.Appendf([]byte{}, "%d:", len(bb)) + return append(out, bb...), nil +} + +type CompactNode6List struct { + Nodes []*Node +} + +func (c *CompactNode6List) UnmarshalBencode(b []byte) error { + var in []byte + if err := bencode.Unmarshal(b, &in); err != nil { + return err + } + //fmt.Println("nodes6", len(in), len(b), "in", hex.EncodeToString(b)) + // if in == "" { + // return nil + // } + if mod := len(in) % CompactNode6InfoLength; mod != 0 { + return fmt.Errorf("CompactNode6List trailing %d bytes", mod) + } + for chunk := range slices.Chunk([]byte(in), CompactNode6InfoLength) { + if len(chunk) == CompactNode6InfoLength { + n := new(Node) + if err := n.UnmarshalBinary(chunk); err != nil { + return err + } + c.Nodes = append(c.Nodes, n) + } + } + return nil +} + +func (c *CompactNode6List) MarshalBencode() ([]byte, error) { + var bb []byte + for _, n := range c.Nodes { + b, err := n.MarshalBinary() + if err != nil { + return nil, err + } + bb = append(bb, b...) + } + out := fmt.Appendf([]byte{}, "%d:", len(bb)) + return append(out, bb...), nil +} diff --git a/dht/compact_node_test.go b/dht/compact_node_test.go new file mode 100644 index 0000000..3429317 --- /dev/null +++ b/dht/compact_node_test.go @@ -0,0 +1,63 @@ +package dht + +import "testing" + +func TestCompactNodeList(t *testing.T) { + type nodeInfo struct { + id string + addr string + } + tests := map[string]struct { + in string + n int + items []nodeInfo + }{ + "contrived": { + in: "52:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x03\x04\x05\x06\x07", + n: 2, + items: []nodeInfo{{ + id: "0000000000000000000000000000000000000000", + addr: "1.2.3.4:1286", + },{ + id: "0000000000000000000000000000000000000000", + addr: "2.3.4.5:1543", + }}, + }, + "captured": { + in: "208:\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6", + n: 8, + items: []nodeInfo{{ + id: "b49964cf29aa11148d0d8b8d4ccc9162f6142f32", + addr: "175.177.204.43:20662", + },{ + id: "b49964cf29aa11148d0d8b8d4ccc9162f6142f32", + addr: "175.177.204.43:20662", + }}, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + t.Logf("input length=%d", len([]byte(tt.in))) + var cnl CompactNodeList + if err := cnl.UnmarshalBencode([]byte(tt.in)); err != nil { + t.Fatal(err) + } + if len(cnl.Nodes) != tt.n { + t.Fatalf("got %d, want %d", len(cnl.Nodes), tt.n) + } + for i, ni := range tt.items { + id := cnl.Nodes[i].ID + if id.String() != ni.id { + t.Fatalf("got %q, want %q", id.String(), ni.id) + } + addr := cnl.Nodes[i] + if addr.AddrPort.String() != ni.addr { + t.Fatalf("got %q, want %q", addr.AddrPort.String(), ni.addr) + } + } + }) + } + +} diff --git a/dht/krpc.go b/dht/krpc.go new file mode 100644 index 0000000..1c9ee4a --- /dev/null +++ b/dht/krpc.go @@ -0,0 +1,259 @@ +package dht + +import ( + "fmt" + + "userspace.com.au/dhtsearch/bencode" + "userspace.com.au/dhtsearch/infohash" +) + +type Msg struct { + // Query method + // one of: "ping", "find_node", "get_peers", "announce_peer" + Query string `bencode:"q,omitzero"` + + // Arguments sent with a query + Args *MsgArgs `bencode:"a,omitzero"` + + // Transaction ID, required + TID string `bencode:"t"` + + // Type of message, required + // one of: q for QUERY, r for RESPONSE, e for ERROR + Type string `bencode:"y"` + + // Response payload for type 'r' + Response *Response `bencode:"r,omitzero"` + + // Error payload for type 'e' + Error *Error `bencode:"e,omitzero"` + + //IP NodeAddr `bencode:"ip,omitzero"` + // Node IP address, required for BEP42 + IP string `bencode:"ip,omitzero"` + + // Sender does not respond to queries, BEP43 + ReadOnly bool `bencode:"ro,omitzero"` + + // https://www.libtorrent.org/dht_extensions.html + ClientId string `bencode:"v,omitzero"` +} + +type Response struct { + ID *infohash.ID `bencode:"id"` + + // K closest nodes + // from get_peers, find_nodes, get & sample_infohashes + NodesList *CompactNodeList `bencode:"nodes,omitzero"` + Nodes6List *CompactNode6List `bencode:"nodes6,omitzero"` + + // Token for future announce_peer or put, BEP44 + Token string + + // Torrent peers + Values []NodeAddr + + // BEP33 (scrapes) + // BFsd *ScrapeBloomFilter `bencode:"BFsd,omitzero"` + // BFpe *ScrapeBloomFilter `bencode:"BFpe,omitzero"` + + // BEP51 + Interval int64 + Num int64 + // Nodes supporting this extension should always include the samples field in the response, even + // when it is zero-length. This lets indexing nodes to distinguish nodes supporting this + // extension from those that respond to unknown query types which contain a target field. + Samples []byte `bencode:"samples,omitzero"` +} + +const ( + Want6 = "n6" + Want4 = "n4" +) + +type MsgArgs struct { + // ID of the querying Node + ID *infohash.ID `bencode:"id,omitzero"` + + // InfoHash of the torrent + InfoHash *infohash.ID `bencode:"info_hash,omitzero"` + + // ID of the node sought + Target *infohash.ID `bencode:"target,omitzero"` + + // Token received from an earlier get_peers query + // Also used in a BEP44 put + Token string `bencode:"token,omitzero"` + + // Sender's torrent port + Port int `bencode:"port,omitzero"` + + // Use senders apparent DHT port + ImpliedPort bool `bencode:"implied_port,omitzero"` + + // Network family wanted, any of "n4" and "n6", BEP32 + Want []string `bencode:"want,omitzero"` + + // BEP33 + // NoSeed int `bencode:"noseed,omitzero"` + // Scrape int `bencode:"scrape,omitzero"` + + // BEP44 + //V any `bencode:"v,omitzero"` + // Seq *int64 `bencode:"seq,omitzero"` + // Cas int64 `bencode:"cas,omitzero"` + // K [32]byte `bencode:"k,omitzero"` + // Salt []byte `bencode:"salt,omitzero"` + // Sig [64]byte `bencode:"sig,omitzero"` +} + +type Error struct { + Code int64 + Msg string +} + +func (e *Error) UnmarshalBencode(b []byte) error { + r := bencode.NewReaderFromBytes(b) + ok := r.ReadList(func(r *bencode.Reader) bool { + if !r.ReadInt(&e.Code) { + return false + } + if !r.ReadString(&e.Msg) { + return false + } + return true + }) + if !ok { + return fmt.Errorf("error unmarshal failed: %w", r.Err()) + } + return nil +} + +func NewFindNodeQuery(id, target infohash.ID, wants []string) *Msg { + return &Msg{ + Query: "find_node", + Type: "q", + Args: &MsgArgs{ + // The querying node + ID: &id, + // The ID sought + Target: &target, + Want: wants, + }, + } +} + +func NewPingQuery(id infohash.ID) *Msg { + return &Msg{ + Query: "ping", + Args: &MsgArgs{ + ID: &id, + }, + } +} + +// func NewGetPeersQuery(id infohash.ID) *Msg { +// msg := &Msg{ +// Query: "get_peers", +// Type: "q", +// Args: &MsgArgs{ +// ID: &c.node.ID, +// InfoHash: &ih, +// }, +// } +// return &Msg{ +// Query: "ping", +// TID: newTransactionID(), +// Args: &MsgArgs{ +// ID: &id, +// }, +// } +// } + +// func NewSampleInfohashesQuery(id, target infohash.ID) *Msg { +// return &Msg{ +// Query: "sample_infohashes", +// TID: newTransactionID(), +// Type: "q", +// Args: &MsgArgs{ +// // The querying node +// ID: &id, +// // The ID sought +// Target: &target, +// }, +// } +// } + +// const ( +// transIDBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +// ) + +// makeQuery returns a query-formed data. +// func MakeQuery(tid, query string, data map[string]any) map[string]any { +// return map[string]any{ +// "t": tid, +// "y": "q", +// "q": query, +// "a": data, +// } +// } + +// makeResponse returns a response-formed data. +// func MakeResponse(tid string, data map[string]any) map[string]any { +// return map[string]any{ +// "t": tid, +// "y": "r", +// "r": data, +// } +// } + +// func DecodeCompactNodeAddr(cni string) string { +// if len(cni) == 6 { +// return fmt.Sprintf("%d.%d.%d.%d:%d", cni[0], cni[1], cni[2], cni[3], (uint16(cni[4])<<8)|uint16(cni[5])) +// } else if len(cni) == 18 { +// b := []byte(cni[:16]) +// return fmt.Sprintf("[%s]:%d", net.IP.String(b), (uint16(cni[16])<<8)|uint16(cni[17])) +// } else { +// return "" +// } +// } + +// func EncodeCompactNodeAddr(addr string) string { +// var a []uint8 +// host, port, _ := net.SplitHostPort(addr) +// ip := net.ParseIP(host) +// if ip == nil { +// return "" +// } +// aa, _ := strconv.ParseUint(port, 10, 16) +// c := uint16(aa) +// if ip2 := net.IP.To4(ip); ip2 != nil { +// a = make([]byte, net.IPv4len+2, net.IPv4len+2) +// copy(a, ip2[0:net.IPv4len]) // ignore bytes IPv6 bytes if it's IPv4. +// a[4] = byte(c >> 8) +// a[5] = byte(c) +// } else { +// a = make([]byte, net.IPv6len+2, net.IPv6len+2) +// copy(a, ip) +// a[16] = byte(c >> 8) +// a[17] = byte(c) +// } +// return string(a) +// } + +// func int2bytes(val int64) []byte { +// data, j := make([]byte, 8), -1 +// for i := 0; i < 8; i++ { +// shift := uint64((7 - i) * 8) +// data[i] = byte((val & (0xff << shift)) >> shift) + +// if j == -1 && data[i] != 0 { +// j = i +// } +// } + +// if j != -1 { +// return data[j:] +// } +// return data[:1] +// } diff --git a/dht/krpc_test.go b/dht/krpc_test.go new file mode 100644 index 0000000..99f5fdf --- /dev/null +++ b/dht/krpc_test.go @@ -0,0 +1,106 @@ +package dht + +import ( + "testing" + + "userspace.com.au/dhtsearch/bencode" + "userspace.com.au/dhtsearch/infohash" +) + +func TestKRPCMsg(t *testing.T) { + id := infohash.ID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}) + tests := []struct { + in Msg + out string + }{{ + in: Msg{Args: &MsgArgs{ID: &id}}, + out: "d1:ad2:id20:\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14e1:t0:1:y0:e", + }, { + in: Msg{Args: &MsgArgs{Want: []string{"n4", "n6"}}}, + out: "d1:ad4:wantl2:n42:n6ee1:t0:1:y0:e", + }} + + for _, tt := range tests { + t.Run("marshal", func(t *testing.T) { + t.Logf("pre msg: %v", tt.in) + b, err := bencode.Marshal(tt.in) + if err != nil { + t.Fatal(err) + } + if string(b) != tt.out { + t.Fatalf("got %q, want %q", string(b), tt.out) + } + }) + // t.Run("unmarshal", func(t *testing.T) { + // var got Msg + // err := bencode.Unmarshal([]byte(tt.out), &got) + // if err != nil { + // t.Fatal(err) + // } + // if got != tt.in { + // t.Fatalf("got %v, want %v", got, tt.in) + // } + // }) + } +} + +func TestError(t *testing.T) { + in := "d1:eli201e17:too many requestse1:t2:rz1:y1:re" + var m Msg + if err := bencode.Unmarshal([]byte(in), &m); err != nil { + t.Fatal(err) + } + if m.Error.Code != 201 { + t.Errorf("got %v, want %d", m.Error.Code, 201) + } + if m.Error.Msg != "too many requests" { + t.Errorf("got %v, want %d", m.Error.Code, 201) + } +} + +// func TestCompactNode(t *testing.T) { +// ih := "infohashinfohash1234" +// idIn, _ := infohash.FromString(ih) +// tests := []struct { +// ip string +// port uint16 +// }{ +// {ip: "127.0.0.1", port: 6881}, +// {ip: "[2404:6800:4006:814::200e]", port: 6881}, +// } +// +// type compactNode interface { +// ID() (*infohash.ID, error) +// AddrPort() (netip.AddrPort, error) +// } +// +// for _, tt := range tests { +// apIn, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", tt.ip, tt.port)) +// if err != nil { +// panic(err) +// } +// s := NewCompactNodeInfo(*idIn, apIn.Addr(), tt.port) +// t.Logf("CompactNodeInfo: %d %v", len(s), []byte(s)) +// +// var cni compactNode +// if apIn.Addr().Is6() { +// cni = CompactNode6Info(s) +// } else { +// cni = CompactNodeInfo(s) +// } +// ap, err := cni.AddrPort() +// if err != nil { +// t.Error(err) +// } +// if ap != apIn { +// t.Errorf("got %v, want %v", ap, apIn) +// } +// id, err := cni.ID() +// if err != nil { +// t.Error(err) +// } +// if !id.Equal(*idIn) { +// t.Errorf("got %v, want %v", id, idIn) +// } +// } +// } diff --git a/dht/ktable.go b/dht/ktable.go new file mode 100644 index 0000000..76f9fdd --- /dev/null +++ b/dht/ktable.go @@ -0,0 +1,411 @@ +package dht + +import ( + "bufio" + "bytes" + "context" + "io" + "log/slog" + "slices" + "sync" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "userspace.com.au/dhtsearch/bencode" + "userspace.com.au/dhtsearch/infohash" +) + +type Table struct { + sync.RWMutex + + // Local node infohash + id infohash.ID + + // Root bucket for the tree + root *bucket + + log *slog.Logger + + // BEP5: Each bucket can only hold K nodes, currently eight, before + // becoming "full." + k int +} + +type bucket struct { + nodes []*Node + left *bucket + right *bucket + dontSplit bool + + // BEP5: Each bucket should maintain a "last changed" property to + // indicate how "fresh" the contents are. When a node in a bucket is + // pinged and it responds, or a node is added to a bucket, or a node in + // a bucket is replaced with another node, the bucket's last changed + // property should be updated. Buckets that have not been changed in + // 15 minutes should be "refreshed." This is done by picking a random + // ID in the range of the bucket and performing a find_nodes search on it. + lastChanged time.Time +} + +var _ RoutingTable = (*Table)(nil) + +func NewTable(id infohash.ID) *Table { + out := &Table{ + id: id, + root: newBucket(), + k: 8, + } + return out +} + +func (t *Table) ID() infohash.ID { + return t.id +} + +func (t *Table) Add(n *Node) bool { + if t.id.Equal(n.ID) { + t.log.Debug("not adding self") + return false + } + t.Lock() + defer t.Unlock() + return t.addNoLock(n) +} + +// Potentially recursive +func (t *Table) addNoLock(n *Node) bool { + log := t.log.With("id", n.ID, "addr", n.AddrPort) + b, bitIndex := t.locateBucket(n.ID) + if b.has(n.ID) { + //log.Debug("node exists") + return false + } + if len(b.nodes) < t.k { + b.add(n) + ktableCalls.Add( + context.TODO(), 1, + metric.WithAttributes(attribute.String("method", "add"))) + log.Debug("added node") + return true + } + if b.dontSplit { + toReplace := b.mostQuestionable(2) + if toReplace != nil { + log.Debug( + "replacing node", + "old", toReplace.AddrPort, + "attempts", toReplace.PingAttempts, + "contact", toReplace.LastContact) + b.remove(toReplace.ID) + b.add(n) + ktableCalls.Add( + context.TODO(), 1, + metric.WithAttributes(attribute.String("method", "replace"))) + return true + } + //log.Debug("bucket is full") + return false + } + + // BEP5: When a bucket is full of known good nodes, no more nodes may be + // added unless our own node ID falls within the range of the bucket. In + // that case, the bucket is replaced by two new buckets each with half the + // range of the old bucket and the nodes from the old bucket are + // distributed among the two new ones. For a new table with only one + // bucket, the full bucket is always split into two new buckets covering + // the ranges 0..2^159 and 2^159..2^160. + b.split(bitIndex) + b.farChild(t.id, bitIndex).dontSplit = true + return t.addNoLock(n) +} + +func (t *Table) Has(id infohash.ID) bool { + t.RLock() + defer t.RUnlock() + return t.hasNoLock(id) +} + +func (t *Table) hasNoLock(id infohash.ID) bool { + b, _ := t.locateBucket(id) + return b.has(id) +} + +func (t *Table) Remove(id infohash.ID) bool { + t.Lock() + defer t.Unlock() + return t.removeNoLock(id) +} + +func (t *Table) removeNoLock(id infohash.ID) bool { + b, _ := t.locateBucket(id) + ktableCalls.Add( + context.TODO(), 1, + metric.WithAttributes(attribute.String("method", "remove"))) + return b.remove(id) +} + +func (t *Table) Count() int { + t.RLock() + defer t.RUnlock() + n := 0 + for _, bucket := range nonEmptyBuckets(t.root) { + n += len(bucket.nodes) + } + return n +} + +// Seen clears a node's questionable flag and updates contact time. +func (t *Table) SetSeen(id infohash.ID) { + t.Lock() + defer t.Unlock() + t.seenNoLock(id) +} + +func (t *Table) seenNoLock(id infohash.ID) { + bucket, _ := t.locateBucket(id) + bucket.update() + if n := bucket.find(id); n != nil { + if n.PingAttempts > 0 { + t.log.Debug("reseting questionable status", "address", n.AddrPort) + } + n.PingAttempts = 0 + n.LastContact = time.Now() + } +} + +func (t *Table) GetClosest(target infohash.ID, limit int) []*Node { + t.RLock() + defer t.RUnlock() + ktableCalls.Add( + context.TODO(), 1, + metric.WithAttributes(attribute.String("method", "closest"))) + bitIndex := 0 + buckets := []*bucket{t.root} + nodes := make([]*Node, 0, limit) + var bucket *bucket + for len(buckets) > 0 && len(nodes) < limit { + bucket, buckets = buckets[len(buckets)-1], buckets[:len(buckets)-1] + if bucket.nodes == nil { + near := bucket.nearChild(target, bitIndex) + far := bucket.farChild(target, bitIndex) + buckets = append(buckets, far, near) + bitIndex++ + } else { + nodes = append(nodes, bucket.nodes...) + } + } + // We have less than requested + if length := len(nodes); limit > length { + limit = length + } + slices.SortFunc(nodes, func(a, b *Node) int { + aDist := target.Xor(a.ID) + bDist := target.Xor(b.ID) + return bytes.Compare(aDist, bDist) + }) + return nodes[:limit] +} + +func (t *Table) GetStale(limit int, d time.Duration) []*Node { + t.RLock() + defer t.RUnlock() + nodes := make([]*Node, 0) + ktableCalls.Add( + context.TODO(), 1, + metric.WithAttributes(attribute.String("method", "stale"))) + buckets := nonEmptyBuckets(t.root) + slices.SortFunc(buckets, func(a, b *bucket) int { + // Oldest first + return b.lastChanged.Compare(a.lastChanged) + }) + for _, b := range buckets { + for _, n := range b.nodes { + if len(nodes) < limit && time.Since(n.LastContact) > d { + nodes = append(nodes, n) + } + } + } + return nodes +} + +// recursive +func nonEmptyBuckets(b *bucket) []*bucket { + if b == nil { + return nil + } + var out []*bucket + if len(b.nodes) > 0 { + out = append(out, b) + } + out = append(out, nonEmptyBuckets(b.left)...) + out = append(out, nonEmptyBuckets(b.right)...) + return out +} + +func (t *Table) locateBucket(id infohash.ID) (bucket *bucket, bitIndex int) { + bucket = t.root + for bucket.nodes == nil { + bucket = bucket.nearChild(id, bitIndex) + bitIndex++ + } + return +} + +func newBucket() *bucket { + return &bucket{ + nodes: make([]*Node, 0), + //lastChanged: time.Now(), + } +} + +func (b *bucket) update() { + b.lastChanged = time.Now() +} + +func (b *bucket) add(n *Node) { + b.nodes = append(b.nodes, n) + b.update() +} + +func (b *bucket) split(bitIndex int) { + b.left = newBucket() + b.right = newBucket() + for _, n := range b.nodes { + b.nearChild(n.ID, bitIndex).add(n) + } + b.nodes = nil +} + +func (b *bucket) has(id infohash.ID) bool { + return b.indexOf(id) >= 0 +} + +func (b *bucket) nearChild(id infohash.ID, bitIndex int) *bucket { + bitIndexWithinByte := bitIndex % 8 + desiredByte := id[bitIndex/8] + if desiredByte&(1<<(uint(7-bitIndexWithinByte))) == 1 { + return b.right + } + return b.left +} + +func (b *bucket) farChild(id infohash.ID, bitIndex int) *bucket { + if c := b.nearChild(id, bitIndex); c == b.right { + return b.left + } + return b.right +} + +func (b *bucket) remove(id infohash.ID) bool { + var out bool + b.nodes = slices.DeleteFunc(b.nodes, func(n *Node) bool { + out = n.ID.Equal(id) + return out + }) + return out +} + +func (b *bucket) indexOf(id infohash.ID) int { + for i, c := range b.nodes { + if id.Equal(c.ID) { + return i + } + } + return -1 +} + +func (b *bucket) find(id infohash.ID) *Node { + if index := b.indexOf(id); index > -1 { + return b.nodes[index] + } + return nil +} + +// mostQuestionable finds the node with the most ping attempts +func (b *bucket) mostQuestionable(attempts int) *Node { + var out *Node + for _, n := range b.nodes { + if n.PingAttempts > attempts { + out = n + attempts = n.PingAttempts + } + } + return out +} + +func (t *Table) WriteTo(w io.Writer) (int64, error) { + t.RLock() + defer t.RUnlock() + + var c int64 + var n int + var err error + + // Infohash first + b, err := bencode.EncodeString(string(t.id[:])) + if err != nil { + return c, err + } + n, err = w.Write(b) + if err != nil { + return c, err + } + c += int64(n) + t.log.Debug("exporting table", "id", t.id) + + // This should write n4 nodes correctly too + var l int + for _, bu := range nonEmptyBuckets(t.root) { + for _, node := range bu.nodes { + cni, err := node.MarshalBinary() + if err != nil { + return c, err + } + b, err := bencode.EncodeString(string(cni)) + if err != nil { + return c, err + } + n, err = w.Write(b) + if err != nil { + return c, err + } + l += 1 + c += int64(n) + t.log.Debug("exported node", "id", node.ID, "address", node.AddrPort) + } + } + t.log.Info("exported table", "id", t.id, "count", l) + return c, err +} + +func (t *Table) ReadFrom(r io.Reader) (int64, error) { + t.Lock() + defer t.Unlock() + + br := bencode.NewReader(bufio.NewReader(r)) + var ih string + if !br.ReadString(&ih) { + return br.Count(), br.Err() + } + t.id = infohash.ID([]byte(ih)) + t.log.Debug("importing table", "id", t.id) + var nodes []Node + for br.Err() == nil { + var cni string + if br.ReadString(&cni) { + var node Node + if err := node.UnmarshalBinary([]byte(cni)); err != nil { + return br.Count(), err + } + nodes = append(nodes, node) + t.log.Debug("imported node", "id", node.ID, "address", node.AddrPort) + } + } + for _, node := range nodes { + t.addNoLock(&node) + } + t.log.Info("imported table", "id", t.id, "count", len(nodes)) + return br.Count(), nil +} diff --git a/dht/ktable_test.go b/dht/ktable_test.go new file mode 100644 index 0000000..266c1af --- /dev/null +++ b/dht/ktable_test.go @@ -0,0 +1,106 @@ +package dht + +import ( + "bytes" + "net/netip" + "testing" + + "github.com/neilotoole/slogt" + + "userspace.com.au/dhtsearch/infohash" +) + +func TestTable(t *testing.T) { + id := bytes.Repeat([]byte{0}, 20) + tbl := NewTable(infohash.ID(id)) + tbl.log = slogt.New(t) + + var nodes []*Node + + for range 100 { + n := &Node{ + ID: infohash.NewRandomID(), + } + if tbl.Add(n) { + nodes = append(nodes, n) + } + } + + t.Run("updating", func(t *testing.T) { + t.Parallel() + for _, n := range nodes { + tbl.SetSeen(n.ID) + } + }) + t.Run("lookup", func(t *testing.T) { + t.Parallel() + for _, n := range nodes { + tbl.Has(n.ID) + } + }) + t.Run("remove add", func(t *testing.T) { + t.Parallel() + for _, n := range nodes { + tbl.Remove(n.ID) + tbl.Add(n) + } + }) +} + +func TestTableExport(t *testing.T) { + nodes := []Node{ + { + ID: infohash.MustParseString("0000000000000000000000000000000000000000"), + AddrPort: netip.MustParseAddrPort("[2001:19f0:5:6d01:5400:2ff:feec:644a]:6881")}, + { + ID: infohash.MustParseString("c7fb3eab2d33331610075ed0322f6cebc5351aed"), + AddrPort: netip.MustParseAddrPort("[2607:9000:3000:33:d5fe:c61c:6adc:6b79]:17980")}, + { + ID: infohash.MustParseString("c7fba7fc5ec5dfd49cff98b31b68e37b0a08fd84"), + AddrPort: netip.MustParseAddrPort("[240e:3b3:9613:52f0:215:5dff:fe0a:8971]:61128")}, + { + ID: infohash.MustParseString("c7fc64a9ebb14481981529ba200a40d00643eb2a"), + AddrPort: netip.MustParseAddrPort("[2001:da8:e000:3002:20c:29ff:fea3:7226]:10301")}, + { + ID: infohash.MustParseString("c5198df3d46a32799ac34436a4737d7b3bb2ff6e"), + AddrPort: netip.MustParseAddrPort("[240e:b8f:966e:9e00::bf9]:63219")}, + { + ID: infohash.MustParseString("c1ea1df5ce162836c2080bdc94a127a90cfad24a"), + AddrPort: netip.MustParseAddrPort("[2a01:e0a:d5b:c690:211:32ff:fe26:7445]:51413")}, + { + ID: infohash.MustParseString("c17f24479d8c5ad3ad7a731a600ff323ecb26ebd"), + AddrPort: netip.MustParseAddrPort("[2001:e68:542f:5959:11b7:2243:527d:1c86]:19264")}, + } + id := infohash.MustParseString("c4bee4bf16dd88527f63cab902b1111111111111") + tbl := NewTable(id) + tbl.log = slogt.New(t) + for _, n := range nodes { + tbl.Add(&n) + } + if tbl.Count() != len(nodes) { + t.Fatalf("got %d, want %d", tbl.Count(), len(nodes)) + } + + buf := new(bytes.Buffer) + n1, err := tbl.WriteTo(buf) + if err != nil { + t.Fatal(err) + } + tbl2 := &Table{ + root: newBucket(), + k: 8, + } + tbl2.log = slogt.New(t) + n2, err := tbl2.ReadFrom(buf) + if err != nil { + t.Fatal(err) + } + if n1 != n2 { + t.Errorf("got %d, want %d", n2, n1) + } + for _, n := range nodes { + if !tbl.Has(n.ID) { + t.Errorf("missing id %s", n.ID) + } + } +} diff --git a/dht/messages.go b/dht/messages.go deleted file mode 100644 index 188ed0f..0000000 --- a/dht/messages.go +++ /dev/null @@ -1,127 +0,0 @@ -package dht - -import ( - "fmt" - "net" - - "src.userspace.com.au/dhtsearch/krpc" - "src.userspace.com.au/dhtsearch/models" -) - -func (n *Node) onPingQuery(rn remoteNode, msg map[string]interface{}) error { - t, err := krpc.GetString(msg, "t") - if err != nil { - return err - } - n.queueMsg(rn, krpc.MakeResponse(t, map[string]interface{}{ - "id": string(n.id), - })) - return nil -} - -func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error { - a, err := krpc.GetMap(msg, "a") - if err != nil { - return err - } - - // This is the ih of the torrent - torrent, err := krpc.GetString(a, "info_hash") - if err != nil { - return err - } - th, err := models.InfohashFromString(torrent) - if err != nil { - return err - } - //n.log.Debug("get_peers query", "source", rn, "torrent", th) - - token := torrent[:2] - neighbour := models.GenerateNeighbour(n.id, *th) - /* - nodes := n.rTable.get(8) - compactNS := []string{} - for _, rn := range nodes { - ns := encodeCompactNodeAddr(rn.addr.String()) - if ns == "" { - n.log.Warn("failed to compact node", "address", rn.address.String()) - continue - } - compactNS = append(compactNS, ns) - } - */ - - t := msg["t"].(string) - n.queueMsg(rn, krpc.MakeResponse(t, map[string]interface{}{ - "id": string(neighbour), - "token": token, - "nodes": "", - //"nodes": strings.Join(compactNS, ""), - })) - - //nodes := n.rTable.get(50) - /* - fmt.Printf("sending get_peers for %s to %d nodes\n", *th, len(nodes)) - q := krpc.MakeQuery(newTransactionID(), "get_peers", map[string]interface{}{ - "id": string(id), - "info_hash": string(*th), - }) - for _, o := range nodes { - n.queueMsg(*o, q) - } - */ - return nil -} - -func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) error { - a, err := krpc.GetMap(msg, "a") - if err != nil { - return err - } - - n.log.Debug("announce_peer", "source", rn) - - host, port, err := net.SplitHostPort(rn.addr.String()) - if err != nil { - return err - } - if port == "0" { - return fmt.Errorf("ignoring port 0") - } - - ihStr, err := krpc.GetString(a, "info_hash") - if err != nil { - return err - } - ih, err := models.InfohashFromString(ihStr) - if err != nil { - return fmt.Errorf("invalid torrent: %s", err) - } - - newPort, err := krpc.GetInt(a, "port") - if err == nil { - if iPort, err := krpc.GetInt(a, "implied_port"); err == nil && iPort == 0 { - // Use the port in the message - addr, err := net.ResolveUDPAddr(n.family, fmt.Sprintf("%s:%d", host, newPort)) - if err != nil { - return err - } - n.log.Debug("implied port", "infohash", ih, "original", rn.addr.String(), "new", addr.String()) - rn = remoteNode{addr: addr, id: rn.id} - } - } - - // TODO do we reply? - - p := models.Peer{Addr: rn.addr, Infohash: *ih} - if n.OnAnnouncePeer != nil { - go n.OnAnnouncePeer(p) - } - return nil -} - -func (n *Node) onFindNodeResponse(rn remoteNode, msg map[string]interface{}) { - r := msg["r"].(map[string]interface{}) - nodes := r["nodes"].(string) - n.processFindNodeResults(rn, nodes) -} diff --git a/dht/metrics.go b/dht/metrics.go new file mode 100644 index 0000000..da708ba --- /dev/null +++ b/dht/metrics.go @@ -0,0 +1,76 @@ +package dht + +import ( + "context" + "time" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/metric" +) + +var meter = otel.Meter("userspace.com.au/dhtsearch/dht") + +var ( + netOctets metric.Int64Counter + netPackets metric.Int64Counter +) + +var ktableCalls metric.Int64Counter + +func init() { + var err error + netOctets, err = meter.Int64Counter( + "network.octets", + metric.WithDescription("The number of bytes."), + metric.WithUnit("{byte}"), + ) + if err != nil { + panic(err) + } + netPackets, err = meter.Int64Counter( + "bytes.received", + metric.WithDescription("The number of packets."), + metric.WithUnit("{byte}"), + ) + if err != nil { + panic(err) + } + ktableCalls, err = meter.Int64Counter( + "ktable.calls", + metric.WithDescription("Number of calls to ktable."), + metric.WithUnit("{call}"), + ) + if err != nil { + panic(err) + } +} + +func configureMetrics(c *Client) error { + var err error + start := time.Now() + if _, err = meter.Float64ObservableCounter( + "uptime", + metric.WithDescription("The running duration."), + metric.WithUnit("s"), + metric.WithFloat64Callback(func(_ context.Context, o metric.Float64Observer) error { + o.Observe(float64(time.Since(start).Seconds())) + return nil + }), + ); err != nil { + panic(err) + } + + _, err = meter.Int64ObservableUpDownCounter( + "ktable.size", + metric.WithDescription("The number of nodes in the ktable."), + metric.WithUnit("{nodes}"), + metric.WithInt64Callback(func(_ context.Context, o metric.Int64Observer) error { + o.Observe(int64(c.table.Count())) + return nil + }), + ) + if err != nil { + return err + } + return nil +} diff --git a/dht/node.go b/dht/node.go index 2ea5f2f..437f13d 100644 --- a/dht/node.go +++ b/dht/node.go @@ -1,427 +1,76 @@ package dht import ( - "context" - "fmt" - "net" + "errors" + "net/netip" + "sync" "time" - "github.com/hashicorp/golang-lru" - "golang.org/x/time/rate" - "src.userspace.com.au/dhtsearch/krpc" - "src.userspace.com.au/dhtsearch/models" - "src.userspace.com.au/go-bencode" - "src.userspace.com.au/logger" + "userspace.com.au/dhtsearch/infohash" ) -var ( - routers = []string{ - "dht.libtorrent.org:25401", - "router.bittorrent.com:6881", - "dht.transmissionbt.com:6881", - "router.utorrent.com:6881", - "dht.aelitis.com:6881", - } -) - -// Node joins the DHT network type Node struct { - id models.Infohash - family string - address string - port int - conn net.PacketConn - pool chan chan packet - rTable *routingTable - udpTimeout int - packetsOut chan packet - log logger.Logger - limiter *rate.Limiter - blacklist *lru.ARCCache - - // OnAnnoucePeer is called for each peer that announces itself - OnAnnouncePeer func(models.Peer) - // OnBadPeer is called for each bad peer - OnBadPeer func(models.Peer) -} - -// NewNode creates a new DHT node -func NewNode(opts ...Option) (*Node, error) { - var err error - id := models.GenInfohash() - - n := &Node{ - id: id, - family: "udp4", - port: 6881, - udpTimeout: 10, - limiter: rate.NewLimiter(rate.Limit(100000), 2000000), - log: logger.New(&logger.Options{Name: "dht"}), - } - - n.rTable, err = newRoutingTable(id, 2000) - if err != nil { - n.log.Error("failed to create routing table", "error", err) - return nil, err - } - - // Set variadic options passed - for _, option := range opts { - err = option(n) - if err != nil { - return nil, err - } - } - - if n.blacklist == nil { - n.blacklist, err = lru.NewARC(1000) - if err != nil { - return nil, err - } - } - - if n.family != "udp4" { - n.log.Debug("trying udp6 server") - n.conn, err = net.ListenPacket("udp6", fmt.Sprintf("[%s]:%d", net.IPv6zero.String(), n.port)) - if err == nil { - n.family = "udp6" - } - } - if n.conn == nil { - n.conn, err = net.ListenPacket("udp4", fmt.Sprintf("%s:%d", net.IPv4zero.String(), n.port)) - if err == nil { - n.family = "udp4" - } - } - if err != nil { - n.log.Error("failed to listen", "error", err) - return nil, err - } - n.log.Info("listening", "id", n.id, "network", n.family, "address", n.conn.LocalAddr().String()) - - return n, nil -} - -// Close stuff -func (n *Node) Close() error { - n.log.Warn("node closing") - return nil -} - -// Run starts the node on the DHT -func (n *Node) Run() { - // Packets onto the network - n.packetsOut = make(chan packet, 1024) - - // Create a slab for allocation - byteSlab := newSlab(8192, 10) + sync.Mutex + ID infohash.ID + AddrPort netip.AddrPort + Secure bool + Family string - n.log.Debug("starting packet writer") - go n.packetWriter() + // To assist with rate limiting + NextQuery map[string]time.Time - // Find neighbours - go n.makeNeighbours() + // Incremented when selected for ping, cleared on pong + PingAttempts int - n.log.Debug("starting packet reader") - for { - b := byteSlab.alloc() - c, addr, err := n.conn.ReadFrom(b) - if err != nil { - n.log.Warn("UDP read error", "error", err) - return - } - - // Chop and process - n.processPacket(packet{ - data: b[0:c], - raddr: addr, - }) - byteSlab.free(b) - } + // Used by table + LastContact time.Time } -func (n *Node) makeNeighbours() { - // TODO configurable - ticker := time.Tick(5 * time.Second) - - n.bootstrap() - - for { - select { - case <-ticker: - if n.rTable.isEmpty() { - n.bootstrap() - } else { - // Send to all nodes - nodes := n.rTable.get(0) - for _, rn := range nodes { - n.findNode(rn, models.GenerateNeighbour(n.id, rn.id)) - } - n.rTable.flush() - } - } - } +func (n *Node) String() string { + return n.AddrPort.String() } -func (n *Node) bootstrap() { - n.log.Debug("bootstrapping") - for _, s := range routers { - addr, err := net.ResolveUDPAddr(n.family, s) - if err != nil { - n.log.Error("failed to parse bootstrap address", "error", err) - continue - } - rn := &remoteNode{addr: addr} - n.findNode(rn, n.id) +func (n *Node) canSend(query string) bool { + if len(n.NextQuery) == 0 { + return true } + n.Lock() + defer n.Unlock() + t, ok := n.NextQuery[query] + return !ok || time.Now().After(t) } -func (n *Node) packetWriter() { - for p := range n.packetsOut { - if p.raddr.String() == n.conn.LocalAddr().String() { - continue - } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - if err := n.limiter.WaitN(ctx, len(p.data)); err != nil { - n.log.Warn("rate limited", "error", err) - continue - } - //n.log.Debug("writing packet", "dest", p.raddr.String()) - _, err := n.conn.WriteTo(p.data, p.raddr) - if err != nil { - n.blacklist.Add(p.raddr.String(), true) - // TODO reduce limit - n.log.Warn("failed to write packet", "error", err) - if n.OnBadPeer != nil { - peer := models.Peer{Addr: p.raddr} - go n.OnBadPeer(peer) - } - } - } -} - -func (n *Node) findNode(rn *remoteNode, id models.Infohash) { - target := models.GenInfohash() - n.sendQuery(rn, "find_node", map[string]interface{}{ - "id": string(id), - "target": string(target), - }) -} - -// ping sends ping query to the chan. -func (n *Node) ping(rn *remoteNode) { - id := models.GenerateNeighbour(n.id, rn.id) - n.sendQuery(rn, "ping", map[string]interface{}{ - "id": string(id), - }) -} - -func (n *Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) error { - // Stop if sending to self - if rn.id.Equal(n.id) { - return nil - } - - t := krpc.NewTransactionID() - - data := krpc.MakeQuery(t, qType, a) - b, err := bencode.Encode(data) - if err != nil { - return err - } - //fmt.Printf("sending %s to %s\n", qType, rn.String()) - n.packetsOut <- packet{ - data: b, - raddr: rn.addr, - } - return nil -} - -// Parse a KRPC packet into a message -func (n *Node) processPacket(p packet) error { - response, _, err := bencode.DecodeDict(p.data, 0) - if err != nil { - return err - } - - y, err := krpc.GetString(response, "y") - if err != nil { - return err - } - - if _, black := n.blacklist.Get(p.raddr.String()); black { - return fmt.Errorf("blacklisted: %s", p.raddr.String()) +func (n *Node) rateQuery(q string, s int64) { + n.Lock() + defer n.Unlock() + if len(n.NextQuery) == 0 { + n.NextQuery = make(map[string]time.Time) } - - switch y { - case "q": - err = n.handleRequest(p.raddr, response) - case "r": - err = n.handleResponse(p.raddr, response) - case "e": - err = n.handleError(p.raddr, response) - default: - err = fmt.Errorf("missing request type") + if s == 0 { + n.NextQuery[q] = time.Now() + return } - if err != nil { - n.log.Warn("failed to process packet", "error", err) - n.blacklist.Add(p.raddr.String(), true) + if s == -1 { + // TODO delay one day + s = 24 * 60 * 60 } - return err + n.NextQuery[q] = time.Now().Add(time.Duration(s) * time.Second) } -// bencode data and send -func (n *Node) queueMsg(rn remoteNode, data map[string]interface{}) error { - b, err := bencode.Encode(data) - if err != nil { - return err - } - n.packetsOut <- packet{ - data: b, - raddr: rn.addr, - } - return nil +func (n *Node) MarshalBinary() ([]byte, error) { + b, _ := NodeAddr(n.AddrPort).MarshalBinary() + return append(n.ID.Bytes(), b...), nil } -// handleRequest handles the requests received from udp. -func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error { - q, err := krpc.GetString(m, "q") - if err != nil { - return err - } - - a, err := krpc.GetMap(m, "a") - if err != nil { - return err - } - - id, err := krpc.GetString(a, "id") - if err != nil { - return err - } - - ih, err := models.InfohashFromString(id) - if err != nil { - return err - } - - if n.id.Equal(*ih) { - return nil - } - - rn := &remoteNode{addr: addr, id: *ih} - - switch q { - case "ping": - err = n.onPingQuery(*rn, m) - - case "get_peers": - err = n.onGetPeersQuery(*rn, m) - - case "announce_peer": - n.onAnnouncePeerQuery(*rn, m) - - default: - //n.queueMsg(addr, makeError(t, protocolError, "invalid q")) - return nil - } - n.rTable.add(rn) - return err -} - -// handleResponse handles responses received from udp. -func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error { - r, err := krpc.GetMap(m, "r") - if err != nil { - return err - } - id, err := krpc.GetString(r, "id") - if err != nil { - return err - } - ih, err := models.InfohashFromString(id) - if err != nil { - return err - } - - rn := &remoteNode{addr: addr, id: *ih} - - nodes, err := krpc.GetString(r, "nodes") - // find_nodes/get_peers response with nodes - if err == nil { - n.onFindNodeResponse(*rn, m) - n.processFindNodeResults(*rn, nodes) - n.rTable.add(rn) +func (n *Node) UnmarshalBinary(b []byte) error { + n.ID = infohash.ID(b[:infohash.Length]) + var na NodeAddr + if err := na.UnmarshalBinary(b[infohash.Length:]); err != nil { return nil } - - values, err := krpc.GetList(r, "values") - // get_peers response - if err == nil { - n.log.Debug("get_peers response", "source", rn) - for _, v := range values { - addr := krpc.DecodeCompactNodeAddr(v.(string)) - n.log.Debug("unhandled get_peer request", "addres", addr) - - // TODO new peer needs to be matched to previous get_peers request - // n.peersManager.Insert(ih, p) - } - n.rTable.add(rn) + n.AddrPort = netip.AddrPort(na) + if !n.AddrPort.IsValid() || n.AddrPort.Port() == 0 { + return errors.New("invalid AddrPort") } return nil } - -// handleError handles errors received from udp. -func (n *Node) handleError(addr net.Addr, m map[string]interface{}) error { - e, err := krpc.GetList(m, "e") - if err != nil { - return err - } - - if len(e) != 2 { - return fmt.Errorf("error packet wrong length %d", len(e)) - } - code := e[0].(int64) - msg := e[1].(string) - n.log.Debug("error packet", "address", addr.String(), "code", code, "error", msg) - - return nil -} - -// Process another node's response to a find_node query. -func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) { - nodeLength := krpc.IPv4NodeAddrLen - if n.family == "udp6" { - nodeLength = krpc.IPv6NodeAddrLen - } - - if len(nodeList)%nodeLength != 0 { - n.log.Error("node list is wrong length", "length", len(nodeList)) - n.blacklist.Add(rn.addr.String(), true) - return - } - - //fmt.Printf("%s sent %d nodes\n", rn.address.String(), len(nodeList)/nodeLength) - - // We got a byte array in groups of 26 or 38 - for i := 0; i < len(nodeList); i += nodeLength { - id := nodeList[i : i+models.InfohashLength] - addrStr := krpc.DecodeCompactNodeAddr(nodeList[i+models.InfohashLength : i+nodeLength]) - - ih, err := models.InfohashFromString(id) - if err != nil { - n.log.Warn("invalid infohash in node list") - continue - } - - addr, err := net.ResolveUDPAddr(n.family, addrStr) - if err != nil || addr.Port == 0 { - //n.log.Warn("unable to resolve", "address", addrStr, "error", err) - continue - } - - rn := &remoteNode{addr: addr, id: *ih} - n.rTable.add(rn) - } -} diff --git a/dht/node_test.go b/dht/node_test.go new file mode 100644 index 0000000..6926c1e --- /dev/null +++ b/dht/node_test.go @@ -0,0 +1,42 @@ +package dht + +import ( + "bytes" + "net/netip" + "testing" + + "userspace.com.au/dhtsearch/infohash" +) + +func TestNodeMarshaling(t *testing.T) { + tests := []Node{ + { + ID: infohash.MustParseString("0000000000000000000000000000000000000000"), + AddrPort: netip.MustParseAddrPort("[2001:19f0:5:6d01:5400:2ff:feec:644a]:6881")}, + { + ID: infohash.MustParseString("c4301bf4b0a4f83eb7afe1eeee8434fb1893cc15"), + AddrPort: netip.MustParseAddrPort("[2001:f90:4090:1230:2697:edff:fe27:d081]:50000")}, + { + ID: infohash.MustParseString("c17f24479d8c5ad3ad7a731a600ff323ecb26ebd"), + AddrPort: netip.MustParseAddrPort("[2001:e68:542f:5959:11b7:2243:527d:1c86]:19264")}, + } + + for _, tt := range tests { + t.Run(tt.ID.String(), func(t *testing.T) { + b, err := tt.MarshalBinary() + if err != nil { + t.Fatal(err) + } + n2 := new(Node) + if err := n2.UnmarshalBinary(b); err != nil { + t.Fatal(err) + } + if tt.AddrPort.Compare(n2.AddrPort) != 0 { + t.Fatalf("got %s, want %s", n2.AddrPort, tt.AddrPort) + } + if !bytes.Equal(n2.ID[:], tt.ID[:]) { + t.Fatalf("got %s, want %s", n2.ID, tt.ID) + } + }) + } +} diff --git a/dht/options.go b/dht/options.go deleted file mode 100644 index 19a5508..0000000 --- a/dht/options.go +++ /dev/null @@ -1,73 +0,0 @@ -package dht - -import ( - "github.com/hashicorp/golang-lru" - "src.userspace.com.au/dhtsearch/models" - "src.userspace.com.au/logger" -) - -type Option func(*Node) error - -func SetOnAnnouncePeer(f func(models.Peer)) Option { - return func(n *Node) error { - n.OnAnnouncePeer = f - return nil - } -} - -func SetOnBadPeer(f func(models.Peer)) Option { - return func(n *Node) error { - n.OnBadPeer = f - return nil - } -} - -// SetAddress sets the IP address to listen on -func SetAddress(ip string) Option { - return func(n *Node) error { - n.address = ip - return nil - } -} - -// SetPort sets the port to listen on -func SetPort(p int) Option { - return func(n *Node) error { - n.port = p - return nil - } -} - -// SetIPv6 enables IPv6 -func SetIPv6(b bool) Option { - return func(n *Node) error { - if b { - n.family = "udp6" - } - return nil - } -} - -// SetUDPTimeout sets the number of seconds to wait for UDP connections -func SetUDPTimeout(s int) Option { - return func(n *Node) error { - n.udpTimeout = s - return nil - } -} - -// SetLogger sets the logger -func SetLogger(l logger.Logger) Option { - return func(n *Node) error { - n.log = l - return nil - } -} - -// SetBlacklist sets the size of the node blacklist -func SetBlacklist(bl *lru.ARCCache) Option { - return func(n *Node) (err error) { - n.blacklist = bl - return err - } -} diff --git a/dht/packet.go b/dht/packet.go deleted file mode 100644 index 62c08fa..0000000 --- a/dht/packet.go +++ /dev/null @@ -1,33 +0,0 @@ -package dht - -import "net" - -// Arbitrary packet types -// Order these lowest to highest priority for use in -// priority queue heap -const ( - _ int = iota - pktQPing - pktRPing - pktQFindNode - pktRAnnouncePeer - pktRGetPeers -) - -var pktName = map[int]string{ - pktQFindNode: "find_node", - pktQPing: "ping", - pktRPing: "ping", - pktRAnnouncePeer: "annouce_peer", - pktRGetPeers: "get_peers", -} - -// Unprocessed packet from socket -type packet struct { - // The packet type - //priority int - // Required by heap interface - //index int - data []byte - raddr net.Addr -} diff --git a/dht/remote_node.go b/dht/remote_node.go deleted file mode 100644 index 0a5cafa..0000000 --- a/dht/remote_node.go +++ /dev/null @@ -1,18 +0,0 @@ -package dht - -import ( - "fmt" - "net" - - "src.userspace.com.au/dhtsearch/models" -) - -type remoteNode struct { - addr net.Addr - id models.Infohash -} - -// String implements fmt.Stringer -func (r remoteNode) String() string { - return fmt.Sprintf("%s (%s)", r.id.String(), r.addr.String()) -} diff --git a/dht/routing_table.go b/dht/routing_table.go deleted file mode 100644 index 37af542..0000000 --- a/dht/routing_table.go +++ /dev/null @@ -1,121 +0,0 @@ -package dht - -import ( - "container/heap" - "sync" - - "src.userspace.com.au/dhtsearch/models" -) - -type rItem struct { - value *remoteNode - distance int - index int // Index in heap -} - -type priorityQueue []*rItem - -type routingTable struct { - id models.Infohash - max int - items priorityQueue - addresses map[string]*remoteNode - sync.Mutex -} - -func newRoutingTable(id models.Infohash, max int) (*routingTable, error) { - k := &routingTable{ - id: id, - max: max, - } - k.flush() - heap.Init(&k.items) - return k, nil -} - -// Len implements sort.Interface -func (pq priorityQueue) Len() int { return len(pq) } - -// Less implements sort.Interface -func (pq priorityQueue) Less(i, j int) bool { - return pq[i].distance > pq[j].distance -} - -// Swap implements sort.Interface -func (pq priorityQueue) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] - pq[i].index = i - pq[j].index = j -} - -// Push implements heap.Interface -func (pq *priorityQueue) Push(x interface{}) { - n := len(*pq) - item := x.(*rItem) - item.index = n - *pq = append(*pq, item) -} - -// Pop implements heap.Interface -func (pq *priorityQueue) Pop() interface{} { - old := *pq - n := len(old) - item := old[n-1] - item.index = -1 // for safety - *pq = old[0 : n-1] - return item -} - -func (k *routingTable) add(rn *remoteNode) { - // Check IP and ports are valid and not self - if !rn.id.Valid() || rn.id.Equal(k.id) { - return - } - - k.Lock() - defer k.Unlock() - - if _, ok := k.addresses[rn.addr.String()]; ok { - return - } - k.addresses[rn.addr.String()] = rn - - item := &rItem{ - value: rn, - distance: k.id.Distance(rn.id), - } - - heap.Push(&k.items, item) - - if len(k.items) > k.max { - for i := k.max - 1; i < len(k.items); i++ { - old := k.items[i] - delete(k.addresses, old.value.addr.String()) - heap.Remove(&k.items, i) - } - } -} - -func (k *routingTable) get(n int) (out []*remoteNode) { - if n == 0 { - n = len(k.items) - } - for i := 0; i < n && i < len(k.items); i++ { - out = append(out, k.items[i].value) - } - return out -} - -func (k *routingTable) flush() { - k.Lock() - defer k.Unlock() - - k.items = make(priorityQueue, 0) - k.addresses = make(map[string]*remoteNode, k.max) -} - -func (k *routingTable) isEmpty() bool { - k.Lock() - defer k.Unlock() - return len(k.items) == 0 -} diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go deleted file mode 100644 index 763df80..0000000 --- a/dht/routing_table_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package dht - -import ( - "fmt" - "net" - "testing" - - "src.userspace.com.au/dhtsearch/models" -) - -func TestPriorityQueue(t *testing.T) { - id := "d1c5676ae7ac98e8b19f63565905105e3c4c37a2" - - tests := []string{ - "d1c5676ae7ac98e8b19f63565905105e3c4c37b9", - "d1c5676ae7ac98e8b19f63565905105e3c4c37a9", - "d1c5676ae7ac98e8b19f63565905105e3c4c37a4", - "d1c5676ae7ac98e8b19f63565905105e3c4c37a3", // distance of 159 - } - - ih, err := models.InfohashFromString(id) - if err != nil { - t.Errorf("failed to create infohash: %s\n", err) - } - - pq, err := newRoutingTable(*ih, 3) - if err != nil { - t.Errorf("failed to create kTable: %s\n", err) - } - - for i, idt := range tests { - iht, err := models.InfohashFromString(idt) - if err != nil { - t.Errorf("failed to create infohash: %s\n", err) - } - addr, _ := net.ResolveUDPAddr("udp", fmt.Sprintf("0.0.0.0:%d", i)) - pq.add(&remoteNode{id: *iht, addr: addr}) - } - - if len(pq.items) != len(pq.addresses) { - t.Errorf("items and addresses out of sync") - } - - first := pq.items[0].value.id - if first.String() != "d1c5676ae7ac98e8b19f63565905105e3c4c37a3" { - t.Errorf("first is %s with distance %d\n", first, ih.Distance(first)) - } -} diff --git a/dht/slab.go b/dht/slab.go deleted file mode 100644 index a8b4018..0000000 --- a/dht/slab.go +++ /dev/null @@ -1,25 +0,0 @@ -package dht - -// Slab memory allocation - -// Initialise the slab as a channel of blocks, allocating them as required and -// pushing them back on the slab. This reduces garbage collection. -type slab chan []byte - -func newSlab(blockSize int, numBlocks int) slab { - s := make(slab, numBlocks) - for i := 0; i < numBlocks; i++ { - s <- make([]byte, blockSize) - } - return s -} - -func (s slab) alloc() (x []byte) { - return <-s -} - -func (s slab) free(x []byte) { - // Check we are using the right dimensions - x = x[:cap(x)] - s <- x -} diff --git a/dht/transactions.go b/dht/transactions.go new file mode 100644 index 0000000..209d7e3 --- /dev/null +++ b/dht/transactions.go @@ -0,0 +1,93 @@ +package dht + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "net/netip" + "sync" + "time" +) + +type transactions struct { + sync.Mutex + ttl time.Duration + requests map[string]trans + tid uint64 + buf [binary.MaxVarintLen64]byte +} + +type trans struct { + ap netip.AddrPort + msg string + ts time.Time + cb func(*Node) +} + +func newTransRegistry() *transactions { + return &transactions{ + ttl: 30 * time.Second, + requests: make(map[string]trans), + } +} + +type TransactionCallback func(*Node) + +func (c *Client) registerTransaction(rn *Node, m *Msg, cb TransactionCallback) { + c.chk.Lock() + defer c.chk.Unlock() + c.chk.tid++ + + n := binary.PutUvarint(c.chk.buf[:], c.chk.tid) + m.TID = string(c.chk.buf[:n]) + + if _, ok := c.chk.requests[m.TID]; ok { + panic(fmt.Sprintf("duplicate transaction ID %q", m.TID)) + } + c.chk.requests[m.TID] = trans{ + ap: rn.AddrPort, + ts: time.Now(), + msg: m.Query, + cb: cb, + } +} + +func (c *Client) checkTransaction(m Msg) (*trans, error) { + c.chk.Lock() + defer c.chk.Unlock() + if t, ok := c.chk.requests[m.TID]; ok { + delete(c.chk.requests, m.TID) + return &t, nil + } + return nil, errors.New("unsolicited transaction") +} + +// func (c *Client) inflightQueries() int { +// c.chk.Lock() +// defer c.chk.Unlock() +// return len(c.chk.requests) +// } + +func (c *Client) auditTransactions(ctx context.Context) { + c.log.Info("starting transaction auditing", "interval", c.tickAuditTransactions) + ticker := time.Tick(c.tickAuditTransactions) + for { + select { + case <-ctx.Done(): + c.log.Info("stopping transaction auditing") + return + case <-ticker: + c.chk.Lock() + before := len(c.chk.requests) + for tid, t := range c.chk.requests { + if time.Since(t.ts) > c.chk.ttl { + delete(c.chk.requests, tid) + } + } + after := len(c.chk.requests) + c.chk.Unlock() + c.log.Debug("cleared transactions", "size", after, "removed", before-after) + } + } +} diff --git a/krpc/krpc.go b/krpc/krpc.go deleted file mode 100644 index d5d1480..0000000 --- a/krpc/krpc.go +++ /dev/null @@ -1,181 +0,0 @@ -package krpc - -import ( - "errors" - "fmt" - "math/rand" - "net" - "strconv" -) - -const ( - transIDBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - IPv4NodeAddrLen = 26 - IPv6NodeAddrLen = 38 -) - -func NewTransactionID() string { - b := make([]byte, 2) - for i := range b { - b[i] = transIDBytes[rand.Int63()%int64(len(transIDBytes))] - } - return string(b) -} - -// makeQuery returns a query-formed data. -func MakeQuery(transaction, query string, data map[string]interface{}) map[string]interface{} { - return map[string]interface{}{ - "t": transaction, - "y": "q", - "q": query, - "a": data, - } -} - -// makeResponse returns a response-formed data. -func MakeResponse(transaction string, data map[string]interface{}) map[string]interface{} { - return map[string]interface{}{ - "t": transaction, - "y": "r", - "r": data, - } -} - -func GetString(data map[string]interface{}, key string) (string, error) { - val, ok := data[key] - if !ok { - return "", fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.(string) - if !ok { - return "", fmt.Errorf("krpc: key type mismatch") - } - return out, nil -} - -func GetInt(data map[string]interface{}, key string) (int, error) { - val, ok := data[key] - if !ok { - return 0, fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.(int64) - if !ok { - return 0, fmt.Errorf("krpc: key type mismatch") - } - return int(out), nil -} - -func GetMap(data map[string]interface{}, key string) (map[string]interface{}, error) { - val, ok := data[key] - if !ok { - return nil, fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("krpc: key type mismatch") - } - return out, nil -} - -func GetList(data map[string]interface{}, key string) ([]interface{}, error) { - val, ok := data[key] - if !ok { - return nil, fmt.Errorf("krpc: missing key %s", key) - } - out, ok := val.([]interface{}) - if !ok { - return nil, fmt.Errorf("krpc: key type mismatch") - } - return out, nil -} - -// parseKeys parses keys. It just wraps parseKey. -func checkKeys(data map[string]interface{}, pairs [][]string) (err error) { - for _, args := range pairs { - key, t := args[0], args[1] - if err = checkKey(data, key, t); err != nil { - break - } - } - return err -} - -// parseKey parses the key in dict data. `t` is type of the keyed value. -// It's one of "int", "string", "map", "list". -func checkKey(data map[string]interface{}, key string, t string) error { - val, ok := data[key] - if !ok { - return fmt.Errorf("krpc: missing key %s", key) - } - - switch t { - case "string": - _, ok = val.(string) - case "int": - _, ok = val.(int) - case "map": - _, ok = val.(map[string]interface{}) - case "list": - _, ok = val.([]interface{}) - default: - return errors.New("krpc: invalid type") - } - - if !ok { - return errors.New("krpc: key type mismatch") - } - - return nil -} - -// Swiped from nictuku -func DecodeCompactNodeAddr(cni string) string { - if len(cni) == 6 { - return fmt.Sprintf("%d.%d.%d.%d:%d", cni[0], cni[1], cni[2], cni[3], (uint16(cni[4])<<8)|uint16(cni[5])) - } else if len(cni) == 18 { - b := []byte(cni[:16]) - return fmt.Sprintf("[%s]:%d", net.IP.String(b), (uint16(cni[16])<<8)|uint16(cni[17])) - } else { - return "" - } -} - -func EncodeCompactNodeAddr(addr string) string { - var a []uint8 - host, port, _ := net.SplitHostPort(addr) - ip := net.ParseIP(host) - if ip == nil { - return "" - } - aa, _ := strconv.ParseUint(port, 10, 16) - c := uint16(aa) - if ip2 := net.IP.To4(ip); ip2 != nil { - a = make([]byte, net.IPv4len+2, net.IPv4len+2) - copy(a, ip2[0:net.IPv4len]) // ignore bytes IPv6 bytes if it's IPv4. - a[4] = byte(c >> 8) - a[5] = byte(c) - } else { - a = make([]byte, net.IPv6len+2, net.IPv6len+2) - copy(a, ip) - a[16] = byte(c >> 8) - a[17] = byte(c) - } - return string(a) -} - -func int2bytes(val int64) []byte { - data, j := make([]byte, 8), -1 - for i := 0; i < 8; i++ { - shift := uint64((7 - i) * 8) - data[i] = byte((val & (0xff << shift)) >> shift) - - if j == -1 && data[i] != 0 { - j = i - } - } - - if j != -1 { - return data[j:] - } - return data[:1] -} diff --git a/krpc/krpc_test.go b/krpc/krpc_test.go deleted file mode 100644 index 59480e1..0000000 --- a/krpc/krpc_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package krpc - -import ( - "encoding/hex" - "testing" -) - -func TestCompactNodeAddr(t *testing.T) { - - tests := []struct { - in string - out string - }{ - {in: "192.168.1.1:6881", out: "c0a801011ae1"}, - {in: "[2001:9372:434a:800::2]:6881", out: "20019372434a080000000000000000021ae1"}, - } - - for _, tt := range tests { - r := EncodeCompactNodeAddr(tt.in) - out, _ := hex.DecodeString(tt.out) - if r != string(out) { - t.Errorf("encodeCompactNodeAddr(%s) => %x, expected %s", tt.in, r, tt.out) - } - - s := DecodeCompactNodeAddr(r) - if s != tt.in { - t.Errorf("decodeCompactNodeAddr(%x) => %s, expected %s", r, s, tt.in) - } - } -} |
