Add whitelist
This commit is contained in:
		
							parent
							
								
									c13d83820a
								
							
						
					
					
						commit
						10370ebb3d
					
				
							
								
								
									
										5
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								go.mod
									
									
									
									
									
								
							@ -1,3 +1,8 @@
 | 
				
			|||||||
module tcproxy
 | 
					module tcproxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
go 1.22.0
 | 
					go 1.22.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					require (
 | 
				
			||||||
 | 
						github.com/fsnotify/fsnotify v1.7.0 // indirect
 | 
				
			||||||
 | 
						golang.org/x/sys v0.4.0 // indirect
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
				
			|||||||
 | 
					github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
 | 
				
			||||||
 | 
					github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
 | 
				
			||||||
 | 
					golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
 | 
				
			||||||
 | 
					golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
				
			||||||
							
								
								
									
										109
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										109
									
								
								main.go
									
									
									
									
									
								
							@ -16,15 +16,30 @@ import (
 | 
				
			|||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/fsnotify/fsnotify"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//go:embed VERSION
 | 
					//go:embed VERSION
 | 
				
			||||||
var versionFile embed.FS
 | 
					var versionFile embed.FS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	NAME    = "tcproxy"
 | 
						NAME    = "tcproxy"
 | 
				
			||||||
	VERSION string
 | 
						VERSION string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						whitelist     = make(map[string]struct{})
 | 
				
			||||||
 | 
						whitelistFile = "whitelist.txt"
 | 
				
			||||||
 | 
						mu            sync.RWMutex
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						VERSION, err = readVersion()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							VERSION = "xx.yy.zz"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func printHelp() {
 | 
					func printHelp() {
 | 
				
			||||||
	helpText := fmt.Sprintf(`Usage: %s [options] [path]
 | 
						helpText := fmt.Sprintf(`Usage: %s [options] [path]
 | 
				
			||||||
Options:
 | 
					Options:
 | 
				
			||||||
@ -40,8 +55,7 @@ func readVersion() (string, error) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	VERSION = strings.TrimSpace(string(version))
 | 
						return strings.TrimSpace(string(version)), nil
 | 
				
			||||||
	return VERSION, nil
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func printVersion() {
 | 
					func printVersion() {
 | 
				
			||||||
@ -77,12 +91,6 @@ func main() {
 | 
				
			|||||||
	bindAddrShort := flag.String("b", "localhost:8443", "The address to bind the server")
 | 
						bindAddrShort := flag.String("b", "localhost:8443", "The address to bind the server")
 | 
				
			||||||
	flag.Parse()
 | 
						flag.Parse()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, err := readVersion()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		fmt.Fprintf(os.Stderr, "Error reading version: %v\n", err)
 | 
					 | 
				
			||||||
		os.Exit(1)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if *bindAddrShort != "localhost:8443" {
 | 
						if *bindAddrShort != "localhost:8443" {
 | 
				
			||||||
		bindAddr = bindAddrShort
 | 
							bindAddr = bindAddrShort
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -95,16 +103,23 @@ func main() {
 | 
				
			|||||||
		os.Exit(0)
 | 
							os.Exit(0)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := loadWhitelist()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Fatalf("Failed to load whitelist: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						go watchWhitelist()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	l, err := net.Listen("tcp", *bindAddr)
 | 
						l, err := net.Listen("tcp", *bindAddr)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Fatal(err)
 | 
							log.Fatalf("Failed to bind to address %s: %v", *bindAddr, err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	printVersion()
 | 
						printVersion()
 | 
				
			||||||
	log.Printf("Listening on %s for incoming HTTPS (or HTTP) connections...", *bindAddr)
 | 
						log.Printf("Listening on %s for incoming HTTPS (or HTTP) connections...", *bindAddr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
		conn, err := l.Accept()
 | 
							conn, err := l.Accept()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Print(err)
 | 
								log.Printf("Failed to accept connection: %v", err)
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		go handleConnection(conn)
 | 
							go handleConnection(conn)
 | 
				
			||||||
@ -233,6 +248,17 @@ func handleConnection(clientConn net.Conn) {
 | 
				
			|||||||
		clientConn.Close()
 | 
							clientConn.Close()
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						clientIP := tcpAddr.IP.String()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						mu.RLock()
 | 
				
			||||||
 | 
						_, allowed := whitelist[clientIP]
 | 
				
			||||||
 | 
						mu.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !allowed {
 | 
				
			||||||
 | 
							log.Printf("Connection from %s denied", clientIP)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Printf("Received connection from %s:%d", tcpAddr.IP, tcpAddr.Port)
 | 
						log.Printf("Received connection from %s:%d", tcpAddr.IP, tcpAddr.Port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Create a buffer to peek the first byte
 | 
						// Create a buffer to peek the first byte
 | 
				
			||||||
@ -256,3 +282,66 @@ func handleConnection(clientConn net.Conn) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func loadWhitelist() error {
 | 
				
			||||||
 | 
						file, err := os.Open(whitelistFile)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer file.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						newWhitelist := make(map[string]struct{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						scanner := bufio.NewScanner(file)
 | 
				
			||||||
 | 
						for scanner.Scan() {
 | 
				
			||||||
 | 
							ip := strings.TrimSpace(scanner.Text())
 | 
				
			||||||
 | 
							if ip != "" && !strings.HasPrefix(ip, "#") {
 | 
				
			||||||
 | 
								newWhitelist[ip] = struct{}{}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := scanner.Err(); err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						mu.Lock()
 | 
				
			||||||
 | 
						whitelist = newWhitelist
 | 
				
			||||||
 | 
						mu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Printf("Whitelist updated: %v", newWhitelist)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func watchWhitelist() {
 | 
				
			||||||
 | 
						watcher, err := fsnotify.NewWatcher()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Fatalf("Failed to create file watcher: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer watcher.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = watcher.Add(whitelistFile)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Fatalf("Failed to watch file: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case event, ok := <-watcher.Events:
 | 
				
			||||||
 | 
								if !ok {
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if event.Op&fsnotify.Write == fsnotify.Write {
 | 
				
			||||||
 | 
									log.Println("Whitelist file modified")
 | 
				
			||||||
 | 
									if err := loadWhitelist(); err != nil {
 | 
				
			||||||
 | 
										log.Printf("Failed to reload whitelist: %v", err)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case err, ok := <-watcher.Errors:
 | 
				
			||||||
 | 
								if !ok {
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								log.Printf("File watcher error: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user