mirror of
https://github.com/sartoopjj/thefeed.git
synced 2026-05-19 06:54:34 +03:00
feat: ✨ media download with DNS query
This commit is contained in:
+78
-10
@@ -238,27 +238,95 @@ func (c *Cache) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectGaps finds places in a sorted message list where consecutive IDs differ
|
||||
// by more than 1. Gaps larger than 500 are ignored (natural Telegram numbering).
|
||||
// Returns nil when there are fewer than 10 messages (not enough history to judge).
|
||||
// detectGaps finds runs of missing IDs between consecutive messages. Album-
|
||||
// merged canonicals cover a contiguous span of sibling IDs (counted via
|
||||
// albumSpan), so absorbed siblings don't show up as fake gaps. Diffs > 500
|
||||
// are ignored (natural Telegram numbering jumps); under 10 messages we don't
|
||||
// have enough history to judge.
|
||||
func detectGaps(msgs []protocol.Message) []Gap {
|
||||
if len(msgs) < 10 {
|
||||
return nil
|
||||
}
|
||||
var gaps []Gap
|
||||
for i := 1; i < len(msgs); i++ {
|
||||
prev, cur := msgs[i-1].ID, msgs[i].ID
|
||||
if diff := cur - prev; diff > 1 && diff <= 500 {
|
||||
gaps = append(gaps, Gap{
|
||||
AfterID: prev,
|
||||
BeforeID: cur,
|
||||
Count: int(diff - 1),
|
||||
})
|
||||
prev, cur := msgs[i-1], msgs[i]
|
||||
span := uint32(albumSpan(prev.Text))
|
||||
if span == 0 {
|
||||
span = 1
|
||||
}
|
||||
expectedNext := prev.ID + span
|
||||
if cur.ID <= expectedNext {
|
||||
continue
|
||||
}
|
||||
diff := cur.ID - expectedNext
|
||||
if diff > 500 {
|
||||
continue
|
||||
}
|
||||
gaps = append(gaps, Gap{
|
||||
AfterID: expectedNext - 1,
|
||||
BeforeID: cur.ID,
|
||||
Count: int(diff),
|
||||
})
|
||||
}
|
||||
return gaps
|
||||
}
|
||||
|
||||
// mediaHeaderTags are the leading [TAG] markers extractMessages may stack
|
||||
// at the start of a canonical message body — one per absorbed album item.
|
||||
var mediaHeaderTags = []string{
|
||||
protocol.MediaImage,
|
||||
protocol.MediaVideo,
|
||||
protocol.MediaFile,
|
||||
protocol.MediaAudio,
|
||||
protocol.MediaSticker,
|
||||
protocol.MediaGIF,
|
||||
protocol.MediaLocation,
|
||||
protocol.MediaContact,
|
||||
}
|
||||
|
||||
// albumSpan counts the leading media-header lines in a canonical body — 0
|
||||
// for plain text, 1 for a single media item, N for an N-item album. A
|
||||
// leading [REPLY]... line is skipped first.
|
||||
func albumSpan(text string) int {
|
||||
if strings.HasPrefix(text, protocol.MediaReply) {
|
||||
nl := strings.IndexByte(text, '\n')
|
||||
if nl < 0 {
|
||||
return 0
|
||||
}
|
||||
text = text[nl+1:]
|
||||
}
|
||||
n := 0
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
if !isMediaHeaderLine(line) {
|
||||
break
|
||||
}
|
||||
n++
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// isMediaHeaderLine matches both the bare [TAG] form and the downloadable
|
||||
// "[TAG]<digit>..." form. Caption text that happens to start with "[IMAGE]"
|
||||
// is rejected because rest[0] won't be a digit.
|
||||
func isMediaHeaderLine(line string) bool {
|
||||
for _, tag := range mediaHeaderTags {
|
||||
if line == tag {
|
||||
return true
|
||||
}
|
||||
if !strings.HasPrefix(line, tag) {
|
||||
continue
|
||||
}
|
||||
rest := line[len(tag):]
|
||||
if rest == "" {
|
||||
return true
|
||||
}
|
||||
if rest[0] >= '0' && rest[0] <= '9' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// channelPath returns the file path for a channel's cache, keyed by sanitised name.
|
||||
// Only letters, digits, hyphens, and underscores are kept; everything else becomes _.
|
||||
func (c *Cache) channelPath(channelName string) string {
|
||||
|
||||
@@ -221,6 +221,111 @@ func TestCacheGapDetection_NoGapWhenFewMessages(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheGapDetection_AlbumNoFalsePositive(t *testing.T) {
|
||||
cache, _ := NewCache(t.TempDir())
|
||||
|
||||
// 10 sequential messages where ID=5 is a 2-image album: it absorbs ID 6
|
||||
// (a real Telegram behaviour). The next message is ID 7. Without the
|
||||
// album-aware fix, the gap detector would flag a missing ID 6.
|
||||
msgs := []protocol.Message{
|
||||
{ID: 1, Timestamp: 1700000000, Text: "a"},
|
||||
{ID: 2, Timestamp: 1700000001, Text: "b"},
|
||||
{ID: 3, Timestamp: 1700000002, Text: "c"},
|
||||
{ID: 4, Timestamp: 1700000003, Text: "d"},
|
||||
{ID: 5, Timestamp: 1700000004, Text: "[IMAGE]100:0:0:0:abcd1234:img1.jpg\n[IMAGE]200:0:0:0:abcd5678:img2.jpg\nalbum caption"},
|
||||
// ID 6 is absorbed into the album above; the feed jumps to 7.
|
||||
{ID: 7, Timestamp: 1700000005, Text: "e"},
|
||||
{ID: 8, Timestamp: 1700000006, Text: "f"},
|
||||
{ID: 9, Timestamp: 1700000007, Text: "g"},
|
||||
{ID: 10, Timestamp: 1700000008, Text: "h"},
|
||||
{ID: 11, Timestamp: 1700000009, Text: "i"},
|
||||
}
|
||||
result, _ := cache.MergeAndPut("albumchan", msgs)
|
||||
|
||||
if len(result.Gaps) != 0 {
|
||||
t.Errorf("album-absorbed sibling should not be flagged as a gap, got %+v", result.Gaps)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheGapDetection_AlbumWithRealGap(t *testing.T) {
|
||||
cache, _ := NewCache(t.TempDir())
|
||||
|
||||
// 3-image album at ID=5 absorbs IDs 6,7. A real gap of IDs 8,9 follows
|
||||
// before ID=10. The detector should report a single 2-message gap.
|
||||
msgs := []protocol.Message{
|
||||
{ID: 1, Timestamp: 1700000000, Text: "a"},
|
||||
{ID: 2, Timestamp: 1700000001, Text: "b"},
|
||||
{ID: 3, Timestamp: 1700000002, Text: "c"},
|
||||
{ID: 4, Timestamp: 1700000003, Text: "d"},
|
||||
{ID: 5, Timestamp: 1700000004, Text: "[IMAGE]100:0:0:0:aaaaaaaa:1.jpg\n[IMAGE]200:0:0:0:bbbbbbbb:2.jpg\n[IMAGE]300:0:0:0:cccccccc:3.jpg\ncap"},
|
||||
// IDs 6,7 absorbed; IDs 8,9 truly missing; resume at 10.
|
||||
{ID: 10, Timestamp: 1700000010, Text: "e"},
|
||||
{ID: 11, Timestamp: 1700000011, Text: "f"},
|
||||
{ID: 12, Timestamp: 1700000012, Text: "g"},
|
||||
{ID: 13, Timestamp: 1700000013, Text: "h"},
|
||||
{ID: 14, Timestamp: 1700000014, Text: "i"},
|
||||
{ID: 15, Timestamp: 1700000015, Text: "j"},
|
||||
}
|
||||
result, _ := cache.MergeAndPut("albumgap", msgs)
|
||||
|
||||
if len(result.Gaps) != 1 {
|
||||
t.Fatalf("expected exactly one gap, got %+v", result.Gaps)
|
||||
}
|
||||
g := result.Gaps[0]
|
||||
if g.AfterID != 7 || g.BeforeID != 10 || g.Count != 2 {
|
||||
t.Errorf("gap = %+v, want AfterID=7 BeforeID=10 Count=2", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheGapDetection_AlbumWithReplyPrefix(t *testing.T) {
|
||||
cache, _ := NewCache(t.TempDir())
|
||||
|
||||
// [REPLY]:42 prefix before the media headers should still let albumSpan
|
||||
// count the headers correctly.
|
||||
msgs := []protocol.Message{
|
||||
{ID: 1, Timestamp: 1700000000, Text: "a"},
|
||||
{ID: 2, Timestamp: 1700000001, Text: "b"},
|
||||
{ID: 3, Timestamp: 1700000002, Text: "c"},
|
||||
{ID: 4, Timestamp: 1700000003, Text: "d"},
|
||||
{ID: 5, Timestamp: 1700000004, Text: "[REPLY]:42\n[IMAGE]100:0:0:0:aaaaaaaa:1.jpg\n[IMAGE]200:0:0:0:bbbbbbbb:2.jpg\nreplied caption"},
|
||||
// ID 6 absorbed.
|
||||
{ID: 7, Timestamp: 1700000005, Text: "e"},
|
||||
{ID: 8, Timestamp: 1700000006, Text: "f"},
|
||||
{ID: 9, Timestamp: 1700000007, Text: "g"},
|
||||
{ID: 10, Timestamp: 1700000008, Text: "h"},
|
||||
{ID: 11, Timestamp: 1700000009, Text: "i"},
|
||||
}
|
||||
result, _ := cache.MergeAndPut("replychan", msgs)
|
||||
|
||||
if len(result.Gaps) != 0 {
|
||||
t.Errorf("album with reply prefix should not produce false gaps, got %+v", result.Gaps)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAlbumSpan(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
text string
|
||||
want int
|
||||
}{
|
||||
{"plain text", "hello world", 0},
|
||||
{"single image legacy", "[IMAGE]\ncaption", 1},
|
||||
{"single image downloadable", "[IMAGE]100:0:0:0:abcd1234:f.jpg\ncap", 1},
|
||||
{"two images", "[IMAGE]100:0:0:0:aa:1.jpg\n[IMAGE]200:0:0:0:bb:2.jpg\ncap", 2},
|
||||
{"three mixed", "[IMAGE]1:0:0:0:aa:a.jpg\n[VIDEO]2:0:0:0:bb:b.mp4\n[FILE]3:0:0:0:cc:c.pdf\nx", 3},
|
||||
{"with reply prefix", "[REPLY]:99\n[IMAGE]100:0:0:0:aa:1.jpg\n[IMAGE]200:0:0:0:bb:2.jpg\ncap", 2},
|
||||
{"reply only no media", "[REPLY]:99\nhello", 0},
|
||||
{"caption that mentions a tag", "look at this [IMAGE] thing", 0},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
if got := albumSpan(c.text); got != c.want {
|
||||
t.Errorf("albumSpan(%q) = %d, want %d", c.text, got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheGapDetection_LargeGapIgnored(t *testing.T) {
|
||||
cache, _ := NewCache(t.TempDir())
|
||||
|
||||
|
||||
@@ -0,0 +1,239 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/sartoopjj/thefeed/internal/protocol"
|
||||
)
|
||||
|
||||
// DecompressMediaReader wraps r per the given compression.
|
||||
func DecompressMediaReader(r io.Reader, compression protocol.MediaCompression) (io.ReadCloser, error) {
|
||||
switch compression {
|
||||
case protocol.MediaCompressionNone:
|
||||
return io.NopCloser(r), nil
|
||||
case protocol.MediaCompressionGzip:
|
||||
return gzip.NewReader(r)
|
||||
case protocol.MediaCompressionDeflate:
|
||||
return flate.NewReader(r), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported media compression: %d", compression)
|
||||
}
|
||||
|
||||
func decompressMediaBytes(body []byte, compression protocol.MediaCompression) ([]byte, error) {
|
||||
rc, err := DecompressMediaReader(bytes.NewReader(body), compression)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rc.Close()
|
||||
return io.ReadAll(rc)
|
||||
}
|
||||
|
||||
// MediaProgress reports per-block progress (completed of total). May be
|
||||
// invoked from a background goroutine.
|
||||
type MediaProgress func(completed, total int)
|
||||
|
||||
// MediaBlockHeaderLen re-exports the protocol header length so callers in
|
||||
// the web layer don't have to import the protocol package twice.
|
||||
const MediaBlockHeaderLen = protocol.MediaBlockHeaderLen
|
||||
|
||||
// ErrMediaHashMismatch indicates the assembled bytes don't match the
|
||||
// expected CRC32. The caller must discard the returned bytes.
|
||||
var ErrMediaHashMismatch = fmt.Errorf("media content hash mismatch")
|
||||
|
||||
// mediaBlockOuterRetries is the per-block retry budget the media path adds
|
||||
// on top of FetchBlock's own internal retries. A ~200-block file can lose
|
||||
// individual blocks repeatedly; without this, one persistent bad block
|
||||
// kills the whole download even though FetchBlock would succeed on a
|
||||
// later attempt.
|
||||
const mediaBlockOuterRetries = 5
|
||||
|
||||
func (f *Fetcher) fetchMediaBlock(ctx context.Context, channel, block uint16) ([]byte, error) {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < mediaBlockOuterRetries; attempt++ {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
data, err := f.FetchBlock(ctx, channel, block)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// FetchMedia returns the assembled bytes of a media blob served on a media
|
||||
// channel, optionally verifying expectedCRC32.
|
||||
func (f *Fetcher) FetchMedia(ctx context.Context, channel uint16, blockCount uint16, expectedCRC32 uint32, progress MediaProgress) ([]byte, error) {
|
||||
if !protocol.IsMediaChannel(channel) {
|
||||
return nil, fmt.Errorf("channel %d is outside media range", channel)
|
||||
}
|
||||
if blockCount == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type blockResult struct {
|
||||
idx int
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
results := make(chan blockResult, blockCount)
|
||||
sem := make(chan struct{}, 5)
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < int(blockCount); i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
results <- blockResult{idx: idx, err: ctx.Err()}
|
||||
return
|
||||
}
|
||||
defer func() { <-sem }()
|
||||
data, err := f.fetchMediaBlock(ctx, channel, uint16(idx))
|
||||
results <- blockResult{idx: idx, data: data, err: err}
|
||||
}(i)
|
||||
}
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
ordered := make([][]byte, blockCount)
|
||||
completed := 0
|
||||
var progMu sync.Mutex
|
||||
for r := range results {
|
||||
if r.err != nil {
|
||||
if r.err == ctx.Err() {
|
||||
return nil, r.err
|
||||
}
|
||||
return nil, fmt.Errorf("media channel %d block %d: %w", channel, r.idx, r.err)
|
||||
}
|
||||
ordered[r.idx] = r.data
|
||||
completed++
|
||||
if progress != nil {
|
||||
progMu.Lock()
|
||||
progress(completed, int(blockCount))
|
||||
progMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if len(ordered) == 0 || len(ordered[0]) < protocol.MediaBlockHeaderLen {
|
||||
return nil, fmt.Errorf("media channel %d: malformed block 0", channel)
|
||||
}
|
||||
header, err := protocol.DecodeMediaBlockHeader(ordered[0][:protocol.MediaBlockHeaderLen])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("media channel %d: %w", channel, err)
|
||||
}
|
||||
if expectedCRC32 != 0 && header.CRC32 != expectedCRC32 {
|
||||
return nil, ErrMediaHashMismatch
|
||||
}
|
||||
|
||||
// Concatenate all block bytes after the header.
|
||||
total := len(ordered[0]) - protocol.MediaBlockHeaderLen
|
||||
for i := 1; i < len(ordered); i++ {
|
||||
total += len(ordered[i])
|
||||
}
|
||||
body := make([]byte, 0, total)
|
||||
body = append(body, ordered[0][protocol.MediaBlockHeaderLen:]...)
|
||||
for i := 1; i < len(ordered); i++ {
|
||||
body = append(body, ordered[i]...)
|
||||
}
|
||||
|
||||
// Decompress per the header.
|
||||
out, err := decompressMediaBytes(body, header.Compression)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompress media channel %d: %w", channel, err)
|
||||
}
|
||||
if expectedCRC32 != 0 {
|
||||
if got := crc32.ChecksumIEEE(out); got != expectedCRC32 {
|
||||
return nil, ErrMediaHashMismatch
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// FetchMediaBlocksStream fetches blocks [startBlock, startBlock+count) and
|
||||
// writes each block's raw bytes to w in order as soon as they become
|
||||
// contiguous. No header parsing; callers slice off the protocol header
|
||||
// themselves and decompress as appropriate. Cancelling ctx aborts both
|
||||
// in-flight DNS queries and pending writes.
|
||||
func (f *Fetcher) FetchMediaBlocksStream(ctx context.Context, channel, startBlock, count uint16, w io.Writer, progress MediaProgress) error {
|
||||
if !protocol.IsMediaChannel(channel) {
|
||||
return fmt.Errorf("channel %d is outside media range", channel)
|
||||
}
|
||||
if count == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
type blockResult struct {
|
||||
idx int
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
results := make(chan blockResult, count)
|
||||
sem := make(chan struct{}, 5)
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < int(count); i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
results <- blockResult{idx: idx, err: ctx.Err()}
|
||||
return
|
||||
}
|
||||
defer func() { <-sem }()
|
||||
data, err := f.fetchMediaBlock(ctx, channel, uint16(int(startBlock)+idx))
|
||||
results <- blockResult{idx: idx, data: data, err: err}
|
||||
}(i)
|
||||
}
|
||||
go func() { wg.Wait(); close(results) }()
|
||||
|
||||
pending := make(map[int][]byte)
|
||||
next := 0
|
||||
completed := 0
|
||||
for r := range results {
|
||||
if r.err != nil {
|
||||
if r.err == ctx.Err() {
|
||||
return r.err
|
||||
}
|
||||
return fmt.Errorf("media channel %d block %d: %w", channel, int(startBlock)+r.idx, r.err)
|
||||
}
|
||||
pending[r.idx] = r.data
|
||||
for {
|
||||
payload, ok := pending[next]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if _, werr := w.Write(payload); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if flusher, ok := w.(interface{ Flush() }); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
next++
|
||||
}
|
||||
completed++
|
||||
if progress != nil {
|
||||
progress(completed, int(count))
|
||||
}
|
||||
}
|
||||
if next != int(count) {
|
||||
return fmt.Errorf("media channel %d: incomplete (%d / %d)", channel, next, count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,207 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"hash/crc32"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/sartoopjj/thefeed/internal/protocol"
|
||||
)
|
||||
|
||||
// withMediaHeader prepends the protocol media block header to body. The
|
||||
// CRC32 is computed over the DECOMPRESSED bytes the caller passes in, but
|
||||
// `body` itself is what the server would have produced after compressing —
|
||||
// which for compression=none is just the bytes themselves.
|
||||
func withMediaHeader(crc uint32, body []byte, compression protocol.MediaCompression) []byte {
|
||||
hdr := protocol.EncodeMediaBlockHeader(protocol.MediaBlockHeader{
|
||||
CRC32: crc,
|
||||
Version: protocol.MediaHeaderVersion,
|
||||
Compression: compression,
|
||||
})
|
||||
out := make([]byte, 0, len(hdr)+len(body))
|
||||
out = append(out, hdr...)
|
||||
out = append(out, body...)
|
||||
return out
|
||||
}
|
||||
|
||||
func gzipBytes(t *testing.T, b []byte) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
zw, _ := gzip.NewWriterLevel(&buf, gzip.BestCompression)
|
||||
if _, err := zw.Write(b); err != nil {
|
||||
t.Fatalf("gzip: %v", err)
|
||||
}
|
||||
if err := zw.Close(); err != nil {
|
||||
t.Fatalf("gzip close: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// blockMockExchange wires the fetcher's exchangeFn so each (channel, block)
|
||||
// pair returns the matching slice from blocks.
|
||||
func blockMockExchange(f *Fetcher, want uint16, blocks [][]byte) func(context.Context, *dns.Msg, string) (*dns.Msg, time.Duration, error) {
|
||||
return func(ctx context.Context, m *dns.Msg, _ string) (*dns.Msg, time.Duration, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
ch, blk, err := protocol.DecodeQuery(f.queryKey, m.Question[0].Name, f.domain)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if ch != want {
|
||||
return nil, 0, errFakeNotFound{}
|
||||
}
|
||||
if int(blk) >= len(blocks) {
|
||||
return nil, 0, errFakeNotFound{}
|
||||
}
|
||||
encoded, encErr := protocol.EncodeResponse(f.responseKey, blocks[int(blk)], 0)
|
||||
if encErr != nil {
|
||||
return nil, 0, encErr
|
||||
}
|
||||
resp := new(dns.Msg)
|
||||
resp.SetReply(m)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
resp.Answer = []dns.RR{&dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0},
|
||||
Txt: []string{encoded},
|
||||
}}
|
||||
return resp, time.Millisecond, nil
|
||||
}
|
||||
}
|
||||
|
||||
type errFakeNotFound struct{}
|
||||
|
||||
func (errFakeNotFound) Error() string { return "fake nxdomain" }
|
||||
|
||||
func TestFetchMediaUncompressed(t *testing.T) {
|
||||
f := newTestFetcher(t, []string{"1.1.1.1:53"})
|
||||
original := make([]byte, 1500)
|
||||
if _, err := rand.Read(original); err != nil {
|
||||
t.Fatalf("rand: %v", err)
|
||||
}
|
||||
crc := crc32.ChecksumIEEE(original)
|
||||
blocks := protocol.SplitIntoBlocks(withMediaHeader(crc, original, protocol.MediaCompressionNone))
|
||||
|
||||
channel := protocol.MediaChannelStart + 7
|
||||
f.exchangeFn = blockMockExchange(f, channel, blocks)
|
||||
|
||||
out, err := f.FetchMedia(context.Background(), channel, uint16(len(blocks)), crc, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("FetchMedia: %v", err)
|
||||
}
|
||||
if !bytes.Equal(out, original) {
|
||||
t.Fatalf("decompressed output differs from original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchMediaDeflate(t *testing.T) {
|
||||
f := newTestFetcher(t, []string{"1.1.1.1:53"})
|
||||
original := bytes.Repeat([]byte("xy "), 250)
|
||||
crc := crc32.ChecksumIEEE(original)
|
||||
var buf bytes.Buffer
|
||||
zw, _ := flate.NewWriter(&buf, flate.BestCompression)
|
||||
zw.Write(original)
|
||||
zw.Close()
|
||||
blocks := protocol.SplitIntoBlocks(withMediaHeader(crc, buf.Bytes(), protocol.MediaCompressionDeflate))
|
||||
|
||||
channel := protocol.MediaChannelStart + 9
|
||||
f.exchangeFn = blockMockExchange(f, channel, blocks)
|
||||
out, err := f.FetchMedia(context.Background(), channel, uint16(len(blocks)), crc, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("FetchMedia: %v", err)
|
||||
}
|
||||
if !bytes.Equal(out, original) {
|
||||
t.Fatalf("decompressed differs from original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchMediaGzip(t *testing.T) {
|
||||
f := newTestFetcher(t, []string{"1.1.1.1:53"})
|
||||
original := bytes.Repeat([]byte("abc123 "), 200) // compressible
|
||||
crc := crc32.ChecksumIEEE(original)
|
||||
body := gzipBytes(t, original)
|
||||
blocks := protocol.SplitIntoBlocks(withMediaHeader(crc, body, protocol.MediaCompressionGzip))
|
||||
|
||||
channel := protocol.MediaChannelStart + 8
|
||||
f.exchangeFn = blockMockExchange(f, channel, blocks)
|
||||
|
||||
out, err := f.FetchMedia(context.Background(), channel, uint16(len(blocks)), crc, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("FetchMedia: %v", err)
|
||||
}
|
||||
if !bytes.Equal(out, original) {
|
||||
t.Fatalf("decompressed output differs from original")
|
||||
}
|
||||
if len(body) >= len(original) {
|
||||
t.Fatalf("compressed body should be smaller than original (got %d vs %d)", len(body), len(original))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchMediaRejectsNonMediaChannel(t *testing.T) {
|
||||
f := newTestFetcher(t, []string{"1.1.1.1:53"})
|
||||
_, err := f.FetchMedia(context.Background(), 1, 1, 0, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for non-media channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchMediaRejectsBadHash(t *testing.T) {
|
||||
f := newTestFetcher(t, []string{"1.1.1.1:53"})
|
||||
original := []byte("hello hash mismatch")
|
||||
crc := crc32.ChecksumIEEE(original)
|
||||
blocks := [][]byte{withMediaHeader(crc, original, protocol.MediaCompressionNone)}
|
||||
channel := protocol.MediaChannelStart + 1
|
||||
f.exchangeFn = blockMockExchange(f, channel, blocks)
|
||||
|
||||
_, err := f.FetchMedia(context.Background(), channel, 1, 0xDEADBEEF, nil)
|
||||
if err != ErrMediaHashMismatch {
|
||||
t.Fatalf("err = %v, want ErrMediaHashMismatch", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchMediaBlocksStreamWritesInOrder(t *testing.T) {
|
||||
f := newTestFetcher(t, []string{"1.1.1.1:53"})
|
||||
blocks := [][]byte{
|
||||
[]byte("alpha"),
|
||||
[]byte("beta"),
|
||||
[]byte("gamma"),
|
||||
}
|
||||
channel := protocol.MediaChannelStart + 12
|
||||
f.exchangeFn = blockMockExchange(f, channel, blocks)
|
||||
|
||||
var got bytes.Buffer
|
||||
if err := f.FetchMediaBlocksStream(context.Background(), channel, 0, 3, &got, nil); err != nil {
|
||||
t.Fatalf("FetchMediaBlocksStream: %v", err)
|
||||
}
|
||||
want := append(append(append([]byte{}, blocks[0]...), blocks[1]...), blocks[2]...)
|
||||
if !bytes.Equal(got.Bytes(), want) {
|
||||
t.Fatalf("got %q, want %q", got.Bytes(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchMediaBlocksStreamPartialRange(t *testing.T) {
|
||||
f := newTestFetcher(t, []string{"1.1.1.1:53"})
|
||||
blocks := [][]byte{
|
||||
[]byte("first-block"),
|
||||
[]byte("second-block"),
|
||||
[]byte("third-block"),
|
||||
}
|
||||
channel := protocol.MediaChannelStart + 13
|
||||
f.exchangeFn = blockMockExchange(f, channel, blocks)
|
||||
|
||||
var got bytes.Buffer
|
||||
if err := f.FetchMediaBlocksStream(context.Background(), channel, 1, 2, &got, nil); err != nil {
|
||||
t.Fatalf("FetchMediaBlocksStream: %v", err)
|
||||
}
|
||||
want := append(append([]byte{}, blocks[1]...), blocks[2]...)
|
||||
if !bytes.Equal(got.Bytes(), want) {
|
||||
t.Fatalf("got %q, want %q", got.Bytes(), want)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user