mirror of
https://github.com/sartoopjj/thefeed.git
synced 2026-05-19 06:54:34 +03:00
feat: ✨ media download with DNS query
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormaliseAutoUpdateList(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in []string
|
||||
want []string
|
||||
}{
|
||||
{"nil", nil, []string{}},
|
||||
{"empty", []string{}, []string{}},
|
||||
{"strip @", []string{"@one", "two"}, []string{"one", "two"}},
|
||||
{"trim whitespace", []string{" one ", "\ttwo\n"}, []string{"one", "two"}},
|
||||
{"drop empties", []string{"one", "", " ", "@", "two"}, []string{"one", "two"}},
|
||||
{"dedupe preserves order", []string{"a", "b", "@a", "c", "b"}, []string{"a", "b", "c"}},
|
||||
{"dedupe across @ form", []string{"@chan", "chan"}, []string{"chan"}},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
got := normaliseAutoUpdateList(c.in)
|
||||
if got == nil {
|
||||
got = []string{}
|
||||
}
|
||||
if !reflect.DeepEqual(got, c.want) {
|
||||
t.Errorf("normaliseAutoUpdateList(%v) = %v, want %v", c.in, got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unicode"
|
||||
|
||||
"github.com/sartoopjj/thefeed/internal/client"
|
||||
"github.com/sartoopjj/thefeed/internal/protocol"
|
||||
)
|
||||
|
||||
// mediaDLProgress tracks how many blocks of a single in-flight download have
|
||||
// been fetched. The frontend polls /api/media/progress to drive a smooth
|
||||
// per-block counter while the xhr is still reading the response body.
|
||||
type mediaDLProgress struct {
|
||||
completed int32
|
||||
total int32
|
||||
}
|
||||
|
||||
// mediaProgressKey is the join of (channel, blockCount, crc) the frontend
|
||||
// uses to look up its own download. It matches the params on the GET URL so
|
||||
// no extra bookkeeping leaks into the JSON response.
|
||||
func mediaProgressKey(channel uint16, blocks uint16, crc uint32) string {
|
||||
return fmt.Sprintf("%d:%d:%08x", channel, blocks, crc)
|
||||
}
|
||||
|
||||
func (s *Server) handleMediaProgress(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
ch64, _ := strconv.ParseUint(q.Get("ch"), 10, 16)
|
||||
blk64, _ := strconv.ParseUint(q.Get("blk"), 10, 16)
|
||||
crc64, _ := strconv.ParseUint(strings.TrimSpace(q.Get("crc")), 16, 32)
|
||||
key := mediaProgressKey(uint16(ch64), uint16(blk64), uint32(crc64))
|
||||
|
||||
s.dlMu.Lock()
|
||||
prog := s.dlProgress[key]
|
||||
s.dlMu.Unlock()
|
||||
if prog == nil {
|
||||
writeJSON(w, map[string]any{"active": false, "completed": 0, "total": int(blk64)})
|
||||
return
|
||||
}
|
||||
writeJSON(w, map[string]any{
|
||||
"active": true,
|
||||
"completed": int(atomic.LoadInt32(&prog.completed)),
|
||||
"total": int(atomic.LoadInt32(&prog.total)),
|
||||
})
|
||||
}
|
||||
|
||||
// handleMediaGet streams a media blob assembled from the
|
||||
// (channel, blocks, crc) tuple embedded in a message's text.
|
||||
//
|
||||
// Query string:
|
||||
//
|
||||
// ch=<uint16> media channel number (10000..60000)
|
||||
// blk=<uint16> total block count
|
||||
// size=<bytes> expected file size (Content-Length)
|
||||
// crc=<hex8> expected CRC32 of full body
|
||||
// name=<filename> optional filename for Content-Disposition
|
||||
// type=<mime> optional mime type override; sanitized
|
||||
func (s *Server) handleMediaGet(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
ch64, err := strconv.ParseUint(q.Get("ch"), 10, 16)
|
||||
if err != nil {
|
||||
http.Error(w, "bad ch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
channel := uint16(ch64)
|
||||
if !protocol.IsMediaChannel(channel) {
|
||||
http.Error(w, "ch out of media range", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
blk64, err := strconv.ParseUint(q.Get("blk"), 10, 16)
|
||||
if err != nil || blk64 == 0 {
|
||||
http.Error(w, "bad blk", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
blockCount := uint16(blk64)
|
||||
|
||||
const maxClaimedSize = 100 * 1024 * 1024
|
||||
expectedSize, _ := strconv.ParseInt(q.Get("size"), 10, 64)
|
||||
if expectedSize < 0 || expectedSize > maxClaimedSize {
|
||||
http.Error(w, "bad size", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
expectedCRC := uint32(0)
|
||||
if v := strings.TrimSpace(q.Get("crc")); v != "" {
|
||||
c, err := strconv.ParseUint(v, 16, 32)
|
||||
if err != nil {
|
||||
http.Error(w, "bad crc", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
expectedCRC = uint32(c)
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
fetcher := s.fetcher
|
||||
s.mu.RUnlock()
|
||||
if fetcher == nil {
|
||||
http.Error(w, "fetcher not configured", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Disk-cache hit: serve directly without ever talking to DNS.
|
||||
if s.mediaCache != nil && expectedCRC != 0 && expectedSize > 0 {
|
||||
if body, mime, ok := s.mediaCache.Get(expectedSize, expectedCRC); ok {
|
||||
servedMime := sanitizeMime(q.Get("type"))
|
||||
if servedMime == "application/octet-stream" {
|
||||
if mime != "" {
|
||||
servedMime = sanitizeMime(mime)
|
||||
} else if sniffed := http.DetectContentType(body); sniffed != "" {
|
||||
servedMime = sanitizeMime(sniffed)
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", servedMime)
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
|
||||
w.Header().Set("Cache-Control", "private, max-age=86400")
|
||||
w.Header().Set("X-Total-Blocks", strconv.Itoa(int(blockCount)))
|
||||
w.Header().Set("X-Cache", "HIT")
|
||||
if filename := sanitizeFilename(q.Get("name")); filename != "" {
|
||||
w.Header().Set("Content-Disposition", "inline; filename=\""+filename+"\"")
|
||||
}
|
||||
if _, err := w.Write(body); err != nil {
|
||||
s.addLog(fmt.Sprintf("media disk-cache write failed: %v", err))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch block 0 synchronously: it carries the protocol header (CRC32,
|
||||
// version, compression). We need that before we can decompress and
|
||||
// before we can sniff Content-Type from the decompressed body.
|
||||
firstBlock, err := fetcher.FetchBlock(ctx, channel, 0)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
http.Error(w, "fetch cancelled", 499)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("fetch media: %v", err), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
if len(firstBlock) < protocol.MediaBlockHeaderLen {
|
||||
http.Error(w, "malformed block 0", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
header, err := protocol.DecodeMediaBlockHeader(firstBlock[:protocol.MediaBlockHeaderLen])
|
||||
if err != nil {
|
||||
http.Error(w, "malformed block 0", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
if expectedCRC != 0 && header.CRC32 != expectedCRC {
|
||||
http.Error(w, "content hash mismatch", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
firstCompressed := firstBlock[protocol.MediaBlockHeaderLen:]
|
||||
|
||||
// Register this download so /api/media/progress can report block
|
||||
// progress as the client polls. Block 0 is already fetched.
|
||||
progKey := mediaProgressKey(channel, blockCount, expectedCRC)
|
||||
prog := &mediaDLProgress{total: int32(blockCount), completed: 1}
|
||||
s.dlMu.Lock()
|
||||
s.dlProgress[progKey] = prog
|
||||
s.dlMu.Unlock()
|
||||
defer func() {
|
||||
s.dlMu.Lock()
|
||||
delete(s.dlProgress, progKey)
|
||||
s.dlMu.Unlock()
|
||||
}()
|
||||
|
||||
// Pipe compressed bytes (block-0 payload + later blocks) into a
|
||||
// decompressor reader. Fed by a goroutine; consumed below for sniffing
|
||||
// and for streaming to the HTTP response.
|
||||
pipeR, pipeW := io.Pipe()
|
||||
go func() {
|
||||
var pipeErr error
|
||||
defer func() { pipeW.CloseWithError(pipeErr) }()
|
||||
if _, err := pipeW.Write(firstCompressed); err != nil {
|
||||
pipeErr = err
|
||||
return
|
||||
}
|
||||
if blockCount > 1 {
|
||||
progressCB := func(done, _ int) {
|
||||
// done counts blocks 1..N-1; add 1 for block 0 already fetched.
|
||||
atomic.StoreInt32(&prog.completed, int32(done+1))
|
||||
}
|
||||
pipeErr = fetcher.FetchMediaBlocksStream(ctx, channel, 1, blockCount-1, pipeW, progressCB)
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := client.DecompressMediaReader(pipeR, header.Compression)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("decompress: %v", err), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
// Tee decompressed bytes into a buffer so we can persist them to the
|
||||
// disk cache after a successful response.
|
||||
var teeBuf *bytes.Buffer
|
||||
if s.mediaCache != nil && expectedCRC != 0 && expectedSize > 0 && expectedSize <= mediaCacheMaxFileExt {
|
||||
teeBuf = bytes.NewBuffer(make([]byte, 0, expectedSize))
|
||||
}
|
||||
|
||||
// Sniff Content-Type from the first decompressed bytes before flushing
|
||||
// headers — once Content-Type goes out we can't change it.
|
||||
const sniffSize = 512
|
||||
sniff := make([]byte, sniffSize)
|
||||
n, err := io.ReadFull(body, sniff)
|
||||
if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
|
||||
http.Error(w, fmt.Sprintf("read media: %v", err), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
sniff = sniff[:n]
|
||||
|
||||
mime := sanitizeMime(q.Get("type"))
|
||||
if mime == "application/octet-stream" {
|
||||
if got := http.DetectContentType(sniff); got != "" {
|
||||
mime = sanitizeMime(got)
|
||||
}
|
||||
}
|
||||
filename := sanitizeFilename(q.Get("name"))
|
||||
|
||||
w.Header().Set("Content-Type", mime)
|
||||
if expectedSize > 0 {
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(expectedSize, 10))
|
||||
}
|
||||
w.Header().Set("Cache-Control", "private, max-age=86400")
|
||||
w.Header().Set("X-Total-Blocks", strconv.Itoa(int(blockCount)))
|
||||
w.Header().Set("X-Cache", "MISS")
|
||||
w.Header().Set("X-Media-Compression", header.Compression.String())
|
||||
if filename != "" {
|
||||
w.Header().Set("Content-Disposition", "inline; filename=\""+filename+"\"")
|
||||
}
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
if teeBuf != nil {
|
||||
teeBuf.Write(sniff)
|
||||
}
|
||||
if _, err := w.Write(sniff); err != nil {
|
||||
s.addLog(fmt.Sprintf("media write head failed: %v", err))
|
||||
return
|
||||
}
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
dst := io.Writer(&flushAfterEachWriter{w: w, flusher: flusher})
|
||||
if teeBuf != nil {
|
||||
dst = io.MultiWriter(dst, teeBuf)
|
||||
}
|
||||
// Small buffer so the browser sees many small chunks instead of one big
|
||||
// one — the xhr onprogress event fires per chunk, which is what drives
|
||||
// the smooth K/N block counter on the client.
|
||||
buf := make([]byte, 2048)
|
||||
if _, err := io.CopyBuffer(dst, body, buf); err != nil {
|
||||
s.addLog(fmt.Sprintf("media stream failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if teeBuf == nil {
|
||||
s.addLog(fmt.Sprintf("media disk-cache skipped: size=%d crc=%x mediaCache=%v", expectedSize, expectedCRC, s.mediaCache != nil))
|
||||
} else if expectedSize > 0 && int64(teeBuf.Len()) != expectedSize {
|
||||
s.addLog(fmt.Sprintf("media disk-cache skipped: tee=%d expected=%d (truncated stream)", teeBuf.Len(), expectedSize))
|
||||
} else {
|
||||
if err := s.mediaCache.Put(int64(teeBuf.Len()), expectedCRC, teeBuf.Bytes(), mime); err != nil {
|
||||
s.addLog(fmt.Sprintf("media disk-cache put failed: %v", err))
|
||||
} else {
|
||||
s.addLog(fmt.Sprintf("media cached: %d bytes, crc=%08x, mime=%s", teeBuf.Len(), expectedCRC, mime))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type flushAfterEachWriter struct {
|
||||
w http.ResponseWriter
|
||||
flusher http.Flusher
|
||||
}
|
||||
|
||||
func (fw *flushAfterEachWriter) Write(p []byte) (int, error) {
|
||||
n, err := fw.w.Write(p)
|
||||
if err == nil && fw.flusher != nil {
|
||||
fw.flusher.Flush()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (fw *flushAfterEachWriter) Flush() {
|
||||
if fw.flusher != nil {
|
||||
fw.flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeMime returns a "type/subtype" MIME string built from safe
|
||||
// characters. HTML/SVG variants are rejected.
|
||||
func sanitizeMime(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
if i := strings.IndexByte(s, ';'); i >= 0 {
|
||||
s = strings.TrimSpace(s[:i])
|
||||
}
|
||||
slash := strings.IndexByte(s, '/')
|
||||
if slash <= 0 || slash == len(s)-1 {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
for _, r := range s {
|
||||
if r == '/' || r == '-' || r == '+' || r == '.' {
|
||||
continue
|
||||
}
|
||||
if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
}
|
||||
switch strings.ToLower(s) {
|
||||
case "text/html", "application/xhtml+xml", "image/svg+xml":
|
||||
return "application/octet-stream"
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// sanitizeFilename strips path components and control characters.
|
||||
func sanitizeFilename(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
if i := strings.LastIndexAny(s, `/\`); i >= 0 {
|
||||
s = s[i+1:]
|
||||
}
|
||||
if s == "" || s == ".." {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, r := range s {
|
||||
if r < 0x20 || r == 0x7F || r == '"' || r == '\\' {
|
||||
continue
|
||||
}
|
||||
b.WriteRune(r)
|
||||
}
|
||||
out := b.String()
|
||||
if len(out) > 200 {
|
||||
out = out[:200]
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
mediaCacheFileExt = ".cache"
|
||||
mediaCacheMaxMime = 200
|
||||
mediaCacheMaxFileExt = 1 << 26 // 64 MiB hard cap per cached file
|
||||
)
|
||||
|
||||
// mediaDiskCache stores downloaded media blobs on disk so multiple devices
|
||||
// connected to the same client/server share the cost of one DNS-tunnelled
|
||||
// fetch. Entries are content-addressed by (size, crc32) and reaped after
|
||||
// ttl based on file mtime.
|
||||
//
|
||||
// File format: each entry is a single file
|
||||
//
|
||||
// <size>_<crc8hex>.cache
|
||||
//
|
||||
// containing:
|
||||
//
|
||||
// 2 bytes BE — mime length
|
||||
// N bytes — mime utf8
|
||||
// rest — raw file bytes
|
||||
type mediaDiskCache struct {
|
||||
dir string
|
||||
ttl time.Duration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMediaDiskCache(dir string, ttl time.Duration) (*mediaDiskCache, error) {
|
||||
if dir == "" {
|
||||
return nil, errors.New("media cache dir is empty")
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mediaDiskCache{dir: dir, ttl: ttl}, nil
|
||||
}
|
||||
|
||||
func (c *mediaDiskCache) keyFile(size int64, crc uint32) string {
|
||||
return filepath.Join(c.dir, fmt.Sprintf("%d_%08x%s", size, crc, mediaCacheFileExt))
|
||||
}
|
||||
|
||||
// Get returns the cached body and mime type if present and not expired.
|
||||
// Touching mtime on hit so the entry stays alive while it's in use.
|
||||
func (c *mediaDiskCache) Get(size int64, crc uint32) (body []byte, mime string, ok bool) {
|
||||
if size <= 0 || crc == 0 {
|
||||
return nil, "", false
|
||||
}
|
||||
path := c.keyFile(size, crc)
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, "", false
|
||||
}
|
||||
if c.ttl > 0 && time.Since(info.ModTime()) > c.ttl {
|
||||
_ = os.Remove(path)
|
||||
return nil, "", false
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil || len(data) < 2 {
|
||||
return nil, "", false
|
||||
}
|
||||
mimeLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||
if mimeLen > mediaCacheMaxMime || 2+mimeLen > len(data) {
|
||||
return nil, "", false
|
||||
}
|
||||
mime = string(data[2 : 2+mimeLen])
|
||||
body = data[2+mimeLen:]
|
||||
if int64(len(body)) != size {
|
||||
// Corrupt or partial write — treat as miss.
|
||||
return nil, "", false
|
||||
}
|
||||
_ = os.Chtimes(path, time.Now(), time.Now())
|
||||
return body, mime, true
|
||||
}
|
||||
|
||||
// Put writes the body+mime atomically to the cache.
|
||||
func (c *mediaDiskCache) Put(size int64, crc uint32, body []byte, mime string) error {
|
||||
if size <= 0 || crc == 0 || int64(len(body)) != size {
|
||||
return errors.New("media cache: invalid put")
|
||||
}
|
||||
if len(body) > mediaCacheMaxFileExt {
|
||||
return errors.New("media cache: body too large")
|
||||
}
|
||||
if len(mime) > mediaCacheMaxMime {
|
||||
mime = mime[:mediaCacheMaxMime]
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
path := c.keyFile(size, crc)
|
||||
tmp := path + ".tmp"
|
||||
f, err := os.Create(tmp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(header, uint16(len(mime)))
|
||||
if _, err := f.Write(header); err != nil {
|
||||
f.Close()
|
||||
os.Remove(tmp)
|
||||
return err
|
||||
}
|
||||
if _, err := f.Write([]byte(mime)); err != nil {
|
||||
f.Close()
|
||||
os.Remove(tmp)
|
||||
return err
|
||||
}
|
||||
if _, err := f.Write(body); err != nil {
|
||||
f.Close()
|
||||
os.Remove(tmp)
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
os.Remove(tmp)
|
||||
return err
|
||||
}
|
||||
return os.Rename(tmp, path)
|
||||
}
|
||||
|
||||
// Cleanup removes entries older than ttl. Returns the count removed.
|
||||
func (c *mediaDiskCache) Cleanup() int {
|
||||
if c.ttl <= 0 {
|
||||
return 0
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
entries, err := os.ReadDir(c.dir)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
now := time.Now()
|
||||
removed := 0
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), mediaCacheFileExt) {
|
||||
continue
|
||||
}
|
||||
info, err := e.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if now.Sub(info.ModTime()) > c.ttl {
|
||||
if os.Remove(filepath.Join(c.dir, e.Name())) == nil {
|
||||
removed++
|
||||
}
|
||||
}
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
// Clear deletes every cached entry. Returns the count removed.
|
||||
func (c *mediaDiskCache) Clear() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
entries, err := os.ReadDir(c.dir)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
removed := 0
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), mediaCacheFileExt) {
|
||||
continue
|
||||
}
|
||||
if os.Remove(filepath.Join(c.dir, e.Name())) == nil {
|
||||
removed++
|
||||
}
|
||||
}
|
||||
return removed
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package web
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeMime(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": "application/octet-stream",
|
||||
"image/jpeg": "image/jpeg",
|
||||
"image/png; charset=utf-8": "image/png",
|
||||
"text/html": "application/octet-stream", // blocked
|
||||
"application/xhtml+xml": "application/octet-stream", // blocked
|
||||
"image/svg+xml": "application/octet-stream", // blocked (XSS via SVG)
|
||||
"image/jpeg<script>": "application/octet-stream", // bad chars
|
||||
"weird": "application/octet-stream", // no slash
|
||||
"/leading": "application/octet-stream",
|
||||
"trailing/": "application/octet-stream",
|
||||
"application/vnd.api+json": "application/vnd.api+json",
|
||||
"image/webp": "image/webp",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := sanitizeMime(in); got != want {
|
||||
t.Errorf("sanitizeMime(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeFilename(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": "",
|
||||
"foo.png": "foo.png",
|
||||
"../../etc/passwd": "passwd",
|
||||
"foo/bar/baz.txt": "baz.txt",
|
||||
"weird\nname.txt": "weirdname.txt",
|
||||
`bad"quote"name`: "badquotename",
|
||||
"..": "",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := sanitizeFilename(in); got != want {
|
||||
t.Errorf("sanitizeFilename(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
+1479
-35
File diff suppressed because it is too large
Load Diff
+372
-14
@@ -49,12 +49,31 @@ type Config struct {
|
||||
}
|
||||
|
||||
// Profile wraps a Config with a user-chosen nickname and a unique ID.
|
||||
// AutoUpdate is the per-profile list of channel usernames the auto-update
|
||||
// goroutine should refresh; AutoUpdateInterval (seconds, 0 → default) sets
|
||||
// the cadence.
|
||||
type Profile struct {
|
||||
ID string `json:"id"`
|
||||
Nickname string `json:"nickname"`
|
||||
Config Config `json:"config"`
|
||||
ID string `json:"id"`
|
||||
Nickname string `json:"nickname"`
|
||||
Config Config `json:"config"`
|
||||
AutoUpdate []string `json:"autoUpdate,omitempty"`
|
||||
AutoUpdateInterval int `json:"autoUpdateInterval,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// minAutoUpdateInterval is the floor — never tick faster than once per
|
||||
// minute, even if the user sets something silly. The DNS path is
|
||||
// expensive and the server's own fetch cycle is much longer.
|
||||
minAutoUpdateInterval = 60 * time.Second
|
||||
// serverFetchSettleDelay is how long after nextFetch we wait before
|
||||
// asking the server for fresh data — gives it time to process the
|
||||
// upstream Telegram fetch and have a coherent metadata snapshot.
|
||||
serverFetchSettleDelay = 30 * time.Second
|
||||
// autoUpdateStartupDelay defers the first tick so the initial metadata
|
||||
// + resolver checks have a chance to land before we start polling.
|
||||
autoUpdateStartupDelay = 30 * time.Second
|
||||
)
|
||||
|
||||
// SavedResolverScore stores persistent resolver performance data.
|
||||
type SavedResolverScore struct {
|
||||
Success int64 `json:"success"`
|
||||
@@ -137,6 +156,17 @@ type Server struct {
|
||||
titlesMu sync.Mutex
|
||||
titlesLoading bool
|
||||
titlesBackoffUntil time.Time
|
||||
|
||||
// dlMu guards dlProgress. Active media downloads register their block
|
||||
// counter here so the frontend can poll /api/media/progress and show
|
||||
// per-block updates instead of waiting for byte chunks.
|
||||
dlMu sync.Mutex
|
||||
dlProgress map[string]*mediaDLProgress
|
||||
|
||||
// mediaCache is a disk-backed store for downloaded media bytes so that
|
||||
// multiple devices on the same network share a single DNS-tunnelled
|
||||
// fetch. Entries expire after 7 days.
|
||||
mediaCache *mediaDiskCache
|
||||
}
|
||||
|
||||
// New creates a new web server.
|
||||
@@ -154,6 +184,11 @@ func New(dataDir string, port int, host string, password string) (*Server, error
|
||||
|
||||
scanner := client.NewResolverScanner()
|
||||
|
||||
mediaCache, mcErr := newMediaDiskCache(filepath.Join(dataDir, "media-cache"), 7*24*time.Hour)
|
||||
if mcErr != nil {
|
||||
log.Printf("Warning: media disk cache disabled: %v", mcErr)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
dataDir: dataDir,
|
||||
port: port,
|
||||
@@ -165,6 +200,13 @@ func New(dataDir string, port int, host string, password string) (*Server, error
|
||||
lastMsgIDs: make(map[int]uint32),
|
||||
lastHashes: make(map[int]uint32),
|
||||
scanner: scanner,
|
||||
mediaCache: mediaCache,
|
||||
dlProgress: make(map[string]*mediaDLProgress),
|
||||
}
|
||||
|
||||
if mediaCache != nil {
|
||||
go mediaCache.Cleanup()
|
||||
go s.runMediaCacheSweep()
|
||||
}
|
||||
|
||||
// Migrate per-profile resolvers into the shared bank on first run.
|
||||
@@ -213,6 +255,8 @@ func (s *Server) Run() error {
|
||||
mux.HandleFunc("/api/events", s.handleSSE)
|
||||
mux.HandleFunc("/api/profiles", s.handleProfiles)
|
||||
mux.HandleFunc("/api/profiles/switch", s.handleProfileSwitch)
|
||||
mux.HandleFunc("/api/auto-update", s.handleAutoUpdate)
|
||||
mux.HandleFunc("/api/auto-update/toggle", s.handleAutoUpdateToggle)
|
||||
mux.HandleFunc("/api/settings", s.handleSettings)
|
||||
mux.HandleFunc("/api/version-check", s.handleVersionCheck)
|
||||
mux.HandleFunc("/api/cache/clear", s.handleClearCache)
|
||||
@@ -230,6 +274,11 @@ func (s *Server) Run() error {
|
||||
mux.HandleFunc("/api/scanner/progress", s.handleScannerProgress)
|
||||
mux.HandleFunc("/api/scanner/apply", s.handleScannerApply)
|
||||
mux.HandleFunc("/api/scanner/presets", s.handleScannerPresets)
|
||||
// Media (image/file) downloader: assembles a binary blob from a media
|
||||
// channel and streams it back. See internal/web/media.go for the param
|
||||
// contract.
|
||||
mux.HandleFunc("/api/media/get", s.handleMediaGet)
|
||||
mux.HandleFunc("/api/media/progress", s.handleMediaProgress)
|
||||
mux.HandleFunc("/", s.handleIndex)
|
||||
|
||||
// Listen on the specified host (default 127.0.0.1)
|
||||
@@ -782,9 +831,140 @@ func (s *Server) initFetcher() error {
|
||||
s.fetcher = fetcher
|
||||
s.cache = cache
|
||||
go cache.Cleanup() // remove channel files not updated in 7 days
|
||||
|
||||
// Goroutine dies with fetcherCtx, so a profile switch / config change
|
||||
// stops it cleanly.
|
||||
go s.runAutoUpdateLoop(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// runAutoUpdateLoop refreshes the active profile's AutoUpdate channels on a
|
||||
// schedule that follows the server's own fetch cycle — there's no point
|
||||
// polling more often than the server actually pulls fresh data from
|
||||
// Telegram. User-set Profile.AutoUpdateInterval is honoured if it's >= the
|
||||
// 60s floor; otherwise we align with nextFetch + settle delay.
|
||||
func (s *Server) runAutoUpdateLoop(ctx context.Context) {
|
||||
select {
|
||||
case <-time.After(autoUpdateStartupDelay):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
var lastTick time.Time
|
||||
for {
|
||||
wait := s.nextAutoUpdateWait(lastTick)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(wait):
|
||||
}
|
||||
if !s.canAutoUpdate() {
|
||||
continue
|
||||
}
|
||||
s.tickAutoUpdate()
|
||||
lastTick = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// nextAutoUpdateWait returns how long to sleep before the next tick. Honours
|
||||
// user override when set sensibly; otherwise sleeps until just after the
|
||||
// server's next Telegram fetch so we always pull just-refreshed data.
|
||||
func (s *Server) nextAutoUpdateWait(lastTick time.Time) time.Duration {
|
||||
pl, _ := s.loadProfiles()
|
||||
if pl != nil && pl.Active != "" {
|
||||
for _, p := range pl.Profiles {
|
||||
if p.ID != pl.Active {
|
||||
continue
|
||||
}
|
||||
if p.AutoUpdateInterval > 0 {
|
||||
user := time.Duration(p.AutoUpdateInterval) * time.Second
|
||||
if user < minAutoUpdateInterval {
|
||||
user = minAutoUpdateInterval
|
||||
}
|
||||
return user
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
nf := s.nextFetch
|
||||
s.mu.RUnlock()
|
||||
if nf == 0 {
|
||||
return minAutoUpdateInterval
|
||||
}
|
||||
target := time.Unix(int64(nf), 0).Add(serverFetchSettleDelay)
|
||||
delay := time.Until(target)
|
||||
if delay < minAutoUpdateInterval {
|
||||
delay = minAutoUpdateInterval
|
||||
}
|
||||
if !lastTick.IsZero() {
|
||||
if since := time.Since(lastTick); since < minAutoUpdateInterval {
|
||||
if rem := minAutoUpdateInterval - since; rem > delay {
|
||||
delay = rem
|
||||
}
|
||||
}
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
// canAutoUpdate returns false when we should skip a tick: server hasn't
|
||||
// produced metadata yet (channel list empty), or the resolver scanner is
|
||||
// busy (it'd race with our DNS fetches), or there's no fetcher.
|
||||
func (s *Server) canAutoUpdate() bool {
|
||||
s.mu.RLock()
|
||||
channels := s.channels
|
||||
fetcher := s.fetcher
|
||||
scanner := s.scanner
|
||||
s.mu.RUnlock()
|
||||
if fetcher == nil || len(channels) == 0 {
|
||||
return false
|
||||
}
|
||||
if scanner != nil {
|
||||
switch scanner.State() {
|
||||
case client.ScannerRunning, client.ScannerPaused:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) tickAutoUpdate() {
|
||||
pl, err := s.loadProfiles()
|
||||
if err != nil || pl == nil || pl.Active == "" {
|
||||
return
|
||||
}
|
||||
var watch []string
|
||||
for _, p := range pl.Profiles {
|
||||
if p.ID == pl.Active {
|
||||
watch = p.AutoUpdate
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(watch) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
channels := s.channels
|
||||
s.mu.RUnlock()
|
||||
if len(channels) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
wantSet := make(map[string]bool, len(watch))
|
||||
for _, name := range watch {
|
||||
wantSet[strings.TrimPrefix(strings.TrimSpace(name), "@")] = true
|
||||
}
|
||||
|
||||
for i, ch := range channels {
|
||||
if !wantSet[ch.Name] {
|
||||
continue
|
||||
}
|
||||
go s.refreshChannel(i + 1) // 1-indexed
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) checkLatestVersion(ctx context.Context) (string, error) {
|
||||
s.mu.RLock()
|
||||
cfg := s.config
|
||||
@@ -1961,6 +2141,10 @@ func (s *Server) handleProfiles(w http.ResponseWriter, r *http.Request) {
|
||||
addToBank(pl, req.Profile.Config.Resolvers)
|
||||
req.Profile.Config.Resolvers = nil
|
||||
}
|
||||
// Carry over fields the edit-profile UI doesn't manage so
|
||||
// they don't get wiped on save (auto-update list etc.).
|
||||
req.Profile.AutoUpdate = p.AutoUpdate
|
||||
req.Profile.AutoUpdateInterval = p.AutoUpdateInterval
|
||||
pl.Profiles[i] = req.Profile
|
||||
if p.ID == pl.Active {
|
||||
needsReinit = true
|
||||
@@ -2097,6 +2281,164 @@ func (s *Server) handleProfileSwitch(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
// handleAutoUpdate exposes the active profile's auto-update channel list.
|
||||
// GET → {channels, intervalSeconds, defaultIntervalSeconds}.
|
||||
// POST {channels, intervalSeconds?} replaces both. Names are stripped and
|
||||
// dedup'd before saving.
|
||||
func (s *Server) handleAutoUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
pl, _ := s.loadProfiles()
|
||||
channels := []string{}
|
||||
interval := 0
|
||||
if pl != nil && pl.Active != "" {
|
||||
for _, p := range pl.Profiles {
|
||||
if p.ID == pl.Active {
|
||||
if p.AutoUpdate != nil {
|
||||
channels = p.AutoUpdate
|
||||
}
|
||||
interval = p.AutoUpdateInterval
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
writeJSON(w, map[string]any{
|
||||
"channels": channels,
|
||||
"intervalSeconds": interval,
|
||||
"defaultIntervalSeconds": int(minAutoUpdateInterval / time.Second),
|
||||
})
|
||||
|
||||
case http.MethodPost:
|
||||
var req struct {
|
||||
Channels []string `json:"channels"`
|
||||
IntervalSeconds *int `json:"intervalSeconds,omitempty"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid JSON", 400)
|
||||
return
|
||||
}
|
||||
pl, err := s.loadProfiles()
|
||||
if err != nil || pl == nil || pl.Active == "" {
|
||||
http.Error(w, "no active profile", 400)
|
||||
return
|
||||
}
|
||||
idx := -1
|
||||
for i, p := range pl.Profiles {
|
||||
if p.ID == pl.Active {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
http.Error(w, "active profile not found", 400)
|
||||
return
|
||||
}
|
||||
pl.Profiles[idx].AutoUpdate = normaliseAutoUpdateList(req.Channels)
|
||||
if req.IntervalSeconds != nil {
|
||||
v := *req.IntervalSeconds
|
||||
if v < 0 {
|
||||
v = 0
|
||||
}
|
||||
minSec := int(minAutoUpdateInterval / time.Second)
|
||||
if v > 0 && v < minSec {
|
||||
v = minSec // floor: never poll faster than the server fetches
|
||||
}
|
||||
pl.Profiles[idx].AutoUpdateInterval = v
|
||||
}
|
||||
if err := s.saveProfiles(pl); err != nil {
|
||||
http.Error(w, fmt.Sprintf("save: %v", err), 500)
|
||||
return
|
||||
}
|
||||
writeJSON(w, map[string]any{
|
||||
"ok": true,
|
||||
"channels": pl.Profiles[idx].AutoUpdate,
|
||||
"intervalSeconds": pl.Profiles[idx].AutoUpdateInterval,
|
||||
})
|
||||
|
||||
default:
|
||||
http.Error(w, "method not allowed", 405)
|
||||
}
|
||||
}
|
||||
|
||||
// handleAutoUpdateToggle flips one channel's membership. Body {channel}.
|
||||
// Returns {enabled, channels}.
|
||||
func (s *Server) handleAutoUpdateToggle(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", 405)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Channel string `json:"channel"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid JSON", 400)
|
||||
return
|
||||
}
|
||||
name := strings.TrimPrefix(strings.TrimSpace(req.Channel), "@")
|
||||
if name == "" {
|
||||
http.Error(w, "channel required", 400)
|
||||
return
|
||||
}
|
||||
pl, err := s.loadProfiles()
|
||||
if err != nil || pl == nil || pl.Active == "" {
|
||||
http.Error(w, "no active profile", 400)
|
||||
return
|
||||
}
|
||||
idx := -1
|
||||
for i, p := range pl.Profiles {
|
||||
if p.ID == pl.Active {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
http.Error(w, "active profile not found", 400)
|
||||
return
|
||||
}
|
||||
current := pl.Profiles[idx].AutoUpdate
|
||||
on := false
|
||||
hit := -1
|
||||
for i, n := range current {
|
||||
if strings.TrimPrefix(strings.TrimSpace(n), "@") == name {
|
||||
hit = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if hit >= 0 {
|
||||
current = append(current[:hit], current[hit+1:]...)
|
||||
} else {
|
||||
current = append(current, name)
|
||||
on = true
|
||||
}
|
||||
pl.Profiles[idx].AutoUpdate = normaliseAutoUpdateList(current)
|
||||
if err := s.saveProfiles(pl); err != nil {
|
||||
http.Error(w, fmt.Sprintf("save: %v", err), 500)
|
||||
return
|
||||
}
|
||||
writeJSON(w, map[string]any{
|
||||
"ok": true,
|
||||
"channel": name,
|
||||
"enabled": on,
|
||||
"channels": pl.Profiles[idx].AutoUpdate,
|
||||
})
|
||||
}
|
||||
|
||||
// normaliseAutoUpdateList strips @ + whitespace, drops empties, dedupes
|
||||
// while preserving order.
|
||||
func normaliseAutoUpdateList(in []string) []string {
|
||||
seen := make(map[string]bool, len(in))
|
||||
out := make([]string, 0, len(in))
|
||||
for _, raw := range in {
|
||||
name := strings.TrimPrefix(strings.TrimSpace(raw), "@")
|
||||
if name == "" || seen[name] {
|
||||
continue
|
||||
}
|
||||
seen[name] = true
|
||||
out = append(out, name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// handleSettings manages user preferences (font size etc.).
|
||||
func (s *Server) handleSettings(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
@@ -2223,26 +2565,42 @@ func (s *Server) handleVersionCheck(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, map[string]any{"ok": true, "latestVersion": v})
|
||||
}
|
||||
|
||||
// handleClearCache deletes all files in the cache directory.
|
||||
// runMediaCacheSweep evicts expired media-cache entries every hour for the
|
||||
// lifetime of the process.
|
||||
func (s *Server) runMediaCacheSweep() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if s.mediaCache == nil {
|
||||
return
|
||||
}
|
||||
s.mediaCache.Cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// handleClearCache wipes both the per-channel message cache and the
|
||||
// downloaded-media disk cache.
|
||||
func (s *Server) handleClearCache(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", 405)
|
||||
return
|
||||
}
|
||||
cacheDir := filepath.Join(s.dataDir, "cache")
|
||||
entries, err := os.ReadDir(cacheDir)
|
||||
if err != nil {
|
||||
writeJSON(w, map[string]any{"ok": true, "deleted": 0})
|
||||
return
|
||||
}
|
||||
deleted := 0
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
cacheDir := filepath.Join(s.dataDir, "cache")
|
||||
if entries, err := os.ReadDir(cacheDir); err == nil {
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
if os.Remove(filepath.Join(cacheDir, e.Name())) == nil {
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
}
|
||||
s.addLog(fmt.Sprintf("Cache cleared: %d files deleted", deleted))
|
||||
writeJSON(w, map[string]any{"ok": true, "deleted": deleted})
|
||||
mediaDeleted := 0
|
||||
if s.mediaCache != nil {
|
||||
mediaDeleted = s.mediaCache.Clear()
|
||||
}
|
||||
s.addLog(fmt.Sprintf("Cache cleared: %d message files, %d media files", deleted, mediaDeleted))
|
||||
writeJSON(w, map[string]any{"ok": true, "deleted": deleted, "mediaDeleted": mediaDeleted})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user