aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile15
-rw-r--r--bt/messages.go99
-rw-r--r--bt/options.go50
-rw-r--r--bt/torrent.go289
-rw-r--r--bt/worker.go398
-rw-r--r--cmd/indexer/README.md10
-rw-r--r--cmd/indexer/main.go15
-rw-r--r--cmd/indexer/otel.go45
-rw-r--r--cmd/indexer/run.go245
-rw-r--r--dht/client.go767
-rw-r--r--dht/compact_node.go133
-rw-r--r--dht/compact_node_test.go63
-rw-r--r--dht/krpc.go259
-rw-r--r--dht/krpc_test.go106
-rw-r--r--dht/ktable.go411
-rw-r--r--dht/ktable_test.go106
-rw-r--r--dht/messages.go127
-rw-r--r--dht/metrics.go76
-rw-r--r--dht/node.go443
-rw-r--r--dht/node_test.go42
-rw-r--r--dht/options.go73
-rw-r--r--dht/packet.go33
-rw-r--r--dht/remote_node.go18
-rw-r--r--dht/routing_table.go121
-rw-r--r--dht/routing_table_test.go48
-rw-r--r--dht/slab.go25
-rw-r--r--dht/transactions.go93
-rw-r--r--krpc/krpc.go181
-rw-r--r--krpc/krpc_test.go30
29 files changed, 2810 insertions, 1511 deletions
diff --git a/Makefile b/Makefile
index eb5ee50..1649254 100644
--- a/Makefile
+++ b/Makefile
@@ -1,13 +1,12 @@
-VERSION ?= $(shell git describe --tags --always)
-SRC := $(shell find . -type f -name '*.go')
-FLAGS := --tags fts5
+VERSION != git describe --tags --always
+SRC != find . -type f -name '*.go'
PLAT := windows darwin linux freebsd openbsd
BINARY := $(patsubst %,dist/%,$(shell find cmd/* -maxdepth 0 -type d -exec basename {} \;))
RELEASE := $(foreach os, $(PLAT), $(patsubst %,%-$(os), $(BINARY)))
.PHONY: build
-build: sqlite $(BINARY)
+build: $(BINARY)
.PHONY: release
release: $(RELEASE)
@@ -15,20 +14,16 @@ release: $(RELEASE)
dist/%: export GOOS=$(word 2,$(subst -, ,$*))
dist/%: bin=$(word 1,$(subst -, ,$*))
dist/%: $(SRC) $(shell find cmd/$(bin) -type f -name '*.go')
- go build -ldflags "-X main.version=$(VERSION)" $(FLAGS) \
+ go build -ldflags "-X main.version=$(VERSION)" \
-o $@ ./cmd/$(bin)
-sqlite:
- CGO_ENABLED=1 go get -u $(FLAGS) github.com/mattn/go-sqlite3 \
- && go install $(FLAGS) github.com/mattn/go-sqlite3
-
.PHONY: test
test:
go test -short -coverprofile=coverage.out ./... \
&& go tool cover -func=coverage.out
.PHONY: lint
-lint: ; go vet ./...
+lint: ; golangci-lint run
.PHONY: clean
clean:
diff --git a/bt/messages.go b/bt/messages.go
new file mode 100644
index 0000000..bcf7da6
--- /dev/null
+++ b/bt/messages.go
@@ -0,0 +1,99 @@
+package bt
+
+// BEP4 Core protocol Message IDs
+
+const (
+ BTMsgChoke uint8 = iota
+ BTMsgUnchoke
+ BTMsgInterested
+ BTMsgNotInterested
+ BTMsgHave
+ BTMsgBitfield
+ BTMsgRequest
+ BTMsgPiece
+ BTMsgCancel
+ BTMsgPort
+ _
+ _
+ _
+ BTMsgSuggest // 13
+ BTMsgHaveAll
+ BTMsgHaveNone
+ BTMsgRejectRequest
+ BTMsgAllowedFast
+ _
+ _
+ BTMsgExtended // 20 Extended messages BEP10
+)
+const (
+ // BEP10 send message types
+ ExtMsgTypeHandshake = uint8(0)
+ // >0 is extension dependant
+)
+
+const (
+ // BEP9 extension message types
+
+ // ExtMsgTypeRequest marks a request message type
+ ExtMsgTypeRequest = 0
+ // extMsgTypeData marks a data message type
+ ExtMsgTypeData = 1
+ // extMsgTypeReject marks a reject message type
+ ExtMsgTypeReject = 2
+)
+
+type ExtMsg struct {
+ Type int `bencode:"msg_type"`
+ Piece int `bencode:"piece"`
+ TotalSize int `bencode:"total_size,omitzero"`
+}
+
+type ExtMsgHandshake struct {
+ // Dictionary of supported extension messages which maps names of
+ // extensions to an extended message ID for each extension message.
+ Messages map[string]int `bencode:"m"`
+ // Local TCP listen port. Allows each side to learn about the TCP port
+ // number of the other side. Note that there is no need for the receiving
+ // side of the connection to send this extension message, since its port
+ // number is already known.
+ Port int `bencode:"p,omitzero"`
+ // Client name and version (as a utf-8 string).
+ Version string `bencode:"v,omitzero"`
+ // A string containing the compact representation of the ip address this
+ // peer sees you as. i.e. this is the receiver's external ip address
+ // (no port is included). This may be either an IPv4 (4 bytes) or an
+ // IPv6 (16 bytes) address.
+ IP string `bencode:"yourip,omitzero"`
+ // If this peer has an IPv6 interface, this is the compact representation
+ // of that address (16 bytes).
+ IPv6 string `bencode:"ipv6,omitzero"`
+ // If this peer has an IPv4 interface, this is the compact representation
+ // of that address (4 bytes).
+ IPv4 string `bencode:"ipv4,omitzero"`
+ // An integer, the number of outstanding request messages this client
+ // supports without dropping any. The default in in libtorrent is 250.
+ Qsize int `bencode:"reqq,omitzero"`
+
+ // BEP9
+ // Specifies an integer value of the number of bytes of the metadata.
+ MetadataSize int `bencode:"metadata_size,omitzero"`
+}
+
+type MetaInfo struct {
+ PieceLength int `bencode:"piece_length"`
+ Pieces string `bencode:"pieces"`
+ Private bool `bencode:"private"`
+
+ Name string `bencode:"name"`
+
+ // Single file
+ Length int `bencode:"length,omitzero"`
+ MD5Sum string `bencode:"md5sum,omitzero"`
+
+ // Multiple files
+ Files []struct {
+ Length int `bencode:"length"`
+ MD5Sum string `bencode:"md5sum"`
+ Path string `bencode:"path"`
+ } `bencode:"files,omitzero"`
+}
diff --git a/bt/options.go b/bt/options.go
deleted file mode 100644
index 391ad0f..0000000
--- a/bt/options.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package bt
-
-import (
- "src.userspace.com.au/dhtsearch/models"
- "src.userspace.com.au/logger"
-)
-
-type Option func(*Worker) error
-
-// SetOnNewTorrent sets the callback
-func SetOnNewTorrent(f func(models.Torrent)) Option {
- return func(w *Worker) error {
- w.OnNewTorrent = f
- return nil
- }
-}
-
-// SetOnBadPeer sets the callback
-func SetOnBadPeer(f func(models.Peer)) Option {
- return func(w *Worker) error {
- w.OnBadPeer = f
- return nil
- }
-}
-
-// SetPort sets the port to listen on
-func SetPort(p int) Option {
- return func(w *Worker) error {
- w.port = p
- return nil
- }
-}
-
-// SetIPv6 enables IPv6
-func SetIPv6(b bool) Option {
- return func(w *Worker) error {
- if b {
- w.family = "tcp6"
- }
- return nil
- }
-}
-
-// SetLogger sets the logger
-func SetLogger(l logger.Logger) Option {
- return func(w *Worker) error {
- w.log = l
- return nil
- }
-}
diff --git a/bt/torrent.go b/bt/torrent.go
new file mode 100644
index 0000000..785fd47
--- /dev/null
+++ b/bt/torrent.go
@@ -0,0 +1,289 @@
+package bt
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/netip"
+ "time"
+
+ "userspace.com.au/dhtsearch/bencode"
+ "userspace.com.au/dhtsearch/infohash"
+)
+
+type Torrent struct {
+ ih infohash.ID
+}
+
+func NewTorrent(ih infohash.ID) *Torrent {
+ return &Torrent{ih: ih}
+}
+
+func (t *Torrent) FetchMetadata(ctx context.Context, ap netip.AddrPort) ([]byte, error) {
+ tcpTimeout := 5 * time.Second
+ tcpAddr := net.TCPAddrFromAddrPort(ap)
+ conn, err := net.DialTCP("tcp", nil, tcpAddr)
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+ // conn.SetLinger(0)
+ // conn.SetNoDelay(true)
+
+ local := conn.LocalAddr().(*net.TCPAddr).AddrPort()
+
+ // Collect reads here
+ buf := new(bytes.Buffer)
+ //buf.Grow(blockSize)
+
+ // Handshake
+ conn.SetWriteDeadline(time.Now().Add(tcpTimeout))
+ hsLength, err := sendHandshake(conn, t.ih)
+ if err != nil {
+ return nil, err
+ }
+ conn.SetReadDeadline(time.Now().Add(tcpTimeout))
+ if err := read(conn, hsLength, buf); err != nil {
+ return nil, fmt.Errorf("handshake failed: %w", err)
+ }
+ resp := buf.Next(hsLength)
+ // BEP10 Check for extension protocol
+ // The bit selected for the extension protocol is bit 20 from the right
+ // (counting starts at 0). So (reserved_byte[5] & 0x10) is the
+ // expression to use for checking if the client supports extended messaging.
+ // See https://www.libtorrent.org/extension_protocol.html
+ extBit := 1 + len(protocolString) + 6
+ if resp[extBit]&0x10 != 0 {
+ return nil, errors.New("extension protocol not supported")
+ }
+ if err = sendExtHandshake(conn, local, ap); err != nil {
+ return nil, err
+ }
+
+ var (
+ // extended message type, set in first message
+ extMsgType uint8
+ // Expected size of all metadata, set in first message
+ // Used to determine length of last piece
+ metadataSize int
+ // Collection of piece data, the output
+ pieces [][]byte
+ // Expected number of pieces, derived from metadatasize
+ totalPieces int
+ // Current count of pieces
+ gotPieces int
+ )
+
+ // Loop over incoming messages
+ for {
+ buf.Reset()
+ if err := readMessage(conn, buf); err != nil {
+ return nil, err
+ }
+ length := buf.Len()
+ if length == 0 {
+ continue
+ }
+ // We are only interested in bt extension protocol (20)
+ msgProtocol, err := buf.ReadByte()
+ if err != nil {
+ return nil, err
+ }
+ if msgProtocol != BTMsgExtended {
+ continue
+ }
+
+ // Read the extension type ie. 0 == handshake
+ msgExtType, err := buf.ReadByte()
+ if err != nil {
+ return nil, err
+ }
+ // // This was set in the handshake and should match
+ // if msgExtType != ExtMsgTypeData {
+ // continue
+ // }
+ payload, err := io.ReadAll(buf)
+ if err != nil {
+ return nil, err
+ }
+ // We are past the protocol and extension type bits
+ br := bencode.NewReaderFromBytes(payload)
+
+ if msgExtType == ExtMsgTypeHandshake {
+ if pieces != nil {
+ // We have already done the handshake!
+ return nil, errors.New("duplicate handshake")
+ }
+ msg := ExtMsgHandshake{}
+ if !br.ReadStruct(&msg) {
+ return nil, br.Err()
+ }
+ totalPieces = msg.MetadataSize / blockSize
+ if msg.MetadataSize%blockSize != 0 {
+ totalPieces += 1
+ }
+ if totalPieces == 0 {
+ return nil, errors.New("no pieces to fetch")
+ }
+ pieces = make([][]byte, totalPieces)
+
+ // The extenstion type we can handle
+ if id, ok := msg.Messages["ut_metadata"]; !ok {
+ return nil, errors.New("missing ut_metadata extension")
+ } else {
+ extMsgType = uint8(id)
+ }
+
+ metadataSize = msg.MetadataSize
+ // Request all metadata pieces
+ if err := requestMetadata(conn, extMsgType, totalPieces); err != nil {
+ return nil, err
+ }
+ continue
+ }
+ if pieces == nil {
+ return nil, errors.New("pieces not created")
+ }
+
+ var msg ExtMsg
+ if !br.ReadStruct(&msg) {
+ return nil, br.Err()
+ }
+ if msg.Type != ExtMsgTypeData {
+ continue
+ }
+ pieceLen := len(payload) - int(br.Count())
+ if msg.Piece == totalPieces-1 {
+ // Last piece needs to equal remainder
+ if pieceLen != metadataSize%blockSize {
+ return nil, fmt.Errorf("incorrect final piece %d", msg.Piece)
+ }
+ } else {
+ // Should be full block
+ if pieceLen != blockSize {
+ return nil, fmt.Errorf("incomplete piece %d", msg.Piece)
+ }
+ }
+
+ pieces[msg.Piece] = payload[br.Count():]
+ gotPieces += 1
+ if gotPieces == totalPieces {
+ break
+ }
+ }
+ return bytes.Join(pieces, nil), nil
+}
+
+// Our peer name and version TODO
+const peerID = "ML-010"
+
+// BTv1 identifier
+const protocolString = "BitTorrent protocol"
+
+// The metadata is handled in blocks of 16KiB (16384 Bytes), see BEP9
+const blockSize = 16384
+
+const (
+ extMsgTypeHandshake = uint8(0)
+)
+
+func sendHandshake(w io.Writer, id infohash.ID) (int, error) {
+ // 49+len(pstr) bytes long
+ // <pstrlen><pstr><reserved><info_hash><peer_id>
+ out := make([]byte, 49+len(protocolString))
+ out[0] = byte(len(protocolString))
+ n := 1
+ // pstr
+ copy(out[n:n+len(protocolString)], []byte(protocolString))
+ n += len([]byte(protocolString))
+ // reserved, setting extended protocol bit
+ reserved := bytes.Repeat([]byte{0}, 8)
+ reserved[5] |= 0x10 // Extension protocol
+ reserved[7] |= 0x01 // Distributed Hash Table
+ copy(out[n:n+8], reserved)
+ n += 8
+ // target infohash
+ copy(out[n:n+20], id[:])
+ n += 20
+ // our peer ID
+ copy(out[n:], peerID[:])
+ return w.Write(out)
+}
+
+func sendExtHandshake(rw io.ReadWriter, local, remote netip.AddrPort) error {
+ payload := ExtMsgHandshake{
+ Messages: map[string]int{
+ "ut_metadata": 1,
+ },
+ IP: string(remote.Addr().AsSlice()),
+ Port: int(local.Port()),
+ }
+ if local.Addr().Is6() {
+ payload.IPv6 = string(local.Addr().AsSlice())
+ } else {
+ payload.IPv4 = string(local.Addr().AsSlice())
+ }
+ b, err := bencode.Marshal(payload)
+ if err != nil {
+ return err
+ }
+ _, err = sendMessage(rw, ExtMsgTypeHandshake, b)
+ return err
+}
+
+// Send multiple requests for all the pieces
+func requestMetadata(w io.Writer, extMsgType uint8, totalPieces int) error {
+ for i := range totalPieces {
+ payload := ExtMsg{
+ Type: ExtMsgTypeRequest,
+ Piece: i,
+ }
+ b, err := bencode.Marshal(payload)
+ if err != nil {
+ return err
+ }
+ sendMessage(w, extMsgType, b)
+ }
+ return nil
+}
+
+// read reads size-length bytes
+func read(r io.Reader, size int, w io.Writer) error {
+ n, err := io.CopyN(w, r, int64(size))
+ if err != nil {
+ return err
+ }
+ if n != int64(size) {
+ return fmt.Errorf("short read, got %d, want %d", n, size)
+ }
+ return nil
+}
+
+// Sends a length prefix, protocol byte, type byte, then data, see BEP10
+func sendMessage(w io.Writer, mType uint8, data []byte) (n int, err error) {
+ // 4 bytes length prefix. Size of the entire message. (Big endian)
+ // 1 byte bt extended message id (20)
+ // 1 byte bt message type id (0: handshake, >0: according to handshake)
+ out := append([]byte{BTMsgExtended}, mType)
+ out = append(out, data...)
+ length := int32(len(out))
+ if err = binary.Write(w, binary.BigEndian, length); err != nil {
+ return n, err
+ }
+ return w.Write(out)
+}
+
+func readMessage(r io.Reader, buf *bytes.Buffer) error {
+ var length int32
+ if err := binary.Read(r, binary.BigEndian, &length); err != nil {
+ return err
+ }
+ if length == 0 {
+ return nil
+ }
+ return read(r, int(length), buf)
+}
diff --git a/bt/worker.go b/bt/worker.go
deleted file mode 100644
index f3c45d0..0000000
--- a/bt/worker.go
+++ /dev/null
@@ -1,398 +0,0 @@
-package bt
-
-import (
- "bytes"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "io/ioutil"
- "net"
- "time"
-
- "src.userspace.com.au/dhtsearch/krpc"
- "src.userspace.com.au/dhtsearch/models"
- "src.userspace.com.au/go-bencode"
- "src.userspace.com.au/logger"
-)
-
-const (
- // MsgRequest marks a request message type
- MsgRequest = iota
- // MsgData marks a data message type
- MsgData
- // MsgReject marks a reject message type
- MsgReject
- // MsgExtended marks it as an extended message
- MsgExtended = 20
-)
-
-const (
- // BlockSize is 2 ^ 14
- BlockSize = 16384
- // MaxMetadataSize represents the max medata it can accept
- MaxMetadataSize = BlockSize * 1000
- // HandshakeBit represents handshake bit
- HandshakeBit = 0
- // TCPTimeout for BT connections
- TCPTimeout = 5
-)
-
-var handshakePrefix = []byte{
- 19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114,
- 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 16, 0, 1,
-}
-
-type Worker struct {
- pool chan chan models.Peer
- port int
- family string
- OnNewTorrent func(t models.Torrent)
- OnBadPeer func(p models.Peer)
- log logger.Logger
-}
-
-func NewWorker(pool chan chan models.Peer, opts ...Option) (*Worker, error) {
- var err error
- w := &Worker{
- pool: pool,
- }
-
- // Set variadic options passed
- for _, option := range opts {
- err = option(w)
- if err != nil {
- return nil, err
- }
- }
- return w, nil
-}
-
-func (bt *Worker) Run() error {
- peerCh := make(chan models.Peer)
-
- for {
- // Signal we are ready for work
- bt.pool <- peerCh
-
- select {
- case p := <-peerCh:
- // Got work
- bt.log.Debug("worker got work", "peer", p)
- md, err := bt.fetchMetadata(p)
- if err != nil {
- bt.log.Debug("failed to fetch metadata", "error", err)
- if bt.OnBadPeer != nil {
- bt.OnBadPeer(p)
- }
- continue
- }
- t, err := models.TorrentFromMetadata(p.Infohash, md)
- if err != nil {
- bt.log.Warn("failed to load torrent", "error", err)
- continue
- }
- if bt.OnNewTorrent != nil {
- bt.OnNewTorrent(*t)
- }
- }
- }
-}
-
-// fetchMetadata fetchs medata info accroding to infohash from dht.
-func (bt *Worker) fetchMetadata(p models.Peer) (out []byte, err error) {
- var (
- length int
- msgType byte
- totalPieces int
- pieces [][]byte
- utMetadata int
- metadataSize int
- )
-
- defer func() {
- pieces = nil
- recover()
- }()
-
- //ll := bt.log.WithFields("address", p.Addr.String())
-
- //ll.Debug("connecting")
- dial, err := net.DialTimeout("tcp", p.Addr.String(), time.Second*15)
- if err != nil {
- return out, err
- }
- // Cast
- conn := dial.(*net.TCPConn)
- conn.SetLinger(0)
- defer conn.Close()
- //ll.Debug("dialed")
-
- data := bytes.NewBuffer(nil)
- data.Grow(BlockSize)
-
- ih := models.GenInfohash()
-
- // TCP handshake
- //ll.Debug("sending handshake")
- _, err = sendHandshake(conn, p.Infohash, ih)
- if err != nil {
- return nil, err
- }
-
- // Handle the handshake response
- //ll.Debug("handling handshake response")
- err = read(conn, 68, data)
- if err != nil {
- return nil, err
- }
- next := data.Next(68)
- //ll.Debug("got next data")
- if !(bytes.Equal(handshakePrefix[:20], next[:20]) && next[25]&0x10 != 0) {
- //ll.Debug("next data does not match", "next", next)
- return nil, errors.New("invalid handshake response")
- }
-
- //ll.Debug("sending ext handshake")
- _, err = sendExtHandshake(conn)
- if err != nil {
- return nil, err
- }
-
- for {
- length, err = readMessage(conn, data)
- if err != nil {
- return out, err
- }
-
- if length == 0 {
- continue
- }
-
- msgType, err = data.ReadByte()
- if err != nil {
- return out, err
- }
-
- switch msgType {
- case MsgExtended:
- extendedID, err := data.ReadByte()
- if err != nil {
- return out, err
- }
-
- payload, err := ioutil.ReadAll(data)
- if err != nil {
- return out, err
- }
-
- if extendedID == 0 {
- if pieces != nil {
- return out, errors.New("invalid extended ID")
- }
-
- utMetadata, metadataSize, err = getUTMetaSize(payload)
- if err != nil {
- return out, err
- }
-
- totalPieces = metadataSize / BlockSize
- if metadataSize%BlockSize != 0 {
- totalPieces++
- }
-
- pieces = make([][]byte, totalPieces)
- go bt.requestPieces(conn, utMetadata, metadataSize, totalPieces)
-
- continue
- }
-
- if pieces == nil {
- return out, errors.New("no pieces found")
- }
-
- dict, index, err := bencode.DecodeDict(payload, 0)
- if err != nil {
- return out, err
- }
-
- mt, err := krpc.GetInt(dict, "msg_type")
- if err != nil {
- return out, err
- }
-
- if mt != MsgData {
- continue
- }
-
- piece, err := krpc.GetInt(dict, "piece")
- if err != nil {
- return out, err
- }
-
- pieceLen := length - 2 - index
-
- // Not last piece? should be full block
- if totalPieces > 1 && piece != totalPieces-1 && pieceLen != BlockSize {
- return out, fmt.Errorf("incomplete piece %d", piece)
- }
- // Last piece needs to equal remainder
- if piece == totalPieces-1 && pieceLen != metadataSize%BlockSize {
- return out, fmt.Errorf("incorrect final piece %d", piece)
- }
-
- pieces[piece] = payload[index:]
-
- if bt.isDone(pieces) {
- return bytes.Join(pieces, nil), nil
- }
- default:
- data.Reset()
- }
- }
-}
-
-// isDone checks if all pieces are complete
-func (bt *Worker) isDone(pieces [][]byte) bool {
- for _, piece := range pieces {
- if len(piece) == 0 {
- return false
- }
- }
- return true
-}
-
-// read reads size-length bytes from conn to data.
-func read(conn net.Conn, size int, data io.Writer) error {
- conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout)))
- n, err := io.CopyN(data, conn, int64(size))
- if err != nil || n != int64(size) {
- return errors.New("read error")
- }
- return nil
-}
-
-// readMessage gets a message from the tcp connection.
-func readMessage(conn net.Conn, data *bytes.Buffer) (length int, err error) {
- if err = read(conn, 4, data); err != nil {
- return length, err
- }
-
- length, err = bytes2int(data.Next(4))
- if err != nil {
- return length, err
- }
-
- if length == 0 {
- return length, nil
- }
-
- err = read(conn, length, data)
- return length, err
-}
-
-// sendMessage sends data to the connection.
-func sendMessage(conn net.Conn, data []byte) (int, error) {
- length := int32(len(data))
-
- buffer := bytes.NewBuffer(nil)
- binary.Write(buffer, binary.BigEndian, length)
-
- conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout)))
- return conn.Write(append(buffer.Bytes(), data...))
-}
-
-// sendHandshake sends handshake message to conn.
-func sendHandshake(conn net.Conn, ih, id models.Infohash) (int, error) {
- data := make([]byte, 68)
- copy(data[:28], handshakePrefix)
- copy(data[28:48], []byte(ih))
- copy(data[48:], []byte(id))
-
- conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(TCPTimeout)))
- return conn.Write(data)
-}
-
-// onHandshake handles the handshake response.
-func onHandshake(data []byte) (err error) {
- if !(bytes.Equal(handshakePrefix[:20], data[:20]) && data[25]&0x10 != 0) {
- err = errors.New("invalid handshake response")
- }
- return err
-}
-
-// sendExtHandshake requests for the ut_metadata and metadata_size.
-func sendExtHandshake(conn net.Conn) (int, error) {
- m, err := bencode.EncodeDict(map[string]interface{}{
- "m": map[string]interface{}{"ut_metadata": 1},
- })
- if err != nil {
- return 0, err
- }
- data := append([]byte{MsgExtended, HandshakeBit}, m...)
-
- return sendMessage(conn, data)
-}
-
-// getUTMetaSize returns the ut_metadata and metadata_size.
-func getUTMetaSize(data []byte) (utMetadata int, metadataSize int, err error) {
- dict, _, err := bencode.DecodeDict(data, 0)
- if err != nil {
- return utMetadata, metadataSize, err
- }
-
- m, err := krpc.GetMap(dict, "m")
- if err != nil {
- return utMetadata, metadataSize, err
- }
-
- utMetadata, err = krpc.GetInt(m, "ut_metadata")
- if err != nil {
- return utMetadata, metadataSize, err
- }
-
- metadataSize, err = krpc.GetInt(dict, "metadata_size")
- if err != nil {
- return utMetadata, metadataSize, err
- }
-
- if metadataSize > MaxMetadataSize {
- err = errors.New("metadata_size too long")
- }
- return utMetadata, metadataSize, err
-}
-
-// Request more pieces
-func (bt *Worker) requestPieces(conn net.Conn, utMetadata int, metadataSize int, totalPieces int) {
- buffer := make([]byte, 1024)
- for i := 0; i < totalPieces; i++ {
- buffer[0] = MsgExtended
- buffer[1] = byte(utMetadata)
-
- msg, _ := bencode.EncodeDict(map[string]interface{}{
- "msg_type": MsgRequest,
- "piece": i,
- })
-
- length := len(msg) + 2
- copy(buffer[2:length], msg)
-
- sendMessage(conn, buffer[:length])
- }
- buffer = nil
-}
-
-// bytes2int returns the int value it represents.
-func bytes2int(data []byte) (int, error) {
- n := len(data)
- if n > 8 {
- return 0, errors.New("data too long")
- }
-
- val := uint64(0)
-
- for i, b := range data {
- val += uint64(b) << uint64((n-i-1)*8)
- }
- return int(val), nil
-}
diff --git a/cmd/indexer/README.md b/cmd/indexer/README.md
new file mode 100644
index 0000000..345301a
--- /dev/null
+++ b/cmd/indexer/README.md
@@ -0,0 +1,10 @@
+# DHT indexer
+
+The general process:
+
+- generate random infohash
+- on tick: send sample request to close nodes
+- on find_nodes response: send sample request to found
+- on samples response: send get_peers request to sender
+- on get_peers response: download metadata
+
diff --git a/cmd/indexer/main.go b/cmd/indexer/main.go
new file mode 100644
index 0000000..758c08e
--- /dev/null
+++ b/cmd/indexer/main.go
@@ -0,0 +1,15 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+)
+
+func main() {
+ ctx := context.Background()
+ if err := run(ctx, os.Args[1:], os.Getenv, os.Stdout, os.Stderr); err != nil {
+ fmt.Fprintf(os.Stderr, "%s\n", err)
+ os.Exit(1)
+ }
+}
diff --git a/cmd/indexer/otel.go b/cmd/indexer/otel.go
new file mode 100644
index 0000000..ae86a3f
--- /dev/null
+++ b/cmd/indexer/otel.go
@@ -0,0 +1,45 @@
+package main
+
+import (
+ "context"
+ "time"
+
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp"
+ "go.opentelemetry.io/otel/sdk/metric"
+ "go.opentelemetry.io/otel/sdk/resource"
+ semconv "go.opentelemetry.io/otel/semconv/v1.37.0"
+)
+
+func newMeterProvider(ctx context.Context) (func(context.Context) error, error) {
+ res, err := resource.Merge(
+ resource.Default(),
+ resource.NewWithAttributes(
+ semconv.SchemaURL,
+ attribute.String("service.name", "indexer"),
+ //semconv.ServiceVersion("0.1.0"),
+ ),
+ )
+ if err != nil {
+ return nil, err
+ }
+ exporter, err := otlpmetrichttp.New(ctx)
+ if err != nil {
+ return nil, err
+ }
+ // metricExporter, err := stdoutmetric.New()
+ // if err != nil {
+ // panic(err)
+ // }
+
+ mp := metric.NewMeterProvider(
+ metric.WithResource(res),
+ metric.WithReader(metric.NewPeriodicReader(
+ exporter,
+ metric.WithInterval(30*time.Second),
+ )),
+ )
+ otel.SetMeterProvider(mp)
+ return mp.Shutdown, nil
+}
diff --git a/cmd/indexer/run.go b/cmd/indexer/run.go
new file mode 100644
index 0000000..85b3198
--- /dev/null
+++ b/cmd/indexer/run.go
@@ -0,0 +1,245 @@
+package main
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "flag"
+ "io"
+ "log/slog"
+ "net"
+ "net/netip"
+ "os"
+ "os/signal"
+ "sync"
+ "syscall"
+ "time"
+
+ "userspace.com.au/dhtsearch/bencode"
+ "userspace.com.au/dhtsearch/bt"
+ "userspace.com.au/dhtsearch/dht"
+ "userspace.com.au/dhtsearch/infohash"
+)
+
+func run(
+ ctx context.Context,
+ _ []string,
+ _ func(string) string,
+ stdout io.Writer,
+ _ io.Writer,
+) error {
+ ctx, cancel := signal.NotifyContext(
+ ctx,
+ syscall.SIGINT, syscall.SIGTERM,
+ )
+ defer cancel()
+
+ var (
+ addrFlag string
+ publicFlag string
+ stateFlag string
+ secureFlag bool
+ metricsFlag bool
+ verboseFlag bool
+ )
+
+ flag.StringVar(&addrFlag, "addr", "", "listen address:port")
+ flag.StringVar(&publicFlag, "public", "", "public address:port")
+ flag.StringVar(&stateFlag, "state", "", "state file")
+ flag.BoolVar(&secureFlag, "secure", false, "Generate secure infohash")
+ flag.BoolVar(&metricsFlag, "metrics", false, "Generate metrics")
+ flag.BoolVar(&verboseFlag, "verbose", false, "verbose")
+ flag.Parse()
+
+ var lvl = new(slog.LevelVar)
+ lvl.Set(slog.LevelInfo)
+ if verboseFlag {
+ lvl.Set(slog.LevelDebug)
+ }
+ logger := slog.New(slog.NewTextHandler(
+ stdout,
+ &slog.HandlerOptions{Level: lvl},
+ ))
+
+ opts := []dht.Option{
+ dht.WithLogger(logger),
+ }
+ if addrFlag != "" {
+ opts = append(opts, dht.WithListenAddress(addrFlag))
+ } else {
+ ip := getFirstPublicIP()
+ if ip == nil {
+ logger.Error("no IP")
+ return errors.New("no IP")
+ }
+ addr := net.JoinHostPort(ip.String(), "6881")
+ l, _ := net.ListenPacket("udp", addr)
+ opts = append(opts, dht.WithListener(l))
+
+ }
+ if publicFlag != "" {
+ opts = append(opts, dht.WithPublicAddr(publicFlag, secureFlag))
+ }
+
+ if stateFlag != "" {
+ b, err := os.ReadFile(stateFlag)
+ if err != nil {
+ if !os.IsNotExist(err) {
+ return err
+ }
+ } else {
+ opts = append(opts, dht.WithState(bytes.NewReader(b)))
+ }
+ }
+
+ if metricsFlag {
+ closer, err := newMeterProvider(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() {
+ if err := closer(context.Background()); err != nil {
+ logger.Error(err.Error())
+ }
+ }()
+ }
+
+ //fetchMetadata := fetchMetadata(logger)
+
+ svr, err := dht.NewClient(ctx, opts...)
+ if err != nil {
+ return err
+ }
+ gotInfohash := make(map[infohash.ID]bool)
+ type toGet struct {
+ ih infohash.ID
+ ap netip.AddrPort
+ }
+ hashes := make(chan toGet)
+
+ go func() {
+ for tg := range hashes {
+ if gotInfohash[tg.ih] {
+ continue
+ }
+ gotInfohash[tg.ih] = true
+ t := bt.NewTorrent(tg.ih)
+ logger.Info("fetching metadata", "ih", tg.ih, "address", tg.ap)
+ b, err := t.FetchMetadata(ctx, tg.ap)
+ if err != nil {
+ gotInfohash[tg.ih] = false
+ logger.Error("failed to fetch metadata", "error", err)
+ return
+ }
+ var info bt.MetaInfo
+ err = bencode.Unmarshal(b, &info)
+ logger.Info("fetched metadata", "name", info.Name)
+
+ }
+ }()
+ // port := uint16(m.Args.Port)
+ // if m.Args.ImpliedPort {
+ // port = rn.AddrPort.Port()
+ // }
+ // ap := netip.AddrPortFrom(rn.AddrPort.Addr(), port)
+
+ getPeers := func(ctx context.Context, rn *dht.Node, m dht.Msg) {
+ if m.Response == nil {
+ return
+ }
+ samples := infohash.SplitCompactInfohashes(m.Response.Samples)
+ logger.Info("got samples", "count", len(samples))
+ for _, ih := range samples {
+ svr.GetPeers(rn, ih, func(rn *dht.Node) {
+ logger.Info("get_peers callback")
+ hashes <- toGet{
+ ih: ih,
+ ap: rn.AddrPort,
+ }
+ })
+ }
+ }
+ getSamples := func(_ context.Context, _ *dht.Node, _ dht.Msg) {
+ ih := infohash.NewRandomID()
+ svr.GetSamples(ih)
+ }
+ //svr.OnPeersResponse = fetchMetadata
+ svr.OnNodesResponse = getSamples
+ svr.OnSamplesResponse = getPeers
+
+ var wg sync.WaitGroup
+ wg.Go(func() {
+ if err := svr.Run(ctx); err != nil {
+ logger.Error("dht client failed", "error", err)
+ }
+ })
+
+ wg.Go(func() {
+ tick := time.NewTicker(10 * time.Second)
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-tick.C:
+ ih := infohash.NewRandomID()
+ svr.GetSamples(ih)
+ }
+ }
+ })
+ wg.Wait()
+
+ if stateFlag != "" {
+ f, err := os.Create(stateFlag)
+ if err != nil {
+ return err
+ }
+ if err := svr.SaveState(f); err != nil {
+ return err
+ }
+ err = f.Close()
+ }
+ return err
+}
+
+func getFirstPublicIP() net.IP {
+ addrs, _ := net.InterfaceAddrs()
+ for _, a := range addrs {
+ ip, _, err := net.ParseCIDR(a.String())
+ if err == nil && !ip.IsLoopback() && !ip.IsPrivate() {
+ return ip
+ }
+ }
+ return nil
+}
+
+func fetchMetadata(log *slog.Logger) func(ctx context.Context, rn *dht.Node, m dht.Msg) {
+ return func(ctx context.Context, rn *dht.Node, m dht.Msg) {
+ // port := uint16(m.Args.Port)
+ // // If it is present and non-zero, the port argument should be
+ // // ignored and the source port of the UDP packet should be used
+ // // as the peer's port instead.
+ // if m.Args.ImpliedPort {
+ // port = rn.AddrPort.Port()
+ // }
+ if m.Response == nil || m.Response.ID == nil {
+ log.Warn("cannot fetch metadata, missing details")
+ return
+ }
+ // ap := netip.AddrPortFrom(rn.AddrPort.Addr(), port)
+ ap := rn.AddrPort
+
+ samples := infohash.SplitCompactInfohashes(m.Response.Samples)
+ for _, id := range samples {
+ t := bt.NewTorrent(id)
+ log.Info("fetching metadata", "ih", id, "address", ap)
+ b, err := t.FetchMetadata(ctx, ap)
+ if err != nil {
+ log.Error("failed to fetch metadata", "error", err)
+ }
+ var info bt.MetaInfo
+ err = bencode.Unmarshal(b, &info)
+ log.Info("fetched metadata", "name", info.Name)
+ }
+
+ }
+}
diff --git a/dht/client.go b/dht/client.go
new file mode 100644
index 0000000..ed0d0ce
--- /dev/null
+++ b/dht/client.go
@@ -0,0 +1,767 @@
+package dht
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "net"
+ "net/netip"
+ "os"
+ "strconv"
+ "sync"
+ "time"
+
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/metric"
+
+ "userspace.com.au/dhtsearch/bencode"
+ "userspace.com.au/dhtsearch/infohash"
+)
+
+var (
+ DefaultBootstraps = []string{
+ //"malkmus.yelnah.org:51000",
+ "torrents:51000",
+ "dht.aelitis.com:6881",
+ "dht.libtorrent.org:25401",
+ "dht.transmissionbt.com:6881",
+ "router.bittorrent.cloud:42069",
+ "router.bittorrent.com:6881",
+ "router.silotis.us:6881",
+ "router.utorrent.com:6881",
+ }
+)
+
+const defaultPacketSize = 1500
+
+const (
+ defaultTickTableMaintenance = 31 * time.Second
+ defaultStaleNodeGrace = 15 * time.Minute
+ defaultTickAuditTransactions = 23 * time.Second
+)
+
+// Client joins the DHT network
+type Client struct {
+ conn net.PacketConn
+ maxPacketSize int
+
+ node Node
+ bootstraps []string
+ udpTimeout int
+ packetsOut chan packet
+ chk *transactions
+ log *slog.Logger
+
+ table RoutingTable
+ tableState io.Reader
+
+ // Ticker and TTL durations
+ tickTableMaintenance time.Duration
+ tickAuditTransactions time.Duration
+ staleNodeGrace time.Duration
+
+ //limiter *rate.Limiter
+ ////blacklist *lru.ARCCache
+
+ // Hooks
+
+ // onPingQuery is called when a ping query is received
+ onPingQuery func(*Node)
+ // onFindNodeQuery is called when a find_node query is received
+ onFindNodeQuery func(context.Context, *Node, Msg)
+ onGetPeersQuery func(context.Context, *Node, Msg)
+ onAnnouncePeerQuery func(context.Context, *Node, Msg)
+ OnNodesResponse func(context.Context, *Node, Msg)
+ OnPeersResponse func(context.Context, *Node, Msg)
+ OnSamplesResponse func(context.Context, *Node, Msg)
+}
+
+// NewClient creates a new DHT client
+func NewClient(ctx context.Context, opts ...Option) (*Client, error) {
+ var err error
+
+ c := &Client{
+ maxPacketSize: defaultPacketSize,
+ chk: newTransRegistry(),
+ node: Node{
+ ID: infohash.NewRandomID(),
+ },
+ log: slog.New(slog.NewTextHandler(os.Stdout, nil)),
+ udpTimeout: 1000,
+ //limiter: rate.NewLimiter(rate.Limit(50), 70),
+
+ tickTableMaintenance: defaultTickTableMaintenance,
+ tickAuditTransactions: defaultTickAuditTransactions,
+ staleNodeGrace: defaultStaleNodeGrace,
+ }
+
+ // Set variadic options passed
+ for _, option := range opts {
+ err = option(c)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if c.table == nil {
+ // Default ktable implementation
+ table := NewTable(c.node.ID)
+ table.log = c.log
+ c.table = table
+ }
+
+ // After table exists, before import
+ if err := configureMetrics(c); err != nil {
+ return c, err
+ }
+
+ if c.tableState != nil {
+ if _, err := c.table.ReadFrom(c.tableState); err != nil {
+ c.log.Warn("failed to read table", "error", err)
+ }
+ c.node.ID = c.table.ID()
+ }
+
+ // if n.blacklist == nil {
+ // n.blacklist, err = lru.NewARC(1000)
+ // if err != nil {
+ // return nil, err
+ // }
+ // }
+
+ if c.conn == nil {
+ if c.conn, err = net.ListenPacket("udp", "0.0.0.0:6881"); err != nil {
+ c.log.Error("failed to listen", "error", err)
+ return nil, err
+ }
+ }
+
+ if len(c.bootstraps) == 0 {
+ c.bootstraps = DefaultBootstraps
+ }
+
+ return c, nil
+}
+
+type RoutingTable interface {
+ ID() infohash.ID
+ Add(*Node) bool
+ Has(infohash.ID) bool
+ SetSeen(infohash.ID)
+ Count() int
+ GetClosest(infohash.ID, int) []*Node
+ GetStale(int, time.Duration) []*Node
+ Remove(infohash.ID) bool
+
+ // Used for import/export
+ io.WriterTo
+ io.ReaderFrom
+}
+
+type Option func(*Client) error
+
+// WithPublicAddr sets the IP:port if different to the listen IP:port
+func WithPublicAddr(s string, secure bool) Option {
+ return func(c *Client) error {
+ var err error
+ c.node.AddrPort, err = netip.ParseAddrPort(s)
+ if err != nil {
+ return err
+ }
+ fmt.Println(c.node.AddrPort)
+ ip := net.UDPAddrFromAddrPort(c.node.AddrPort)
+ c.node.ID = infohash.NewRandomID()
+ if secure {
+ c.node.ID = infohash.NewSecureID(ip.IP)
+ }
+ c.node.Secure = infohash.IsSecure(c.node.ID, ip.IP)
+ return nil
+ }
+}
+
+func WithListener(p net.PacketConn) Option {
+ return func(c *Client) error {
+ c.conn = p
+ return nil
+ }
+}
+
+// WithListenAddress sets the IP:port address to listen on
+func WithListenAddress(s string) Option {
+ return func(c *Client) error {
+ var err error
+ c.conn, err = net.ListenPacket("udp", s)
+ return err
+ }
+}
+
+// WithLogger sets the IP:port address to listen on
+func WithLogger(l *slog.Logger) Option {
+ return func(c *Client) error {
+ c.log = l
+ return nil
+ }
+}
+
+// WithBootstraps enables custom bootstrap addresses
+func WithBootstraps(addrs ...string) Option {
+ return func(c *Client) error {
+ c.bootstraps = addrs
+ return nil
+ }
+}
+
+// WithTable enables using a custom node table implementation.
+func WithTable(rt RoutingTable) Option {
+ return func(c *Client) error {
+ c.table = rt
+ return nil
+ }
+}
+
+// WithState enables using a custom node table implementation.
+func WithState(r io.Reader) Option {
+ return func(c *Client) error {
+ c.tableState = r
+ return nil
+ }
+}
+
+// OnAnnoucePeer is called when an announce_peer query is received
+func OnAnnouncePeerQuery(f func(context.Context, *Node, Msg)) Option {
+ return func(c *Client) error {
+ c.onAnnouncePeerQuery = f
+ return nil
+ }
+}
+
+// OnGetPeersQuery is called when a get_peers query is received
+func OnGetPeersQuery(f func(context.Context, *Node, Msg)) Option {
+ return func(c *Client) error {
+ c.onGetPeersQuery = f
+ return nil
+ }
+}
+
+// OnPeersResponse is called when a response has peers
+func OnPeersResponse(f func(context.Context, *Node, Msg)) Option {
+ return func(c *Client) error {
+ c.OnPeersResponse = f
+ return nil
+ }
+}
+
+// OnSamplesResponse is called when a response has samples
+func OnSamplesResponse(f func(context.Context, *Node, Msg)) Option {
+ return func(c *Client) error {
+ c.OnSamplesResponse = f
+ return nil
+ }
+}
+
+// Close stuff
+func (c *Client) Close() error {
+ c.log.Warn("node closing")
+ return nil
+}
+
+func (c *Client) SaveState(w io.Writer) error {
+ _, err := c.table.WriteTo(w)
+ return err
+}
+
+// Run starts the node on the DHT
+func (c *Client) Run(ctx context.Context) error {
+ var wg sync.WaitGroup
+ wg.Go(func() { c.packetWriter(ctx) })
+ wg.Go(func() { c.tableMaintenance(ctx) })
+ wg.Go(func() { c.auditTransactions(ctx) })
+ wg.Go(func() { c.packetReader(ctx) })
+ c.log.Info("listening", "id", c.node.ID, "address", c.node.AddrPort, "listen", c.conn.LocalAddr().String())
+ wg.Wait()
+ // Close this after packetWriter has finished
+ close(c.packetsOut)
+ c.log.Debug("client stopped")
+ return nil
+}
+
+func (c *Client) wants() []string {
+ if c.node.AddrPort.Addr().Is4() {
+ return []string{Want4}
+ } else {
+ return []string{Want6}
+ }
+}
+
+//type PeerStore interface {
+// Add(Node) (bool, error)
+// Get(int) ([]Node, error)
+// //Delete(Node) error
+// Reset() error
+
+// // Used for import/export
+// io.WriterTo
+// io.ReaderFrom
+//}
+
+// func addrPort2Addr(in netip.AddrPort) net.Addr {
+// return net.UDPAddrFromAddrPort(in)
+// }
+// func addr2AddrPort(in net.Addr) (netip.AddrPort, error) {
+// return netip.ParseAddrPort(in.String())
+// }
+
+// Unprocessed packet from socket
+type packet struct {
+ raddr net.Addr
+ data []byte
+}
+
+func (c *Client) tableMaintenance(ctx context.Context) {
+ c.log.Info("starting table maintenance", "interval", c.tickTableMaintenance, "grace", c.staleNodeGrace)
+
+ pingStale := func() {
+ stale := c.table.GetStale(8, c.staleNodeGrace)
+ for _, n := range stale {
+ // if n.FailedResponses > 2 {
+ // c.table.Remove(n.ID)
+ // removed++
+ // continue
+ // }
+ n.PingAttempts += 1
+ // Don't just ping, get more nodes
+ //c.SendMsg(*n, NewPingQuery(c.node.ID))
+ rID := infohash.NewCloseID(c.node.ID)
+ c.sendMsg(n, NewFindNodeQuery(c.node.ID, rID, c.wants()))
+ }
+ c.log.Debug("contacted stale nodes", "sent", len(stale))
+ }
+
+ makeNeighbours := func() {
+ n := c.table.Count()
+ if n == 0 {
+ c.bootstrap()
+ return
+ }
+ rID := infohash.NewRandomID()
+ nodes := c.table.GetClosest(c.node.ID, 8)
+ for _, rn := range nodes {
+ c.sendMsg(rn, NewFindNodeQuery(c.node.ID, rID, c.wants()))
+ }
+ c.log.Debug("making neighbours", "sent", len(nodes))
+ }
+
+ // Once at startup
+ makeNeighbours()
+
+ ticker := time.Tick(c.tickTableMaintenance)
+ for {
+ select {
+ case <-ctx.Done():
+ c.log.Info("stopping table maintenance")
+ return
+ case <-ticker:
+ makeNeighbours()
+ pingStale()
+ }
+ }
+}
+
+func (c *Client) bootstrap() {
+ for _, s := range c.bootstraps {
+ host, portStr, err := net.SplitHostPort(s)
+ if err != nil {
+ c.log.Warn("failed to parse bootstrap entry", "address", s, "error", err)
+ continue
+ }
+ log := c.log.With(slog.String("host", host))
+ port, err := strconv.ParseUint(portStr, 10, 16)
+ if err != nil {
+ c.log.Warn("failed to parse bootstrap port", "error", err)
+ continue
+ }
+ ips, err := net.LookupIP(host)
+ if err != nil {
+ log.Warn("failed to resolve", "error", err)
+ continue
+ }
+ for _, b := range ips {
+ ip, ok := netip.AddrFromSlice(b)
+ if !ok {
+ log.Warn("failed to convert address")
+ continue
+ }
+ if c.node.AddrPort.Addr().Is4() != ip.Is4() {
+ log.Debug("different network family")
+ continue
+ }
+ ap := netip.AddrPortFrom(ip, uint16(port))
+ rn := &Node{
+ ID: infohash.NewRandomID(),
+ AddrPort: ap,
+ }
+ log.Info("bootstrapping", "address", ap)
+ c.sendMsg(rn, NewFindNodeQuery(c.node.ID, rn.ID, c.wants()))
+ }
+ }
+}
+
+// GetPeers sends a get_peers message
+func (c *Client) GetPeers(rn *Node, ih infohash.ID, cb TransactionCallback) {
+ msg := &Msg{
+ Query: "get_peers",
+ Type: "q",
+ Args: &MsgArgs{
+ ID: &c.node.ID,
+ InfoHash: &ih,
+ },
+ }
+ c.log.Debug("sending get_peers", "ih", ih, "raddr", rn.AddrPort, "id", rn.ID)
+ c.sendMsgWithCallback(rn, msg, cb)
+}
+
+// GetSamples sends a sample_infohashes message
+func (c *Client) GetSamples(ih infohash.ID) {
+ nodes := c.table.GetClosest(ih, 8)
+ msg := &Msg{
+ Query: "sample_infohashes",
+ Type: "q",
+ Args: &MsgArgs{
+ // The querying node
+ ID: &c.node.ID,
+ // The ID sought
+ Target: &ih,
+ },
+ }
+ for _, rn := range nodes {
+ c.sendMsg(rn, msg)
+ }
+}
+
+func (c *Client) sendMsg(rn *Node, m *Msg) {
+ c.sendMsgWithCallback(rn, m, nil)
+}
+
+// sendMsg sends a KRPC message to the network
+func (c *Client) sendMsgWithCallback(rn *Node, m *Msg, cb TransactionCallback) {
+ log := c.log.With(
+ "type", m.Query,
+ "addr", rn.AddrPort,
+ )
+ // Don't send to self
+ if rn.ID.Equal(c.node.ID) {
+ log.Warn("not sending to self")
+ return
+ }
+ if m.Type == "q" {
+ log = c.log.With(
+ "qType", m.Type,
+ "query", m.Query,
+ )
+ if !rn.canSend(m.Query) {
+ log.Warn("ratelimited")
+ return
+ }
+ }
+ addr := net.UDPAddrFromAddrPort(rn.AddrPort)
+ if m.Type == "q" {
+ c.registerTransaction(rn, m, cb)
+ }
+ log = log.With("tid", m.TID)
+ b, err := bencode.Marshal(m)
+ if err != nil {
+ log.Warn("failed to marshal", "error", err)
+ }
+ log.Debug("sending message")
+ c.packetsOut <- packet{
+ data: b,
+ raddr: addr,
+ }
+}
+
+func (c *Client) packetWriter(ctx context.Context) {
+ c.log.Debug("starting packet writer")
+ // Packets onto the network
+ c.packetsOut = make(chan packet, 2024)
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case p := <-c.packetsOut:
+ if p.raddr.String() == c.conn.LocalAddr().String() {
+ continue
+ }
+ _ = c.conn.SetWriteDeadline(time.Now().Add(200 * time.Millisecond))
+ n, err := c.conn.WriteTo(p.data, p.raddr)
+ if err != nil {
+ //n.blacklist.Add(p.raddr.String(), true)
+ // TODO reduce limit
+ c.log.Warn("failed to write packet", "error", err)
+ }
+ netPackets.Add(
+ ctx, 1,
+ metric.WithAttributes(
+ attribute.String("network.io.direction", "transmit")))
+ netOctets.Add(
+ ctx, int64(n),
+ metric.WithAttributes(
+ attribute.String("network.io.direction", "transmit")))
+ }
+ }
+}
+
+func (c *Client) packetReader(ctx context.Context) {
+ c.log.Info("starting packet reader")
+ pool := sync.Pool{
+ New: func() any {
+ out := make([]byte, c.maxPacketSize)
+ return &out
+ },
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ c.log.Info("stopping packet reader")
+ _ = c.conn.Close()
+ return
+ default:
+ b := pool.Get().(*[]byte)
+ _ = c.conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
+ //c.log.Debug("waiting UDP read")
+ n, addr, err := c.conn.ReadFrom(*b)
+ netOctets.Add(
+ ctx, int64(n),
+ metric.WithAttributes(
+ attribute.String("network.io.direction", "receive")))
+ if err != nil {
+ if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
+ continue
+ }
+ c.log.Warn("UDP read error", "error", err)
+ return
+ }
+ netPackets.Add(
+ ctx, 1,
+ metric.WithAttributes(
+ attribute.String("network.io.direction", "receive")))
+
+ var m Msg
+ if err := bencode.Unmarshal((*b)[:n], &m); err != nil {
+ c.log.Warn("krpc unmarshal msg failed", "error", err)
+ return
+ }
+ pool.Put(b)
+ rn := &Node{
+ LastContact: time.Now(),
+ AddrPort: addr.(*net.UDPAddr).AddrPort(),
+ }
+ go c.processMessage(ctx, rn, m)
+ }
+ }
+}
+
+// Parse a KRPC packet into a message
+// Called in goroutine
+func (c *Client) processMessage(ctx context.Context, rn *Node, m Msg) {
+ log := c.log.With("type", m.Type, "raddr", rn.AddrPort)
+ // if _, black := n.blacklist.Get(p.raddr.String()); black {
+ // return fmt.Errorf("blacklisted: %s", p.raddr.String())
+ // }
+ var err error
+ if m.Type == "q" {
+ err = c.handleQuery(ctx, rn, m)
+ } else {
+ err = c.handleResponse(ctx, rn, m)
+ }
+ if err != nil {
+ log.Warn("failed to process message", "error", err)
+ //n.blacklist.Add(p.raddr.String(), true)
+ }
+}
+
+// handleResponse handles responses received from udp.
+func (c *Client) handleResponse(ctx context.Context, rn *Node, m Msg) error {
+ log := c.log.With(
+ "raddr", rn.AddrPort,
+ "id", rn.ID,
+ )
+
+ trans, err := c.checkTransaction(m)
+ if m.Type == "e" {
+ log = log.With("code", m.Error.Code, "error", m.Error.Msg, "tid", m.TID, "qType", trans.msg)
+ if m.Error.Code == 204 {
+ // Don't query this node like this again
+ rn.rateQuery(trans.msg, -1)
+ log.Info("rate limiting node")
+ }
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+
+ if m.Response == nil {
+ return errors.New("missing reponse")
+ }
+ rn.ID = *m.Response.ID
+ if !c.table.Has(rn.ID) {
+ _ = c.table.Add(rn)
+ }
+ c.table.SetSeen(rn.ID)
+ log.Debug("received response")
+ r := *m.Response
+
+ // Add new nodes to our routing table
+ var newNodes []*Node
+ if r.NodesList != nil {
+ newNodes = append(newNodes, r.NodesList.Nodes...)
+ }
+ if r.Nodes6List != nil {
+ newNodes = append(newNodes, r.Nodes6List.Nodes...)
+ }
+ if len(newNodes) > 0 {
+ for _, cn := range newNodes {
+ _ = c.table.Add(cn)
+ }
+ if c.OnNodesResponse != nil {
+ log.Debug("on nodes")
+ // TODO set deadline
+ go c.OnNodesResponse(ctx, rn, m)
+ }
+ }
+
+ if len(r.Values) > 0 {
+ rn.rateQuery(trans.msg, r.Interval)
+ if c.OnPeersResponse != nil {
+ log.Info("on peers")
+ // TODO set deadline
+ go c.OnPeersResponse(ctx, rn, m)
+ }
+ }
+
+ if len(r.Samples) > 0 {
+ log.Info("on samples")
+ if r.Interval > 0 {
+ rn.NextQuery["sample_infohashes"] = time.Now().Add(time.Duration(r.Interval) * time.Second)
+ }
+ // if m.Response != nil && m.Response.ID != nil {
+ // samples := infohash.SplitCompactInfohashes(
+ // m.Response.Samples,
+ // )
+ // for _, ih := range samples {
+ // c.GetPeers(ih)
+ // }
+ // }
+ if c.OnSamplesResponse != nil {
+ // TODO set deadline
+ go c.OnSamplesResponse(ctx, rn, m)
+ }
+ }
+ if trans.cb != nil {
+ go trans.cb(rn)
+ }
+
+ return nil
+}
+
+// handleQuery handles incoming queries
+func (c *Client) handleQuery(ctx context.Context, rn *Node, m Msg) error {
+ if m.Args.ID == nil {
+ return errors.New("query missing ID")
+ }
+ rn.ID = *m.Args.ID
+ log := c.log.With(
+ "query", m.Query,
+ "raddr", rn.AddrPort,
+ "id", rn.ID,
+ )
+ c.table.Add(rn)
+ log.Debug("received query")
+
+ switch m.Query {
+ case "ping":
+ c.sendMsg(rn, &Msg{
+ Type: "r",
+ TID: m.TID,
+ Response: &Response{
+ ID: &c.node.ID,
+ },
+ })
+ if c.onPingQuery != nil {
+ log.Info("on ping")
+ // TODO set deadline
+ go c.onPingQuery(rn)
+ }
+ case "find_node":
+ if m.Args.Target == nil {
+ return errors.New("query missing target ID")
+ }
+ target := *m.Args.Target
+ out := &Msg{
+ Type: "r",
+ TID: m.TID,
+ Response: &Response{
+ ID: &c.node.ID,
+ Token: m.Args.Token,
+ },
+ }
+ nodes := c.table.GetClosest(target, 8)
+ switch c.node.Family {
+ case Want4:
+ out.Response.NodesList.Nodes = nodes
+ case Want6:
+ out.Response.Nodes6List.Nodes = nodes
+ }
+ c.sendMsg(rn, out)
+ if c.onFindNodeQuery != nil {
+ log.Info("on find_node")
+ // TODO set deadline
+ go c.onFindNodeQuery(ctx, rn, m)
+ }
+ case "get_peers":
+ if m.Args.InfoHash == nil {
+ return errors.New("query missing infohash ID")
+ }
+ target := *m.Args.InfoHash
+ out := &Msg{
+ Type: "r",
+ TID: m.TID,
+ Response: &Response{
+ ID: &c.node.ID,
+ Token: m.Args.Token,
+ },
+ }
+ nodes := c.table.GetClosest(target, 8)
+ switch c.node.Family {
+ case Want4:
+ out.Response.NodesList.Nodes = nodes
+ case Want6:
+ out.Response.Nodes6List.Nodes = nodes
+ }
+ c.sendMsg(rn, out)
+ if c.onGetPeersQuery != nil {
+ log.Info("on get_peers")
+ // TODO set deadline
+ go c.onGetPeersQuery(ctx, rn, m)
+ }
+ case "announce_peer":
+ // port := uint16(m.Args.Port)
+ // // If it is present and non-zero, the port argument should be
+ // // ignored and the source port of the UDP packet should be used
+ // // as the peer's port instead.
+ // if m.Args.ImpliedPort {
+ // port = rn.AddrPort.Port()
+ // }
+ if c.onAnnouncePeerQuery != nil {
+ log.Info("on announce_peer")
+ // TODO set deadline
+ go c.onAnnouncePeerQuery(ctx, rn, m)
+ }
+ default:
+ log.Warn("unknown type")
+ }
+ return nil
+}
diff --git a/dht/compact_node.go b/dht/compact_node.go
new file mode 100644
index 0000000..3d9a405
--- /dev/null
+++ b/dht/compact_node.go
@@ -0,0 +1,133 @@
+package dht
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "net/netip"
+ "slices"
+
+ "userspace.com.au/dhtsearch/bencode"
+ "userspace.com.au/dhtsearch/infohash"
+)
+
+const (
+ ip4AddrLength = 4
+ ip6AddrLength = 16
+ portLength = 2
+ compactIP4Length = ip4AddrLength + portLength
+ compactIP6Length = ip6AddrLength + portLength
+ CompactNodeInfoLength = infohash.Length + compactIP4Length // 26 bytes
+ CompactNode6InfoLength = infohash.Length + compactIP6Length // 38 bytes
+)
+
+type NodeAddr netip.AddrPort
+
+// Like AddrPort UnmarshalBinary except port is big-endian
+func (n *NodeAddr) UnmarshalBinary(b []byte) error {
+ var addr netip.Addr
+ var offset int
+ var ok bool
+ if len(b) > compactIP4Length {
+ addr, ok = netip.AddrFromSlice(b[:ip6AddrLength])
+ offset = ip6AddrLength
+ } else {
+ addr, ok = netip.AddrFromSlice(b[:ip4AddrLength])
+ offset = ip4AddrLength
+ }
+ if !ok {
+ return errors.New("failed to parse node address")
+ }
+ ap := netip.AddrPortFrom(addr, binary.BigEndian.Uint16(b[offset:]))
+ *n = NodeAddr(ap)
+ return nil
+}
+
+// Like AddrPort MarshalBinary except port is big-endian
+func (n NodeAddr) MarshalBinary() ([]byte, error) {
+ ap := netip.AddrPort(n)
+ b, err := ap.Addr().MarshalBinary()
+ if err != nil {
+ return nil, err
+ }
+ return binary.Append(b, binary.BigEndian, ap.Port())
+}
+
+type CompactNodeList struct {
+ Nodes []*Node
+}
+
+func (c *CompactNodeList) UnmarshalBencode(b []byte) error {
+ var in []byte
+ if err := bencode.Unmarshal(b, &in); err != nil {
+ return err
+ }
+ if mod := len(in) % CompactNodeInfoLength; mod != 0 {
+ return fmt.Errorf("CompactNodeList trailing %d bytes", mod)
+ }
+ //fmt.Println("chunking", len(in), "into chunks of", CompactNodeInfoLength)
+ for chunk := range slices.Chunk(in, CompactNodeInfoLength) {
+ if len(chunk) == CompactNodeInfoLength {
+ n := new(Node)
+ if err := n.UnmarshalBinary(chunk); err != nil {
+ continue
+ }
+ c.Nodes = append(c.Nodes, n)
+ }
+ }
+ return nil
+}
+
+func (c *CompactNodeList) MarshalBencode() ([]byte, error) {
+ var bb []byte
+ for _, n := range c.Nodes {
+ b, err := n.MarshalBinary()
+ if err != nil {
+ return nil, err
+ }
+ bb = append(bb, b...)
+ }
+ out := fmt.Appendf([]byte{}, "%d:", len(bb))
+ return append(out, bb...), nil
+}
+
+type CompactNode6List struct {
+ Nodes []*Node
+}
+
+func (c *CompactNode6List) UnmarshalBencode(b []byte) error {
+ var in []byte
+ if err := bencode.Unmarshal(b, &in); err != nil {
+ return err
+ }
+ //fmt.Println("nodes6", len(in), len(b), "in", hex.EncodeToString(b))
+ // if in == "" {
+ // return nil
+ // }
+ if mod := len(in) % CompactNode6InfoLength; mod != 0 {
+ return fmt.Errorf("CompactNode6List trailing %d bytes", mod)
+ }
+ for chunk := range slices.Chunk([]byte(in), CompactNode6InfoLength) {
+ if len(chunk) == CompactNode6InfoLength {
+ n := new(Node)
+ if err := n.UnmarshalBinary(chunk); err != nil {
+ return err
+ }
+ c.Nodes = append(c.Nodes, n)
+ }
+ }
+ return nil
+}
+
+func (c *CompactNode6List) MarshalBencode() ([]byte, error) {
+ var bb []byte
+ for _, n := range c.Nodes {
+ b, err := n.MarshalBinary()
+ if err != nil {
+ return nil, err
+ }
+ bb = append(bb, b...)
+ }
+ out := fmt.Appendf([]byte{}, "%d:", len(bb))
+ return append(out, bb...), nil
+}
diff --git a/dht/compact_node_test.go b/dht/compact_node_test.go
new file mode 100644
index 0000000..3429317
--- /dev/null
+++ b/dht/compact_node_test.go
@@ -0,0 +1,63 @@
+package dht
+
+import "testing"
+
+func TestCompactNodeList(t *testing.T) {
+ type nodeInfo struct {
+ id string
+ addr string
+ }
+ tests := map[string]struct {
+ in string
+ n int
+ items []nodeInfo
+ }{
+ "contrived": {
+ in: "52:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x03\x04\x05\x06\x07",
+ n: 2,
+ items: []nodeInfo{{
+ id: "0000000000000000000000000000000000000000",
+ addr: "1.2.3.4:1286",
+ },{
+ id: "0000000000000000000000000000000000000000",
+ addr: "2.3.4.5:1543",
+ }},
+ },
+ "captured": {
+ in: "208:\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6\xb4\x99d\xcf)\xaa\x11\x14\x8d\r\x8b\x8dL̑b\xf6\x14/2\xaf\xb1\xcc+P\xb6",
+ n: 8,
+ items: []nodeInfo{{
+ id: "b49964cf29aa11148d0d8b8d4ccc9162f6142f32",
+ addr: "175.177.204.43:20662",
+ },{
+ id: "b49964cf29aa11148d0d8b8d4ccc9162f6142f32",
+ addr: "175.177.204.43:20662",
+ }},
+ },
+ }
+
+ for name, tt := range tests {
+ t.Run(name, func(t *testing.T) {
+ t.Logf("input length=%d", len([]byte(tt.in)))
+ var cnl CompactNodeList
+ if err := cnl.UnmarshalBencode([]byte(tt.in)); err != nil {
+ t.Fatal(err)
+ }
+ if len(cnl.Nodes) != tt.n {
+ t.Fatalf("got %d, want %d", len(cnl.Nodes), tt.n)
+ }
+ for i, ni := range tt.items {
+ id := cnl.Nodes[i].ID
+ if id.String() != ni.id {
+ t.Fatalf("got %q, want %q", id.String(), ni.id)
+ }
+ addr := cnl.Nodes[i]
+ if addr.AddrPort.String() != ni.addr {
+ t.Fatalf("got %q, want %q", addr.AddrPort.String(), ni.addr)
+ }
+ }
+ })
+ }
+
+}
diff --git a/dht/krpc.go b/dht/krpc.go
new file mode 100644
index 0000000..1c9ee4a
--- /dev/null
+++ b/dht/krpc.go
@@ -0,0 +1,259 @@
+package dht
+
+import (
+ "fmt"
+
+ "userspace.com.au/dhtsearch/bencode"
+ "userspace.com.au/dhtsearch/infohash"
+)
+
+type Msg struct {
+ // Query method
+ // one of: "ping", "find_node", "get_peers", "announce_peer"
+ Query string `bencode:"q,omitzero"`
+
+ // Arguments sent with a query
+ Args *MsgArgs `bencode:"a,omitzero"`
+
+ // Transaction ID, required
+ TID string `bencode:"t"`
+
+ // Type of message, required
+ // one of: q for QUERY, r for RESPONSE, e for ERROR
+ Type string `bencode:"y"`
+
+ // Response payload for type 'r'
+ Response *Response `bencode:"r,omitzero"`
+
+ // Error payload for type 'e'
+ Error *Error `bencode:"e,omitzero"`
+
+ //IP NodeAddr `bencode:"ip,omitzero"`
+ // Node IP address, required for BEP42
+ IP string `bencode:"ip,omitzero"`
+
+ // Sender does not respond to queries, BEP43
+ ReadOnly bool `bencode:"ro,omitzero"`
+
+ // https://www.libtorrent.org/dht_extensions.html
+ ClientId string `bencode:"v,omitzero"`
+}
+
+type Response struct {
+ ID *infohash.ID `bencode:"id"`
+
+ // K closest nodes
+ // from get_peers, find_nodes, get & sample_infohashes
+ NodesList *CompactNodeList `bencode:"nodes,omitzero"`
+ Nodes6List *CompactNode6List `bencode:"nodes6,omitzero"`
+
+ // Token for future announce_peer or put, BEP44
+ Token string
+
+ // Torrent peers
+ Values []NodeAddr
+
+ // BEP33 (scrapes)
+ // BFsd *ScrapeBloomFilter `bencode:"BFsd,omitzero"`
+ // BFpe *ScrapeBloomFilter `bencode:"BFpe,omitzero"`
+
+ // BEP51
+ Interval int64
+ Num int64
+ // Nodes supporting this extension should always include the samples field in the response, even
+ // when it is zero-length. This lets indexing nodes to distinguish nodes supporting this
+ // extension from those that respond to unknown query types which contain a target field.
+ Samples []byte `bencode:"samples,omitzero"`
+}
+
+const (
+ Want6 = "n6"
+ Want4 = "n4"
+)
+
+type MsgArgs struct {
+ // ID of the querying Node
+ ID *infohash.ID `bencode:"id,omitzero"`
+
+ // InfoHash of the torrent
+ InfoHash *infohash.ID `bencode:"info_hash,omitzero"`
+
+ // ID of the node sought
+ Target *infohash.ID `bencode:"target,omitzero"`
+
+ // Token received from an earlier get_peers query
+ // Also used in a BEP44 put
+ Token string `bencode:"token,omitzero"`
+
+ // Sender's torrent port
+ Port int `bencode:"port,omitzero"`
+
+ // Use senders apparent DHT port
+ ImpliedPort bool `bencode:"implied_port,omitzero"`
+
+ // Network family wanted, any of "n4" and "n6", BEP32
+ Want []string `bencode:"want,omitzero"`
+
+ // BEP33
+ // NoSeed int `bencode:"noseed,omitzero"`
+ // Scrape int `bencode:"scrape,omitzero"`
+
+ // BEP44
+ //V any `bencode:"v,omitzero"`
+ // Seq *int64 `bencode:"seq,omitzero"`
+ // Cas int64 `bencode:"cas,omitzero"`
+ // K [32]byte `bencode:"k,omitzero"`
+ // Salt []byte `bencode:"salt,omitzero"`
+ // Sig [64]byte `bencode:"sig,omitzero"`
+}
+
+type Error struct {
+ Code int64
+ Msg string
+}
+
+func (e *Error) UnmarshalBencode(b []byte) error {
+ r := bencode.NewReaderFromBytes(b)
+ ok := r.ReadList(func(r *bencode.Reader) bool {
+ if !r.ReadInt(&e.Code) {
+ return false
+ }
+ if !r.ReadString(&e.Msg) {
+ return false
+ }
+ return true
+ })
+ if !ok {
+ return fmt.Errorf("error unmarshal failed: %w", r.Err())
+ }
+ return nil
+}
+
+func NewFindNodeQuery(id, target infohash.ID, wants []string) *Msg {
+ return &Msg{
+ Query: "find_node",
+ Type: "q",
+ Args: &MsgArgs{
+ // The querying node
+ ID: &id,
+ // The ID sought
+ Target: &target,
+ Want: wants,
+ },
+ }
+}
+
+func NewPingQuery(id infohash.ID) *Msg {
+ return &Msg{
+ Query: "ping",
+ Args: &MsgArgs{
+ ID: &id,
+ },
+ }
+}
+
+// func NewGetPeersQuery(id infohash.ID) *Msg {
+// msg := &Msg{
+// Query: "get_peers",
+// Type: "q",
+// Args: &MsgArgs{
+// ID: &c.node.ID,
+// InfoHash: &ih,
+// },
+// }
+// return &Msg{
+// Query: "ping",
+// TID: newTransactionID(),
+// Args: &MsgArgs{
+// ID: &id,
+// },
+// }
+// }
+
+// func NewSampleInfohashesQuery(id, target infohash.ID) *Msg {
+// return &Msg{
+// Query: "sample_infohashes",
+// TID: newTransactionID(),
+// Type: "q",
+// Args: &MsgArgs{
+// // The querying node
+// ID: &id,
+// // The ID sought
+// Target: &target,
+// },
+// }
+// }
+
+// const (
+// transIDBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+// )
+
+// makeQuery returns a query-formed data.
+// func MakeQuery(tid, query string, data map[string]any) map[string]any {
+// return map[string]any{
+// "t": tid,
+// "y": "q",
+// "q": query,
+// "a": data,
+// }
+// }
+
+// makeResponse returns a response-formed data.
+// func MakeResponse(tid string, data map[string]any) map[string]any {
+// return map[string]any{
+// "t": tid,
+// "y": "r",
+// "r": data,
+// }
+// }
+
+// func DecodeCompactNodeAddr(cni string) string {
+// if len(cni) == 6 {
+// return fmt.Sprintf("%d.%d.%d.%d:%d", cni[0], cni[1], cni[2], cni[3], (uint16(cni[4])<<8)|uint16(cni[5]))
+// } else if len(cni) == 18 {
+// b := []byte(cni[:16])
+// return fmt.Sprintf("[%s]:%d", net.IP.String(b), (uint16(cni[16])<<8)|uint16(cni[17]))
+// } else {
+// return ""
+// }
+// }
+
+// func EncodeCompactNodeAddr(addr string) string {
+// var a []uint8
+// host, port, _ := net.SplitHostPort(addr)
+// ip := net.ParseIP(host)
+// if ip == nil {
+// return ""
+// }
+// aa, _ := strconv.ParseUint(port, 10, 16)
+// c := uint16(aa)
+// if ip2 := net.IP.To4(ip); ip2 != nil {
+// a = make([]byte, net.IPv4len+2, net.IPv4len+2)
+// copy(a, ip2[0:net.IPv4len]) // ignore bytes IPv6 bytes if it's IPv4.
+// a[4] = byte(c >> 8)
+// a[5] = byte(c)
+// } else {
+// a = make([]byte, net.IPv6len+2, net.IPv6len+2)
+// copy(a, ip)
+// a[16] = byte(c >> 8)
+// a[17] = byte(c)
+// }
+// return string(a)
+// }
+
+// func int2bytes(val int64) []byte {
+// data, j := make([]byte, 8), -1
+// for i := 0; i < 8; i++ {
+// shift := uint64((7 - i) * 8)
+// data[i] = byte((val & (0xff << shift)) >> shift)
+
+// if j == -1 && data[i] != 0 {
+// j = i
+// }
+// }
+
+// if j != -1 {
+// return data[j:]
+// }
+// return data[:1]
+// }
diff --git a/dht/krpc_test.go b/dht/krpc_test.go
new file mode 100644
index 0000000..99f5fdf
--- /dev/null
+++ b/dht/krpc_test.go
@@ -0,0 +1,106 @@
+package dht
+
+import (
+ "testing"
+
+ "userspace.com.au/dhtsearch/bencode"
+ "userspace.com.au/dhtsearch/infohash"
+)
+
+func TestKRPCMsg(t *testing.T) {
+ id := infohash.ID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20})
+ tests := []struct {
+ in Msg
+ out string
+ }{{
+ in: Msg{Args: &MsgArgs{ID: &id}},
+ out: "d1:ad2:id20:\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14e1:t0:1:y0:e",
+ }, {
+ in: Msg{Args: &MsgArgs{Want: []string{"n4", "n6"}}},
+ out: "d1:ad4:wantl2:n42:n6ee1:t0:1:y0:e",
+ }}
+
+ for _, tt := range tests {
+ t.Run("marshal", func(t *testing.T) {
+ t.Logf("pre msg: %v", tt.in)
+ b, err := bencode.Marshal(tt.in)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(b) != tt.out {
+ t.Fatalf("got %q, want %q", string(b), tt.out)
+ }
+ })
+ // t.Run("unmarshal", func(t *testing.T) {
+ // var got Msg
+ // err := bencode.Unmarshal([]byte(tt.out), &got)
+ // if err != nil {
+ // t.Fatal(err)
+ // }
+ // if got != tt.in {
+ // t.Fatalf("got %v, want %v", got, tt.in)
+ // }
+ // })
+ }
+}
+
+func TestError(t *testing.T) {
+ in := "d1:eli201e17:too many requestse1:t2:rz1:y1:re"
+ var m Msg
+ if err := bencode.Unmarshal([]byte(in), &m); err != nil {
+ t.Fatal(err)
+ }
+ if m.Error.Code != 201 {
+ t.Errorf("got %v, want %d", m.Error.Code, 201)
+ }
+ if m.Error.Msg != "too many requests" {
+ t.Errorf("got %v, want %d", m.Error.Code, 201)
+ }
+}
+
+// func TestCompactNode(t *testing.T) {
+// ih := "infohashinfohash1234"
+// idIn, _ := infohash.FromString(ih)
+// tests := []struct {
+// ip string
+// port uint16
+// }{
+// {ip: "127.0.0.1", port: 6881},
+// {ip: "[2404:6800:4006:814::200e]", port: 6881},
+// }
+//
+// type compactNode interface {
+// ID() (*infohash.ID, error)
+// AddrPort() (netip.AddrPort, error)
+// }
+//
+// for _, tt := range tests {
+// apIn, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", tt.ip, tt.port))
+// if err != nil {
+// panic(err)
+// }
+// s := NewCompactNodeInfo(*idIn, apIn.Addr(), tt.port)
+// t.Logf("CompactNodeInfo: %d %v", len(s), []byte(s))
+//
+// var cni compactNode
+// if apIn.Addr().Is6() {
+// cni = CompactNode6Info(s)
+// } else {
+// cni = CompactNodeInfo(s)
+// }
+// ap, err := cni.AddrPort()
+// if err != nil {
+// t.Error(err)
+// }
+// if ap != apIn {
+// t.Errorf("got %v, want %v", ap, apIn)
+// }
+// id, err := cni.ID()
+// if err != nil {
+// t.Error(err)
+// }
+// if !id.Equal(*idIn) {
+// t.Errorf("got %v, want %v", id, idIn)
+// }
+// }
+// }
diff --git a/dht/ktable.go b/dht/ktable.go
new file mode 100644
index 0000000..76f9fdd
--- /dev/null
+++ b/dht/ktable.go
@@ -0,0 +1,411 @@
+package dht
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "io"
+ "log/slog"
+ "slices"
+ "sync"
+ "time"
+
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/metric"
+
+ "userspace.com.au/dhtsearch/bencode"
+ "userspace.com.au/dhtsearch/infohash"
+)
+
+type Table struct {
+ sync.RWMutex
+
+ // Local node infohash
+ id infohash.ID
+
+ // Root bucket for the tree
+ root *bucket
+
+ log *slog.Logger
+
+ // BEP5: Each bucket can only hold K nodes, currently eight, before
+ // becoming "full."
+ k int
+}
+
+type bucket struct {
+ nodes []*Node
+ left *bucket
+ right *bucket
+ dontSplit bool
+
+ // BEP5: Each bucket should maintain a "last changed" property to
+ // indicate how "fresh" the contents are. When a node in a bucket is
+ // pinged and it responds, or a node is added to a bucket, or a node in
+ // a bucket is replaced with another node, the bucket's last changed
+ // property should be updated. Buckets that have not been changed in
+ // 15 minutes should be "refreshed." This is done by picking a random
+ // ID in the range of the bucket and performing a find_nodes search on it.
+ lastChanged time.Time
+}
+
+var _ RoutingTable = (*Table)(nil)
+
+func NewTable(id infohash.ID) *Table {
+ out := &Table{
+ id: id,
+ root: newBucket(),
+ k: 8,
+ }
+ return out
+}
+
+func (t *Table) ID() infohash.ID {
+ return t.id
+}
+
+func (t *Table) Add(n *Node) bool {
+ if t.id.Equal(n.ID) {
+ t.log.Debug("not adding self")
+ return false
+ }
+ t.Lock()
+ defer t.Unlock()
+ return t.addNoLock(n)
+}
+
+// Potentially recursive
+func (t *Table) addNoLock(n *Node) bool {
+ log := t.log.With("id", n.ID, "addr", n.AddrPort)
+ b, bitIndex := t.locateBucket(n.ID)
+ if b.has(n.ID) {
+ //log.Debug("node exists")
+ return false
+ }
+ if len(b.nodes) < t.k {
+ b.add(n)
+ ktableCalls.Add(
+ context.TODO(), 1,
+ metric.WithAttributes(attribute.String("method", "add")))
+ log.Debug("added node")
+ return true
+ }
+ if b.dontSplit {
+ toReplace := b.mostQuestionable(2)
+ if toReplace != nil {
+ log.Debug(
+ "replacing node",
+ "old", toReplace.AddrPort,
+ "attempts", toReplace.PingAttempts,
+ "contact", toReplace.LastContact)
+ b.remove(toReplace.ID)
+ b.add(n)
+ ktableCalls.Add(
+ context.TODO(), 1,
+ metric.WithAttributes(attribute.String("method", "replace")))
+ return true
+ }
+ //log.Debug("bucket is full")
+ return false
+ }
+
+ // BEP5: When a bucket is full of known good nodes, no more nodes may be
+ // added unless our own node ID falls within the range of the bucket. In
+ // that case, the bucket is replaced by two new buckets each with half the
+ // range of the old bucket and the nodes from the old bucket are
+ // distributed among the two new ones. For a new table with only one
+ // bucket, the full bucket is always split into two new buckets covering
+ // the ranges 0..2^159 and 2^159..2^160.
+ b.split(bitIndex)
+ b.farChild(t.id, bitIndex).dontSplit = true
+ return t.addNoLock(n)
+}
+
+func (t *Table) Has(id infohash.ID) bool {
+ t.RLock()
+ defer t.RUnlock()
+ return t.hasNoLock(id)
+}
+
+func (t *Table) hasNoLock(id infohash.ID) bool {
+ b, _ := t.locateBucket(id)
+ return b.has(id)
+}
+
+func (t *Table) Remove(id infohash.ID) bool {
+ t.Lock()
+ defer t.Unlock()
+ return t.removeNoLock(id)
+}
+
+func (t *Table) removeNoLock(id infohash.ID) bool {
+ b, _ := t.locateBucket(id)
+ ktableCalls.Add(
+ context.TODO(), 1,
+ metric.WithAttributes(attribute.String("method", "remove")))
+ return b.remove(id)
+}
+
+func (t *Table) Count() int {
+ t.RLock()
+ defer t.RUnlock()
+ n := 0
+ for _, bucket := range nonEmptyBuckets(t.root) {
+ n += len(bucket.nodes)
+ }
+ return n
+}
+
+// Seen clears a node's questionable flag and updates contact time.
+func (t *Table) SetSeen(id infohash.ID) {
+ t.Lock()
+ defer t.Unlock()
+ t.seenNoLock(id)
+}
+
+func (t *Table) seenNoLock(id infohash.ID) {
+ bucket, _ := t.locateBucket(id)
+ bucket.update()
+ if n := bucket.find(id); n != nil {
+ if n.PingAttempts > 0 {
+ t.log.Debug("reseting questionable status", "address", n.AddrPort)
+ }
+ n.PingAttempts = 0
+ n.LastContact = time.Now()
+ }
+}
+
+func (t *Table) GetClosest(target infohash.ID, limit int) []*Node {
+ t.RLock()
+ defer t.RUnlock()
+ ktableCalls.Add(
+ context.TODO(), 1,
+ metric.WithAttributes(attribute.String("method", "closest")))
+ bitIndex := 0
+ buckets := []*bucket{t.root}
+ nodes := make([]*Node, 0, limit)
+ var bucket *bucket
+ for len(buckets) > 0 && len(nodes) < limit {
+ bucket, buckets = buckets[len(buckets)-1], buckets[:len(buckets)-1]
+ if bucket.nodes == nil {
+ near := bucket.nearChild(target, bitIndex)
+ far := bucket.farChild(target, bitIndex)
+ buckets = append(buckets, far, near)
+ bitIndex++
+ } else {
+ nodes = append(nodes, bucket.nodes...)
+ }
+ }
+ // We have less than requested
+ if length := len(nodes); limit > length {
+ limit = length
+ }
+ slices.SortFunc(nodes, func(a, b *Node) int {
+ aDist := target.Xor(a.ID)
+ bDist := target.Xor(b.ID)
+ return bytes.Compare(aDist, bDist)
+ })
+ return nodes[:limit]
+}
+
+func (t *Table) GetStale(limit int, d time.Duration) []*Node {
+ t.RLock()
+ defer t.RUnlock()
+ nodes := make([]*Node, 0)
+ ktableCalls.Add(
+ context.TODO(), 1,
+ metric.WithAttributes(attribute.String("method", "stale")))
+ buckets := nonEmptyBuckets(t.root)
+ slices.SortFunc(buckets, func(a, b *bucket) int {
+ // Oldest first
+ return b.lastChanged.Compare(a.lastChanged)
+ })
+ for _, b := range buckets {
+ for _, n := range b.nodes {
+ if len(nodes) < limit && time.Since(n.LastContact) > d {
+ nodes = append(nodes, n)
+ }
+ }
+ }
+ return nodes
+}
+
+// recursive
+func nonEmptyBuckets(b *bucket) []*bucket {
+ if b == nil {
+ return nil
+ }
+ var out []*bucket
+ if len(b.nodes) > 0 {
+ out = append(out, b)
+ }
+ out = append(out, nonEmptyBuckets(b.left)...)
+ out = append(out, nonEmptyBuckets(b.right)...)
+ return out
+}
+
+func (t *Table) locateBucket(id infohash.ID) (bucket *bucket, bitIndex int) {
+ bucket = t.root
+ for bucket.nodes == nil {
+ bucket = bucket.nearChild(id, bitIndex)
+ bitIndex++
+ }
+ return
+}
+
+func newBucket() *bucket {
+ return &bucket{
+ nodes: make([]*Node, 0),
+ //lastChanged: time.Now(),
+ }
+}
+
+func (b *bucket) update() {
+ b.lastChanged = time.Now()
+}
+
+func (b *bucket) add(n *Node) {
+ b.nodes = append(b.nodes, n)
+ b.update()
+}
+
+func (b *bucket) split(bitIndex int) {
+ b.left = newBucket()
+ b.right = newBucket()
+ for _, n := range b.nodes {
+ b.nearChild(n.ID, bitIndex).add(n)
+ }
+ b.nodes = nil
+}
+
+func (b *bucket) has(id infohash.ID) bool {
+ return b.indexOf(id) >= 0
+}
+
+func (b *bucket) nearChild(id infohash.ID, bitIndex int) *bucket {
+ bitIndexWithinByte := bitIndex % 8
+ desiredByte := id[bitIndex/8]
+ if desiredByte&(1<<(uint(7-bitIndexWithinByte))) == 1 {
+ return b.right
+ }
+ return b.left
+}
+
+func (b *bucket) farChild(id infohash.ID, bitIndex int) *bucket {
+ if c := b.nearChild(id, bitIndex); c == b.right {
+ return b.left
+ }
+ return b.right
+}
+
+func (b *bucket) remove(id infohash.ID) bool {
+ var out bool
+ b.nodes = slices.DeleteFunc(b.nodes, func(n *Node) bool {
+ out = n.ID.Equal(id)
+ return out
+ })
+ return out
+}
+
+func (b *bucket) indexOf(id infohash.ID) int {
+ for i, c := range b.nodes {
+ if id.Equal(c.ID) {
+ return i
+ }
+ }
+ return -1
+}
+
+func (b *bucket) find(id infohash.ID) *Node {
+ if index := b.indexOf(id); index > -1 {
+ return b.nodes[index]
+ }
+ return nil
+}
+
+// mostQuestionable finds the node with the most ping attempts
+func (b *bucket) mostQuestionable(attempts int) *Node {
+ var out *Node
+ for _, n := range b.nodes {
+ if n.PingAttempts > attempts {
+ out = n
+ attempts = n.PingAttempts
+ }
+ }
+ return out
+}
+
+func (t *Table) WriteTo(w io.Writer) (int64, error) {
+ t.RLock()
+ defer t.RUnlock()
+
+ var c int64
+ var n int
+ var err error
+
+ // Infohash first
+ b, err := bencode.EncodeString(string(t.id[:]))
+ if err != nil {
+ return c, err
+ }
+ n, err = w.Write(b)
+ if err != nil {
+ return c, err
+ }
+ c += int64(n)
+ t.log.Debug("exporting table", "id", t.id)
+
+ // This should write n4 nodes correctly too
+ var l int
+ for _, bu := range nonEmptyBuckets(t.root) {
+ for _, node := range bu.nodes {
+ cni, err := node.MarshalBinary()
+ if err != nil {
+ return c, err
+ }
+ b, err := bencode.EncodeString(string(cni))
+ if err != nil {
+ return c, err
+ }
+ n, err = w.Write(b)
+ if err != nil {
+ return c, err
+ }
+ l += 1
+ c += int64(n)
+ t.log.Debug("exported node", "id", node.ID, "address", node.AddrPort)
+ }
+ }
+ t.log.Info("exported table", "id", t.id, "count", l)
+ return c, err
+}
+
+func (t *Table) ReadFrom(r io.Reader) (int64, error) {
+ t.Lock()
+ defer t.Unlock()
+
+ br := bencode.NewReader(bufio.NewReader(r))
+ var ih string
+ if !br.ReadString(&ih) {
+ return br.Count(), br.Err()
+ }
+ t.id = infohash.ID([]byte(ih))
+ t.log.Debug("importing table", "id", t.id)
+ var nodes []Node
+ for br.Err() == nil {
+ var cni string
+ if br.ReadString(&cni) {
+ var node Node
+ if err := node.UnmarshalBinary([]byte(cni)); err != nil {
+ return br.Count(), err
+ }
+ nodes = append(nodes, node)
+ t.log.Debug("imported node", "id", node.ID, "address", node.AddrPort)
+ }
+ }
+ for _, node := range nodes {
+ t.addNoLock(&node)
+ }
+ t.log.Info("imported table", "id", t.id, "count", len(nodes))
+ return br.Count(), nil
+}
diff --git a/dht/ktable_test.go b/dht/ktable_test.go
new file mode 100644
index 0000000..266c1af
--- /dev/null
+++ b/dht/ktable_test.go
@@ -0,0 +1,106 @@
+package dht
+
+import (
+ "bytes"
+ "net/netip"
+ "testing"
+
+ "github.com/neilotoole/slogt"
+
+ "userspace.com.au/dhtsearch/infohash"
+)
+
+func TestTable(t *testing.T) {
+ id := bytes.Repeat([]byte{0}, 20)
+ tbl := NewTable(infohash.ID(id))
+ tbl.log = slogt.New(t)
+
+ var nodes []*Node
+
+ for range 100 {
+ n := &Node{
+ ID: infohash.NewRandomID(),
+ }
+ if tbl.Add(n) {
+ nodes = append(nodes, n)
+ }
+ }
+
+ t.Run("updating", func(t *testing.T) {
+ t.Parallel()
+ for _, n := range nodes {
+ tbl.SetSeen(n.ID)
+ }
+ })
+ t.Run("lookup", func(t *testing.T) {
+ t.Parallel()
+ for _, n := range nodes {
+ tbl.Has(n.ID)
+ }
+ })
+ t.Run("remove add", func(t *testing.T) {
+ t.Parallel()
+ for _, n := range nodes {
+ tbl.Remove(n.ID)
+ tbl.Add(n)
+ }
+ })
+}
+
+func TestTableExport(t *testing.T) {
+ nodes := []Node{
+ {
+ ID: infohash.MustParseString("0000000000000000000000000000000000000000"),
+ AddrPort: netip.MustParseAddrPort("[2001:19f0:5:6d01:5400:2ff:feec:644a]:6881")},
+ {
+ ID: infohash.MustParseString("c7fb3eab2d33331610075ed0322f6cebc5351aed"),
+ AddrPort: netip.MustParseAddrPort("[2607:9000:3000:33:d5fe:c61c:6adc:6b79]:17980")},
+ {
+ ID: infohash.MustParseString("c7fba7fc5ec5dfd49cff98b31b68e37b0a08fd84"),
+ AddrPort: netip.MustParseAddrPort("[240e:3b3:9613:52f0:215:5dff:fe0a:8971]:61128")},
+ {
+ ID: infohash.MustParseString("c7fc64a9ebb14481981529ba200a40d00643eb2a"),
+ AddrPort: netip.MustParseAddrPort("[2001:da8:e000:3002:20c:29ff:fea3:7226]:10301")},
+ {
+ ID: infohash.MustParseString("c5198df3d46a32799ac34436a4737d7b3bb2ff6e"),
+ AddrPort: netip.MustParseAddrPort("[240e:b8f:966e:9e00::bf9]:63219")},
+ {
+ ID: infohash.MustParseString("c1ea1df5ce162836c2080bdc94a127a90cfad24a"),
+ AddrPort: netip.MustParseAddrPort("[2a01:e0a:d5b:c690:211:32ff:fe26:7445]:51413")},
+ {
+ ID: infohash.MustParseString("c17f24479d8c5ad3ad7a731a600ff323ecb26ebd"),
+ AddrPort: netip.MustParseAddrPort("[2001:e68:542f:5959:11b7:2243:527d:1c86]:19264")},
+ }
+ id := infohash.MustParseString("c4bee4bf16dd88527f63cab902b1111111111111")
+ tbl := NewTable(id)
+ tbl.log = slogt.New(t)
+ for _, n := range nodes {
+ tbl.Add(&n)
+ }
+ if tbl.Count() != len(nodes) {
+ t.Fatalf("got %d, want %d", tbl.Count(), len(nodes))
+ }
+
+ buf := new(bytes.Buffer)
+ n1, err := tbl.WriteTo(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tbl2 := &Table{
+ root: newBucket(),
+ k: 8,
+ }
+ tbl2.log = slogt.New(t)
+ n2, err := tbl2.ReadFrom(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n1 != n2 {
+ t.Errorf("got %d, want %d", n2, n1)
+ }
+ for _, n := range nodes {
+ if !tbl.Has(n.ID) {
+ t.Errorf("missing id %s", n.ID)
+ }
+ }
+}
diff --git a/dht/messages.go b/dht/messages.go
deleted file mode 100644
index 188ed0f..0000000
--- a/dht/messages.go
+++ /dev/null
@@ -1,127 +0,0 @@
-package dht
-
-import (
- "fmt"
- "net"
-
- "src.userspace.com.au/dhtsearch/krpc"
- "src.userspace.com.au/dhtsearch/models"
-)
-
-func (n *Node) onPingQuery(rn remoteNode, msg map[string]interface{}) error {
- t, err := krpc.GetString(msg, "t")
- if err != nil {
- return err
- }
- n.queueMsg(rn, krpc.MakeResponse(t, map[string]interface{}{
- "id": string(n.id),
- }))
- return nil
-}
-
-func (n *Node) onGetPeersQuery(rn remoteNode, msg map[string]interface{}) error {
- a, err := krpc.GetMap(msg, "a")
- if err != nil {
- return err
- }
-
- // This is the ih of the torrent
- torrent, err := krpc.GetString(a, "info_hash")
- if err != nil {
- return err
- }
- th, err := models.InfohashFromString(torrent)
- if err != nil {
- return err
- }
- //n.log.Debug("get_peers query", "source", rn, "torrent", th)
-
- token := torrent[:2]
- neighbour := models.GenerateNeighbour(n.id, *th)
- /*
- nodes := n.rTable.get(8)
- compactNS := []string{}
- for _, rn := range nodes {
- ns := encodeCompactNodeAddr(rn.addr.String())
- if ns == "" {
- n.log.Warn("failed to compact node", "address", rn.address.String())
- continue
- }
- compactNS = append(compactNS, ns)
- }
- */
-
- t := msg["t"].(string)
- n.queueMsg(rn, krpc.MakeResponse(t, map[string]interface{}{
- "id": string(neighbour),
- "token": token,
- "nodes": "",
- //"nodes": strings.Join(compactNS, ""),
- }))
-
- //nodes := n.rTable.get(50)
- /*
- fmt.Printf("sending get_peers for %s to %d nodes\n", *th, len(nodes))
- q := krpc.MakeQuery(newTransactionID(), "get_peers", map[string]interface{}{
- "id": string(id),
- "info_hash": string(*th),
- })
- for _, o := range nodes {
- n.queueMsg(*o, q)
- }
- */
- return nil
-}
-
-func (n *Node) onAnnouncePeerQuery(rn remoteNode, msg map[string]interface{}) error {
- a, err := krpc.GetMap(msg, "a")
- if err != nil {
- return err
- }
-
- n.log.Debug("announce_peer", "source", rn)
-
- host, port, err := net.SplitHostPort(rn.addr.String())
- if err != nil {
- return err
- }
- if port == "0" {
- return fmt.Errorf("ignoring port 0")
- }
-
- ihStr, err := krpc.GetString(a, "info_hash")
- if err != nil {
- return err
- }
- ih, err := models.InfohashFromString(ihStr)
- if err != nil {
- return fmt.Errorf("invalid torrent: %s", err)
- }
-
- newPort, err := krpc.GetInt(a, "port")
- if err == nil {
- if iPort, err := krpc.GetInt(a, "implied_port"); err == nil && iPort == 0 {
- // Use the port in the message
- addr, err := net.ResolveUDPAddr(n.family, fmt.Sprintf("%s:%d", host, newPort))
- if err != nil {
- return err
- }
- n.log.Debug("implied port", "infohash", ih, "original", rn.addr.String(), "new", addr.String())
- rn = remoteNode{addr: addr, id: rn.id}
- }
- }
-
- // TODO do we reply?
-
- p := models.Peer{Addr: rn.addr, Infohash: *ih}
- if n.OnAnnouncePeer != nil {
- go n.OnAnnouncePeer(p)
- }
- return nil
-}
-
-func (n *Node) onFindNodeResponse(rn remoteNode, msg map[string]interface{}) {
- r := msg["r"].(map[string]interface{})
- nodes := r["nodes"].(string)
- n.processFindNodeResults(rn, nodes)
-}
diff --git a/dht/metrics.go b/dht/metrics.go
new file mode 100644
index 0000000..da708ba
--- /dev/null
+++ b/dht/metrics.go
@@ -0,0 +1,76 @@
+package dht
+
+import (
+ "context"
+ "time"
+
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/metric"
+)
+
+var meter = otel.Meter("userspace.com.au/dhtsearch/dht")
+
+var (
+ netOctets metric.Int64Counter
+ netPackets metric.Int64Counter
+)
+
+var ktableCalls metric.Int64Counter
+
+func init() {
+ var err error
+ netOctets, err = meter.Int64Counter(
+ "network.octets",
+ metric.WithDescription("The number of bytes."),
+ metric.WithUnit("{byte}"),
+ )
+ if err != nil {
+ panic(err)
+ }
+ netPackets, err = meter.Int64Counter(
+ "bytes.received",
+ metric.WithDescription("The number of packets."),
+ metric.WithUnit("{byte}"),
+ )
+ if err != nil {
+ panic(err)
+ }
+ ktableCalls, err = meter.Int64Counter(
+ "ktable.calls",
+ metric.WithDescription("Number of calls to ktable."),
+ metric.WithUnit("{call}"),
+ )
+ if err != nil {
+ panic(err)
+ }
+}
+
+func configureMetrics(c *Client) error {
+ var err error
+ start := time.Now()
+ if _, err = meter.Float64ObservableCounter(
+ "uptime",
+ metric.WithDescription("The running duration."),
+ metric.WithUnit("s"),
+ metric.WithFloat64Callback(func(_ context.Context, o metric.Float64Observer) error {
+ o.Observe(float64(time.Since(start).Seconds()))
+ return nil
+ }),
+ ); err != nil {
+ panic(err)
+ }
+
+ _, err = meter.Int64ObservableUpDownCounter(
+ "ktable.size",
+ metric.WithDescription("The number of nodes in the ktable."),
+ metric.WithUnit("{nodes}"),
+ metric.WithInt64Callback(func(_ context.Context, o metric.Int64Observer) error {
+ o.Observe(int64(c.table.Count()))
+ return nil
+ }),
+ )
+ if err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/dht/node.go b/dht/node.go
index 2ea5f2f..437f13d 100644
--- a/dht/node.go
+++ b/dht/node.go
@@ -1,427 +1,76 @@
package dht
import (
- "context"
- "fmt"
- "net"
+ "errors"
+ "net/netip"
+ "sync"
"time"
- "github.com/hashicorp/golang-lru"
- "golang.org/x/time/rate"
- "src.userspace.com.au/dhtsearch/krpc"
- "src.userspace.com.au/dhtsearch/models"
- "src.userspace.com.au/go-bencode"
- "src.userspace.com.au/logger"
+ "userspace.com.au/dhtsearch/infohash"
)
-var (
- routers = []string{
- "dht.libtorrent.org:25401",
- "router.bittorrent.com:6881",
- "dht.transmissionbt.com:6881",
- "router.utorrent.com:6881",
- "dht.aelitis.com:6881",
- }
-)
-
-// Node joins the DHT network
type Node struct {
- id models.Infohash
- family string
- address string
- port int
- conn net.PacketConn
- pool chan chan packet
- rTable *routingTable
- udpTimeout int
- packetsOut chan packet
- log logger.Logger
- limiter *rate.Limiter
- blacklist *lru.ARCCache
-
- // OnAnnoucePeer is called for each peer that announces itself
- OnAnnouncePeer func(models.Peer)
- // OnBadPeer is called for each bad peer
- OnBadPeer func(models.Peer)
-}
-
-// NewNode creates a new DHT node
-func NewNode(opts ...Option) (*Node, error) {
- var err error
- id := models.GenInfohash()
-
- n := &Node{
- id: id,
- family: "udp4",
- port: 6881,
- udpTimeout: 10,
- limiter: rate.NewLimiter(rate.Limit(100000), 2000000),
- log: logger.New(&logger.Options{Name: "dht"}),
- }
-
- n.rTable, err = newRoutingTable(id, 2000)
- if err != nil {
- n.log.Error("failed to create routing table", "error", err)
- return nil, err
- }
-
- // Set variadic options passed
- for _, option := range opts {
- err = option(n)
- if err != nil {
- return nil, err
- }
- }
-
- if n.blacklist == nil {
- n.blacklist, err = lru.NewARC(1000)
- if err != nil {
- return nil, err
- }
- }
-
- if n.family != "udp4" {
- n.log.Debug("trying udp6 server")
- n.conn, err = net.ListenPacket("udp6", fmt.Sprintf("[%s]:%d", net.IPv6zero.String(), n.port))
- if err == nil {
- n.family = "udp6"
- }
- }
- if n.conn == nil {
- n.conn, err = net.ListenPacket("udp4", fmt.Sprintf("%s:%d", net.IPv4zero.String(), n.port))
- if err == nil {
- n.family = "udp4"
- }
- }
- if err != nil {
- n.log.Error("failed to listen", "error", err)
- return nil, err
- }
- n.log.Info("listening", "id", n.id, "network", n.family, "address", n.conn.LocalAddr().String())
-
- return n, nil
-}
-
-// Close stuff
-func (n *Node) Close() error {
- n.log.Warn("node closing")
- return nil
-}
-
-// Run starts the node on the DHT
-func (n *Node) Run() {
- // Packets onto the network
- n.packetsOut = make(chan packet, 1024)
-
- // Create a slab for allocation
- byteSlab := newSlab(8192, 10)
+ sync.Mutex
+ ID infohash.ID
+ AddrPort netip.AddrPort
+ Secure bool
+ Family string
- n.log.Debug("starting packet writer")
- go n.packetWriter()
+ // To assist with rate limiting
+ NextQuery map[string]time.Time
- // Find neighbours
- go n.makeNeighbours()
+ // Incremented when selected for ping, cleared on pong
+ PingAttempts int
- n.log.Debug("starting packet reader")
- for {
- b := byteSlab.alloc()
- c, addr, err := n.conn.ReadFrom(b)
- if err != nil {
- n.log.Warn("UDP read error", "error", err)
- return
- }
-
- // Chop and process
- n.processPacket(packet{
- data: b[0:c],
- raddr: addr,
- })
- byteSlab.free(b)
- }
+ // Used by table
+ LastContact time.Time
}
-func (n *Node) makeNeighbours() {
- // TODO configurable
- ticker := time.Tick(5 * time.Second)
-
- n.bootstrap()
-
- for {
- select {
- case <-ticker:
- if n.rTable.isEmpty() {
- n.bootstrap()
- } else {
- // Send to all nodes
- nodes := n.rTable.get(0)
- for _, rn := range nodes {
- n.findNode(rn, models.GenerateNeighbour(n.id, rn.id))
- }
- n.rTable.flush()
- }
- }
- }
+func (n *Node) String() string {
+ return n.AddrPort.String()
}
-func (n *Node) bootstrap() {
- n.log.Debug("bootstrapping")
- for _, s := range routers {
- addr, err := net.ResolveUDPAddr(n.family, s)
- if err != nil {
- n.log.Error("failed to parse bootstrap address", "error", err)
- continue
- }
- rn := &remoteNode{addr: addr}
- n.findNode(rn, n.id)
+func (n *Node) canSend(query string) bool {
+ if len(n.NextQuery) == 0 {
+ return true
}
+ n.Lock()
+ defer n.Unlock()
+ t, ok := n.NextQuery[query]
+ return !ok || time.Now().After(t)
}
-func (n *Node) packetWriter() {
- for p := range n.packetsOut {
- if p.raddr.String() == n.conn.LocalAddr().String() {
- continue
- }
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
- if err := n.limiter.WaitN(ctx, len(p.data)); err != nil {
- n.log.Warn("rate limited", "error", err)
- continue
- }
- //n.log.Debug("writing packet", "dest", p.raddr.String())
- _, err := n.conn.WriteTo(p.data, p.raddr)
- if err != nil {
- n.blacklist.Add(p.raddr.String(), true)
- // TODO reduce limit
- n.log.Warn("failed to write packet", "error", err)
- if n.OnBadPeer != nil {
- peer := models.Peer{Addr: p.raddr}
- go n.OnBadPeer(peer)
- }
- }
- }
-}
-
-func (n *Node) findNode(rn *remoteNode, id models.Infohash) {
- target := models.GenInfohash()
- n.sendQuery(rn, "find_node", map[string]interface{}{
- "id": string(id),
- "target": string(target),
- })
-}
-
-// ping sends ping query to the chan.
-func (n *Node) ping(rn *remoteNode) {
- id := models.GenerateNeighbour(n.id, rn.id)
- n.sendQuery(rn, "ping", map[string]interface{}{
- "id": string(id),
- })
-}
-
-func (n *Node) sendQuery(rn *remoteNode, qType string, a map[string]interface{}) error {
- // Stop if sending to self
- if rn.id.Equal(n.id) {
- return nil
- }
-
- t := krpc.NewTransactionID()
-
- data := krpc.MakeQuery(t, qType, a)
- b, err := bencode.Encode(data)
- if err != nil {
- return err
- }
- //fmt.Printf("sending %s to %s\n", qType, rn.String())
- n.packetsOut <- packet{
- data: b,
- raddr: rn.addr,
- }
- return nil
-}
-
-// Parse a KRPC packet into a message
-func (n *Node) processPacket(p packet) error {
- response, _, err := bencode.DecodeDict(p.data, 0)
- if err != nil {
- return err
- }
-
- y, err := krpc.GetString(response, "y")
- if err != nil {
- return err
- }
-
- if _, black := n.blacklist.Get(p.raddr.String()); black {
- return fmt.Errorf("blacklisted: %s", p.raddr.String())
+func (n *Node) rateQuery(q string, s int64) {
+ n.Lock()
+ defer n.Unlock()
+ if len(n.NextQuery) == 0 {
+ n.NextQuery = make(map[string]time.Time)
}
-
- switch y {
- case "q":
- err = n.handleRequest(p.raddr, response)
- case "r":
- err = n.handleResponse(p.raddr, response)
- case "e":
- err = n.handleError(p.raddr, response)
- default:
- err = fmt.Errorf("missing request type")
+ if s == 0 {
+ n.NextQuery[q] = time.Now()
+ return
}
- if err != nil {
- n.log.Warn("failed to process packet", "error", err)
- n.blacklist.Add(p.raddr.String(), true)
+ if s == -1 {
+ // TODO delay one day
+ s = 24 * 60 * 60
}
- return err
+ n.NextQuery[q] = time.Now().Add(time.Duration(s) * time.Second)
}
-// bencode data and send
-func (n *Node) queueMsg(rn remoteNode, data map[string]interface{}) error {
- b, err := bencode.Encode(data)
- if err != nil {
- return err
- }
- n.packetsOut <- packet{
- data: b,
- raddr: rn.addr,
- }
- return nil
+func (n *Node) MarshalBinary() ([]byte, error) {
+ b, _ := NodeAddr(n.AddrPort).MarshalBinary()
+ return append(n.ID.Bytes(), b...), nil
}
-// handleRequest handles the requests received from udp.
-func (n *Node) handleRequest(addr net.Addr, m map[string]interface{}) error {
- q, err := krpc.GetString(m, "q")
- if err != nil {
- return err
- }
-
- a, err := krpc.GetMap(m, "a")
- if err != nil {
- return err
- }
-
- id, err := krpc.GetString(a, "id")
- if err != nil {
- return err
- }
-
- ih, err := models.InfohashFromString(id)
- if err != nil {
- return err
- }
-
- if n.id.Equal(*ih) {
- return nil
- }
-
- rn := &remoteNode{addr: addr, id: *ih}
-
- switch q {
- case "ping":
- err = n.onPingQuery(*rn, m)
-
- case "get_peers":
- err = n.onGetPeersQuery(*rn, m)
-
- case "announce_peer":
- n.onAnnouncePeerQuery(*rn, m)
-
- default:
- //n.queueMsg(addr, makeError(t, protocolError, "invalid q"))
- return nil
- }
- n.rTable.add(rn)
- return err
-}
-
-// handleResponse handles responses received from udp.
-func (n *Node) handleResponse(addr net.Addr, m map[string]interface{}) error {
- r, err := krpc.GetMap(m, "r")
- if err != nil {
- return err
- }
- id, err := krpc.GetString(r, "id")
- if err != nil {
- return err
- }
- ih, err := models.InfohashFromString(id)
- if err != nil {
- return err
- }
-
- rn := &remoteNode{addr: addr, id: *ih}
-
- nodes, err := krpc.GetString(r, "nodes")
- // find_nodes/get_peers response with nodes
- if err == nil {
- n.onFindNodeResponse(*rn, m)
- n.processFindNodeResults(*rn, nodes)
- n.rTable.add(rn)
+func (n *Node) UnmarshalBinary(b []byte) error {
+ n.ID = infohash.ID(b[:infohash.Length])
+ var na NodeAddr
+ if err := na.UnmarshalBinary(b[infohash.Length:]); err != nil {
return nil
}
-
- values, err := krpc.GetList(r, "values")
- // get_peers response
- if err == nil {
- n.log.Debug("get_peers response", "source", rn)
- for _, v := range values {
- addr := krpc.DecodeCompactNodeAddr(v.(string))
- n.log.Debug("unhandled get_peer request", "addres", addr)
-
- // TODO new peer needs to be matched to previous get_peers request
- // n.peersManager.Insert(ih, p)
- }
- n.rTable.add(rn)
+ n.AddrPort = netip.AddrPort(na)
+ if !n.AddrPort.IsValid() || n.AddrPort.Port() == 0 {
+ return errors.New("invalid AddrPort")
}
return nil
}
-
-// handleError handles errors received from udp.
-func (n *Node) handleError(addr net.Addr, m map[string]interface{}) error {
- e, err := krpc.GetList(m, "e")
- if err != nil {
- return err
- }
-
- if len(e) != 2 {
- return fmt.Errorf("error packet wrong length %d", len(e))
- }
- code := e[0].(int64)
- msg := e[1].(string)
- n.log.Debug("error packet", "address", addr.String(), "code", code, "error", msg)
-
- return nil
-}
-
-// Process another node's response to a find_node query.
-func (n *Node) processFindNodeResults(rn remoteNode, nodeList string) {
- nodeLength := krpc.IPv4NodeAddrLen
- if n.family == "udp6" {
- nodeLength = krpc.IPv6NodeAddrLen
- }
-
- if len(nodeList)%nodeLength != 0 {
- n.log.Error("node list is wrong length", "length", len(nodeList))
- n.blacklist.Add(rn.addr.String(), true)
- return
- }
-
- //fmt.Printf("%s sent %d nodes\n", rn.address.String(), len(nodeList)/nodeLength)
-
- // We got a byte array in groups of 26 or 38
- for i := 0; i < len(nodeList); i += nodeLength {
- id := nodeList[i : i+models.InfohashLength]
- addrStr := krpc.DecodeCompactNodeAddr(nodeList[i+models.InfohashLength : i+nodeLength])
-
- ih, err := models.InfohashFromString(id)
- if err != nil {
- n.log.Warn("invalid infohash in node list")
- continue
- }
-
- addr, err := net.ResolveUDPAddr(n.family, addrStr)
- if err != nil || addr.Port == 0 {
- //n.log.Warn("unable to resolve", "address", addrStr, "error", err)
- continue
- }
-
- rn := &remoteNode{addr: addr, id: *ih}
- n.rTable.add(rn)
- }
-}
diff --git a/dht/node_test.go b/dht/node_test.go
new file mode 100644
index 0000000..6926c1e
--- /dev/null
+++ b/dht/node_test.go
@@ -0,0 +1,42 @@
+package dht
+
+import (
+ "bytes"
+ "net/netip"
+ "testing"
+
+ "userspace.com.au/dhtsearch/infohash"
+)
+
+func TestNodeMarshaling(t *testing.T) {
+ tests := []Node{
+ {
+ ID: infohash.MustParseString("0000000000000000000000000000000000000000"),
+ AddrPort: netip.MustParseAddrPort("[2001:19f0:5:6d01:5400:2ff:feec:644a]:6881")},
+ {
+ ID: infohash.MustParseString("c4301bf4b0a4f83eb7afe1eeee8434fb1893cc15"),
+ AddrPort: netip.MustParseAddrPort("[2001:f90:4090:1230:2697:edff:fe27:d081]:50000")},
+ {
+ ID: infohash.MustParseString("c17f24479d8c5ad3ad7a731a600ff323ecb26ebd"),
+ AddrPort: netip.MustParseAddrPort("[2001:e68:542f:5959:11b7:2243:527d:1c86]:19264")},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.ID.String(), func(t *testing.T) {
+ b, err := tt.MarshalBinary()
+ if err != nil {
+ t.Fatal(err)
+ }
+ n2 := new(Node)
+ if err := n2.UnmarshalBinary(b); err != nil {
+ t.Fatal(err)
+ }
+ if tt.AddrPort.Compare(n2.AddrPort) != 0 {
+ t.Fatalf("got %s, want %s", n2.AddrPort, tt.AddrPort)
+ }
+ if !bytes.Equal(n2.ID[:], tt.ID[:]) {
+ t.Fatalf("got %s, want %s", n2.ID, tt.ID)
+ }
+ })
+ }
+}
diff --git a/dht/options.go b/dht/options.go
deleted file mode 100644
index 19a5508..0000000
--- a/dht/options.go
+++ /dev/null
@@ -1,73 +0,0 @@
-package dht
-
-import (
- "github.com/hashicorp/golang-lru"
- "src.userspace.com.au/dhtsearch/models"
- "src.userspace.com.au/logger"
-)
-
-type Option func(*Node) error
-
-func SetOnAnnouncePeer(f func(models.Peer)) Option {
- return func(n *Node) error {
- n.OnAnnouncePeer = f
- return nil
- }
-}
-
-func SetOnBadPeer(f func(models.Peer)) Option {
- return func(n *Node) error {
- n.OnBadPeer = f
- return nil
- }
-}
-
-// SetAddress sets the IP address to listen on
-func SetAddress(ip string) Option {
- return func(n *Node) error {
- n.address = ip
- return nil
- }
-}
-
-// SetPort sets the port to listen on
-func SetPort(p int) Option {
- return func(n *Node) error {
- n.port = p
- return nil
- }
-}
-
-// SetIPv6 enables IPv6
-func SetIPv6(b bool) Option {
- return func(n *Node) error {
- if b {
- n.family = "udp6"
- }
- return nil
- }
-}
-
-// SetUDPTimeout sets the number of seconds to wait for UDP connections
-func SetUDPTimeout(s int) Option {
- return func(n *Node) error {
- n.udpTimeout = s
- return nil
- }
-}
-
-// SetLogger sets the logger
-func SetLogger(l logger.Logger) Option {
- return func(n *Node) error {
- n.log = l
- return nil
- }
-}
-
-// SetBlacklist sets the size of the node blacklist
-func SetBlacklist(bl *lru.ARCCache) Option {
- return func(n *Node) (err error) {
- n.blacklist = bl
- return err
- }
-}
diff --git a/dht/packet.go b/dht/packet.go
deleted file mode 100644
index 62c08fa..0000000
--- a/dht/packet.go
+++ /dev/null
@@ -1,33 +0,0 @@
-package dht
-
-import "net"
-
-// Arbitrary packet types
-// Order these lowest to highest priority for use in
-// priority queue heap
-const (
- _ int = iota
- pktQPing
- pktRPing
- pktQFindNode
- pktRAnnouncePeer
- pktRGetPeers
-)
-
-var pktName = map[int]string{
- pktQFindNode: "find_node",
- pktQPing: "ping",
- pktRPing: "ping",
- pktRAnnouncePeer: "annouce_peer",
- pktRGetPeers: "get_peers",
-}
-
-// Unprocessed packet from socket
-type packet struct {
- // The packet type
- //priority int
- // Required by heap interface
- //index int
- data []byte
- raddr net.Addr
-}
diff --git a/dht/remote_node.go b/dht/remote_node.go
deleted file mode 100644
index 0a5cafa..0000000
--- a/dht/remote_node.go
+++ /dev/null
@@ -1,18 +0,0 @@
-package dht
-
-import (
- "fmt"
- "net"
-
- "src.userspace.com.au/dhtsearch/models"
-)
-
-type remoteNode struct {
- addr net.Addr
- id models.Infohash
-}
-
-// String implements fmt.Stringer
-func (r remoteNode) String() string {
- return fmt.Sprintf("%s (%s)", r.id.String(), r.addr.String())
-}
diff --git a/dht/routing_table.go b/dht/routing_table.go
deleted file mode 100644
index 37af542..0000000
--- a/dht/routing_table.go
+++ /dev/null
@@ -1,121 +0,0 @@
-package dht
-
-import (
- "container/heap"
- "sync"
-
- "src.userspace.com.au/dhtsearch/models"
-)
-
-type rItem struct {
- value *remoteNode
- distance int
- index int // Index in heap
-}
-
-type priorityQueue []*rItem
-
-type routingTable struct {
- id models.Infohash
- max int
- items priorityQueue
- addresses map[string]*remoteNode
- sync.Mutex
-}
-
-func newRoutingTable(id models.Infohash, max int) (*routingTable, error) {
- k := &routingTable{
- id: id,
- max: max,
- }
- k.flush()
- heap.Init(&k.items)
- return k, nil
-}
-
-// Len implements sort.Interface
-func (pq priorityQueue) Len() int { return len(pq) }
-
-// Less implements sort.Interface
-func (pq priorityQueue) Less(i, j int) bool {
- return pq[i].distance > pq[j].distance
-}
-
-// Swap implements sort.Interface
-func (pq priorityQueue) Swap(i, j int) {
- pq[i], pq[j] = pq[j], pq[i]
- pq[i].index = i
- pq[j].index = j
-}
-
-// Push implements heap.Interface
-func (pq *priorityQueue) Push(x interface{}) {
- n := len(*pq)
- item := x.(*rItem)
- item.index = n
- *pq = append(*pq, item)
-}
-
-// Pop implements heap.Interface
-func (pq *priorityQueue) Pop() interface{} {
- old := *pq
- n := len(old)
- item := old[n-1]
- item.index = -1 // for safety
- *pq = old[0 : n-1]
- return item
-}
-
-func (k *routingTable) add(rn *remoteNode) {
- // Check IP and ports are valid and not self
- if !rn.id.Valid() || rn.id.Equal(k.id) {
- return
- }
-
- k.Lock()
- defer k.Unlock()
-
- if _, ok := k.addresses[rn.addr.String()]; ok {
- return
- }
- k.addresses[rn.addr.String()] = rn
-
- item := &rItem{
- value: rn,
- distance: k.id.Distance(rn.id),
- }
-
- heap.Push(&k.items, item)
-
- if len(k.items) > k.max {
- for i := k.max - 1; i < len(k.items); i++ {
- old := k.items[i]
- delete(k.addresses, old.value.addr.String())
- heap.Remove(&k.items, i)
- }
- }
-}
-
-func (k *routingTable) get(n int) (out []*remoteNode) {
- if n == 0 {
- n = len(k.items)
- }
- for i := 0; i < n && i < len(k.items); i++ {
- out = append(out, k.items[i].value)
- }
- return out
-}
-
-func (k *routingTable) flush() {
- k.Lock()
- defer k.Unlock()
-
- k.items = make(priorityQueue, 0)
- k.addresses = make(map[string]*remoteNode, k.max)
-}
-
-func (k *routingTable) isEmpty() bool {
- k.Lock()
- defer k.Unlock()
- return len(k.items) == 0
-}
diff --git a/dht/routing_table_test.go b/dht/routing_table_test.go
deleted file mode 100644
index 763df80..0000000
--- a/dht/routing_table_test.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package dht
-
-import (
- "fmt"
- "net"
- "testing"
-
- "src.userspace.com.au/dhtsearch/models"
-)
-
-func TestPriorityQueue(t *testing.T) {
- id := "d1c5676ae7ac98e8b19f63565905105e3c4c37a2"
-
- tests := []string{
- "d1c5676ae7ac98e8b19f63565905105e3c4c37b9",
- "d1c5676ae7ac98e8b19f63565905105e3c4c37a9",
- "d1c5676ae7ac98e8b19f63565905105e3c4c37a4",
- "d1c5676ae7ac98e8b19f63565905105e3c4c37a3", // distance of 159
- }
-
- ih, err := models.InfohashFromString(id)
- if err != nil {
- t.Errorf("failed to create infohash: %s\n", err)
- }
-
- pq, err := newRoutingTable(*ih, 3)
- if err != nil {
- t.Errorf("failed to create kTable: %s\n", err)
- }
-
- for i, idt := range tests {
- iht, err := models.InfohashFromString(idt)
- if err != nil {
- t.Errorf("failed to create infohash: %s\n", err)
- }
- addr, _ := net.ResolveUDPAddr("udp", fmt.Sprintf("0.0.0.0:%d", i))
- pq.add(&remoteNode{id: *iht, addr: addr})
- }
-
- if len(pq.items) != len(pq.addresses) {
- t.Errorf("items and addresses out of sync")
- }
-
- first := pq.items[0].value.id
- if first.String() != "d1c5676ae7ac98e8b19f63565905105e3c4c37a3" {
- t.Errorf("first is %s with distance %d\n", first, ih.Distance(first))
- }
-}
diff --git a/dht/slab.go b/dht/slab.go
deleted file mode 100644
index a8b4018..0000000
--- a/dht/slab.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package dht
-
-// Slab memory allocation
-
-// Initialise the slab as a channel of blocks, allocating them as required and
-// pushing them back on the slab. This reduces garbage collection.
-type slab chan []byte
-
-func newSlab(blockSize int, numBlocks int) slab {
- s := make(slab, numBlocks)
- for i := 0; i < numBlocks; i++ {
- s <- make([]byte, blockSize)
- }
- return s
-}
-
-func (s slab) alloc() (x []byte) {
- return <-s
-}
-
-func (s slab) free(x []byte) {
- // Check we are using the right dimensions
- x = x[:cap(x)]
- s <- x
-}
diff --git a/dht/transactions.go b/dht/transactions.go
new file mode 100644
index 0000000..209d7e3
--- /dev/null
+++ b/dht/transactions.go
@@ -0,0 +1,93 @@
+package dht
+
+import (
+ "context"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "net/netip"
+ "sync"
+ "time"
+)
+
+type transactions struct {
+ sync.Mutex
+ ttl time.Duration
+ requests map[string]trans
+ tid uint64
+ buf [binary.MaxVarintLen64]byte
+}
+
+type trans struct {
+ ap netip.AddrPort
+ msg string
+ ts time.Time
+ cb func(*Node)
+}
+
+func newTransRegistry() *transactions {
+ return &transactions{
+ ttl: 30 * time.Second,
+ requests: make(map[string]trans),
+ }
+}
+
+type TransactionCallback func(*Node)
+
+func (c *Client) registerTransaction(rn *Node, m *Msg, cb TransactionCallback) {
+ c.chk.Lock()
+ defer c.chk.Unlock()
+ c.chk.tid++
+
+ n := binary.PutUvarint(c.chk.buf[:], c.chk.tid)
+ m.TID = string(c.chk.buf[:n])
+
+ if _, ok := c.chk.requests[m.TID]; ok {
+ panic(fmt.Sprintf("duplicate transaction ID %q", m.TID))
+ }
+ c.chk.requests[m.TID] = trans{
+ ap: rn.AddrPort,
+ ts: time.Now(),
+ msg: m.Query,
+ cb: cb,
+ }
+}
+
+func (c *Client) checkTransaction(m Msg) (*trans, error) {
+ c.chk.Lock()
+ defer c.chk.Unlock()
+ if t, ok := c.chk.requests[m.TID]; ok {
+ delete(c.chk.requests, m.TID)
+ return &t, nil
+ }
+ return nil, errors.New("unsolicited transaction")
+}
+
+// func (c *Client) inflightQueries() int {
+// c.chk.Lock()
+// defer c.chk.Unlock()
+// return len(c.chk.requests)
+// }
+
+func (c *Client) auditTransactions(ctx context.Context) {
+ c.log.Info("starting transaction auditing", "interval", c.tickAuditTransactions)
+ ticker := time.Tick(c.tickAuditTransactions)
+ for {
+ select {
+ case <-ctx.Done():
+ c.log.Info("stopping transaction auditing")
+ return
+ case <-ticker:
+ c.chk.Lock()
+ before := len(c.chk.requests)
+ for tid, t := range c.chk.requests {
+ if time.Since(t.ts) > c.chk.ttl {
+ delete(c.chk.requests, tid)
+ }
+ }
+ after := len(c.chk.requests)
+ c.chk.Unlock()
+ c.log.Debug("cleared transactions", "size", after, "removed", before-after)
+ }
+ }
+}
diff --git a/krpc/krpc.go b/krpc/krpc.go
deleted file mode 100644
index d5d1480..0000000
--- a/krpc/krpc.go
+++ /dev/null
@@ -1,181 +0,0 @@
-package krpc
-
-import (
- "errors"
- "fmt"
- "math/rand"
- "net"
- "strconv"
-)
-
-const (
- transIDBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
- IPv4NodeAddrLen = 26
- IPv6NodeAddrLen = 38
-)
-
-func NewTransactionID() string {
- b := make([]byte, 2)
- for i := range b {
- b[i] = transIDBytes[rand.Int63()%int64(len(transIDBytes))]
- }
- return string(b)
-}
-
-// makeQuery returns a query-formed data.
-func MakeQuery(transaction, query string, data map[string]interface{}) map[string]interface{} {
- return map[string]interface{}{
- "t": transaction,
- "y": "q",
- "q": query,
- "a": data,
- }
-}
-
-// makeResponse returns a response-formed data.
-func MakeResponse(transaction string, data map[string]interface{}) map[string]interface{} {
- return map[string]interface{}{
- "t": transaction,
- "y": "r",
- "r": data,
- }
-}
-
-func GetString(data map[string]interface{}, key string) (string, error) {
- val, ok := data[key]
- if !ok {
- return "", fmt.Errorf("krpc: missing key %s", key)
- }
- out, ok := val.(string)
- if !ok {
- return "", fmt.Errorf("krpc: key type mismatch")
- }
- return out, nil
-}
-
-func GetInt(data map[string]interface{}, key string) (int, error) {
- val, ok := data[key]
- if !ok {
- return 0, fmt.Errorf("krpc: missing key %s", key)
- }
- out, ok := val.(int64)
- if !ok {
- return 0, fmt.Errorf("krpc: key type mismatch")
- }
- return int(out), nil
-}
-
-func GetMap(data map[string]interface{}, key string) (map[string]interface{}, error) {
- val, ok := data[key]
- if !ok {
- return nil, fmt.Errorf("krpc: missing key %s", key)
- }
- out, ok := val.(map[string]interface{})
- if !ok {
- return nil, fmt.Errorf("krpc: key type mismatch")
- }
- return out, nil
-}
-
-func GetList(data map[string]interface{}, key string) ([]interface{}, error) {
- val, ok := data[key]
- if !ok {
- return nil, fmt.Errorf("krpc: missing key %s", key)
- }
- out, ok := val.([]interface{})
- if !ok {
- return nil, fmt.Errorf("krpc: key type mismatch")
- }
- return out, nil
-}
-
-// parseKeys parses keys. It just wraps parseKey.
-func checkKeys(data map[string]interface{}, pairs [][]string) (err error) {
- for _, args := range pairs {
- key, t := args[0], args[1]
- if err = checkKey(data, key, t); err != nil {
- break
- }
- }
- return err
-}
-
-// parseKey parses the key in dict data. `t` is type of the keyed value.
-// It's one of "int", "string", "map", "list".
-func checkKey(data map[string]interface{}, key string, t string) error {
- val, ok := data[key]
- if !ok {
- return fmt.Errorf("krpc: missing key %s", key)
- }
-
- switch t {
- case "string":
- _, ok = val.(string)
- case "int":
- _, ok = val.(int)
- case "map":
- _, ok = val.(map[string]interface{})
- case "list":
- _, ok = val.([]interface{})
- default:
- return errors.New("krpc: invalid type")
- }
-
- if !ok {
- return errors.New("krpc: key type mismatch")
- }
-
- return nil
-}
-
-// Swiped from nictuku
-func DecodeCompactNodeAddr(cni string) string {
- if len(cni) == 6 {
- return fmt.Sprintf("%d.%d.%d.%d:%d", cni[0], cni[1], cni[2], cni[3], (uint16(cni[4])<<8)|uint16(cni[5]))
- } else if len(cni) == 18 {
- b := []byte(cni[:16])
- return fmt.Sprintf("[%s]:%d", net.IP.String(b), (uint16(cni[16])<<8)|uint16(cni[17]))
- } else {
- return ""
- }
-}
-
-func EncodeCompactNodeAddr(addr string) string {
- var a []uint8
- host, port, _ := net.SplitHostPort(addr)
- ip := net.ParseIP(host)
- if ip == nil {
- return ""
- }
- aa, _ := strconv.ParseUint(port, 10, 16)
- c := uint16(aa)
- if ip2 := net.IP.To4(ip); ip2 != nil {
- a = make([]byte, net.IPv4len+2, net.IPv4len+2)
- copy(a, ip2[0:net.IPv4len]) // ignore bytes IPv6 bytes if it's IPv4.
- a[4] = byte(c >> 8)
- a[5] = byte(c)
- } else {
- a = make([]byte, net.IPv6len+2, net.IPv6len+2)
- copy(a, ip)
- a[16] = byte(c >> 8)
- a[17] = byte(c)
- }
- return string(a)
-}
-
-func int2bytes(val int64) []byte {
- data, j := make([]byte, 8), -1
- for i := 0; i < 8; i++ {
- shift := uint64((7 - i) * 8)
- data[i] = byte((val & (0xff << shift)) >> shift)
-
- if j == -1 && data[i] != 0 {
- j = i
- }
- }
-
- if j != -1 {
- return data[j:]
- }
- return data[:1]
-}
diff --git a/krpc/krpc_test.go b/krpc/krpc_test.go
deleted file mode 100644
index 59480e1..0000000
--- a/krpc/krpc_test.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package krpc
-
-import (
- "encoding/hex"
- "testing"
-)
-
-func TestCompactNodeAddr(t *testing.T) {
-
- tests := []struct {
- in string
- out string
- }{
- {in: "192.168.1.1:6881", out: "c0a801011ae1"},
- {in: "[2001:9372:434a:800::2]:6881", out: "20019372434a080000000000000000021ae1"},
- }
-
- for _, tt := range tests {
- r := EncodeCompactNodeAddr(tt.in)
- out, _ := hex.DecodeString(tt.out)
- if r != string(out) {
- t.Errorf("encodeCompactNodeAddr(%s) => %x, expected %s", tt.in, r, tt.out)
- }
-
- s := DecodeCompactNodeAddr(r)
- if s != tt.in {
- t.Errorf("decodeCompactNodeAddr(%x) => %s, expected %s", r, s, tt.in)
- }
- }
-}