mirror of
https://github.com/sartoopjj/thefeed.git
synced 2026-05-19 05:04:35 +03:00
feat: 🎉 first version
This commit is contained in:
@@ -0,0 +1,122 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sartoopjj/thefeed/internal/protocol"
|
||||
)
|
||||
|
||||
// Cache provides file-based caching for channel data.
|
||||
type Cache struct {
|
||||
dir string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type cachedChannel struct {
|
||||
Messages []protocol.Message `json:"messages"`
|
||||
FetchedAt int64 `json:"fetched_at"`
|
||||
}
|
||||
|
||||
type cachedMeta struct {
|
||||
Metadata *protocol.Metadata `json:"metadata"`
|
||||
FetchedAt int64 `json:"fetched_at"`
|
||||
}
|
||||
|
||||
// NewCache creates a file cache in the given directory.
|
||||
func NewCache(dir string) (*Cache, error) {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return nil, fmt.Errorf("create cache dir: %w", err)
|
||||
}
|
||||
return &Cache{dir: dir}, nil
|
||||
}
|
||||
|
||||
// GetMessages returns cached messages for a channel, or nil if expired.
|
||||
func (c *Cache) GetMessages(channelNum int, maxAge time.Duration) []protocol.Message {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
path := c.channelPath(channelNum)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cached cachedChannel
|
||||
if err := json.Unmarshal(data, &cached); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if maxAge > 0 && time.Since(time.Unix(cached.FetchedAt, 0)) > maxAge {
|
||||
return nil
|
||||
}
|
||||
|
||||
return cached.Messages
|
||||
}
|
||||
|
||||
// PutMessages stores messages for a channel.
|
||||
func (c *Cache) PutMessages(channelNum int, msgs []protocol.Message) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
cached := cachedChannel{
|
||||
Messages: msgs,
|
||||
FetchedAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cached)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(c.channelPath(channelNum), data, 0600)
|
||||
}
|
||||
|
||||
// GetMetadata returns cached metadata, or nil if expired.
|
||||
func (c *Cache) GetMetadata(maxAge time.Duration) *protocol.Metadata {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
path := filepath.Join(c.dir, "metadata.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cached cachedMeta
|
||||
if err := json.Unmarshal(data, &cached); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if maxAge > 0 && time.Since(time.Unix(cached.FetchedAt, 0)) > maxAge {
|
||||
return nil
|
||||
}
|
||||
|
||||
return cached.Metadata
|
||||
}
|
||||
|
||||
// PutMetadata stores metadata.
|
||||
func (c *Cache) PutMetadata(meta *protocol.Metadata) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
cached := cachedMeta{
|
||||
Metadata: meta,
|
||||
FetchedAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cached)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(filepath.Join(c.dir, "metadata.json"), data, 0600)
|
||||
}
|
||||
|
||||
func (c *Cache) channelPath(channelNum int) string {
|
||||
return filepath.Join(c.dir, fmt.Sprintf("channel_%d.json", channelNum))
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sartoopjj/thefeed/internal/protocol"
|
||||
)
|
||||
|
||||
func TestCacheMessages(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cache, err := NewCache(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCache: %v", err)
|
||||
}
|
||||
msgs := []protocol.Message{
|
||||
{ID: 1, Timestamp: 1700000000, Text: "Hello"},
|
||||
{ID: 2, Timestamp: 1700000060, Text: "World"},
|
||||
}
|
||||
if err := cache.PutMessages(1, msgs); err != nil {
|
||||
t.Fatalf("PutMessages: %v", err)
|
||||
}
|
||||
cached := cache.GetMessages(1, 1*time.Hour)
|
||||
if cached == nil {
|
||||
t.Fatal("expected cached messages")
|
||||
}
|
||||
if len(cached) != 2 {
|
||||
t.Fatalf("got %d messages, want 2", len(cached))
|
||||
}
|
||||
if cached[0].Text != "Hello" || cached[1].Text != "World" {
|
||||
t.Error("cached message text mismatch")
|
||||
}
|
||||
if cache.GetMessages(2, 1*time.Hour) != nil {
|
||||
t.Error("expected nil for uncached channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMetadata(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cache, err := NewCache(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
meta := &protocol.Metadata{
|
||||
Marker: [3]byte{1, 2, 3},
|
||||
Timestamp: 1700000000,
|
||||
Channels: []protocol.ChannelInfo{
|
||||
{Name: "test", Blocks: 5, LastMsgID: 100},
|
||||
},
|
||||
}
|
||||
if err := cache.PutMetadata(meta); err != nil {
|
||||
t.Fatalf("PutMetadata: %v", err)
|
||||
}
|
||||
cached := cache.GetMetadata(1 * time.Hour)
|
||||
if cached == nil {
|
||||
t.Fatal("expected cached metadata")
|
||||
}
|
||||
if cached.Timestamp != 1700000000 {
|
||||
t.Errorf("timestamp: got %d, want 1700000000", cached.Timestamp)
|
||||
}
|
||||
if len(cached.Channels) != 1 || cached.Channels[0].Name != "test" {
|
||||
t.Error("metadata channel mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheDirCreation(t *testing.T) {
|
||||
dir := t.TempDir() + "/sub/dir"
|
||||
_, err := NewCache(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCache should create dirs: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
t.Error("cache dir should be created")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,252 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/sartoopjj/thefeed/internal/protocol"
|
||||
)
|
||||
|
||||
// LogFunc is a callback for logging DNS queries (for debug/TUI).
|
||||
type LogFunc func(msg string)
|
||||
|
||||
// Fetcher fetches feed blocks over DNS.
|
||||
type Fetcher struct {
|
||||
domain string
|
||||
queryKey [protocol.KeySize]byte
|
||||
responseKey [protocol.KeySize]byte
|
||||
queryMode protocol.QueryEncoding
|
||||
|
||||
mu sync.RWMutex
|
||||
resolvers []string
|
||||
timeout time.Duration
|
||||
|
||||
// Rate limiting
|
||||
rateMu sync.Mutex
|
||||
queryDelay time.Duration
|
||||
lastQuery time.Time
|
||||
|
||||
// Debug logging
|
||||
logFunc LogFunc
|
||||
}
|
||||
|
||||
// NewFetcher creates a new DNS block fetcher.
|
||||
func NewFetcher(domain, passphrase string, resolvers []string) (*Fetcher, error) {
|
||||
qk, rk, err := protocol.DeriveKeys(passphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("derive keys: %w", err)
|
||||
}
|
||||
|
||||
return &Fetcher{
|
||||
domain: strings.TrimSuffix(domain, "."),
|
||||
queryKey: qk,
|
||||
responseKey: rk,
|
||||
queryMode: protocol.QuerySingleLabel,
|
||||
resolvers: resolvers,
|
||||
timeout: 5 * time.Second,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetRateLimit sets the maximum queries per second (0 = unlimited).
|
||||
func (f *Fetcher) SetRateLimit(qps float64) {
|
||||
if qps <= 0 {
|
||||
f.queryDelay = 0
|
||||
return
|
||||
}
|
||||
f.queryDelay = time.Duration(float64(time.Second) / qps)
|
||||
}
|
||||
|
||||
// SetLogFunc sets the debug log callback.
|
||||
func (f *Fetcher) SetLogFunc(fn LogFunc) {
|
||||
f.logFunc = fn
|
||||
}
|
||||
|
||||
// SetQueryMode sets the DNS query encoding mode.
|
||||
func (f *Fetcher) SetQueryMode(mode protocol.QueryEncoding) {
|
||||
f.queryMode = mode
|
||||
}
|
||||
|
||||
func (f *Fetcher) log(format string, args ...any) {
|
||||
if f.logFunc != nil {
|
||||
f.logFunc(fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Fetcher) rateWait() {
|
||||
if f.queryDelay <= 0 {
|
||||
return
|
||||
}
|
||||
f.rateMu.Lock()
|
||||
defer f.rateMu.Unlock()
|
||||
elapsed := time.Since(f.lastQuery)
|
||||
if elapsed < f.queryDelay {
|
||||
time.Sleep(f.queryDelay - elapsed)
|
||||
}
|
||||
f.lastQuery = time.Now()
|
||||
}
|
||||
|
||||
// SetResolvers replaces the resolver list.
|
||||
func (f *Fetcher) SetResolvers(resolvers []string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.resolvers = resolvers
|
||||
}
|
||||
|
||||
// Resolvers returns the current resolver list.
|
||||
func (f *Fetcher) Resolvers() []string {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
result := make([]string, len(f.resolvers))
|
||||
copy(result, f.resolvers)
|
||||
return result
|
||||
}
|
||||
|
||||
// FetchBlock fetches a single block from a channel.
|
||||
func (f *Fetcher) FetchBlock(channel, block uint16) ([]byte, error) {
|
||||
f.rateWait()
|
||||
|
||||
qname, err := protocol.EncodeQuery(f.queryKey, channel, block, f.domain, f.queryMode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode query: %w", err)
|
||||
}
|
||||
|
||||
f.log("Q ch=%d blk=%d → %s", channel, block, qname)
|
||||
|
||||
resolvers := f.Resolvers()
|
||||
if len(resolvers) == 0 {
|
||||
return nil, fmt.Errorf("no resolvers configured")
|
||||
}
|
||||
|
||||
// Shuffle resolvers to distribute load
|
||||
shuffled := make([]string, len(resolvers))
|
||||
copy(shuffled, resolvers)
|
||||
rand.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })
|
||||
|
||||
var lastErr error
|
||||
for _, resolver := range shuffled {
|
||||
data, err := f.queryResolver(resolver, qname)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("all resolvers failed, last error: %w", lastErr)
|
||||
}
|
||||
|
||||
// FetchMetadata fetches and parses metadata (channel 0).
|
||||
func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) {
|
||||
data, err := f.FetchBlock(protocol.MetadataChannel, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch metadata block 0: %w", err)
|
||||
}
|
||||
|
||||
meta, err := protocol.ParseMetadata(data)
|
||||
if err == nil {
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// Metadata might span multiple blocks
|
||||
allData := make([]byte, len(data))
|
||||
copy(allData, data)
|
||||
|
||||
for blk := uint16(1); blk < 10; blk++ {
|
||||
block, fetchErr := f.FetchBlock(protocol.MetadataChannel, blk)
|
||||
if fetchErr != nil {
|
||||
break
|
||||
}
|
||||
allData = append(allData, block...)
|
||||
meta, parseErr := protocol.ParseMetadata(allData)
|
||||
if parseErr == nil {
|
||||
return meta, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("could not parse metadata: %w", err)
|
||||
}
|
||||
|
||||
// FetchChannel fetches all blocks for a channel and parses messages.
|
||||
func (f *Fetcher) FetchChannel(channelNum int, blockCount int) ([]protocol.Message, error) {
|
||||
if blockCount <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type result struct {
|
||||
idx int
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
results := make(chan result, blockCount)
|
||||
// Limit concurrency to 3 to reduce DNS burst traffic
|
||||
sem := make(chan struct{}, 3)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < blockCount; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
data, err := f.FetchBlock(uint16(channelNum), uint16(idx))
|
||||
results <- result{idx: idx, data: data, err: err}
|
||||
}(i)
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
ordered := make([][]byte, blockCount)
|
||||
for r := range results {
|
||||
if r.err != nil {
|
||||
return nil, fmt.Errorf("fetch block %d: %w", r.idx, r.err)
|
||||
}
|
||||
ordered[r.idx] = r.data
|
||||
}
|
||||
|
||||
var allData []byte
|
||||
for _, block := range ordered {
|
||||
allData = append(allData, block...)
|
||||
}
|
||||
|
||||
return protocol.ParseMessages(allData)
|
||||
}
|
||||
|
||||
func (f *Fetcher) queryResolver(resolver, qname string) ([]byte, error) {
|
||||
if !strings.Contains(resolver, ":") {
|
||||
resolver = resolver + ":53"
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
c.Timeout = f.timeout
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(qname), dns.TypeTXT)
|
||||
m.RecursionDesired = true
|
||||
|
||||
resp, _, err := c.Exchange(m, resolver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dns exchange with %s: %w", resolver, err)
|
||||
}
|
||||
|
||||
if resp.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("dns error from %s: %s", resolver, dns.RcodeToString[resp.Rcode])
|
||||
}
|
||||
|
||||
for _, ans := range resp.Answer {
|
||||
if txt, ok := ans.(*dns.TXT); ok {
|
||||
encoded := strings.Join(txt.Txt, "")
|
||||
return protocol.DecodeResponse(f.responseKey, encoded)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no TXT record in response from %s", resolver)
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ResolverScanner scans CIDR ranges to find working DNS resolvers.
|
||||
type ResolverScanner struct {
|
||||
fetcher *Fetcher
|
||||
concurrency int
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewResolverScanner creates a resolver scanner.
|
||||
func NewResolverScanner(fetcher *Fetcher, concurrency int) *ResolverScanner {
|
||||
if concurrency <= 0 {
|
||||
concurrency = 50
|
||||
}
|
||||
return &ResolverScanner{
|
||||
fetcher: fetcher,
|
||||
concurrency: concurrency,
|
||||
timeout: 3 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// ScanCIDR scans a CIDR range for working DNS resolvers.
|
||||
func (rs *ResolverScanner) ScanCIDR(cidr string, onFound func(ip string)) error {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse CIDR %q: %w", cidr, err)
|
||||
}
|
||||
|
||||
ips := expandCIDR(ipNet)
|
||||
return rs.scanIPs(ips, onFound)
|
||||
}
|
||||
|
||||
// ScanFile scans resolver IPs from a file (one per line, supports CIDR notation).
|
||||
func (rs *ResolverScanner) ScanFile(path string, onFound func(ip string)) error {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var ips []string
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(line, "/") {
|
||||
_, ipNet, err := net.ParseCIDR(line)
|
||||
if err != nil {
|
||||
log.Printf("[resolver] skip invalid CIDR: %s", line)
|
||||
continue
|
||||
}
|
||||
ips = append(ips, expandCIDR(ipNet)...)
|
||||
} else {
|
||||
ips = append(ips, line)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return rs.scanIPs(ips, onFound)
|
||||
}
|
||||
|
||||
// CheckResolver tests if a single resolver works by querying metadata.
|
||||
func (rs *ResolverScanner) CheckResolver(ip string) bool {
|
||||
if !strings.Contains(ip, ":") {
|
||||
ip = ip + ":53"
|
||||
}
|
||||
|
||||
// Create a new fetcher with only this resolver to avoid copying the lock.
|
||||
tmpFetcher := &Fetcher{
|
||||
domain: rs.fetcher.domain,
|
||||
queryKey: rs.fetcher.queryKey,
|
||||
responseKey: rs.fetcher.responseKey,
|
||||
resolvers: []string{ip},
|
||||
timeout: rs.timeout,
|
||||
}
|
||||
|
||||
_, err := tmpFetcher.FetchBlock(0, 0)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (rs *ResolverScanner) scanIPs(ips []string, onFound func(ip string)) error {
|
||||
if len(ips) == 0 {
|
||||
return fmt.Errorf("no IPs to scan")
|
||||
}
|
||||
|
||||
var found atomic.Int32
|
||||
sem := make(chan struct{}, rs.concurrency)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, ip := range ips {
|
||||
wg.Add(1)
|
||||
go func(ip string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
if rs.CheckResolver(ip) {
|
||||
found.Add(1)
|
||||
if onFound != nil {
|
||||
onFound(ip)
|
||||
}
|
||||
}
|
||||
}(ip)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if found.Load() == 0 {
|
||||
return fmt.Errorf("no working resolvers found among %d IPs", len(ips))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadResolversFile loads resolver IPs from a file (one per line).
|
||||
func LoadResolversFile(path string) ([]string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var resolvers []string
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
resolvers = append(resolvers, line)
|
||||
}
|
||||
return resolvers, scanner.Err()
|
||||
}
|
||||
|
||||
func expandCIDR(ipNet *net.IPNet) []string {
|
||||
var ips []string
|
||||
ip := ipNet.IP.Mask(ipNet.Mask)
|
||||
|
||||
for ip := cloneIP(ip); ipNet.Contains(ip); incIP(ip) {
|
||||
// Skip network and broadcast addresses for /24 and smaller
|
||||
ones, bits := ipNet.Mask.Size()
|
||||
if bits-ones <= 8 {
|
||||
last := ip[len(ip)-1]
|
||||
if last == 0 || last == 255 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
ips = append(ips, ip.String())
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
func cloneIP(ip net.IP) net.IP {
|
||||
dup := make(net.IP, len(ip))
|
||||
copy(dup, ip)
|
||||
return dup
|
||||
}
|
||||
|
||||
func incIP(ip net.IP) {
|
||||
for j := len(ip) - 1; j >= 0; j-- {
|
||||
ip[j]++
|
||||
if ip[j] > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user