feat: relays for download media

This commit is contained in:
Sarto
2026-04-30 00:47:11 +03:30
parent b4e9cd8714
commit 989fec3cec
26 changed files with 2517 additions and 368 deletions
+84
View File
@@ -1,8 +1,10 @@
package server
import (
"context"
"crypto/rand"
"fmt"
"log"
"sync"
"time"
@@ -34,6 +36,14 @@ type Feed struct {
// rejects queries to media channels with a not-found error, mirroring
// pre-feature behaviour.
media *MediaCache
// gitHubRelay (optional) lets clients fetch media bytes over plain
// HTTPS from a GitHub repo. nil when disabled.
gitHubRelay *GitHubRelay
// relayInfoBlocks serves the relay-discovery channel
// (RelayInfoChannel) — block 0 contains the GitHub "owner/repo"
// string, or an empty payload if the relay is off.
relayInfoBlocks [][]byte
}
// NewFeed creates a new Feed with the given channel names.
@@ -95,6 +105,9 @@ func (f *Feed) GetBlock(channel, block int) ([]byte, error) {
if channel == int(protocol.TitlesChannel) {
return f.getTitlesBlock(block)
}
if channel == int(protocol.RelayInfoChannel) {
return f.getRelayInfoBlock(block)
}
// Channel sits in the binary media range — delegate to MediaCache. We
// drop the read lock first because MediaCache uses its own lock and we
// don't want to hold f.mu across that path.
@@ -132,6 +145,65 @@ func (f *Feed) MediaCache() *MediaCache {
return f.media
}
// SetGitHubRelay attaches the GitHub fast relay. Safe to call once at
// startup. nil disables.
func (f *Feed) SetGitHubRelay(r *GitHubRelay) {
f.mu.Lock()
defer f.mu.Unlock()
f.gitHubRelay = r
f.rebuildRelayInfoBlocks()
}
// GitHubRelay returns the configured relay, or nil.
func (f *Feed) GitHubRelay() *GitHubRelay {
f.mu.RLock()
defer f.mu.RUnlock()
return f.gitHubRelay
}
// AfterFetchCycle: touch live media → flush pending → prune stale.
// Touch must come first so files referenced by skipped fetches don't age out.
func (f *Feed) AfterFetchCycle(ctx context.Context) {
gh := f.GitHubRelay()
if gh == nil {
return
}
if mc := f.MediaCache(); mc != nil {
mc.TouchRelayEntries()
}
if err := gh.Flush(ctx); err != nil {
log.Printf("[gh-relay] flush after fetch: %v", err)
}
if ttl := gh.TTL(); ttl > 0 {
cutoff := time.Now().Add(-ttl)
if n, err := gh.PruneStale(ctx, cutoff); err != nil {
log.Printf("[gh-relay] prune after fetch: %v", err)
} else if n > 0 {
log.Printf("[gh-relay] pruned %d stale file(s) after fetch", n)
}
}
}
// rebuildRelayInfoBlocks builds the discovery payload served on
// RelayInfoChannel. Format: "key=value\n" lines (UTF-8). Block 0 is
// prefixed with a uint16 total-block count so the client can fetch the
// rest in parallel.
//
// Keys are short (gh = github owner/repo) to keep packets small.
func (f *Feed) rebuildRelayInfoBlocks() {
var payload []byte
if r := f.gitHubRelay; r != nil {
payload = []byte(fmt.Sprintf("gh=%s\n", r.Repo()))
}
blocks := protocol.SplitIntoBlocks(payload)
if len(blocks) == 0 {
blocks = [][]byte{nil}
}
prefix := []byte{byte(len(blocks) >> 8), byte(len(blocks))}
blocks[0] = append(prefix, blocks[0]...)
f.relayInfoBlocks = blocks
}
func (f *Feed) getVersionBlock(block int) ([]byte, error) {
blocks := f.versionBlocks
if len(blocks) == 0 {
@@ -198,6 +270,18 @@ func (f *Feed) getTitlesBlock(block int) ([]byte, error) {
return blocks[block], nil
}
func (f *Feed) getRelayInfoBlock(block int) ([]byte, error) {
blocks := f.relayInfoBlocks
if len(blocks) == 0 {
f.rebuildRelayInfoBlocks()
blocks = f.relayInfoBlocks
}
if block < 0 || block >= len(blocks) {
return nil, fmt.Errorf("relay-info block %d out of range (%d blocks)", block, len(blocks))
}
return blocks[block], nil
}
// rebuildTitlesBlocks re-serializes the display name map and splits it into blocks.
// Block 0 is prefixed with a uint16 total-block count so the client can fetch all
// remaining blocks in parallel after reading the first one.
+685
View File
@@ -0,0 +1,685 @@
package server
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// githubAPI is the canonical REST endpoint. Tests can override it.
var githubAPI = "https://api.github.com"
const flushBatchLimit = 100
// GitHubRelay uploads encrypted media to a GitHub repo. Domain and object
// names are HMAC'd; blobs are AES-256-GCM. Uploads are batched into one
// Git Data API commit per flush.
type GitHubRelay struct {
cfg GitHubRelayConfig
passphrase string
domain string
relayKey [protocol.KeySize]byte
branch string
client *http.Client
mu sync.Mutex
known map[string]*ghEntry
pending map[string]*pendingUpload
statePath string
dirty bool
// commitMu serialises ref-advancing operations so concurrent flushes
// don't race on updateRef.
commitMu sync.Mutex
}
type ghEntry struct {
size int64
crc uint32
lastSeen time.Time
}
type pendingUpload struct {
blob []byte
size int64
crc uint32
}
// NewGitHubRelay returns nil when the config is incomplete.
func NewGitHubRelay(cfg GitHubRelayConfig, domain, passphrase string) *GitHubRelay {
if !cfg.Active() || domain == "" || passphrase == "" {
return nil
}
relayKey, err := protocol.DeriveRelayKey(passphrase)
if err != nil {
return nil
}
branch := cfg.Branch
if branch == "" {
branch = "main"
}
r := &GitHubRelay{
cfg: cfg,
passphrase: passphrase,
domain: protocol.RelayDomainSegment(domain, passphrase),
relayKey: relayKey,
branch: branch,
client: &http.Client{Timeout: 2 * time.Minute},
known: make(map[string]*ghEntry),
pending: make(map[string]*pendingUpload),
statePath: cfg.StatePath,
}
if r.statePath != "" {
if err := r.loadState(); err != nil {
log.Printf("[gh-relay] load state %s: %v", r.statePath, err)
}
}
return r
}
type persistedEntry struct {
Size int64 `json:"size"`
CRC uint32 `json:"crc"`
LastSeen time.Time `json:"lastSeen"`
}
func (g *GitHubRelay) loadState() error {
f, err := os.Open(g.statePath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
defer f.Close()
var raw map[string]persistedEntry
if err := json.NewDecoder(f).Decode(&raw); err != nil {
return err
}
g.mu.Lock()
defer g.mu.Unlock()
for k, v := range raw {
g.known[k] = &ghEntry{size: v.Size, crc: v.CRC, lastSeen: v.LastSeen}
}
log.Printf("[gh-relay] loaded %d entries from %s", len(raw), g.statePath)
return nil
}
// saveStateLocked writes `known` to disk via a tmp+rename so a crash mid-write
// doesn't leave a truncated file. Caller must hold g.mu.
func (g *GitHubRelay) saveStateLocked() error {
if g.statePath == "" {
return nil
}
out := make(map[string]persistedEntry, len(g.known))
for k, e := range g.known {
out[k] = persistedEntry{Size: e.size, CRC: e.crc, LastSeen: e.lastSeen}
}
dir := filepath.Dir(g.statePath)
if err := os.MkdirAll(dir, 0o700); err != nil {
return err
}
tmp, err := os.CreateTemp(dir, "gh-relay-*.json")
if err != nil {
return err
}
enc := json.NewEncoder(tmp)
enc.SetIndent("", " ")
if err := enc.Encode(out); err != nil {
tmp.Close()
os.Remove(tmp.Name())
return err
}
if err := tmp.Close(); err != nil {
os.Remove(tmp.Name())
return err
}
g.dirty = false
return os.Rename(tmp.Name(), g.statePath)
}
// Repo returns the configured "owner/repo" so the discovery channel can
// expose it to clients without leaking the token.
func (g *GitHubRelay) Repo() string {
if g == nil {
return ""
}
return g.cfg.Repo
}
// MaxBytes is the per-file cap. 0 means no cap.
func (g *GitHubRelay) MaxBytes() int64 {
if g == nil {
return 0
}
return g.cfg.MaxBytes
}
// TTL returns the configured object lifetime.
func (g *GitHubRelay) TTL() time.Duration {
if g == nil {
return 0
}
return time.Duration(g.cfg.TTLMinutes) * time.Minute
}
// Domain is the HMAC'd path segment used inside the relay repo.
func (g *GitHubRelay) Domain() string {
if g == nil {
return ""
}
return g.domain
}
// Upload encrypts body and queues it for the next batched commit.
// ErrTooLarge if body exceeds the configured cap.
func (g *GitHubRelay) Upload(ctx context.Context, body []byte) error {
if g == nil {
return errors.New("github relay disabled")
}
if g.cfg.MaxBytes > 0 && int64(len(body)) > g.cfg.MaxBytes {
return ErrTooLarge
}
size := int64(len(body))
crc := crc32.ChecksumIEEE(body)
key := protocol.RelayObjectName(size, crc, g.passphrase)
g.mu.Lock()
if e, ok := g.known[key]; ok {
e.lastSeen = time.Now()
g.dirty = true
g.mu.Unlock()
return nil
}
if _, ok := g.pending[key]; ok {
g.mu.Unlock()
return nil
}
g.mu.Unlock()
blob, err := protocol.EncryptRelayBlob(g.relayKey, body)
if err != nil {
return fmt.Errorf("encrypt relay blob: %w", err)
}
g.mu.Lock()
if e, ok := g.known[key]; ok {
e.lastSeen = time.Now()
g.dirty = true
g.mu.Unlock()
return nil
}
if _, ok := g.pending[key]; ok {
g.mu.Unlock()
return nil
}
g.pending[key] = &pendingUpload{blob: blob, size: size, crc: crc}
overLimit := len(g.pending) >= flushBatchLimit
g.mu.Unlock()
if overLimit {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
if err := g.flushPending(ctx); err != nil {
log.Printf("[gh-relay] limit flush: %v", err)
}
}()
}
return nil
}
// Has reports whether the file is committed or queued for the next commit.
func (g *GitHubRelay) Has(size int64, crc uint32) bool {
if g == nil {
return false
}
key := protocol.RelayObjectName(size, crc, g.passphrase)
g.mu.Lock()
defer g.mu.Unlock()
if _, ok := g.known[key]; ok {
return true
}
_, ok := g.pending[key]
return ok
}
// Touch refreshes the lastSeen timestamp without re-uploading. Used when
// upstream re-delivers a file that's already in the relay.
func (g *GitHubRelay) Touch(size int64, crc uint32) {
if g == nil {
return
}
key := protocol.RelayObjectName(size, crc, g.passphrase)
g.mu.Lock()
if e, ok := g.known[key]; ok {
e.lastSeen = time.Now()
g.dirty = true
}
g.mu.Unlock()
}
// PruneStale removes every file in `known` whose lastSeen is older than
// cutoff. Selection happens INSIDE commitMu so concurrent prunes from
// different readers can't pick the same files and race the resulting
// commits (which used to produce 422 BadObjectState).
func (g *GitHubRelay) PruneStale(ctx context.Context, cutoff time.Time) (int, error) {
if g == nil {
return 0, nil
}
g.commitMu.Lock()
defer g.commitMu.Unlock()
g.mu.Lock()
var entries []treeEntry
var keys []string
for k, e := range g.known {
if e.lastSeen.Before(cutoff) {
entries = append(entries, treeEntry{
Path: g.domain + "/" + k,
Mode: "100644",
Type: "blob",
SHA: nil,
})
keys = append(keys, k)
}
}
g.mu.Unlock()
if len(entries) == 0 {
return 0, nil
}
log.Printf("[gh-relay] starting prune of %d file(s)", len(entries))
headSHA, err := g.getRef(ctx, g.branch)
if err != nil {
return 0, fmt.Errorf("get ref: %w", err)
}
parentTree, err := g.getCommitTree(ctx, headSHA)
if err != nil {
return 0, fmt.Errorf("get commit %s: %w", headSHA, err)
}
newTree, err := g.createTree(ctx, parentTree, entries)
if err != nil {
return 0, fmt.Errorf("create tree: %w", err)
}
msg := fmt.Sprintf("thefeed: prune %d file(s)", len(entries))
commitSHA, err := g.createCommit(ctx, msg, newTree, []string{headSHA})
if err != nil {
return 0, fmt.Errorf("create commit: %w", err)
}
if err := g.updateRef(ctx, g.branch, commitSHA); err != nil {
return 0, fmt.Errorf("update ref %s: %w", g.branch, err)
}
g.mu.Lock()
for _, k := range keys {
delete(g.known, k)
}
g.dirty = true
if err := g.saveStateLocked(); err != nil {
log.Printf("[gh-relay] save state after prune: %v", err)
}
g.mu.Unlock()
return len(entries), nil
}
// --- Flush loop -------------------------------------------------------------
// Run waits for shutdown and flushes any remaining pending uploads on the
// way out. Flush + prune during normal operation are driven by
// Feed.AfterFetchCycle so they line up with the natural cadence of upstream
// fetches. A best-effort backstop tick handles the case where nothing has
// fetched in a long time (e.g. all channels were skipped from cache).
func (g *GitHubRelay) Run(ctx context.Context) {
if g == nil {
return
}
tick := time.NewTicker(10 * time.Minute)
defer tick.Stop()
saveTick := time.NewTicker(5 * time.Minute)
defer saveTick.Stop()
for {
select {
case <-saveTick.C:
g.mu.Lock()
if g.dirty && g.statePath != "" {
if err := g.saveStateLocked(); err != nil {
log.Printf("[gh-relay] periodic save: %v", err)
}
}
g.mu.Unlock()
case <-ctx.Done():
fctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
if err := g.flushPending(fctx); err != nil {
log.Printf("[gh-relay] shutdown flush: %v", err)
}
cancel()
g.mu.Lock()
if g.dirty {
if err := g.saveStateLocked(); err != nil {
log.Printf("[gh-relay] shutdown save: %v", err)
}
}
g.mu.Unlock()
return
case <-tick.C:
if g.queueSize() == 0 {
continue
}
fctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
if err := g.flushPending(fctx); err != nil {
log.Printf("[gh-relay] backstop flush: %v", err)
}
cancel()
}
}
}
func (g *GitHubRelay) queueSize() int {
g.mu.Lock()
n := len(g.pending)
g.mu.Unlock()
return n
}
// Flush forces an immediate commit of any pending uploads. Safe to call
// from tests or graceful shutdown; does nothing if the queue is empty.
func (g *GitHubRelay) Flush(ctx context.Context) error {
if g == nil {
return nil
}
return g.flushPending(ctx)
}
// flushPending drains the pending map into a single Git commit via the Git
// Data API. On any error the batch is re-queued so the next tick retries.
func (g *GitHubRelay) flushPending(ctx context.Context) error {
g.mu.Lock()
if len(g.pending) == 0 {
g.mu.Unlock()
return nil
}
batch := g.pending
g.pending = make(map[string]*pendingUpload)
g.mu.Unlock()
if err := g.commitBatch(ctx, batch); err != nil {
// Re-queue. A peer goroutine may have queued newer entries with
// the same key; prefer those.
g.mu.Lock()
for k, v := range batch {
if _, exists := g.pending[k]; !exists {
g.pending[k] = v
}
}
g.mu.Unlock()
return err
}
now := time.Now()
g.mu.Lock()
for k, p := range batch {
g.known[k] = &ghEntry{size: p.size, crc: p.crc, lastSeen: now}
}
g.dirty = true
if err := g.saveStateLocked(); err != nil {
log.Printf("[gh-relay] save state: %v", err)
}
g.mu.Unlock()
log.Printf("[gh-relay] committed %d file(s)", len(batch))
return nil
}
// treeEntry is the Git Data API tree-item shape used by both upload
// (SHA = newly-created blob) and delete (SHA = nil → entry removed from
// the resulting tree).
type treeEntry struct {
Path string `json:"path"`
Mode string `json:"mode"`
Type string `json:"type"`
SHA *string `json:"sha"` // pointer so nil serialises as JSON `null`
}
// commitBatch performs the Git Data API dance:
//
// GET ref → POST blobs → POST tree (with base_tree) → POST commit → PATCH ref.
//
// A single commit covers every file in the batch, regardless of count.
func (g *GitHubRelay) commitBatch(ctx context.Context, batch map[string]*pendingUpload) error {
if len(batch) == 0 {
return nil
}
g.commitMu.Lock()
defer g.commitMu.Unlock()
log.Printf("[gh-relay] starting upload of %d file(s)", len(batch))
headSHA, err := g.getRef(ctx, g.branch)
if err != nil {
return fmt.Errorf("get ref: %w", err)
}
parentTree, err := g.getCommitTree(ctx, headSHA)
if err != nil {
return fmt.Errorf("get commit %s: %w", headSHA, err)
}
entries := make([]treeEntry, 0, len(batch))
for objKey, p := range batch {
blobSHA, err := g.createBlob(ctx, p.blob)
if err != nil {
return fmt.Errorf("create blob %s: %w", objKey, err)
}
s := blobSHA
entries = append(entries, treeEntry{
Path: g.domain + "/" + objKey,
Mode: "100644",
Type: "blob",
SHA: &s,
})
}
newTree, err := g.createTree(ctx, parentTree, entries)
if err != nil {
return fmt.Errorf("create tree: %w", err)
}
msg := fmt.Sprintf("thefeed: upload %d file(s)", len(batch))
commitSHA, err := g.createCommit(ctx, msg, newTree, []string{headSHA})
if err != nil {
return fmt.Errorf("create commit: %w", err)
}
if err := g.updateRef(ctx, g.branch, commitSHA); err != nil {
return fmt.Errorf("update ref %s: %w", g.branch, err)
}
return nil
}
// --- Git Data API plumbing --------------------------------------------------
func (g *GitHubRelay) getRef(ctx context.Context, branch string) (string, error) {
req, err := g.newReq(ctx, http.MethodGet, "/repos/"+g.cfg.Repo+"/git/ref/heads/"+branch, nil)
if err != nil {
return "", err
}
resp, err := g.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("%s — %s", resp.Status, string(body))
}
var out struct {
Object struct {
SHA string `json:"sha"`
} `json:"object"`
}
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return "", err
}
return out.Object.SHA, nil
}
func (g *GitHubRelay) getCommitTree(ctx context.Context, commitSHA string) (string, error) {
req, err := g.newReq(ctx, http.MethodGet, "/repos/"+g.cfg.Repo+"/git/commits/"+commitSHA, nil)
if err != nil {
return "", err
}
resp, err := g.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("%s — %s", resp.Status, string(body))
}
var out struct {
Tree struct {
SHA string `json:"sha"`
} `json:"tree"`
}
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return "", err
}
return out.Tree.SHA, nil
}
func (g *GitHubRelay) createBlob(ctx context.Context, content []byte) (string, error) {
body, _ := json.Marshal(map[string]any{
"encoding": "base64",
"content": base64.StdEncoding.EncodeToString(content),
})
req, err := g.newReq(ctx, http.MethodPost, "/repos/"+g.cfg.Repo+"/git/blobs", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := g.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
raw, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("%s — %s", resp.Status, string(raw))
}
var out struct {
SHA string `json:"sha"`
}
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return "", err
}
return out.SHA, nil
}
func (g *GitHubRelay) createTree(ctx context.Context, baseTree string, entries any) (string, error) {
body, _ := json.Marshal(map[string]any{
"base_tree": baseTree,
"tree": entries,
})
req, err := g.newReq(ctx, http.MethodPost, "/repos/"+g.cfg.Repo+"/git/trees", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := g.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
raw, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("%s — %s", resp.Status, string(raw))
}
var out struct {
SHA string `json:"sha"`
}
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return "", err
}
return out.SHA, nil
}
func (g *GitHubRelay) createCommit(ctx context.Context, message, treeSHA string, parents []string) (string, error) {
body, _ := json.Marshal(map[string]any{
"message": message,
"tree": treeSHA,
"parents": parents,
})
req, err := g.newReq(ctx, http.MethodPost, "/repos/"+g.cfg.Repo+"/git/commits", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := g.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
raw, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("%s — %s", resp.Status, string(raw))
}
var out struct {
SHA string `json:"sha"`
}
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return "", err
}
return out.SHA, nil
}
func (g *GitHubRelay) updateRef(ctx context.Context, branch, commitSHA string) error {
body, _ := json.Marshal(map[string]any{
"sha": commitSHA,
"force": false,
})
req, err := g.newReq(ctx, http.MethodPatch, "/repos/"+g.cfg.Repo+"/git/refs/heads/"+branch, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := g.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
raw, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s — %s", resp.Status, string(raw))
}
return nil
}
// --- HTTP plumbing ----------------------------------------------------------
func (g *GitHubRelay) newReq(ctx context.Context, method, urlPath string, body io.Reader) (*http.Request, error) {
full := strings.TrimRight(githubAPI, "/") + urlPath
req, err := http.NewRequestWithContext(ctx, method, full, body)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+g.cfg.Token)
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
req.Header.Set("User-Agent", "thefeed-server")
return req, nil
}
+242
View File
@@ -0,0 +1,242 @@
package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"hash/crc32"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"testing"
"time"
)
// fakeGitHub stubs the slice of GitHub's REST API the relay uses:
// - Git Data API (refs / commits / blobs / trees) for batched uploads
// - Contents API (list / delete) for PruneStale
type fakeGitHub struct {
mu sync.Mutex
files map[string][]byte // repoPath → ciphertext (committed)
commits int // number of commits created (rate-limit metric)
blobs int // blob create count
deletes int // contents-api deletions
// Tree state — dumb counter; we don't model real Git history.
headSHA string
treeSHA string
nextSeq int
}
func (f *fakeGitHub) sha(prefix string) string {
f.nextSeq++
return prefix + "-" + strconv.Itoa(f.nextSeq)
}
func (f *fakeGitHub) handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
f.mu.Lock()
defer f.mu.Unlock()
path := strings.TrimPrefix(r.URL.Path, "/repos/owner/repo/")
// --- Git Data API ---------------------------------------------------
switch {
case r.Method == http.MethodGet && strings.HasPrefix(path, "git/ref/heads/"):
if f.headSHA == "" {
f.headSHA = f.sha("commit")
}
_ = json.NewEncoder(w).Encode(map[string]any{
"object": map[string]any{"sha": f.headSHA},
})
return
case r.Method == http.MethodGet && strings.HasPrefix(path, "git/commits/"):
if f.treeSHA == "" {
f.treeSHA = f.sha("tree")
}
_ = json.NewEncoder(w).Encode(map[string]any{
"tree": map[string]any{"sha": f.treeSHA},
})
return
case r.Method == http.MethodPost && path == "git/blobs":
var body struct{ Content string }
_ = json.NewDecoder(r.Body).Decode(&body)
f.blobs++
s := f.sha("blob")
_ = json.NewEncoder(w).Encode(map[string]any{"sha": s})
return
case r.Method == http.MethodPost && path == "git/trees":
// SHA is *string so null serialises as JSON null and decodes back to nil.
var body struct {
BaseTree string `json:"base_tree"`
Tree []struct {
Path string `json:"path"`
SHA *string `json:"sha"`
} `json:"tree"`
}
_ = json.NewDecoder(r.Body).Decode(&body)
for _, e := range body.Tree {
if e.SHA == nil {
delete(f.files, e.Path)
f.deletes++
} else {
f.files[e.Path] = []byte("committed")
}
}
f.treeSHA = f.sha("tree")
_ = json.NewEncoder(w).Encode(map[string]any{"sha": f.treeSHA})
return
case r.Method == http.MethodPost && path == "git/commits":
f.commits++
f.headSHA = f.sha("commit")
_ = json.NewEncoder(w).Encode(map[string]any{"sha": f.headSHA})
return
case r.Method == http.MethodPatch && strings.HasPrefix(path, "git/refs/heads/"):
w.WriteHeader(http.StatusOK)
return
}
// --- Contents API (used only for the directory listing in PruneStale) ---
if r.Method == http.MethodGet {
repoPath := strings.TrimPrefix(path, "contents/")
items := []map[string]any{}
prefix := repoPath + "/"
for k, v := range f.files {
if strings.HasPrefix(k, prefix) {
items = append(items, map[string]any{
"path": k, "sha": "sha-" + k, "type": "file", "size": len(v),
})
}
}
_ = json.NewEncoder(w).Encode(items)
}
})
}
func newFakeGitHub(t *testing.T) (*fakeGitHub, func()) {
f := &fakeGitHub{files: map[string][]byte{}}
srv := httptest.NewServer(f.handler())
prev := githubAPI
githubAPI = srv.URL
t.Cleanup(func() { githubAPI = prev; srv.Close() })
return f, srv.Close
}
func TestGitHubRelayUploadAndDedup(t *testing.T) {
fk, _ := newFakeGitHub(t)
r := NewGitHubRelay(GitHubRelayConfig{Enabled: true, Token: "tok", Repo: "owner/repo", MaxBytes: 1 << 20, TTLMinutes: 60}, "feed.example.com", "test-passphrase")
if r == nil {
t.Fatal("relay should activate with full config")
}
body := []byte("hello relay world")
if err := r.Upload(context.Background(), body); err != nil {
t.Fatalf("first upload: %v", err)
}
// Second upload of the same content must dedup before reaching GitHub.
if err := r.Upload(context.Background(), body); err != nil {
t.Fatalf("second upload: %v", err)
}
// Force the batch to commit synchronously.
if err := r.Flush(context.Background()); err != nil {
t.Fatalf("flush: %v", err)
}
if fk.commits != 1 {
t.Errorf("commits = %d, want 1 (one batch)", fk.commits)
}
if fk.blobs != 1 {
t.Errorf("blobs = %d, want 1 (dedup before flush)", fk.blobs)
}
if !r.Has(int64(len(body)), crc32.ChecksumIEEE(body)) {
t.Errorf("Has should return true after upload")
}
// A third Flush with no new uploads must be a no-op (no new commit).
if err := r.Flush(context.Background()); err != nil {
t.Fatalf("noop flush: %v", err)
}
if fk.commits != 1 {
t.Errorf("commits after noop flush = %d, want 1", fk.commits)
}
}
func TestGitHubRelayMaxBytes(t *testing.T) {
newFakeGitHub(t)
r := NewGitHubRelay(GitHubRelayConfig{Enabled: true, Token: "tok", Repo: "owner/repo", MaxBytes: 16, TTLMinutes: 60}, "ex.test", "pp")
err := r.Upload(context.Background(), bytes.Repeat([]byte("x"), 32))
if !errors.Is(err, ErrTooLarge) {
t.Fatalf("err = %v, want ErrTooLarge", err)
}
}
func TestGitHubRelayPruneStale(t *testing.T) {
fk, _ := newFakeGitHub(t)
r := NewGitHubRelay(GitHubRelayConfig{Enabled: true, Token: "tok", Repo: "owner/repo", MaxBytes: 1 << 20, TTLMinutes: 1}, "ex.test", "pp")
if err := r.Upload(context.Background(), []byte("stays")); err != nil {
t.Fatalf("upload stays: %v", err)
}
if err := r.Upload(context.Background(), []byte("goes")); err != nil {
t.Fatalf("upload goes: %v", err)
}
// Commit the batch so PruneStale can find files in the listing.
if err := r.Flush(context.Background()); err != nil {
t.Fatalf("flush: %v", err)
}
// Roll back the lastSeen of the "goes" entry so PruneStale removes it.
// "stays" is 5 bytes, "goes" is 4 — match by size.
r.mu.Lock()
for _, e := range r.known {
if e.size == 4 {
e.lastSeen = time.Now().Add(-2 * time.Hour)
}
}
r.mu.Unlock()
commitsBefore := fk.commits
removed, err := r.PruneStale(context.Background(), time.Now().Add(-time.Hour))
if err != nil {
t.Fatalf("prune: %v", err)
}
if removed != 1 {
t.Errorf("removed = %d, want 1", removed)
}
if fk.deletes != 1 {
t.Errorf("tree-deletes = %d, want 1", fk.deletes)
}
if got := fk.commits - commitsBefore; got != 1 {
t.Errorf("prune commits = %d, want 1 (single batched commit)", got)
}
}
// TestGitHubRelayStatePersistence: known map survives a fresh relay
// instance pointed at the same statePath.
func TestGitHubRelayStatePersistence(t *testing.T) {
newFakeGitHub(t)
dir := t.TempDir()
statePath := dir + "/gh_relay_state.json"
cfg := GitHubRelayConfig{Enabled: true, Token: "tok", Repo: "owner/repo", MaxBytes: 1 << 20, TTLMinutes: 60, StatePath: statePath}
r1 := NewGitHubRelay(cfg, "ex.test", "pp")
if err := r1.Upload(context.Background(), []byte("survive me")); err != nil {
t.Fatalf("upload: %v", err)
}
if err := r1.Flush(context.Background()); err != nil {
t.Fatalf("flush: %v", err)
}
body := []byte("survive me")
if !r1.Has(int64(len(body)), crc32.ChecksumIEEE(body)) {
t.Fatal("r1 should know the file after flush")
}
r2 := NewGitHubRelay(cfg, "ex.test", "pp")
if !r2.Has(int64(len(body)), crc32.ChecksumIEEE(body)) {
t.Fatal("r2 should have loaded the file from statePath")
}
}
+156 -51
View File
@@ -4,6 +4,7 @@ import (
"bytes"
"compress/flate"
"compress/gzip"
"context"
"errors"
"fmt"
"hash/crc32"
@@ -27,11 +28,12 @@ type MediaCache struct {
maxFileBytes int64
ttl time.Duration
compression protocol.MediaCompression
dnsEnabled bool // when false, RelayDNS stays unset on the wire
// Logger receives an info line per cache event when set (Store hits/misses,
// evictions). The default is a silent no-op so tests don't print noise.
logf func(format string, args ...interface{})
gh *GitHubRelay
mu sync.RWMutex
byKey map[string]*mediaEntry // upstream key (file_id / URL) → entry
byChannel map[uint16]*mediaEntry // assigned channel → entry
@@ -67,16 +69,11 @@ type mediaEntry struct {
// MediaCacheConfig configures a new MediaCache.
type MediaCacheConfig struct {
// MaxFileBytes is the largest individual file the cache will accept.
// Files larger than this are rejected by Store with ErrTooLarge.
MaxFileBytes int64
// TTL is how long an entry stays cached after its last refresh.
TTL time.Duration
// Compression is the wire-format compression used for media blocks.
// Defaults to MediaCompressionNone when zero.
Compression protocol.MediaCompression
// Logf receives info-level cache events. Optional.
Logf func(format string, args ...interface{})
MaxFileBytes int64
TTL time.Duration
Compression protocol.MediaCompression
Logf func(format string, args ...interface{})
DNSRelayEnabled bool // controls Relays[RelayDNS] on the wire
}
// ErrTooLarge is returned by Store when content exceeds MaxFileBytes.
@@ -99,6 +96,7 @@ func NewMediaCache(cfg MediaCacheConfig) *MediaCache {
maxFileBytes: cfg.MaxFileBytes,
ttl: cfg.TTL,
compression: cfg.Compression,
dnsEnabled: cfg.DNSRelayEnabled,
logf: logf,
byKey: make(map[string]*mediaEntry),
byChannel: make(map[uint16]*mediaEntry),
@@ -128,14 +126,18 @@ func (c *MediaCache) Store(cacheKey, tag string, content []byte, mimeType, filen
tag = protocol.MediaFile
}
size := int64(len(content))
if c.maxFileBytes > 0 && size > c.maxFileBytes {
// Reject only when no enabled relay could host this file. A file too big
// for DNS but small enough for GitHub still belongs in the cache —
// MaxAcceptableBytes() collapses both caps into a single ceiling.
if max := c.MaxAcceptableBytes(); max > 0 && size > max {
atomic.AddUint64(&c.storeRejected, 1)
return protocol.MediaMeta{
Tag: tag,
Size: size,
Downloadable: false,
Tag: tag,
Size: size,
Relays: nil,
}, ErrTooLarge
}
dnsFits := c.maxFileBytes == 0 || size <= c.maxFileBytes
now := time.Now()
hash := crc32.ChecksumIEEE(content)
@@ -144,18 +146,15 @@ func (c *MediaCache) Store(cacheKey, tag string, content []byte, mimeType, filen
defer c.mu.Unlock()
if existing, ok := c.byKey[cacheKey]; ok && existing.crc32 == hash {
// Same upstream id and same content — just refresh the TTL.
existing.expiresAt = c.expiry(now)
atomic.AddUint64(&c.storeHits, 1)
c.logf("media: refresh tag=%s key=%s ch=%d size=%d", tag, cacheKey, existing.channel, existing.size)
if c.gh != nil {
c.gh.Touch(existing.size, existing.crc32)
}
return c.metaForLocked(existing), nil
}
// Cross-key content match: a different upstream id pointed at exactly
// the same bytes. Bind the new cache key to the existing entry so any
// future Lookup under either key works, and refresh the TTL. This is
// the case the spec asks for: "same media → just reset TTL, don't take
// a new channel slot".
if existing, ok := c.byHash[hash]; ok {
existing.expiresAt = c.expiry(now)
if cacheKey != existing.cacheKey {
@@ -173,6 +172,9 @@ func (c *MediaCache) Store(cacheKey, tag string, content []byte, mimeType, filen
c.byKey[cacheKey] = existing
atomic.AddUint64(&c.storeHits, 1)
c.logf("media: dedup tag=%s key=%s ch=%d size=%d (hash match)", tag, cacheKey, existing.channel, existing.size)
if c.gh != nil {
c.gh.Touch(existing.size, existing.crc32)
}
return c.metaForLocked(existing), nil
}
@@ -190,29 +192,38 @@ func (c *MediaCache) Store(cacheKey, tag string, content []byte, mimeType, filen
// active entries; n is capped by the media-channel range.
c.sweepExpiredLocked(now)
channel, err := c.allocateChannelLocked(now)
if err != nil {
return protocol.MediaMeta{}, err
}
blocks, encErr := splitMediaBlocks(hash, content, c.compression)
if encErr != nil {
return protocol.MediaMeta{}, encErr
}
if size > 0 {
var compressedBody int
for _, b := range blocks {
compressedBody += len(b)
var (
channel uint16
blocks [][]byte
)
if dnsFits {
var err error
channel, err = c.allocateChannelLocked(now)
if err != nil {
return protocol.MediaMeta{}, err
}
compressedBody -= protocol.MediaBlockHeaderLen
if compressedBody < 0 {
compressedBody = 0
var encErr error
blocks, encErr = splitMediaBlocks(hash, content, c.compression)
if encErr != nil {
return protocol.MediaMeta{}, encErr
}
var savedPct int
if c.compression != protocol.MediaCompressionNone && size > 0 {
savedPct = int((size - int64(compressedBody)) * 100 / size)
if size > 0 {
var compressedBody int
for _, b := range blocks {
compressedBody += len(b)
}
compressedBody -= protocol.MediaBlockHeaderLen
if compressedBody < 0 {
compressedBody = 0
}
var savedPct int
if c.compression != protocol.MediaCompressionNone && size > 0 {
savedPct = int((size - int64(compressedBody)) * 100 / size)
}
c.logf("media: compress=%s key=%s orig=%d body=%d saved=%d%%", c.compression, cacheKey, size, compressedBody, savedPct)
}
c.logf("media: compress=%s key=%s orig=%d body=%d saved=%d%%", c.compression, cacheKey, size, compressedBody, savedPct)
} else {
c.logf("media: store key=%s size=%d too big for DNS — relay only", cacheKey, size)
}
entry := &mediaEntry{
channel: channel,
@@ -226,13 +237,29 @@ func (c *MediaCache) Store(cacheKey, tag string, content []byte, mimeType, filen
expiresAt: c.expiry(now),
}
c.byKey[cacheKey] = entry
c.byChannel[channel] = entry
if dnsFits {
c.byChannel[channel] = entry
}
c.byHash[hash] = entry
atomic.AddUint64(&c.storeMisses, 1)
atomic.AddInt64(&c.currentEntries, 1)
atomic.AddInt64(&c.currentBytes, size)
c.logf("media: store tag=%s key=%s ch=%d size=%d blocks=%d", tag, cacheKey, channel, size, len(blocks))
// Best-effort relay upload — copy of `content` because the caller may
// reuse the slice. Failures are logged but never block the DNS path.
if c.gh != nil {
gh := c.gh
body := append([]byte(nil), content...)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
if err := gh.Upload(ctx, body); err != nil {
c.logf("media: gh-relay upload failed: %v", err)
}
}()
}
return c.metaForLocked(entry), nil
}
@@ -439,15 +466,93 @@ func (c *MediaCache) expiry(now time.Time) time.Time {
}
func (c *MediaCache) metaForLocked(entry *mediaEntry) protocol.MediaMeta {
return protocol.MediaMeta{
Tag: entry.tag,
Size: entry.size,
Downloadable: true,
Channel: entry.channel,
Blocks: uint16(len(entry.blocks)),
CRC32: entry.crc32,
Filename: entry.filename,
// DNS bit only when DNS is enabled AND we actually computed blocks for
// this entry. Files larger than the DNS cap have len(blocks)==0.
dnsOK := c.dnsEnabled && len(entry.blocks) > 0
// GitHub bit reflects "the relay would serve this file": relay enabled
// and the file fits its cap. We don't require the upload to have
// finished — small files in particular would otherwise miss the bit on
// first render because the upload runs asynchronously. The web layer
// retries transient 404s while the upload is still in flight.
ghOK := false
if c.gh != nil {
ghMax := c.gh.MaxBytes()
ghOK = ghMax == 0 || entry.size <= ghMax
}
relays := []bool{dnsOK, ghOK}
meta := protocol.MediaMeta{
Tag: entry.tag,
Size: entry.size,
Relays: relays,
CRC32: entry.crc32,
Filename: entry.filename,
}
if dnsOK {
meta.Channel = entry.channel
meta.Blocks = uint16(len(entry.blocks))
}
return meta
}
// SetGitHubRelay attaches the GitHub fast relay. Store calls (and Lookup
// hits) will then surface RelayGitHub when the relay has the bytes.
func (c *MediaCache) SetGitHubRelay(g *GitHubRelay) {
c.mu.Lock()
defer c.mu.Unlock()
c.gh = g
}
// TouchRelayEntries refreshes relay lastSeen for every cached file so
// files referenced by skipped-fetch cycles aren't pruned.
func (c *MediaCache) TouchRelayEntries() {
if c == nil {
return
}
c.mu.RLock()
gh := c.gh
if gh == nil {
c.mu.RUnlock()
return
}
pairs := make([][2]uint64, 0, len(c.byHash))
for _, e := range c.byHash {
pairs = append(pairs, [2]uint64{uint64(e.size), uint64(e.crc32)})
}
c.mu.RUnlock()
for _, p := range pairs {
gh.Touch(int64(p[0]), uint32(p[1]))
}
}
// MaxAcceptableBytes returns the largest file size any enabled relay would
// accept. Callers use it as the "should we even fetch this?" gate so that
// files which fit GitHub but not DNS still get pulled. 0 means "no cap".
func (c *MediaCache) MaxAcceptableBytes() int64 {
if c == nil {
return 0
}
c.mu.RLock()
gh := c.gh
c.mu.RUnlock()
dns := c.maxFileBytes
var ghMax int64
if gh != nil {
ghMax = gh.MaxBytes()
}
// 0 from any enabled relay means "no cap" — propagate.
if (dns == 0 && c.dnsEnabled) || (gh != nil && ghMax == 0) {
return 0
}
if !c.dnsEnabled {
return ghMax
}
if gh == nil {
return dns
}
if ghMax > dns {
return ghMax
}
return dns
}
// splitMediaBlocks compresses the content (when compression != none),
+7 -7
View File
@@ -90,13 +90,13 @@ func downloadHTTPMedia(ctx context.Context, cache *MediaCache, tag, rawURL strin
return protocol.MediaMeta{}, false
}
maxBytes := cache.maxFileBytes
maxBytes := cache.MaxAcceptableBytes()
if maxBytes > 0 && resp.ContentLength > 0 && resp.ContentLength > maxBytes {
size := resp.ContentLength
return protocol.MediaMeta{
Tag: tag,
Size: size,
Downloadable: false,
Tag: tag,
Size: size,
Relays: nil,
}, true
}
@@ -115,9 +115,9 @@ func downloadHTTPMedia(ctx context.Context, cache *MediaCache, tag, rawURL strin
}
if maxBytes > 0 && int64(len(bytes)) > maxBytes {
return protocol.MediaMeta{
Tag: tag,
Size: int64(len(bytes)),
Downloadable: false,
Tag: tag,
Size: int64(len(bytes)),
Relays: nil,
}, true
}
+13 -12
View File
@@ -27,7 +27,7 @@ func TestApplyHTTPMediaSourcesEndToEnd(t *testing.T) {
}))
defer srv.Close()
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 1 << 20, TTL: time.Hour})
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 1 << 20, TTL: time.Hour, DNSRelayEnabled: true})
msgs := []protocol.Message{
{ID: 100, Timestamp: 1, Text: protocol.MediaImage + "\nhello"},
@@ -42,7 +42,7 @@ func TestApplyHTTPMediaSourcesEndToEnd(t *testing.T) {
if !ok {
t.Fatalf("ParseMediaText ok=false on rewritten message: %q", msgs[0].Text)
}
if !meta.Downloadable {
if !meta.HasRelay(protocol.RelayDNS) {
t.Fatalf("expected downloadable meta, got %+v (text=%q)", meta, msgs[0].Text)
}
if meta.Tag != protocol.MediaImage {
@@ -80,7 +80,7 @@ func TestApplyHTTPMediaSourcesEndToEnd(t *testing.T) {
}
}
// TestApplyHTTPMediaSourcesGzipRoundTrip: with --media-compression=gzip,
// TestApplyHTTPMediaSourcesGzipRoundTrip: with --dns-media-compression=gzip,
// a successful upstream fetch lands compressed blocks in the cache. A
// client decompressing the assembled blocks recovers the original bytes
// verbatim and the embedded CRC32 matches.
@@ -94,9 +94,10 @@ func TestApplyHTTPMediaSourcesGzipRoundTrip(t *testing.T) {
defer srv.Close()
cache := NewMediaCache(MediaCacheConfig{
MaxFileBytes: 1 << 20,
TTL: time.Hour,
Compression: protocol.MediaCompressionGzip,
MaxFileBytes: 1 << 20,
TTL: time.Hour,
Compression: protocol.MediaCompressionGzip,
DNSRelayEnabled: true,
})
msgs := []protocol.Message{{ID: 100, Timestamp: 1, Text: protocol.MediaImage + "\n"}}
sources := []mediaSource{{tag: protocol.MediaImage, url: srv.URL + "/big.png"}}
@@ -106,7 +107,7 @@ func TestApplyHTTPMediaSourcesGzipRoundTrip(t *testing.T) {
applyHTTPMediaSources(ctx, cache, msgs, sources)
meta, _, ok := protocol.ParseMediaText(msgs[0].Text)
if !ok || !meta.Downloadable {
if !ok || !meta.HasRelay(protocol.RelayDNS) {
t.Fatalf("expected downloadable meta, got %+v", meta)
}
@@ -169,7 +170,7 @@ func TestApplyHTTPMediaSourcesAlbum(t *testing.T) {
}))
defer srv.Close()
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 1 << 20, TTL: time.Hour})
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 1 << 20, TTL: time.Hour, DNSRelayEnabled: true})
// Mirror what parsePublicMessagesWithMedia produces for a 3-image album:
// stacked [IMAGE] headers + caption, plus an extraURLs slice on the source.
@@ -207,7 +208,7 @@ func TestApplyHTTPMediaSourcesAlbum(t *testing.T) {
if !ok {
t.Fatalf("ParseMediaText #%d ok=false on %q", i, rest)
}
if !meta.Downloadable {
if !meta.HasRelay(protocol.RelayDNS) {
t.Errorf("header #%d not downloadable: %+v", i, meta)
}
if int(meta.Size) != len(images[i]) {
@@ -236,7 +237,7 @@ func TestApplyHTTPMediaSourcesAlbumPartialFailure(t *testing.T) {
}))
defer srv.Close()
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 1 << 20, TTL: time.Hour})
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 1 << 20, TTL: time.Hour, DNSRelayEnabled: true})
body := protocol.MediaImage + "\n" + protocol.MediaImage + "\ncap"
msgs := []protocol.Message{{ID: 5, Timestamp: 1, Text: body}}
@@ -273,7 +274,7 @@ func TestApplyHTTPMediaSourcesRejectsOversize(t *testing.T) {
}))
defer srv.Close()
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 100, TTL: time.Hour})
cache := NewMediaCache(MediaCacheConfig{MaxFileBytes: 100, TTL: time.Hour, DNSRelayEnabled: true})
msgs := []protocol.Message{{ID: 1, Timestamp: 1, Text: protocol.MediaImage + "\ncap"}}
sources := []mediaSource{{tag: protocol.MediaImage, url: srv.URL + "/big.jpg"}}
@@ -285,7 +286,7 @@ func TestApplyHTTPMediaSourcesRejectsOversize(t *testing.T) {
if !ok {
t.Fatalf("ParseMediaText ok=false")
}
if meta.Downloadable {
if meta.HasRelay(protocol.RelayDNS) {
t.Fatalf("oversized file should not be downloadable; got meta=%+v", meta)
}
if meta.Size != int64(len(bigBody)) {
+11 -10
View File
@@ -78,12 +78,13 @@ func (tr *TelegramReader) downloadTelegramPhoto(ctx context.Context, api *tg.Cli
return protocol.MediaMeta{}, false
}
// Honour the configured max-size early so we don't even open the RPC for
// objects we'll just throw away.
if maxBytes := cache.maxFileBytes; maxBytes > 0 && bestBytes > maxBytes {
// objects no enabled relay would accept. Files that fit GitHub but not
// DNS still get fetched.
if maxBytes := cache.MaxAcceptableBytes(); maxBytes > 0 && bestBytes > maxBytes {
return protocol.MediaMeta{
Tag: protocol.MediaImage,
Size: bestBytes,
Downloadable: false,
Tag: protocol.MediaImage,
Size: bestBytes,
Relays: nil,
}, true
}
@@ -128,11 +129,11 @@ func (tr *TelegramReader) downloadTelegramDocument(ctx context.Context, api *tg.
return protocol.MediaMeta{}, false
}
if maxBytes := cache.maxFileBytes; maxBytes > 0 && doc.Size > maxBytes {
if maxBytes := cache.MaxAcceptableBytes(); maxBytes > 0 && doc.Size > maxBytes {
return protocol.MediaMeta{
Tag: tag,
Size: doc.Size,
Downloadable: false,
Tag: tag,
Size: doc.Size,
Relays: nil,
}, true
}
@@ -170,7 +171,7 @@ func (tr *TelegramReader) downloadTelegramFile(ctx context.Context, api *tg.Clie
cache := tr.feed.MediaCache()
maxBytes := int64(0)
if cache != nil {
maxBytes = cache.maxFileBytes
maxBytes = cache.MaxAcceptableBytes()
}
var (
+24 -6
View File
@@ -12,7 +12,24 @@ import (
)
func newTestCache(maxBytes int64, ttl time.Duration) *MediaCache {
return NewMediaCache(MediaCacheConfig{MaxFileBytes: maxBytes, TTL: ttl})
return NewMediaCache(MediaCacheConfig{MaxFileBytes: maxBytes, TTL: ttl, DNSRelayEnabled: true})
}
// TestMediaCacheRelayFlags: with DNS off the wire flag stays clear, and
// when a GitHub relay is attached the cache surfaces RelayGitHub.
func TestMediaCacheRelayFlags(t *testing.T) {
cfg := MediaCacheConfig{MaxFileBytes: 1 << 20, TTL: time.Hour, DNSRelayEnabled: false}
cache := NewMediaCache(cfg)
meta, err := cache.Store("k", protocol.MediaImage, []byte("payload"), "image/jpeg", "")
if err != nil {
t.Fatalf("Store: %v", err)
}
if meta.HasRelay(protocol.RelayDNS) {
t.Errorf("DNS relay should be off when DNSRelayEnabled=false")
}
if meta.HasRelay(protocol.RelayGitHub) {
t.Errorf("GitHub relay should be off when no relay is attached")
}
}
func TestMediaCacheStoreAndGetBlock(t *testing.T) {
@@ -23,8 +40,8 @@ func TestMediaCacheStoreAndGetBlock(t *testing.T) {
if err != nil {
t.Fatalf("Store: %v", err)
}
if !meta.Downloadable {
t.Fatalf("Downloadable = false, want true")
if !meta.HasRelay(protocol.RelayDNS) {
t.Fatalf("RelayDNS = false, want true")
}
if !protocol.IsMediaChannel(meta.Channel) {
t.Fatalf("Channel %d not in media range", meta.Channel)
@@ -72,9 +89,10 @@ func TestMediaCacheStoreAndGetBlock(t *testing.T) {
// original.
func TestMediaCacheStoreGzip(t *testing.T) {
cache := NewMediaCache(MediaCacheConfig{
MaxFileBytes: 1 << 20,
TTL: time.Hour,
Compression: protocol.MediaCompressionGzip,
MaxFileBytes: 1 << 20,
TTL: time.Hour,
Compression: protocol.MediaCompressionGzip,
DNSRelayEnabled: true,
})
content := bytes.Repeat([]byte("compress-me "), 200)
+35 -14
View File
@@ -28,13 +28,26 @@ type PublicReader struct {
client *http.Client
baseURL string
mu sync.RWMutex
cache map[string]cachedMessages
cacheTTL time.Duration
mu sync.RWMutex
cache map[string]cachedMessages
cacheTTL time.Duration
fetchInterval time.Duration
refreshCh chan struct{}
}
// SetFetchInterval overrides the default 10m fetch cadence. Caller must
// invoke before Run starts.
func (pr *PublicReader) SetFetchInterval(d time.Duration) {
if d <= 0 {
return
}
pr.mu.Lock()
pr.fetchInterval = d
pr.cacheTTL = d
pr.mu.Unlock()
}
// NewPublicReader creates a reader for public channels without Telegram login.
func NewPublicReader(channelUsernames []string, feed *Feed, msgLimit int, baseCh int) *PublicReader {
cleaned := make([]string, len(channelUsernames))
@@ -56,9 +69,10 @@ func NewPublicReader(channelUsernames []string, feed *Feed, msgLimit int, baseCh
Timeout: 30 * time.Second,
},
baseURL: "https://t.me/s",
cache: make(map[string]cachedMessages),
cacheTTL: 10 * time.Minute,
refreshCh: make(chan struct{}, 1),
cache: make(map[string]cachedMessages),
cacheTTL: 10 * time.Minute,
fetchInterval: 10 * time.Minute,
refreshCh: make(chan struct{}, 1),
}
}
@@ -67,9 +81,10 @@ func (pr *PublicReader) Run(ctx context.Context) error {
pr.feed.SetTelegramLoggedIn(false)
pr.fetchAll(ctx)
ticker := time.NewTicker(10 * time.Minute)
interval := pr.fetchInterval
ticker := time.NewTicker(interval)
defer ticker.Stop()
pr.feed.SetNextFetch(uint32(time.Now().Add(10 * time.Minute).Unix()))
pr.feed.SetNextFetch(uint32(time.Now().Add(interval).Unix()))
for {
select {
@@ -77,14 +92,14 @@ func (pr *PublicReader) Run(ctx context.Context) error {
return ctx.Err()
case <-ticker.C:
pr.fetchAll(ctx)
pr.feed.SetNextFetch(uint32(time.Now().Add(10 * time.Minute).Unix()))
pr.feed.SetNextFetch(uint32(time.Now().Add(interval).Unix()))
case <-pr.refreshCh:
pr.mu.Lock()
pr.cache = make(map[string]cachedMessages)
pr.mu.Unlock()
pr.fetchAll(ctx)
ticker.Reset(10 * time.Minute)
pr.feed.SetNextFetch(uint32(time.Now().Add(10 * time.Minute).Unix()))
ticker.Reset(interval)
pr.feed.SetNextFetch(uint32(time.Now().Add(interval).Unix()))
}
}
}
@@ -112,14 +127,18 @@ func (pr *PublicReader) UpdateChannels(channels []string) {
func (pr *PublicReader) fetchAll(ctx context.Context) {
log.Printf("[public] fetch cycle started for %d channels", len(pr.channels))
start := time.Now()
var fetched, failed int
var fetched, failed, skipped int
pr.mu.RLock()
cacheTTL := pr.cacheTTL
pr.mu.RUnlock()
for i, username := range pr.channels {
chNum := pr.baseCh + i
pr.mu.RLock()
cached, ok := pr.cache[username]
pr.mu.RUnlock()
if ok && time.Since(cached.fetched) < pr.cacheTTL {
if ok && time.Since(cached.fetched) < cacheTTL {
skipped++
continue
}
@@ -148,7 +167,9 @@ func (pr *PublicReader) fetchAll(ctx context.Context) {
fetched++
log.Printf("[public] updated %s (%s): %d messages", username, title, len(msgs))
}
log.Printf("[public] fetch cycle done in %s: %d fetched, %d failed, %d total", time.Since(start).Round(time.Millisecond), fetched, failed, len(pr.channels))
log.Printf("[public] fetch cycle done in %s: %d fetched, %d failed, %d skipped, %d total",
time.Since(start).Round(time.Millisecond), fetched, failed, skipped, len(pr.channels))
pr.feed.AfterFetchCycle(ctx)
}
func (pr *PublicReader) fetchChannel(ctx context.Context, username string) ([]protocol.Message, string, error) {
+60 -29
View File
@@ -25,21 +25,32 @@ type Config struct {
NoTelegram bool // if true, fetch public channels without Telegram login
AllowManage bool // if true, remote channel management and sending via DNS is allowed
Debug bool // if true, log every decoded DNS query
// NoMedia disables downloading and serving image/file media. When set, the
// server emits the legacy [TAG]\ncaption form for media messages so old
// clients keep working unchanged.
NoMedia bool
// MediaMaxSize is the per-file cap in bytes for cached media. 0 means no
// cap (not recommended in production).
MediaMaxSize int64
// MediaCacheTTL is the cache lifetime in minutes for a single entry. The
// effective TTL is reset whenever the same upstream id is fetched again.
MediaCacheTTL int
// MediaCompression names the compression applied to cached media bytes
// before they're split into DNS blocks. One of "none", "gzip",
// "deflate". Empty defaults to "gzip".
MediaCompression string
Telegram TelegramConfig
// DNSMediaEnabled toggles the slow DNS-relay path. When false the
// server still ingests media bytes (so other relays can serve them)
// but the wire-format DNS flag is unset for clients.
DNSMediaEnabled bool
DNSMediaMaxSize int64 // per-file cap for the DNS relay (0 = no cap)
DNSMediaCacheTTL int // DNS-relay TTL in minutes
DNSMediaCompression string // DNS-relay compression: none|gzip|deflate
FetchInterval time.Duration // 0 = default 10m; floor enforced by main
GitHubRelay GitHubRelayConfig
Telegram TelegramConfig
}
// GitHubRelayConfig configures the GitHub fast relay. Active() requires
// Enabled + Token + Repo.
type GitHubRelayConfig struct {
Enabled bool
Token string
Repo string
Branch string // default branch to commit to; "" → "main"
StatePath string // file used to persist lastSeen across restarts
MaxBytes int64
TTLMinutes int
}
func (g GitHubRelayConfig) Active() bool {
return g.Enabled && g.Token != "" && g.Repo != ""
}
// Server orchestrates the DNS server and Telegram reader.
@@ -81,35 +92,52 @@ func (s *Server) Run(ctx context.Context) error {
SetMediaDebugLogs(s.cfg.Debug)
// Configure media cache before any reader starts so the very first fetch
// cycle can populate it. When --no-media is set we leave Feed.media as
// nil; the readers fall through to the legacy [TAG]\ncaption form, and
// Feed.GetBlock rejects media-channel queries with not-found.
if !s.cfg.NoMedia {
ttlMin := s.cfg.MediaCacheTTL
// Spin up the media cache when at least one relay is enabled. The cache
// owns the byte pipeline; whether DNS or GitHub serves bytes to clients
// is controlled by per-relay flags on each MediaMeta.
anyRelay := s.cfg.DNSMediaEnabled || s.cfg.GitHubRelay.Active()
if anyRelay {
ttlMin := s.cfg.DNSMediaCacheTTL
if ttlMin <= 0 {
ttlMin = 600
}
ttl := time.Duration(ttlMin) * time.Minute
compName := s.cfg.MediaCompression
compName := s.cfg.DNSMediaCompression
if compName == "" {
compName = "gzip"
}
compression, err := protocol.ParseMediaCompressionName(compName)
if err != nil {
return fmt.Errorf("--media-compression: %w", err)
return fmt.Errorf("--dns-media-compression: %w", err)
}
mediaCache := NewMediaCache(MediaCacheConfig{
MaxFileBytes: s.cfg.MediaMaxSize,
TTL: ttl,
Compression: compression,
Logf: logfMedia,
MaxFileBytes: s.cfg.DNSMediaMaxSize,
TTL: ttl,
Compression: compression,
Logf: logfMedia,
DNSRelayEnabled: s.cfg.DNSMediaEnabled,
})
s.feed.SetMediaCache(mediaCache)
log.Printf("[server] media cache enabled: max-size=%d bytes, ttl=%s, compression=%s", s.cfg.MediaMaxSize, ttl, compression)
log.Printf("[server] media: dns=%v max=%d ttl=%s compression=%s",
s.cfg.DNSMediaEnabled, s.cfg.DNSMediaMaxSize, ttl, compression)
go s.runMediaSweep(ctx, mediaCache, ttl)
if s.cfg.GitHubRelay.Active() {
gh := NewGitHubRelay(s.cfg.GitHubRelay, s.cfg.Domain, s.cfg.Passphrase)
if gh != nil {
mediaCache.SetGitHubRelay(gh)
s.feed.SetGitHubRelay(gh)
go gh.Run(ctx)
branch := s.cfg.GitHubRelay.Branch
if branch == "" {
branch = "main"
}
log.Printf("[server] github relay: repo=%s branch=%s max=%d ttl=%dm",
gh.Repo(), branch, gh.MaxBytes(), s.cfg.GitHubRelay.TTLMinutes)
}
}
} else {
log.Println("[server] media cache disabled (--no-media)")
log.Println("[server] media disabled (no relays enabled)")
}
go startLatestVersionTracker(ctx, s.feed)
@@ -129,6 +157,7 @@ func (s *Server) Run(ctx context.Context) error {
}
if len(s.telegramChannels) > 0 {
reader := NewTelegramReader(s.cfg.Telegram, s.telegramChannels, s.feed, msgLimit, 1)
reader.SetFetchInterval(s.cfg.FetchInterval)
s.reader = reader
channelCtl = reader
go func() {
@@ -148,6 +177,7 @@ func (s *Server) Run(ctx context.Context) error {
msgLimit = 15
}
publicReader := NewPublicReader(s.telegramChannels, s.feed, msgLimit, 1)
publicReader.SetFetchInterval(s.cfg.FetchInterval)
channelCtl = publicReader
go func() {
log.Println("[public] reader goroutine started")
@@ -167,6 +197,7 @@ func (s *Server) Run(ctx context.Context) error {
msgLimit = 15
}
xReader = NewXPublicReader(s.xAccounts, s.feed, msgLimit, len(s.telegramChannels)+1, s.cfg.XRSSInstances)
xReader.SetFetchInterval(s.cfg.FetchInterval)
go func() {
log.Println("[x] reader goroutine started")
if err := xReader.Run(ctx); err != nil && ctx.Err() == nil {
+39 -22
View File
@@ -63,9 +63,10 @@ type TelegramReader struct {
msgLimit int // max messages to fetch per channel
baseCh int
mu sync.RWMutex
cache map[string]cachedMessages
cacheTTL time.Duration
mu sync.RWMutex
cache map[string]cachedMessages
cacheTTL time.Duration
fetchInterval time.Duration
// api is set once authenticated, used for sending messages.
apiMu sync.RWMutex
@@ -74,6 +75,17 @@ type TelegramReader struct {
refreshCh chan struct{} // signals Run() to re-fetch immediately
}
// SetFetchInterval overrides the default 10m fetch cadence.
func (tr *TelegramReader) SetFetchInterval(d time.Duration) {
if d <= 0 {
return
}
tr.mu.Lock()
tr.fetchInterval = d
tr.cacheTTL = d
tr.mu.Unlock()
}
// resolvedPeer holds the resolved Telegram peer along with its chat type.
type resolvedPeer struct {
peer tg.InputPeerClass
@@ -100,14 +112,15 @@ func NewTelegramReader(cfg TelegramConfig, channelUsernames []string, feed *Feed
baseCh = 1
}
return &TelegramReader{
cfg: cfg,
channels: cleaned,
feed: feed,
msgLimit: msgLimit,
baseCh: baseCh,
cache: make(map[string]cachedMessages),
cacheTTL: 10 * time.Minute,
refreshCh: make(chan struct{}, 1),
cfg: cfg,
channels: cleaned,
feed: feed,
msgLimit: msgLimit,
baseCh: baseCh,
cache: make(map[string]cachedMessages),
cacheTTL: 10 * time.Minute,
fetchInterval: 10 * time.Minute,
refreshCh: make(chan struct{}, 1),
}
}
@@ -146,11 +159,11 @@ func (tr *TelegramReader) Run(ctx context.Context) error {
// Initial fetch
tr.fetchAll(ctx, api)
// Periodic fetch loop
ticker := time.NewTicker(10 * time.Minute)
interval := tr.fetchInterval
ticker := time.NewTicker(interval)
defer ticker.Stop()
tr.feed.SetNextFetch(uint32(time.Now().Add(10 * time.Minute).Unix()))
tr.feed.SetNextFetch(uint32(time.Now().Add(interval).Unix()))
for {
select {
@@ -158,15 +171,14 @@ func (tr *TelegramReader) Run(ctx context.Context) error {
return ctx.Err()
case <-ticker.C:
tr.fetchAll(ctx, api)
tr.feed.SetNextFetch(uint32(time.Now().Add(10 * time.Minute).Unix()))
tr.feed.SetNextFetch(uint32(time.Now().Add(interval).Unix()))
case <-tr.refreshCh:
// Invalidate cache so fetchAll re-fetches everything.
tr.mu.Lock()
tr.cache = make(map[string]cachedMessages)
tr.mu.Unlock()
tr.fetchAll(ctx, api)
ticker.Reset(10 * time.Minute)
tr.feed.SetNextFetch(uint32(time.Now().Add(10 * time.Minute).Unix()))
ticker.Reset(interval)
tr.feed.SetNextFetch(uint32(time.Now().Add(interval).Unix()))
}
}
})
@@ -222,15 +234,18 @@ func (tr *TelegramReader) authenticate(ctx context.Context, client *telegram.Cli
func (tr *TelegramReader) fetchAll(ctx context.Context, api *tg.Client) {
log.Printf("[telegram] fetch cycle started for %d channels", len(tr.channels))
start := time.Now()
var fetched, failed int
var fetched, failed, skipped int
tr.mu.RLock()
cacheTTL := tr.cacheTTL
tr.mu.RUnlock()
for i, username := range tr.channels {
chNum := tr.baseCh + i
// Check cache
tr.mu.RLock()
cached, ok := tr.cache[username]
tr.mu.RUnlock()
if ok && time.Since(cached.fetched) < tr.cacheTTL {
if ok && time.Since(cached.fetched) < cacheTTL {
skipped++
continue
}
@@ -272,7 +287,9 @@ func (tr *TelegramReader) fetchAll(ctx context.Context, api *tg.Client) {
fetched++
log.Printf("[telegram] updated %s (%s): %d messages (type=%d, canSend=%v)", username, rp.title, len(msgs), rp.chatType, rp.canSend)
}
log.Printf("[telegram] fetch cycle done in %s: %d fetched, %d failed, %d total", time.Since(start).Round(time.Millisecond), fetched, failed, len(tr.channels))
log.Printf("[telegram] fetch cycle done in %s: %d fetched, %d failed, %d skipped, %d total",
time.Since(start).Round(time.Millisecond), fetched, failed, skipped, len(tr.channels))
tr.feed.AfterFetchCycle(ctx)
}
// resolvePeer resolves a Telegram username to an InputPeer, handling channels,
+30 -14
View File
@@ -32,13 +32,25 @@ type XPublicReader struct {
client *http.Client
instances []string
mu sync.RWMutex
cache map[string]cachedMessages
cacheTTL time.Duration
mu sync.RWMutex
cache map[string]cachedMessages
cacheTTL time.Duration
fetchInterval time.Duration
refreshCh chan struct{}
}
// SetFetchInterval overrides the default 10m fetch cadence.
func (xr *XPublicReader) SetFetchInterval(d time.Duration) {
if d <= 0 {
return
}
xr.mu.Lock()
xr.fetchInterval = d
xr.cacheTTL = d
xr.mu.Unlock()
}
const maxXRSSBodyBytes int64 = 2 << 20 // 2 MiB
var xSnowflakeRe = regexp.MustCompile(`\d{8,}`)
@@ -74,10 +86,11 @@ func NewXPublicReader(accounts []string, feed *Feed, msgLimit int, baseCh int, i
return http.ErrUseLastResponse
},
},
instances: instances,
cache: make(map[string]cachedMessages),
cacheTTL: 10 * time.Minute,
refreshCh: make(chan struct{}, 1),
instances: instances,
cache: make(map[string]cachedMessages),
cacheTTL: 10 * time.Minute,
fetchInterval: 10 * time.Minute,
refreshCh: make(chan struct{}, 1),
}
}
@@ -124,7 +137,8 @@ func normalizeXRSSInstances(instancesCSV string) []string {
func (xr *XPublicReader) Run(ctx context.Context) error {
xr.fetchAll(ctx)
ticker := time.NewTicker(10 * time.Minute)
interval := xr.fetchInterval
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
@@ -138,7 +152,7 @@ func (xr *XPublicReader) Run(ctx context.Context) error {
xr.cache = make(map[string]cachedMessages)
xr.mu.Unlock()
xr.fetchAll(ctx)
ticker.Reset(10 * time.Minute)
ticker.Reset(interval)
}
}
}
@@ -162,14 +176,13 @@ func (xr *XPublicReader) SetBaseCh(baseCh int) {
func (xr *XPublicReader) fetchAll(ctx context.Context) {
log.Printf("[x] fetch cycle started for %d accounts (instances: %v)", len(xr.accounts), xr.instances)
start := time.Now()
var fetched, failed int
var fetched, failed, skipped int
xr.mu.RLock()
baseCh := xr.baseCh
cacheTTL := xr.cacheTTL
xr.mu.RUnlock()
// Always set ChatType for all X accounts upfront, so channels show the X flag
// even if the Nitter fetch fails or the cache is still valid.
for i := range xr.accounts {
xr.feed.SetChatInfo(baseCh+i, protocol.ChatTypeX, false)
}
@@ -180,7 +193,8 @@ func (xr *XPublicReader) fetchAll(ctx context.Context) {
xr.mu.RLock()
cached, ok := xr.cache[account]
xr.mu.RUnlock()
if ok && time.Since(cached.fetched) < xr.cacheTTL {
if ok && time.Since(cached.fetched) < cacheTTL {
skipped++
continue
}
@@ -210,7 +224,9 @@ func (xr *XPublicReader) fetchAll(ctx context.Context) {
fetched++
log.Printf("[x] updated @%s: %d posts", account, len(msgs))
}
log.Printf("[x] fetch cycle done in %s: %d fetched, %d failed, %d total", time.Since(start).Round(time.Millisecond), fetched, failed, len(xr.accounts))
log.Printf("[x] fetch cycle done in %s: %d fetched, %d failed, %d skipped, %d total",
time.Since(start).Round(time.Millisecond), fetched, failed, skipped, len(xr.accounts))
xr.feed.AfterFetchCycle(ctx)
}
func (xr *XPublicReader) fetchAccount(ctx context.Context, username string) ([]protocol.Message, string, error) {