Add whitelist
This commit is contained in:
parent
c13d83820a
commit
10370ebb3d
5
go.mod
5
go.mod
@ -1,3 +1,8 @@
|
|||||||
module tcproxy
|
module tcproxy
|
||||||
|
|
||||||
go 1.22.0
|
go 1.22.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||||
|
golang.org/x/sys v0.4.0 // indirect
|
||||||
|
)
|
||||||
|
4
go.sum
Normal file
4
go.sum
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||||
|
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||||
|
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
|
||||||
|
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
109
main.go
109
main.go
@ -16,15 +16,30 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed VERSION
|
//go:embed VERSION
|
||||||
var versionFile embed.FS
|
var versionFile embed.FS
|
||||||
|
|
||||||
var (
|
var (
|
||||||
NAME = "tcproxy"
|
NAME = "tcproxy"
|
||||||
VERSION string
|
VERSION string
|
||||||
|
|
||||||
|
whitelist = make(map[string]struct{})
|
||||||
|
whitelistFile = "whitelist.txt"
|
||||||
|
mu sync.RWMutex
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
VERSION, err = readVersion()
|
||||||
|
if err != nil {
|
||||||
|
VERSION = "xx.yy.zz"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func printHelp() {
|
func printHelp() {
|
||||||
helpText := fmt.Sprintf(`Usage: %s [options] [path]
|
helpText := fmt.Sprintf(`Usage: %s [options] [path]
|
||||||
Options:
|
Options:
|
||||||
@ -40,8 +55,7 @@ func readVersion() (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
VERSION = strings.TrimSpace(string(version))
|
return strings.TrimSpace(string(version)), nil
|
||||||
return VERSION, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func printVersion() {
|
func printVersion() {
|
||||||
@ -77,12 +91,6 @@ func main() {
|
|||||||
bindAddrShort := flag.String("b", "localhost:8443", "The address to bind the server")
|
bindAddrShort := flag.String("b", "localhost:8443", "The address to bind the server")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
_, err := readVersion()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error reading version: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if *bindAddrShort != "localhost:8443" {
|
if *bindAddrShort != "localhost:8443" {
|
||||||
bindAddr = bindAddrShort
|
bindAddr = bindAddrShort
|
||||||
}
|
}
|
||||||
@ -95,16 +103,23 @@ func main() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err := loadWhitelist()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load whitelist: %v", err)
|
||||||
|
}
|
||||||
|
go watchWhitelist()
|
||||||
|
|
||||||
l, err := net.Listen("tcp", *bindAddr)
|
l, err := net.Listen("tcp", *bindAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatalf("Failed to bind to address %s: %v", *bindAddr, err)
|
||||||
}
|
}
|
||||||
printVersion()
|
printVersion()
|
||||||
log.Printf("Listening on %s for incoming HTTPS (or HTTP) connections...", *bindAddr)
|
log.Printf("Listening on %s for incoming HTTPS (or HTTP) connections...", *bindAddr)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := l.Accept()
|
conn, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Printf("Failed to accept connection: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go handleConnection(conn)
|
go handleConnection(conn)
|
||||||
@ -233,6 +248,17 @@ func handleConnection(clientConn net.Conn) {
|
|||||||
clientConn.Close()
|
clientConn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
clientIP := tcpAddr.IP.String()
|
||||||
|
|
||||||
|
mu.RLock()
|
||||||
|
_, allowed := whitelist[clientIP]
|
||||||
|
mu.RUnlock()
|
||||||
|
|
||||||
|
if !allowed {
|
||||||
|
log.Printf("Connection from %s denied", clientIP)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
log.Printf("Received connection from %s:%d", tcpAddr.IP, tcpAddr.Port)
|
log.Printf("Received connection from %s:%d", tcpAddr.IP, tcpAddr.Port)
|
||||||
|
|
||||||
// Create a buffer to peek the first byte
|
// Create a buffer to peek the first byte
|
||||||
@ -256,3 +282,66 @@ func handleConnection(clientConn net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loadWhitelist() error {
|
||||||
|
file, err := os.Open(whitelistFile)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
newWhitelist := make(map[string]struct{})
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(file)
|
||||||
|
for scanner.Scan() {
|
||||||
|
ip := strings.TrimSpace(scanner.Text())
|
||||||
|
if ip != "" && !strings.HasPrefix(ip, "#") {
|
||||||
|
newWhitelist[ip] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
whitelist = newWhitelist
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
log.Printf("Whitelist updated: %v", newWhitelist)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func watchWhitelist() {
|
||||||
|
watcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to create file watcher: %v", err)
|
||||||
|
}
|
||||||
|
defer watcher.Close()
|
||||||
|
|
||||||
|
err = watcher.Add(whitelistFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to watch file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event, ok := <-watcher.Events:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||||
|
log.Println("Whitelist file modified")
|
||||||
|
if err := loadWhitelist(); err != nil {
|
||||||
|
log.Printf("Failed to reload whitelist: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case err, ok := <-watcher.Errors:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("File watcher error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user