From 10370ebb3db53b84cf39a736bfe293edf0a6c270 Mon Sep 17 00:00:00 2001 From: Reza Behzadan Date: Tue, 6 Aug 2024 22:58:45 +0330 Subject: [PATCH] Add whitelist --- VERSION | 2 +- go.mod | 5 +++ go.sum | 4 +++ main.go | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++------ 4 files changed, 109 insertions(+), 11 deletions(-) create mode 100644 go.sum diff --git a/VERSION b/VERSION index 9df886c..26ca594 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.4.2 +1.5.1 diff --git a/go.mod b/go.mod index cad92d6..83f6a2d 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,8 @@ module tcproxy go 1.22.0 + +require ( + github.com/fsnotify/fsnotify v1.7.0 // indirect + golang.org/x/sys v0.4.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ccd7ce9 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/main.go b/main.go index 82b9b31..48055cd 100644 --- a/main.go +++ b/main.go @@ -16,15 +16,30 @@ import ( "strings" "sync" "time" + + "github.com/fsnotify/fsnotify" ) //go:embed VERSION var versionFile embed.FS + var ( NAME = "tcproxy" VERSION string + + whitelist = make(map[string]struct{}) + whitelistFile = "whitelist.txt" + mu sync.RWMutex ) +func init() { + var err error + VERSION, err = readVersion() + if err != nil { + VERSION = "xx.yy.zz" + } +} + func printHelp() { helpText := fmt.Sprintf(`Usage: %s [options] [path] Options: @@ -40,8 +55,7 @@ func readVersion() (string, error) { if err != nil { return "", err } - VERSION = strings.TrimSpace(string(version)) - return VERSION, nil + return strings.TrimSpace(string(version)), nil } func printVersion() { @@ -77,12 +91,6 @@ func main() { bindAddrShort := flag.String("b", "localhost:8443", "The address to bind the server") flag.Parse() - _, err := readVersion() - if err != nil { - fmt.Fprintf(os.Stderr, "Error reading version: %v\n", err) - os.Exit(1) - } - if *bindAddrShort != "localhost:8443" { bindAddr = bindAddrShort } @@ -95,16 +103,23 @@ func main() { os.Exit(0) } + err := loadWhitelist() + if err != nil { + log.Fatalf("Failed to load whitelist: %v", err) + } + go watchWhitelist() + l, err := net.Listen("tcp", *bindAddr) if err != nil { - log.Fatal(err) + log.Fatalf("Failed to bind to address %s: %v", *bindAddr, err) } printVersion() log.Printf("Listening on %s for incoming HTTPS (or HTTP) connections...", *bindAddr) + for { conn, err := l.Accept() if err != nil { - log.Print(err) + log.Printf("Failed to accept connection: %v", err) continue } go handleConnection(conn) @@ -233,6 +248,17 @@ func handleConnection(clientConn net.Conn) { clientConn.Close() return } + clientIP := tcpAddr.IP.String() + + mu.RLock() + _, allowed := whitelist[clientIP] + mu.RUnlock() + + if !allowed { + log.Printf("Connection from %s denied", clientIP) + return + } + log.Printf("Received connection from %s:%d", tcpAddr.IP, tcpAddr.Port) // Create a buffer to peek the first byte @@ -256,3 +282,66 @@ func handleConnection(clientConn net.Conn) { } } + +func loadWhitelist() error { + file, err := os.Open(whitelistFile) + if err != nil { + return err + } + defer file.Close() + + newWhitelist := make(map[string]struct{}) + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + ip := strings.TrimSpace(scanner.Text()) + if ip != "" && !strings.HasPrefix(ip, "#") { + newWhitelist[ip] = struct{}{} + } + } + + if err := scanner.Err(); err != nil { + return err + } + + mu.Lock() + whitelist = newWhitelist + mu.Unlock() + + log.Printf("Whitelist updated: %v", newWhitelist) + + return nil +} + +func watchWhitelist() { + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Fatalf("Failed to create file watcher: %v", err) + } + defer watcher.Close() + + err = watcher.Add(whitelistFile) + if err != nil { + log.Fatalf("Failed to watch file: %v", err) + } + + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Op&fsnotify.Write == fsnotify.Write { + log.Println("Whitelist file modified") + if err := loadWhitelist(); err != nil { + log.Printf("Failed to reload whitelist: %v", err) + } + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + log.Printf("File watcher error: %v", err) + } + } +}