aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bencode/decode.go84
-rw-r--r--bencode/decode_test.go163
2 files changed, 112 insertions, 135 deletions
diff --git a/bencode/decode.go b/bencode/decode.go
index 2e356c2..29b60f7 100644
--- a/bencode/decode.go
+++ b/bencode/decode.go
@@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"strconv"
- "unicode/utf8"
)
// Decode decodes a bencoded string to string, int, list or map.
@@ -14,53 +13,52 @@ func Decode(data []byte) (r interface{}, err error) {
}
// DecodeString decodes a string from a given index
-// It returns the string and the position of the last character
-// or the offset at the point of error
-func DecodeString(data []byte, start int) (r string, end int, err error) {
+// It returns the string, the number of bytes successfully read
+func DecodeString(data []byte, start int) (r string, n int, err error) {
if start >= len(data) || data[start] < '0' || data[start] > '9' {
err = errors.New("bencode: invalid string length")
- return r, 0, err
+ return r, 1, err
}
prefix, i, err := readUntil(data, start, ':')
+ end := start + i
if err != nil {
- return r, i, err
+ return r, end - start, err
}
length, err := strconv.ParseInt(string(prefix), 10, 0)
if err != nil {
- return r, 0, err
+ return r, end - start, err
}
-
- end = i + int(length)
+ end = end + int(length)
if end > len(data) || end < i {
err = errors.New("bencode: string length out of range")
- return r, end, err
+ return r, end - start, err
}
- return string(data[i:end]), end - 1, nil
+ return string(data[start+i : end]), end - start, nil
}
// DecodeInt decodes an integer value
-// It returns the integer and the position of the last character
+// It returns the integer and the number of bytes successfully read
func DecodeInt(data []byte, start int) (r int64, end int, err error) {
if start >= len(data) || data[start] != 'i' {
err = errors.New("bencode: invalid integer")
return r, end, err
}
- prefix, end, err := readUntil(data, start, 'e')
+ prefix, n, err := readUntil(data, start, 'e')
if err != nil {
- return r, end, err
+ return r, n, err
}
r, err = strconv.ParseInt(string(prefix[1:]), 10, 64)
- return r, end - 1, err
+ return r, n, err
}
// DecodeList decodes a list value
-// It returns the array and the position of the last character
+// It returns the array and the number of bytes successfully read
func DecodeList(data []byte, start int) (r []interface{}, end int, err error) {
if start >= len(data) {
return r, end, errors.New("bencode: list range error")
@@ -73,22 +71,22 @@ func DecodeList(data []byte, start int) (r []interface{}, end int, err error) {
// Empty list
if data[end] == 'e' {
- return r, end, nil
+ return r, 2, nil
}
var item interface{}
+ var n int
for end < len(data) {
- item, end, err = decodeItem(data, end)
+ item, n, err = decodeItem(data, end)
+ end = end + n
if err != nil {
- return r, end, err
+ return r, end - start, err
}
r = append(r, item)
- end++
- char, _ := utf8.DecodeRune(data[end:])
- if char == 'e' {
- return r, end, nil
+ if data[end] == 'e' {
+ return r, end - start + 1, nil
}
}
@@ -97,28 +95,27 @@ func DecodeList(data []byte, start int) (r []interface{}, end int, err error) {
// DecodeDict decodes a dict as a map
// It returns the map and the position of the last character
-func DecodeDict(data []byte, start int) (r map[string]interface{}, end int, err error) {
+func DecodeDict(data []byte, start int) (map[string]interface{}, int, error) {
+ r := make(map[string]interface{})
+
if start >= len(data) {
- return r, end, errors.New("bencode: dict range error")
+ return r, 0, errors.New("bencode: dict range error")
}
if data[start] != 'd' {
- return r, end, errors.New("bencode: invalid dict")
+ return r, 1, errors.New("bencode: invalid dict")
}
- end = start + 1
+ end := start + 1
// Empty dict
if data[end] == 'e' {
- return r, end, nil
+ return r, 2, nil
}
- var item interface{}
- var key string
- r = make(map[string]interface{})
-
for end < len(data) {
- key, end, err = DecodeString(data, end)
+ key, n, err := DecodeString(data, end)
+ end = end + n
if err != nil {
return r, end, errors.New("bencode: invalid dict key")
}
@@ -127,26 +124,24 @@ func DecodeDict(data []byte, start int) (r map[string]interface{}, end int, err
return r, end, errors.New("bencode: dict range error")
}
- end++
- item, end, err = decodeItem(data, end)
+ item, n, err := decodeItem(data, end)
+ end = end + n
+
if err != nil {
return r, end, err
}
r[key] = item
- end++
- char, _ := utf8.DecodeRune(data[end:])
- if char == 'e' {
- return r, end, nil
+ if data[end] == 'e' {
+ return r, end - start + 1, nil
}
}
return r, end, errors.New("bencode: invalid dict termination")
}
// decodeItem decodes the next type at the given index
-// It returns the index of the last character decoded
-func decodeItem(data []byte, start int) (r interface{}, end int, err error) {
+func decodeItem(data []byte, start int) (r interface{}, n int, err error) {
switch data[start] {
case 'l':
return DecodeList(data, start)
@@ -160,14 +155,13 @@ func decodeItem(data []byte, start int) (r interface{}, end int, err error) {
}
// Read until the given character
-// Returns the slice before the character and the index of the next character
-// or the offset at the point of error
+// Returns the slice before the character and the number of bytes successfully read
func readUntil(data []byte, start int, c byte) ([]byte, int, error) {
i := start
for ; i < len(data); i++ {
if data[i] == c {
- return data[start:i], i + 1, nil
+ return data[start:i], i - start + 1, nil
}
}
- return data, i, fmt.Errorf("bencode: '%b' not found", c)
+ return data, i - start, fmt.Errorf("bencode: '%b' not found", c)
}
diff --git a/bencode/decode_test.go b/bencode/decode_test.go
index e83f6a8..a39dc9f 100644
--- a/bencode/decode_test.go
+++ b/bencode/decode_test.go
@@ -6,19 +6,21 @@ import (
func TestDecodeString(t *testing.T) {
tests := []struct {
- in string
- out string
- end int
+ in string
+ start int
+ out string
+ n int
}{
- {in: "0:", out: "", end: 1},
- {in: "5:hello", out: "hello", end: 6},
- {in: "7:goodbye", out: "goodbye", end: 8},
- {in: "11:hello world", out: "hello world", end: 13},
- {in: "20:1-5%3~]+=\\| []>.,`??", out: "1-5%3~]+=\\| []>.,`??", end: 22},
+ {in: "0:", start: 0, out: "", n: 2},
+ {in: "5:hello", start: 0, out: "hello", n: 7},
+ {in: "7:goodbye", start: 0, out: "goodbye", n: 9},
+ {in: "11:hello world", start: 0, out: "hello world", n: 14},
+ {in: "20:1-5%3~]+=\\| []>.,`??", start: 0, out: "1-5%3~]+=\\| []>.,`??", n: 23},
+ {in: "123412347:goodbye", start: 8, out: "goodbye", n: 9},
}
for _, tt := range tests {
- r1, end, err := DecodeString([]byte(tt.in), 0)
+ r1, n, err := DecodeString([]byte(tt.in), tt.start)
if err != nil {
t.Errorf("DecodeString(%q) failed with error %q", tt.in, err)
}
@@ -27,17 +29,19 @@ func TestDecodeString(t *testing.T) {
t.Errorf("DecodeString(%q) => %q, expected %q", tt.in, r1, tt.out)
}
- if end != tt.end {
- t.Errorf("DecodeString(%q) ended at %d, expected %d", tt.in, end, tt.end)
+ if n != tt.n {
+ t.Errorf("DecodeString(%q) read %d, expected %d", tt.in, n, tt.n)
}
- r2, err := Decode([]byte(tt.in))
- if err != nil {
- t.Errorf("DecodeString(%q) failed with error %q", tt.in, err)
- }
+ if tt.start == 0 {
+ r2, err := Decode([]byte(tt.in))
+ if err != nil {
+ t.Errorf("Decode(%q) failed with error %q", tt.in, err)
+ }
- if r2 != tt.out {
- t.Errorf("DecodeString(%q) => %q, expected %q", tt.in, r2, tt.out)
+ if r2 != tt.out {
+ t.Errorf("Decode(%q) => %q, expected %q", tt.in, r2, tt.out)
+ }
}
}
@@ -45,19 +49,21 @@ func TestDecodeString(t *testing.T) {
func TestDecodeInt(t *testing.T) {
tests := []struct {
- in string
- out int64
- end int
+ in string
+ start int
+ out int64
+ n int
}{
- {in: "i0e", out: int64(0), end: 2},
- {in: "i5e", out: int64(5), end: 2},
- {in: "i-5e", out: int64(-5), end: 3},
- {in: "i1234567890e", out: int64(1234567890), end: 11},
- {in: "i-1234567890e", out: int64(-1234567890), end: 12},
+ {in: "i0e", start: 0, out: int64(0), n: 3},
+ {in: "i5e", start: 0, out: int64(5), n: 3},
+ {in: "i-5e", start: 0, out: int64(-5), n: 4},
+ {in: "i1234567890e", start: 0, out: int64(1234567890), n: 12},
+ {in: "i-1234567890e", start: 0, out: int64(-1234567890), n: 13},
+ {in: "asdfasdfi-5e", start: 8, out: int64(-5), n: 4},
}
for _, tt := range tests {
- r1, end, err := DecodeInt([]byte(tt.in), 0)
+ r1, n, err := DecodeInt([]byte(tt.in), tt.start)
if err != nil {
t.Errorf("DecodeInt(%q) failed with error %q", tt.in, err)
}
@@ -66,17 +72,19 @@ func TestDecodeInt(t *testing.T) {
t.Errorf("DecodeInt(%q) => %d, expected %d", tt.in, r1, tt.out)
}
- if end != tt.end {
- t.Errorf("DecodeInt(%q) ended at %d, expected %d", tt.in, end, tt.end)
+ if n != tt.n {
+ t.Errorf("DecodeInt(%q) read %d, expected %d", tt.in, n, tt.n)
}
- r2, err := Decode([]byte(tt.in))
- if err != nil {
- t.Errorf("DecodeInt(%q) failed with error %q", tt.in, err)
- }
+ if tt.start == 0 {
+ r2, err := Decode([]byte(tt.in))
+ if err != nil {
+ t.Errorf("Decode(%q) failed with error %q", tt.in, err)
+ }
- if r2 != tt.out {
- t.Errorf("DecodeInt(%q) => %d, expected %d", tt.in, r2, tt.out)
+ if r2 != tt.out {
+ t.Errorf("Decode(%q) => %d, expected %d", tt.in, r2, tt.out)
+ }
}
}
@@ -86,16 +94,16 @@ func TestDecodeList(t *testing.T) {
tests := []struct {
in string
out []interface{}
- end int
+ n int
}{
- {in: "l4:spam4:eggse", out: []interface{}{"spam", "eggs"}, end: 13},
- {in: "le", out: []interface{}{}, end: 1},
- {in: "li-1ei0ee", out: []interface{}{int64(-1), int64(0)}, end: 8},
- {in: "l4:testi-1ei0ee", out: []interface{}{"test", int64(-1), int64(0)}, end: 14},
+ {in: "l4:spam4:eggse", out: []interface{}{"spam", "eggs"}, n: 14},
+ {in: "le", out: []interface{}{}, n: 2},
+ {in: "li-1ei0ee", out: []interface{}{int64(-1), int64(0)}, n: 9},
+ {in: "l4:testi-1ei0ee", out: []interface{}{"test", int64(-1), int64(0)}, n: 15},
}
for _, tt := range tests {
- r1, end, err := DecodeList([]byte(tt.in), 0)
+ r1, n, err := DecodeList([]byte(tt.in), 0)
if err != nil {
t.Errorf("DecodeList(%q) failed with error %q", tt.in, err)
}
@@ -110,8 +118,8 @@ func TestDecodeList(t *testing.T) {
}
}
- if end != tt.end {
- t.Errorf("DecodeList(%q) ended at %d, expected %d", tt.in, end, tt.end)
+ if n != tt.n {
+ t.Errorf("DecodeList(%q) read %d, expected %d", tt.in, n, tt.n)
}
r2, err := Decode([]byte(tt.in))
@@ -134,18 +142,20 @@ func TestDecodeList(t *testing.T) {
func TestDecodeDict(t *testing.T) {
tests := []struct {
- in string
- out map[string]interface{}
- end int
+ in string
+ start int
+ out map[string]interface{}
+ n int
}{
- {in: "d4:spam4:eggse", out: map[string]interface{}{"spam": "eggs"}, end: 13},
- {in: "de", out: map[string]interface{}{}, end: 1},
- {in: "d4:testi-1e3:twoi0ee", out: map[string]interface{}{"test": int64(-1), "two": int64(0)}, end: 14},
- {in: "d4:testi0ee", out: map[string]interface{}{"test": int64(0)}, end: 10},
+ {in: "d4:spam4:eggse", start: 0, out: map[string]interface{}{"spam": "eggs"}, n: 14},
+ {in: "de", start: 0, out: map[string]interface{}{}, n: 2},
+ {in: "d4:testi-1e3:twoi0ee", start: 0, out: map[string]interface{}{"test": int64(-1), "two": int64(0)}, n: 20},
+ {in: "d4:testi0ee", start: 0, out: map[string]interface{}{"test": int64(0)}, n: 11},
+ {in: "012345d4:spam4:eggse", start: 6, out: map[string]interface{}{"spam": "eggs"}, n: 14},
}
for _, tt := range tests {
- r1, _, err := DecodeDict([]byte(tt.in), 0)
+ r1, _, err := DecodeDict([]byte(tt.in), tt.start)
if err != nil {
t.Errorf("DecodeDict(%q) failed with error %q", tt.in, err)
}
@@ -160,49 +170,22 @@ func TestDecodeDict(t *testing.T) {
}
}
- r2, err := Decode([]byte(tt.in))
- if err != nil {
- t.Errorf("Decode(%q) failed with error %q", tt.in, err)
- }
-
- r3, ok := r2.(map[string]interface{})
- if !ok {
- t.Errorf("Decode(%q) did not return a map", tt.in)
- }
-
- for k := range r3 {
- if r3[k] != tt.out[k] {
- t.Errorf("Decode(%q) => %v, expected %v", tt.in, r3, tt.out)
+ if tt.start == 0 {
+ r2, err := Decode([]byte(tt.in))
+ if err != nil {
+ t.Errorf("Decode(%q) failed with error %q", tt.in, err)
}
- }
- }
-}
-
-func TestReadUntil(t *testing.T) {
- tests := []struct {
- in string
- start int
- out string
- end int
- }{
- {in: "0:", start: 0, out: "0", end: 2},
- {in: "5:hello", start: 0, out: "5", end: 2},
- {in: "1234567:goodbye", start: 0, out: "1234567", end: 8},
- {in: "asdfasdfsa5:hello", start: 10, out: "5", end: 12},
- }
-
- for _, tt := range tests {
- r, i, err := readUntil([]byte(tt.in), tt.start, ':')
- if err != nil {
- t.Errorf("readUntil(%q) failed with error %q", tt.in, err)
- }
- if string(r) != tt.out {
- t.Errorf("readUntil(%q) => %q, expected %q", tt.in, r, tt.out)
- }
+ r3, ok := r2.(map[string]interface{})
+ if !ok {
+ t.Errorf("Decode(%q) did not return a map", tt.in)
+ }
- if i != tt.end {
- t.Errorf("readUntil(%q) ended at %d, expected %d", tt.in, i, tt.end)
+ for k := range r3 {
+ if r3[k] != tt.out[k] {
+ t.Errorf("Decode(%q) => %v, expected %v", tt.in, r3, tt.out)
+ }
+ }
}
}
}