feat: media download with DNS query

This commit is contained in:
Sarto
2026-04-29 01:45:27 +03:30
parent 11946c0147
commit b4e9cd8714
34 changed files with 6303 additions and 137 deletions
+33
View File
@@ -0,0 +1,33 @@
package web
import (
"reflect"
"testing"
)
func TestNormaliseAutoUpdateList(t *testing.T) {
cases := []struct {
name string
in []string
want []string
}{
{"nil", nil, []string{}},
{"empty", []string{}, []string{}},
{"strip @", []string{"@one", "two"}, []string{"one", "two"}},
{"trim whitespace", []string{" one ", "\ttwo\n"}, []string{"one", "two"}},
{"drop empties", []string{"one", "", " ", "@", "two"}, []string{"one", "two"}},
{"dedupe preserves order", []string{"a", "b", "@a", "c", "b"}, []string{"a", "b", "c"}},
{"dedupe across @ form", []string{"@chan", "chan"}, []string{"chan"}},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got := normaliseAutoUpdateList(c.in)
if got == nil {
got = []string{}
}
if !reflect.DeepEqual(got, c.want) {
t.Errorf("normaliseAutoUpdateList(%v) = %v, want %v", c.in, got, c.want)
}
})
}
}
+356
View File
@@ -0,0 +1,356 @@
package web
import (
"bytes"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync/atomic"
"unicode"
"github.com/sartoopjj/thefeed/internal/client"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// mediaDLProgress tracks how many blocks of a single in-flight download have
// been fetched. The frontend polls /api/media/progress to drive a smooth
// per-block counter while the xhr is still reading the response body.
type mediaDLProgress struct {
completed int32
total int32
}
// mediaProgressKey is the join of (channel, blockCount, crc) the frontend
// uses to look up its own download. It matches the params on the GET URL so
// no extra bookkeeping leaks into the JSON response.
func mediaProgressKey(channel uint16, blocks uint16, crc uint32) string {
return fmt.Sprintf("%d:%d:%08x", channel, blocks, crc)
}
func (s *Server) handleMediaProgress(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
ch64, _ := strconv.ParseUint(q.Get("ch"), 10, 16)
blk64, _ := strconv.ParseUint(q.Get("blk"), 10, 16)
crc64, _ := strconv.ParseUint(strings.TrimSpace(q.Get("crc")), 16, 32)
key := mediaProgressKey(uint16(ch64), uint16(blk64), uint32(crc64))
s.dlMu.Lock()
prog := s.dlProgress[key]
s.dlMu.Unlock()
if prog == nil {
writeJSON(w, map[string]any{"active": false, "completed": 0, "total": int(blk64)})
return
}
writeJSON(w, map[string]any{
"active": true,
"completed": int(atomic.LoadInt32(&prog.completed)),
"total": int(atomic.LoadInt32(&prog.total)),
})
}
// handleMediaGet streams a media blob assembled from the
// (channel, blocks, crc) tuple embedded in a message's text.
//
// Query string:
//
// ch=<uint16> media channel number (10000..60000)
// blk=<uint16> total block count
// size=<bytes> expected file size (Content-Length)
// crc=<hex8> expected CRC32 of full body
// name=<filename> optional filename for Content-Disposition
// type=<mime> optional mime type override; sanitized
func (s *Server) handleMediaGet(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
q := r.URL.Query()
ch64, err := strconv.ParseUint(q.Get("ch"), 10, 16)
if err != nil {
http.Error(w, "bad ch", http.StatusBadRequest)
return
}
channel := uint16(ch64)
if !protocol.IsMediaChannel(channel) {
http.Error(w, "ch out of media range", http.StatusBadRequest)
return
}
blk64, err := strconv.ParseUint(q.Get("blk"), 10, 16)
if err != nil || blk64 == 0 {
http.Error(w, "bad blk", http.StatusBadRequest)
return
}
blockCount := uint16(blk64)
const maxClaimedSize = 100 * 1024 * 1024
expectedSize, _ := strconv.ParseInt(q.Get("size"), 10, 64)
if expectedSize < 0 || expectedSize > maxClaimedSize {
http.Error(w, "bad size", http.StatusBadRequest)
return
}
expectedCRC := uint32(0)
if v := strings.TrimSpace(q.Get("crc")); v != "" {
c, err := strconv.ParseUint(v, 16, 32)
if err != nil {
http.Error(w, "bad crc", http.StatusBadRequest)
return
}
expectedCRC = uint32(c)
}
s.mu.RLock()
fetcher := s.fetcher
s.mu.RUnlock()
if fetcher == nil {
http.Error(w, "fetcher not configured", http.StatusServiceUnavailable)
return
}
ctx := r.Context()
// Disk-cache hit: serve directly without ever talking to DNS.
if s.mediaCache != nil && expectedCRC != 0 && expectedSize > 0 {
if body, mime, ok := s.mediaCache.Get(expectedSize, expectedCRC); ok {
servedMime := sanitizeMime(q.Get("type"))
if servedMime == "application/octet-stream" {
if mime != "" {
servedMime = sanitizeMime(mime)
} else if sniffed := http.DetectContentType(body); sniffed != "" {
servedMime = sanitizeMime(sniffed)
}
}
w.Header().Set("Content-Type", servedMime)
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
w.Header().Set("Cache-Control", "private, max-age=86400")
w.Header().Set("X-Total-Blocks", strconv.Itoa(int(blockCount)))
w.Header().Set("X-Cache", "HIT")
if filename := sanitizeFilename(q.Get("name")); filename != "" {
w.Header().Set("Content-Disposition", "inline; filename=\""+filename+"\"")
}
if _, err := w.Write(body); err != nil {
s.addLog(fmt.Sprintf("media disk-cache write failed: %v", err))
}
return
}
}
// Fetch block 0 synchronously: it carries the protocol header (CRC32,
// version, compression). We need that before we can decompress and
// before we can sniff Content-Type from the decompressed body.
firstBlock, err := fetcher.FetchBlock(ctx, channel, 0)
if err != nil {
if ctx.Err() != nil {
http.Error(w, "fetch cancelled", 499)
return
}
http.Error(w, fmt.Sprintf("fetch media: %v", err), http.StatusBadGateway)
return
}
if len(firstBlock) < protocol.MediaBlockHeaderLen {
http.Error(w, "malformed block 0", http.StatusBadGateway)
return
}
header, err := protocol.DecodeMediaBlockHeader(firstBlock[:protocol.MediaBlockHeaderLen])
if err != nil {
http.Error(w, "malformed block 0", http.StatusBadGateway)
return
}
if expectedCRC != 0 && header.CRC32 != expectedCRC {
http.Error(w, "content hash mismatch", http.StatusBadGateway)
return
}
firstCompressed := firstBlock[protocol.MediaBlockHeaderLen:]
// Register this download so /api/media/progress can report block
// progress as the client polls. Block 0 is already fetched.
progKey := mediaProgressKey(channel, blockCount, expectedCRC)
prog := &mediaDLProgress{total: int32(blockCount), completed: 1}
s.dlMu.Lock()
s.dlProgress[progKey] = prog
s.dlMu.Unlock()
defer func() {
s.dlMu.Lock()
delete(s.dlProgress, progKey)
s.dlMu.Unlock()
}()
// Pipe compressed bytes (block-0 payload + later blocks) into a
// decompressor reader. Fed by a goroutine; consumed below for sniffing
// and for streaming to the HTTP response.
pipeR, pipeW := io.Pipe()
go func() {
var pipeErr error
defer func() { pipeW.CloseWithError(pipeErr) }()
if _, err := pipeW.Write(firstCompressed); err != nil {
pipeErr = err
return
}
if blockCount > 1 {
progressCB := func(done, _ int) {
// done counts blocks 1..N-1; add 1 for block 0 already fetched.
atomic.StoreInt32(&prog.completed, int32(done+1))
}
pipeErr = fetcher.FetchMediaBlocksStream(ctx, channel, 1, blockCount-1, pipeW, progressCB)
}
}()
body, err := client.DecompressMediaReader(pipeR, header.Compression)
if err != nil {
http.Error(w, fmt.Sprintf("decompress: %v", err), http.StatusBadGateway)
return
}
defer body.Close()
// Tee decompressed bytes into a buffer so we can persist them to the
// disk cache after a successful response.
var teeBuf *bytes.Buffer
if s.mediaCache != nil && expectedCRC != 0 && expectedSize > 0 && expectedSize <= mediaCacheMaxFileExt {
teeBuf = bytes.NewBuffer(make([]byte, 0, expectedSize))
}
// Sniff Content-Type from the first decompressed bytes before flushing
// headers — once Content-Type goes out we can't change it.
const sniffSize = 512
sniff := make([]byte, sniffSize)
n, err := io.ReadFull(body, sniff)
if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
http.Error(w, fmt.Sprintf("read media: %v", err), http.StatusBadGateway)
return
}
sniff = sniff[:n]
mime := sanitizeMime(q.Get("type"))
if mime == "application/octet-stream" {
if got := http.DetectContentType(sniff); got != "" {
mime = sanitizeMime(got)
}
}
filename := sanitizeFilename(q.Get("name"))
w.Header().Set("Content-Type", mime)
if expectedSize > 0 {
w.Header().Set("Content-Length", strconv.FormatInt(expectedSize, 10))
}
w.Header().Set("Cache-Control", "private, max-age=86400")
w.Header().Set("X-Total-Blocks", strconv.Itoa(int(blockCount)))
w.Header().Set("X-Cache", "MISS")
w.Header().Set("X-Media-Compression", header.Compression.String())
if filename != "" {
w.Header().Set("Content-Disposition", "inline; filename=\""+filename+"\"")
}
flusher, _ := w.(http.Flusher)
if teeBuf != nil {
teeBuf.Write(sniff)
}
if _, err := w.Write(sniff); err != nil {
s.addLog(fmt.Sprintf("media write head failed: %v", err))
return
}
if flusher != nil {
flusher.Flush()
}
dst := io.Writer(&flushAfterEachWriter{w: w, flusher: flusher})
if teeBuf != nil {
dst = io.MultiWriter(dst, teeBuf)
}
// Small buffer so the browser sees many small chunks instead of one big
// one — the xhr onprogress event fires per chunk, which is what drives
// the smooth K/N block counter on the client.
buf := make([]byte, 2048)
if _, err := io.CopyBuffer(dst, body, buf); err != nil {
s.addLog(fmt.Sprintf("media stream failed: %v", err))
return
}
if teeBuf == nil {
s.addLog(fmt.Sprintf("media disk-cache skipped: size=%d crc=%x mediaCache=%v", expectedSize, expectedCRC, s.mediaCache != nil))
} else if expectedSize > 0 && int64(teeBuf.Len()) != expectedSize {
s.addLog(fmt.Sprintf("media disk-cache skipped: tee=%d expected=%d (truncated stream)", teeBuf.Len(), expectedSize))
} else {
if err := s.mediaCache.Put(int64(teeBuf.Len()), expectedCRC, teeBuf.Bytes(), mime); err != nil {
s.addLog(fmt.Sprintf("media disk-cache put failed: %v", err))
} else {
s.addLog(fmt.Sprintf("media cached: %d bytes, crc=%08x, mime=%s", teeBuf.Len(), expectedCRC, mime))
}
}
}
type flushAfterEachWriter struct {
w http.ResponseWriter
flusher http.Flusher
}
func (fw *flushAfterEachWriter) Write(p []byte) (int, error) {
n, err := fw.w.Write(p)
if err == nil && fw.flusher != nil {
fw.flusher.Flush()
}
return n, err
}
func (fw *flushAfterEachWriter) Flush() {
if fw.flusher != nil {
fw.flusher.Flush()
}
}
// sanitizeMime returns a "type/subtype" MIME string built from safe
// characters. HTML/SVG variants are rejected.
func sanitizeMime(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return "application/octet-stream"
}
if i := strings.IndexByte(s, ';'); i >= 0 {
s = strings.TrimSpace(s[:i])
}
slash := strings.IndexByte(s, '/')
if slash <= 0 || slash == len(s)-1 {
return "application/octet-stream"
}
for _, r := range s {
if r == '/' || r == '-' || r == '+' || r == '.' {
continue
}
if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
return "application/octet-stream"
}
}
switch strings.ToLower(s) {
case "text/html", "application/xhtml+xml", "image/svg+xml":
return "application/octet-stream"
}
return s
}
// sanitizeFilename strips path components and control characters.
func sanitizeFilename(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return ""
}
if i := strings.LastIndexAny(s, `/\`); i >= 0 {
s = s[i+1:]
}
if s == "" || s == ".." {
return ""
}
var b strings.Builder
for _, r := range s {
if r < 0x20 || r == 0x7F || r == '"' || r == '\\' {
continue
}
b.WriteRune(r)
}
out := b.String()
if len(out) > 200 {
out = out[:200]
}
return out
}
+179
View File
@@ -0,0 +1,179 @@
package web
import (
"encoding/binary"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
const (
mediaCacheFileExt = ".cache"
mediaCacheMaxMime = 200
mediaCacheMaxFileExt = 1 << 26 // 64 MiB hard cap per cached file
)
// mediaDiskCache stores downloaded media blobs on disk so multiple devices
// connected to the same client/server share the cost of one DNS-tunnelled
// fetch. Entries are content-addressed by (size, crc32) and reaped after
// ttl based on file mtime.
//
// File format: each entry is a single file
//
// <size>_<crc8hex>.cache
//
// containing:
//
// 2 bytes BE — mime length
// N bytes — mime utf8
// rest — raw file bytes
type mediaDiskCache struct {
dir string
ttl time.Duration
mu sync.Mutex
}
func newMediaDiskCache(dir string, ttl time.Duration) (*mediaDiskCache, error) {
if dir == "" {
return nil, errors.New("media cache dir is empty")
}
if err := os.MkdirAll(dir, 0o700); err != nil {
return nil, err
}
return &mediaDiskCache{dir: dir, ttl: ttl}, nil
}
func (c *mediaDiskCache) keyFile(size int64, crc uint32) string {
return filepath.Join(c.dir, fmt.Sprintf("%d_%08x%s", size, crc, mediaCacheFileExt))
}
// Get returns the cached body and mime type if present and not expired.
// Touching mtime on hit so the entry stays alive while it's in use.
func (c *mediaDiskCache) Get(size int64, crc uint32) (body []byte, mime string, ok bool) {
if size <= 0 || crc == 0 {
return nil, "", false
}
path := c.keyFile(size, crc)
info, err := os.Stat(path)
if err != nil {
return nil, "", false
}
if c.ttl > 0 && time.Since(info.ModTime()) > c.ttl {
_ = os.Remove(path)
return nil, "", false
}
data, err := os.ReadFile(path)
if err != nil || len(data) < 2 {
return nil, "", false
}
mimeLen := int(binary.BigEndian.Uint16(data[:2]))
if mimeLen > mediaCacheMaxMime || 2+mimeLen > len(data) {
return nil, "", false
}
mime = string(data[2 : 2+mimeLen])
body = data[2+mimeLen:]
if int64(len(body)) != size {
// Corrupt or partial write — treat as miss.
return nil, "", false
}
_ = os.Chtimes(path, time.Now(), time.Now())
return body, mime, true
}
// Put writes the body+mime atomically to the cache.
func (c *mediaDiskCache) Put(size int64, crc uint32, body []byte, mime string) error {
if size <= 0 || crc == 0 || int64(len(body)) != size {
return errors.New("media cache: invalid put")
}
if len(body) > mediaCacheMaxFileExt {
return errors.New("media cache: body too large")
}
if len(mime) > mediaCacheMaxMime {
mime = mime[:mediaCacheMaxMime]
}
c.mu.Lock()
defer c.mu.Unlock()
path := c.keyFile(size, crc)
tmp := path + ".tmp"
f, err := os.Create(tmp)
if err != nil {
return err
}
header := make([]byte, 2)
binary.BigEndian.PutUint16(header, uint16(len(mime)))
if _, err := f.Write(header); err != nil {
f.Close()
os.Remove(tmp)
return err
}
if _, err := f.Write([]byte(mime)); err != nil {
f.Close()
os.Remove(tmp)
return err
}
if _, err := f.Write(body); err != nil {
f.Close()
os.Remove(tmp)
return err
}
if err := f.Close(); err != nil {
os.Remove(tmp)
return err
}
return os.Rename(tmp, path)
}
// Cleanup removes entries older than ttl. Returns the count removed.
func (c *mediaDiskCache) Cleanup() int {
if c.ttl <= 0 {
return 0
}
c.mu.Lock()
defer c.mu.Unlock()
entries, err := os.ReadDir(c.dir)
if err != nil {
return 0
}
now := time.Now()
removed := 0
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), mediaCacheFileExt) {
continue
}
info, err := e.Info()
if err != nil {
continue
}
if now.Sub(info.ModTime()) > c.ttl {
if os.Remove(filepath.Join(c.dir, e.Name())) == nil {
removed++
}
}
}
return removed
}
// Clear deletes every cached entry. Returns the count removed.
func (c *mediaDiskCache) Clear() int {
c.mu.Lock()
defer c.mu.Unlock()
entries, err := os.ReadDir(c.dir)
if err != nil {
return 0
}
removed := 0
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), mediaCacheFileExt) {
continue
}
if os.Remove(filepath.Join(c.dir, e.Name())) == nil {
removed++
}
}
return removed
}
+42
View File
@@ -0,0 +1,42 @@
package web
import "testing"
func TestSanitizeMime(t *testing.T) {
cases := map[string]string{
"": "application/octet-stream",
"image/jpeg": "image/jpeg",
"image/png; charset=utf-8": "image/png",
"text/html": "application/octet-stream", // blocked
"application/xhtml+xml": "application/octet-stream", // blocked
"image/svg+xml": "application/octet-stream", // blocked (XSS via SVG)
"image/jpeg<script>": "application/octet-stream", // bad chars
"weird": "application/octet-stream", // no slash
"/leading": "application/octet-stream",
"trailing/": "application/octet-stream",
"application/vnd.api+json": "application/vnd.api+json",
"image/webp": "image/webp",
}
for in, want := range cases {
if got := sanitizeMime(in); got != want {
t.Errorf("sanitizeMime(%q) = %q, want %q", in, got, want)
}
}
}
func TestSanitizeFilename(t *testing.T) {
cases := map[string]string{
"": "",
"foo.png": "foo.png",
"../../etc/passwd": "passwd",
"foo/bar/baz.txt": "baz.txt",
"weird\nname.txt": "weirdname.txt",
`bad"quote"name`: "badquotename",
"..": "",
}
for in, want := range cases {
if got := sanitizeFilename(in); got != want {
t.Errorf("sanitizeFilename(%q) = %q, want %q", in, got, want)
}
}
}
File diff suppressed because it is too large Load Diff
+372 -14
View File
@@ -49,12 +49,31 @@ type Config struct {
}
// Profile wraps a Config with a user-chosen nickname and a unique ID.
// AutoUpdate is the per-profile list of channel usernames the auto-update
// goroutine should refresh; AutoUpdateInterval (seconds, 0 → default) sets
// the cadence.
type Profile struct {
ID string `json:"id"`
Nickname string `json:"nickname"`
Config Config `json:"config"`
ID string `json:"id"`
Nickname string `json:"nickname"`
Config Config `json:"config"`
AutoUpdate []string `json:"autoUpdate,omitempty"`
AutoUpdateInterval int `json:"autoUpdateInterval,omitempty"`
}
const (
// minAutoUpdateInterval is the floor — never tick faster than once per
// minute, even if the user sets something silly. The DNS path is
// expensive and the server's own fetch cycle is much longer.
minAutoUpdateInterval = 60 * time.Second
// serverFetchSettleDelay is how long after nextFetch we wait before
// asking the server for fresh data — gives it time to process the
// upstream Telegram fetch and have a coherent metadata snapshot.
serverFetchSettleDelay = 30 * time.Second
// autoUpdateStartupDelay defers the first tick so the initial metadata
// + resolver checks have a chance to land before we start polling.
autoUpdateStartupDelay = 30 * time.Second
)
// SavedResolverScore stores persistent resolver performance data.
type SavedResolverScore struct {
Success int64 `json:"success"`
@@ -137,6 +156,17 @@ type Server struct {
titlesMu sync.Mutex
titlesLoading bool
titlesBackoffUntil time.Time
// dlMu guards dlProgress. Active media downloads register their block
// counter here so the frontend can poll /api/media/progress and show
// per-block updates instead of waiting for byte chunks.
dlMu sync.Mutex
dlProgress map[string]*mediaDLProgress
// mediaCache is a disk-backed store for downloaded media bytes so that
// multiple devices on the same network share a single DNS-tunnelled
// fetch. Entries expire after 7 days.
mediaCache *mediaDiskCache
}
// New creates a new web server.
@@ -154,6 +184,11 @@ func New(dataDir string, port int, host string, password string) (*Server, error
scanner := client.NewResolverScanner()
mediaCache, mcErr := newMediaDiskCache(filepath.Join(dataDir, "media-cache"), 7*24*time.Hour)
if mcErr != nil {
log.Printf("Warning: media disk cache disabled: %v", mcErr)
}
s := &Server{
dataDir: dataDir,
port: port,
@@ -165,6 +200,13 @@ func New(dataDir string, port int, host string, password string) (*Server, error
lastMsgIDs: make(map[int]uint32),
lastHashes: make(map[int]uint32),
scanner: scanner,
mediaCache: mediaCache,
dlProgress: make(map[string]*mediaDLProgress),
}
if mediaCache != nil {
go mediaCache.Cleanup()
go s.runMediaCacheSweep()
}
// Migrate per-profile resolvers into the shared bank on first run.
@@ -213,6 +255,8 @@ func (s *Server) Run() error {
mux.HandleFunc("/api/events", s.handleSSE)
mux.HandleFunc("/api/profiles", s.handleProfiles)
mux.HandleFunc("/api/profiles/switch", s.handleProfileSwitch)
mux.HandleFunc("/api/auto-update", s.handleAutoUpdate)
mux.HandleFunc("/api/auto-update/toggle", s.handleAutoUpdateToggle)
mux.HandleFunc("/api/settings", s.handleSettings)
mux.HandleFunc("/api/version-check", s.handleVersionCheck)
mux.HandleFunc("/api/cache/clear", s.handleClearCache)
@@ -230,6 +274,11 @@ func (s *Server) Run() error {
mux.HandleFunc("/api/scanner/progress", s.handleScannerProgress)
mux.HandleFunc("/api/scanner/apply", s.handleScannerApply)
mux.HandleFunc("/api/scanner/presets", s.handleScannerPresets)
// Media (image/file) downloader: assembles a binary blob from a media
// channel and streams it back. See internal/web/media.go for the param
// contract.
mux.HandleFunc("/api/media/get", s.handleMediaGet)
mux.HandleFunc("/api/media/progress", s.handleMediaProgress)
mux.HandleFunc("/", s.handleIndex)
// Listen on the specified host (default 127.0.0.1)
@@ -782,9 +831,140 @@ func (s *Server) initFetcher() error {
s.fetcher = fetcher
s.cache = cache
go cache.Cleanup() // remove channel files not updated in 7 days
// Goroutine dies with fetcherCtx, so a profile switch / config change
// stops it cleanly.
go s.runAutoUpdateLoop(ctx)
return nil
}
// runAutoUpdateLoop refreshes the active profile's AutoUpdate channels on a
// schedule that follows the server's own fetch cycle — there's no point
// polling more often than the server actually pulls fresh data from
// Telegram. User-set Profile.AutoUpdateInterval is honoured if it's >= the
// 60s floor; otherwise we align with nextFetch + settle delay.
func (s *Server) runAutoUpdateLoop(ctx context.Context) {
select {
case <-time.After(autoUpdateStartupDelay):
case <-ctx.Done():
return
}
var lastTick time.Time
for {
wait := s.nextAutoUpdateWait(lastTick)
select {
case <-ctx.Done():
return
case <-time.After(wait):
}
if !s.canAutoUpdate() {
continue
}
s.tickAutoUpdate()
lastTick = time.Now()
}
}
// nextAutoUpdateWait returns how long to sleep before the next tick. Honours
// user override when set sensibly; otherwise sleeps until just after the
// server's next Telegram fetch so we always pull just-refreshed data.
func (s *Server) nextAutoUpdateWait(lastTick time.Time) time.Duration {
pl, _ := s.loadProfiles()
if pl != nil && pl.Active != "" {
for _, p := range pl.Profiles {
if p.ID != pl.Active {
continue
}
if p.AutoUpdateInterval > 0 {
user := time.Duration(p.AutoUpdateInterval) * time.Second
if user < minAutoUpdateInterval {
user = minAutoUpdateInterval
}
return user
}
break
}
}
s.mu.RLock()
nf := s.nextFetch
s.mu.RUnlock()
if nf == 0 {
return minAutoUpdateInterval
}
target := time.Unix(int64(nf), 0).Add(serverFetchSettleDelay)
delay := time.Until(target)
if delay < minAutoUpdateInterval {
delay = minAutoUpdateInterval
}
if !lastTick.IsZero() {
if since := time.Since(lastTick); since < minAutoUpdateInterval {
if rem := minAutoUpdateInterval - since; rem > delay {
delay = rem
}
}
}
return delay
}
// canAutoUpdate returns false when we should skip a tick: server hasn't
// produced metadata yet (channel list empty), or the resolver scanner is
// busy (it'd race with our DNS fetches), or there's no fetcher.
func (s *Server) canAutoUpdate() bool {
s.mu.RLock()
channels := s.channels
fetcher := s.fetcher
scanner := s.scanner
s.mu.RUnlock()
if fetcher == nil || len(channels) == 0 {
return false
}
if scanner != nil {
switch scanner.State() {
case client.ScannerRunning, client.ScannerPaused:
return false
}
}
return true
}
func (s *Server) tickAutoUpdate() {
pl, err := s.loadProfiles()
if err != nil || pl == nil || pl.Active == "" {
return
}
var watch []string
for _, p := range pl.Profiles {
if p.ID == pl.Active {
watch = p.AutoUpdate
break
}
}
if len(watch) == 0 {
return
}
s.mu.RLock()
channels := s.channels
s.mu.RUnlock()
if len(channels) == 0 {
return
}
wantSet := make(map[string]bool, len(watch))
for _, name := range watch {
wantSet[strings.TrimPrefix(strings.TrimSpace(name), "@")] = true
}
for i, ch := range channels {
if !wantSet[ch.Name] {
continue
}
go s.refreshChannel(i + 1) // 1-indexed
}
}
func (s *Server) checkLatestVersion(ctx context.Context) (string, error) {
s.mu.RLock()
cfg := s.config
@@ -1961,6 +2141,10 @@ func (s *Server) handleProfiles(w http.ResponseWriter, r *http.Request) {
addToBank(pl, req.Profile.Config.Resolvers)
req.Profile.Config.Resolvers = nil
}
// Carry over fields the edit-profile UI doesn't manage so
// they don't get wiped on save (auto-update list etc.).
req.Profile.AutoUpdate = p.AutoUpdate
req.Profile.AutoUpdateInterval = p.AutoUpdateInterval
pl.Profiles[i] = req.Profile
if p.ID == pl.Active {
needsReinit = true
@@ -2097,6 +2281,164 @@ func (s *Server) handleProfileSwitch(w http.ResponseWriter, r *http.Request) {
writeJSON(w, map[string]any{"ok": true})
}
// handleAutoUpdate exposes the active profile's auto-update channel list.
// GET → {channels, intervalSeconds, defaultIntervalSeconds}.
// POST {channels, intervalSeconds?} replaces both. Names are stripped and
// dedup'd before saving.
func (s *Server) handleAutoUpdate(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
pl, _ := s.loadProfiles()
channels := []string{}
interval := 0
if pl != nil && pl.Active != "" {
for _, p := range pl.Profiles {
if p.ID == pl.Active {
if p.AutoUpdate != nil {
channels = p.AutoUpdate
}
interval = p.AutoUpdateInterval
break
}
}
}
writeJSON(w, map[string]any{
"channels": channels,
"intervalSeconds": interval,
"defaultIntervalSeconds": int(minAutoUpdateInterval / time.Second),
})
case http.MethodPost:
var req struct {
Channels []string `json:"channels"`
IntervalSeconds *int `json:"intervalSeconds,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid JSON", 400)
return
}
pl, err := s.loadProfiles()
if err != nil || pl == nil || pl.Active == "" {
http.Error(w, "no active profile", 400)
return
}
idx := -1
for i, p := range pl.Profiles {
if p.ID == pl.Active {
idx = i
break
}
}
if idx < 0 {
http.Error(w, "active profile not found", 400)
return
}
pl.Profiles[idx].AutoUpdate = normaliseAutoUpdateList(req.Channels)
if req.IntervalSeconds != nil {
v := *req.IntervalSeconds
if v < 0 {
v = 0
}
minSec := int(minAutoUpdateInterval / time.Second)
if v > 0 && v < minSec {
v = minSec // floor: never poll faster than the server fetches
}
pl.Profiles[idx].AutoUpdateInterval = v
}
if err := s.saveProfiles(pl); err != nil {
http.Error(w, fmt.Sprintf("save: %v", err), 500)
return
}
writeJSON(w, map[string]any{
"ok": true,
"channels": pl.Profiles[idx].AutoUpdate,
"intervalSeconds": pl.Profiles[idx].AutoUpdateInterval,
})
default:
http.Error(w, "method not allowed", 405)
}
}
// handleAutoUpdateToggle flips one channel's membership. Body {channel}.
// Returns {enabled, channels}.
func (s *Server) handleAutoUpdateToggle(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", 405)
return
}
var req struct {
Channel string `json:"channel"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid JSON", 400)
return
}
name := strings.TrimPrefix(strings.TrimSpace(req.Channel), "@")
if name == "" {
http.Error(w, "channel required", 400)
return
}
pl, err := s.loadProfiles()
if err != nil || pl == nil || pl.Active == "" {
http.Error(w, "no active profile", 400)
return
}
idx := -1
for i, p := range pl.Profiles {
if p.ID == pl.Active {
idx = i
break
}
}
if idx < 0 {
http.Error(w, "active profile not found", 400)
return
}
current := pl.Profiles[idx].AutoUpdate
on := false
hit := -1
for i, n := range current {
if strings.TrimPrefix(strings.TrimSpace(n), "@") == name {
hit = i
break
}
}
if hit >= 0 {
current = append(current[:hit], current[hit+1:]...)
} else {
current = append(current, name)
on = true
}
pl.Profiles[idx].AutoUpdate = normaliseAutoUpdateList(current)
if err := s.saveProfiles(pl); err != nil {
http.Error(w, fmt.Sprintf("save: %v", err), 500)
return
}
writeJSON(w, map[string]any{
"ok": true,
"channel": name,
"enabled": on,
"channels": pl.Profiles[idx].AutoUpdate,
})
}
// normaliseAutoUpdateList strips @ + whitespace, drops empties, dedupes
// while preserving order.
func normaliseAutoUpdateList(in []string) []string {
seen := make(map[string]bool, len(in))
out := make([]string, 0, len(in))
for _, raw := range in {
name := strings.TrimPrefix(strings.TrimSpace(raw), "@")
if name == "" || seen[name] {
continue
}
seen[name] = true
out = append(out, name)
}
return out
}
// handleSettings manages user preferences (font size etc.).
func (s *Server) handleSettings(w http.ResponseWriter, r *http.Request) {
switch r.Method {
@@ -2223,26 +2565,42 @@ func (s *Server) handleVersionCheck(w http.ResponseWriter, r *http.Request) {
writeJSON(w, map[string]any{"ok": true, "latestVersion": v})
}
// handleClearCache deletes all files in the cache directory.
// runMediaCacheSweep evicts expired media-cache entries every hour for the
// lifetime of the process.
func (s *Server) runMediaCacheSweep() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
if s.mediaCache == nil {
return
}
s.mediaCache.Cleanup()
}
}
// handleClearCache wipes both the per-channel message cache and the
// downloaded-media disk cache.
func (s *Server) handleClearCache(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", 405)
return
}
cacheDir := filepath.Join(s.dataDir, "cache")
entries, err := os.ReadDir(cacheDir)
if err != nil {
writeJSON(w, map[string]any{"ok": true, "deleted": 0})
return
}
deleted := 0
for _, e := range entries {
if !e.IsDir() {
cacheDir := filepath.Join(s.dataDir, "cache")
if entries, err := os.ReadDir(cacheDir); err == nil {
for _, e := range entries {
if e.IsDir() {
continue
}
if os.Remove(filepath.Join(cacheDir, e.Name())) == nil {
deleted++
}
}
}
s.addLog(fmt.Sprintf("Cache cleared: %d files deleted", deleted))
writeJSON(w, map[string]any{"ok": true, "deleted": deleted})
mediaDeleted := 0
if s.mediaCache != nil {
mediaDeleted = s.mediaCache.Clear()
}
s.addLog(fmt.Sprintf("Cache cleared: %d message files, %d media files", deleted, mediaDeleted))
writeJSON(w, map[string]any{"ok": true, "deleted": deleted, "mediaDeleted": mediaDeleted})
}