feat: implement AES-256 block cipher for query encryption and decryption

This commit is contained in:
Sarto
2026-03-25 23:13:33 +03:30
parent 6753c541e7
commit e9226d6543
12 changed files with 1021 additions and 364 deletions
+25 -11
View File
@@ -96,7 +96,7 @@ func TestE2E_FetchMetadataThroughDNS(t *testing.T) {
t.Fatalf("create fetcher: %v", err)
}
meta, err := fetcher.FetchMetadata()
meta, err := fetcher.FetchMetadata(context.Background())
if err != nil {
t.Fatalf("fetch metadata: %v", err)
}
@@ -136,7 +136,7 @@ func TestE2E_FetchChannelMessages(t *testing.T) {
t.Fatalf("create fetcher: %v", err)
}
meta, err := fetcher.FetchMetadata()
meta, err := fetcher.FetchMetadata(context.Background())
if err != nil {
t.Fatalf("fetch metadata: %v", err)
}
@@ -146,7 +146,7 @@ func TestE2E_FetchChannelMessages(t *testing.T) {
t.Fatal("expected blocks > 0")
}
fetchedMsgs, err := fetcher.FetchChannel(1, blockCount)
fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, blockCount)
if err != nil {
t.Fatalf("fetch channel: %v", err)
}
@@ -180,14 +180,14 @@ func TestE2E_FetchWithDoubleLabel(t *testing.T) {
if err != nil {
t.Fatalf("create fetcher: %v", err)
}
fetcher.SetQueryMode(protocol.QueryDoubleLabel)
fetcher.SetQueryMode(protocol.QueryMultiLabel)
meta, err := fetcher.FetchMetadata()
meta, err := fetcher.FetchMetadata(context.Background())
if err != nil {
t.Fatalf("fetch metadata: %v", err)
}
fetchedMsgs, err := fetcher.FetchChannel(1, int(meta.Channels[0].Blocks))
fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, int(meta.Channels[0].Blocks))
if err != nil {
t.Fatalf("fetch channel: %v", err)
}
@@ -216,7 +216,7 @@ func TestE2E_WrongPassphrase(t *testing.T) {
t.Fatalf("create fetcher: %v", err)
}
_, err = fetcher.FetchMetadata()
_, err = fetcher.FetchMetadata(context.Background())
if err == nil {
t.Fatal("expected error with wrong passphrase, got nil")
}
@@ -243,12 +243,12 @@ func TestE2E_LargeMessages(t *testing.T) {
t.Fatalf("create fetcher: %v", err)
}
meta, err := fetcher.FetchMetadata()
meta, err := fetcher.FetchMetadata(context.Background())
if err != nil {
t.Fatalf("fetch metadata: %v", err)
}
fetchedMsgs, err := fetcher.FetchChannel(1, int(meta.Channels[0].Blocks))
fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, int(meta.Channels[0].Blocks))
if err != nil {
t.Fatalf("fetch channel: %v", err)
}
@@ -586,8 +586,15 @@ func TestE2E_FullRoundTrip(t *testing.T) {
t.Fatalf("config POST status=%d", resp.StatusCode)
}
// Wait for auto-refresh to fetch data
time.Sleep(3 * time.Second)
// Refresh channels via selected-channel API semantics.
respRefresh1, err := http.Post(base+"/api/refresh?channel=1", "application/json", nil)
if err != nil {
t.Fatalf("POST /api/refresh?channel=1: %v", err)
}
respRefresh1.Body.Close()
// Give channel 1 refresh goroutine time to complete before refreshing channel 2,
// because starting a new refresh cancels the previous in-flight refresh.
time.Sleep(700 * time.Millisecond)
// Channels should be populated
resp2, err := http.Get(base + "/api/channels")
@@ -621,6 +628,13 @@ func TestE2E_FullRoundTrip(t *testing.T) {
t.Errorf("msg[0].Text = %q, want %q", msgList[0].Text, "General message 1")
}
respRefresh2, err := http.Post(base+"/api/refresh?channel=2", "application/json", nil)
if err != nil {
t.Fatalf("POST /api/refresh?channel=2: %v", err)
}
respRefresh2.Body.Close()
time.Sleep(700 * time.Millisecond)
// Messages for channel 2
resp4, err := http.Get(base + "/api/messages/2")
if err != nil {
+273 -92
View File
@@ -1,6 +1,7 @@
package client
import (
"context"
"fmt"
"math/rand"
"strings"
@@ -12,26 +13,36 @@ import (
"github.com/sartoopjj/thefeed/internal/protocol"
)
// LogFunc is a callback for logging DNS queries (for debug/TUI).
// LogFunc is a callback for log messages.
type LogFunc func(msg string)
// noiseDomains are popular domains used to blend feed queries into normal-looking DNS traffic.
var noiseDomains = []string{
"www.google.com", "www.cloudflare.com", "one.one.one.one",
"www.youtube.com", "www.instagram.com", "www.amazon.com",
"www.microsoft.com", "www.apple.com", "www.github.com",
"www.wikipedia.org", "www.reddit.com", "www.twitter.com",
}
// Fetcher fetches feed blocks over DNS.
type Fetcher struct {
domain string
queryKey [protocol.KeySize]byte
responseKey [protocol.KeySize]byte
queryMode protocol.QueryEncoding
timeout time.Duration
mu sync.RWMutex
resolvers []string
timeout time.Duration
// Resolver pools — allResolvers is what the user configured;
// activeResolvers is kept up-to-date by ResolverChecker (only healthy ones).
mu sync.RWMutex
allResolvers []string
activeResolvers []string
// Rate limiting
rateMu sync.Mutex
queryDelay time.Duration
lastQuery time.Time
// Rate limiting via token bucket; nil means unlimited.
rateQPS float64
rateCh chan struct{}
// Debug logging
debug bool
logFunc LogFunc
}
@@ -42,23 +53,28 @@ func NewFetcher(domain, passphrase string, resolvers []string) (*Fetcher, error)
return nil, fmt.Errorf("derive keys: %w", err)
}
r := make([]string, len(resolvers))
copy(r, resolvers)
return &Fetcher{
domain: strings.TrimSuffix(domain, "."),
queryKey: qk,
responseKey: rk,
queryMode: protocol.QuerySingleLabel,
resolvers: resolvers,
timeout: 5 * time.Second,
domain: strings.TrimSuffix(domain, "."),
queryKey: qk,
responseKey: rk,
queryMode: protocol.QuerySingleLabel,
allResolvers: r,
activeResolvers: r,
timeout: 10 * time.Second,
}, nil
}
// SetRateLimit sets the maximum queries per second (0 = unlimited).
// SetRateLimit sets the maximum queries per second (0 = unlimited). Must be called before Start.
func (f *Fetcher) SetRateLimit(qps float64) {
if qps <= 0 {
f.queryDelay = 0
return
}
f.queryDelay = time.Duration(float64(time.Second) / qps)
f.rateQPS = qps
}
// SetTimeout sets the per-query DNS timeout.
func (f *Fetcher) SetTimeout(d time.Duration) {
f.timeout = d
}
// SetLogFunc sets the debug log callback.
@@ -66,83 +82,225 @@ func (f *Fetcher) SetLogFunc(fn LogFunc) {
f.logFunc = fn
}
// SetDebug enables or disables debug logging of generated query names.
func (f *Fetcher) SetDebug(debug bool) {
f.debug = debug
}
// SetQueryMode sets the DNS query encoding mode.
func (f *Fetcher) SetQueryMode(mode protocol.QueryEncoding) {
f.queryMode = mode
}
// SetActiveResolvers updates the healthy resolver pool. Called by ResolverChecker.
// If the new list is empty, the current pool is unchanged (to avoid blackout during a bad check).
func (f *Fetcher) SetActiveResolvers(resolvers []string) {
f.mu.Lock()
defer f.mu.Unlock()
if len(resolvers) > 0 {
f.activeResolvers = make([]string, len(resolvers))
copy(f.activeResolvers, resolvers)
}
}
// SetResolvers replaces the full resolver list and resets the active pool.
func (f *Fetcher) SetResolvers(resolvers []string) {
f.mu.Lock()
defer f.mu.Unlock()
f.allResolvers = make([]string, len(resolvers))
copy(f.allResolvers, resolvers)
f.activeResolvers = make([]string, len(resolvers))
copy(f.activeResolvers, resolvers)
}
// AllResolvers returns all user-configured resolvers.
func (f *Fetcher) AllResolvers() []string {
f.mu.RLock()
defer f.mu.RUnlock()
result := make([]string, len(f.allResolvers))
copy(result, f.allResolvers)
return result
}
// Resolvers returns the currently active (healthy) resolver list.
func (f *Fetcher) Resolvers() []string {
f.mu.RLock()
defer f.mu.RUnlock()
result := make([]string, len(f.activeResolvers))
copy(result, f.activeResolvers)
return result
}
// Start launches background goroutines (rate limiter and noise generator).
// ctx controls their lifetime — cancel it to cleanly stop them.
// Call once per fetcher configuration; creating a new fetcher replaces the old one.
func (f *Fetcher) Start(ctx context.Context) {
if f.rateQPS > 0 {
f.rateCh = make(chan struct{}, 1)
go f.runRateLimiter(ctx)
go f.runNoise(ctx)
}
}
// runRateLimiter issues one token into rateCh every 1/QPS seconds.
// The channel capacity is 1, so tokens do not accumulate (no burst).
func (f *Fetcher) runRateLimiter(ctx context.Context) {
interval := time.Duration(float64(time.Second) / f.rateQPS)
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
select {
case f.rateCh <- struct{}{}:
default: // bucket full; discard extra token to prevent burst
}
}
}
}
// runNoise sends decoy A-record queries to popular domains at a randomised
// rate matching the configured QPS, to make feed traffic look like normal DNS usage.
func (f *Fetcher) runNoise(ctx context.Context) {
interval := time.Duration(float64(time.Second) / f.rateQPS)
for {
// Random delay: 13× the query interval.
jitter := time.Duration(rand.Int63n(int64(2*interval) + 1))
select {
case <-ctx.Done():
return
case <-time.After(interval + jitter):
}
resolvers := f.Resolvers()
if len(resolvers) == 0 {
continue
}
resolver := resolvers[rand.Intn(len(resolvers))]
if !strings.Contains(resolver, ":") {
resolver += ":53"
}
target := noiseDomains[rand.Intn(len(noiseDomains))]
go func(r, d string) {
c := &dns.Client{Timeout: f.timeout}
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(d), dns.TypeA)
m.RecursionDesired = true
c.Exchange(m, r) //nolint:errcheck — fire-and-forget noise query
}(resolver, target)
}
}
func (f *Fetcher) log(format string, args ...any) {
if f.logFunc != nil {
f.logFunc(fmt.Sprintf(format, args...))
}
}
func (f *Fetcher) rateWait() {
if f.queryDelay <= 0 {
// logProgress logs a progress bar: "prefix [====> ] 45%"
func (f *Fetcher) logProgress(prefix string, current, total float64) {
if f.logFunc == nil || total <= 0 {
return
}
f.rateMu.Lock()
defer f.rateMu.Unlock()
elapsed := time.Since(f.lastQuery)
if elapsed < f.queryDelay {
time.Sleep(f.queryDelay - elapsed)
percent := int((current / total) * 100)
barLen := 20
filled := int((current / total) * float64(barLen))
empty := barLen - filled
bar := "["
for i := 0; i < filled; i++ {
bar += "="
}
f.lastQuery = time.Now()
if filled < barLen {
bar += ">"
}
for i := 0; i < empty-1; i++ {
bar += " "
}
bar += "]"
f.logFunc(fmt.Sprintf("%s %s %d%%", prefix, bar, percent))
}
// SetResolvers replaces the resolver list.
func (f *Fetcher) SetResolvers(resolvers []string) {
f.mu.Lock()
defer f.mu.Unlock()
f.resolvers = resolvers
// rateWait blocks until a rate-limit token is available or ctx is cancelled.
// Returns nil when a token was acquired, ctx.Err() when cancelled.
func (f *Fetcher) rateWait(ctx context.Context) error {
if f.rateCh == nil {
// Unlimited: just propagate any existing cancellation.
return ctx.Err()
}
select {
case <-f.rateCh:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Resolvers returns the current resolver list.
func (f *Fetcher) Resolvers() []string {
f.mu.RLock()
defer f.mu.RUnlock()
result := make([]string, len(f.resolvers))
copy(result, f.resolvers)
return result
}
// FetchBlock fetches a single block from a channel.
func (f *Fetcher) FetchBlock(channel, block uint16) ([]byte, error) {
f.rateWait()
qname, err := protocol.EncodeQuery(f.queryKey, channel, block, f.domain, f.queryMode)
if err != nil {
return nil, fmt.Errorf("encode query: %w", err)
}
f.log("Q ch=%d blk=%d → %s", channel, block, qname)
resolvers := f.Resolvers()
if len(resolvers) == 0 {
return nil, fmt.Errorf("no resolvers configured")
}
// Shuffle resolvers to distribute load
shuffled := make([]string, len(resolvers))
copy(shuffled, resolvers)
rand.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })
// FetchBlock fetches a single encrypted block from the given channel.
// It enqueues through the rate limiter and respects ctx cancellation.
// On transient failure it retries up to 2 additional times with a short back-off.
func (f *Fetcher) FetchBlock(ctx context.Context, channel, block uint16) ([]byte, error) {
const maxAttempts = 3
var lastErr error
for _, resolver := range shuffled {
data, err := f.queryResolver(resolver, qname)
if err != nil {
lastErr = err
continue
for attempt := 0; attempt < maxAttempts; attempt++ {
if attempt > 0 {
// Brief back-off before retry; bail immediately if ctx is done.
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Duration(attempt) * 500 * time.Millisecond):
}
}
return data, nil
}
return nil, fmt.Errorf("all resolvers failed, last error: %w", lastErr)
if err := f.rateWait(ctx); err != nil {
return nil, err
}
qname, err := protocol.EncodeQuery(f.queryKey, channel, block, f.domain, f.queryMode)
if err != nil {
return nil, fmt.Errorf("encode query: %w", err)
}
if f.debug {
f.log("[debug] query ch=%d blk=%d attempt=%d qname=%s", channel, block, attempt+1, qname)
}
resolvers := f.Resolvers()
if len(resolvers) == 0 {
return nil, fmt.Errorf("no active resolvers")
}
// Shuffle to spread load across resolvers.
shuffled := make([]string, len(resolvers))
copy(shuffled, resolvers)
rand.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })
for _, resolver := range shuffled {
if ctx.Err() != nil {
return nil, ctx.Err()
}
data, err := f.queryResolver(ctx, resolver, qname)
if err != nil {
lastErr = err
continue
}
return data, nil
}
lastErr = fmt.Errorf("all resolvers failed: %w", lastErr)
if attempt+1 < maxAttempts {
f.log("block ch=%d blk=%d attempt %d/%d failed, retrying: %v", channel, block, attempt+1, maxAttempts, lastErr)
}
}
return nil, lastErr
}
// FetchMetadata fetches and parses metadata (channel 0).
func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) {
data, err := f.FetchBlock(protocol.MetadataChannel, 0)
// FetchMetadata fetches and parses the metadata block (channel 0).
func (f *Fetcher) FetchMetadata(ctx context.Context) (*protocol.Metadata, error) {
data, err := f.FetchBlock(ctx, protocol.MetadataChannel, 0)
if err != nil {
return nil, fmt.Errorf("fetch metadata block 0: %w", err)
}
@@ -152,12 +310,15 @@ func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) {
return meta, nil
}
// Metadata might span multiple blocks
// Metadata may span multiple blocks.
allData := make([]byte, len(data))
copy(allData, data)
for blk := uint16(1); blk < 10; blk++ {
block, fetchErr := f.FetchBlock(protocol.MetadataChannel, blk)
if ctx.Err() != nil {
return nil, ctx.Err()
}
block, fetchErr := f.FetchBlock(ctx, protocol.MetadataChannel, blk)
if fetchErr != nil {
break
}
@@ -171,31 +332,41 @@ func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) {
return nil, fmt.Errorf("could not parse metadata: %w", err)
}
// FetchChannel fetches all blocks for a channel and parses messages.
func (f *Fetcher) FetchChannel(channelNum int, blockCount int) ([]protocol.Message, error) {
// FetchChannel fetches all blocks for a channel and returns the parsed messages.
// Cancelling ctx immediately aborts any queued or in-flight block fetches.
// Each block is retried individually via FetchBlock before the channel fetch fails.
func (f *Fetcher) FetchChannel(ctx context.Context, channelNum int, blockCount int) ([]protocol.Message, error) {
if blockCount <= 0 {
return nil, nil
}
type result struct {
type blockResult struct {
idx int
data []byte
err error
}
results := make(chan result, blockCount)
// Limit concurrency to 3 to reduce DNS burst traffic
sem := make(chan struct{}, 3)
results := make(chan blockResult, blockCount)
// Cap concurrent DNS queries at 5; the token-bucket rate limiter provides
// the actual throughput control regardless of this concurrency cap.
sem := make(chan struct{}, 5)
var wg sync.WaitGroup
for i := 0; i < blockCount; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
sem <- struct{}{}
// Acquire semaphore or bail on cancellation.
select {
case sem <- struct{}{}:
case <-ctx.Done():
results <- blockResult{idx: idx, err: ctx.Err()}
return
}
defer func() { <-sem }()
data, err := f.FetchBlock(uint16(channelNum), uint16(idx))
results <- result{idx: idx, data: data, err: err}
data, err := f.FetchBlock(ctx, uint16(channelNum), uint16(idx))
results <- blockResult{idx: idx, data: data, err: err}
}(i)
}
@@ -205,24 +376,33 @@ func (f *Fetcher) FetchChannel(channelNum int, blockCount int) ([]protocol.Messa
}()
ordered := make([][]byte, blockCount)
completed := 0
for r := range results {
if r.err != nil {
return nil, fmt.Errorf("fetch block %d: %w", r.idx, r.err)
if r.err == ctx.Err() {
// Context cancelled — abort immediately.
return nil, r.err
}
// FetchBlock already retried internally; log and treat as fatal for this channel.
f.log("Channel %d block %d permanently failed: %v", channelNum, r.idx, r.err)
return nil, fmt.Errorf("channel %d block %d: %w", channelNum, r.idx, r.err)
}
ordered[r.idx] = r.data
completed++
f.logProgress(fmt.Sprintf("Channel %d", channelNum), float64(completed), float64(blockCount))
}
var allData []byte
for _, block := range ordered {
allData = append(allData, block...)
for _, b := range ordered {
allData = append(allData, b...)
}
return protocol.ParseMessages(allData)
}
func (f *Fetcher) queryResolver(resolver, qname string) ([]byte, error) {
func (f *Fetcher) queryResolver(ctx context.Context, resolver, qname string) ([]byte, error) {
if !strings.Contains(resolver, ":") {
resolver = resolver + ":53"
resolver += ":53"
}
c := new(dns.Client)
@@ -231,8 +411,9 @@ func (f *Fetcher) queryResolver(resolver, qname string) ([]byte, error) {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(qname), dns.TypeTXT)
m.RecursionDesired = true
m.SetEdns0(4096, false) // advertise 4 KiB UDP buffer to avoid response truncation
resp, _, err := c.Exchange(m, resolver)
resp, _, err := c.ExchangeContext(ctx, m, resolver)
if err != nil {
return nil, fmt.Errorf("dns exchange with %s: %w", resolver, err)
}
+85 -137
View File
@@ -1,181 +1,129 @@
package client
import (
"bufio"
"context"
"fmt"
"log"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
"github.com/sartoopjj/thefeed/internal/protocol"
)
// ResolverScanner scans CIDR ranges to find working DNS resolvers.
type ResolverScanner struct {
fetcher *Fetcher
concurrency int
timeout time.Duration
// ResolverChecker periodically probes the fetcher's configured resolvers and
// updates the active (healthy) resolver pool. It replaces the old file/CIDR
// scanner — no file I/O; just a plain DNS probe on channel 0.
type ResolverChecker struct {
fetcher *Fetcher
timeout time.Duration
logFunc LogFunc
}
// NewResolverScanner creates a resolver scanner.
func NewResolverScanner(fetcher *Fetcher, concurrency int) *ResolverScanner {
if concurrency <= 0 {
concurrency = 50
// NewResolverChecker creates a health checker for the resolvers in fetcher.
// timeout is the per-probe deadline; 0 uses a 5-second default.
func NewResolverChecker(fetcher *Fetcher, timeout time.Duration) *ResolverChecker {
if timeout <= 0 {
timeout = 5 * time.Second
}
return &ResolverScanner{
fetcher: fetcher,
concurrency: concurrency,
timeout: 3 * time.Second,
return &ResolverChecker{
fetcher: fetcher,
timeout: timeout,
}
}
// ScanCIDR scans a CIDR range for working DNS resolvers.
func (rs *ResolverScanner) ScanCIDR(cidr string, onFound func(ip string)) error {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
return fmt.Errorf("parse CIDR %q: %w", cidr, err)
}
ips := expandCIDR(ipNet)
return rs.scanIPs(ips, onFound)
// SetLogFunc sets the callback used to emit health-check results to the log panel.
func (rc *ResolverChecker) SetLogFunc(fn LogFunc) {
rc.logFunc = fn
}
// ScanFile scans resolver IPs from a file (one per line, supports CIDR notation).
func (rs *ResolverScanner) ScanFile(path string, onFound func(ip string)) error {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
var ips []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if strings.Contains(line, "/") {
_, ipNet, err := net.ParseCIDR(line)
if err != nil {
log.Printf("[resolver] skip invalid CIDR: %s", line)
continue
// Start begins the periodic health-check loop in the background.
// An initial check runs immediately; subsequent checks happen every 10 minutes.
// ctx controls the lifetime — cancel it to stop the checker.
func (rc *ResolverChecker) Start(ctx context.Context) {
go func() {
rc.runCheck()
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
rc.runCheck()
}
ips = append(ips, expandCIDR(ipNet)...)
} else {
ips = append(ips, line)
}
}
if err := scanner.Err(); err != nil {
return err
}
return rs.scanIPs(ips, onFound)
}()
}
// CheckResolver tests if a single resolver works by querying metadata.
func (rs *ResolverScanner) CheckResolver(ip string) bool {
if !strings.Contains(ip, ":") {
ip = ip + ":53"
func (rc *ResolverChecker) runCheck() {
resolvers := rc.fetcher.AllResolvers()
if len(resolvers) == 0 {
return
}
// Create a new fetcher with only this resolver to avoid copying the lock.
tmpFetcher := &Fetcher{
domain: rs.fetcher.domain,
queryKey: rs.fetcher.queryKey,
responseKey: rs.fetcher.responseKey,
resolvers: []string{ip},
timeout: rs.timeout,
}
rc.log("Checking %d resolver(s)...", len(resolvers))
_, err := tmpFetcher.FetchBlock(0, 0)
return err == nil
}
func (rs *ResolverScanner) scanIPs(ips []string, onFound func(ip string)) error {
if len(ips) == 0 {
return fmt.Errorf("no IPs to scan")
}
var found atomic.Int32
sem := make(chan struct{}, rs.concurrency)
var healthy []string
var mu sync.Mutex
var wg sync.WaitGroup
sem := make(chan struct{}, 10) // probe up to 10 resolvers concurrently
for _, ip := range ips {
for _, r := range resolvers {
wg.Add(1)
go func(ip string) {
go func(r string) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
if rs.CheckResolver(ip) {
found.Add(1)
if onFound != nil {
onFound(ip)
}
if rc.checkOne(r) {
mu.Lock()
healthy = append(healthy, r)
mu.Unlock()
rc.log("Resolver OK: %s", r)
} else {
rc.log("Resolver failed: %s", r)
}
}(ip)
}(r)
}
wg.Wait()
if found.Load() == 0 {
return fmt.Errorf("no working resolvers found among %d IPs", len(ips))
}
return nil
rc.fetcher.SetActiveResolvers(healthy)
rc.log("Resolver check done: %d/%d healthy", len(healthy), len(resolvers))
}
// LoadResolversFile loads resolver IPs from a file (one per line).
func LoadResolversFile(path string) ([]string, error) {
f, err := os.Open(path)
// checkOne probes a single resolver by sending a metadata channel query
// (channel 0, block 0). A successful DNS response (any rcode that isn't a
// network/timeout error) means the resolver is reachable and understands the domain.
func (rc *ResolverChecker) checkOne(resolver string) bool {
if !strings.Contains(resolver, ":") {
resolver += ":53"
}
qname, err := protocol.EncodeQuery(
rc.fetcher.queryKey,
protocol.MetadataChannel, 0,
rc.fetcher.domain,
rc.fetcher.queryMode,
)
if err != nil {
return nil, err
return false
}
defer f.Close()
var resolvers []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
resolvers = append(resolvers, line)
}
return resolvers, scanner.Err()
c := &dns.Client{Timeout: rc.timeout}
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(qname), dns.TypeTXT)
m.RecursionDesired = true
resp, _, err := c.Exchange(m, resolver)
// We consider the resolver healthy if we get any DNS response back
// (even NXDOMAIN means the resolver forwarded the query to our server).
return err == nil && resp != nil
}
func expandCIDR(ipNet *net.IPNet) []string {
var ips []string
ip := ipNet.IP.Mask(ipNet.Mask)
for ip := cloneIP(ip); ipNet.Contains(ip); incIP(ip) {
// Skip network and broadcast addresses for /24 and smaller
ones, bits := ipNet.Mask.Size()
if bits-ones <= 8 {
last := ip[len(ip)-1]
if last == 0 || last == 255 {
continue
}
}
ips = append(ips, ip.String())
}
return ips
}
func cloneIP(ip net.IP) net.IP {
dup := make(net.IP, len(ip))
copy(dup, ip)
return dup
}
func incIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
func (rc *ResolverChecker) log(format string, args ...any) {
if rc.logFunc != nil {
rc.logFunc(fmt.Sprintf(format, args...))
}
}
+39
View File
@@ -67,3 +67,42 @@ func Decrypt(key [KeySize]byte, ciphertext []byte) ([]byte, error) {
nonce := ciphertext[:gcm.NonceSize()]
return gcm.Open(nil, nonce, ciphertext[gcm.NonceSize():], nil)
}
// encryptQueryBlock encrypts an 8-byte query payload using a direct AES-256 block cipher.
// The payload is expanded to one AES block (16 bytes) with 8 trailing zero bytes before
// encryption. No nonce or auth tag needed: the 4 random bytes in the payload guarantee
// unique ciphertext per query. Result is always 16 bytes.
func encryptQueryBlock(key [KeySize]byte, payload []byte) ([]byte, error) {
if len(payload) != QueryPayloadSize {
return nil, fmt.Errorf("encryptQueryBlock: payload must be %d bytes, got %d", QueryPayloadSize, len(payload))
}
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, err
}
var buf [aes.BlockSize]byte
copy(buf[:QueryPayloadSize], payload) // bytes 8-15 stay zero
block.Encrypt(buf[:], buf[:])
return buf[:], nil
}
// decryptQueryBlock decrypts a query ciphertext produced by encryptQueryBlock.
// Accepts ciphertext with optional random suffix bytes (≥ BlockSize); only the
// first BlockSize bytes are used. Verifies the last 8 bytes of plaintext are zero.
func decryptQueryBlock(key [KeySize]byte, ciphertext []byte) ([]byte, error) {
if len(ciphertext) < aes.BlockSize {
return nil, fmt.Errorf("decryptQueryBlock: need at least %d bytes, got %d", aes.BlockSize, len(ciphertext))
}
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, err
}
var buf [aes.BlockSize]byte
block.Decrypt(buf[:], ciphertext[:aes.BlockSize]) // ignore suffix
for i := QueryPayloadSize; i < aes.BlockSize; i++ {
if buf[i] != 0 {
return nil, fmt.Errorf("decryptQueryBlock: integrity check failed (wrong key?)")
}
}
return buf[:QueryPayloadSize], nil
}
+137 -29
View File
@@ -7,26 +7,48 @@ import (
"encoding/binary"
"encoding/hex"
"fmt"
"math/big"
"strconv"
"strings"
)
const (
maxDNSLabelLen = 63
maxDNSNameLen = 253 // without trailing dot
)
// QueryEncoding controls how DNS query subdomains are encoded.
type QueryEncoding int
const (
// QuerySingleLabel uses base32 in a single DNS label (default, stealthier).
QuerySingleLabel QueryEncoding = iota
// QueryDoubleLabel uses hex split across two DNS labels.
QueryDoubleLabel
// QueryMultiLabel uses hex split across multiple DNS labels.
QueryMultiLabel
// QueryPlainLabel encodes channel and block as plain decimal text (no query encryption).
// Responses are always encrypted regardless of this setting.
QueryPlainLabel
)
var b32 = base32.StdEncoding.WithPadding(base32.NoPadding)
// EncodeQuery creates an encrypted DNS query subdomain.
// EncodeQuery creates a DNS query subdomain for the given channel and block.
// Single-label (default): [base32_encrypted].domain
// Double-label: [hex_part1].[hex_part2].domain
// Payload: 4 random + 2 channel + 2 block = 8 bytes, encrypted with AES-GCM.
// Multi-label: [hex_part1].[hex_part2].domain
// Plain-label: c<channel>b<block>.domain (no query encryption)
// Responses are always encrypted regardless of mode.
func EncodeQuery(queryKey [KeySize]byte, channel, block uint16, domain string, mode QueryEncoding) (string, error) {
domain = strings.TrimSuffix(domain, ".")
if domain == "" {
return "", fmt.Errorf("empty domain")
}
// Plain text mode: no encryption, just human-readable label.
if mode == QueryPlainLabel {
label := fmt.Sprintf("c%db%d", channel, block)
return joinQName([]string{label}, domain)
}
payload := make([]byte, QueryPayloadSize)
if _, err := rand.Read(payload[:QueryPaddingSize]); err != nil {
@@ -36,24 +58,88 @@ func EncodeQuery(queryKey [KeySize]byte, channel, block uint16, domain string, m
binary.BigEndian.PutUint16(payload[QueryPaddingSize:], channel)
binary.BigEndian.PutUint16(payload[QueryPaddingSize+QueryChannelSize:], block)
encrypted, err := Encrypt(queryKey, payload)
encrypted, err := encryptQueryBlock(queryKey, payload)
if err != nil {
return "", fmt.Errorf("encrypt query: %w", err)
}
// Append 04 random suffix bytes so query length varies per request.
// The decoder strips these by only using the first aes.BlockSize bytes.
suffixLen, _ := rand.Int(rand.Reader, big.NewInt(5)) // [0,4]
suffix := make([]byte, int(suffixLen.Int64()))
rand.Read(suffix) //nolint:errcheck — non-critical randomness
ciphertext := append(encrypted, suffix...)
switch mode {
case QueryDoubleLabel:
h := hex.EncodeToString(encrypted)
mid := len(h) / 2
return fmt.Sprintf("%s.%s.%s", h[:mid], h[mid:], domain), nil
case QueryMultiLabel:
h := hex.EncodeToString(ciphertext)
labels := splitMultiLabel(h)
return joinQName(labels, domain)
default:
encoded := strings.ToLower(b32.EncodeToString(encrypted))
return fmt.Sprintf("%s.%s", encoded, domain), nil
encoded := strings.ToLower(b32.EncodeToString(ciphertext))
return joinQName([]string{encoded}, domain)
}
}
func splitLabel(s string, size int) []string {
if size <= 0 {
size = maxDNSLabelLen
}
if size > maxDNSLabelLen {
size = maxDNSLabelLen
}
parts := make([]string, 0, (len(s)+size-1)/size)
for len(s) > size {
parts = append(parts, s[:size])
s = s[size:]
}
if len(s) > 0 {
parts = append(parts, s)
}
return parts
}
// splitMultiLabel splits a hex string into two labels of randomised, unequal length.
// The first label is between 12 and (len-4) chars so the second is at least 4 chars.
// This makes query labels look less uniform across requests.
func splitMultiLabel(h string) []string {
if len(h) <= 8 {
return []string{h}
}
// first label: random length in [minFirst, len-4]
minFirst := 8
maxFirst := len(h) - 4
if maxFirst <= minFirst {
maxFirst = minFirst + 1
}
// crypto/rand for the split point; fall back to midpoint on error
split := (len(h) + 1) / 2 // default: slightly off-centre
if n, err := rand.Int(rand.Reader, big.NewInt(int64(maxFirst-minFirst+1))); err == nil {
split = minFirst + int(n.Int64())
}
return []string{h[:split], h[split:]}
}
func joinQName(labels []string, domain string) (string, error) {
for _, l := range labels {
if len(l) == 0 {
return "", fmt.Errorf("empty label")
}
if len(l) > maxDNSLabelLen {
return "", fmt.Errorf("label too long: %d", len(l))
}
}
qname := strings.Join(append(labels, domain), ".")
if len(qname) > maxDNSNameLen {
return "", fmt.Errorf("query name too long: %d", len(qname))
}
return qname, nil
}
// DecodeQuery parses and decrypts a DNS query subdomain.
// Auto-detects single-label (base32) or double-label (hex) encoding.
// Auto-detects plain-text (c<N>b<M>), single-label base32, or multi-label hex encoding.
func DecodeQuery(queryKey [KeySize]byte, qname, domain string) (channel, block uint16, err error) {
qname = strings.TrimSuffix(qname, ".")
domain = strings.TrimSuffix(domain, ".")
@@ -65,32 +151,54 @@ func DecodeQuery(queryKey [KeySize]byte, qname, domain string) (channel, block u
encoded := qname[:len(qname)-len(suffix)]
// Try base32 first (single label, no dots, or dots stripped)
b32str := strings.ReplaceAll(encoded, ".", "")
ciphertext, err := b32.DecodeString(strings.ToUpper(b32str))
if err == nil {
return decryptQuery(queryKey, ciphertext)
// Try plain-label first: c<channel>b<block> (short, no dots, all decimal)
if ch, blk, ok := parsePlainLabel(encoded); ok {
return ch, blk, nil
}
// Fall back to hex (double-label)
// Try base32 (single-label: no dots or dots stripped)
b32str := strings.ReplaceAll(encoded, ".", "")
if ct, e := b32.DecodeString(strings.ToUpper(b32str)); e == nil {
return parseQueryCiphertext(queryKey, ct)
}
// Fall back to hex (multi-label: dots stripped)
hexStr := strings.ReplaceAll(encoded, ".", "")
ciphertext, err = hex.DecodeString(hexStr)
if err != nil {
ct, e := hex.DecodeString(hexStr)
if e != nil {
return 0, 0, fmt.Errorf("decode query: invalid encoding")
}
return decryptQuery(queryKey, ciphertext)
return parseQueryCiphertext(queryKey, ct)
}
func decryptQuery(queryKey [KeySize]byte, ciphertext []byte) (channel, block uint16, err error) {
plaintext, err := Decrypt(queryKey, ciphertext)
// parsePlainLabel parses the plain-text query format "c<channel>b<block>".
// Returns ok=false if the string does not match this pattern.
func parsePlainLabel(s string) (channel, block uint16, ok bool) {
if len(s) < 3 || s[0] != 'c' {
return 0, 0, false
}
bi := strings.IndexByte(s[1:], 'b')
if bi < 0 {
return 0, 0, false
}
bi++ // adjust for the slice offset
chStr, bStr := s[1:bi], s[bi+1:]
if len(chStr) == 0 || len(bStr) == 0 {
return 0, 0, false
}
ch, err1 := strconv.ParseUint(chStr, 10, 16)
blk, err2 := strconv.ParseUint(bStr, 10, 16)
if err1 != nil || err2 != nil {
return 0, 0, false
}
return uint16(ch), uint16(blk), true
}
func parseQueryCiphertext(queryKey [KeySize]byte, ciphertext []byte) (channel, block uint16, err error) {
plaintext, err := decryptQueryBlock(queryKey, ciphertext)
if err != nil {
return 0, 0, fmt.Errorf("decrypt: %w", err)
}
if len(plaintext) != QueryPayloadSize {
return 0, 0, fmt.Errorf("invalid payload size: %d", len(plaintext))
}
channel = binary.BigEndian.Uint16(plaintext[QueryPaddingSize:])
block = binary.BigEndian.Uint16(plaintext[QueryPaddingSize+QueryChannelSize:])
return channel, block, nil
+78 -5
View File
@@ -1,6 +1,7 @@
package protocol
import (
"fmt"
"strings"
"testing"
)
@@ -43,21 +44,26 @@ func TestEncodeDecodeQuerySingleLabel(t *testing.T) {
}
}
func TestEncodeDecodeQueryDoubleLabel(t *testing.T) {
func TestEncodeDecodeQueryMultiLabel(t *testing.T) {
qk, _, err := DeriveKeys("test-key")
if err != nil {
t.Fatalf("DeriveKeys: %v", err)
}
domain := "t.example.com"
qname, err := EncodeQuery(qk, 3, 7, domain, QueryDoubleLabel)
qname, err := EncodeQuery(qk, 3, 7, domain, QueryMultiLabel)
if err != nil {
t.Fatal(err)
}
// Double label: two hex labels before domain
// Multi-label mode splits hex across labels; all must be DNS-safe.
subdomain := qname[:len(qname)-len(domain)-1]
parts := strings.Split(subdomain, ".")
if len(parts) != 2 {
t.Errorf("double-label query should have 2 parts, got %d: %q", len(parts), subdomain)
if len(parts) < 1 {
t.Errorf("multi-label query should have at least 1 part, got %d: %q", len(parts), subdomain)
}
for _, p := range parts {
if len(p) == 0 || len(p) > 63 {
t.Errorf("invalid label length %d in %q", len(p), p)
}
}
ch, blk, err := DecodeQuery(qk, qname, domain)
if err != nil {
@@ -68,6 +74,73 @@ func TestEncodeDecodeQueryDoubleLabel(t *testing.T) {
}
}
func TestEncodeQueryTooLongDomain(t *testing.T) {
qk, _, err := DeriveKeys("test-key")
if err != nil {
t.Fatalf("DeriveKeys: %v", err)
}
// 250-char domain should make qname exceed DNS 253-char limit.
longDomain := strings.Repeat("a", 250)
_, err = EncodeQuery(qk, 1, 1, longDomain, QueryMultiLabel)
if err == nil {
t.Fatal("expected error for too-long domain")
}
}
func TestEncodeDecodeQueryPlainLabel(t *testing.T) {
qk, _, err := DeriveKeys("test-key")
if err != nil {
t.Fatalf("DeriveKeys: %v", err)
}
domain := "t.example.com"
tests := []struct {
channel uint16
block uint16
}{
{0, 0},
{1, 42},
{255, 65535},
{3, 100},
}
for _, tt := range tests {
qname, err := EncodeQuery(qk, tt.channel, tt.block, domain, QueryPlainLabel)
if err != nil {
t.Fatalf("EncodeQuery(%d, %d): %v", tt.channel, tt.block, err)
}
// Label should be "c<channel>b<block>" — human readable, no padding hex.
want := fmt.Sprintf("c%db%d.%s", tt.channel, tt.block, domain)
if qname != want {
t.Errorf("got %q, want %q", qname, want)
}
// DecodeQuery must recover channel and block regardless of key.
ch, blk, err := DecodeQuery(qk, qname, domain)
if err != nil {
t.Fatalf("DecodeQuery: %v", err)
}
if ch != tt.channel || blk != tt.block {
t.Errorf("got ch=%d blk=%d, want ch=%d blk=%d", ch, blk, tt.channel, tt.block)
}
}
}
func TestPlainLabelNotConfusedWithEncrypted(t *testing.T) {
qk, _, err := DeriveKeys("test-key")
if err != nil {
t.Fatalf("DeriveKeys: %v", err)
}
domain := "t.example.com"
// Encode with single-label then check that DecodeQuery does NOT treat it as plain.
qname, _ := EncodeQuery(qk, 5, 10, domain, QuerySingleLabel)
ch, blk, err := DecodeQuery(qk, qname, domain)
if err != nil {
t.Fatalf("DecodeQuery single-label: %v", err)
}
if ch != 5 || blk != 10 {
t.Errorf("got ch=%d blk=%d, want ch=5 blk=10", ch, blk)
}
}
func TestDecodeQueryWrongKey(t *testing.T) {
qk1, _, _ := DeriveKeys("key1")
qk2, _, _ := DeriveKeys("key2")
+29 -11
View File
@@ -1,14 +1,21 @@
package protocol
import (
"crypto/rand"
"encoding/binary"
"fmt"
"math/big"
)
const (
// DefaultBlockPayload is the decrypted payload per DNS TXT block.
// Calculated to stay within 512-byte UDP DNS limit after encryption + base64 + padding overhead.
DefaultBlockPayload = 180
// MinBlockPayload is the minimum decrypted payload per DNS TXT block.
MinBlockPayload = 200
// MaxBlockPayload is the maximum decrypted payload per DNS TXT block.
// 600 bytes data + 28 GCM overhead + 2 prefix + 32 padding → ~856 base64 chars.
// Well within the 4096-byte EDNS0 UDP buffer the client advertises.
MaxBlockPayload = 600
// DefaultBlockPayload is kept for compatibility; equals MaxBlockPayload.
DefaultBlockPayload = MaxBlockPayload
// DefaultMaxPadding is the default random padding added to responses to vary DNS response size.
DefaultMaxPadding = 32
@@ -210,20 +217,31 @@ func ParseMessages(data []byte) ([]Message, error) {
return msgs, nil
}
// SplitIntoBlocks splits data into blocks of DefaultBlockPayload size.
// SplitIntoBlocks splits data into blocks of randomly varying size in [MinBlockPayload, MaxBlockPayload].
// Random sizes make traffic analysis harder; the client just concatenates all blocks to reassemble.
func SplitIntoBlocks(data []byte) [][]byte {
if len(data) == 0 {
return [][]byte{{}}
return [][]byte{{}} // channel 0 block 0 must always exist
}
var blocks [][]byte
for i := 0; i < len(data); i += DefaultBlockPayload {
end := i + DefaultBlockPayload
if end > len(data) {
end = len(data)
rem := data
for len(rem) > 0 {
size := randBlockSize()
if size > len(rem) {
size = len(rem)
}
block := make([]byte, end-i)
copy(block, data[i:end])
block := make([]byte, size)
copy(block, rem[:size])
blocks = append(blocks, block)
rem = rem[size:]
}
return blocks
}
func randBlockSize() int {
n, err := rand.Int(rand.Reader, big.NewInt(int64(MaxBlockPayload-MinBlockPayload+1)))
if err != nil {
return (MinBlockPayload + MaxBlockPayload) / 2
}
return MinBlockPayload + int(n.Int64())
}
+9 -9
View File
@@ -57,19 +57,19 @@ func TestSerializeParseMessages(t *testing.T) {
}
func TestSplitIntoBlocks(t *testing.T) {
data := bytes.Repeat([]byte("A"), DefaultBlockPayload*3+50)
// MaxBlockPayload*3+50 guarantees at least 4 blocks (ceil((MaxBlockPayload*3+50)/MaxBlockPayload) = 4).
data := bytes.Repeat([]byte("A"), MaxBlockPayload*3+50)
blocks := SplitIntoBlocks(data)
if len(blocks) != 4 {
t.Fatalf("blocks: got %d, want 4", len(blocks))
if len(blocks) < 4 {
t.Fatalf("expected at least 4 blocks for %d bytes, got %d", len(data), len(blocks))
}
for i, b := range blocks {
if i < 3 && len(b) != DefaultBlockPayload {
t.Errorf("block %d: size %d, want %d", i, len(b), DefaultBlockPayload)
// Every non-last block must be within [MinBlockPayload, MaxBlockPayload].
for i, b := range blocks[:len(blocks)-1] {
if len(b) < MinBlockPayload || len(b) > MaxBlockPayload {
t.Errorf("block %d: size %d, want [%d, %d]", i, len(b), MinBlockPayload, MaxBlockPayload)
}
}
if len(blocks[3]) != 50 {
t.Errorf("last block: size %d, want 50", len(blocks[3]))
}
// Reassembled data must equal original.
var reassembled []byte
for _, b := range blocks {
reassembled = append(reassembled, b...)
+3 -1
View File
@@ -71,7 +71,9 @@ func TestFeedGetBlockUnknownChannel(t *testing.T) {
func TestFeedLargeMessages(t *testing.T) {
feed := NewFeed([]string{"Test"})
largeText := make([]byte, 500)
// Use text large enough to span 2 blocks at DefaultBlockPayload (currently 700 bytes).
// Message serialization overhead is 10 bytes, so we need >690 bytes of text.
largeText := make([]byte, 750)
for i := range largeText {
largeText[i] = 65
}
+2 -2
View File
@@ -80,7 +80,7 @@ func NewTelegramReader(cfg TelegramConfig, channelUsernames []string, feed *Feed
channels: cleaned,
feed: feed,
cache: make(map[string]cachedMessages),
cacheTTL: 1 * time.Minute,
cacheTTL: 5 * time.Minute,
}
}
@@ -115,7 +115,7 @@ func (tr *TelegramReader) Run(ctx context.Context) error {
tr.fetchAll(ctx, api)
// Periodic fetch loop
ticker := time.NewTicker(1 * time.Minute)
ticker := time.NewTicker(3 * time.Minute)
defer ticker.Stop()
for {
+165 -17
View File
@@ -142,6 +142,44 @@
word-break: break-word;
}
.progress-panel {
height: 56px;
min-height: 56px;
background: var(--surface);
border-top: 1px solid var(--border);
border-bottom: 1px solid var(--border);
overflow-y: auto;
font-size: 12px;
font-family: 'Vazirmatn', monospace;
padding: 12px 12px;
direction: ltr;
text-align: left;
color: var(--text-dim);
}
.progress-item {
margin-bottom: 8px;
padding: 6px;
border-radius: 4px;
background: var(--bg);
}
.progress-label {
font-size: 11px;
margin-bottom: 3px;
color: var(--text-dim);
}
.progress-bar {
width: 100%;
height: 6px;
background: var(--border);
border-radius: 3px;
overflow: hidden;
}
.progress-fill {
height: 100%;
background: var(--accent);
transition: width 0.2s;
}
.log-panel {
height: 150px;
min-height: 100px;
@@ -149,13 +187,21 @@
border-top: 1px solid var(--border);
overflow-y: auto;
font-size: 12px;
font-family: 'Vazirmatn', monospace;
font-family: monospace;
padding: 8px 12px;
direction: ltr;
text-align: left;
color: var(--text-dim);
}
.log-line { padding: 2px 0; }
.log-line {
padding: 2px 0;
line-height: 1.3;
}
.log-line.info { color: #60a5fa; }
.log-line.progress { color: #fbbf24; }
.log-line.error { color: #ef4444; }
.log-line.success { color: #22c55e; }
.log-line.warning { color: #f97316; }
.modal-overlay {
display: none;
@@ -245,6 +291,7 @@
<button class="btn btn-primary" onclick="openSettings()">Configure</button>
</div>
</div>
<div class="progress-panel" id="progressPanel"></div>
<div class="log-panel" id="logPanel"></div>
</div>
</div>
@@ -268,12 +315,17 @@
<label>Query Mode</label>
<select id="cfgQueryMode">
<option value="single">Single label (base32, stealthier)</option>
<option value="double">Double label (hex)</option>
<option value="double">Multi-label (hex)</option>
<option value="plain">Plain text (no query encryption)</option>
</select>
</div>
<div class="form-group">
<label>Rate Limit (queries/sec, 0 = unlimited)</label>
<input type="number" id="cfgRateLimit" value="0" min="0" step="0.1">
<input type="number" id="cfgRateLimit" value="100" min="0" step="0.1">
</div>
<div class="form-group" style="flex-direction:row;align-items:center;gap:10px;">
<input type="checkbox" id="cfgDebug" style="width:auto;margin:0;">
<label for="cfgDebug" style="margin:0;cursor:pointer;">Debug mode (log generated query names)</label>
</div>
<div style="display:flex;gap:8px;justify-content:flex-end;margin-top:20px;">
<button class="btn" onclick="closeSettings()">Cancel</button>
@@ -287,6 +339,7 @@
let selectedChannel = 0;
let channels = [];
let eventSource = null;
let autoRefreshTimer = null;
async function init() {
try {
@@ -296,11 +349,16 @@
openSettings();
} else {
document.getElementById('statusDot').className = 'status-dot connected';
if (status.channels && status.channels.length > 0) {
channels = status.channels;
renderChannels();
selectChannel(1);
await doRefresh();
await loadChannels();
if (channels && channels.length > 0) {
await selectChannel(1);
}
autoRefreshTimer = setInterval(function() {
if (selectedChannel > 0) {
doRefresh();
}
}, 120000);
}
} catch (e) {}
connectSSE();
@@ -312,12 +370,23 @@
eventSource.addEventListener('log', function(e) {
addLogLine(JSON.parse(e.data));
});
eventSource.addEventListener('update', function(e) {
loadChannels();
if (selectedChannel > 0) loadMessages(selectedChannel);
eventSource.addEventListener('update', async function(e) {
var wasEmpty = channels.length === 0;
await loadChannels();
if (wasEmpty && channels.length > 0 && selectedChannel === 0) {
// Channels just appeared for the first time — auto-select the first one.
selectChannel(1);
} else if (selectedChannel > 0) {
loadMessages(selectedChannel);
}
});
eventSource.onerror = function() {
document.getElementById('statusDot').className = 'status-dot disconnected';
// EventSource.CLOSED (2) means the browser stopped retrying — reconnect manually.
if (eventSource.readyState === EventSource.CLOSED) {
eventSource.close();
setTimeout(connectSSE, 3000);
}
};
eventSource.onopen = function() {
document.getElementById('statusDot').className = 'status-dot connected';
@@ -343,6 +412,7 @@
async function selectChannel(num) {
selectedChannel = num;
renderChannels();
await doRefresh();
await loadMessages(num);
}
@@ -392,16 +462,79 @@
var el = document.getElementById('logPanel');
var div = document.createElement('div');
div.className = 'log-line';
// Parse log level from line content
var level = 'info';
var displayText = line;
if (typeof line === 'string') {
if (line.includes('Error:') || line.includes('error') || line.includes('Error')) {
level = 'error';
} else if (line.includes('Warning:') || line.includes('warning') || line.includes('Warning')) {
level = 'warning';
} else if (line.includes('OK:') || line.includes('OK') || line.includes('success') || line.includes('done')) {
level = 'success';
} else if (line.match(/\[\d+%\]|\d+\s*%/)) {
level = 'progress';
updateProgressDisplay(line);
return;
}
}
div.className = 'log-line ' + level;
div.textContent = line;
el.appendChild(div);
el.scrollTop = el.scrollHeight;
// Keep only last 200 lines
while (el.children.length > 200) {
el.removeChild(el.firstChild);
}
}
function updateProgressDisplay(line) {
// Parse progress from "Channel N [====> ] 45%"
var match = line.match(/Channel\s+(\d+)/);
if (!match) return;
var channelNum = match[1];
var percentMatch = line.match(/(\d+)%/);
var percent = percentMatch ? parseInt(percentMatch[1]) : 0;
var panel = document.getElementById('progressPanel');
var itemId = 'progress-current';
var item = document.getElementById(itemId);
if (!item) {
item = document.createElement('div');
item.id = itemId;
item.className = 'progress-item';
item.innerHTML = '<div class="progress-label">Channel ' + channelNum + '</div>' +
'<div class="progress-bar"><div class="progress-fill" style="width:0%"></div></div>';
panel.appendChild(item);
} else {
item.querySelector('.progress-label').textContent = 'Channel ' + channelNum;
}
var fill = item.querySelector('.progress-fill');
fill.style.width = percent + '%';
// Remove if complete
if (percent >= 100) {
setTimeout(function() {
if (item.parentNode) item.parentNode.removeChild(item);
}, 1000);
}
}
async function doRefresh() {
try { await fetch('/api/refresh', {method: 'POST'}); } catch (e) {}
try {
var url = '/api/refresh';
if (selectedChannel > 0) {
url += '?channel=' + selectedChannel;
}
await fetch(url, {method: 'POST'});
} catch (e) {}
}
function openSettings() {
@@ -411,7 +544,12 @@
if (cfg.key) document.getElementById('cfgKey').value = cfg.key;
if (cfg.resolvers) document.getElementById('cfgResolvers').value = cfg.resolvers.join('\n');
if (cfg.queryMode) document.getElementById('cfgQueryMode').value = cfg.queryMode;
if (cfg.rateLimit) document.getElementById('cfgRateLimit').value = cfg.rateLimit;
if (typeof cfg.rateLimit === 'number') {
document.getElementById('cfgRateLimit').value = cfg.rateLimit;
} else {
document.getElementById('cfgRateLimit').value = 100;
}
document.getElementById('cfgDebug').checked = !!cfg.debug;
}).catch(function() {});
}
@@ -432,7 +570,8 @@
key: document.getElementById('cfgKey').value,
resolvers: resolvers,
queryMode: document.getElementById('cfgQueryMode').value,
rateLimit: parseFloat(document.getElementById('cfgRateLimit').value) || 0
rateLimit: parseFloat(document.getElementById('cfgRateLimit').value) || 100,
debug: document.getElementById('cfgDebug').checked
};
if (!cfg.domain || !cfg.key || resolvers.length === 0) {
@@ -453,9 +592,18 @@
}
closeSettings();
document.getElementById('statusDot').className = 'status-dot connected';
setTimeout(function() {
loadChannels();
}, 2000);
await doRefresh();
await loadChannels();
if (channels && channels.length > 0) {
await selectChannel(1);
}
if (!autoRefreshTimer) {
autoRefreshTimer = setInterval(function() {
if (selectedChannel > 0) {
fetch('/api/refresh?channel=' + selectedChannel + '&quiet=1', {method: 'POST'});
}
}, 120000);
}
} catch (e) {
errEl.textContent = e.message;
errEl.style.display = 'block';
+176 -50
View File
@@ -1,6 +1,7 @@
package web
import (
"context"
"embed"
"encoding/json"
"fmt"
@@ -29,6 +30,11 @@ type Config struct {
Resolvers []string `json:"resolvers"`
QueryMode string `json:"queryMode"`
RateLimit float64 `json:"rateLimit"`
// Timeout is the per-query DNS timeout in seconds (0 = default 5 s).
// Also used as the resolver health-check probe timeout.
Timeout float64 `json:"timeout,omitempty"`
// Debug enables verbose query logging (shows generated DNS query names).
Debug bool `json:"debug,omitempty"`
}
// Server is the web UI server for thefeed client.
@@ -43,6 +49,16 @@ type Server struct {
channels []protocol.ChannelInfo
messages map[int][]protocol.Message
// fetcherCtx/fetcherCancel control the lifetime of the active fetcher's
// background goroutines (rate limiter, noise, resolver checker).
// They are cancelled and recreated each time the config changes.
fetcherCtx context.Context
fetcherCancel context.CancelFunc
// refreshMu / refreshCancel allow a new refresh to cancel an in-progress one.
refreshMu sync.Mutex
refreshCancel context.CancelFunc
logMu sync.RWMutex
logLines []string
@@ -96,7 +112,7 @@ func (s *Server) Run() error {
fmt.Printf("\n Open in browser: http://%s\n\n", addr)
if s.fetcher != nil {
s.startAutoRefresh()
go s.refreshMetadataOnly()
}
return http.ListenAndServe(addr, mux)
@@ -164,7 +180,7 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
http.Error(w, fmt.Sprintf("init fetcher: %v", err), 500)
return
}
s.startAutoRefresh()
go s.refreshMetadataOnly()
writeJSON(w, map[string]any{"ok": true})
default:
@@ -202,7 +218,28 @@ func (s *Server) handleRefresh(w http.ResponseWriter, r *http.Request) {
http.Error(w, "method not allowed", 405)
return
}
go s.refresh()
// Background (quiet) refreshes skip silently if one is already running,
// so the auto-refresh timer never cancels a slow in-progress fetch.
if r.URL.Query().Get("quiet") == "1" {
s.refreshMu.Lock()
running := s.refreshCancel != nil
s.refreshMu.Unlock()
if running {
writeJSON(w, map[string]any{"ok": true, "skipped": true})
return
}
}
chParam := r.URL.Query().Get("channel")
if chParam != "" {
chNum, err := strconv.Atoi(chParam)
if err != nil || chNum < 1 {
http.Error(w, "invalid channel", 400)
return
}
go s.refreshChannel(chNum)
} else {
go s.refreshMetadataOnly()
}
writeJSON(w, map[string]any{"ok": true})
}
@@ -278,6 +315,11 @@ func (s *Server) initFetcher() error {
s.mu.Lock()
defer s.mu.Unlock()
// Cancel goroutines from the previous fetcher configuration.
if s.fetcherCancel != nil {
s.fetcherCancel()
}
cfg := s.config
if cfg == nil {
return fmt.Errorf("no config")
@@ -295,57 +337,81 @@ func (s *Server) initFetcher() error {
}
if cfg.QueryMode == "double" {
fetcher.SetQueryMode(protocol.QueryDoubleLabel)
fetcher.SetQueryMode(protocol.QueryMultiLabel)
} else if cfg.QueryMode == "plain" {
fetcher.SetQueryMode(protocol.QueryPlainLabel)
}
fetcher.SetDebug(cfg.Debug)
if cfg.RateLimit > 0 {
fetcher.SetRateLimit(cfg.RateLimit)
}
timeout := 5 * time.Second
if cfg.Timeout > 0 {
timeout = time.Duration(cfg.Timeout * float64(time.Second))
}
fetcher.SetTimeout(timeout)
fetcher.SetLogFunc(func(msg string) {
s.addLog(msg)
})
// Create a shared context for this fetcher's lifetime.
ctx, cancel := context.WithCancel(context.Background())
s.fetcherCtx = ctx
s.fetcherCancel = cancel
// Start rate limiter and noise goroutines.
fetcher.Start(ctx)
// Start periodic resolver health checks.
checker := client.NewResolverChecker(fetcher, timeout)
checker.SetLogFunc(func(msg string) {
s.addLog(msg)
})
checker.Start(ctx)
s.fetcher = fetcher
s.cache = cache
return nil
}
func (s *Server) startAutoRefresh() {
if s.stopRefresh != nil {
close(s.stopRefresh)
func (s *Server) refreshMetadataOnly() {
// Cancel any in-progress refresh and start a new cancellable one.
s.refreshMu.Lock()
if s.refreshCancel != nil {
s.refreshCancel()
}
s.stopRefresh = make(chan struct{})
stop := s.stopRefresh
go s.refresh()
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-stop:
return
case <-ticker.C:
s.refresh()
}
}
}()
}
func (s *Server) refresh() {
s.mu.RLock()
basectx := s.fetcherCtx
fetcher := s.fetcher
cache := s.cache
s.mu.RUnlock()
if fetcher == nil {
if fetcher == nil || basectx == nil {
s.refreshMu.Unlock()
return
}
// Child context: cancelled either by the next refresh call or by a config change.
ctx, cancel := context.WithCancel(basectx)
s.refreshCancel = cancel
s.refreshMu.Unlock()
defer func() {
cancel()
s.refreshMu.Lock()
s.refreshCancel = nil
s.refreshMu.Unlock()
}()
s.addLog("Fetching metadata...")
meta, err := fetcher.FetchMetadata()
meta, err := fetcher.FetchMetadata(ctx)
if err != nil {
if ctx.Err() != nil {
s.addLog("Refresh cancelled")
return
}
s.addLog(fmt.Sprintf("Error: %v", err))
return
}
@@ -359,31 +425,91 @@ func (s *Server) refresh() {
}
s.broadcast("event: update\ndata: \"channels\"\n\n")
}
for i, ch := range meta.Channels {
chNum := i + 1
blockCount := int(ch.Blocks)
if blockCount <= 0 {
continue
}
msgs, err := fetcher.FetchChannel(chNum, blockCount)
if err != nil {
s.addLog(fmt.Sprintf("Channel %s error: %v", ch.Name, err))
continue
}
s.mu.Lock()
s.messages[chNum] = msgs
s.mu.Unlock()
if cache != nil {
_ = cache.PutMessages(chNum, msgs)
}
s.addLog(fmt.Sprintf("Updated %s: %d messages", ch.Name, len(msgs)))
func (s *Server) refreshChannel(channelNum int) {
s.refreshMu.Lock()
if s.refreshCancel != nil {
s.refreshCancel()
}
s.mu.RLock()
basectx := s.fetcherCtx
fetcher := s.fetcher
cache := s.cache
channels := s.channels
s.mu.RUnlock()
if fetcher == nil || basectx == nil {
s.refreshMu.Unlock()
return
}
ctx, cancel := context.WithCancel(basectx)
s.refreshCancel = cancel
s.refreshMu.Unlock()
defer func() {
cancel()
s.refreshMu.Lock()
s.refreshCancel = nil
s.refreshMu.Unlock()
}()
meta, err := fetcher.FetchMetadata(ctx)
if err != nil {
if ctx.Err() != nil {
s.addLog("Refresh cancelled")
return
}
s.addLog(fmt.Sprintf("Error: %v", err))
return
}
s.mu.Lock()
s.channels = meta.Channels
s.mu.Unlock()
if cache != nil {
_ = cache.PutMetadata(meta)
}
s.broadcast("event: update\ndata: \"channels\"\n\n")
channels = meta.Channels
if channelNum < 1 || channelNum > len(channels) {
s.addLog(fmt.Sprintf("Warning: channel %d is not available", channelNum))
return
}
ch := channels[channelNum-1]
blockCount := int(ch.Blocks)
if blockCount <= 0 {
s.mu.Lock()
s.messages[channelNum] = nil
s.mu.Unlock()
s.addLog(fmt.Sprintf("Updated %s: 0 messages", ch.Name))
s.broadcast("event: update\ndata: \"messages\"\n\n")
return
}
msgs, err := fetcher.FetchChannel(ctx, channelNum, blockCount)
if err != nil {
if ctx.Err() != nil {
s.addLog("Refresh cancelled")
return
}
s.addLog(fmt.Sprintf("Channel %s error: %v", ch.Name, err))
return
}
s.mu.Lock()
s.messages[channelNum] = msgs
s.mu.Unlock()
if cache != nil {
_ = cache.PutMessages(channelNum, msgs)
}
s.addLog(fmt.Sprintf("Updated %s: %d messages", ch.Name, len(msgs)))
s.broadcast("event: update\ndata: \"messages\"\n\n")
}