Add whitelist
This commit is contained in:
parent
c13d83820a
commit
10370ebb3d
5
go.mod
5
go.mod
@ -1,3 +1,8 @@
|
||||
module tcproxy
|
||||
|
||||
go 1.22.0
|
||||
|
||||
require (
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
golang.org/x/sys v0.4.0 // indirect
|
||||
)
|
||||
|
4
go.sum
Normal file
4
go.sum
Normal file
@ -0,0 +1,4 @@
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
|
||||
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
109
main.go
109
main.go
@ -16,15 +16,30 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
//go:embed VERSION
|
||||
var versionFile embed.FS
|
||||
|
||||
var (
|
||||
NAME = "tcproxy"
|
||||
VERSION string
|
||||
|
||||
whitelist = make(map[string]struct{})
|
||||
whitelistFile = "whitelist.txt"
|
||||
mu sync.RWMutex
|
||||
)
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
VERSION, err = readVersion()
|
||||
if err != nil {
|
||||
VERSION = "xx.yy.zz"
|
||||
}
|
||||
}
|
||||
|
||||
func printHelp() {
|
||||
helpText := fmt.Sprintf(`Usage: %s [options] [path]
|
||||
Options:
|
||||
@ -40,8 +55,7 @@ func readVersion() (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
VERSION = strings.TrimSpace(string(version))
|
||||
return VERSION, nil
|
||||
return strings.TrimSpace(string(version)), nil
|
||||
}
|
||||
|
||||
func printVersion() {
|
||||
@ -77,12 +91,6 @@ func main() {
|
||||
bindAddrShort := flag.String("b", "localhost:8443", "The address to bind the server")
|
||||
flag.Parse()
|
||||
|
||||
_, err := readVersion()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading version: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *bindAddrShort != "localhost:8443" {
|
||||
bindAddr = bindAddrShort
|
||||
}
|
||||
@ -95,16 +103,23 @@ func main() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
err := loadWhitelist()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load whitelist: %v", err)
|
||||
}
|
||||
go watchWhitelist()
|
||||
|
||||
l, err := net.Listen("tcp", *bindAddr)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
log.Fatalf("Failed to bind to address %s: %v", *bindAddr, err)
|
||||
}
|
||||
printVersion()
|
||||
log.Printf("Listening on %s for incoming HTTPS (or HTTP) connections...", *bindAddr)
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
log.Printf("Failed to accept connection: %v", err)
|
||||
continue
|
||||
}
|
||||
go handleConnection(conn)
|
||||
@ -233,6 +248,17 @@ func handleConnection(clientConn net.Conn) {
|
||||
clientConn.Close()
|
||||
return
|
||||
}
|
||||
clientIP := tcpAddr.IP.String()
|
||||
|
||||
mu.RLock()
|
||||
_, allowed := whitelist[clientIP]
|
||||
mu.RUnlock()
|
||||
|
||||
if !allowed {
|
||||
log.Printf("Connection from %s denied", clientIP)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Received connection from %s:%d", tcpAddr.IP, tcpAddr.Port)
|
||||
|
||||
// Create a buffer to peek the first byte
|
||||
@ -256,3 +282,66 @@ func handleConnection(clientConn net.Conn) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func loadWhitelist() error {
|
||||
file, err := os.Open(whitelistFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
newWhitelist := make(map[string]struct{})
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
ip := strings.TrimSpace(scanner.Text())
|
||||
if ip != "" && !strings.HasPrefix(ip, "#") {
|
||||
newWhitelist[ip] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
whitelist = newWhitelist
|
||||
mu.Unlock()
|
||||
|
||||
log.Printf("Whitelist updated: %v", newWhitelist)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func watchWhitelist() {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create file watcher: %v", err)
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
err = watcher.Add(whitelistFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to watch file: %v", err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||
log.Println("Whitelist file modified")
|
||||
if err := loadWhitelist(); err != nil {
|
||||
log.Printf("Failed to reload whitelist: %v", err)
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Printf("File watcher error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user