mirror of
https://github.com/sartoopjj/thefeed.git
synced 2026-05-18 05:24:36 +03:00
feat: implement AES-256 block cipher for query encryption and decryption
This commit is contained in:
+25
-11
@@ -96,7 +96,7 @@ func TestE2E_FetchMetadataThroughDNS(t *testing.T) {
|
||||
t.Fatalf("create fetcher: %v", err)
|
||||
}
|
||||
|
||||
meta, err := fetcher.FetchMetadata()
|
||||
meta, err := fetcher.FetchMetadata(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("fetch metadata: %v", err)
|
||||
}
|
||||
@@ -136,7 +136,7 @@ func TestE2E_FetchChannelMessages(t *testing.T) {
|
||||
t.Fatalf("create fetcher: %v", err)
|
||||
}
|
||||
|
||||
meta, err := fetcher.FetchMetadata()
|
||||
meta, err := fetcher.FetchMetadata(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("fetch metadata: %v", err)
|
||||
}
|
||||
@@ -146,7 +146,7 @@ func TestE2E_FetchChannelMessages(t *testing.T) {
|
||||
t.Fatal("expected blocks > 0")
|
||||
}
|
||||
|
||||
fetchedMsgs, err := fetcher.FetchChannel(1, blockCount)
|
||||
fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, blockCount)
|
||||
if err != nil {
|
||||
t.Fatalf("fetch channel: %v", err)
|
||||
}
|
||||
@@ -180,14 +180,14 @@ func TestE2E_FetchWithDoubleLabel(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("create fetcher: %v", err)
|
||||
}
|
||||
fetcher.SetQueryMode(protocol.QueryDoubleLabel)
|
||||
fetcher.SetQueryMode(protocol.QueryMultiLabel)
|
||||
|
||||
meta, err := fetcher.FetchMetadata()
|
||||
meta, err := fetcher.FetchMetadata(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("fetch metadata: %v", err)
|
||||
}
|
||||
|
||||
fetchedMsgs, err := fetcher.FetchChannel(1, int(meta.Channels[0].Blocks))
|
||||
fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, int(meta.Channels[0].Blocks))
|
||||
if err != nil {
|
||||
t.Fatalf("fetch channel: %v", err)
|
||||
}
|
||||
@@ -216,7 +216,7 @@ func TestE2E_WrongPassphrase(t *testing.T) {
|
||||
t.Fatalf("create fetcher: %v", err)
|
||||
}
|
||||
|
||||
_, err = fetcher.FetchMetadata()
|
||||
_, err = fetcher.FetchMetadata(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error with wrong passphrase, got nil")
|
||||
}
|
||||
@@ -243,12 +243,12 @@ func TestE2E_LargeMessages(t *testing.T) {
|
||||
t.Fatalf("create fetcher: %v", err)
|
||||
}
|
||||
|
||||
meta, err := fetcher.FetchMetadata()
|
||||
meta, err := fetcher.FetchMetadata(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("fetch metadata: %v", err)
|
||||
}
|
||||
|
||||
fetchedMsgs, err := fetcher.FetchChannel(1, int(meta.Channels[0].Blocks))
|
||||
fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, int(meta.Channels[0].Blocks))
|
||||
if err != nil {
|
||||
t.Fatalf("fetch channel: %v", err)
|
||||
}
|
||||
@@ -586,8 +586,15 @@ func TestE2E_FullRoundTrip(t *testing.T) {
|
||||
t.Fatalf("config POST status=%d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Wait for auto-refresh to fetch data
|
||||
time.Sleep(3 * time.Second)
|
||||
// Refresh channels via selected-channel API semantics.
|
||||
respRefresh1, err := http.Post(base+"/api/refresh?channel=1", "application/json", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("POST /api/refresh?channel=1: %v", err)
|
||||
}
|
||||
respRefresh1.Body.Close()
|
||||
// Give channel 1 refresh goroutine time to complete before refreshing channel 2,
|
||||
// because starting a new refresh cancels the previous in-flight refresh.
|
||||
time.Sleep(700 * time.Millisecond)
|
||||
|
||||
// Channels should be populated
|
||||
resp2, err := http.Get(base + "/api/channels")
|
||||
@@ -621,6 +628,13 @@ func TestE2E_FullRoundTrip(t *testing.T) {
|
||||
t.Errorf("msg[0].Text = %q, want %q", msgList[0].Text, "General message 1")
|
||||
}
|
||||
|
||||
respRefresh2, err := http.Post(base+"/api/refresh?channel=2", "application/json", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("POST /api/refresh?channel=2: %v", err)
|
||||
}
|
||||
respRefresh2.Body.Close()
|
||||
time.Sleep(700 * time.Millisecond)
|
||||
|
||||
// Messages for channel 2
|
||||
resp4, err := http.Get(base + "/api/messages/2")
|
||||
if err != nil {
|
||||
|
||||
+273
-92
@@ -1,6 +1,7 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
@@ -12,26 +13,36 @@ import (
|
||||
"github.com/sartoopjj/thefeed/internal/protocol"
|
||||
)
|
||||
|
||||
// LogFunc is a callback for logging DNS queries (for debug/TUI).
|
||||
// LogFunc is a callback for log messages.
|
||||
type LogFunc func(msg string)
|
||||
|
||||
// noiseDomains are popular domains used to blend feed queries into normal-looking DNS traffic.
|
||||
var noiseDomains = []string{
|
||||
"www.google.com", "www.cloudflare.com", "one.one.one.one",
|
||||
"www.youtube.com", "www.instagram.com", "www.amazon.com",
|
||||
"www.microsoft.com", "www.apple.com", "www.github.com",
|
||||
"www.wikipedia.org", "www.reddit.com", "www.twitter.com",
|
||||
}
|
||||
|
||||
// Fetcher fetches feed blocks over DNS.
|
||||
type Fetcher struct {
|
||||
domain string
|
||||
queryKey [protocol.KeySize]byte
|
||||
responseKey [protocol.KeySize]byte
|
||||
queryMode protocol.QueryEncoding
|
||||
timeout time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
resolvers []string
|
||||
timeout time.Duration
|
||||
// Resolver pools — allResolvers is what the user configured;
|
||||
// activeResolvers is kept up-to-date by ResolverChecker (only healthy ones).
|
||||
mu sync.RWMutex
|
||||
allResolvers []string
|
||||
activeResolvers []string
|
||||
|
||||
// Rate limiting
|
||||
rateMu sync.Mutex
|
||||
queryDelay time.Duration
|
||||
lastQuery time.Time
|
||||
// Rate limiting via token bucket; nil means unlimited.
|
||||
rateQPS float64
|
||||
rateCh chan struct{}
|
||||
|
||||
// Debug logging
|
||||
debug bool
|
||||
logFunc LogFunc
|
||||
}
|
||||
|
||||
@@ -42,23 +53,28 @@ func NewFetcher(domain, passphrase string, resolvers []string) (*Fetcher, error)
|
||||
return nil, fmt.Errorf("derive keys: %w", err)
|
||||
}
|
||||
|
||||
r := make([]string, len(resolvers))
|
||||
copy(r, resolvers)
|
||||
|
||||
return &Fetcher{
|
||||
domain: strings.TrimSuffix(domain, "."),
|
||||
queryKey: qk,
|
||||
responseKey: rk,
|
||||
queryMode: protocol.QuerySingleLabel,
|
||||
resolvers: resolvers,
|
||||
timeout: 5 * time.Second,
|
||||
domain: strings.TrimSuffix(domain, "."),
|
||||
queryKey: qk,
|
||||
responseKey: rk,
|
||||
queryMode: protocol.QuerySingleLabel,
|
||||
allResolvers: r,
|
||||
activeResolvers: r,
|
||||
timeout: 10 * time.Second,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetRateLimit sets the maximum queries per second (0 = unlimited).
|
||||
// SetRateLimit sets the maximum queries per second (0 = unlimited). Must be called before Start.
|
||||
func (f *Fetcher) SetRateLimit(qps float64) {
|
||||
if qps <= 0 {
|
||||
f.queryDelay = 0
|
||||
return
|
||||
}
|
||||
f.queryDelay = time.Duration(float64(time.Second) / qps)
|
||||
f.rateQPS = qps
|
||||
}
|
||||
|
||||
// SetTimeout sets the per-query DNS timeout.
|
||||
func (f *Fetcher) SetTimeout(d time.Duration) {
|
||||
f.timeout = d
|
||||
}
|
||||
|
||||
// SetLogFunc sets the debug log callback.
|
||||
@@ -66,83 +82,225 @@ func (f *Fetcher) SetLogFunc(fn LogFunc) {
|
||||
f.logFunc = fn
|
||||
}
|
||||
|
||||
// SetDebug enables or disables debug logging of generated query names.
|
||||
func (f *Fetcher) SetDebug(debug bool) {
|
||||
f.debug = debug
|
||||
}
|
||||
|
||||
// SetQueryMode sets the DNS query encoding mode.
|
||||
func (f *Fetcher) SetQueryMode(mode protocol.QueryEncoding) {
|
||||
f.queryMode = mode
|
||||
}
|
||||
|
||||
// SetActiveResolvers updates the healthy resolver pool. Called by ResolverChecker.
|
||||
// If the new list is empty, the current pool is unchanged (to avoid blackout during a bad check).
|
||||
func (f *Fetcher) SetActiveResolvers(resolvers []string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if len(resolvers) > 0 {
|
||||
f.activeResolvers = make([]string, len(resolvers))
|
||||
copy(f.activeResolvers, resolvers)
|
||||
}
|
||||
}
|
||||
|
||||
// SetResolvers replaces the full resolver list and resets the active pool.
|
||||
func (f *Fetcher) SetResolvers(resolvers []string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.allResolvers = make([]string, len(resolvers))
|
||||
copy(f.allResolvers, resolvers)
|
||||
f.activeResolvers = make([]string, len(resolvers))
|
||||
copy(f.activeResolvers, resolvers)
|
||||
}
|
||||
|
||||
// AllResolvers returns all user-configured resolvers.
|
||||
func (f *Fetcher) AllResolvers() []string {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
result := make([]string, len(f.allResolvers))
|
||||
copy(result, f.allResolvers)
|
||||
return result
|
||||
}
|
||||
|
||||
// Resolvers returns the currently active (healthy) resolver list.
|
||||
func (f *Fetcher) Resolvers() []string {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
result := make([]string, len(f.activeResolvers))
|
||||
copy(result, f.activeResolvers)
|
||||
return result
|
||||
}
|
||||
|
||||
// Start launches background goroutines (rate limiter and noise generator).
|
||||
// ctx controls their lifetime — cancel it to cleanly stop them.
|
||||
// Call once per fetcher configuration; creating a new fetcher replaces the old one.
|
||||
func (f *Fetcher) Start(ctx context.Context) {
|
||||
if f.rateQPS > 0 {
|
||||
f.rateCh = make(chan struct{}, 1)
|
||||
go f.runRateLimiter(ctx)
|
||||
go f.runNoise(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// runRateLimiter issues one token into rateCh every 1/QPS seconds.
|
||||
// The channel capacity is 1, so tokens do not accumulate (no burst).
|
||||
func (f *Fetcher) runRateLimiter(ctx context.Context) {
|
||||
interval := time.Duration(float64(time.Second) / f.rateQPS)
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
select {
|
||||
case f.rateCh <- struct{}{}:
|
||||
default: // bucket full; discard extra token to prevent burst
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runNoise sends decoy A-record queries to popular domains at a randomised
|
||||
// rate matching the configured QPS, to make feed traffic look like normal DNS usage.
|
||||
func (f *Fetcher) runNoise(ctx context.Context) {
|
||||
interval := time.Duration(float64(time.Second) / f.rateQPS)
|
||||
for {
|
||||
// Random delay: 1–3× the query interval.
|
||||
jitter := time.Duration(rand.Int63n(int64(2*interval) + 1))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(interval + jitter):
|
||||
}
|
||||
|
||||
resolvers := f.Resolvers()
|
||||
if len(resolvers) == 0 {
|
||||
continue
|
||||
}
|
||||
resolver := resolvers[rand.Intn(len(resolvers))]
|
||||
if !strings.Contains(resolver, ":") {
|
||||
resolver += ":53"
|
||||
}
|
||||
target := noiseDomains[rand.Intn(len(noiseDomains))]
|
||||
|
||||
go func(r, d string) {
|
||||
c := &dns.Client{Timeout: f.timeout}
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(d), dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
c.Exchange(m, r) //nolint:errcheck — fire-and-forget noise query
|
||||
}(resolver, target)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Fetcher) log(format string, args ...any) {
|
||||
if f.logFunc != nil {
|
||||
f.logFunc(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Fetcher) rateWait() {
|
||||
if f.queryDelay <= 0 {
|
||||
// logProgress logs a progress bar: "prefix [====> ] 45%"
|
||||
func (f *Fetcher) logProgress(prefix string, current, total float64) {
|
||||
if f.logFunc == nil || total <= 0 {
|
||||
return
|
||||
}
|
||||
f.rateMu.Lock()
|
||||
defer f.rateMu.Unlock()
|
||||
elapsed := time.Since(f.lastQuery)
|
||||
if elapsed < f.queryDelay {
|
||||
time.Sleep(f.queryDelay - elapsed)
|
||||
|
||||
percent := int((current / total) * 100)
|
||||
barLen := 20
|
||||
filled := int((current / total) * float64(barLen))
|
||||
empty := barLen - filled
|
||||
|
||||
bar := "["
|
||||
for i := 0; i < filled; i++ {
|
||||
bar += "="
|
||||
}
|
||||
f.lastQuery = time.Now()
|
||||
if filled < barLen {
|
||||
bar += ">"
|
||||
}
|
||||
for i := 0; i < empty-1; i++ {
|
||||
bar += " "
|
||||
}
|
||||
bar += "]"
|
||||
|
||||
f.logFunc(fmt.Sprintf("%s %s %d%%", prefix, bar, percent))
|
||||
}
|
||||
|
||||
// SetResolvers replaces the resolver list.
|
||||
func (f *Fetcher) SetResolvers(resolvers []string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.resolvers = resolvers
|
||||
// rateWait blocks until a rate-limit token is available or ctx is cancelled.
|
||||
// Returns nil when a token was acquired, ctx.Err() when cancelled.
|
||||
func (f *Fetcher) rateWait(ctx context.Context) error {
|
||||
if f.rateCh == nil {
|
||||
// Unlimited: just propagate any existing cancellation.
|
||||
return ctx.Err()
|
||||
}
|
||||
select {
|
||||
case <-f.rateCh:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Resolvers returns the current resolver list.
|
||||
func (f *Fetcher) Resolvers() []string {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
result := make([]string, len(f.resolvers))
|
||||
copy(result, f.resolvers)
|
||||
return result
|
||||
}
|
||||
|
||||
// FetchBlock fetches a single block from a channel.
|
||||
func (f *Fetcher) FetchBlock(channel, block uint16) ([]byte, error) {
|
||||
f.rateWait()
|
||||
|
||||
qname, err := protocol.EncodeQuery(f.queryKey, channel, block, f.domain, f.queryMode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode query: %w", err)
|
||||
}
|
||||
|
||||
f.log("Q ch=%d blk=%d → %s", channel, block, qname)
|
||||
|
||||
resolvers := f.Resolvers()
|
||||
if len(resolvers) == 0 {
|
||||
return nil, fmt.Errorf("no resolvers configured")
|
||||
}
|
||||
|
||||
// Shuffle resolvers to distribute load
|
||||
shuffled := make([]string, len(resolvers))
|
||||
copy(shuffled, resolvers)
|
||||
rand.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })
|
||||
|
||||
// FetchBlock fetches a single encrypted block from the given channel.
|
||||
// It enqueues through the rate limiter and respects ctx cancellation.
|
||||
// On transient failure it retries up to 2 additional times with a short back-off.
|
||||
func (f *Fetcher) FetchBlock(ctx context.Context, channel, block uint16) ([]byte, error) {
|
||||
const maxAttempts = 3
|
||||
var lastErr error
|
||||
for _, resolver := range shuffled {
|
||||
data, err := f.queryResolver(resolver, qname)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Brief back-off before retry; bail immediately if ctx is done.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(attempt) * 500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("all resolvers failed, last error: %w", lastErr)
|
||||
if err := f.rateWait(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qname, err := protocol.EncodeQuery(f.queryKey, channel, block, f.domain, f.queryMode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode query: %w", err)
|
||||
}
|
||||
if f.debug {
|
||||
f.log("[debug] query ch=%d blk=%d attempt=%d qname=%s", channel, block, attempt+1, qname)
|
||||
}
|
||||
|
||||
resolvers := f.Resolvers()
|
||||
if len(resolvers) == 0 {
|
||||
return nil, fmt.Errorf("no active resolvers")
|
||||
}
|
||||
|
||||
// Shuffle to spread load across resolvers.
|
||||
shuffled := make([]string, len(resolvers))
|
||||
copy(shuffled, resolvers)
|
||||
rand.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })
|
||||
|
||||
for _, resolver := range shuffled {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
data, err := f.queryResolver(ctx, resolver, qname)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
lastErr = fmt.Errorf("all resolvers failed: %w", lastErr)
|
||||
if attempt+1 < maxAttempts {
|
||||
f.log("block ch=%d blk=%d attempt %d/%d failed, retrying: %v", channel, block, attempt+1, maxAttempts, lastErr)
|
||||
}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// FetchMetadata fetches and parses metadata (channel 0).
|
||||
func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) {
|
||||
data, err := f.FetchBlock(protocol.MetadataChannel, 0)
|
||||
// FetchMetadata fetches and parses the metadata block (channel 0).
|
||||
func (f *Fetcher) FetchMetadata(ctx context.Context) (*protocol.Metadata, error) {
|
||||
data, err := f.FetchBlock(ctx, protocol.MetadataChannel, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch metadata block 0: %w", err)
|
||||
}
|
||||
@@ -152,12 +310,15 @@ func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) {
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// Metadata might span multiple blocks
|
||||
// Metadata may span multiple blocks.
|
||||
allData := make([]byte, len(data))
|
||||
copy(allData, data)
|
||||
|
||||
for blk := uint16(1); blk < 10; blk++ {
|
||||
block, fetchErr := f.FetchBlock(protocol.MetadataChannel, blk)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
block, fetchErr := f.FetchBlock(ctx, protocol.MetadataChannel, blk)
|
||||
if fetchErr != nil {
|
||||
break
|
||||
}
|
||||
@@ -171,31 +332,41 @@ func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) {
|
||||
return nil, fmt.Errorf("could not parse metadata: %w", err)
|
||||
}
|
||||
|
||||
// FetchChannel fetches all blocks for a channel and parses messages.
|
||||
func (f *Fetcher) FetchChannel(channelNum int, blockCount int) ([]protocol.Message, error) {
|
||||
// FetchChannel fetches all blocks for a channel and returns the parsed messages.
|
||||
// Cancelling ctx immediately aborts any queued or in-flight block fetches.
|
||||
// Each block is retried individually via FetchBlock before the channel fetch fails.
|
||||
func (f *Fetcher) FetchChannel(ctx context.Context, channelNum int, blockCount int) ([]protocol.Message, error) {
|
||||
if blockCount <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type result struct {
|
||||
type blockResult struct {
|
||||
idx int
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
results := make(chan result, blockCount)
|
||||
// Limit concurrency to 3 to reduce DNS burst traffic
|
||||
sem := make(chan struct{}, 3)
|
||||
results := make(chan blockResult, blockCount)
|
||||
// Cap concurrent DNS queries at 5; the token-bucket rate limiter provides
|
||||
// the actual throughput control regardless of this concurrency cap.
|
||||
sem := make(chan struct{}, 5)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < blockCount; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
// Acquire semaphore or bail on cancellation.
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
results <- blockResult{idx: idx, err: ctx.Err()}
|
||||
return
|
||||
}
|
||||
defer func() { <-sem }()
|
||||
data, err := f.FetchBlock(uint16(channelNum), uint16(idx))
|
||||
results <- result{idx: idx, data: data, err: err}
|
||||
|
||||
data, err := f.FetchBlock(ctx, uint16(channelNum), uint16(idx))
|
||||
results <- blockResult{idx: idx, data: data, err: err}
|
||||
}(i)
|
||||
}
|
||||
|
||||
@@ -205,24 +376,33 @@ func (f *Fetcher) FetchChannel(channelNum int, blockCount int) ([]protocol.Messa
|
||||
}()
|
||||
|
||||
ordered := make([][]byte, blockCount)
|
||||
completed := 0
|
||||
for r := range results {
|
||||
if r.err != nil {
|
||||
return nil, fmt.Errorf("fetch block %d: %w", r.idx, r.err)
|
||||
if r.err == ctx.Err() {
|
||||
// Context cancelled — abort immediately.
|
||||
return nil, r.err
|
||||
}
|
||||
// FetchBlock already retried internally; log and treat as fatal for this channel.
|
||||
f.log("Channel %d block %d permanently failed: %v", channelNum, r.idx, r.err)
|
||||
return nil, fmt.Errorf("channel %d block %d: %w", channelNum, r.idx, r.err)
|
||||
}
|
||||
ordered[r.idx] = r.data
|
||||
completed++
|
||||
f.logProgress(fmt.Sprintf("Channel %d", channelNum), float64(completed), float64(blockCount))
|
||||
}
|
||||
|
||||
var allData []byte
|
||||
for _, block := range ordered {
|
||||
allData = append(allData, block...)
|
||||
for _, b := range ordered {
|
||||
allData = append(allData, b...)
|
||||
}
|
||||
|
||||
return protocol.ParseMessages(allData)
|
||||
}
|
||||
|
||||
func (f *Fetcher) queryResolver(resolver, qname string) ([]byte, error) {
|
||||
func (f *Fetcher) queryResolver(ctx context.Context, resolver, qname string) ([]byte, error) {
|
||||
if !strings.Contains(resolver, ":") {
|
||||
resolver = resolver + ":53"
|
||||
resolver += ":53"
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
@@ -231,8 +411,9 @@ func (f *Fetcher) queryResolver(resolver, qname string) ([]byte, error) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(qname), dns.TypeTXT)
|
||||
m.RecursionDesired = true
|
||||
m.SetEdns0(4096, false) // advertise 4 KiB UDP buffer to avoid response truncation
|
||||
|
||||
resp, _, err := c.Exchange(m, resolver)
|
||||
resp, _, err := c.ExchangeContext(ctx, m, resolver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dns exchange with %s: %w", resolver, err)
|
||||
}
|
||||
|
||||
+85
-137
@@ -1,181 +1,129 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/sartoopjj/thefeed/internal/protocol"
|
||||
)
|
||||
|
||||
// ResolverScanner scans CIDR ranges to find working DNS resolvers.
|
||||
type ResolverScanner struct {
|
||||
fetcher *Fetcher
|
||||
concurrency int
|
||||
timeout time.Duration
|
||||
// ResolverChecker periodically probes the fetcher's configured resolvers and
|
||||
// updates the active (healthy) resolver pool. It replaces the old file/CIDR
|
||||
// scanner — no file I/O; just a plain DNS probe on channel 0.
|
||||
type ResolverChecker struct {
|
||||
fetcher *Fetcher
|
||||
timeout time.Duration
|
||||
logFunc LogFunc
|
||||
}
|
||||
|
||||
// NewResolverScanner creates a resolver scanner.
|
||||
func NewResolverScanner(fetcher *Fetcher, concurrency int) *ResolverScanner {
|
||||
if concurrency <= 0 {
|
||||
concurrency = 50
|
||||
// NewResolverChecker creates a health checker for the resolvers in fetcher.
|
||||
// timeout is the per-probe deadline; 0 uses a 5-second default.
|
||||
func NewResolverChecker(fetcher *Fetcher, timeout time.Duration) *ResolverChecker {
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
return &ResolverScanner{
|
||||
fetcher: fetcher,
|
||||
concurrency: concurrency,
|
||||
timeout: 3 * time.Second,
|
||||
return &ResolverChecker{
|
||||
fetcher: fetcher,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// ScanCIDR scans a CIDR range for working DNS resolvers.
|
||||
func (rs *ResolverScanner) ScanCIDR(cidr string, onFound func(ip string)) error {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse CIDR %q: %w", cidr, err)
|
||||
}
|
||||
|
||||
ips := expandCIDR(ipNet)
|
||||
return rs.scanIPs(ips, onFound)
|
||||
// SetLogFunc sets the callback used to emit health-check results to the log panel.
|
||||
func (rc *ResolverChecker) SetLogFunc(fn LogFunc) {
|
||||
rc.logFunc = fn
|
||||
}
|
||||
|
||||
// ScanFile scans resolver IPs from a file (one per line, supports CIDR notation).
|
||||
func (rs *ResolverScanner) ScanFile(path string, onFound func(ip string)) error {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var ips []string
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(line, "/") {
|
||||
_, ipNet, err := net.ParseCIDR(line)
|
||||
if err != nil {
|
||||
log.Printf("[resolver] skip invalid CIDR: %s", line)
|
||||
continue
|
||||
// Start begins the periodic health-check loop in the background.
|
||||
// An initial check runs immediately; subsequent checks happen every 10 minutes.
|
||||
// ctx controls the lifetime — cancel it to stop the checker.
|
||||
func (rc *ResolverChecker) Start(ctx context.Context) {
|
||||
go func() {
|
||||
rc.runCheck()
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rc.runCheck()
|
||||
}
|
||||
ips = append(ips, expandCIDR(ipNet)...)
|
||||
} else {
|
||||
ips = append(ips, line)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return rs.scanIPs(ips, onFound)
|
||||
}()
|
||||
}
|
||||
|
||||
// CheckResolver tests if a single resolver works by querying metadata.
|
||||
func (rs *ResolverScanner) CheckResolver(ip string) bool {
|
||||
if !strings.Contains(ip, ":") {
|
||||
ip = ip + ":53"
|
||||
func (rc *ResolverChecker) runCheck() {
|
||||
resolvers := rc.fetcher.AllResolvers()
|
||||
if len(resolvers) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new fetcher with only this resolver to avoid copying the lock.
|
||||
tmpFetcher := &Fetcher{
|
||||
domain: rs.fetcher.domain,
|
||||
queryKey: rs.fetcher.queryKey,
|
||||
responseKey: rs.fetcher.responseKey,
|
||||
resolvers: []string{ip},
|
||||
timeout: rs.timeout,
|
||||
}
|
||||
rc.log("Checking %d resolver(s)...", len(resolvers))
|
||||
|
||||
_, err := tmpFetcher.FetchBlock(0, 0)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (rs *ResolverScanner) scanIPs(ips []string, onFound func(ip string)) error {
|
||||
if len(ips) == 0 {
|
||||
return fmt.Errorf("no IPs to scan")
|
||||
}
|
||||
|
||||
var found atomic.Int32
|
||||
sem := make(chan struct{}, rs.concurrency)
|
||||
var healthy []string
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
sem := make(chan struct{}, 10) // probe up to 10 resolvers concurrently
|
||||
|
||||
for _, ip := range ips {
|
||||
for _, r := range resolvers {
|
||||
wg.Add(1)
|
||||
go func(ip string) {
|
||||
go func(r string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
if rs.CheckResolver(ip) {
|
||||
found.Add(1)
|
||||
if onFound != nil {
|
||||
onFound(ip)
|
||||
}
|
||||
if rc.checkOne(r) {
|
||||
mu.Lock()
|
||||
healthy = append(healthy, r)
|
||||
mu.Unlock()
|
||||
rc.log("Resolver OK: %s", r)
|
||||
} else {
|
||||
rc.log("Resolver failed: %s", r)
|
||||
}
|
||||
}(ip)
|
||||
}(r)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if found.Load() == 0 {
|
||||
return fmt.Errorf("no working resolvers found among %d IPs", len(ips))
|
||||
}
|
||||
return nil
|
||||
rc.fetcher.SetActiveResolvers(healthy)
|
||||
rc.log("Resolver check done: %d/%d healthy", len(healthy), len(resolvers))
|
||||
}
|
||||
|
||||
// LoadResolversFile loads resolver IPs from a file (one per line).
|
||||
func LoadResolversFile(path string) ([]string, error) {
|
||||
f, err := os.Open(path)
|
||||
// checkOne probes a single resolver by sending a metadata channel query
|
||||
// (channel 0, block 0). A successful DNS response (any rcode that isn't a
|
||||
// network/timeout error) means the resolver is reachable and understands the domain.
|
||||
func (rc *ResolverChecker) checkOne(resolver string) bool {
|
||||
if !strings.Contains(resolver, ":") {
|
||||
resolver += ":53"
|
||||
}
|
||||
|
||||
qname, err := protocol.EncodeQuery(
|
||||
rc.fetcher.queryKey,
|
||||
protocol.MetadataChannel, 0,
|
||||
rc.fetcher.domain,
|
||||
rc.fetcher.queryMode,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return false
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var resolvers []string
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
resolvers = append(resolvers, line)
|
||||
}
|
||||
return resolvers, scanner.Err()
|
||||
c := &dns.Client{Timeout: rc.timeout}
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(qname), dns.TypeTXT)
|
||||
m.RecursionDesired = true
|
||||
|
||||
resp, _, err := c.Exchange(m, resolver)
|
||||
// We consider the resolver healthy if we get any DNS response back
|
||||
// (even NXDOMAIN means the resolver forwarded the query to our server).
|
||||
return err == nil && resp != nil
|
||||
}
|
||||
|
||||
func expandCIDR(ipNet *net.IPNet) []string {
|
||||
var ips []string
|
||||
ip := ipNet.IP.Mask(ipNet.Mask)
|
||||
|
||||
for ip := cloneIP(ip); ipNet.Contains(ip); incIP(ip) {
|
||||
// Skip network and broadcast addresses for /24 and smaller
|
||||
ones, bits := ipNet.Mask.Size()
|
||||
if bits-ones <= 8 {
|
||||
last := ip[len(ip)-1]
|
||||
if last == 0 || last == 255 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
ips = append(ips, ip.String())
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
func cloneIP(ip net.IP) net.IP {
|
||||
dup := make(net.IP, len(ip))
|
||||
copy(dup, ip)
|
||||
return dup
|
||||
}
|
||||
|
||||
func incIP(ip net.IP) {
|
||||
for j := len(ip) - 1; j >= 0; j-- {
|
||||
ip[j]++
|
||||
if ip[j] > 0 {
|
||||
break
|
||||
}
|
||||
func (rc *ResolverChecker) log(format string, args ...any) {
|
||||
if rc.logFunc != nil {
|
||||
rc.logFunc(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,3 +67,42 @@ func Decrypt(key [KeySize]byte, ciphertext []byte) ([]byte, error) {
|
||||
nonce := ciphertext[:gcm.NonceSize()]
|
||||
return gcm.Open(nil, nonce, ciphertext[gcm.NonceSize():], nil)
|
||||
}
|
||||
|
||||
// encryptQueryBlock encrypts an 8-byte query payload using a direct AES-256 block cipher.
|
||||
// The payload is expanded to one AES block (16 bytes) with 8 trailing zero bytes before
|
||||
// encryption. No nonce or auth tag needed: the 4 random bytes in the payload guarantee
|
||||
// unique ciphertext per query. Result is always 16 bytes.
|
||||
func encryptQueryBlock(key [KeySize]byte, payload []byte) ([]byte, error) {
|
||||
if len(payload) != QueryPayloadSize {
|
||||
return nil, fmt.Errorf("encryptQueryBlock: payload must be %d bytes, got %d", QueryPayloadSize, len(payload))
|
||||
}
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var buf [aes.BlockSize]byte
|
||||
copy(buf[:QueryPayloadSize], payload) // bytes 8-15 stay zero
|
||||
block.Encrypt(buf[:], buf[:])
|
||||
return buf[:], nil
|
||||
}
|
||||
|
||||
// decryptQueryBlock decrypts a query ciphertext produced by encryptQueryBlock.
|
||||
// Accepts ciphertext with optional random suffix bytes (≥ BlockSize); only the
|
||||
// first BlockSize bytes are used. Verifies the last 8 bytes of plaintext are zero.
|
||||
func decryptQueryBlock(key [KeySize]byte, ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext) < aes.BlockSize {
|
||||
return nil, fmt.Errorf("decryptQueryBlock: need at least %d bytes, got %d", aes.BlockSize, len(ciphertext))
|
||||
}
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var buf [aes.BlockSize]byte
|
||||
block.Decrypt(buf[:], ciphertext[:aes.BlockSize]) // ignore suffix
|
||||
for i := QueryPayloadSize; i < aes.BlockSize; i++ {
|
||||
if buf[i] != 0 {
|
||||
return nil, fmt.Errorf("decryptQueryBlock: integrity check failed (wrong key?)")
|
||||
}
|
||||
}
|
||||
return buf[:QueryPayloadSize], nil
|
||||
}
|
||||
|
||||
+137
-29
@@ -7,26 +7,48 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
maxDNSLabelLen = 63
|
||||
maxDNSNameLen = 253 // without trailing dot
|
||||
)
|
||||
|
||||
// QueryEncoding controls how DNS query subdomains are encoded.
|
||||
type QueryEncoding int
|
||||
|
||||
const (
|
||||
// QuerySingleLabel uses base32 in a single DNS label (default, stealthier).
|
||||
QuerySingleLabel QueryEncoding = iota
|
||||
// QueryDoubleLabel uses hex split across two DNS labels.
|
||||
QueryDoubleLabel
|
||||
// QueryMultiLabel uses hex split across multiple DNS labels.
|
||||
QueryMultiLabel
|
||||
// QueryPlainLabel encodes channel and block as plain decimal text (no query encryption).
|
||||
// Responses are always encrypted regardless of this setting.
|
||||
QueryPlainLabel
|
||||
)
|
||||
|
||||
var b32 = base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||
|
||||
// EncodeQuery creates an encrypted DNS query subdomain.
|
||||
// EncodeQuery creates a DNS query subdomain for the given channel and block.
|
||||
// Single-label (default): [base32_encrypted].domain
|
||||
// Double-label: [hex_part1].[hex_part2].domain
|
||||
// Payload: 4 random + 2 channel + 2 block = 8 bytes, encrypted with AES-GCM.
|
||||
// Multi-label: [hex_part1].[hex_part2].domain
|
||||
// Plain-label: c<channel>b<block>.domain (no query encryption)
|
||||
// Responses are always encrypted regardless of mode.
|
||||
func EncodeQuery(queryKey [KeySize]byte, channel, block uint16, domain string, mode QueryEncoding) (string, error) {
|
||||
domain = strings.TrimSuffix(domain, ".")
|
||||
if domain == "" {
|
||||
return "", fmt.Errorf("empty domain")
|
||||
}
|
||||
|
||||
// Plain text mode: no encryption, just human-readable label.
|
||||
if mode == QueryPlainLabel {
|
||||
label := fmt.Sprintf("c%db%d", channel, block)
|
||||
return joinQName([]string{label}, domain)
|
||||
}
|
||||
|
||||
payload := make([]byte, QueryPayloadSize)
|
||||
|
||||
if _, err := rand.Read(payload[:QueryPaddingSize]); err != nil {
|
||||
@@ -36,24 +58,88 @@ func EncodeQuery(queryKey [KeySize]byte, channel, block uint16, domain string, m
|
||||
binary.BigEndian.PutUint16(payload[QueryPaddingSize:], channel)
|
||||
binary.BigEndian.PutUint16(payload[QueryPaddingSize+QueryChannelSize:], block)
|
||||
|
||||
encrypted, err := Encrypt(queryKey, payload)
|
||||
encrypted, err := encryptQueryBlock(queryKey, payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encrypt query: %w", err)
|
||||
}
|
||||
|
||||
// Append 0–4 random suffix bytes so query length varies per request.
|
||||
// The decoder strips these by only using the first aes.BlockSize bytes.
|
||||
suffixLen, _ := rand.Int(rand.Reader, big.NewInt(5)) // [0,4]
|
||||
suffix := make([]byte, int(suffixLen.Int64()))
|
||||
rand.Read(suffix) //nolint:errcheck — non-critical randomness
|
||||
ciphertext := append(encrypted, suffix...)
|
||||
|
||||
switch mode {
|
||||
case QueryDoubleLabel:
|
||||
h := hex.EncodeToString(encrypted)
|
||||
mid := len(h) / 2
|
||||
return fmt.Sprintf("%s.%s.%s", h[:mid], h[mid:], domain), nil
|
||||
case QueryMultiLabel:
|
||||
h := hex.EncodeToString(ciphertext)
|
||||
labels := splitMultiLabel(h)
|
||||
return joinQName(labels, domain)
|
||||
default:
|
||||
encoded := strings.ToLower(b32.EncodeToString(encrypted))
|
||||
return fmt.Sprintf("%s.%s", encoded, domain), nil
|
||||
encoded := strings.ToLower(b32.EncodeToString(ciphertext))
|
||||
return joinQName([]string{encoded}, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func splitLabel(s string, size int) []string {
|
||||
if size <= 0 {
|
||||
size = maxDNSLabelLen
|
||||
}
|
||||
if size > maxDNSLabelLen {
|
||||
size = maxDNSLabelLen
|
||||
}
|
||||
|
||||
parts := make([]string, 0, (len(s)+size-1)/size)
|
||||
for len(s) > size {
|
||||
parts = append(parts, s[:size])
|
||||
s = s[size:]
|
||||
}
|
||||
if len(s) > 0 {
|
||||
parts = append(parts, s)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// splitMultiLabel splits a hex string into two labels of randomised, unequal length.
|
||||
// The first label is between 12 and (len-4) chars so the second is at least 4 chars.
|
||||
// This makes query labels look less uniform across requests.
|
||||
func splitMultiLabel(h string) []string {
|
||||
if len(h) <= 8 {
|
||||
return []string{h}
|
||||
}
|
||||
// first label: random length in [minFirst, len-4]
|
||||
minFirst := 8
|
||||
maxFirst := len(h) - 4
|
||||
if maxFirst <= minFirst {
|
||||
maxFirst = minFirst + 1
|
||||
}
|
||||
// crypto/rand for the split point; fall back to midpoint on error
|
||||
split := (len(h) + 1) / 2 // default: slightly off-centre
|
||||
if n, err := rand.Int(rand.Reader, big.NewInt(int64(maxFirst-minFirst+1))); err == nil {
|
||||
split = minFirst + int(n.Int64())
|
||||
}
|
||||
return []string{h[:split], h[split:]}
|
||||
}
|
||||
|
||||
func joinQName(labels []string, domain string) (string, error) {
|
||||
for _, l := range labels {
|
||||
if len(l) == 0 {
|
||||
return "", fmt.Errorf("empty label")
|
||||
}
|
||||
if len(l) > maxDNSLabelLen {
|
||||
return "", fmt.Errorf("label too long: %d", len(l))
|
||||
}
|
||||
}
|
||||
|
||||
qname := strings.Join(append(labels, domain), ".")
|
||||
if len(qname) > maxDNSNameLen {
|
||||
return "", fmt.Errorf("query name too long: %d", len(qname))
|
||||
}
|
||||
return qname, nil
|
||||
}
|
||||
|
||||
// DecodeQuery parses and decrypts a DNS query subdomain.
|
||||
// Auto-detects single-label (base32) or double-label (hex) encoding.
|
||||
// Auto-detects plain-text (c<N>b<M>), single-label base32, or multi-label hex encoding.
|
||||
func DecodeQuery(queryKey [KeySize]byte, qname, domain string) (channel, block uint16, err error) {
|
||||
qname = strings.TrimSuffix(qname, ".")
|
||||
domain = strings.TrimSuffix(domain, ".")
|
||||
@@ -65,32 +151,54 @@ func DecodeQuery(queryKey [KeySize]byte, qname, domain string) (channel, block u
|
||||
|
||||
encoded := qname[:len(qname)-len(suffix)]
|
||||
|
||||
// Try base32 first (single label, no dots, or dots stripped)
|
||||
b32str := strings.ReplaceAll(encoded, ".", "")
|
||||
ciphertext, err := b32.DecodeString(strings.ToUpper(b32str))
|
||||
if err == nil {
|
||||
return decryptQuery(queryKey, ciphertext)
|
||||
// Try plain-label first: c<channel>b<block> (short, no dots, all decimal)
|
||||
if ch, blk, ok := parsePlainLabel(encoded); ok {
|
||||
return ch, blk, nil
|
||||
}
|
||||
|
||||
// Fall back to hex (double-label)
|
||||
// Try base32 (single-label: no dots or dots stripped)
|
||||
b32str := strings.ReplaceAll(encoded, ".", "")
|
||||
if ct, e := b32.DecodeString(strings.ToUpper(b32str)); e == nil {
|
||||
return parseQueryCiphertext(queryKey, ct)
|
||||
}
|
||||
|
||||
// Fall back to hex (multi-label: dots stripped)
|
||||
hexStr := strings.ReplaceAll(encoded, ".", "")
|
||||
ciphertext, err = hex.DecodeString(hexStr)
|
||||
if err != nil {
|
||||
ct, e := hex.DecodeString(hexStr)
|
||||
if e != nil {
|
||||
return 0, 0, fmt.Errorf("decode query: invalid encoding")
|
||||
}
|
||||
return decryptQuery(queryKey, ciphertext)
|
||||
return parseQueryCiphertext(queryKey, ct)
|
||||
}
|
||||
|
||||
func decryptQuery(queryKey [KeySize]byte, ciphertext []byte) (channel, block uint16, err error) {
|
||||
plaintext, err := Decrypt(queryKey, ciphertext)
|
||||
// parsePlainLabel parses the plain-text query format "c<channel>b<block>".
|
||||
// Returns ok=false if the string does not match this pattern.
|
||||
func parsePlainLabel(s string) (channel, block uint16, ok bool) {
|
||||
if len(s) < 3 || s[0] != 'c' {
|
||||
return 0, 0, false
|
||||
}
|
||||
bi := strings.IndexByte(s[1:], 'b')
|
||||
if bi < 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
bi++ // adjust for the slice offset
|
||||
chStr, bStr := s[1:bi], s[bi+1:]
|
||||
if len(chStr) == 0 || len(bStr) == 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
ch, err1 := strconv.ParseUint(chStr, 10, 16)
|
||||
blk, err2 := strconv.ParseUint(bStr, 10, 16)
|
||||
if err1 != nil || err2 != nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
return uint16(ch), uint16(blk), true
|
||||
}
|
||||
|
||||
func parseQueryCiphertext(queryKey [KeySize]byte, ciphertext []byte) (channel, block uint16, err error) {
|
||||
plaintext, err := decryptQueryBlock(queryKey, ciphertext)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
if len(plaintext) != QueryPayloadSize {
|
||||
return 0, 0, fmt.Errorf("invalid payload size: %d", len(plaintext))
|
||||
}
|
||||
|
||||
channel = binary.BigEndian.Uint16(plaintext[QueryPaddingSize:])
|
||||
block = binary.BigEndian.Uint16(plaintext[QueryPaddingSize+QueryChannelSize:])
|
||||
return channel, block, nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@@ -43,21 +44,26 @@ func TestEncodeDecodeQuerySingleLabel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodeQueryDoubleLabel(t *testing.T) {
|
||||
func TestEncodeDecodeQueryMultiLabel(t *testing.T) {
|
||||
qk, _, err := DeriveKeys("test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("DeriveKeys: %v", err)
|
||||
}
|
||||
domain := "t.example.com"
|
||||
qname, err := EncodeQuery(qk, 3, 7, domain, QueryDoubleLabel)
|
||||
qname, err := EncodeQuery(qk, 3, 7, domain, QueryMultiLabel)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Double label: two hex labels before domain
|
||||
// Multi-label mode splits hex across labels; all must be DNS-safe.
|
||||
subdomain := qname[:len(qname)-len(domain)-1]
|
||||
parts := strings.Split(subdomain, ".")
|
||||
if len(parts) != 2 {
|
||||
t.Errorf("double-label query should have 2 parts, got %d: %q", len(parts), subdomain)
|
||||
if len(parts) < 1 {
|
||||
t.Errorf("multi-label query should have at least 1 part, got %d: %q", len(parts), subdomain)
|
||||
}
|
||||
for _, p := range parts {
|
||||
if len(p) == 0 || len(p) > 63 {
|
||||
t.Errorf("invalid label length %d in %q", len(p), p)
|
||||
}
|
||||
}
|
||||
ch, blk, err := DecodeQuery(qk, qname, domain)
|
||||
if err != nil {
|
||||
@@ -68,6 +74,73 @@ func TestEncodeDecodeQueryDoubleLabel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeQueryTooLongDomain(t *testing.T) {
|
||||
qk, _, err := DeriveKeys("test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("DeriveKeys: %v", err)
|
||||
}
|
||||
|
||||
// 250-char domain should make qname exceed DNS 253-char limit.
|
||||
longDomain := strings.Repeat("a", 250)
|
||||
_, err = EncodeQuery(qk, 1, 1, longDomain, QueryMultiLabel)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for too-long domain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodeQueryPlainLabel(t *testing.T) {
|
||||
qk, _, err := DeriveKeys("test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("DeriveKeys: %v", err)
|
||||
}
|
||||
domain := "t.example.com"
|
||||
tests := []struct {
|
||||
channel uint16
|
||||
block uint16
|
||||
}{
|
||||
{0, 0},
|
||||
{1, 42},
|
||||
{255, 65535},
|
||||
{3, 100},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
qname, err := EncodeQuery(qk, tt.channel, tt.block, domain, QueryPlainLabel)
|
||||
if err != nil {
|
||||
t.Fatalf("EncodeQuery(%d, %d): %v", tt.channel, tt.block, err)
|
||||
}
|
||||
// Label should be "c<channel>b<block>" — human readable, no padding hex.
|
||||
want := fmt.Sprintf("c%db%d.%s", tt.channel, tt.block, domain)
|
||||
if qname != want {
|
||||
t.Errorf("got %q, want %q", qname, want)
|
||||
}
|
||||
// DecodeQuery must recover channel and block regardless of key.
|
||||
ch, blk, err := DecodeQuery(qk, qname, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeQuery: %v", err)
|
||||
}
|
||||
if ch != tt.channel || blk != tt.block {
|
||||
t.Errorf("got ch=%d blk=%d, want ch=%d blk=%d", ch, blk, tt.channel, tt.block)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlainLabelNotConfusedWithEncrypted(t *testing.T) {
|
||||
qk, _, err := DeriveKeys("test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("DeriveKeys: %v", err)
|
||||
}
|
||||
domain := "t.example.com"
|
||||
// Encode with single-label then check that DecodeQuery does NOT treat it as plain.
|
||||
qname, _ := EncodeQuery(qk, 5, 10, domain, QuerySingleLabel)
|
||||
ch, blk, err := DecodeQuery(qk, qname, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeQuery single-label: %v", err)
|
||||
}
|
||||
if ch != 5 || blk != 10 {
|
||||
t.Errorf("got ch=%d blk=%d, want ch=5 blk=10", ch, blk)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeQueryWrongKey(t *testing.T) {
|
||||
qk1, _, _ := DeriveKeys("key1")
|
||||
qk2, _, _ := DeriveKeys("key2")
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultBlockPayload is the decrypted payload per DNS TXT block.
|
||||
// Calculated to stay within 512-byte UDP DNS limit after encryption + base64 + padding overhead.
|
||||
DefaultBlockPayload = 180
|
||||
// MinBlockPayload is the minimum decrypted payload per DNS TXT block.
|
||||
MinBlockPayload = 200
|
||||
// MaxBlockPayload is the maximum decrypted payload per DNS TXT block.
|
||||
// 600 bytes data + 28 GCM overhead + 2 prefix + 32 padding → ~856 base64 chars.
|
||||
// Well within the 4096-byte EDNS0 UDP buffer the client advertises.
|
||||
MaxBlockPayload = 600
|
||||
// DefaultBlockPayload is kept for compatibility; equals MaxBlockPayload.
|
||||
DefaultBlockPayload = MaxBlockPayload
|
||||
|
||||
// DefaultMaxPadding is the default random padding added to responses to vary DNS response size.
|
||||
DefaultMaxPadding = 32
|
||||
@@ -210,20 +217,31 @@ func ParseMessages(data []byte) ([]Message, error) {
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
// SplitIntoBlocks splits data into blocks of DefaultBlockPayload size.
|
||||
// SplitIntoBlocks splits data into blocks of randomly varying size in [MinBlockPayload, MaxBlockPayload].
|
||||
// Random sizes make traffic analysis harder; the client just concatenates all blocks to reassemble.
|
||||
func SplitIntoBlocks(data []byte) [][]byte {
|
||||
if len(data) == 0 {
|
||||
return [][]byte{{}}
|
||||
return [][]byte{{}} // channel 0 block 0 must always exist
|
||||
}
|
||||
var blocks [][]byte
|
||||
for i := 0; i < len(data); i += DefaultBlockPayload {
|
||||
end := i + DefaultBlockPayload
|
||||
if end > len(data) {
|
||||
end = len(data)
|
||||
rem := data
|
||||
for len(rem) > 0 {
|
||||
size := randBlockSize()
|
||||
if size > len(rem) {
|
||||
size = len(rem)
|
||||
}
|
||||
block := make([]byte, end-i)
|
||||
copy(block, data[i:end])
|
||||
block := make([]byte, size)
|
||||
copy(block, rem[:size])
|
||||
blocks = append(blocks, block)
|
||||
rem = rem[size:]
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
func randBlockSize() int {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(MaxBlockPayload-MinBlockPayload+1)))
|
||||
if err != nil {
|
||||
return (MinBlockPayload + MaxBlockPayload) / 2
|
||||
}
|
||||
return MinBlockPayload + int(n.Int64())
|
||||
}
|
||||
|
||||
@@ -57,19 +57,19 @@ func TestSerializeParseMessages(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSplitIntoBlocks(t *testing.T) {
|
||||
data := bytes.Repeat([]byte("A"), DefaultBlockPayload*3+50)
|
||||
// MaxBlockPayload*3+50 guarantees at least 4 blocks (ceil((MaxBlockPayload*3+50)/MaxBlockPayload) = 4).
|
||||
data := bytes.Repeat([]byte("A"), MaxBlockPayload*3+50)
|
||||
blocks := SplitIntoBlocks(data)
|
||||
if len(blocks) != 4 {
|
||||
t.Fatalf("blocks: got %d, want 4", len(blocks))
|
||||
if len(blocks) < 4 {
|
||||
t.Fatalf("expected at least 4 blocks for %d bytes, got %d", len(data), len(blocks))
|
||||
}
|
||||
for i, b := range blocks {
|
||||
if i < 3 && len(b) != DefaultBlockPayload {
|
||||
t.Errorf("block %d: size %d, want %d", i, len(b), DefaultBlockPayload)
|
||||
// Every non-last block must be within [MinBlockPayload, MaxBlockPayload].
|
||||
for i, b := range blocks[:len(blocks)-1] {
|
||||
if len(b) < MinBlockPayload || len(b) > MaxBlockPayload {
|
||||
t.Errorf("block %d: size %d, want [%d, %d]", i, len(b), MinBlockPayload, MaxBlockPayload)
|
||||
}
|
||||
}
|
||||
if len(blocks[3]) != 50 {
|
||||
t.Errorf("last block: size %d, want 50", len(blocks[3]))
|
||||
}
|
||||
// Reassembled data must equal original.
|
||||
var reassembled []byte
|
||||
for _, b := range blocks {
|
||||
reassembled = append(reassembled, b...)
|
||||
|
||||
@@ -71,7 +71,9 @@ func TestFeedGetBlockUnknownChannel(t *testing.T) {
|
||||
|
||||
func TestFeedLargeMessages(t *testing.T) {
|
||||
feed := NewFeed([]string{"Test"})
|
||||
largeText := make([]byte, 500)
|
||||
// Use text large enough to span 2 blocks at DefaultBlockPayload (currently 700 bytes).
|
||||
// Message serialization overhead is 10 bytes, so we need >690 bytes of text.
|
||||
largeText := make([]byte, 750)
|
||||
for i := range largeText {
|
||||
largeText[i] = 65
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ func NewTelegramReader(cfg TelegramConfig, channelUsernames []string, feed *Feed
|
||||
channels: cleaned,
|
||||
feed: feed,
|
||||
cache: make(map[string]cachedMessages),
|
||||
cacheTTL: 1 * time.Minute,
|
||||
cacheTTL: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ func (tr *TelegramReader) Run(ctx context.Context) error {
|
||||
tr.fetchAll(ctx, api)
|
||||
|
||||
// Periodic fetch loop
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
ticker := time.NewTicker(3 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
|
||||
+165
-17
@@ -142,6 +142,44 @@
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.progress-panel {
|
||||
height: 56px;
|
||||
min-height: 56px;
|
||||
background: var(--surface);
|
||||
border-top: 1px solid var(--border);
|
||||
border-bottom: 1px solid var(--border);
|
||||
overflow-y: auto;
|
||||
font-size: 12px;
|
||||
font-family: 'Vazirmatn', monospace;
|
||||
padding: 12px 12px;
|
||||
direction: ltr;
|
||||
text-align: left;
|
||||
color: var(--text-dim);
|
||||
}
|
||||
.progress-item {
|
||||
margin-bottom: 8px;
|
||||
padding: 6px;
|
||||
border-radius: 4px;
|
||||
background: var(--bg);
|
||||
}
|
||||
.progress-label {
|
||||
font-size: 11px;
|
||||
margin-bottom: 3px;
|
||||
color: var(--text-dim);
|
||||
}
|
||||
.progress-bar {
|
||||
width: 100%;
|
||||
height: 6px;
|
||||
background: var(--border);
|
||||
border-radius: 3px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.progress-fill {
|
||||
height: 100%;
|
||||
background: var(--accent);
|
||||
transition: width 0.2s;
|
||||
}
|
||||
|
||||
.log-panel {
|
||||
height: 150px;
|
||||
min-height: 100px;
|
||||
@@ -149,13 +187,21 @@
|
||||
border-top: 1px solid var(--border);
|
||||
overflow-y: auto;
|
||||
font-size: 12px;
|
||||
font-family: 'Vazirmatn', monospace;
|
||||
font-family: monospace;
|
||||
padding: 8px 12px;
|
||||
direction: ltr;
|
||||
text-align: left;
|
||||
color: var(--text-dim);
|
||||
}
|
||||
.log-line { padding: 2px 0; }
|
||||
.log-line {
|
||||
padding: 2px 0;
|
||||
line-height: 1.3;
|
||||
}
|
||||
.log-line.info { color: #60a5fa; }
|
||||
.log-line.progress { color: #fbbf24; }
|
||||
.log-line.error { color: #ef4444; }
|
||||
.log-line.success { color: #22c55e; }
|
||||
.log-line.warning { color: #f97316; }
|
||||
|
||||
.modal-overlay {
|
||||
display: none;
|
||||
@@ -245,6 +291,7 @@
|
||||
<button class="btn btn-primary" onclick="openSettings()">Configure</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="progress-panel" id="progressPanel"></div>
|
||||
<div class="log-panel" id="logPanel"></div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -268,12 +315,17 @@
|
||||
<label>Query Mode</label>
|
||||
<select id="cfgQueryMode">
|
||||
<option value="single">Single label (base32, stealthier)</option>
|
||||
<option value="double">Double label (hex)</option>
|
||||
<option value="double">Multi-label (hex)</option>
|
||||
<option value="plain">Plain text (no query encryption)</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label>Rate Limit (queries/sec, 0 = unlimited)</label>
|
||||
<input type="number" id="cfgRateLimit" value="0" min="0" step="0.1">
|
||||
<input type="number" id="cfgRateLimit" value="100" min="0" step="0.1">
|
||||
</div>
|
||||
<div class="form-group" style="flex-direction:row;align-items:center;gap:10px;">
|
||||
<input type="checkbox" id="cfgDebug" style="width:auto;margin:0;">
|
||||
<label for="cfgDebug" style="margin:0;cursor:pointer;">Debug mode (log generated query names)</label>
|
||||
</div>
|
||||
<div style="display:flex;gap:8px;justify-content:flex-end;margin-top:20px;">
|
||||
<button class="btn" onclick="closeSettings()">Cancel</button>
|
||||
@@ -287,6 +339,7 @@
|
||||
let selectedChannel = 0;
|
||||
let channels = [];
|
||||
let eventSource = null;
|
||||
let autoRefreshTimer = null;
|
||||
|
||||
async function init() {
|
||||
try {
|
||||
@@ -296,11 +349,16 @@
|
||||
openSettings();
|
||||
} else {
|
||||
document.getElementById('statusDot').className = 'status-dot connected';
|
||||
if (status.channels && status.channels.length > 0) {
|
||||
channels = status.channels;
|
||||
renderChannels();
|
||||
selectChannel(1);
|
||||
await doRefresh();
|
||||
await loadChannels();
|
||||
if (channels && channels.length > 0) {
|
||||
await selectChannel(1);
|
||||
}
|
||||
autoRefreshTimer = setInterval(function() {
|
||||
if (selectedChannel > 0) {
|
||||
doRefresh();
|
||||
}
|
||||
}, 120000);
|
||||
}
|
||||
} catch (e) {}
|
||||
connectSSE();
|
||||
@@ -312,12 +370,23 @@
|
||||
eventSource.addEventListener('log', function(e) {
|
||||
addLogLine(JSON.parse(e.data));
|
||||
});
|
||||
eventSource.addEventListener('update', function(e) {
|
||||
loadChannels();
|
||||
if (selectedChannel > 0) loadMessages(selectedChannel);
|
||||
eventSource.addEventListener('update', async function(e) {
|
||||
var wasEmpty = channels.length === 0;
|
||||
await loadChannels();
|
||||
if (wasEmpty && channels.length > 0 && selectedChannel === 0) {
|
||||
// Channels just appeared for the first time — auto-select the first one.
|
||||
selectChannel(1);
|
||||
} else if (selectedChannel > 0) {
|
||||
loadMessages(selectedChannel);
|
||||
}
|
||||
});
|
||||
eventSource.onerror = function() {
|
||||
document.getElementById('statusDot').className = 'status-dot disconnected';
|
||||
// EventSource.CLOSED (2) means the browser stopped retrying — reconnect manually.
|
||||
if (eventSource.readyState === EventSource.CLOSED) {
|
||||
eventSource.close();
|
||||
setTimeout(connectSSE, 3000);
|
||||
}
|
||||
};
|
||||
eventSource.onopen = function() {
|
||||
document.getElementById('statusDot').className = 'status-dot connected';
|
||||
@@ -343,6 +412,7 @@
|
||||
async function selectChannel(num) {
|
||||
selectedChannel = num;
|
||||
renderChannels();
|
||||
await doRefresh();
|
||||
await loadMessages(num);
|
||||
}
|
||||
|
||||
@@ -392,16 +462,79 @@
|
||||
var el = document.getElementById('logPanel');
|
||||
var div = document.createElement('div');
|
||||
div.className = 'log-line';
|
||||
|
||||
// Parse log level from line content
|
||||
var level = 'info';
|
||||
var displayText = line;
|
||||
|
||||
if (typeof line === 'string') {
|
||||
if (line.includes('Error:') || line.includes('error') || line.includes('Error')) {
|
||||
level = 'error';
|
||||
} else if (line.includes('Warning:') || line.includes('warning') || line.includes('Warning')) {
|
||||
level = 'warning';
|
||||
} else if (line.includes('OK:') || line.includes('OK') || line.includes('success') || line.includes('done')) {
|
||||
level = 'success';
|
||||
} else if (line.match(/\[\d+%\]|\d+\s*%/)) {
|
||||
level = 'progress';
|
||||
updateProgressDisplay(line);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
div.className = 'log-line ' + level;
|
||||
div.textContent = line;
|
||||
el.appendChild(div);
|
||||
el.scrollTop = el.scrollHeight;
|
||||
|
||||
// Keep only last 200 lines
|
||||
while (el.children.length > 200) {
|
||||
el.removeChild(el.firstChild);
|
||||
}
|
||||
}
|
||||
|
||||
function updateProgressDisplay(line) {
|
||||
// Parse progress from "Channel N [====> ] 45%"
|
||||
var match = line.match(/Channel\s+(\d+)/);
|
||||
if (!match) return;
|
||||
|
||||
var channelNum = match[1];
|
||||
var percentMatch = line.match(/(\d+)%/);
|
||||
var percent = percentMatch ? parseInt(percentMatch[1]) : 0;
|
||||
|
||||
var panel = document.getElementById('progressPanel');
|
||||
var itemId = 'progress-current';
|
||||
var item = document.getElementById(itemId);
|
||||
|
||||
if (!item) {
|
||||
item = document.createElement('div');
|
||||
item.id = itemId;
|
||||
item.className = 'progress-item';
|
||||
item.innerHTML = '<div class="progress-label">Channel ' + channelNum + '</div>' +
|
||||
'<div class="progress-bar"><div class="progress-fill" style="width:0%"></div></div>';
|
||||
panel.appendChild(item);
|
||||
} else {
|
||||
item.querySelector('.progress-label').textContent = 'Channel ' + channelNum;
|
||||
}
|
||||
|
||||
var fill = item.querySelector('.progress-fill');
|
||||
fill.style.width = percent + '%';
|
||||
|
||||
// Remove if complete
|
||||
if (percent >= 100) {
|
||||
setTimeout(function() {
|
||||
if (item.parentNode) item.parentNode.removeChild(item);
|
||||
}, 1000);
|
||||
}
|
||||
}
|
||||
|
||||
async function doRefresh() {
|
||||
try { await fetch('/api/refresh', {method: 'POST'}); } catch (e) {}
|
||||
try {
|
||||
var url = '/api/refresh';
|
||||
if (selectedChannel > 0) {
|
||||
url += '?channel=' + selectedChannel;
|
||||
}
|
||||
await fetch(url, {method: 'POST'});
|
||||
} catch (e) {}
|
||||
}
|
||||
|
||||
function openSettings() {
|
||||
@@ -411,7 +544,12 @@
|
||||
if (cfg.key) document.getElementById('cfgKey').value = cfg.key;
|
||||
if (cfg.resolvers) document.getElementById('cfgResolvers').value = cfg.resolvers.join('\n');
|
||||
if (cfg.queryMode) document.getElementById('cfgQueryMode').value = cfg.queryMode;
|
||||
if (cfg.rateLimit) document.getElementById('cfgRateLimit').value = cfg.rateLimit;
|
||||
if (typeof cfg.rateLimit === 'number') {
|
||||
document.getElementById('cfgRateLimit').value = cfg.rateLimit;
|
||||
} else {
|
||||
document.getElementById('cfgRateLimit').value = 100;
|
||||
}
|
||||
document.getElementById('cfgDebug').checked = !!cfg.debug;
|
||||
}).catch(function() {});
|
||||
}
|
||||
|
||||
@@ -432,7 +570,8 @@
|
||||
key: document.getElementById('cfgKey').value,
|
||||
resolvers: resolvers,
|
||||
queryMode: document.getElementById('cfgQueryMode').value,
|
||||
rateLimit: parseFloat(document.getElementById('cfgRateLimit').value) || 0
|
||||
rateLimit: parseFloat(document.getElementById('cfgRateLimit').value) || 100,
|
||||
debug: document.getElementById('cfgDebug').checked
|
||||
};
|
||||
|
||||
if (!cfg.domain || !cfg.key || resolvers.length === 0) {
|
||||
@@ -453,9 +592,18 @@
|
||||
}
|
||||
closeSettings();
|
||||
document.getElementById('statusDot').className = 'status-dot connected';
|
||||
setTimeout(function() {
|
||||
loadChannels();
|
||||
}, 2000);
|
||||
await doRefresh();
|
||||
await loadChannels();
|
||||
if (channels && channels.length > 0) {
|
||||
await selectChannel(1);
|
||||
}
|
||||
if (!autoRefreshTimer) {
|
||||
autoRefreshTimer = setInterval(function() {
|
||||
if (selectedChannel > 0) {
|
||||
fetch('/api/refresh?channel=' + selectedChannel + '&quiet=1', {method: 'POST'});
|
||||
}
|
||||
}, 120000);
|
||||
}
|
||||
} catch (e) {
|
||||
errEl.textContent = e.message;
|
||||
errEl.style.display = 'block';
|
||||
|
||||
+176
-50
@@ -1,6 +1,7 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -29,6 +30,11 @@ type Config struct {
|
||||
Resolvers []string `json:"resolvers"`
|
||||
QueryMode string `json:"queryMode"`
|
||||
RateLimit float64 `json:"rateLimit"`
|
||||
// Timeout is the per-query DNS timeout in seconds (0 = default 5 s).
|
||||
// Also used as the resolver health-check probe timeout.
|
||||
Timeout float64 `json:"timeout,omitempty"`
|
||||
// Debug enables verbose query logging (shows generated DNS query names).
|
||||
Debug bool `json:"debug,omitempty"`
|
||||
}
|
||||
|
||||
// Server is the web UI server for thefeed client.
|
||||
@@ -43,6 +49,16 @@ type Server struct {
|
||||
channels []protocol.ChannelInfo
|
||||
messages map[int][]protocol.Message
|
||||
|
||||
// fetcherCtx/fetcherCancel control the lifetime of the active fetcher's
|
||||
// background goroutines (rate limiter, noise, resolver checker).
|
||||
// They are cancelled and recreated each time the config changes.
|
||||
fetcherCtx context.Context
|
||||
fetcherCancel context.CancelFunc
|
||||
|
||||
// refreshMu / refreshCancel allow a new refresh to cancel an in-progress one.
|
||||
refreshMu sync.Mutex
|
||||
refreshCancel context.CancelFunc
|
||||
|
||||
logMu sync.RWMutex
|
||||
logLines []string
|
||||
|
||||
@@ -96,7 +112,7 @@ func (s *Server) Run() error {
|
||||
fmt.Printf("\n Open in browser: http://%s\n\n", addr)
|
||||
|
||||
if s.fetcher != nil {
|
||||
s.startAutoRefresh()
|
||||
go s.refreshMetadataOnly()
|
||||
}
|
||||
|
||||
return http.ListenAndServe(addr, mux)
|
||||
@@ -164,7 +180,7 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, fmt.Sprintf("init fetcher: %v", err), 500)
|
||||
return
|
||||
}
|
||||
s.startAutoRefresh()
|
||||
go s.refreshMetadataOnly()
|
||||
writeJSON(w, map[string]any{"ok": true})
|
||||
|
||||
default:
|
||||
@@ -202,7 +218,28 @@ func (s *Server) handleRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "method not allowed", 405)
|
||||
return
|
||||
}
|
||||
go s.refresh()
|
||||
// Background (quiet) refreshes skip silently if one is already running,
|
||||
// so the auto-refresh timer never cancels a slow in-progress fetch.
|
||||
if r.URL.Query().Get("quiet") == "1" {
|
||||
s.refreshMu.Lock()
|
||||
running := s.refreshCancel != nil
|
||||
s.refreshMu.Unlock()
|
||||
if running {
|
||||
writeJSON(w, map[string]any{"ok": true, "skipped": true})
|
||||
return
|
||||
}
|
||||
}
|
||||
chParam := r.URL.Query().Get("channel")
|
||||
if chParam != "" {
|
||||
chNum, err := strconv.Atoi(chParam)
|
||||
if err != nil || chNum < 1 {
|
||||
http.Error(w, "invalid channel", 400)
|
||||
return
|
||||
}
|
||||
go s.refreshChannel(chNum)
|
||||
} else {
|
||||
go s.refreshMetadataOnly()
|
||||
}
|
||||
writeJSON(w, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
@@ -278,6 +315,11 @@ func (s *Server) initFetcher() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Cancel goroutines from the previous fetcher configuration.
|
||||
if s.fetcherCancel != nil {
|
||||
s.fetcherCancel()
|
||||
}
|
||||
|
||||
cfg := s.config
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("no config")
|
||||
@@ -295,57 +337,81 @@ func (s *Server) initFetcher() error {
|
||||
}
|
||||
|
||||
if cfg.QueryMode == "double" {
|
||||
fetcher.SetQueryMode(protocol.QueryDoubleLabel)
|
||||
fetcher.SetQueryMode(protocol.QueryMultiLabel)
|
||||
} else if cfg.QueryMode == "plain" {
|
||||
fetcher.SetQueryMode(protocol.QueryPlainLabel)
|
||||
}
|
||||
fetcher.SetDebug(cfg.Debug)
|
||||
if cfg.RateLimit > 0 {
|
||||
fetcher.SetRateLimit(cfg.RateLimit)
|
||||
}
|
||||
|
||||
timeout := 5 * time.Second
|
||||
if cfg.Timeout > 0 {
|
||||
timeout = time.Duration(cfg.Timeout * float64(time.Second))
|
||||
}
|
||||
fetcher.SetTimeout(timeout)
|
||||
|
||||
fetcher.SetLogFunc(func(msg string) {
|
||||
s.addLog(msg)
|
||||
})
|
||||
|
||||
// Create a shared context for this fetcher's lifetime.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.fetcherCtx = ctx
|
||||
s.fetcherCancel = cancel
|
||||
|
||||
// Start rate limiter and noise goroutines.
|
||||
fetcher.Start(ctx)
|
||||
|
||||
// Start periodic resolver health checks.
|
||||
checker := client.NewResolverChecker(fetcher, timeout)
|
||||
checker.SetLogFunc(func(msg string) {
|
||||
s.addLog(msg)
|
||||
})
|
||||
checker.Start(ctx)
|
||||
|
||||
s.fetcher = fetcher
|
||||
s.cache = cache
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) startAutoRefresh() {
|
||||
if s.stopRefresh != nil {
|
||||
close(s.stopRefresh)
|
||||
func (s *Server) refreshMetadataOnly() {
|
||||
// Cancel any in-progress refresh and start a new cancellable one.
|
||||
s.refreshMu.Lock()
|
||||
if s.refreshCancel != nil {
|
||||
s.refreshCancel()
|
||||
}
|
||||
s.stopRefresh = make(chan struct{})
|
||||
stop := s.stopRefresh
|
||||
|
||||
go s.refresh()
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.refresh()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) refresh() {
|
||||
s.mu.RLock()
|
||||
basectx := s.fetcherCtx
|
||||
fetcher := s.fetcher
|
||||
cache := s.cache
|
||||
s.mu.RUnlock()
|
||||
|
||||
if fetcher == nil {
|
||||
if fetcher == nil || basectx == nil {
|
||||
s.refreshMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Child context: cancelled either by the next refresh call or by a config change.
|
||||
ctx, cancel := context.WithCancel(basectx)
|
||||
s.refreshCancel = cancel
|
||||
s.refreshMu.Unlock()
|
||||
defer func() {
|
||||
cancel()
|
||||
s.refreshMu.Lock()
|
||||
s.refreshCancel = nil
|
||||
s.refreshMu.Unlock()
|
||||
}()
|
||||
|
||||
s.addLog("Fetching metadata...")
|
||||
meta, err := fetcher.FetchMetadata()
|
||||
meta, err := fetcher.FetchMetadata(ctx)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
s.addLog("Refresh cancelled")
|
||||
return
|
||||
}
|
||||
s.addLog(fmt.Sprintf("Error: %v", err))
|
||||
return
|
||||
}
|
||||
@@ -359,31 +425,91 @@ func (s *Server) refresh() {
|
||||
}
|
||||
|
||||
s.broadcast("event: update\ndata: \"channels\"\n\n")
|
||||
}
|
||||
|
||||
for i, ch := range meta.Channels {
|
||||
chNum := i + 1
|
||||
blockCount := int(ch.Blocks)
|
||||
if blockCount <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
msgs, err := fetcher.FetchChannel(chNum, blockCount)
|
||||
if err != nil {
|
||||
s.addLog(fmt.Sprintf("Channel %s error: %v", ch.Name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.messages[chNum] = msgs
|
||||
s.mu.Unlock()
|
||||
|
||||
if cache != nil {
|
||||
_ = cache.PutMessages(chNum, msgs)
|
||||
}
|
||||
|
||||
s.addLog(fmt.Sprintf("Updated %s: %d messages", ch.Name, len(msgs)))
|
||||
func (s *Server) refreshChannel(channelNum int) {
|
||||
s.refreshMu.Lock()
|
||||
if s.refreshCancel != nil {
|
||||
s.refreshCancel()
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
basectx := s.fetcherCtx
|
||||
fetcher := s.fetcher
|
||||
cache := s.cache
|
||||
channels := s.channels
|
||||
s.mu.RUnlock()
|
||||
|
||||
if fetcher == nil || basectx == nil {
|
||||
s.refreshMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(basectx)
|
||||
s.refreshCancel = cancel
|
||||
s.refreshMu.Unlock()
|
||||
defer func() {
|
||||
cancel()
|
||||
s.refreshMu.Lock()
|
||||
s.refreshCancel = nil
|
||||
s.refreshMu.Unlock()
|
||||
}()
|
||||
|
||||
meta, err := fetcher.FetchMetadata(ctx)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
s.addLog("Refresh cancelled")
|
||||
return
|
||||
}
|
||||
s.addLog(fmt.Sprintf("Error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.channels = meta.Channels
|
||||
s.mu.Unlock()
|
||||
|
||||
if cache != nil {
|
||||
_ = cache.PutMetadata(meta)
|
||||
}
|
||||
s.broadcast("event: update\ndata: \"channels\"\n\n")
|
||||
|
||||
channels = meta.Channels
|
||||
if channelNum < 1 || channelNum > len(channels) {
|
||||
s.addLog(fmt.Sprintf("Warning: channel %d is not available", channelNum))
|
||||
return
|
||||
}
|
||||
|
||||
ch := channels[channelNum-1]
|
||||
blockCount := int(ch.Blocks)
|
||||
if blockCount <= 0 {
|
||||
s.mu.Lock()
|
||||
s.messages[channelNum] = nil
|
||||
s.mu.Unlock()
|
||||
s.addLog(fmt.Sprintf("Updated %s: 0 messages", ch.Name))
|
||||
s.broadcast("event: update\ndata: \"messages\"\n\n")
|
||||
return
|
||||
}
|
||||
|
||||
msgs, err := fetcher.FetchChannel(ctx, channelNum, blockCount)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
s.addLog("Refresh cancelled")
|
||||
return
|
||||
}
|
||||
s.addLog(fmt.Sprintf("Channel %s error: %v", ch.Name, err))
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.messages[channelNum] = msgs
|
||||
s.mu.Unlock()
|
||||
|
||||
if cache != nil {
|
||||
_ = cache.PutMessages(channelNum, msgs)
|
||||
}
|
||||
|
||||
s.addLog(fmt.Sprintf("Updated %s: %d messages", ch.Name, len(msgs)))
|
||||
s.broadcast("event: update\ndata: \"messages\"\n\n")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user