feat: media download with DNS query

This commit is contained in:
Sarto
2026-04-29 01:45:27 +03:30
parent 11946c0147
commit b4e9cd8714
34 changed files with 6303 additions and 137 deletions
+78 -10
View File
@@ -238,27 +238,95 @@ func (c *Cache) Cleanup() error {
return nil
}
// detectGaps finds places in a sorted message list where consecutive IDs differ
// by more than 1. Gaps larger than 500 are ignored (natural Telegram numbering).
// Returns nil when there are fewer than 10 messages (not enough history to judge).
// detectGaps finds runs of missing IDs between consecutive messages. Album-
// merged canonicals cover a contiguous span of sibling IDs (counted via
// albumSpan), so absorbed siblings don't show up as fake gaps. Diffs > 500
// are ignored (natural Telegram numbering jumps); under 10 messages we don't
// have enough history to judge.
func detectGaps(msgs []protocol.Message) []Gap {
if len(msgs) < 10 {
return nil
}
var gaps []Gap
for i := 1; i < len(msgs); i++ {
prev, cur := msgs[i-1].ID, msgs[i].ID
if diff := cur - prev; diff > 1 && diff <= 500 {
gaps = append(gaps, Gap{
AfterID: prev,
BeforeID: cur,
Count: int(diff - 1),
})
prev, cur := msgs[i-1], msgs[i]
span := uint32(albumSpan(prev.Text))
if span == 0 {
span = 1
}
expectedNext := prev.ID + span
if cur.ID <= expectedNext {
continue
}
diff := cur.ID - expectedNext
if diff > 500 {
continue
}
gaps = append(gaps, Gap{
AfterID: expectedNext - 1,
BeforeID: cur.ID,
Count: int(diff),
})
}
return gaps
}
// mediaHeaderTags are the leading [TAG] markers extractMessages may stack
// at the start of a canonical message body — one per absorbed album item.
var mediaHeaderTags = []string{
protocol.MediaImage,
protocol.MediaVideo,
protocol.MediaFile,
protocol.MediaAudio,
protocol.MediaSticker,
protocol.MediaGIF,
protocol.MediaLocation,
protocol.MediaContact,
}
// albumSpan counts the leading media-header lines in a canonical body — 0
// for plain text, 1 for a single media item, N for an N-item album. A
// leading [REPLY]... line is skipped first.
func albumSpan(text string) int {
if strings.HasPrefix(text, protocol.MediaReply) {
nl := strings.IndexByte(text, '\n')
if nl < 0 {
return 0
}
text = text[nl+1:]
}
n := 0
for _, line := range strings.Split(text, "\n") {
if !isMediaHeaderLine(line) {
break
}
n++
}
return n
}
// isMediaHeaderLine matches both the bare [TAG] form and the downloadable
// "[TAG]<digit>..." form. Caption text that happens to start with "[IMAGE]"
// is rejected because rest[0] won't be a digit.
func isMediaHeaderLine(line string) bool {
for _, tag := range mediaHeaderTags {
if line == tag {
return true
}
if !strings.HasPrefix(line, tag) {
continue
}
rest := line[len(tag):]
if rest == "" {
return true
}
if rest[0] >= '0' && rest[0] <= '9' {
return true
}
}
return false
}
// channelPath returns the file path for a channel's cache, keyed by sanitised name.
// Only letters, digits, hyphens, and underscores are kept; everything else becomes _.
func (c *Cache) channelPath(channelName string) string {
+105
View File
@@ -221,6 +221,111 @@ func TestCacheGapDetection_NoGapWhenFewMessages(t *testing.T) {
}
}
func TestCacheGapDetection_AlbumNoFalsePositive(t *testing.T) {
cache, _ := NewCache(t.TempDir())
// 10 sequential messages where ID=5 is a 2-image album: it absorbs ID 6
// (a real Telegram behaviour). The next message is ID 7. Without the
// album-aware fix, the gap detector would flag a missing ID 6.
msgs := []protocol.Message{
{ID: 1, Timestamp: 1700000000, Text: "a"},
{ID: 2, Timestamp: 1700000001, Text: "b"},
{ID: 3, Timestamp: 1700000002, Text: "c"},
{ID: 4, Timestamp: 1700000003, Text: "d"},
{ID: 5, Timestamp: 1700000004, Text: "[IMAGE]100:0:0:0:abcd1234:img1.jpg\n[IMAGE]200:0:0:0:abcd5678:img2.jpg\nalbum caption"},
// ID 6 is absorbed into the album above; the feed jumps to 7.
{ID: 7, Timestamp: 1700000005, Text: "e"},
{ID: 8, Timestamp: 1700000006, Text: "f"},
{ID: 9, Timestamp: 1700000007, Text: "g"},
{ID: 10, Timestamp: 1700000008, Text: "h"},
{ID: 11, Timestamp: 1700000009, Text: "i"},
}
result, _ := cache.MergeAndPut("albumchan", msgs)
if len(result.Gaps) != 0 {
t.Errorf("album-absorbed sibling should not be flagged as a gap, got %+v", result.Gaps)
}
}
func TestCacheGapDetection_AlbumWithRealGap(t *testing.T) {
cache, _ := NewCache(t.TempDir())
// 3-image album at ID=5 absorbs IDs 6,7. A real gap of IDs 8,9 follows
// before ID=10. The detector should report a single 2-message gap.
msgs := []protocol.Message{
{ID: 1, Timestamp: 1700000000, Text: "a"},
{ID: 2, Timestamp: 1700000001, Text: "b"},
{ID: 3, Timestamp: 1700000002, Text: "c"},
{ID: 4, Timestamp: 1700000003, Text: "d"},
{ID: 5, Timestamp: 1700000004, Text: "[IMAGE]100:0:0:0:aaaaaaaa:1.jpg\n[IMAGE]200:0:0:0:bbbbbbbb:2.jpg\n[IMAGE]300:0:0:0:cccccccc:3.jpg\ncap"},
// IDs 6,7 absorbed; IDs 8,9 truly missing; resume at 10.
{ID: 10, Timestamp: 1700000010, Text: "e"},
{ID: 11, Timestamp: 1700000011, Text: "f"},
{ID: 12, Timestamp: 1700000012, Text: "g"},
{ID: 13, Timestamp: 1700000013, Text: "h"},
{ID: 14, Timestamp: 1700000014, Text: "i"},
{ID: 15, Timestamp: 1700000015, Text: "j"},
}
result, _ := cache.MergeAndPut("albumgap", msgs)
if len(result.Gaps) != 1 {
t.Fatalf("expected exactly one gap, got %+v", result.Gaps)
}
g := result.Gaps[0]
if g.AfterID != 7 || g.BeforeID != 10 || g.Count != 2 {
t.Errorf("gap = %+v, want AfterID=7 BeforeID=10 Count=2", g)
}
}
func TestCacheGapDetection_AlbumWithReplyPrefix(t *testing.T) {
cache, _ := NewCache(t.TempDir())
// [REPLY]:42 prefix before the media headers should still let albumSpan
// count the headers correctly.
msgs := []protocol.Message{
{ID: 1, Timestamp: 1700000000, Text: "a"},
{ID: 2, Timestamp: 1700000001, Text: "b"},
{ID: 3, Timestamp: 1700000002, Text: "c"},
{ID: 4, Timestamp: 1700000003, Text: "d"},
{ID: 5, Timestamp: 1700000004, Text: "[REPLY]:42\n[IMAGE]100:0:0:0:aaaaaaaa:1.jpg\n[IMAGE]200:0:0:0:bbbbbbbb:2.jpg\nreplied caption"},
// ID 6 absorbed.
{ID: 7, Timestamp: 1700000005, Text: "e"},
{ID: 8, Timestamp: 1700000006, Text: "f"},
{ID: 9, Timestamp: 1700000007, Text: "g"},
{ID: 10, Timestamp: 1700000008, Text: "h"},
{ID: 11, Timestamp: 1700000009, Text: "i"},
}
result, _ := cache.MergeAndPut("replychan", msgs)
if len(result.Gaps) != 0 {
t.Errorf("album with reply prefix should not produce false gaps, got %+v", result.Gaps)
}
}
func TestAlbumSpan(t *testing.T) {
cases := []struct {
name string
text string
want int
}{
{"plain text", "hello world", 0},
{"single image legacy", "[IMAGE]\ncaption", 1},
{"single image downloadable", "[IMAGE]100:0:0:0:abcd1234:f.jpg\ncap", 1},
{"two images", "[IMAGE]100:0:0:0:aa:1.jpg\n[IMAGE]200:0:0:0:bb:2.jpg\ncap", 2},
{"three mixed", "[IMAGE]1:0:0:0:aa:a.jpg\n[VIDEO]2:0:0:0:bb:b.mp4\n[FILE]3:0:0:0:cc:c.pdf\nx", 3},
{"with reply prefix", "[REPLY]:99\n[IMAGE]100:0:0:0:aa:1.jpg\n[IMAGE]200:0:0:0:bb:2.jpg\ncap", 2},
{"reply only no media", "[REPLY]:99\nhello", 0},
{"caption that mentions a tag", "look at this [IMAGE] thing", 0},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
if got := albumSpan(c.text); got != c.want {
t.Errorf("albumSpan(%q) = %d, want %d", c.text, got, c.want)
}
})
}
}
func TestCacheGapDetection_LargeGapIgnored(t *testing.T) {
cache, _ := NewCache(t.TempDir())
+239
View File
@@ -0,0 +1,239 @@
package client
import (
"bytes"
"compress/flate"
"compress/gzip"
"context"
"fmt"
"hash/crc32"
"io"
"sync"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// DecompressMediaReader wraps r per the given compression.
func DecompressMediaReader(r io.Reader, compression protocol.MediaCompression) (io.ReadCloser, error) {
switch compression {
case protocol.MediaCompressionNone:
return io.NopCloser(r), nil
case protocol.MediaCompressionGzip:
return gzip.NewReader(r)
case protocol.MediaCompressionDeflate:
return flate.NewReader(r), nil
}
return nil, fmt.Errorf("unsupported media compression: %d", compression)
}
func decompressMediaBytes(body []byte, compression protocol.MediaCompression) ([]byte, error) {
rc, err := DecompressMediaReader(bytes.NewReader(body), compression)
if err != nil {
return nil, err
}
defer rc.Close()
return io.ReadAll(rc)
}
// MediaProgress reports per-block progress (completed of total). May be
// invoked from a background goroutine.
type MediaProgress func(completed, total int)
// MediaBlockHeaderLen re-exports the protocol header length so callers in
// the web layer don't have to import the protocol package twice.
const MediaBlockHeaderLen = protocol.MediaBlockHeaderLen
// ErrMediaHashMismatch indicates the assembled bytes don't match the
// expected CRC32. The caller must discard the returned bytes.
var ErrMediaHashMismatch = fmt.Errorf("media content hash mismatch")
// mediaBlockOuterRetries is the per-block retry budget the media path adds
// on top of FetchBlock's own internal retries. A ~200-block file can lose
// individual blocks repeatedly; without this, one persistent bad block
// kills the whole download even though FetchBlock would succeed on a
// later attempt.
const mediaBlockOuterRetries = 5
func (f *Fetcher) fetchMediaBlock(ctx context.Context, channel, block uint16) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < mediaBlockOuterRetries; attempt++ {
if ctx.Err() != nil {
return nil, ctx.Err()
}
data, err := f.FetchBlock(ctx, channel, block)
if err == nil {
return data, nil
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
lastErr = err
}
return nil, lastErr
}
// FetchMedia returns the assembled bytes of a media blob served on a media
// channel, optionally verifying expectedCRC32.
func (f *Fetcher) FetchMedia(ctx context.Context, channel uint16, blockCount uint16, expectedCRC32 uint32, progress MediaProgress) ([]byte, error) {
if !protocol.IsMediaChannel(channel) {
return nil, fmt.Errorf("channel %d is outside media range", channel)
}
if blockCount == 0 {
return nil, nil
}
type blockResult struct {
idx int
data []byte
err error
}
results := make(chan blockResult, blockCount)
sem := make(chan struct{}, 5)
var wg sync.WaitGroup
for i := 0; i < int(blockCount); i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
select {
case sem <- struct{}{}:
case <-ctx.Done():
results <- blockResult{idx: idx, err: ctx.Err()}
return
}
defer func() { <-sem }()
data, err := f.fetchMediaBlock(ctx, channel, uint16(idx))
results <- blockResult{idx: idx, data: data, err: err}
}(i)
}
go func() {
wg.Wait()
close(results)
}()
ordered := make([][]byte, blockCount)
completed := 0
var progMu sync.Mutex
for r := range results {
if r.err != nil {
if r.err == ctx.Err() {
return nil, r.err
}
return nil, fmt.Errorf("media channel %d block %d: %w", channel, r.idx, r.err)
}
ordered[r.idx] = r.data
completed++
if progress != nil {
progMu.Lock()
progress(completed, int(blockCount))
progMu.Unlock()
}
}
if len(ordered) == 0 || len(ordered[0]) < protocol.MediaBlockHeaderLen {
return nil, fmt.Errorf("media channel %d: malformed block 0", channel)
}
header, err := protocol.DecodeMediaBlockHeader(ordered[0][:protocol.MediaBlockHeaderLen])
if err != nil {
return nil, fmt.Errorf("media channel %d: %w", channel, err)
}
if expectedCRC32 != 0 && header.CRC32 != expectedCRC32 {
return nil, ErrMediaHashMismatch
}
// Concatenate all block bytes after the header.
total := len(ordered[0]) - protocol.MediaBlockHeaderLen
for i := 1; i < len(ordered); i++ {
total += len(ordered[i])
}
body := make([]byte, 0, total)
body = append(body, ordered[0][protocol.MediaBlockHeaderLen:]...)
for i := 1; i < len(ordered); i++ {
body = append(body, ordered[i]...)
}
// Decompress per the header.
out, err := decompressMediaBytes(body, header.Compression)
if err != nil {
return nil, fmt.Errorf("decompress media channel %d: %w", channel, err)
}
if expectedCRC32 != 0 {
if got := crc32.ChecksumIEEE(out); got != expectedCRC32 {
return nil, ErrMediaHashMismatch
}
}
return out, nil
}
// FetchMediaBlocksStream fetches blocks [startBlock, startBlock+count) and
// writes each block's raw bytes to w in order as soon as they become
// contiguous. No header parsing; callers slice off the protocol header
// themselves and decompress as appropriate. Cancelling ctx aborts both
// in-flight DNS queries and pending writes.
func (f *Fetcher) FetchMediaBlocksStream(ctx context.Context, channel, startBlock, count uint16, w io.Writer, progress MediaProgress) error {
if !protocol.IsMediaChannel(channel) {
return fmt.Errorf("channel %d is outside media range", channel)
}
if count == 0 {
return nil
}
type blockResult struct {
idx int
data []byte
err error
}
results := make(chan blockResult, count)
sem := make(chan struct{}, 5)
var wg sync.WaitGroup
for i := 0; i < int(count); i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
select {
case sem <- struct{}{}:
case <-ctx.Done():
results <- blockResult{idx: idx, err: ctx.Err()}
return
}
defer func() { <-sem }()
data, err := f.fetchMediaBlock(ctx, channel, uint16(int(startBlock)+idx))
results <- blockResult{idx: idx, data: data, err: err}
}(i)
}
go func() { wg.Wait(); close(results) }()
pending := make(map[int][]byte)
next := 0
completed := 0
for r := range results {
if r.err != nil {
if r.err == ctx.Err() {
return r.err
}
return fmt.Errorf("media channel %d block %d: %w", channel, int(startBlock)+r.idx, r.err)
}
pending[r.idx] = r.data
for {
payload, ok := pending[next]
if !ok {
break
}
if _, werr := w.Write(payload); werr != nil {
return werr
}
if flusher, ok := w.(interface{ Flush() }); ok {
flusher.Flush()
}
next++
}
completed++
if progress != nil {
progress(completed, int(count))
}
}
if next != int(count) {
return fmt.Errorf("media channel %d: incomplete (%d / %d)", channel, next, count)
}
return nil
}
+207
View File
@@ -0,0 +1,207 @@
package client
import (
"bytes"
"compress/flate"
"compress/gzip"
"context"
"crypto/rand"
"hash/crc32"
"testing"
"time"
"github.com/miekg/dns"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// withMediaHeader prepends the protocol media block header to body. The
// CRC32 is computed over the DECOMPRESSED bytes the caller passes in, but
// `body` itself is what the server would have produced after compressing —
// which for compression=none is just the bytes themselves.
func withMediaHeader(crc uint32, body []byte, compression protocol.MediaCompression) []byte {
hdr := protocol.EncodeMediaBlockHeader(protocol.MediaBlockHeader{
CRC32: crc,
Version: protocol.MediaHeaderVersion,
Compression: compression,
})
out := make([]byte, 0, len(hdr)+len(body))
out = append(out, hdr...)
out = append(out, body...)
return out
}
func gzipBytes(t *testing.T, b []byte) []byte {
t.Helper()
var buf bytes.Buffer
zw, _ := gzip.NewWriterLevel(&buf, gzip.BestCompression)
if _, err := zw.Write(b); err != nil {
t.Fatalf("gzip: %v", err)
}
if err := zw.Close(); err != nil {
t.Fatalf("gzip close: %v", err)
}
return buf.Bytes()
}
// blockMockExchange wires the fetcher's exchangeFn so each (channel, block)
// pair returns the matching slice from blocks.
func blockMockExchange(f *Fetcher, want uint16, blocks [][]byte) func(context.Context, *dns.Msg, string) (*dns.Msg, time.Duration, error) {
return func(ctx context.Context, m *dns.Msg, _ string) (*dns.Msg, time.Duration, error) {
if err := ctx.Err(); err != nil {
return nil, 0, err
}
ch, blk, err := protocol.DecodeQuery(f.queryKey, m.Question[0].Name, f.domain)
if err != nil {
return nil, 0, err
}
if ch != want {
return nil, 0, errFakeNotFound{}
}
if int(blk) >= len(blocks) {
return nil, 0, errFakeNotFound{}
}
encoded, encErr := protocol.EncodeResponse(f.responseKey, blocks[int(blk)], 0)
if encErr != nil {
return nil, 0, encErr
}
resp := new(dns.Msg)
resp.SetReply(m)
resp.Rcode = dns.RcodeSuccess
resp.Answer = []dns.RR{&dns.TXT{
Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0},
Txt: []string{encoded},
}}
return resp, time.Millisecond, nil
}
}
type errFakeNotFound struct{}
func (errFakeNotFound) Error() string { return "fake nxdomain" }
func TestFetchMediaUncompressed(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
original := make([]byte, 1500)
if _, err := rand.Read(original); err != nil {
t.Fatalf("rand: %v", err)
}
crc := crc32.ChecksumIEEE(original)
blocks := protocol.SplitIntoBlocks(withMediaHeader(crc, original, protocol.MediaCompressionNone))
channel := protocol.MediaChannelStart + 7
f.exchangeFn = blockMockExchange(f, channel, blocks)
out, err := f.FetchMedia(context.Background(), channel, uint16(len(blocks)), crc, nil)
if err != nil {
t.Fatalf("FetchMedia: %v", err)
}
if !bytes.Equal(out, original) {
t.Fatalf("decompressed output differs from original")
}
}
func TestFetchMediaDeflate(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
original := bytes.Repeat([]byte("xy "), 250)
crc := crc32.ChecksumIEEE(original)
var buf bytes.Buffer
zw, _ := flate.NewWriter(&buf, flate.BestCompression)
zw.Write(original)
zw.Close()
blocks := protocol.SplitIntoBlocks(withMediaHeader(crc, buf.Bytes(), protocol.MediaCompressionDeflate))
channel := protocol.MediaChannelStart + 9
f.exchangeFn = blockMockExchange(f, channel, blocks)
out, err := f.FetchMedia(context.Background(), channel, uint16(len(blocks)), crc, nil)
if err != nil {
t.Fatalf("FetchMedia: %v", err)
}
if !bytes.Equal(out, original) {
t.Fatalf("decompressed differs from original")
}
}
func TestFetchMediaGzip(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
original := bytes.Repeat([]byte("abc123 "), 200) // compressible
crc := crc32.ChecksumIEEE(original)
body := gzipBytes(t, original)
blocks := protocol.SplitIntoBlocks(withMediaHeader(crc, body, protocol.MediaCompressionGzip))
channel := protocol.MediaChannelStart + 8
f.exchangeFn = blockMockExchange(f, channel, blocks)
out, err := f.FetchMedia(context.Background(), channel, uint16(len(blocks)), crc, nil)
if err != nil {
t.Fatalf("FetchMedia: %v", err)
}
if !bytes.Equal(out, original) {
t.Fatalf("decompressed output differs from original")
}
if len(body) >= len(original) {
t.Fatalf("compressed body should be smaller than original (got %d vs %d)", len(body), len(original))
}
}
func TestFetchMediaRejectsNonMediaChannel(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
_, err := f.FetchMedia(context.Background(), 1, 1, 0, nil)
if err == nil {
t.Fatalf("expected error for non-media channel")
}
}
func TestFetchMediaRejectsBadHash(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
original := []byte("hello hash mismatch")
crc := crc32.ChecksumIEEE(original)
blocks := [][]byte{withMediaHeader(crc, original, protocol.MediaCompressionNone)}
channel := protocol.MediaChannelStart + 1
f.exchangeFn = blockMockExchange(f, channel, blocks)
_, err := f.FetchMedia(context.Background(), channel, 1, 0xDEADBEEF, nil)
if err != ErrMediaHashMismatch {
t.Fatalf("err = %v, want ErrMediaHashMismatch", err)
}
}
func TestFetchMediaBlocksStreamWritesInOrder(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
blocks := [][]byte{
[]byte("alpha"),
[]byte("beta"),
[]byte("gamma"),
}
channel := protocol.MediaChannelStart + 12
f.exchangeFn = blockMockExchange(f, channel, blocks)
var got bytes.Buffer
if err := f.FetchMediaBlocksStream(context.Background(), channel, 0, 3, &got, nil); err != nil {
t.Fatalf("FetchMediaBlocksStream: %v", err)
}
want := append(append(append([]byte{}, blocks[0]...), blocks[1]...), blocks[2]...)
if !bytes.Equal(got.Bytes(), want) {
t.Fatalf("got %q, want %q", got.Bytes(), want)
}
}
func TestFetchMediaBlocksStreamPartialRange(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
blocks := [][]byte{
[]byte("first-block"),
[]byte("second-block"),
[]byte("third-block"),
}
channel := protocol.MediaChannelStart + 13
f.exchangeFn = blockMockExchange(f, channel, blocks)
var got bytes.Buffer
if err := f.FetchMediaBlocksStream(context.Background(), channel, 1, 2, &got, nil); err != nil {
t.Fatalf("FetchMediaBlocksStream: %v", err)
}
want := append(append([]byte{}, blocks[1]...), blocks[2]...)
if !bytes.Equal(got.Bytes(), want) {
t.Fatalf("got %q, want %q", got.Bytes(), want)
}
}