diff options
| -rw-r--r-- | bencode/decode.go | 84 | ||||
| -rw-r--r-- | bencode/decode_test.go | 163 |
2 files changed, 112 insertions, 135 deletions
diff --git a/bencode/decode.go b/bencode/decode.go index 2e356c2..29b60f7 100644 --- a/bencode/decode.go +++ b/bencode/decode.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "strconv" - "unicode/utf8" ) // Decode decodes a bencoded string to string, int, list or map. @@ -14,53 +13,52 @@ func Decode(data []byte) (r interface{}, err error) { } // DecodeString decodes a string from a given index -// It returns the string and the position of the last character -// or the offset at the point of error -func DecodeString(data []byte, start int) (r string, end int, err error) { +// It returns the string, the number of bytes successfully read +func DecodeString(data []byte, start int) (r string, n int, err error) { if start >= len(data) || data[start] < '0' || data[start] > '9' { err = errors.New("bencode: invalid string length") - return r, 0, err + return r, 1, err } prefix, i, err := readUntil(data, start, ':') + end := start + i if err != nil { - return r, i, err + return r, end - start, err } length, err := strconv.ParseInt(string(prefix), 10, 0) if err != nil { - return r, 0, err + return r, end - start, err } - - end = i + int(length) + end = end + int(length) if end > len(data) || end < i { err = errors.New("bencode: string length out of range") - return r, end, err + return r, end - start, err } - return string(data[i:end]), end - 1, nil + return string(data[start+i : end]), end - start, nil } // DecodeInt decodes an integer value -// It returns the integer and the position of the last character +// It returns the integer and the number of bytes successfully read func DecodeInt(data []byte, start int) (r int64, end int, err error) { if start >= len(data) || data[start] != 'i' { err = errors.New("bencode: invalid integer") return r, end, err } - prefix, end, err := readUntil(data, start, 'e') + prefix, n, err := readUntil(data, start, 'e') if err != nil { - return r, end, err + return r, n, err } r, err = strconv.ParseInt(string(prefix[1:]), 10, 64) - return r, end - 1, err + return r, n, err } // DecodeList decodes a list value -// It returns the array and the position of the last character +// It returns the array and the number of bytes successfully read func DecodeList(data []byte, start int) (r []interface{}, end int, err error) { if start >= len(data) { return r, end, errors.New("bencode: list range error") @@ -73,22 +71,22 @@ func DecodeList(data []byte, start int) (r []interface{}, end int, err error) { // Empty list if data[end] == 'e' { - return r, end, nil + return r, 2, nil } var item interface{} + var n int for end < len(data) { - item, end, err = decodeItem(data, end) + item, n, err = decodeItem(data, end) + end = end + n if err != nil { - return r, end, err + return r, end - start, err } r = append(r, item) - end++ - char, _ := utf8.DecodeRune(data[end:]) - if char == 'e' { - return r, end, nil + if data[end] == 'e' { + return r, end - start + 1, nil } } @@ -97,28 +95,27 @@ func DecodeList(data []byte, start int) (r []interface{}, end int, err error) { // DecodeDict decodes a dict as a map // It returns the map and the position of the last character -func DecodeDict(data []byte, start int) (r map[string]interface{}, end int, err error) { +func DecodeDict(data []byte, start int) (map[string]interface{}, int, error) { + r := make(map[string]interface{}) + if start >= len(data) { - return r, end, errors.New("bencode: dict range error") + return r, 0, errors.New("bencode: dict range error") } if data[start] != 'd' { - return r, end, errors.New("bencode: invalid dict") + return r, 1, errors.New("bencode: invalid dict") } - end = start + 1 + end := start + 1 // Empty dict if data[end] == 'e' { - return r, end, nil + return r, 2, nil } - var item interface{} - var key string - r = make(map[string]interface{}) - for end < len(data) { - key, end, err = DecodeString(data, end) + key, n, err := DecodeString(data, end) + end = end + n if err != nil { return r, end, errors.New("bencode: invalid dict key") } @@ -127,26 +124,24 @@ func DecodeDict(data []byte, start int) (r map[string]interface{}, end int, err return r, end, errors.New("bencode: dict range error") } - end++ - item, end, err = decodeItem(data, end) + item, n, err := decodeItem(data, end) + end = end + n + if err != nil { return r, end, err } r[key] = item - end++ - char, _ := utf8.DecodeRune(data[end:]) - if char == 'e' { - return r, end, nil + if data[end] == 'e' { + return r, end - start + 1, nil } } return r, end, errors.New("bencode: invalid dict termination") } // decodeItem decodes the next type at the given index -// It returns the index of the last character decoded -func decodeItem(data []byte, start int) (r interface{}, end int, err error) { +func decodeItem(data []byte, start int) (r interface{}, n int, err error) { switch data[start] { case 'l': return DecodeList(data, start) @@ -160,14 +155,13 @@ func decodeItem(data []byte, start int) (r interface{}, end int, err error) { } // Read until the given character -// Returns the slice before the character and the index of the next character -// or the offset at the point of error +// Returns the slice before the character and the number of bytes successfully read func readUntil(data []byte, start int, c byte) ([]byte, int, error) { i := start for ; i < len(data); i++ { if data[i] == c { - return data[start:i], i + 1, nil + return data[start:i], i - start + 1, nil } } - return data, i, fmt.Errorf("bencode: '%b' not found", c) + return data, i - start, fmt.Errorf("bencode: '%b' not found", c) } diff --git a/bencode/decode_test.go b/bencode/decode_test.go index e83f6a8..a39dc9f 100644 --- a/bencode/decode_test.go +++ b/bencode/decode_test.go @@ -6,19 +6,21 @@ import ( func TestDecodeString(t *testing.T) { tests := []struct { - in string - out string - end int + in string + start int + out string + n int }{ - {in: "0:", out: "", end: 1}, - {in: "5:hello", out: "hello", end: 6}, - {in: "7:goodbye", out: "goodbye", end: 8}, - {in: "11:hello world", out: "hello world", end: 13}, - {in: "20:1-5%3~]+=\\| []>.,`??", out: "1-5%3~]+=\\| []>.,`??", end: 22}, + {in: "0:", start: 0, out: "", n: 2}, + {in: "5:hello", start: 0, out: "hello", n: 7}, + {in: "7:goodbye", start: 0, out: "goodbye", n: 9}, + {in: "11:hello world", start: 0, out: "hello world", n: 14}, + {in: "20:1-5%3~]+=\\| []>.,`??", start: 0, out: "1-5%3~]+=\\| []>.,`??", n: 23}, + {in: "123412347:goodbye", start: 8, out: "goodbye", n: 9}, } for _, tt := range tests { - r1, end, err := DecodeString([]byte(tt.in), 0) + r1, n, err := DecodeString([]byte(tt.in), tt.start) if err != nil { t.Errorf("DecodeString(%q) failed with error %q", tt.in, err) } @@ -27,17 +29,19 @@ func TestDecodeString(t *testing.T) { t.Errorf("DecodeString(%q) => %q, expected %q", tt.in, r1, tt.out) } - if end != tt.end { - t.Errorf("DecodeString(%q) ended at %d, expected %d", tt.in, end, tt.end) + if n != tt.n { + t.Errorf("DecodeString(%q) read %d, expected %d", tt.in, n, tt.n) } - r2, err := Decode([]byte(tt.in)) - if err != nil { - t.Errorf("DecodeString(%q) failed with error %q", tt.in, err) - } + if tt.start == 0 { + r2, err := Decode([]byte(tt.in)) + if err != nil { + t.Errorf("Decode(%q) failed with error %q", tt.in, err) + } - if r2 != tt.out { - t.Errorf("DecodeString(%q) => %q, expected %q", tt.in, r2, tt.out) + if r2 != tt.out { + t.Errorf("Decode(%q) => %q, expected %q", tt.in, r2, tt.out) + } } } @@ -45,19 +49,21 @@ func TestDecodeString(t *testing.T) { func TestDecodeInt(t *testing.T) { tests := []struct { - in string - out int64 - end int + in string + start int + out int64 + n int }{ - {in: "i0e", out: int64(0), end: 2}, - {in: "i5e", out: int64(5), end: 2}, - {in: "i-5e", out: int64(-5), end: 3}, - {in: "i1234567890e", out: int64(1234567890), end: 11}, - {in: "i-1234567890e", out: int64(-1234567890), end: 12}, + {in: "i0e", start: 0, out: int64(0), n: 3}, + {in: "i5e", start: 0, out: int64(5), n: 3}, + {in: "i-5e", start: 0, out: int64(-5), n: 4}, + {in: "i1234567890e", start: 0, out: int64(1234567890), n: 12}, + {in: "i-1234567890e", start: 0, out: int64(-1234567890), n: 13}, + {in: "asdfasdfi-5e", start: 8, out: int64(-5), n: 4}, } for _, tt := range tests { - r1, end, err := DecodeInt([]byte(tt.in), 0) + r1, n, err := DecodeInt([]byte(tt.in), tt.start) if err != nil { t.Errorf("DecodeInt(%q) failed with error %q", tt.in, err) } @@ -66,17 +72,19 @@ func TestDecodeInt(t *testing.T) { t.Errorf("DecodeInt(%q) => %d, expected %d", tt.in, r1, tt.out) } - if end != tt.end { - t.Errorf("DecodeInt(%q) ended at %d, expected %d", tt.in, end, tt.end) + if n != tt.n { + t.Errorf("DecodeInt(%q) read %d, expected %d", tt.in, n, tt.n) } - r2, err := Decode([]byte(tt.in)) - if err != nil { - t.Errorf("DecodeInt(%q) failed with error %q", tt.in, err) - } + if tt.start == 0 { + r2, err := Decode([]byte(tt.in)) + if err != nil { + t.Errorf("Decode(%q) failed with error %q", tt.in, err) + } - if r2 != tt.out { - t.Errorf("DecodeInt(%q) => %d, expected %d", tt.in, r2, tt.out) + if r2 != tt.out { + t.Errorf("Decode(%q) => %d, expected %d", tt.in, r2, tt.out) + } } } @@ -86,16 +94,16 @@ func TestDecodeList(t *testing.T) { tests := []struct { in string out []interface{} - end int + n int }{ - {in: "l4:spam4:eggse", out: []interface{}{"spam", "eggs"}, end: 13}, - {in: "le", out: []interface{}{}, end: 1}, - {in: "li-1ei0ee", out: []interface{}{int64(-1), int64(0)}, end: 8}, - {in: "l4:testi-1ei0ee", out: []interface{}{"test", int64(-1), int64(0)}, end: 14}, + {in: "l4:spam4:eggse", out: []interface{}{"spam", "eggs"}, n: 14}, + {in: "le", out: []interface{}{}, n: 2}, + {in: "li-1ei0ee", out: []interface{}{int64(-1), int64(0)}, n: 9}, + {in: "l4:testi-1ei0ee", out: []interface{}{"test", int64(-1), int64(0)}, n: 15}, } for _, tt := range tests { - r1, end, err := DecodeList([]byte(tt.in), 0) + r1, n, err := DecodeList([]byte(tt.in), 0) if err != nil { t.Errorf("DecodeList(%q) failed with error %q", tt.in, err) } @@ -110,8 +118,8 @@ func TestDecodeList(t *testing.T) { } } - if end != tt.end { - t.Errorf("DecodeList(%q) ended at %d, expected %d", tt.in, end, tt.end) + if n != tt.n { + t.Errorf("DecodeList(%q) read %d, expected %d", tt.in, n, tt.n) } r2, err := Decode([]byte(tt.in)) @@ -134,18 +142,20 @@ func TestDecodeList(t *testing.T) { func TestDecodeDict(t *testing.T) { tests := []struct { - in string - out map[string]interface{} - end int + in string + start int + out map[string]interface{} + n int }{ - {in: "d4:spam4:eggse", out: map[string]interface{}{"spam": "eggs"}, end: 13}, - {in: "de", out: map[string]interface{}{}, end: 1}, - {in: "d4:testi-1e3:twoi0ee", out: map[string]interface{}{"test": int64(-1), "two": int64(0)}, end: 14}, - {in: "d4:testi0ee", out: map[string]interface{}{"test": int64(0)}, end: 10}, + {in: "d4:spam4:eggse", start: 0, out: map[string]interface{}{"spam": "eggs"}, n: 14}, + {in: "de", start: 0, out: map[string]interface{}{}, n: 2}, + {in: "d4:testi-1e3:twoi0ee", start: 0, out: map[string]interface{}{"test": int64(-1), "two": int64(0)}, n: 20}, + {in: "d4:testi0ee", start: 0, out: map[string]interface{}{"test": int64(0)}, n: 11}, + {in: "012345d4:spam4:eggse", start: 6, out: map[string]interface{}{"spam": "eggs"}, n: 14}, } for _, tt := range tests { - r1, _, err := DecodeDict([]byte(tt.in), 0) + r1, _, err := DecodeDict([]byte(tt.in), tt.start) if err != nil { t.Errorf("DecodeDict(%q) failed with error %q", tt.in, err) } @@ -160,49 +170,22 @@ func TestDecodeDict(t *testing.T) { } } - r2, err := Decode([]byte(tt.in)) - if err != nil { - t.Errorf("Decode(%q) failed with error %q", tt.in, err) - } - - r3, ok := r2.(map[string]interface{}) - if !ok { - t.Errorf("Decode(%q) did not return a map", tt.in) - } - - for k := range r3 { - if r3[k] != tt.out[k] { - t.Errorf("Decode(%q) => %v, expected %v", tt.in, r3, tt.out) + if tt.start == 0 { + r2, err := Decode([]byte(tt.in)) + if err != nil { + t.Errorf("Decode(%q) failed with error %q", tt.in, err) } - } - } -} - -func TestReadUntil(t *testing.T) { - tests := []struct { - in string - start int - out string - end int - }{ - {in: "0:", start: 0, out: "0", end: 2}, - {in: "5:hello", start: 0, out: "5", end: 2}, - {in: "1234567:goodbye", start: 0, out: "1234567", end: 8}, - {in: "asdfasdfsa5:hello", start: 10, out: "5", end: 12}, - } - - for _, tt := range tests { - r, i, err := readUntil([]byte(tt.in), tt.start, ':') - if err != nil { - t.Errorf("readUntil(%q) failed with error %q", tt.in, err) - } - if string(r) != tt.out { - t.Errorf("readUntil(%q) => %q, expected %q", tt.in, r, tt.out) - } + r3, ok := r2.(map[string]interface{}) + if !ok { + t.Errorf("Decode(%q) did not return a map", tt.in) + } - if i != tt.end { - t.Errorf("readUntil(%q) ended at %d, expected %d", tt.in, i, tt.end) + for k := range r3 { + if r3[k] != tt.out[k] { + t.Errorf("Decode(%q) => %v, expected %v", tt.in, r3, tt.out) + } + } } } } |
