feat: media download with DNS query

This commit is contained in:
Sarto
2026-04-29 01:45:27 +03:30
parent 11946c0147
commit b4e9cd8714
34 changed files with 6303 additions and 137 deletions
+78 -10
View File
@@ -238,27 +238,95 @@ func (c *Cache) Cleanup() error {
return nil
}
// detectGaps finds places in a sorted message list where consecutive IDs differ
// by more than 1. Gaps larger than 500 are ignored (natural Telegram numbering).
// Returns nil when there are fewer than 10 messages (not enough history to judge).
// detectGaps finds runs of missing IDs between consecutive messages. Album-
// merged canonicals cover a contiguous span of sibling IDs (counted via
// albumSpan), so absorbed siblings don't show up as fake gaps. Diffs > 500
// are ignored (natural Telegram numbering jumps); under 10 messages we don't
// have enough history to judge.
func detectGaps(msgs []protocol.Message) []Gap {
if len(msgs) < 10 {
return nil
}
var gaps []Gap
for i := 1; i < len(msgs); i++ {
prev, cur := msgs[i-1].ID, msgs[i].ID
if diff := cur - prev; diff > 1 && diff <= 500 {
gaps = append(gaps, Gap{
AfterID: prev,
BeforeID: cur,
Count: int(diff - 1),
})
prev, cur := msgs[i-1], msgs[i]
span := uint32(albumSpan(prev.Text))
if span == 0 {
span = 1
}
expectedNext := prev.ID + span
if cur.ID <= expectedNext {
continue
}
diff := cur.ID - expectedNext
if diff > 500 {
continue
}
gaps = append(gaps, Gap{
AfterID: expectedNext - 1,
BeforeID: cur.ID,
Count: int(diff),
})
}
return gaps
}
// mediaHeaderTags are the leading [TAG] markers extractMessages may stack
// at the start of a canonical message body — one per absorbed album item.
var mediaHeaderTags = []string{
protocol.MediaImage,
protocol.MediaVideo,
protocol.MediaFile,
protocol.MediaAudio,
protocol.MediaSticker,
protocol.MediaGIF,
protocol.MediaLocation,
protocol.MediaContact,
}
// albumSpan counts the leading media-header lines in a canonical body — 0
// for plain text, 1 for a single media item, N for an N-item album. A
// leading [REPLY]... line is skipped first.
func albumSpan(text string) int {
if strings.HasPrefix(text, protocol.MediaReply) {
nl := strings.IndexByte(text, '\n')
if nl < 0 {
return 0
}
text = text[nl+1:]
}
n := 0
for _, line := range strings.Split(text, "\n") {
if !isMediaHeaderLine(line) {
break
}
n++
}
return n
}
// isMediaHeaderLine matches both the bare [TAG] form and the downloadable
// "[TAG]<digit>..." form. Caption text that happens to start with "[IMAGE]"
// is rejected because rest[0] won't be a digit.
func isMediaHeaderLine(line string) bool {
for _, tag := range mediaHeaderTags {
if line == tag {
return true
}
if !strings.HasPrefix(line, tag) {
continue
}
rest := line[len(tag):]
if rest == "" {
return true
}
if rest[0] >= '0' && rest[0] <= '9' {
return true
}
}
return false
}
// channelPath returns the file path for a channel's cache, keyed by sanitised name.
// Only letters, digits, hyphens, and underscores are kept; everything else becomes _.
func (c *Cache) channelPath(channelName string) string {
+105
View File
@@ -221,6 +221,111 @@ func TestCacheGapDetection_NoGapWhenFewMessages(t *testing.T) {
}
}
func TestCacheGapDetection_AlbumNoFalsePositive(t *testing.T) {
cache, _ := NewCache(t.TempDir())
// 10 sequential messages where ID=5 is a 2-image album: it absorbs ID 6
// (a real Telegram behaviour). The next message is ID 7. Without the
// album-aware fix, the gap detector would flag a missing ID 6.
msgs := []protocol.Message{
{ID: 1, Timestamp: 1700000000, Text: "a"},
{ID: 2, Timestamp: 1700000001, Text: "b"},
{ID: 3, Timestamp: 1700000002, Text: "c"},
{ID: 4, Timestamp: 1700000003, Text: "d"},
{ID: 5, Timestamp: 1700000004, Text: "[IMAGE]100:0:0:0:abcd1234:img1.jpg\n[IMAGE]200:0:0:0:abcd5678:img2.jpg\nalbum caption"},
// ID 6 is absorbed into the album above; the feed jumps to 7.
{ID: 7, Timestamp: 1700000005, Text: "e"},
{ID: 8, Timestamp: 1700000006, Text: "f"},
{ID: 9, Timestamp: 1700000007, Text: "g"},
{ID: 10, Timestamp: 1700000008, Text: "h"},
{ID: 11, Timestamp: 1700000009, Text: "i"},
}
result, _ := cache.MergeAndPut("albumchan", msgs)
if len(result.Gaps) != 0 {
t.Errorf("album-absorbed sibling should not be flagged as a gap, got %+v", result.Gaps)
}
}
func TestCacheGapDetection_AlbumWithRealGap(t *testing.T) {
cache, _ := NewCache(t.TempDir())
// 3-image album at ID=5 absorbs IDs 6,7. A real gap of IDs 8,9 follows
// before ID=10. The detector should report a single 2-message gap.
msgs := []protocol.Message{
{ID: 1, Timestamp: 1700000000, Text: "a"},
{ID: 2, Timestamp: 1700000001, Text: "b"},
{ID: 3, Timestamp: 1700000002, Text: "c"},
{ID: 4, Timestamp: 1700000003, Text: "d"},
{ID: 5, Timestamp: 1700000004, Text: "[IMAGE]100:0:0:0:aaaaaaaa:1.jpg\n[IMAGE]200:0:0:0:bbbbbbbb:2.jpg\n[IMAGE]300:0:0:0:cccccccc:3.jpg\ncap"},
// IDs 6,7 absorbed; IDs 8,9 truly missing; resume at 10.
{ID: 10, Timestamp: 1700000010, Text: "e"},
{ID: 11, Timestamp: 1700000011, Text: "f"},
{ID: 12, Timestamp: 1700000012, Text: "g"},
{ID: 13, Timestamp: 1700000013, Text: "h"},
{ID: 14, Timestamp: 1700000014, Text: "i"},
{ID: 15, Timestamp: 1700000015, Text: "j"},
}
result, _ := cache.MergeAndPut("albumgap", msgs)
if len(result.Gaps) != 1 {
t.Fatalf("expected exactly one gap, got %+v", result.Gaps)
}
g := result.Gaps[0]
if g.AfterID != 7 || g.BeforeID != 10 || g.Count != 2 {
t.Errorf("gap = %+v, want AfterID=7 BeforeID=10 Count=2", g)
}
}
func TestCacheGapDetection_AlbumWithReplyPrefix(t *testing.T) {
cache, _ := NewCache(t.TempDir())
// [REPLY]:42 prefix before the media headers should still let albumSpan
// count the headers correctly.
msgs := []protocol.Message{
{ID: 1, Timestamp: 1700000000, Text: "a"},
{ID: 2, Timestamp: 1700000001, Text: "b"},
{ID: 3, Timestamp: 1700000002, Text: "c"},
{ID: 4, Timestamp: 1700000003, Text: "d"},
{ID: 5, Timestamp: 1700000004, Text: "[REPLY]:42\n[IMAGE]100:0:0:0:aaaaaaaa:1.jpg\n[IMAGE]200:0:0:0:bbbbbbbb:2.jpg\nreplied caption"},
// ID 6 absorbed.
{ID: 7, Timestamp: 1700000005, Text: "e"},
{ID: 8, Timestamp: 1700000006, Text: "f"},
{ID: 9, Timestamp: 1700000007, Text: "g"},
{ID: 10, Timestamp: 1700000008, Text: "h"},
{ID: 11, Timestamp: 1700000009, Text: "i"},
}
result, _ := cache.MergeAndPut("replychan", msgs)
if len(result.Gaps) != 0 {
t.Errorf("album with reply prefix should not produce false gaps, got %+v", result.Gaps)
}
}
func TestAlbumSpan(t *testing.T) {
cases := []struct {
name string
text string
want int
}{
{"plain text", "hello world", 0},
{"single image legacy", "[IMAGE]\ncaption", 1},
{"single image downloadable", "[IMAGE]100:0:0:0:abcd1234:f.jpg\ncap", 1},
{"two images", "[IMAGE]100:0:0:0:aa:1.jpg\n[IMAGE]200:0:0:0:bb:2.jpg\ncap", 2},
{"three mixed", "[IMAGE]1:0:0:0:aa:a.jpg\n[VIDEO]2:0:0:0:bb:b.mp4\n[FILE]3:0:0:0:cc:c.pdf\nx", 3},
{"with reply prefix", "[REPLY]:99\n[IMAGE]100:0:0:0:aa:1.jpg\n[IMAGE]200:0:0:0:bb:2.jpg\ncap", 2},
{"reply only no media", "[REPLY]:99\nhello", 0},
{"caption that mentions a tag", "look at this [IMAGE] thing", 0},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
if got := albumSpan(c.text); got != c.want {
t.Errorf("albumSpan(%q) = %d, want %d", c.text, got, c.want)
}
})
}
}
func TestCacheGapDetection_LargeGapIgnored(t *testing.T) {
cache, _ := NewCache(t.TempDir())
+239
View File
@@ -0,0 +1,239 @@
package client
import (
"bytes"
"compress/flate"
"compress/gzip"
"context"
"fmt"
"hash/crc32"
"io"
"sync"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// DecompressMediaReader wraps r per the given compression.
func DecompressMediaReader(r io.Reader, compression protocol.MediaCompression) (io.ReadCloser, error) {
switch compression {
case protocol.MediaCompressionNone:
return io.NopCloser(r), nil
case protocol.MediaCompressionGzip:
return gzip.NewReader(r)
case protocol.MediaCompressionDeflate:
return flate.NewReader(r), nil
}
return nil, fmt.Errorf("unsupported media compression: %d", compression)
}
func decompressMediaBytes(body []byte, compression protocol.MediaCompression) ([]byte, error) {
rc, err := DecompressMediaReader(bytes.NewReader(body), compression)
if err != nil {
return nil, err
}
defer rc.Close()
return io.ReadAll(rc)
}
// MediaProgress reports per-block progress (completed of total). May be
// invoked from a background goroutine.
type MediaProgress func(completed, total int)
// MediaBlockHeaderLen re-exports the protocol header length so callers in
// the web layer don't have to import the protocol package twice.
const MediaBlockHeaderLen = protocol.MediaBlockHeaderLen
// ErrMediaHashMismatch indicates the assembled bytes don't match the
// expected CRC32. The caller must discard the returned bytes.
var ErrMediaHashMismatch = fmt.Errorf("media content hash mismatch")
// mediaBlockOuterRetries is the per-block retry budget the media path adds
// on top of FetchBlock's own internal retries. A ~200-block file can lose
// individual blocks repeatedly; without this, one persistent bad block
// kills the whole download even though FetchBlock would succeed on a
// later attempt.
const mediaBlockOuterRetries = 5
func (f *Fetcher) fetchMediaBlock(ctx context.Context, channel, block uint16) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < mediaBlockOuterRetries; attempt++ {
if ctx.Err() != nil {
return nil, ctx.Err()
}
data, err := f.FetchBlock(ctx, channel, block)
if err == nil {
return data, nil
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
lastErr = err
}
return nil, lastErr
}
// FetchMedia returns the assembled bytes of a media blob served on a media
// channel, optionally verifying expectedCRC32.
func (f *Fetcher) FetchMedia(ctx context.Context, channel uint16, blockCount uint16, expectedCRC32 uint32, progress MediaProgress) ([]byte, error) {
if !protocol.IsMediaChannel(channel) {
return nil, fmt.Errorf("channel %d is outside media range", channel)
}
if blockCount == 0 {
return nil, nil
}
type blockResult struct {
idx int
data []byte
err error
}
results := make(chan blockResult, blockCount)
sem := make(chan struct{}, 5)
var wg sync.WaitGroup
for i := 0; i < int(blockCount); i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
select {
case sem <- struct{}{}:
case <-ctx.Done():
results <- blockResult{idx: idx, err: ctx.Err()}
return
}
defer func() { <-sem }()
data, err := f.fetchMediaBlock(ctx, channel, uint16(idx))
results <- blockResult{idx: idx, data: data, err: err}
}(i)
}
go func() {
wg.Wait()
close(results)
}()
ordered := make([][]byte, blockCount)
completed := 0
var progMu sync.Mutex
for r := range results {
if r.err != nil {
if r.err == ctx.Err() {
return nil, r.err
}
return nil, fmt.Errorf("media channel %d block %d: %w", channel, r.idx, r.err)
}
ordered[r.idx] = r.data
completed++
if progress != nil {
progMu.Lock()
progress(completed, int(blockCount))
progMu.Unlock()
}
}
if len(ordered) == 0 || len(ordered[0]) < protocol.MediaBlockHeaderLen {
return nil, fmt.Errorf("media channel %d: malformed block 0", channel)
}
header, err := protocol.DecodeMediaBlockHeader(ordered[0][:protocol.MediaBlockHeaderLen])
if err != nil {
return nil, fmt.Errorf("media channel %d: %w", channel, err)
}
if expectedCRC32 != 0 && header.CRC32 != expectedCRC32 {
return nil, ErrMediaHashMismatch
}
// Concatenate all block bytes after the header.
total := len(ordered[0]) - protocol.MediaBlockHeaderLen
for i := 1; i < len(ordered); i++ {
total += len(ordered[i])
}
body := make([]byte, 0, total)
body = append(body, ordered[0][protocol.MediaBlockHeaderLen:]...)
for i := 1; i < len(ordered); i++ {
body = append(body, ordered[i]...)
}
// Decompress per the header.
out, err := decompressMediaBytes(body, header.Compression)
if err != nil {
return nil, fmt.Errorf("decompress media channel %d: %w", channel, err)
}
if expectedCRC32 != 0 {
if got := crc32.ChecksumIEEE(out); got != expectedCRC32 {
return nil, ErrMediaHashMismatch
}
}
return out, nil
}
// FetchMediaBlocksStream fetches blocks [startBlock, startBlock+count) and
// writes each block's raw bytes to w in order as soon as they become
// contiguous. No header parsing; callers slice off the protocol header
// themselves and decompress as appropriate. Cancelling ctx aborts both
// in-flight DNS queries and pending writes.
func (f *Fetcher) FetchMediaBlocksStream(ctx context.Context, channel, startBlock, count uint16, w io.Writer, progress MediaProgress) error {
if !protocol.IsMediaChannel(channel) {
return fmt.Errorf("channel %d is outside media range", channel)
}
if count == 0 {
return nil
}
type blockResult struct {
idx int
data []byte
err error
}
results := make(chan blockResult, count)
sem := make(chan struct{}, 5)
var wg sync.WaitGroup
for i := 0; i < int(count); i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
select {
case sem <- struct{}{}:
case <-ctx.Done():
results <- blockResult{idx: idx, err: ctx.Err()}
return
}
defer func() { <-sem }()
data, err := f.fetchMediaBlock(ctx, channel, uint16(int(startBlock)+idx))
results <- blockResult{idx: idx, data: data, err: err}
}(i)
}
go func() { wg.Wait(); close(results) }()
pending := make(map[int][]byte)
next := 0
completed := 0
for r := range results {
if r.err != nil {
if r.err == ctx.Err() {
return r.err
}
return fmt.Errorf("media channel %d block %d: %w", channel, int(startBlock)+r.idx, r.err)
}
pending[r.idx] = r.data
for {
payload, ok := pending[next]
if !ok {
break
}
if _, werr := w.Write(payload); werr != nil {
return werr
}
if flusher, ok := w.(interface{ Flush() }); ok {
flusher.Flush()
}
next++
}
completed++
if progress != nil {
progress(completed, int(count))
}
}
if next != int(count) {
return fmt.Errorf("media channel %d: incomplete (%d / %d)", channel, next, count)
}
return nil
}
+207
View File
@@ -0,0 +1,207 @@
package client
import (
"bytes"
"compress/flate"
"compress/gzip"
"context"
"crypto/rand"
"hash/crc32"
"testing"
"time"
"github.com/miekg/dns"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// withMediaHeader prepends the protocol media block header to body. The
// CRC32 is computed over the DECOMPRESSED bytes the caller passes in, but
// `body` itself is what the server would have produced after compressing —
// which for compression=none is just the bytes themselves.
func withMediaHeader(crc uint32, body []byte, compression protocol.MediaCompression) []byte {
hdr := protocol.EncodeMediaBlockHeader(protocol.MediaBlockHeader{
CRC32: crc,
Version: protocol.MediaHeaderVersion,
Compression: compression,
})
out := make([]byte, 0, len(hdr)+len(body))
out = append(out, hdr...)
out = append(out, body...)
return out
}
func gzipBytes(t *testing.T, b []byte) []byte {
t.Helper()
var buf bytes.Buffer
zw, _ := gzip.NewWriterLevel(&buf, gzip.BestCompression)
if _, err := zw.Write(b); err != nil {
t.Fatalf("gzip: %v", err)
}
if err := zw.Close(); err != nil {
t.Fatalf("gzip close: %v", err)
}
return buf.Bytes()
}
// blockMockExchange wires the fetcher's exchangeFn so each (channel, block)
// pair returns the matching slice from blocks.
func blockMockExchange(f *Fetcher, want uint16, blocks [][]byte) func(context.Context, *dns.Msg, string) (*dns.Msg, time.Duration, error) {
return func(ctx context.Context, m *dns.Msg, _ string) (*dns.Msg, time.Duration, error) {
if err := ctx.Err(); err != nil {
return nil, 0, err
}
ch, blk, err := protocol.DecodeQuery(f.queryKey, m.Question[0].Name, f.domain)
if err != nil {
return nil, 0, err
}
if ch != want {
return nil, 0, errFakeNotFound{}
}
if int(blk) >= len(blocks) {
return nil, 0, errFakeNotFound{}
}
encoded, encErr := protocol.EncodeResponse(f.responseKey, blocks[int(blk)], 0)
if encErr != nil {
return nil, 0, encErr
}
resp := new(dns.Msg)
resp.SetReply(m)
resp.Rcode = dns.RcodeSuccess
resp.Answer = []dns.RR{&dns.TXT{
Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0},
Txt: []string{encoded},
}}
return resp, time.Millisecond, nil
}
}
type errFakeNotFound struct{}
func (errFakeNotFound) Error() string { return "fake nxdomain" }
func TestFetchMediaUncompressed(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
original := make([]byte, 1500)
if _, err := rand.Read(original); err != nil {
t.Fatalf("rand: %v", err)
}
crc := crc32.ChecksumIEEE(original)
blocks := protocol.SplitIntoBlocks(withMediaHeader(crc, original, protocol.MediaCompressionNone))
channel := protocol.MediaChannelStart + 7
f.exchangeFn = blockMockExchange(f, channel, blocks)
out, err := f.FetchMedia(context.Background(), channel, uint16(len(blocks)), crc, nil)
if err != nil {
t.Fatalf("FetchMedia: %v", err)
}
if !bytes.Equal(out, original) {
t.Fatalf("decompressed output differs from original")
}
}
func TestFetchMediaDeflate(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
original := bytes.Repeat([]byte("xy "), 250)
crc := crc32.ChecksumIEEE(original)
var buf bytes.Buffer
zw, _ := flate.NewWriter(&buf, flate.BestCompression)
zw.Write(original)
zw.Close()
blocks := protocol.SplitIntoBlocks(withMediaHeader(crc, buf.Bytes(), protocol.MediaCompressionDeflate))
channel := protocol.MediaChannelStart + 9
f.exchangeFn = blockMockExchange(f, channel, blocks)
out, err := f.FetchMedia(context.Background(), channel, uint16(len(blocks)), crc, nil)
if err != nil {
t.Fatalf("FetchMedia: %v", err)
}
if !bytes.Equal(out, original) {
t.Fatalf("decompressed differs from original")
}
}
func TestFetchMediaGzip(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
original := bytes.Repeat([]byte("abc123 "), 200) // compressible
crc := crc32.ChecksumIEEE(original)
body := gzipBytes(t, original)
blocks := protocol.SplitIntoBlocks(withMediaHeader(crc, body, protocol.MediaCompressionGzip))
channel := protocol.MediaChannelStart + 8
f.exchangeFn = blockMockExchange(f, channel, blocks)
out, err := f.FetchMedia(context.Background(), channel, uint16(len(blocks)), crc, nil)
if err != nil {
t.Fatalf("FetchMedia: %v", err)
}
if !bytes.Equal(out, original) {
t.Fatalf("decompressed output differs from original")
}
if len(body) >= len(original) {
t.Fatalf("compressed body should be smaller than original (got %d vs %d)", len(body), len(original))
}
}
func TestFetchMediaRejectsNonMediaChannel(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
_, err := f.FetchMedia(context.Background(), 1, 1, 0, nil)
if err == nil {
t.Fatalf("expected error for non-media channel")
}
}
func TestFetchMediaRejectsBadHash(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
original := []byte("hello hash mismatch")
crc := crc32.ChecksumIEEE(original)
blocks := [][]byte{withMediaHeader(crc, original, protocol.MediaCompressionNone)}
channel := protocol.MediaChannelStart + 1
f.exchangeFn = blockMockExchange(f, channel, blocks)
_, err := f.FetchMedia(context.Background(), channel, 1, 0xDEADBEEF, nil)
if err != ErrMediaHashMismatch {
t.Fatalf("err = %v, want ErrMediaHashMismatch", err)
}
}
func TestFetchMediaBlocksStreamWritesInOrder(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
blocks := [][]byte{
[]byte("alpha"),
[]byte("beta"),
[]byte("gamma"),
}
channel := protocol.MediaChannelStart + 12
f.exchangeFn = blockMockExchange(f, channel, blocks)
var got bytes.Buffer
if err := f.FetchMediaBlocksStream(context.Background(), channel, 0, 3, &got, nil); err != nil {
t.Fatalf("FetchMediaBlocksStream: %v", err)
}
want := append(append(append([]byte{}, blocks[0]...), blocks[1]...), blocks[2]...)
if !bytes.Equal(got.Bytes(), want) {
t.Fatalf("got %q, want %q", got.Bytes(), want)
}
}
func TestFetchMediaBlocksStreamPartialRange(t *testing.T) {
f := newTestFetcher(t, []string{"1.1.1.1:53"})
blocks := [][]byte{
[]byte("first-block"),
[]byte("second-block"),
[]byte("third-block"),
}
channel := protocol.MediaChannelStart + 13
f.exchangeFn = blockMockExchange(f, channel, blocks)
var got bytes.Buffer
if err := f.FetchMediaBlocksStream(context.Background(), channel, 1, 2, &got, nil); err != nil {
t.Fatalf("FetchMediaBlocksStream: %v", err)
}
want := append(append([]byte{}, blocks[1]...), blocks[2]...)
if !bytes.Equal(got.Bytes(), want) {
t.Fatalf("got %q, want %q", got.Bytes(), want)
}
}
+229
View File
@@ -0,0 +1,229 @@
package protocol
import (
"encoding/hex"
"fmt"
"hash/fnv"
"strconv"
"strings"
)
// MediaMeta describes a downloadable media blob attached to a feed message.
//
// Wire format embedded in a message's text body (immediately after the media
// tag, before any caption):
//
// [IMAGE]<size>:<dl>:<ch>:<blk>:<crc32hex>[:<filename>]
// caption goes here on the next line(s)
//
// The filename field is optional; when present it carries an OS-friendly
// suggested filename (server-sanitised: no newlines, no path separators, no
// control characters, length-capped). Old clients that split on ':' and
// only read parts[0..4] keep working — they just ignore the trailing field.
type MediaMeta struct {
Tag string // e.g. MediaImage, MediaVideo, MediaFile
Size int64
Downloadable bool
Channel uint16
Blocks uint16
CRC32 uint32
Filename string
}
// String renders the metadata in the wire format documented above, including
// the leading tag and trailing newline that separates the metadata row from
// any caption.
func (m MediaMeta) String() string {
dl := 0
if m.Downloadable {
dl = 1
}
if fn := SanitiseMediaFilename(m.Filename); fn != "" {
return fmt.Sprintf("%s%d:%d:%d:%d:%08x:%s\n",
m.Tag, m.Size, dl, m.Channel, m.Blocks, m.CRC32, fn)
}
return fmt.Sprintf("%s%d:%d:%d:%d:%08x\n",
m.Tag, m.Size, dl, m.Channel, m.Blocks, m.CRC32)
}
// SanitiseMediaFilename returns a filename safe to embed in the wire
// metadata line. The output uses a restricted alphabet ([A-Za-z0-9._-]) so
// no path separator, colon, newline, or control char can ever survive.
// When the input is too long the base name is replaced with a short
// hash-derived id but the extension is preserved so other OSes still
// recognise the file type.
func SanitiseMediaFilename(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return ""
}
if i := strings.LastIndexAny(s, `/\`); i >= 0 {
s = s[i+1:]
}
cleaned := filterFilenameRunes(s)
if cleaned == "" || cleaned == "." || cleaned == ".." {
return ""
}
const maxBase = 24
const maxExt = 8
base, ext := splitFilenameExt(cleaned)
if len(ext) > maxExt {
ext = ext[:maxExt]
}
if len(base) > maxBase {
h := fnv.New64a()
_, _ = h.Write([]byte(cleaned))
base = "media-" + hex.EncodeToString(h.Sum(nil))[:8]
}
if base == "" || base == "." {
base = "media"
}
if ext != "" {
return base + "." + ext
}
return base
}
func filterFilenameRunes(s string) string {
var b strings.Builder
for _, r := range s {
switch {
case r >= '0' && r <= '9',
r >= 'A' && r <= 'Z',
r >= 'a' && r <= 'z',
r == '.', r == '_', r == '-':
b.WriteRune(r)
}
}
return b.String()
}
func splitFilenameExt(s string) (base, ext string) {
if i := strings.LastIndexByte(s, '.'); i >= 0 && i < len(s)-1 {
return s[:i], s[i+1:]
}
return s, ""
}
// EncodeMediaText prepends the metadata line to an optional caption and
// returns the combined message text. A nil/empty caption yields just the tag
// + metadata + trailing newline-less string (the caption split is by the
// metadata line's trailing \n, so an empty caption simply has no extra body).
func EncodeMediaText(meta MediaMeta, caption string) string {
header := meta.String()
if caption == "" {
// Drop the trailing newline so the message text doesn't end with a
// blank line for caption-less media.
return strings.TrimSuffix(header, "\n")
}
return header + caption
}
// ParseMediaText parses a message body that begins with a known media tag.
// On success it returns the metadata and the remaining caption (which may be
// empty). When the body uses the legacy "[TAG]\ncaption" form (no metadata
// suffix), ParseMediaText returns ok=true with Downloadable=false and
// Channel=0 — the caller can treat it as a non-downloadable placeholder
// exactly like before.
//
// Unknown tags return ok=false. Malformed metadata for a known tag also
// returns ok=false so the caller falls back to legacy display.
func ParseMediaText(body string) (meta MediaMeta, caption string, ok bool) {
tag, rest, found := splitKnownMediaTag(body)
if !found {
return MediaMeta{}, body, false
}
meta.Tag = tag
// The bit between the tag and the first newline is the metadata payload.
nl := strings.IndexByte(rest, '\n')
var metaLine string
if nl < 0 {
metaLine = rest
caption = ""
} else {
metaLine = rest[:nl]
caption = rest[nl+1:]
}
metaLine = strings.TrimSpace(metaLine)
if metaLine == "" {
// Legacy [TAG]\ncaption — no per-file metadata. Treat as not-downloadable.
return MediaMeta{Tag: tag}, caption, true
}
parts := strings.Split(metaLine, ":")
if len(parts) < 5 {
// Looks like a caption line that happens to start with this tag (e.g.
// "[IMAGE]nice photo"). Don't claim a structured parse — return the
// whole `rest` as caption so the message still renders.
return MediaMeta{Tag: tag}, rest, true
}
size, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil || size < 0 {
return MediaMeta{Tag: tag}, rest, true
}
dl, err := strconv.Atoi(parts[1])
if err != nil || (dl != 0 && dl != 1) {
return MediaMeta{Tag: tag}, rest, true
}
ch, err := strconv.ParseUint(parts[2], 10, 16)
if err != nil {
return MediaMeta{Tag: tag}, rest, true
}
blk, err := strconv.ParseUint(parts[3], 10, 16)
if err != nil {
return MediaMeta{Tag: tag}, rest, true
}
crc, err := strconv.ParseUint(parts[4], 16, 32)
if err != nil {
return MediaMeta{Tag: tag}, rest, true
}
// Reject any channel claimed inside a parseable metadata line that falls
// outside the reserved media range — that can only be a malformed message
// or a tampering attempt; refuse to surface it as downloadable.
channel := uint16(ch)
downloadable := dl == 1
if downloadable && (!IsMediaChannel(channel) || blk == 0) {
downloadable = false
}
meta.Size = size
meta.Downloadable = downloadable
meta.Channel = channel
meta.Blocks = uint16(blk)
meta.CRC32 = uint32(crc)
if len(parts) >= 6 {
// SanitiseMediaFilename strips the field separator, so we can't
// reach this point with a colon inside the filename. Take parts[5]
// directly and re-sanitise defensively.
meta.Filename = SanitiseMediaFilename(parts[5])
}
return meta, caption, true
}
// knownMediaTags are the message text prefixes that mark a downloadable media
// attachment. Order matters only for prefix matching; longer/more-specific
// tags are not currently aliased so the order is alphabetical for clarity.
var knownMediaTags = []string{
MediaAudio,
MediaFile,
MediaGIF,
MediaImage,
MediaSticker,
MediaVideo,
}
// splitKnownMediaTag returns the matched tag and the remainder of the body
// when body starts with one of knownMediaTags.
func splitKnownMediaTag(body string) (tag, rest string, ok bool) {
for _, t := range knownMediaTags {
if strings.HasPrefix(body, t) {
return t, body[len(t):], true
}
}
return "", body, false
}
+102
View File
@@ -0,0 +1,102 @@
package protocol
import (
"encoding/binary"
"fmt"
)
// MediaCompression names a compression method applied to a cached media
// file's bytes before they're split into DNS blocks.
type MediaCompression byte
const (
MediaCompressionNone MediaCompression = 0
MediaCompressionGzip MediaCompression = 1
MediaCompressionDeflate MediaCompression = 2
)
// MediaHeaderVersion is the current header version. Bumped when the layout
// changes incompatibly; until then, the reserved bytes carry future fields.
const MediaHeaderVersion uint8 = 1
// MediaBlockHeaderLen is the fixed length of the metadata prefix that the
// server prepends to a cached media file's bytes before splitting into
// blocks. Block 0 of every media channel begins with these bytes.
//
// Layout (big-endian where multi-byte):
// [0:4] CRC32(IEEE) of the DECOMPRESSED file content
// [4] header version (currently 1)
// [5] compression byte (MediaCompression*)
// [6:16] reserved (zero) — room for future protocol fields without
// bumping the version byte
const MediaBlockHeaderLen = 16
// MediaBlockHeader is the parsed form of a media-channel block-0 header.
type MediaBlockHeader struct {
CRC32 uint32
Version uint8
Compression MediaCompression
}
// EncodeMediaBlockHeader writes the binary header into a fresh slice of
// length MediaBlockHeaderLen. Reserved bytes are zero-padded.
func EncodeMediaBlockHeader(h MediaBlockHeader) []byte {
buf := make([]byte, MediaBlockHeaderLen)
binary.BigEndian.PutUint32(buf[0:4], h.CRC32)
if h.Version == 0 {
h.Version = MediaHeaderVersion
}
buf[4] = h.Version
buf[5] = byte(h.Compression)
return buf
}
// DecodeMediaBlockHeader parses the first MediaBlockHeaderLen bytes of a
// media block. Errors on truncation or unknown header version.
func DecodeMediaBlockHeader(b []byte) (MediaBlockHeader, error) {
if len(b) < MediaBlockHeaderLen {
return MediaBlockHeader{}, fmt.Errorf("media block header truncated: have %d bytes, need %d", len(b), MediaBlockHeaderLen)
}
h := MediaBlockHeader{
CRC32: binary.BigEndian.Uint32(b[0:4]),
Version: b[4],
Compression: MediaCompression(b[5]),
}
if h.Version != MediaHeaderVersion {
return MediaBlockHeader{}, fmt.Errorf("media block header version %d not supported (want %d)", h.Version, MediaHeaderVersion)
}
switch h.Compression {
case MediaCompressionNone, MediaCompressionGzip, MediaCompressionDeflate:
default:
return MediaBlockHeader{}, fmt.Errorf("media block header: unknown compression %d", h.Compression)
}
return h, nil
}
// ParseMediaCompressionName returns the MediaCompression matching one of
// "none", "gzip", "deflate" (case-insensitive). Used by the CLI flag to
// translate user input.
func ParseMediaCompressionName(s string) (MediaCompression, error) {
switch s {
case "", "none":
return MediaCompressionNone, nil
case "gzip":
return MediaCompressionGzip, nil
case "deflate":
return MediaCompressionDeflate, nil
}
return 0, fmt.Errorf("unknown media compression %q", s)
}
// String returns the canonical name of the compression value.
func (c MediaCompression) String() string {
switch c {
case MediaCompressionNone:
return "none"
case MediaCompressionGzip:
return "gzip"
case MediaCompressionDeflate:
return "deflate"
}
return fmt.Sprintf("unknown(%d)", byte(c))
}
+78
View File
@@ -0,0 +1,78 @@
package protocol
import (
"bytes"
"testing"
)
func TestEncodeDecodeMediaBlockHeader(t *testing.T) {
cases := []MediaBlockHeader{
{CRC32: 0x01020304, Version: MediaHeaderVersion, Compression: MediaCompressionNone},
{CRC32: 0xdeadbeef, Version: MediaHeaderVersion, Compression: MediaCompressionGzip},
{CRC32: 0, Version: MediaHeaderVersion, Compression: MediaCompressionDeflate},
}
for _, h := range cases {
buf := EncodeMediaBlockHeader(h)
if len(buf) != MediaBlockHeaderLen {
t.Fatalf("encoded length = %d, want %d", len(buf), MediaBlockHeaderLen)
}
// Reserved bytes must be zero for forward compatibility.
if !bytes.Equal(buf[6:], make([]byte, MediaBlockHeaderLen-6)) {
t.Fatalf("reserved bytes not zero: %x", buf[6:])
}
got, err := DecodeMediaBlockHeader(buf)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if got != h {
t.Fatalf("round-trip: got %+v, want %+v", got, h)
}
}
}
func TestDecodeMediaBlockHeaderRejectsBadVersion(t *testing.T) {
buf := EncodeMediaBlockHeader(MediaBlockHeader{CRC32: 1, Version: MediaHeaderVersion, Compression: MediaCompressionNone})
buf[4] = 9 // bogus version
_, err := DecodeMediaBlockHeader(buf)
if err == nil {
t.Fatal("expected error for unknown version")
}
}
func TestDecodeMediaBlockHeaderRejectsBadCompression(t *testing.T) {
buf := EncodeMediaBlockHeader(MediaBlockHeader{Version: MediaHeaderVersion})
buf[5] = 99
_, err := DecodeMediaBlockHeader(buf)
if err == nil {
t.Fatal("expected error for unknown compression")
}
}
func TestDecodeMediaBlockHeaderRejectsTruncated(t *testing.T) {
_, err := DecodeMediaBlockHeader(make([]byte, MediaBlockHeaderLen-1))
if err == nil {
t.Fatal("expected error for truncated header")
}
}
func TestParseMediaCompressionName(t *testing.T) {
cases := map[string]MediaCompression{
"": MediaCompressionNone,
"none": MediaCompressionNone,
"gzip": MediaCompressionGzip,
"deflate": MediaCompressionDeflate,
}
for in, want := range cases {
got, err := ParseMediaCompressionName(in)
if err != nil {
t.Errorf("ParseMediaCompressionName(%q): %v", in, err)
continue
}
if got != want {
t.Errorf("ParseMediaCompressionName(%q) = %v, want %v", in, got, want)
}
}
if _, err := ParseMediaCompressionName("brotli"); err == nil {
t.Fatal("expected error for unknown compression name")
}
}
+232
View File
@@ -0,0 +1,232 @@
package protocol
import (
"strings"
"testing"
)
func TestEncodeMediaTextRoundTrip(t *testing.T) {
cases := []struct {
name string
meta MediaMeta
caption string
}{
{
name: "image with caption",
meta: MediaMeta{
Tag: MediaImage,
Size: 123456,
Downloadable: true,
Channel: 12345,
Blocks: 42,
CRC32: 0xabcdef01,
},
caption: "hello world\nmulti-line",
},
{
name: "file with filename",
meta: MediaMeta{
Tag: MediaFile,
Size: 800,
Downloadable: true,
Channel: MediaChannelStart,
Blocks: 2,
CRC32: 0,
Filename: "report.zip",
},
caption: "",
},
{
name: "filename strips path traversal",
meta: MediaMeta{
Tag: MediaFile,
Size: 100,
Downloadable: true,
Channel: MediaChannelStart + 1,
Blocks: 1,
CRC32: 0xdeadbeef,
// Server-side sanitisation strips dirs, control chars, and ":"
// before the metadata reaches the wire — so a parsed filename
// is never going to contain any of those.
Filename: "/tmp/../etc/passwd:bad\nname",
},
caption: "",
},
{
name: "non-downloadable image",
meta: MediaMeta{
Tag: MediaImage,
Size: 50_000_000,
Downloadable: false,
Channel: 0,
Blocks: 0,
CRC32: 0xdeadbeef,
},
caption: "too big",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
body := EncodeMediaText(tc.meta, tc.caption)
meta, caption, ok := ParseMediaText(body)
if !ok {
t.Fatalf("ParseMediaText returned ok=false for body %q", body)
}
if caption != tc.caption {
t.Fatalf("caption = %q, want %q", caption, tc.caption)
}
if meta.Tag != tc.meta.Tag {
t.Fatalf("Tag = %q, want %q", meta.Tag, tc.meta.Tag)
}
if meta.Size != tc.meta.Size {
t.Fatalf("Size = %d, want %d", meta.Size, tc.meta.Size)
}
if meta.Downloadable != tc.meta.Downloadable {
t.Fatalf("Downloadable = %v, want %v", meta.Downloadable, tc.meta.Downloadable)
}
if meta.Channel != tc.meta.Channel {
t.Fatalf("Channel = %d, want %d", meta.Channel, tc.meta.Channel)
}
if meta.Blocks != tc.meta.Blocks {
t.Fatalf("Blocks = %d, want %d", meta.Blocks, tc.meta.Blocks)
}
if meta.CRC32 != tc.meta.CRC32 {
t.Fatalf("CRC32 = %x, want %x", meta.CRC32, tc.meta.CRC32)
}
wantFilename := SanitiseMediaFilename(tc.meta.Filename)
if meta.Filename != wantFilename {
t.Fatalf("Filename = %q, want %q", meta.Filename, wantFilename)
}
})
}
}
func TestSanitiseMediaFilename(t *testing.T) {
cases := map[string]string{
"": "",
"report.zip": "report.zip",
"path/to/report.zip": "report.zip",
"..": "",
"a:b\nc.txt": "abc.txt",
"hello": "hello",
"WeIrD-Name_v2.tar.gz": "WeIrD-Name_v2.tar.gz",
"\xff\xfe.txt": "media.txt",
"\u062d\u0645\u0644\u0647.zip": "media.zip",
}
for in, want := range cases {
if got := SanitiseMediaFilename(in); got != want {
t.Errorf("SanitiseMediaFilename(%q) = %q, want %q", in, got, want)
}
}
}
func TestSanitiseMediaFilenameLongName(t *testing.T) {
long := strings.Repeat("abc", 50) + ".zip"
got := SanitiseMediaFilename(long)
if !strings.HasPrefix(got, "media-") || !strings.HasSuffix(got, ".zip") {
t.Fatalf("long filename = %q, want media-<hash>.zip", got)
}
if len(got) > 6+8+1+3 {
t.Fatalf("long filename too long: %q", got)
}
if again := SanitiseMediaFilename(long); again != got {
t.Fatalf("non-deterministic: %q vs %q", got, again)
}
}
// Backward compat: legacy "[IMAGE]\ncaption" must still parse cleanly with
// caption preserved and Downloadable=false.
func TestParseMediaTextLegacy(t *testing.T) {
body := "[IMAGE]\nlook at this"
meta, caption, ok := ParseMediaText(body)
if !ok {
t.Fatalf("ParseMediaText ok=false on legacy body")
}
if meta.Tag != MediaImage {
t.Fatalf("Tag = %q, want %q", meta.Tag, MediaImage)
}
if meta.Downloadable {
t.Fatalf("Downloadable should be false on legacy body")
}
if caption != "look at this" {
t.Fatalf("caption = %q, want %q", caption, "look at this")
}
}
// Backward compat: legacy [IMAGE] with no caption.
func TestParseMediaTextLegacyNoCaption(t *testing.T) {
for _, body := range []string{"[IMAGE]", "[IMAGE]\n"} {
meta, caption, ok := ParseMediaText(body)
if !ok {
t.Fatalf("ok=false on %q", body)
}
if meta.Tag != MediaImage {
t.Fatalf("Tag = %q, want [IMAGE]", meta.Tag)
}
if meta.Downloadable {
t.Fatalf("legacy body should not be downloadable")
}
if caption != "" {
t.Fatalf("caption = %q, want empty", caption)
}
}
}
// A normal caption that happens to lead with a media tag should not be
// misparsed as downloadable metadata.
func TestParseMediaTextHumanCaption(t *testing.T) {
body := "[IMAGE]nice picture\nrest of post"
meta, caption, ok := ParseMediaText(body)
if !ok {
t.Fatalf("ok=false on caption-leading body")
}
if meta.Downloadable {
t.Fatalf("downloadable should be false for a human caption")
}
if meta.Channel != 0 {
t.Fatalf("channel should be 0 for non-metadata body, got %d", meta.Channel)
}
want := "nice picture\nrest of post"
if caption != want {
t.Fatalf("caption = %q, want %q", caption, want)
}
}
// Unknown tag → ok=false.
func TestParseMediaTextUnknownTag(t *testing.T) {
_, _, ok := ParseMediaText("not a tag")
if ok {
t.Fatalf("ok=true for non-tag body")
}
}
// A metadata line that names a channel outside the media range must NOT be
// surfaced as downloadable.
func TestParseMediaTextRejectsOutOfRangeChannel(t *testing.T) {
body := "[IMAGE]100:1:5:200:00000000\ncaption"
meta, _, ok := ParseMediaText(body)
if !ok {
t.Fatalf("ok=false on otherwise-valid metadata")
}
if meta.Downloadable {
t.Fatalf("Downloadable should be false for channel %d outside media range", meta.Channel)
}
}
func TestIsMediaChannel(t *testing.T) {
checks := map[uint16]bool{
0: false,
1: false,
MediaChannelStart - 1: false,
MediaChannelStart: true,
MediaChannelStart + 100: true,
MediaChannelEnd: true,
MediaChannelEnd + 1: false,
65535: false,
}
for ch, want := range checks {
if got := IsMediaChannel(ch); got != want {
t.Errorf("IsMediaChannel(%d) = %v, want %v", ch, got, want)
}
}
}
+22
View File
@@ -20,6 +20,12 @@ const (
// DefaultBlockPayload is kept for compatibility; equals MaxBlockPayload.
DefaultBlockPayload = MaxBlockPayload
// MediaBlockPayload is the fixed payload size used for media (image/file)
// blocks. Media blocks are raw binary, and using a fixed size simplifies
// both server-side block boundaries and client-side range/resume math.
// Tuned for safe DNS UDP response after AES-GCM + base64 + padding.
MediaBlockPayload = MaxBlockPayload
// DefaultMaxPadding is the default random padding added to responses to vary DNS response size.
DefaultMaxPadding = 32
@@ -29,6 +35,14 @@ const (
// MetadataChannel is the special channel number for server metadata.
MetadataChannel = 0
// MediaChannelStart and MediaChannelEnd bound the channel-number range
// reserved for cached binary media (images, files, ...). Each cached file
// occupies one channel; bytes are split into raw blocks served via the
// usual DNS TXT path. The range is well above typical feed channel counts
// and well below the special control channels at the top of uint16 space.
MediaChannelStart uint16 = 10000
MediaChannelEnd uint16 = 60000 // inclusive
// MarkerSize is the random marker in metadata to verify data freshness.
MarkerSize = 3
@@ -46,6 +60,14 @@ const (
MsgContentHashSize = 4
)
// IsMediaChannel reports whether ch falls inside the reserved media-blob
// channel range. Media channels are not enumerated in Metadata; the client
// learns each (channel, blocks, hash) tuple from the corresponding feed
// message text via [TAG]<size>:<dl>:<ch>:<blk>:<crc32hex>.
func IsMediaChannel(ch uint16) bool {
return ch >= MediaChannelStart && ch <= MediaChannelEnd
}
// Media placeholder strings for non-text content.
const (
MediaImage = "[IMAGE]"
+12
View File
@@ -58,6 +58,7 @@ type hourlyFetchReport struct {
totalQueries int64
metadataQueries int64
versionQueries int64
mediaQueries int64 // queries that landed in the media-blob channel range
perChannel map[uint16]*channelFetchStats
perResolver map[string]int64
}
@@ -696,6 +697,13 @@ func recordReportQuery(rep *hourlyFetchReport, event reportEvent) {
rep.versionQueries++
return
}
if protocol.IsMediaChannel(channel) {
// We don't fan out per-media-channel stats — the channel-id is just
// a transient slot, and 50K possible ids would explode the report.
// Total media-query volume is enough for the operator's purposes.
rep.mediaQueries++
return
}
stats := rep.perChannel[channel]
if stats == nil {
@@ -769,10 +777,14 @@ func (s *DNSServer) emitHourlyReport(rep *hourlyFetchReport, final bool) {
"totalDnsQueries": rep.totalQueries,
"totalMetadataQueries": rep.metadataQueries,
"totalVersionQueries": rep.versionQueries,
"totalMediaQueries": rep.mediaQueries,
"channels": entries,
"topResolvers": resolvers,
"finalFlush": final,
}
if mediaCache := s.feed.MediaCache(); mediaCache != nil {
payload["mediaCache"] = mediaCache.Stats()
}
b, err := json.Marshal(payload)
if err != nil {
log.Printf("[dns_hourly] marshal error: %v", err)
+33
View File
@@ -27,6 +27,13 @@ type Feed struct {
telegramLoggedIn bool
nextFetch uint32
latestVersion string
// media holds binary blobs (images, files, ...) on a separate set of
// channel numbers in the [MediaChannelStart, MediaChannelEnd] range. It
// may be nil when media downloads are disabled — Feed.GetBlock then
// rejects queries to media channels with a not-found error, mirroring
// pre-feature behaviour.
media *MediaCache
}
// NewFeed creates a new Feed with the given channel names.
@@ -88,6 +95,16 @@ func (f *Feed) GetBlock(channel, block int) ([]byte, error) {
if channel == int(protocol.TitlesChannel) {
return f.getTitlesBlock(block)
}
// Channel sits in the binary media range — delegate to MediaCache. We
// drop the read lock first because MediaCache uses its own lock and we
// don't want to hold f.mu across that path.
if channel >= 0 && channel <= 0xFFFF && protocol.IsMediaChannel(uint16(channel)) {
media := f.media
if media == nil {
return nil, fmt.Errorf("media channel %d not configured", channel)
}
return media.GetBlock(uint16(channel), uint16(block))
}
ch, ok := f.blocks[channel]
if !ok {
@@ -99,6 +116,22 @@ func (f *Feed) GetBlock(channel, block int) ([]byte, error) {
return ch[block], nil
}
// SetMediaCache attaches a MediaCache to this Feed. Pass nil to disable
// media serving (the default for backward compat). Safe to call once at
// startup before any DNS query is served.
func (f *Feed) SetMediaCache(c *MediaCache) {
f.mu.Lock()
defer f.mu.Unlock()
f.media = c
}
// MediaCache returns the configured MediaCache or nil.
func (f *Feed) MediaCache() *MediaCache {
f.mu.RLock()
defer f.mu.RUnlock()
return f.media
}
func (f *Feed) getVersionBlock(block int) ([]byte, error) {
blocks := f.versionBlocks
if len(blocks) == 0 {
+524
View File
@@ -0,0 +1,524 @@
package server
import (
"bytes"
"compress/flate"
"compress/gzip"
"errors"
"fmt"
"hash/crc32"
"io"
"sync"
"sync/atomic"
"time"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// MediaCache stores binary media blobs (images, files, ...) keyed by an
// upstream-stable identifier (Telegram file_id, image URL, ...). Each entry
// occupies one channel number drawn from the [MediaChannelStart, MediaChannelEnd]
// range, plus a precomputed list of fixed-size raw blocks served via the
// regular DNS TXT path.
//
// The cache is safe for concurrent use. Hot-path operations (Store, GetBlock)
// are O(log n) at worst and typically O(1) with the help of two side maps.
type MediaCache struct {
maxFileBytes int64
ttl time.Duration
compression protocol.MediaCompression
// Logger receives an info line per cache event when set (Store hits/misses,
// evictions). The default is a silent no-op so tests don't print noise.
logf func(format string, args ...interface{})
mu sync.RWMutex
byKey map[string]*mediaEntry // upstream key (file_id / URL) → entry
byChannel map[uint16]*mediaEntry // assigned channel → entry
byHash map[uint32]*mediaEntry // CRC32(content) → entry, for cross-key dedup
nextChannel uint16 // round-robin allocation hint
// Counters surfaced via Stats(); written with atomics so reads from the
// hourly reporter don't have to acquire mu.
storeHits uint64
storeMisses uint64
storeRejected uint64 // file too large
queryCount uint64 // total media block queries served
evictionCount uint64
currentEntries int64 // live entry count
currentBytes int64 // sum of file sizes currently cached
}
type mediaEntry struct {
channel uint16
cacheKey string // primary upstream id this entry was first stored under
aliases []string // additional keys (different upstream ids, same content)
mimeType string
filename string
tag string // protocol media tag (MediaImage, MediaFile, ...)
size int64
crc32 uint32
blocks [][]byte
expiresAt time.Time
// inflight prevents the eviction sweep from reaping an entry that is
// currently being downloaded by a goroutine that hasn't installed it yet.
inflight bool
}
// MediaCacheConfig configures a new MediaCache.
type MediaCacheConfig struct {
// MaxFileBytes is the largest individual file the cache will accept.
// Files larger than this are rejected by Store with ErrTooLarge.
MaxFileBytes int64
// TTL is how long an entry stays cached after its last refresh.
TTL time.Duration
// Compression is the wire-format compression used for media blocks.
// Defaults to MediaCompressionNone when zero.
Compression protocol.MediaCompression
// Logf receives info-level cache events. Optional.
Logf func(format string, args ...interface{})
}
// ErrTooLarge is returned by Store when content exceeds MaxFileBytes.
var ErrTooLarge = errors.New("media file exceeds configured max-size")
// ErrCacheFull is returned by Store when no media channel slot is available.
// In practice this requires either MediaChannelEnd-Start+1 simultaneously
// pinned files or a TTL too generous for the workload.
var ErrCacheFull = errors.New("no free media channel slot")
// NewMediaCache constructs a cache with the given configuration. A zero
// MaxFileBytes disables the size cap; a zero TTL means entries never expire
// (not recommended in production).
func NewMediaCache(cfg MediaCacheConfig) *MediaCache {
logf := cfg.Logf
if logf == nil {
logf = func(string, ...interface{}) {}
}
return &MediaCache{
maxFileBytes: cfg.MaxFileBytes,
ttl: cfg.TTL,
compression: cfg.Compression,
logf: logf,
byKey: make(map[string]*mediaEntry),
byChannel: make(map[uint16]*mediaEntry),
byHash: make(map[uint32]*mediaEntry),
nextChannel: protocol.MediaChannelStart,
}
}
// Store inserts (or refreshes) a media blob into the cache and returns
// metadata that the caller can embed in a feed message.
//
// cacheKey is an upstream-stable identifier (e.g. Telegram file_id, image
// URL). When the same key is stored again, the existing entry's TTL is
// refreshed and the same channel/blocks are returned without copying the
// contents — callers should rely on this for the "fetch every 10 min"
// duplicate-handling case described in the design.
//
// tag is the protocol media tag (MediaImage, MediaFile, ...); mimeType and
// filename are optional and stored for the HTTP layer to surface to the
// client. content is the raw file bytes; the caller may pass a slice it
// continues to use after the call (Store copies into block-sized chunks).
func (c *MediaCache) Store(cacheKey, tag string, content []byte, mimeType, filename string) (protocol.MediaMeta, error) {
if cacheKey == "" {
return protocol.MediaMeta{}, errors.New("media: empty cache key")
}
if tag == "" {
tag = protocol.MediaFile
}
size := int64(len(content))
if c.maxFileBytes > 0 && size > c.maxFileBytes {
atomic.AddUint64(&c.storeRejected, 1)
return protocol.MediaMeta{
Tag: tag,
Size: size,
Downloadable: false,
}, ErrTooLarge
}
now := time.Now()
hash := crc32.ChecksumIEEE(content)
c.mu.Lock()
defer c.mu.Unlock()
if existing, ok := c.byKey[cacheKey]; ok && existing.crc32 == hash {
// Same upstream id and same content — just refresh the TTL.
existing.expiresAt = c.expiry(now)
atomic.AddUint64(&c.storeHits, 1)
c.logf("media: refresh tag=%s key=%s ch=%d size=%d", tag, cacheKey, existing.channel, existing.size)
return c.metaForLocked(existing), nil
}
// Cross-key content match: a different upstream id pointed at exactly
// the same bytes. Bind the new cache key to the existing entry so any
// future Lookup under either key works, and refresh the TTL. This is
// the case the spec asks for: "same media → just reset TTL, don't take
// a new channel slot".
if existing, ok := c.byHash[hash]; ok {
existing.expiresAt = c.expiry(now)
if cacheKey != existing.cacheKey {
alreadyAliased := false
for _, a := range existing.aliases {
if a == cacheKey {
alreadyAliased = true
break
}
}
if !alreadyAliased {
existing.aliases = append(existing.aliases, cacheKey)
}
}
c.byKey[cacheKey] = existing
atomic.AddUint64(&c.storeHits, 1)
c.logf("media: dedup tag=%s key=%s ch=%d size=%d (hash match)", tag, cacheKey, existing.channel, existing.size)
return c.metaForLocked(existing), nil
}
// Either a new key, or the same key carries different bytes (a Telegram
// edit, a re-upload). Allocate a fresh channel and replace.
if existing, ok := c.byKey[cacheKey]; ok {
c.dropEntryLocked(existing)
}
// Opportunistic sweep before we allocate. Without this, expired entries
// that don't sit on the allocator's linear-scan path (i.e. ones below
// nextChannel) accumulate until the periodic sweep runs. That breaks
// the "TTL is the upper bound on how long a slot stays cached" promise
// across burst-store workloads with small TTLs. The cost is O(n) over
// active entries; n is capped by the media-channel range.
c.sweepExpiredLocked(now)
channel, err := c.allocateChannelLocked(now)
if err != nil {
return protocol.MediaMeta{}, err
}
blocks, encErr := splitMediaBlocks(hash, content, c.compression)
if encErr != nil {
return protocol.MediaMeta{}, encErr
}
if size > 0 {
var compressedBody int
for _, b := range blocks {
compressedBody += len(b)
}
compressedBody -= protocol.MediaBlockHeaderLen
if compressedBody < 0 {
compressedBody = 0
}
var savedPct int
if c.compression != protocol.MediaCompressionNone && size > 0 {
savedPct = int((size - int64(compressedBody)) * 100 / size)
}
c.logf("media: compress=%s key=%s orig=%d body=%d saved=%d%%", c.compression, cacheKey, size, compressedBody, savedPct)
}
entry := &mediaEntry{
channel: channel,
cacheKey: cacheKey,
mimeType: mimeType,
filename: protocol.SanitiseMediaFilename(filename),
tag: tag,
size: size,
crc32: hash,
blocks: blocks,
expiresAt: c.expiry(now),
}
c.byKey[cacheKey] = entry
c.byChannel[channel] = entry
c.byHash[hash] = entry
atomic.AddUint64(&c.storeMisses, 1)
atomic.AddInt64(&c.currentEntries, 1)
atomic.AddInt64(&c.currentBytes, size)
c.logf("media: store tag=%s key=%s ch=%d size=%d blocks=%d", tag, cacheKey, channel, size, len(blocks))
return c.metaForLocked(entry), nil
}
// LookupByChannel returns the cached entry's transport metadata (mime,
// filename) for a serving channel. Returns ok=false if no entry is mapped.
// Used by the HTTP layer to pick a sensible Content-Type/Content-Disposition
// for clients that didn't provide one in the query string.
func (c *MediaCache) LookupByChannel(channel uint16) (mime, filename string, ok bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, found := c.byChannel[channel]
if !found {
return "", "", false
}
return entry.mimeType, entry.filename, true
}
// Lookup returns the metadata for an entry by cache key, refreshing TTL on
// hit. Returns ok=false if not present.
func (c *MediaCache) Lookup(cacheKey string) (protocol.MediaMeta, bool) {
c.mu.Lock()
defer c.mu.Unlock()
entry, ok := c.byKey[cacheKey]
if !ok {
return protocol.MediaMeta{}, false
}
entry.expiresAt = c.expiry(time.Now())
return c.metaForLocked(entry), true
}
// GetBlock returns one block of cached media for serving over DNS. Returns an
// error if the channel isn't a media channel, the entry has expired, or the
// block index is out of range. Increments the served-query counter.
func (c *MediaCache) GetBlock(channel, block uint16) ([]byte, error) {
if !protocol.IsMediaChannel(channel) {
return nil, fmt.Errorf("channel %d is outside media range", channel)
}
atomic.AddUint64(&c.queryCount, 1)
c.mu.RLock()
entry, ok := c.byChannel[channel]
c.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("media channel %d not found", channel)
}
if int(block) >= len(entry.blocks) {
return nil, fmt.Errorf("media block %d out of range (%d blocks)", block, len(entry.blocks))
}
// Reading a block extends the entry lifetime — clients in the middle of
// downloading shouldn't have the cache rug pulled mid-transfer.
c.mu.Lock()
entry.expiresAt = c.expiry(time.Now())
c.mu.Unlock()
return entry.blocks[block], nil
}
// Sweep evicts entries whose TTL has elapsed. Returns the number evicted.
// Safe to call from a periodic goroutine.
func (c *MediaCache) Sweep() int {
if c.ttl <= 0 {
return 0
}
now := time.Now()
c.mu.Lock()
defer c.mu.Unlock()
n := c.sweepExpiredLocked(now)
if n > 0 {
c.logf("media: sweep evicted=%d remaining=%d", n, len(c.byChannel))
}
return n
}
// sweepExpiredLocked is the shared implementation behind both the periodic
// Sweep and the opportunistic per-Store sweep. Caller must hold c.mu.
// It returns the number of entries evicted.
func (c *MediaCache) sweepExpiredLocked(now time.Time) int {
if c.ttl <= 0 {
return 0
}
var expired []*mediaEntry
for _, entry := range c.byChannel {
if entry.inflight {
continue
}
if now.After(entry.expiresAt) {
expired = append(expired, entry)
}
}
for _, entry := range expired {
c.dropEntryLocked(entry)
}
return len(expired)
}
// MediaCacheStats is a snapshot of cache counters.
type MediaCacheStats struct {
Entries int64 `json:"entries"`
Bytes int64 `json:"bytes"`
Queries uint64 `json:"queries"`
StoreHits uint64 `json:"storeHits"`
StoreMisses uint64 `json:"storeMisses"`
StoreRejected uint64 `json:"storeRejected"`
Evictions uint64 `json:"evictions"`
MaxFileBytes int64 `json:"maxFileBytes"`
TTLSeconds int64 `json:"ttlSeconds"`
}
// Stats returns a snapshot of cache counters. Lock-free for the per-counter
// fields; Entries and Bytes are also atomic.
func (c *MediaCache) Stats() MediaCacheStats {
return MediaCacheStats{
Entries: atomic.LoadInt64(&c.currentEntries),
Bytes: atomic.LoadInt64(&c.currentBytes),
Queries: atomic.LoadUint64(&c.queryCount),
StoreHits: atomic.LoadUint64(&c.storeHits),
StoreMisses: atomic.LoadUint64(&c.storeMisses),
StoreRejected: atomic.LoadUint64(&c.storeRejected),
Evictions: atomic.LoadUint64(&c.evictionCount),
MaxFileBytes: c.maxFileBytes,
TTLSeconds: int64(c.ttl / time.Second),
}
}
// allocateChannelLocked finds a free channel in the media range, evicting
// expired entries on the way. Caller must hold c.mu.
func (c *MediaCache) allocateChannelLocked(now time.Time) (uint16, error) {
rangeSize := int(protocol.MediaChannelEnd) - int(protocol.MediaChannelStart) + 1
start := c.nextChannel
if start < protocol.MediaChannelStart || start > protocol.MediaChannelEnd {
start = protocol.MediaChannelStart
}
cur := start
for i := 0; i < rangeSize; i++ {
entry, taken := c.byChannel[cur]
if !taken {
c.advanceNextLocked(cur)
return cur, nil
}
if !entry.inflight && c.ttl > 0 && now.After(entry.expiresAt) {
c.dropEntryLocked(entry)
c.advanceNextLocked(cur)
return cur, nil
}
// Step to next slot, wrap when we hit the end of the range.
if cur == protocol.MediaChannelEnd {
cur = protocol.MediaChannelStart
} else {
cur++
}
}
// Range fully occupied with non-expired entries — evict the oldest one as
// a last resort, so the cache never hard-fails under steady-state
// pressure with reasonable configs.
var oldest *mediaEntry
for _, entry := range c.byChannel {
if entry.inflight {
continue
}
if oldest == nil || entry.expiresAt.Before(oldest.expiresAt) {
oldest = entry
}
}
if oldest == nil {
return 0, ErrCacheFull
}
freed := oldest.channel
c.dropEntryLocked(oldest)
c.advanceNextLocked(freed)
return freed, nil
}
func (c *MediaCache) advanceNextLocked(used uint16) {
if used == protocol.MediaChannelEnd {
c.nextChannel = protocol.MediaChannelStart
} else {
c.nextChannel = used + 1
}
}
func (c *MediaCache) dropEntryLocked(entry *mediaEntry) {
delete(c.byChannel, entry.channel)
delete(c.byKey, entry.cacheKey)
for _, alias := range entry.aliases {
// Only delete an alias if it still resolves to this entry; a later
// store under the same key may have rebound it elsewhere.
if c.byKey[alias] == entry {
delete(c.byKey, alias)
}
}
if c.byHash[entry.crc32] == entry {
delete(c.byHash, entry.crc32)
}
atomic.AddInt64(&c.currentEntries, -1)
atomic.AddInt64(&c.currentBytes, -entry.size)
atomic.AddUint64(&c.evictionCount, 1)
}
func (c *MediaCache) expiry(now time.Time) time.Time {
if c.ttl <= 0 {
// "Never" — represented as far future so all comparisons act as expected.
return time.Unix(1<<62, 0)
}
return now.Add(c.ttl)
}
func (c *MediaCache) metaForLocked(entry *mediaEntry) protocol.MediaMeta {
return protocol.MediaMeta{
Tag: entry.tag,
Size: entry.size,
Downloadable: true,
Channel: entry.channel,
Blocks: uint16(len(entry.blocks)),
CRC32: entry.crc32,
Filename: entry.filename,
}
}
// splitMediaBlocks compresses the content (when compression != none),
// prepends the protocol media header, then splits the result into
// randomly-sized blocks. The CRC32 carried in the header is over the
// DECOMPRESSED bytes so the client can verify integrity after
// decompression. Uniform sizing is avoided to match the anti-DPI strategy
// used for feed-message blocks.
func splitMediaBlocks(crc32Hash uint32, content []byte, compression protocol.MediaCompression) ([][]byte, error) {
body, err := compressMediaBytes(content, compression)
if err != nil {
return nil, err
}
header := protocol.EncodeMediaBlockHeader(protocol.MediaBlockHeader{
CRC32: crc32Hash,
Version: protocol.MediaHeaderVersion,
Compression: compression,
})
full := make([]byte, 0, len(header)+len(body))
full = append(full, header...)
full = append(full, body...)
return protocol.SplitIntoBlocks(full), nil
}
func compressMediaBytes(content []byte, compression protocol.MediaCompression) ([]byte, error) {
switch compression {
case protocol.MediaCompressionNone:
return content, nil
case protocol.MediaCompressionGzip:
var buf bytes.Buffer
zw, err := gzip.NewWriterLevel(&buf, gzip.BestCompression)
if err != nil {
return nil, err
}
if _, err := zw.Write(content); err != nil {
zw.Close()
return nil, err
}
if err := zw.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
case protocol.MediaCompressionDeflate:
var buf bytes.Buffer
zw, err := flate.NewWriter(&buf, flate.BestCompression)
if err != nil {
return nil, err
}
if _, err := zw.Write(content); err != nil {
zw.Close()
return nil, err
}
if err := zw.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
return nil, fmt.Errorf("unsupported media compression: %d", compression)
}
// DecompressMediaBytes is the inverse of compressMediaBytes; exposed for
// the HTTP layer (which receives a stream of compressed bytes after the
// header is stripped) and tests.
func DecompressMediaBytes(r io.Reader, compression protocol.MediaCompression) (io.ReadCloser, error) {
switch compression {
case protocol.MediaCompressionNone:
return io.NopCloser(r), nil
case protocol.MediaCompressionGzip:
return gzip.NewReader(r)
case protocol.MediaCompressionDeflate:
return flate.NewReader(r), nil
}
return nil, fmt.Errorf("unsupported media compression: %d", compression)
}
+148
View File
@@ -0,0 +1,148 @@
package server
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"path"
"strings"
"time"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// httpMediaClient is a small shared client for fetching media URLs the
// public-Telegram and X readers extract. It deliberately uses a relatively
// short timeout — media downloads must not stall the rest of a fetch cycle.
var httpMediaClient = &http.Client{
Timeout: 60 * time.Second,
// Disallow redirects to non-http(s) schemes; Telegram CDN sometimes
// redirects through 301/302 to a regional host which is fine.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) > 5 {
return errors.New("too many redirects")
}
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
return fmt.Errorf("disallowed redirect scheme %q", req.URL.Scheme)
}
return nil
},
}
// allowedMediaSchemes is the set of URL schemes downloadHTTPMedia will load.
var allowedMediaSchemes = map[string]bool{
"http": true,
"https": true,
}
// downloadHTTPMedia fetches the bytes at rawURL and stores them in cache,
// using the URL itself as the cache key (so refreshing the same channel
// every 10 min just bumps TTL on hit).
//
// It enforces the configured max-size both up-front (Content-Length) and on
// the wire (LimitReader) so a server lying about size can't blow past the
// limit. URLs are validated against allowedMediaSchemes; private-network
// targets are not blocked here because callers (PublicReader, XPublicReader)
// only pass URLs scraped from Telegram/Nitter responses.
func downloadHTTPMedia(ctx context.Context, cache *MediaCache, tag, rawURL string) (protocol.MediaMeta, bool) {
if cache == nil || rawURL == "" {
return protocol.MediaMeta{}, false
}
parsed, err := url.Parse(rawURL)
if err != nil || !allowedMediaSchemes[parsed.Scheme] {
return protocol.MediaMeta{}, false
}
// Cache key is the canonical URL — image-link rotation on the upstream
// side will create a fresh entry, but identical URLs across fetches will
// just refresh TTL.
cacheKey := tag + ":url:" + parsed.String()
if meta, ok := cache.Lookup(cacheKey); ok {
return meta, true
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
if err != nil {
return protocol.MediaMeta{}, false
}
req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; thefeed/1.0)")
req.Header.Set("Accept", "image/*, application/octet-stream;q=0.9, */*;q=0.5")
resp, err := httpMediaClient.Do(req)
if err != nil {
logfMedia("[media-http] %s: request failed: %v", parsed.String(), err)
return protocol.MediaMeta{}, false
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return protocol.MediaMeta{}, false
}
// Defense in depth: reject HTML/XHTML responses outright. Telegram's
// public web view sometimes redirects "file" links to the channel page
// itself; without this check we'd happily cache the channel's HTML as
// the user's downloadable file.
ctype := strings.ToLower(strings.TrimSpace(strings.Split(resp.Header.Get("Content-Type"), ";")[0]))
if ctype == "text/html" || ctype == "application/xhtml+xml" {
logfMedia("[media-http] %s: refusing HTML response (got %s)", parsed.String(), ctype)
return protocol.MediaMeta{}, false
}
maxBytes := cache.maxFileBytes
if maxBytes > 0 && resp.ContentLength > 0 && resp.ContentLength > maxBytes {
size := resp.ContentLength
return protocol.MediaMeta{
Tag: tag,
Size: size,
Downloadable: false,
}, true
}
limit := int64(-1)
if maxBytes > 0 {
limit = maxBytes + 1 // +1 to detect overflow vs exact match
}
var body io.Reader = resp.Body
if limit > 0 {
body = io.LimitReader(resp.Body, limit)
}
bytes, err := io.ReadAll(body)
if err != nil {
logfMedia("[media-http] %s: read failed: %v", parsed.String(), err)
return protocol.MediaMeta{}, false
}
if maxBytes > 0 && int64(len(bytes)) > maxBytes {
return protocol.MediaMeta{
Tag: tag,
Size: int64(len(bytes)),
Downloadable: false,
}, true
}
meta, err := cache.Store(cacheKey, tag, bytes, resp.Header.Get("Content-Type"), urlBaseName(parsed))
if err != nil {
if errors.Is(err, ErrTooLarge) {
return meta, true
}
return protocol.MediaMeta{}, false
}
return meta, true
}
// urlBaseName returns the trailing path segment, stripped of its query, as a
// best-effort filename for HTTP layer Content-Disposition headers.
func urlBaseName(u *url.URL) string {
if u == nil {
return ""
}
base := path.Base(u.Path)
if base == "" || base == "/" || base == "." {
return ""
}
if i := strings.IndexByte(base, '?'); i >= 0 {
base = base[:i]
}
return base
}
+298
View File
@@ -0,0 +1,298 @@
package server
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// TestApplyHTTPMediaSourcesEndToEnd wires a fake upstream HTTP image server,
// runs applyHTTPMediaSources against it, and verifies the message body now
// carries downloadable metadata that ParseMediaText can read back. Then it
// fetches a block out of the resulting MediaCache to confirm the bytes were
// stored correctly.
func TestApplyHTTPMediaSourcesEndToEnd(t *testing.T) {
imageBytes := []byte("fake-image-bytes-payload-1234567890")
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(http.StatusOK)
w.Write(imageBytes)
}))
defer srv.Close()
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 1 << 20, TTL: time.Hour})
msgs := []protocol.Message{
{ID: 100, Timestamp: 1, Text: protocol.MediaImage + "\nhello"},
}
sources := []mediaSource{{tag: protocol.MediaImage, url: srv.URL + "/photo.png"}}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
applyHTTPMediaSources(ctx, cache, msgs, sources)
meta, caption, ok := protocol.ParseMediaText(msgs[0].Text)
if !ok {
t.Fatalf("ParseMediaText ok=false on rewritten message: %q", msgs[0].Text)
}
if !meta.Downloadable {
t.Fatalf("expected downloadable meta, got %+v (text=%q)", meta, msgs[0].Text)
}
if meta.Tag != protocol.MediaImage {
t.Fatalf("Tag = %q, want %q", meta.Tag, protocol.MediaImage)
}
if meta.Size != int64(len(imageBytes)) {
t.Fatalf("Size = %d, want %d", meta.Size, len(imageBytes))
}
if caption != "hello" {
t.Fatalf("caption = %q, want %q", caption, "hello")
}
// Block 0 starts with the 4-byte CRC32 prefix; subsequent blocks are
// raw content.
var got []byte
for blk := uint16(0); blk < meta.Blocks; blk++ {
b, err := cache.GetBlock(meta.Channel, blk)
if err != nil {
t.Fatalf("GetBlock(%d, %d): %v", meta.Channel, blk, err)
}
got = append(got, b...)
}
if len(got) < protocol.MediaBlockHeaderLen {
t.Fatalf("block 0 too short: %d", len(got))
}
hdr, err := protocol.DecodeMediaBlockHeader(got[:protocol.MediaBlockHeaderLen])
if err != nil {
t.Fatalf("DecodeMediaBlockHeader: %v", err)
}
if hdr.CRC32 != meta.CRC32 {
t.Fatalf("header CRC = %x, want %x", hdr.CRC32, meta.CRC32)
}
if string(got[protocol.MediaBlockHeaderLen:]) != string(imageBytes) {
t.Fatalf("reassembled bytes differ:\n got: %q\n want: %q", got[protocol.MediaBlockHeaderLen:], imageBytes)
}
}
// TestApplyHTTPMediaSourcesGzipRoundTrip: with --media-compression=gzip,
// a successful upstream fetch lands compressed blocks in the cache. A
// client decompressing the assembled blocks recovers the original bytes
// verbatim and the embedded CRC32 matches.
func TestApplyHTTPMediaSourcesGzipRoundTrip(t *testing.T) {
imageBytes := bytes.Repeat([]byte("compressible-stripe "), 300)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(http.StatusOK)
w.Write(imageBytes)
}))
defer srv.Close()
cache := NewMediaCache(MediaCacheConfig{
MaxFileBytes: 1 << 20,
TTL: time.Hour,
Compression: protocol.MediaCompressionGzip,
})
msgs := []protocol.Message{{ID: 100, Timestamp: 1, Text: protocol.MediaImage + "\n"}}
sources := []mediaSource{{tag: protocol.MediaImage, url: srv.URL + "/big.png"}}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
applyHTTPMediaSources(ctx, cache, msgs, sources)
meta, _, ok := protocol.ParseMediaText(msgs[0].Text)
if !ok || !meta.Downloadable {
t.Fatalf("expected downloadable meta, got %+v", meta)
}
var got []byte
for blk := uint16(0); blk < meta.Blocks; blk++ {
b, err := cache.GetBlock(meta.Channel, blk)
if err != nil {
t.Fatalf("GetBlock: %v", err)
}
got = append(got, b...)
}
hdr, err := protocol.DecodeMediaBlockHeader(got[:protocol.MediaBlockHeaderLen])
if err != nil {
t.Fatalf("DecodeMediaBlockHeader: %v", err)
}
if hdr.Compression != protocol.MediaCompressionGzip {
t.Fatalf("compression = %v, want gzip", hdr.Compression)
}
if hdr.CRC32 != meta.CRC32 {
t.Fatalf("header CRC = %x, want %x", hdr.CRC32, meta.CRC32)
}
rc, err := DecompressMediaBytes(bytes.NewReader(got[protocol.MediaBlockHeaderLen:]), hdr.Compression)
if err != nil {
t.Fatalf("decompress: %v", err)
}
defer rc.Close()
out, err := io.ReadAll(rc)
if err != nil {
t.Fatalf("read all: %v", err)
}
if !bytes.Equal(out, imageBytes) {
t.Fatalf("decompressed differs from upstream")
}
}
// TestApplyHTTPMediaSourcesAlbum: when src.extraURLs is populated (public-mode
// album), every URL is fetched and the canonical body is rebuilt with N
// stacked downloadable headers + the original caption. The frontend then
// renders an N-card album.
func TestApplyHTTPMediaSourcesAlbum(t *testing.T) {
images := [][]byte{
[]byte("first-image-bytes-XXXXXX"),
[]byte("second-image-bytes-YYYYY"),
[]byte("third-image-bytes-ZZZZZZ"),
}
served := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/jpeg")
w.WriteHeader(http.StatusOK)
// Path looks like /img-N.jpg → pick the matching slice.
switch r.URL.Path {
case "/img1.jpg":
w.Write(images[0])
case "/img2.jpg":
w.Write(images[1])
case "/img3.jpg":
w.Write(images[2])
}
served++
}))
defer srv.Close()
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 1 << 20, TTL: time.Hour})
// Mirror what parsePublicMessagesWithMedia produces for a 3-image album:
// stacked [IMAGE] headers + caption, plus an extraURLs slice on the source.
body := protocol.MediaImage + "\n" + protocol.MediaImage + "\n" + protocol.MediaImage + "\nalbum caption"
msgs := []protocol.Message{{ID: 5, Timestamp: 1, Text: body}}
sources := []mediaSource{{
tag: protocol.MediaImage,
url: srv.URL + "/img1.jpg",
extraURLs: []string{srv.URL + "/img2.jpg", srv.URL + "/img3.jpg"},
}}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
applyHTTPMediaSources(ctx, cache, msgs, sources)
if served != 3 {
t.Errorf("served = %d, want 3 upstream fetches", served)
}
// Rewritten body must have exactly 3 [IMAGE]<size>:1:... headers and
// the original caption preserved on the trailing line.
got := msgs[0].Text
headerCount := strings.Count(got, protocol.MediaImage)
if headerCount != 3 {
t.Fatalf("header count = %d, want 3 (text=%q)", headerCount, got)
}
if !strings.HasSuffix(got, "\nalbum caption") {
t.Errorf("caption not preserved: %q", got)
}
// Each header must round-trip through ParseMediaText with downloadable=true.
rest := got
for i := 0; i < 3; i++ {
meta, c, ok := protocol.ParseMediaText(rest)
if !ok {
t.Fatalf("ParseMediaText #%d ok=false on %q", i, rest)
}
if !meta.Downloadable {
t.Errorf("header #%d not downloadable: %+v", i, meta)
}
if int(meta.Size) != len(images[i]) {
t.Errorf("header #%d size = %d, want %d", i, meta.Size, len(images[i]))
}
rest = c
}
if rest != "album caption" {
t.Errorf("trailing caption = %q, want %q", rest, "album caption")
}
}
// TestApplyHTTPMediaSourcesAlbumPartialFailure: when one upstream fetch
// fails we still emit a placeholder [TAG] for that slot so the album's
// ID-span (= number of leading headers) is preserved. The remaining items
// stay downloadable.
func TestApplyHTTPMediaSourcesAlbumPartialFailure(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/broken.jpg" {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "image/jpeg")
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok-image"))
}))
defer srv.Close()
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 1 << 20, TTL: time.Hour})
body := protocol.MediaImage + "\n" + protocol.MediaImage + "\ncap"
msgs := []protocol.Message{{ID: 5, Timestamp: 1, Text: body}}
sources := []mediaSource{{
tag: protocol.MediaImage,
url: srv.URL + "/ok.jpg",
extraURLs: []string{srv.URL + "/broken.jpg"},
}}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
applyHTTPMediaSources(ctx, cache, msgs, sources)
got := msgs[0].Text
if c := strings.Count(got, protocol.MediaImage); c != 2 {
t.Errorf("header count = %d, want 2 (text=%q)", c, got)
}
// First should be downloadable; last line is the broken-fallback bare tag
// followed by the caption.
if !strings.HasSuffix(got, "\n"+protocol.MediaImage+"\ncap") {
t.Errorf("expected placeholder + caption tail, got %q", got)
}
}
// TestApplyHTTPMediaSourcesRejectsOversize: a too-large file leaves the
// message text untouched but still records the entry as "metadata only" with
// downloadable=false so the UI can show the size without offering the button.
func TestApplyHTTPMediaSourcesRejectsOversize(t *testing.T) {
bigBody := strings.Repeat("X", 1024)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "1024")
w.WriteHeader(http.StatusOK)
w.Write([]byte(bigBody))
}))
defer srv.Close()
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 100, TTL: time.Hour})
msgs := []protocol.Message{{ID: 1, Timestamp: 1, Text: protocol.MediaImage + "\ncap"}}
sources := []mediaSource{{tag: protocol.MediaImage, url: srv.URL + "/big.jpg"}}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
applyHTTPMediaSources(ctx, cache, msgs, sources)
meta, _, ok := protocol.ParseMediaText(msgs[0].Text)
if !ok {
t.Fatalf("ParseMediaText ok=false")
}
if meta.Downloadable {
t.Fatalf("oversized file should not be downloadable; got meta=%+v", meta)
}
if meta.Size != int64(len(bigBody)) {
t.Fatalf("Size = %d, want %d (server should still surface the size)", meta.Size, len(bigBody))
}
stats := cache.Stats()
if stats.Entries != 0 {
t.Fatalf("oversized file should not occupy a cache slot, got entries=%d", stats.Entries)
}
}
+27
View File
@@ -0,0 +1,27 @@
package server
import (
"log"
"sync/atomic"
)
// mediaDebugLogs gates verbose media-cache log output. Server.Run flips it
// based on the --debug flag at startup. Atomic so other goroutines reading
// the value while logging don't need a mutex.
var mediaDebugLogs atomic.Bool
// SetMediaDebugLogs enables or disables the media debug log channel.
func SetMediaDebugLogs(enabled bool) {
mediaDebugLogs.Store(enabled)
}
// logfMedia prints a media-feature log line only when debug logging is on.
// Errors that operators should always see go through plain log.Printf
// directly; logfMedia is reserved for the chatty per-store / per-cache-hit
// chatter.
func logfMedia(format string, args ...interface{}) {
if !mediaDebugLogs.Load() {
return
}
log.Printf("[media-debug] "+format, args...)
}
+321
View File
@@ -0,0 +1,321 @@
package server
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/gotd/td/tg"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// telegramMediaDownloadChunk is the per-RPC chunk size used by UploadGetFile.
// MTProto requires Limit to be a multiple of 4KB and ≤ 1MB; 256KB is a good
// trade-off between API call overhead and memory pressure for tiny files.
const telegramMediaDownloadChunk = 256 * 1024
// telegramMediaPhotoSizeOrder lists Telegram photo size codes from smallest
// to largest. The downloader picks the smallest *usable* size — for a
// DNS-tunnelled feed, bandwidth is precious and a thumbnail is usually
// enough for the user to decide whether to look at the original. The
// "stripped" placeholder type is filtered out separately because it is not
// a real renderable image.
//
// a / b — tiny (≤ 100px)
// c — small chat preview
// m — medium
// s — small (legacy)
// x — high-quality
// y / w — original / largest
var telegramMediaPhotoSizeOrder = []string{"a", "b", "c", "m", "s", "x", "y", "w"}
// downloadTelegramMedia fetches and caches media for a Telegram message. It
// returns the metadata that should be embedded in the message body, or an
// empty MediaMeta with ok=false to fall through to the legacy [TAG] path.
//
// The function is best-effort: any error (download failure, oversized file,
// missing download API) is logged once and the message is returned without
// downloadable metadata so the rest of the feed isn't blocked. The caller
// is responsible for substituting EncodeMediaText into the message body.
func (tr *TelegramReader) downloadTelegramMedia(ctx context.Context, api *tg.Client, msg *tg.Message) (protocol.MediaMeta, bool) {
if api == nil || msg == nil || msg.Media == nil {
return protocol.MediaMeta{}, false
}
cache := tr.feed.MediaCache()
if cache == nil {
return protocol.MediaMeta{}, false
}
switch m := msg.Media.(type) {
case *tg.MessageMediaPhoto:
photo, ok := m.Photo.(*tg.Photo)
if !ok {
return protocol.MediaMeta{}, false
}
return tr.downloadTelegramPhoto(ctx, api, cache, photo)
case *tg.MessageMediaDocument:
doc, ok := m.Document.(*tg.Document)
if !ok {
return protocol.MediaMeta{}, false
}
return tr.downloadTelegramDocument(ctx, api, cache, doc)
}
return protocol.MediaMeta{}, false
}
func (tr *TelegramReader) downloadTelegramPhoto(ctx context.Context, api *tg.Client, cache *MediaCache, photo *tg.Photo) (protocol.MediaMeta, bool) {
cacheKey := "tg-photo:" + strconv.FormatInt(photo.ID, 10)
// Hit the cache before doing any I/O — exact dedup, no bytes transferred.
if meta, ok := cache.Lookup(cacheKey); ok {
return meta, true
}
bestType, bestBytes := pickSmallestPhotoSize(photo.Sizes)
if bestType == "" {
return protocol.MediaMeta{}, false
}
// Honour the configured max-size early so we don't even open the RPC for
// objects we'll just throw away.
if maxBytes := cache.maxFileBytes; maxBytes > 0 && bestBytes > maxBytes {
return protocol.MediaMeta{
Tag: protocol.MediaImage,
Size: bestBytes,
Downloadable: false,
}, true
}
loc := &tg.InputPhotoFileLocation{
ID: photo.ID,
AccessHash: photo.AccessHash,
FileReference: photo.FileReference,
ThumbSize: bestType,
}
bytes, err := tr.downloadTelegramFile(ctx, api, loc, bestBytes)
if err != nil {
// Transient fetch error (network, FILE_REFERENCE_EXPIRED, etc.).
// We don't mark the message as non-downloadable in that case —
// "non-downloadable" means "the file exists but the server chose
// not to cache it" (i.e. oversized). Falling through to legacy
// keeps the UI honest, and the next 10-min refresh cycle re-tries.
tr.logMediaError("photo", photo.ID, err)
return protocol.MediaMeta{}, false
}
meta, err := cache.Store(cacheKey, protocol.MediaImage, bytes, "image/jpeg", "")
if err != nil {
// ErrTooLarge is reported as non-downloadable; any other store error
// is just dropped to legacy.
if errors.Is(err, ErrTooLarge) {
return meta, true
}
tr.logMediaError("photo", photo.ID, err)
return protocol.MediaMeta{}, false
}
return meta, true
}
func (tr *TelegramReader) downloadTelegramDocument(ctx context.Context, api *tg.Client, cache *MediaCache, doc *tg.Document) (protocol.MediaMeta, bool) {
cacheKey := "tg-doc:" + strconv.FormatInt(doc.ID, 10)
if meta, ok := cache.Lookup(cacheKey); ok {
return meta, true
}
tag, filename := classifyDocumentTagAndName(doc)
if tag == protocol.MediaSticker {
return protocol.MediaMeta{}, false
}
if maxBytes := cache.maxFileBytes; maxBytes > 0 && doc.Size > maxBytes {
return protocol.MediaMeta{
Tag: tag,
Size: doc.Size,
Downloadable: false,
}, true
}
loc := &tg.InputDocumentFileLocation{
ID: doc.ID,
AccessHash: doc.AccessHash,
FileReference: doc.FileReference,
ThumbSize: "",
}
bytes, err := tr.downloadTelegramFile(ctx, api, loc, doc.Size)
if err != nil {
// See note in downloadTelegramPhoto: transient fetch errors should
// not be surfaced as "non-downloadable", they should fall through
// to legacy [TAG]\ncaption rendering and let the next refresh retry.
tr.logMediaError("doc", doc.ID, err)
return protocol.MediaMeta{}, false
}
meta, err := cache.Store(cacheKey, tag, bytes, doc.MimeType, filename)
if err != nil {
if errors.Is(err, ErrTooLarge) {
return meta, true
}
tr.logMediaError("doc", doc.ID, err)
return protocol.MediaMeta{}, false
}
return meta, true
}
// downloadTelegramFile downloads `expectedSize` bytes (or all available bytes
// when expectedSize <= 0) from the given Telegram file location. It enforces
// the configured max-size cap defensively so a file that lies about its size
// still can't blow past the limit on the wire.
func (tr *TelegramReader) downloadTelegramFile(ctx context.Context, api *tg.Client, loc tg.InputFileLocationClass, expectedSize int64) ([]byte, error) {
cache := tr.feed.MediaCache()
maxBytes := int64(0)
if cache != nil {
maxBytes = cache.maxFileBytes
}
var (
out []byte
offset int64
)
for {
if ctx.Err() != nil {
return nil, ctx.Err()
}
req := &tg.UploadGetFileRequest{
Location: loc,
Offset: offset,
Limit: telegramMediaDownloadChunk,
}
res, err := api.UploadGetFile(ctx, req)
if err != nil {
return nil, fmt.Errorf("upload.getFile offset=%d: %w", offset, err)
}
fileRes, ok := res.(*tg.UploadFile)
if !ok {
return nil, fmt.Errorf("unexpected upload response type %T", res)
}
if len(fileRes.Bytes) == 0 {
break
}
out = append(out, fileRes.Bytes...)
offset += int64(len(fileRes.Bytes))
// Hard guard against runaway downloads.
if maxBytes > 0 && int64(len(out)) > maxBytes {
return nil, fmt.Errorf("download exceeded configured max-size (%d > %d)", len(out), maxBytes)
}
// We consider the transfer complete when the server returned less than
// the requested chunk (canonical EOF) or we've reached the expected size.
if len(fileRes.Bytes) < telegramMediaDownloadChunk {
break
}
if expectedSize > 0 && int64(len(out)) >= expectedSize {
break
}
}
return out, nil
}
// pickSmallestPhotoSize returns the smallest usable size in a Telegram
// Photo as (type-code, byte-size). DNS-tunnelled bandwidth is precious, so
// we prefer a small chat-preview thumbnail over the full-resolution
// original whenever Telegram offers both. Returns empty type when no usable
// size is available (e.g. only stripped placeholder thumbs).
func pickSmallestPhotoSize(sizes []tg.PhotoSizeClass) (string, int64) {
type candidate struct {
typ string
size int64
}
var pool []candidate
add := func(typ string, size int64) {
if typ == "" {
return
}
pool = append(pool, candidate{typ: typ, size: size})
}
for _, s := range sizes {
switch v := s.(type) {
case *tg.PhotoSize:
add(v.Type, int64(v.Size))
case *tg.PhotoCachedSize:
add(v.Type, int64(len(v.Bytes)))
case *tg.PhotoSizeProgressive:
// Progressive carries a slice of progressive sizes; the FIRST
// element is the smallest progressive prefix the server can
// stream, which suits "smallest usable" perfectly.
if len(v.Sizes) > 0 {
add(v.Type, int64(v.Sizes[0]))
} else {
add(v.Type, 0)
}
case *tg.PhotoStrippedSize:
// Stripped sizes are tiny placeholder thumbs — skip.
}
}
if len(pool) == 0 {
return "", 0
}
// Prefer the entry with the smallest declared byte size; break ties
// using the type-code preference order (smallest first). When the
// declared size is 0 (unknown), the type code alone decides the order.
rank := make(map[string]int, len(telegramMediaPhotoSizeOrder))
for i, t := range telegramMediaPhotoSizeOrder {
rank[t] = i
}
bestIdx := -1
for i, c := range pool {
if bestIdx < 0 {
bestIdx = i
continue
}
b := pool[bestIdx]
// Prefer a strictly smaller known size.
if c.size > 0 && b.size > 0 {
if c.size < b.size {
bestIdx = i
continue
}
if c.size == b.size && rank[c.typ] < rank[b.typ] {
bestIdx = i
}
continue
}
// One of them has unknown size — fall back to type-code rank.
if rank[c.typ] < rank[b.typ] {
bestIdx = i
}
}
chosen := pool[bestIdx]
return chosen.typ, chosen.size
}
// classifyDocumentTagAndName returns the protocol media tag and best-effort
// filename for a Telegram Document. The tag mirrors classifyDocument's logic
// but also exposes the filename attribute so the HTTP layer can offer a
// reasonable Content-Disposition.
func classifyDocumentTagAndName(doc *tg.Document) (string, string) {
tag := protocol.MediaFile
filename := ""
for _, attr := range doc.Attributes {
switch a := attr.(type) {
case *tg.DocumentAttributeVideo:
tag = protocol.MediaVideo
case *tg.DocumentAttributeAudio:
tag = protocol.MediaAudio
case *tg.DocumentAttributeSticker:
tag = protocol.MediaSticker
case *tg.DocumentAttributeAnimated:
tag = protocol.MediaGIF
case *tg.DocumentAttributeFilename:
filename = a.FileName
}
}
return tag, filename
}
func (tr *TelegramReader) logMediaError(kind string, id int64, err error) {
// Best-effort log; the receiver's package log is fine for now.
logfMedia("[telegram] media %s id=%d download failed: %v", kind, id, err)
}
+307
View File
@@ -0,0 +1,307 @@
package server
import (
"bytes"
"errors"
"hash/crc32"
"strings"
"testing"
"time"
"github.com/sartoopjj/thefeed/internal/protocol"
)
func newTestCache(maxBytes int64, ttl time.Duration) *MediaCache {
return NewMediaCache(MediaCacheConfig{MaxFileBytes: maxBytes, TTL: ttl})
}
func TestMediaCacheStoreAndGetBlock(t *testing.T) {
cache := newTestCache(1<<20, time.Hour)
content := bytes.Repeat([]byte("ab"), 1000) // 2000 bytes — multiple blocks
meta, err := cache.Store("key1", protocol.MediaImage, content, "image/jpeg", "")
if err != nil {
t.Fatalf("Store: %v", err)
}
if !meta.Downloadable {
t.Fatalf("Downloadable = false, want true")
}
if !protocol.IsMediaChannel(meta.Channel) {
t.Fatalf("Channel %d not in media range", meta.Channel)
}
if meta.Size != int64(len(content)) {
t.Fatalf("Size = %d, want %d", meta.Size, len(content))
}
if meta.CRC32 != crc32.ChecksumIEEE(content) {
t.Fatalf("CRC32 mismatch")
}
if meta.Blocks == 0 {
t.Fatalf("Blocks should be > 0")
}
// Reassemble: block 0 begins with the protocol media header, then comes
// the (compression-default = none) bytes which equal the original.
var got []byte
for blk := uint16(0); blk < meta.Blocks; blk++ {
b, err := cache.GetBlock(meta.Channel, blk)
if err != nil {
t.Fatalf("GetBlock(%d, %d): %v", meta.Channel, blk, err)
}
got = append(got, b...)
}
if len(got) < protocol.MediaBlockHeaderLen {
t.Fatalf("assembled bytes too short: %d", len(got))
}
hdr, err := protocol.DecodeMediaBlockHeader(got[:protocol.MediaBlockHeaderLen])
if err != nil {
t.Fatalf("DecodeMediaBlockHeader: %v", err)
}
if hdr.CRC32 != meta.CRC32 {
t.Fatalf("header CRC = %x, want %x", hdr.CRC32, meta.CRC32)
}
if hdr.Compression != protocol.MediaCompressionNone {
t.Fatalf("header compression = %v, want none", hdr.Compression)
}
if !bytes.Equal(got[protocol.MediaBlockHeaderLen:], content) {
t.Fatalf("reassembled bytes differ: got %d, want %d", len(got)-protocol.MediaBlockHeaderLen, len(content))
}
}
// TestMediaCacheStoreGzip exercises the compressed wire path: bytes after
// the header are gzip-compressed and DecompressMediaBytes reproduces the
// original.
func TestMediaCacheStoreGzip(t *testing.T) {
cache := NewMediaCache(MediaCacheConfig{
MaxFileBytes: 1 << 20,
TTL: time.Hour,
Compression: protocol.MediaCompressionGzip,
})
content := bytes.Repeat([]byte("compress-me "), 200)
meta, err := cache.Store("gz", protocol.MediaFile, content, "text/plain", "")
if err != nil {
t.Fatalf("Store: %v", err)
}
var got []byte
for blk := uint16(0); blk < meta.Blocks; blk++ {
b, err := cache.GetBlock(meta.Channel, blk)
if err != nil {
t.Fatalf("GetBlock(%d, %d): %v", meta.Channel, blk, err)
}
got = append(got, b...)
}
hdr, err := protocol.DecodeMediaBlockHeader(got[:protocol.MediaBlockHeaderLen])
if err != nil {
t.Fatalf("DecodeMediaBlockHeader: %v", err)
}
if hdr.Compression != protocol.MediaCompressionGzip {
t.Fatalf("compression = %v, want gzip", hdr.Compression)
}
body, err := DecompressMediaBytes(bytes.NewReader(got[protocol.MediaBlockHeaderLen:]), hdr.Compression)
if err != nil {
t.Fatalf("decompress: %v", err)
}
defer body.Close()
decompressed := new(bytes.Buffer)
if _, err := decompressed.ReadFrom(body); err != nil {
t.Fatalf("read decompressed: %v", err)
}
if !bytes.Equal(decompressed.Bytes(), content) {
t.Fatalf("decompressed differs from original")
}
if crc32.ChecksumIEEE(decompressed.Bytes()) != hdr.CRC32 {
t.Fatalf("header CRC %x doesn't match decompressed CRC %x", hdr.CRC32, crc32.ChecksumIEEE(decompressed.Bytes()))
}
}
// Storing the same key with the same content should refresh TTL but reuse
// the existing channel — this is the "every 10 min refresh" deduplication
// path called out in the spec.
func TestMediaCacheDedup(t *testing.T) {
cache := newTestCache(0, time.Hour)
content := []byte("hello")
meta1, err := cache.Store("dup", protocol.MediaImage, content, "", "")
if err != nil {
t.Fatalf("first Store: %v", err)
}
stats1 := cache.Stats()
meta2, err := cache.Store("dup", protocol.MediaImage, content, "", "")
if err != nil {
t.Fatalf("second Store: %v", err)
}
if meta1.Channel != meta2.Channel {
t.Fatalf("dedup: channel changed (%d → %d)", meta1.Channel, meta2.Channel)
}
stats2 := cache.Stats()
if stats2.StoreHits != stats1.StoreHits+1 {
t.Fatalf("StoreHits did not increment: %d → %d", stats1.StoreHits, stats2.StoreHits)
}
if stats2.StoreMisses != stats1.StoreMisses {
t.Fatalf("StoreMisses changed unexpectedly")
}
}
// Cross-key dedup: identical bytes arriving under a different upstream id
// must reuse the existing cache slot, refresh the TTL, and not consume a
// fresh channel — this is the behaviour the spec calls out.
func TestMediaCacheCrossKeyDedup(t *testing.T) {
cache := newTestCache(0, time.Hour)
content := []byte("the same bytes under different keys")
m1, err := cache.Store("key-A", protocol.MediaImage, content, "", "")
if err != nil {
t.Fatalf("first Store: %v", err)
}
statsBefore := cache.Stats()
m2, err := cache.Store("key-B-different", protocol.MediaImage, content, "", "")
if err != nil {
t.Fatalf("second Store: %v", err)
}
if m1.Channel != m2.Channel {
t.Fatalf("cross-key dedup: channel changed (%d -> %d)", m1.Channel, m2.Channel)
}
statsAfter := cache.Stats()
if statsAfter.Entries != statsBefore.Entries {
t.Fatalf("cross-key dedup: entries grew %d -> %d (should reuse slot)", statsBefore.Entries, statsAfter.Entries)
}
if statsAfter.StoreHits != statsBefore.StoreHits+1 {
t.Fatalf("StoreHits should have incremented")
}
// Lookup under either key returns the same entry.
if meta, ok := cache.Lookup("key-A"); !ok || meta.Channel != m1.Channel {
t.Fatalf("Lookup(key-A) failed: ok=%v meta=%+v", ok, meta)
}
if meta, ok := cache.Lookup("key-B-different"); !ok || meta.Channel != m1.Channel {
t.Fatalf("Lookup(key-B-different) failed: ok=%v meta=%+v", ok, meta)
}
}
// Same key with different bytes (e.g. a Telegram edit) must replace the
// stored content and produce a new channel.
func TestMediaCacheKeyReplaceOnContentChange(t *testing.T) {
cache := newTestCache(0, time.Hour)
first := []byte("first content")
second := []byte("second content (different)")
m1, err := cache.Store("k", protocol.MediaImage, first, "", "")
if err != nil {
t.Fatalf("first Store: %v", err)
}
m2, err := cache.Store("k", protocol.MediaImage, second, "", "")
if err != nil {
t.Fatalf("second Store: %v", err)
}
if m1.CRC32 == m2.CRC32 {
t.Fatalf("CRC32 should differ for different content")
}
// Verify GetBlock on m1.Channel either succeeds with NEW bytes (channel
// reuse) or fails entirely — never returns the OLD bytes. Block 0
// begins with the protocol header whose CRC field identifies which
// content the slot is currently serving.
if blk, err := cache.GetBlock(m1.Channel, 0); err == nil {
if len(blk) >= protocol.MediaBlockHeaderLen {
if hdr, err := protocol.DecodeMediaBlockHeader(blk[:protocol.MediaBlockHeaderLen]); err == nil && hdr.CRC32 == m1.CRC32 {
t.Fatalf("GetBlock returned stale (first) bytes after content change")
}
}
}
}
func TestMediaCacheRejectsOversizeFile(t *testing.T) {
cache := newTestCache(100, time.Hour)
_, err := cache.Store("big", protocol.MediaFile, bytes.Repeat([]byte("x"), 200), "", "")
if !errors.Is(err, ErrTooLarge) {
t.Fatalf("err = %v, want ErrTooLarge", err)
}
stats := cache.Stats()
if stats.StoreRejected != 1 {
t.Fatalf("StoreRejected = %d, want 1", stats.StoreRejected)
}
if stats.Entries != 0 {
t.Fatalf("Entries = %d, want 0", stats.Entries)
}
}
func TestMediaCacheGetBlockOutOfRange(t *testing.T) {
cache := newTestCache(0, time.Hour)
_, err := cache.GetBlock(protocol.MediaChannelStart, 0)
if err == nil {
t.Fatalf("expected error for unknown channel")
}
_, err = cache.GetBlock(0, 0)
if err == nil || !strings.Contains(err.Error(), "outside media range") {
t.Fatalf("expected media-range error, got %v", err)
}
}
func TestMediaCacheSweepEvictsExpired(t *testing.T) {
cache := newTestCache(0, 10*time.Millisecond)
_, err := cache.Store("k", protocol.MediaFile, []byte("data"), "", "")
if err != nil {
t.Fatalf("Store: %v", err)
}
if cache.Stats().Entries != 1 {
t.Fatalf("Entries = %d, want 1", cache.Stats().Entries)
}
time.Sleep(20 * time.Millisecond)
if n := cache.Sweep(); n != 1 {
t.Fatalf("Sweep evicted %d, want 1", n)
}
if cache.Stats().Entries != 0 {
t.Fatalf("Entries after sweep = %d, want 0", cache.Stats().Entries)
}
}
// Allocator: when the next-hint slot is taken but expired, that slot is
// reclaimed instead of skipped.
func TestMediaCacheReclaimsExpiredSlot(t *testing.T) {
cache := newTestCache(0, 10*time.Millisecond)
m1, err := cache.Store("a", protocol.MediaFile, []byte("aaa"), "", "")
if err != nil {
t.Fatalf("Store a: %v", err)
}
time.Sleep(20 * time.Millisecond)
// Force the allocator's nextChannel back to m1.Channel by storing keys
// until we wrap is impractical, but we know the next hint is m1.Channel+1.
// Triggering a Store with the expired slot in the way of the linear scan
// proves it's reclaimed and the new entry fits.
m2, err := cache.Store("b", protocol.MediaFile, []byte("bbb"), "", "")
if err != nil {
t.Fatalf("Store b: %v", err)
}
if m2.Channel == m1.Channel {
t.Logf("note: reused expired slot at ch %d (expected when nextChannel wraps)", m2.Channel)
}
stats := cache.Stats()
if stats.Entries != 1 {
t.Fatalf("Entries = %d, want 1 (the old expired entry should be gone)", stats.Entries)
}
}
// Round-trip with the wire-format encoder: a cache entry's metadata, when
// embedded in a message, can be parsed back to recover the same channel and
// hash a client would download.
func TestMediaCacheMetadataRoundTrip(t *testing.T) {
cache := newTestCache(0, time.Hour)
content := []byte("round trip content")
meta, err := cache.Store("rt", protocol.MediaImage, content, "image/png", "pic.png")
if err != nil {
t.Fatalf("Store: %v", err)
}
body := protocol.EncodeMediaText(meta, "look at this")
parsed, caption, ok := protocol.ParseMediaText(body)
if !ok {
t.Fatalf("ParseMediaText ok=false")
}
if parsed.Channel != meta.Channel {
t.Fatalf("Channel: parsed %d, stored %d", parsed.Channel, meta.Channel)
}
if parsed.CRC32 != meta.CRC32 {
t.Fatalf("CRC32 mismatch")
}
if caption != "look at this" {
t.Fatalf("caption = %q", caption)
}
}
+176 -13
View File
@@ -172,13 +172,79 @@ func (pr *PublicReader) fetchChannel(ctx context.Context, username string) ([]pr
if err != nil {
return nil, "", err
}
msgs, err := parsePublicMessages(body)
msgs, sources, err := parsePublicMessagesWithMedia(body)
if err != nil {
return nil, "", err
}
// If the server has a configured media cache, fetch each scraped image
// URL and rewrite the corresponding message text to embed downloadable
// metadata. Failures here are best-effort: messages keep their legacy
// "[IMAGE]\ncaption" body when downloads don't succeed.
if cache := pr.feed.MediaCache(); cache != nil {
applyHTTPMediaSources(ctx, cache, msgs, sources)
}
return msgs, extractChannelTitle(body), nil
}
// applyHTTPMediaSources downloads each src.url (+ extraURLs for albums) and
// rewrites the matching message body with N stacked downloadable metadata
// lines. Failed downloads emit a bare [TAG] so the album's ID span is
// preserved.
func applyHTTPMediaSources(ctx context.Context, cache *MediaCache, msgs []protocol.Message, sources []mediaSource) {
for i := range msgs {
if i >= len(sources) {
break
}
src := sources[i]
if src.url == "" || src.tag == "" {
continue
}
// Strip every leading [TAG] header so we can re-emit clean metadata
// (ParseMediaText only peels one tag per call).
body := msgs[i].Text
for {
_, rest, parsed := protocol.ParseMediaText(body)
if !parsed {
break
}
body = rest
}
caption := body
urls := append([]string{src.url}, src.extraURLs...)
var encoded strings.Builder
downloaded := 0
for j, u := range urls {
meta, ok := downloadHTTPMedia(ctx, cache, src.tag, u)
if j > 0 {
encoded.WriteByte('\n')
}
if !ok {
encoded.WriteString(src.tag)
continue
}
downloaded++
encoded.WriteString(strings.TrimSuffix(meta.String(), "\n"))
}
if downloaded == 0 {
continue
}
newText := encoded.String()
if caption != "" {
newText += "\n" + caption
}
msgs[i].Text = newText
}
}
// mediaSource is the per-message media descriptor returned by the public
// scraper. extraURLs holds additional album siblings; url is the first one.
type mediaSource struct {
tag string
url string
extraURLs []string
}
// extractChannelTitle parses the channel display name from the Telegram public page.
func extractChannelTitle(body []byte) string {
doc, err := html.Parse(strings.NewReader(string(body)))
@@ -219,12 +285,27 @@ func mergeMessages(old, new []protocol.Message) []protocol.Message {
}
func parsePublicMessages(body []byte) ([]protocol.Message, error) {
msgs, _, err := parsePublicMessagesWithMedia(body)
return msgs, err
}
// parsePublicMessagesWithMedia is identical to parsePublicMessages but also
// returns a per-message media descriptor — same length and ordering as the
// returned messages — that callers can use to fetch the underlying photo or
// document over HTTP and rewrite the message body. The legacy behaviour
// (returning just messages) is preserved by parsePublicMessages above for
// existing tests and pre-feature callers.
func parsePublicMessagesWithMedia(body []byte) ([]protocol.Message, []mediaSource, error) {
doc, err := html.Parse(strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("parse html: %w", err)
return nil, nil, fmt.Errorf("parse html: %w", err)
}
var collected []publicMessage
type collectedMsg struct {
msg publicMessage
src mediaSource
}
var collected []collectedMsg
visitNodes(doc, func(n *html.Node) {
post := attrValue(n, "data-post")
if post == "" {
@@ -235,17 +316,42 @@ func parsePublicMessages(body []byte) ([]protocol.Message, error) {
return
}
text := strings.TrimSpace(extractMessageText(findMessageBodyNode(n)))
var src mediaSource
mediaPrefix := ""
switch {
case findFirstByClass(n, "tgme_widget_message_photo_wrap") != nil:
mediaPrefix = protocol.MediaImage
// Albums share one data-post block with N nested photo wraps.
// Stack N [IMAGE] headers so the client-side gap detector
// (albumSpan) doesn't flag the absorbed sibling IDs as missing.
photoWraps := findAllByClass(n, "tgme_widget_message_photo_wrap")
if len(photoWraps) > 1 {
headers := make([]string, len(photoWraps))
for i := range headers {
headers[i] = protocol.MediaImage
}
mediaPrefix = strings.Join(headers, "\n")
} else {
mediaPrefix = protocol.MediaImage
}
src = mediaSource{tag: protocol.MediaImage, url: extractBackgroundImageURL(photoWraps[0])}
for i := 1; i < len(photoWraps); i++ {
if u := extractBackgroundImageURL(photoWraps[i]); u != "" {
src.extraURLs = append(src.extraURLs, u)
}
}
case findFirstByClass(n, "tgme_widget_message_video_player") != nil ||
findFirstByClass(n, "tgme_widget_message_roundvideo_player") != nil:
mediaPrefix = protocol.MediaVideo
// t.me/s/ does not serve real video bytes — the player anchor links
// to the channel page itself. Don't try to download.
case findFirstByClass(n, "tgme_widget_message_sticker_wrap") != nil:
mediaPrefix = protocol.MediaSticker
// Stickers are emitted as the legacy tag only; we don't cache or
// serve their bytes (animated/.tgs variants don't render inline
// in the browser anyway).
case findFirstByClass(n, "tgme_widget_message_voice") != nil:
mediaPrefix = protocol.MediaAudio
// Public web view doesn't expose voice file bytes either.
case findFirstByClass(n, "tgme_widget_message_poll") != nil:
mediaPrefix = protocol.MediaPoll
pollBody := extractPollData(n)
@@ -264,6 +370,10 @@ func parsePublicMessages(body []byte) ([]protocol.Message, error) {
mediaPrefix = protocol.MediaContact
case findFirstByClass(n, "tgme_widget_message_document_wrap") != nil:
mediaPrefix = protocol.MediaFile
// In t.me/s/ the document link is a "view in Telegram" page link,
// not the file CDN — fetching it would download the channel HTML.
// Skip; documents are downloadable only when the server runs with
// a Telegram login (gotd UploadGetFile path).
case findFirstByClass(n, "message_media_not_supported") != nil:
// Telegram shows "Please open Telegram to view this post" for
// content the public web view can't render: polls/quizzes, but
@@ -294,26 +404,31 @@ func parsePublicMessages(body []byte) ([]protocol.Message, error) {
text = protocol.MediaReply + "\n" + text
}
}
collected = append(collected, publicMessage{
id: id,
timestamp: extractMessageTimestamp(n),
text: text,
collected = append(collected, collectedMsg{
msg: publicMessage{
id: id,
timestamp: extractMessageTimestamp(n),
text: text,
},
src: src,
})
})
if len(collected) == 0 {
return nil, fmt.Errorf("no public messages found")
return nil, nil, fmt.Errorf("no public messages found")
}
sort.Slice(collected, func(i, j int) bool {
return collected[i].id > collected[j].id
return collected[i].msg.id > collected[j].msg.id
})
msgs := make([]protocol.Message, 0, len(collected))
for _, msg := range collected {
msgs = append(msgs, protocol.Message{ID: msg.id, Timestamp: msg.timestamp, Text: msg.text})
sources := make([]mediaSource, 0, len(collected))
for _, c := range collected {
msgs = append(msgs, protocol.Message{ID: c.msg.id, Timestamp: c.msg.timestamp, Text: c.msg.text})
sources = append(sources, c.src)
}
return msgs, nil
return msgs, sources, nil
}
func visitNodes(n *html.Node, fn func(*html.Node)) {
@@ -339,6 +454,17 @@ func findFirstByClass(n *html.Node, class string) *html.Node {
return found
}
// findAllByClass returns every descendant of n that carries the given class.
func findAllByClass(n *html.Node, class string) []*html.Node {
var found []*html.Node
visitNodes(n, func(cur *html.Node) {
if hasClass(cur, class) {
found = append(found, cur)
}
})
return found
}
func hasClass(n *html.Node, class string) bool {
if n == nil || n.Type != html.ElementNode {
return false
@@ -578,3 +704,40 @@ func extractReplyID(replyNode *html.Node) uint32 {
}
return id
}
// extractBackgroundImageURL pulls the URL out of an inline
// `style="background-image:url('...')"` attribute. Telegram's public photo
// widget uses this pattern to render thumbnails — the URL points to the
// CDN-hosted image and is the source we want to download. Returns an empty
// string when the pattern is not present.
func extractBackgroundImageURL(n *html.Node) string {
if n == nil {
return ""
}
style := attrValue(n, "style")
if style == "" {
return ""
}
idx := strings.Index(style, "background-image")
if idx < 0 {
return ""
}
// Find url(...) after the property name.
rest := style[idx:]
open := strings.Index(rest, "url(")
if open < 0 {
return ""
}
rest = rest[open+len("url("):]
close := strings.IndexByte(rest, ')')
if close < 0 {
return ""
}
raw := strings.TrimSpace(rest[:close])
raw = strings.TrimPrefix(raw, "'")
raw = strings.TrimSuffix(raw, "'")
raw = strings.TrimPrefix(raw, "\"")
raw = strings.TrimSuffix(raw, "\"")
return raw
}
+70
View File
@@ -104,6 +104,76 @@ func TestMergeMessages(t *testing.T) {
}
}
func TestParsePublicMessagesAlbumStacksHeaders(t *testing.T) {
// Album = one data-post with N nested photo wraps. We must emit N
// stacked [IMAGE] headers so albumSpan suppresses the absorbed-sibling
// "1 missed" gap.
body := []byte(`
<html><body>
<div class="tgme_widget_message" data-post="testchan/210">
<a class="tgme_widget_message_date"><time datetime="2026-04-10T12:00:00+00:00"></time></a>
<a class="tgme_widget_message_photo_wrap" style="background-image:url('https://cdn.telegram.org/img1.jpg')"></a>
<a class="tgme_widget_message_photo_wrap" style="background-image:url('https://cdn.telegram.org/img2.jpg')"></a>
<a class="tgme_widget_message_photo_wrap" style="background-image:url('https://cdn.telegram.org/img3.jpg')"></a>
<div class="tgme_widget_message_text">album caption</div>
</div>
</body></html>
`)
msgs, sources, err := parsePublicMessagesWithMedia(body)
if err != nil {
t.Fatalf("parsePublicMessagesWithMedia: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("len(msgs) = %d, want 1", len(msgs))
}
wantText := "[IMAGE]\n[IMAGE]\n[IMAGE]\nalbum caption"
if msgs[0].Text != wantText {
t.Fatalf("msgs[0].Text = %q, want %q", msgs[0].Text, wantText)
}
if len(sources) != 1 {
t.Fatalf("len(sources) = %d, want 1", len(sources))
}
src := sources[0]
if src.tag != protocol.MediaImage {
t.Errorf("src.tag = %q, want %q", src.tag, protocol.MediaImage)
}
if src.url != "https://cdn.telegram.org/img1.jpg" {
t.Errorf("src.url = %q, want first photo URL", src.url)
}
if len(src.extraURLs) != 2 ||
src.extraURLs[0] != "https://cdn.telegram.org/img2.jpg" ||
src.extraURLs[1] != "https://cdn.telegram.org/img3.jpg" {
t.Errorf("src.extraURLs = %v, want [img2, img3]", src.extraURLs)
}
}
func TestParsePublicMessagesSinglePhotoUnchanged(t *testing.T) {
// Single photo: one [IMAGE] header, no extraURLs.
body := []byte(`
<html><body>
<div class="tgme_widget_message" data-post="testchan/220">
<a class="tgme_widget_message_photo_wrap" style="background-image:url('https://cdn.telegram.org/single.jpg')"></a>
<div class="tgme_widget_message_text">just one</div>
</div>
</body></html>
`)
msgs, sources, err := parsePublicMessagesWithMedia(body)
if err != nil {
t.Fatalf("parsePublicMessagesWithMedia: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("len(msgs) = %d, want 1", len(msgs))
}
wantText := "[IMAGE]\njust one"
if msgs[0].Text != wantText {
t.Fatalf("msgs[0].Text = %q, want %q", msgs[0].Text, wantText)
}
if sources[0].url != "https://cdn.telegram.org/single.jpg" || len(sources[0].extraURLs) != 0 {
t.Errorf("source = %+v, want url=single, extraURLs empty", sources[0])
}
}
func TestParsePublicMessagesReplyPreviewUsesMainBody(t *testing.T) {
body := []byte(`
<html><body>
+75 -1
View File
@@ -7,6 +7,7 @@ import (
"log"
"os"
"strings"
"time"
"github.com/sartoopjj/thefeed/internal/protocol"
)
@@ -24,7 +25,21 @@ type Config struct {
NoTelegram bool // if true, fetch public channels without Telegram login
AllowManage bool // if true, remote channel management and sending via DNS is allowed
Debug bool // if true, log every decoded DNS query
Telegram TelegramConfig
// NoMedia disables downloading and serving image/file media. When set, the
// server emits the legacy [TAG]\ncaption form for media messages so old
// clients keep working unchanged.
NoMedia bool
// MediaMaxSize is the per-file cap in bytes for cached media. 0 means no
// cap (not recommended in production).
MediaMaxSize int64
// MediaCacheTTL is the cache lifetime in minutes for a single entry. The
// effective TTL is reset whenever the same upstream id is fetched again.
MediaCacheTTL int
// MediaCompression names the compression applied to cached media bytes
// before they're split into DNS blocks. One of "none", "gzip",
// "deflate". Empty defaults to "gzip".
MediaCompression string
Telegram TelegramConfig
}
// Server orchestrates the DNS server and Telegram reader.
@@ -64,6 +79,39 @@ func (s *Server) Run(ctx context.Context) error {
return fmt.Errorf("derive keys: %w", err)
}
SetMediaDebugLogs(s.cfg.Debug)
// Configure media cache before any reader starts so the very first fetch
// cycle can populate it. When --no-media is set we leave Feed.media as
// nil; the readers fall through to the legacy [TAG]\ncaption form, and
// Feed.GetBlock rejects media-channel queries with not-found.
if !s.cfg.NoMedia {
ttlMin := s.cfg.MediaCacheTTL
if ttlMin <= 0 {
ttlMin = 600
}
ttl := time.Duration(ttlMin) * time.Minute
compName := s.cfg.MediaCompression
if compName == "" {
compName = "gzip"
}
compression, err := protocol.ParseMediaCompressionName(compName)
if err != nil {
return fmt.Errorf("--media-compression: %w", err)
}
mediaCache := NewMediaCache(MediaCacheConfig{
MaxFileBytes: s.cfg.MediaMaxSize,
TTL: ttl,
Compression: compression,
Logf: logfMedia,
})
s.feed.SetMediaCache(mediaCache)
log.Printf("[server] media cache enabled: max-size=%d bytes, ttl=%s, compression=%s", s.cfg.MediaMaxSize, ttl, compression)
go s.runMediaSweep(ctx, mediaCache, ttl)
} else {
log.Println("[server] media cache disabled (--no-media)")
}
go startLatestVersionTracker(ctx, s.feed)
var channelCtl channelRefresher
@@ -180,3 +228,29 @@ func prefixXAccounts(accounts []string) []string {
}
return out
}
// runMediaSweep periodically evicts expired entries from the cache. The
// interval is min(ttl/4, 5min) so we don't waste cycles on long-TTL configs
// while still reclaiming slots in time under steady-state churn.
func (s *Server) runMediaSweep(ctx context.Context, cache *MediaCache, ttl time.Duration) {
if cache == nil {
return
}
interval := ttl / 4
if interval <= 0 || interval > 5*time.Minute {
interval = 5 * time.Minute
}
if interval < 30*time.Second {
interval = 30 * time.Second
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
cache.Sweep()
}
}
}
+91 -11
View File
@@ -253,7 +253,7 @@ func (tr *TelegramReader) fetchAll(ctx context.Context, api *tg.Client) {
}
userNames := buildUserMap(hist)
msgs, err := tr.extractMessages(hist, rp.chatType, userNames)
msgs, err := tr.extractMessages(ctx, api, hist, rp.chatType, userNames)
if err != nil {
log.Printf("[telegram] fetch %s: extract messages failed: %v", username, err)
failed++
@@ -334,7 +334,7 @@ func (tr *TelegramReader) fetchChannel(ctx context.Context, api *tg.Client, user
}
userNames := buildUserMap(hist)
return tr.extractMessages(hist, protocol.ChatTypeChannel, userNames)
return tr.extractMessages(ctx, api, hist, protocol.ChatTypeChannel, userNames)
}
// buildUserMap extracts a user ID → display name map from a history response.
@@ -366,7 +366,7 @@ func buildUserMap(hist tg.MessagesMessagesClass) map[int64]string {
return m
}
func (tr *TelegramReader) extractMessages(hist tg.MessagesMessagesClass, chatType protocol.ChatType, userNames map[int64]string) ([]protocol.Message, error) {
func (tr *TelegramReader) extractMessages(ctx context.Context, api *tg.Client, hist tg.MessagesMessagesClass, chatType protocol.ChatType, userNames map[int64]string) ([]protocol.Message, error) {
var tgMsgs []tg.MessageClass
switch h := hist.(type) {
@@ -380,21 +380,67 @@ func (tr *TelegramReader) extractMessages(hist tg.MessagesMessagesClass, chatTyp
return nil, fmt.Errorf("unexpected messages type: %T", hist)
}
var msgs []protocol.Message
// Album-aware grouping: Telegram delivers an album as N separate
// messages sharing the same GroupedID. We merge them into one feed
// message that carries every album item's media header and the album's
// single caption, with the lowest message ID as the canonical post id.
type album struct {
canonical *tg.Message
headers []string
caption string
}
groups := map[int64]*album{}
var order []int64
var nextSingleID int64 = -1 // sentinel keys for non-grouped messages
for _, raw := range tgMsgs {
msg, ok := raw.(*tg.Message)
if !ok {
continue
}
header, caption := tr.extractMediaHeaderAndCaption(ctx, api, msg)
if header == "" && caption == "" {
continue
}
gid := msg.GroupedID
if gid == 0 {
gid = nextSingleID
nextSingleID--
}
g, exists := groups[gid]
if !exists {
g = &album{canonical: msg}
groups[gid] = g
order = append(order, gid)
}
if header != "" {
g.headers = append(g.headers, header)
}
if caption != "" && g.caption == "" {
g.caption = caption
}
// Keep the canonical pointer at the lowest-id message so reply,
// timestamp, and ordering stay stable across album items.
if msg.ID < g.canonical.ID {
g.canonical = msg
}
}
text := tr.extractText(msg)
msgs := make([]protocol.Message, 0, len(order))
for _, gid := range order {
g := groups[gid]
text := strings.Join(g.headers, "\n")
if text != "" && g.caption != "" {
text += "\n" + g.caption
} else if text == "" {
text = g.caption
}
if text == "" {
continue
}
// For private chats, prefix with the sender's name.
if chatType == protocol.ChatTypePrivate {
if fromID, ok := msg.GetFromID(); ok {
if fromID, ok := g.canonical.GetFromID(); ok {
if pu, ok := fromID.(*tg.PeerUser); ok {
if name, ok := userNames[pu.UserID]; ok {
text = name + ": " + text
@@ -403,8 +449,7 @@ func (tr *TelegramReader) extractMessages(hist tg.MessagesMessagesClass, chatTyp
}
}
// Mark messages that are replies (include reply-to message ID).
if replyTo, hasReply := msg.GetReplyTo(); hasReply {
if replyTo, hasReply := g.canonical.GetReplyTo(); hasReply {
if rh, ok := replyTo.(*tg.MessageReplyHeader); ok {
if rid, hasID := rh.GetReplyToMsgID(); hasID {
text = fmt.Sprintf("%s:%d\n%s", protocol.MediaReply, rid, text)
@@ -417,8 +462,8 @@ func (tr *TelegramReader) extractMessages(hist tg.MessagesMessagesClass, chatTyp
}
msgs = append(msgs, protocol.Message{
ID: uint32(msg.ID),
Timestamp: uint32(msg.Date),
ID: uint32(g.canonical.ID),
Timestamp: uint32(g.canonical.Date),
Text: text,
})
}
@@ -426,6 +471,41 @@ func (tr *TelegramReader) extractMessages(hist tg.MessagesMessagesClass, chatTyp
return msgs, nil
}
// extractMediaHeaderAndCaption returns the [TAG]<meta> header line (if any)
// and the human caption for a single Telegram message. Used by the album
// merger to combine N messages into one feed message with multiple headers.
// Polls remain inline because they're never grouped into albums.
func (tr *TelegramReader) extractMediaHeaderAndCaption(ctx context.Context, api *tg.Client, msg *tg.Message) (header, caption string) {
caption = applyTextURLEntities(msg.Message, msg.Entities)
if msg.Media == nil {
return "", caption
}
switch m := msg.Media.(type) {
case *tg.MessageMediaPhoto, *tg.MessageMediaDocument:
if meta, ok := tr.downloadTelegramMedia(ctx, api, msg); ok {
header = strings.TrimSuffix(meta.String(), "\n")
return header, caption
}
// Non-downloadable image/doc: fall back to legacy [TAG] tag only.
if _, ok := m.(*tg.MessageMediaPhoto); ok {
return protocol.MediaImage, caption
}
if d, ok := m.(*tg.MessageMediaDocument); ok {
return tr.classifyDocument(d), caption
}
case *tg.MessageMediaGeo, *tg.MessageMediaGeoLive, *tg.MessageMediaVenue:
return protocol.MediaLocation, caption
case *tg.MessageMediaContact:
return protocol.MediaContact, caption
case *tg.MessageMediaPoll:
// Polls render with a synthesised body that's not a normal caption;
// keep the legacy single-message behaviour by returning the whole
// payload as the "caption" with no header.
return "", tr.extractText(msg)
}
return "", caption
}
func (tr *TelegramReader) extractText(msg *tg.Message) string {
text := applyTextURLEntities(msg.Message, msg.Entities)
+81 -6
View File
@@ -247,7 +247,7 @@ func (xr *XPublicReader) fetchAccount(ctx context.Context, username string) ([]p
continue
}
msgs, title, err := parseXRSSMessages(body, username)
msgs, sources, title, err := parseXRSSMessagesWithMedia(body, username)
if err != nil {
log.Printf("[x] @%s: instance %s: parse error: %v", username, instance, err)
lastErr = fmt.Errorf("%s: %w", instance, err)
@@ -255,9 +255,15 @@ func (xr *XPublicReader) fetchAccount(ctx context.Context, username string) ([]p
}
// Filter out garbled messages (invalid UTF-8 or mostly non-printable).
cleaned := msgs[:0]
for _, m := range msgs {
cleanedSources := sources[:0]
for i, m := range msgs {
if isReadableText(m.Text) {
cleaned = append(cleaned, m)
if i < len(sources) {
cleanedSources = append(cleanedSources, sources[i])
} else {
cleanedSources = append(cleanedSources, mediaSource{})
}
} else {
log.Printf("[x] @%s: skipping garbled message ID=%d (len=%d)", username, m.ID, len(m.Text))
}
@@ -266,6 +272,14 @@ func (xr *XPublicReader) fetchAccount(ctx context.Context, username string) ([]p
lastErr = fmt.Errorf("%s: all %d messages were garbled", instance, len(msgs))
continue
}
// Run image downloads when a media cache is attached. Each Nitter
// item carries an image URL we extracted from the description; for
// non-image media types we have no public URL to fetch on X, so the
// downstream rendering simply falls back to the legacy [TAG]\ncaption
// form for those.
if cache := xr.feed.MediaCache(); cache != nil && len(cleanedSources) > 0 {
applyHTTPMediaSources(ctx, cache, cleaned, cleanedSources)
}
return cleaned, title, nil
}
if lastErr == nil {
@@ -291,19 +305,32 @@ type xRSSItem struct {
}
func parseXRSSMessages(body []byte, feedUser string) ([]protocol.Message, string, error) {
msgs, _, title, err := parseXRSSMessagesWithMedia(body, feedUser)
return msgs, title, err
}
// parseXRSSMessagesWithMedia parses a Nitter RSS feed and additionally
// returns one mediaSource per parsed message — same length and order — so
// the caller can run HTTP downloads against the extracted image URLs and
// rewrite messages to use the [IMAGE]<size>:<dl>:<ch>:<blk>:<crc32> form.
// X posts on Nitter can contain multiple images per status; we only surface
// the *first* one for now, which keeps the download pipeline simple and
// matches what the legacy text rendering shows.
func parseXRSSMessagesWithMedia(body []byte, feedUser string) ([]protocol.Message, []mediaSource, string, error) {
body = sanitizeUTF8(body)
var feed xRSS
if err := xml.Unmarshal(body, &feed); err != nil {
return nil, "", fmt.Errorf("parse rss: %w", err)
return nil, nil, "", fmt.Errorf("parse rss: %w", err)
}
if len(feed.Channel.Items) == 0 {
return nil, "", fmt.Errorf("empty rss feed")
return nil, nil, "", fmt.Errorf("empty rss feed")
}
title := strings.TrimSpace(feed.Channel.Title)
feedUserLower := strings.ToLower(strings.TrimPrefix(feedUser, "@"))
msgs := make([]protocol.Message, 0, len(feed.Channel.Items))
sources := make([]mediaSource, 0, len(feed.Channel.Items))
for _, item := range feed.Channel.Items {
id, err := extractXStatusID(item.GUID, item.Link)
if err != nil {
@@ -329,12 +356,60 @@ func parseXRSSMessages(body []byte, feedUser string) ([]protocol.Message, string
}
}
// Best-effort image extraction from the description / encoded HTML.
src := mediaSource{}
if u := extractFirstImgSrc(item.Description); u != "" {
src = mediaSource{tag: protocol.MediaImage, url: u}
} else if u := extractFirstImgSrc(item.Encoded); u != "" {
src = mediaSource{tag: protocol.MediaImage, url: u}
}
msgs = append(msgs, protocol.Message{ID: id, Timestamp: ts, Text: text})
sources = append(sources, src)
}
if len(msgs) == 0 {
return nil, "", fmt.Errorf("no parseable posts")
return nil, nil, "", fmt.Errorf("no parseable posts")
}
return msgs, title, nil
return msgs, sources, title, nil
}
// extractFirstImgSrc scans an HTML fragment for the first <img src="..."> and
// returns the URL value. Returns "" when no img is present. We avoid pulling
// in golang.org/x/net/html for this single-purpose lookup; the regex only
// needs to handle the simple cases Nitter generates.
func extractFirstImgSrc(htmlFrag string) string {
if htmlFrag == "" {
return ""
}
low := strings.ToLower(htmlFrag)
idx := strings.Index(low, "<img ")
if idx < 0 {
return ""
}
tail := htmlFrag[idx:]
srcIdx := strings.Index(strings.ToLower(tail), "src=")
if srcIdx < 0 {
return ""
}
tail = tail[srcIdx+len("src="):]
if len(tail) == 0 {
return ""
}
quote := tail[0]
if quote != '"' && quote != '\'' {
// Bare attribute value — read until whitespace or '>'.
end := strings.IndexAny(tail, " >")
if end < 0 {
return ""
}
return strings.TrimSpace(tail[:end])
}
tail = tail[1:]
end := strings.IndexByte(tail, quote)
if end < 0 {
return ""
}
return tail[:end]
}
// extractLinkUsername extracts the username from a Nitter/X status URL.
+33
View File
@@ -0,0 +1,33 @@
package web
import (
"reflect"
"testing"
)
func TestNormaliseAutoUpdateList(t *testing.T) {
cases := []struct {
name string
in []string
want []string
}{
{"nil", nil, []string{}},
{"empty", []string{}, []string{}},
{"strip @", []string{"@one", "two"}, []string{"one", "two"}},
{"trim whitespace", []string{" one ", "\ttwo\n"}, []string{"one", "two"}},
{"drop empties", []string{"one", "", " ", "@", "two"}, []string{"one", "two"}},
{"dedupe preserves order", []string{"a", "b", "@a", "c", "b"}, []string{"a", "b", "c"}},
{"dedupe across @ form", []string{"@chan", "chan"}, []string{"chan"}},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got := normaliseAutoUpdateList(c.in)
if got == nil {
got = []string{}
}
if !reflect.DeepEqual(got, c.want) {
t.Errorf("normaliseAutoUpdateList(%v) = %v, want %v", c.in, got, c.want)
}
})
}
}
+356
View File
@@ -0,0 +1,356 @@
package web
import (
"bytes"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync/atomic"
"unicode"
"github.com/sartoopjj/thefeed/internal/client"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// mediaDLProgress tracks how many blocks of a single in-flight download have
// been fetched. The frontend polls /api/media/progress to drive a smooth
// per-block counter while the xhr is still reading the response body.
type mediaDLProgress struct {
completed int32
total int32
}
// mediaProgressKey is the join of (channel, blockCount, crc) the frontend
// uses to look up its own download. It matches the params on the GET URL so
// no extra bookkeeping leaks into the JSON response.
func mediaProgressKey(channel uint16, blocks uint16, crc uint32) string {
return fmt.Sprintf("%d:%d:%08x", channel, blocks, crc)
}
func (s *Server) handleMediaProgress(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
ch64, _ := strconv.ParseUint(q.Get("ch"), 10, 16)
blk64, _ := strconv.ParseUint(q.Get("blk"), 10, 16)
crc64, _ := strconv.ParseUint(strings.TrimSpace(q.Get("crc")), 16, 32)
key := mediaProgressKey(uint16(ch64), uint16(blk64), uint32(crc64))
s.dlMu.Lock()
prog := s.dlProgress[key]
s.dlMu.Unlock()
if prog == nil {
writeJSON(w, map[string]any{"active": false, "completed": 0, "total": int(blk64)})
return
}
writeJSON(w, map[string]any{
"active": true,
"completed": int(atomic.LoadInt32(&prog.completed)),
"total": int(atomic.LoadInt32(&prog.total)),
})
}
// handleMediaGet streams a media blob assembled from the
// (channel, blocks, crc) tuple embedded in a message's text.
//
// Query string:
//
// ch=<uint16> media channel number (10000..60000)
// blk=<uint16> total block count
// size=<bytes> expected file size (Content-Length)
// crc=<hex8> expected CRC32 of full body
// name=<filename> optional filename for Content-Disposition
// type=<mime> optional mime type override; sanitized
func (s *Server) handleMediaGet(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
q := r.URL.Query()
ch64, err := strconv.ParseUint(q.Get("ch"), 10, 16)
if err != nil {
http.Error(w, "bad ch", http.StatusBadRequest)
return
}
channel := uint16(ch64)
if !protocol.IsMediaChannel(channel) {
http.Error(w, "ch out of media range", http.StatusBadRequest)
return
}
blk64, err := strconv.ParseUint(q.Get("blk"), 10, 16)
if err != nil || blk64 == 0 {
http.Error(w, "bad blk", http.StatusBadRequest)
return
}
blockCount := uint16(blk64)
const maxClaimedSize = 100 * 1024 * 1024
expectedSize, _ := strconv.ParseInt(q.Get("size"), 10, 64)
if expectedSize < 0 || expectedSize > maxClaimedSize {
http.Error(w, "bad size", http.StatusBadRequest)
return
}
expectedCRC := uint32(0)
if v := strings.TrimSpace(q.Get("crc")); v != "" {
c, err := strconv.ParseUint(v, 16, 32)
if err != nil {
http.Error(w, "bad crc", http.StatusBadRequest)
return
}
expectedCRC = uint32(c)
}
s.mu.RLock()
fetcher := s.fetcher
s.mu.RUnlock()
if fetcher == nil {
http.Error(w, "fetcher not configured", http.StatusServiceUnavailable)
return
}
ctx := r.Context()
// Disk-cache hit: serve directly without ever talking to DNS.
if s.mediaCache != nil && expectedCRC != 0 && expectedSize > 0 {
if body, mime, ok := s.mediaCache.Get(expectedSize, expectedCRC); ok {
servedMime := sanitizeMime(q.Get("type"))
if servedMime == "application/octet-stream" {
if mime != "" {
servedMime = sanitizeMime(mime)
} else if sniffed := http.DetectContentType(body); sniffed != "" {
servedMime = sanitizeMime(sniffed)
}
}
w.Header().Set("Content-Type", servedMime)
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
w.Header().Set("Cache-Control", "private, max-age=86400")
w.Header().Set("X-Total-Blocks", strconv.Itoa(int(blockCount)))
w.Header().Set("X-Cache", "HIT")
if filename := sanitizeFilename(q.Get("name")); filename != "" {
w.Header().Set("Content-Disposition", "inline; filename=\""+filename+"\"")
}
if _, err := w.Write(body); err != nil {
s.addLog(fmt.Sprintf("media disk-cache write failed: %v", err))
}
return
}
}
// Fetch block 0 synchronously: it carries the protocol header (CRC32,
// version, compression). We need that before we can decompress and
// before we can sniff Content-Type from the decompressed body.
firstBlock, err := fetcher.FetchBlock(ctx, channel, 0)
if err != nil {
if ctx.Err() != nil {
http.Error(w, "fetch cancelled", 499)
return
}
http.Error(w, fmt.Sprintf("fetch media: %v", err), http.StatusBadGateway)
return
}
if len(firstBlock) < protocol.MediaBlockHeaderLen {
http.Error(w, "malformed block 0", http.StatusBadGateway)
return
}
header, err := protocol.DecodeMediaBlockHeader(firstBlock[:protocol.MediaBlockHeaderLen])
if err != nil {
http.Error(w, "malformed block 0", http.StatusBadGateway)
return
}
if expectedCRC != 0 && header.CRC32 != expectedCRC {
http.Error(w, "content hash mismatch", http.StatusBadGateway)
return
}
firstCompressed := firstBlock[protocol.MediaBlockHeaderLen:]
// Register this download so /api/media/progress can report block
// progress as the client polls. Block 0 is already fetched.
progKey := mediaProgressKey(channel, blockCount, expectedCRC)
prog := &mediaDLProgress{total: int32(blockCount), completed: 1}
s.dlMu.Lock()
s.dlProgress[progKey] = prog
s.dlMu.Unlock()
defer func() {
s.dlMu.Lock()
delete(s.dlProgress, progKey)
s.dlMu.Unlock()
}()
// Pipe compressed bytes (block-0 payload + later blocks) into a
// decompressor reader. Fed by a goroutine; consumed below for sniffing
// and for streaming to the HTTP response.
pipeR, pipeW := io.Pipe()
go func() {
var pipeErr error
defer func() { pipeW.CloseWithError(pipeErr) }()
if _, err := pipeW.Write(firstCompressed); err != nil {
pipeErr = err
return
}
if blockCount > 1 {
progressCB := func(done, _ int) {
// done counts blocks 1..N-1; add 1 for block 0 already fetched.
atomic.StoreInt32(&prog.completed, int32(done+1))
}
pipeErr = fetcher.FetchMediaBlocksStream(ctx, channel, 1, blockCount-1, pipeW, progressCB)
}
}()
body, err := client.DecompressMediaReader(pipeR, header.Compression)
if err != nil {
http.Error(w, fmt.Sprintf("decompress: %v", err), http.StatusBadGateway)
return
}
defer body.Close()
// Tee decompressed bytes into a buffer so we can persist them to the
// disk cache after a successful response.
var teeBuf *bytes.Buffer
if s.mediaCache != nil && expectedCRC != 0 && expectedSize > 0 && expectedSize <= mediaCacheMaxFileExt {
teeBuf = bytes.NewBuffer(make([]byte, 0, expectedSize))
}
// Sniff Content-Type from the first decompressed bytes before flushing
// headers — once Content-Type goes out we can't change it.
const sniffSize = 512
sniff := make([]byte, sniffSize)
n, err := io.ReadFull(body, sniff)
if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
http.Error(w, fmt.Sprintf("read media: %v", err), http.StatusBadGateway)
return
}
sniff = sniff[:n]
mime := sanitizeMime(q.Get("type"))
if mime == "application/octet-stream" {
if got := http.DetectContentType(sniff); got != "" {
mime = sanitizeMime(got)
}
}
filename := sanitizeFilename(q.Get("name"))
w.Header().Set("Content-Type", mime)
if expectedSize > 0 {
w.Header().Set("Content-Length", strconv.FormatInt(expectedSize, 10))
}
w.Header().Set("Cache-Control", "private, max-age=86400")
w.Header().Set("X-Total-Blocks", strconv.Itoa(int(blockCount)))
w.Header().Set("X-Cache", "MISS")
w.Header().Set("X-Media-Compression", header.Compression.String())
if filename != "" {
w.Header().Set("Content-Disposition", "inline; filename=\""+filename+"\"")
}
flusher, _ := w.(http.Flusher)
if teeBuf != nil {
teeBuf.Write(sniff)
}
if _, err := w.Write(sniff); err != nil {
s.addLog(fmt.Sprintf("media write head failed: %v", err))
return
}
if flusher != nil {
flusher.Flush()
}
dst := io.Writer(&flushAfterEachWriter{w: w, flusher: flusher})
if teeBuf != nil {
dst = io.MultiWriter(dst, teeBuf)
}
// Small buffer so the browser sees many small chunks instead of one big
// one — the xhr onprogress event fires per chunk, which is what drives
// the smooth K/N block counter on the client.
buf := make([]byte, 2048)
if _, err := io.CopyBuffer(dst, body, buf); err != nil {
s.addLog(fmt.Sprintf("media stream failed: %v", err))
return
}
if teeBuf == nil {
s.addLog(fmt.Sprintf("media disk-cache skipped: size=%d crc=%x mediaCache=%v", expectedSize, expectedCRC, s.mediaCache != nil))
} else if expectedSize > 0 && int64(teeBuf.Len()) != expectedSize {
s.addLog(fmt.Sprintf("media disk-cache skipped: tee=%d expected=%d (truncated stream)", teeBuf.Len(), expectedSize))
} else {
if err := s.mediaCache.Put(int64(teeBuf.Len()), expectedCRC, teeBuf.Bytes(), mime); err != nil {
s.addLog(fmt.Sprintf("media disk-cache put failed: %v", err))
} else {
s.addLog(fmt.Sprintf("media cached: %d bytes, crc=%08x, mime=%s", teeBuf.Len(), expectedCRC, mime))
}
}
}
type flushAfterEachWriter struct {
w http.ResponseWriter
flusher http.Flusher
}
func (fw *flushAfterEachWriter) Write(p []byte) (int, error) {
n, err := fw.w.Write(p)
if err == nil && fw.flusher != nil {
fw.flusher.Flush()
}
return n, err
}
func (fw *flushAfterEachWriter) Flush() {
if fw.flusher != nil {
fw.flusher.Flush()
}
}
// sanitizeMime returns a "type/subtype" MIME string built from safe
// characters. HTML/SVG variants are rejected.
func sanitizeMime(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return "application/octet-stream"
}
if i := strings.IndexByte(s, ';'); i >= 0 {
s = strings.TrimSpace(s[:i])
}
slash := strings.IndexByte(s, '/')
if slash <= 0 || slash == len(s)-1 {
return "application/octet-stream"
}
for _, r := range s {
if r == '/' || r == '-' || r == '+' || r == '.' {
continue
}
if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
return "application/octet-stream"
}
}
switch strings.ToLower(s) {
case "text/html", "application/xhtml+xml", "image/svg+xml":
return "application/octet-stream"
}
return s
}
// sanitizeFilename strips path components and control characters.
func sanitizeFilename(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return ""
}
if i := strings.LastIndexAny(s, `/\`); i >= 0 {
s = s[i+1:]
}
if s == "" || s == ".." {
return ""
}
var b strings.Builder
for _, r := range s {
if r < 0x20 || r == 0x7F || r == '"' || r == '\\' {
continue
}
b.WriteRune(r)
}
out := b.String()
if len(out) > 200 {
out = out[:200]
}
return out
}
+179
View File
@@ -0,0 +1,179 @@
package web
import (
"encoding/binary"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
const (
mediaCacheFileExt = ".cache"
mediaCacheMaxMime = 200
mediaCacheMaxFileExt = 1 << 26 // 64 MiB hard cap per cached file
)
// mediaDiskCache stores downloaded media blobs on disk so multiple devices
// connected to the same client/server share the cost of one DNS-tunnelled
// fetch. Entries are content-addressed by (size, crc32) and reaped after
// ttl based on file mtime.
//
// File format: each entry is a single file
//
// <size>_<crc8hex>.cache
//
// containing:
//
// 2 bytes BE — mime length
// N bytes — mime utf8
// rest — raw file bytes
type mediaDiskCache struct {
dir string
ttl time.Duration
mu sync.Mutex
}
func newMediaDiskCache(dir string, ttl time.Duration) (*mediaDiskCache, error) {
if dir == "" {
return nil, errors.New("media cache dir is empty")
}
if err := os.MkdirAll(dir, 0o700); err != nil {
return nil, err
}
return &mediaDiskCache{dir: dir, ttl: ttl}, nil
}
func (c *mediaDiskCache) keyFile(size int64, crc uint32) string {
return filepath.Join(c.dir, fmt.Sprintf("%d_%08x%s", size, crc, mediaCacheFileExt))
}
// Get returns the cached body and mime type if present and not expired.
// Touching mtime on hit so the entry stays alive while it's in use.
func (c *mediaDiskCache) Get(size int64, crc uint32) (body []byte, mime string, ok bool) {
if size <= 0 || crc == 0 {
return nil, "", false
}
path := c.keyFile(size, crc)
info, err := os.Stat(path)
if err != nil {
return nil, "", false
}
if c.ttl > 0 && time.Since(info.ModTime()) > c.ttl {
_ = os.Remove(path)
return nil, "", false
}
data, err := os.ReadFile(path)
if err != nil || len(data) < 2 {
return nil, "", false
}
mimeLen := int(binary.BigEndian.Uint16(data[:2]))
if mimeLen > mediaCacheMaxMime || 2+mimeLen > len(data) {
return nil, "", false
}
mime = string(data[2 : 2+mimeLen])
body = data[2+mimeLen:]
if int64(len(body)) != size {
// Corrupt or partial write — treat as miss.
return nil, "", false
}
_ = os.Chtimes(path, time.Now(), time.Now())
return body, mime, true
}
// Put writes the body+mime atomically to the cache.
func (c *mediaDiskCache) Put(size int64, crc uint32, body []byte, mime string) error {
if size <= 0 || crc == 0 || int64(len(body)) != size {
return errors.New("media cache: invalid put")
}
if len(body) > mediaCacheMaxFileExt {
return errors.New("media cache: body too large")
}
if len(mime) > mediaCacheMaxMime {
mime = mime[:mediaCacheMaxMime]
}
c.mu.Lock()
defer c.mu.Unlock()
path := c.keyFile(size, crc)
tmp := path + ".tmp"
f, err := os.Create(tmp)
if err != nil {
return err
}
header := make([]byte, 2)
binary.BigEndian.PutUint16(header, uint16(len(mime)))
if _, err := f.Write(header); err != nil {
f.Close()
os.Remove(tmp)
return err
}
if _, err := f.Write([]byte(mime)); err != nil {
f.Close()
os.Remove(tmp)
return err
}
if _, err := f.Write(body); err != nil {
f.Close()
os.Remove(tmp)
return err
}
if err := f.Close(); err != nil {
os.Remove(tmp)
return err
}
return os.Rename(tmp, path)
}
// Cleanup removes entries older than ttl. Returns the count removed.
func (c *mediaDiskCache) Cleanup() int {
if c.ttl <= 0 {
return 0
}
c.mu.Lock()
defer c.mu.Unlock()
entries, err := os.ReadDir(c.dir)
if err != nil {
return 0
}
now := time.Now()
removed := 0
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), mediaCacheFileExt) {
continue
}
info, err := e.Info()
if err != nil {
continue
}
if now.Sub(info.ModTime()) > c.ttl {
if os.Remove(filepath.Join(c.dir, e.Name())) == nil {
removed++
}
}
}
return removed
}
// Clear deletes every cached entry. Returns the count removed.
func (c *mediaDiskCache) Clear() int {
c.mu.Lock()
defer c.mu.Unlock()
entries, err := os.ReadDir(c.dir)
if err != nil {
return 0
}
removed := 0
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), mediaCacheFileExt) {
continue
}
if os.Remove(filepath.Join(c.dir, e.Name())) == nil {
removed++
}
}
return removed
}
+42
View File
@@ -0,0 +1,42 @@
package web
import "testing"
func TestSanitizeMime(t *testing.T) {
cases := map[string]string{
"": "application/octet-stream",
"image/jpeg": "image/jpeg",
"image/png; charset=utf-8": "image/png",
"text/html": "application/octet-stream", // blocked
"application/xhtml+xml": "application/octet-stream", // blocked
"image/svg+xml": "application/octet-stream", // blocked (XSS via SVG)
"image/jpeg<script>": "application/octet-stream", // bad chars
"weird": "application/octet-stream", // no slash
"/leading": "application/octet-stream",
"trailing/": "application/octet-stream",
"application/vnd.api+json": "application/vnd.api+json",
"image/webp": "image/webp",
}
for in, want := range cases {
if got := sanitizeMime(in); got != want {
t.Errorf("sanitizeMime(%q) = %q, want %q", in, got, want)
}
}
}
func TestSanitizeFilename(t *testing.T) {
cases := map[string]string{
"": "",
"foo.png": "foo.png",
"../../etc/passwd": "passwd",
"foo/bar/baz.txt": "baz.txt",
"weird\nname.txt": "weirdname.txt",
`bad"quote"name`: "badquotename",
"..": "",
}
for in, want := range cases {
if got := sanitizeFilename(in); got != want {
t.Errorf("sanitizeFilename(%q) = %q, want %q", in, got, want)
}
}
}
File diff suppressed because it is too large Load Diff
+372 -14
View File
@@ -49,12 +49,31 @@ type Config struct {
}
// Profile wraps a Config with a user-chosen nickname and a unique ID.
// AutoUpdate is the per-profile list of channel usernames the auto-update
// goroutine should refresh; AutoUpdateInterval (seconds, 0 → default) sets
// the cadence.
type Profile struct {
ID string `json:"id"`
Nickname string `json:"nickname"`
Config Config `json:"config"`
ID string `json:"id"`
Nickname string `json:"nickname"`
Config Config `json:"config"`
AutoUpdate []string `json:"autoUpdate,omitempty"`
AutoUpdateInterval int `json:"autoUpdateInterval,omitempty"`
}
const (
// minAutoUpdateInterval is the floor — never tick faster than once per
// minute, even if the user sets something silly. The DNS path is
// expensive and the server's own fetch cycle is much longer.
minAutoUpdateInterval = 60 * time.Second
// serverFetchSettleDelay is how long after nextFetch we wait before
// asking the server for fresh data — gives it time to process the
// upstream Telegram fetch and have a coherent metadata snapshot.
serverFetchSettleDelay = 30 * time.Second
// autoUpdateStartupDelay defers the first tick so the initial metadata
// + resolver checks have a chance to land before we start polling.
autoUpdateStartupDelay = 30 * time.Second
)
// SavedResolverScore stores persistent resolver performance data.
type SavedResolverScore struct {
Success int64 `json:"success"`
@@ -137,6 +156,17 @@ type Server struct {
titlesMu sync.Mutex
titlesLoading bool
titlesBackoffUntil time.Time
// dlMu guards dlProgress. Active media downloads register their block
// counter here so the frontend can poll /api/media/progress and show
// per-block updates instead of waiting for byte chunks.
dlMu sync.Mutex
dlProgress map[string]*mediaDLProgress
// mediaCache is a disk-backed store for downloaded media bytes so that
// multiple devices on the same network share a single DNS-tunnelled
// fetch. Entries expire after 7 days.
mediaCache *mediaDiskCache
}
// New creates a new web server.
@@ -154,6 +184,11 @@ func New(dataDir string, port int, host string, password string) (*Server, error
scanner := client.NewResolverScanner()
mediaCache, mcErr := newMediaDiskCache(filepath.Join(dataDir, "media-cache"), 7*24*time.Hour)
if mcErr != nil {
log.Printf("Warning: media disk cache disabled: %v", mcErr)
}
s := &Server{
dataDir: dataDir,
port: port,
@@ -165,6 +200,13 @@ func New(dataDir string, port int, host string, password string) (*Server, error
lastMsgIDs: make(map[int]uint32),
lastHashes: make(map[int]uint32),
scanner: scanner,
mediaCache: mediaCache,
dlProgress: make(map[string]*mediaDLProgress),
}
if mediaCache != nil {
go mediaCache.Cleanup()
go s.runMediaCacheSweep()
}
// Migrate per-profile resolvers into the shared bank on first run.
@@ -213,6 +255,8 @@ func (s *Server) Run() error {
mux.HandleFunc("/api/events", s.handleSSE)
mux.HandleFunc("/api/profiles", s.handleProfiles)
mux.HandleFunc("/api/profiles/switch", s.handleProfileSwitch)
mux.HandleFunc("/api/auto-update", s.handleAutoUpdate)
mux.HandleFunc("/api/auto-update/toggle", s.handleAutoUpdateToggle)
mux.HandleFunc("/api/settings", s.handleSettings)
mux.HandleFunc("/api/version-check", s.handleVersionCheck)
mux.HandleFunc("/api/cache/clear", s.handleClearCache)
@@ -230,6 +274,11 @@ func (s *Server) Run() error {
mux.HandleFunc("/api/scanner/progress", s.handleScannerProgress)
mux.HandleFunc("/api/scanner/apply", s.handleScannerApply)
mux.HandleFunc("/api/scanner/presets", s.handleScannerPresets)
// Media (image/file) downloader: assembles a binary blob from a media
// channel and streams it back. See internal/web/media.go for the param
// contract.
mux.HandleFunc("/api/media/get", s.handleMediaGet)
mux.HandleFunc("/api/media/progress", s.handleMediaProgress)
mux.HandleFunc("/", s.handleIndex)
// Listen on the specified host (default 127.0.0.1)
@@ -782,9 +831,140 @@ func (s *Server) initFetcher() error {
s.fetcher = fetcher
s.cache = cache
go cache.Cleanup() // remove channel files not updated in 7 days
// Goroutine dies with fetcherCtx, so a profile switch / config change
// stops it cleanly.
go s.runAutoUpdateLoop(ctx)
return nil
}
// runAutoUpdateLoop refreshes the active profile's AutoUpdate channels on a
// schedule that follows the server's own fetch cycle — there's no point
// polling more often than the server actually pulls fresh data from
// Telegram. User-set Profile.AutoUpdateInterval is honoured if it's >= the
// 60s floor; otherwise we align with nextFetch + settle delay.
func (s *Server) runAutoUpdateLoop(ctx context.Context) {
select {
case <-time.After(autoUpdateStartupDelay):
case <-ctx.Done():
return
}
var lastTick time.Time
for {
wait := s.nextAutoUpdateWait(lastTick)
select {
case <-ctx.Done():
return
case <-time.After(wait):
}
if !s.canAutoUpdate() {
continue
}
s.tickAutoUpdate()
lastTick = time.Now()
}
}
// nextAutoUpdateWait returns how long to sleep before the next tick. Honours
// user override when set sensibly; otherwise sleeps until just after the
// server's next Telegram fetch so we always pull just-refreshed data.
func (s *Server) nextAutoUpdateWait(lastTick time.Time) time.Duration {
pl, _ := s.loadProfiles()
if pl != nil && pl.Active != "" {
for _, p := range pl.Profiles {
if p.ID != pl.Active {
continue
}
if p.AutoUpdateInterval > 0 {
user := time.Duration(p.AutoUpdateInterval) * time.Second
if user < minAutoUpdateInterval {
user = minAutoUpdateInterval
}
return user
}
break
}
}
s.mu.RLock()
nf := s.nextFetch
s.mu.RUnlock()
if nf == 0 {
return minAutoUpdateInterval
}
target := time.Unix(int64(nf), 0).Add(serverFetchSettleDelay)
delay := time.Until(target)
if delay < minAutoUpdateInterval {
delay = minAutoUpdateInterval
}
if !lastTick.IsZero() {
if since := time.Since(lastTick); since < minAutoUpdateInterval {
if rem := minAutoUpdateInterval - since; rem > delay {
delay = rem
}
}
}
return delay
}
// canAutoUpdate returns false when we should skip a tick: server hasn't
// produced metadata yet (channel list empty), or the resolver scanner is
// busy (it'd race with our DNS fetches), or there's no fetcher.
func (s *Server) canAutoUpdate() bool {
s.mu.RLock()
channels := s.channels
fetcher := s.fetcher
scanner := s.scanner
s.mu.RUnlock()
if fetcher == nil || len(channels) == 0 {
return false
}
if scanner != nil {
switch scanner.State() {
case client.ScannerRunning, client.ScannerPaused:
return false
}
}
return true
}
func (s *Server) tickAutoUpdate() {
pl, err := s.loadProfiles()
if err != nil || pl == nil || pl.Active == "" {
return
}
var watch []string
for _, p := range pl.Profiles {
if p.ID == pl.Active {
watch = p.AutoUpdate
break
}
}
if len(watch) == 0 {
return
}
s.mu.RLock()
channels := s.channels
s.mu.RUnlock()
if len(channels) == 0 {
return
}
wantSet := make(map[string]bool, len(watch))
for _, name := range watch {
wantSet[strings.TrimPrefix(strings.TrimSpace(name), "@")] = true
}
for i, ch := range channels {
if !wantSet[ch.Name] {
continue
}
go s.refreshChannel(i + 1) // 1-indexed
}
}
func (s *Server) checkLatestVersion(ctx context.Context) (string, error) {
s.mu.RLock()
cfg := s.config
@@ -1961,6 +2141,10 @@ func (s *Server) handleProfiles(w http.ResponseWriter, r *http.Request) {
addToBank(pl, req.Profile.Config.Resolvers)
req.Profile.Config.Resolvers = nil
}
// Carry over fields the edit-profile UI doesn't manage so
// they don't get wiped on save (auto-update list etc.).
req.Profile.AutoUpdate = p.AutoUpdate
req.Profile.AutoUpdateInterval = p.AutoUpdateInterval
pl.Profiles[i] = req.Profile
if p.ID == pl.Active {
needsReinit = true
@@ -2097,6 +2281,164 @@ func (s *Server) handleProfileSwitch(w http.ResponseWriter, r *http.Request) {
writeJSON(w, map[string]any{"ok": true})
}
// handleAutoUpdate exposes the active profile's auto-update channel list.
// GET → {channels, intervalSeconds, defaultIntervalSeconds}.
// POST {channels, intervalSeconds?} replaces both. Names are stripped and
// dedup'd before saving.
func (s *Server) handleAutoUpdate(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
pl, _ := s.loadProfiles()
channels := []string{}
interval := 0
if pl != nil && pl.Active != "" {
for _, p := range pl.Profiles {
if p.ID == pl.Active {
if p.AutoUpdate != nil {
channels = p.AutoUpdate
}
interval = p.AutoUpdateInterval
break
}
}
}
writeJSON(w, map[string]any{
"channels": channels,
"intervalSeconds": interval,
"defaultIntervalSeconds": int(minAutoUpdateInterval / time.Second),
})
case http.MethodPost:
var req struct {
Channels []string `json:"channels"`
IntervalSeconds *int `json:"intervalSeconds,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid JSON", 400)
return
}
pl, err := s.loadProfiles()
if err != nil || pl == nil || pl.Active == "" {
http.Error(w, "no active profile", 400)
return
}
idx := -1
for i, p := range pl.Profiles {
if p.ID == pl.Active {
idx = i
break
}
}
if idx < 0 {
http.Error(w, "active profile not found", 400)
return
}
pl.Profiles[idx].AutoUpdate = normaliseAutoUpdateList(req.Channels)
if req.IntervalSeconds != nil {
v := *req.IntervalSeconds
if v < 0 {
v = 0
}
minSec := int(minAutoUpdateInterval / time.Second)
if v > 0 && v < minSec {
v = minSec // floor: never poll faster than the server fetches
}
pl.Profiles[idx].AutoUpdateInterval = v
}
if err := s.saveProfiles(pl); err != nil {
http.Error(w, fmt.Sprintf("save: %v", err), 500)
return
}
writeJSON(w, map[string]any{
"ok": true,
"channels": pl.Profiles[idx].AutoUpdate,
"intervalSeconds": pl.Profiles[idx].AutoUpdateInterval,
})
default:
http.Error(w, "method not allowed", 405)
}
}
// handleAutoUpdateToggle flips one channel's membership. Body {channel}.
// Returns {enabled, channels}.
func (s *Server) handleAutoUpdateToggle(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", 405)
return
}
var req struct {
Channel string `json:"channel"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid JSON", 400)
return
}
name := strings.TrimPrefix(strings.TrimSpace(req.Channel), "@")
if name == "" {
http.Error(w, "channel required", 400)
return
}
pl, err := s.loadProfiles()
if err != nil || pl == nil || pl.Active == "" {
http.Error(w, "no active profile", 400)
return
}
idx := -1
for i, p := range pl.Profiles {
if p.ID == pl.Active {
idx = i
break
}
}
if idx < 0 {
http.Error(w, "active profile not found", 400)
return
}
current := pl.Profiles[idx].AutoUpdate
on := false
hit := -1
for i, n := range current {
if strings.TrimPrefix(strings.TrimSpace(n), "@") == name {
hit = i
break
}
}
if hit >= 0 {
current = append(current[:hit], current[hit+1:]...)
} else {
current = append(current, name)
on = true
}
pl.Profiles[idx].AutoUpdate = normaliseAutoUpdateList(current)
if err := s.saveProfiles(pl); err != nil {
http.Error(w, fmt.Sprintf("save: %v", err), 500)
return
}
writeJSON(w, map[string]any{
"ok": true,
"channel": name,
"enabled": on,
"channels": pl.Profiles[idx].AutoUpdate,
})
}
// normaliseAutoUpdateList strips @ + whitespace, drops empties, dedupes
// while preserving order.
func normaliseAutoUpdateList(in []string) []string {
seen := make(map[string]bool, len(in))
out := make([]string, 0, len(in))
for _, raw := range in {
name := strings.TrimPrefix(strings.TrimSpace(raw), "@")
if name == "" || seen[name] {
continue
}
seen[name] = true
out = append(out, name)
}
return out
}
// handleSettings manages user preferences (font size etc.).
func (s *Server) handleSettings(w http.ResponseWriter, r *http.Request) {
switch r.Method {
@@ -2223,26 +2565,42 @@ func (s *Server) handleVersionCheck(w http.ResponseWriter, r *http.Request) {
writeJSON(w, map[string]any{"ok": true, "latestVersion": v})
}
// handleClearCache deletes all files in the cache directory.
// runMediaCacheSweep evicts expired media-cache entries every hour for the
// lifetime of the process.
func (s *Server) runMediaCacheSweep() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
if s.mediaCache == nil {
return
}
s.mediaCache.Cleanup()
}
}
// handleClearCache wipes both the per-channel message cache and the
// downloaded-media disk cache.
func (s *Server) handleClearCache(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", 405)
return
}
cacheDir := filepath.Join(s.dataDir, "cache")
entries, err := os.ReadDir(cacheDir)
if err != nil {
writeJSON(w, map[string]any{"ok": true, "deleted": 0})
return
}
deleted := 0
for _, e := range entries {
if !e.IsDir() {
cacheDir := filepath.Join(s.dataDir, "cache")
if entries, err := os.ReadDir(cacheDir); err == nil {
for _, e := range entries {
if e.IsDir() {
continue
}
if os.Remove(filepath.Join(cacheDir, e.Name())) == nil {
deleted++
}
}
}
s.addLog(fmt.Sprintf("Cache cleared: %d files deleted", deleted))
writeJSON(w, map[string]any{"ok": true, "deleted": deleted})
mediaDeleted := 0
if s.mediaCache != nil {
mediaDeleted = s.mediaCache.Clear()
}
s.addLog(fmt.Sprintf("Cache cleared: %d message files, %d media files", deleted, mediaDeleted))
writeJSON(w, map[string]any{"ok": true, "deleted": deleted, "mediaDeleted": mediaDeleted})
}