348 lines
8.1 KiB
Go
348 lines
8.1 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto/tls"
|
|
"embed"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/fsnotify/fsnotify"
|
|
)
|
|
|
|
//go:embed VERSION
|
|
var versionFile embed.FS
|
|
|
|
var (
|
|
NAME = "tcproxy"
|
|
VERSION string
|
|
|
|
whitelist = make(map[string]struct{})
|
|
whitelistFile = "whitelist.txt"
|
|
mu sync.RWMutex
|
|
)
|
|
|
|
func init() {
|
|
var err error
|
|
VERSION, err = readVersion()
|
|
if err != nil {
|
|
VERSION = "xx.yy.zz"
|
|
}
|
|
}
|
|
|
|
func printHelp() {
|
|
helpText := fmt.Sprintf(`Usage: %s [options] [path]
|
|
Options:
|
|
-b, --bind The address to bind the server (default "localhost:8443").
|
|
-v, --version Display the version of the server.
|
|
-h, --help Display this help message.
|
|
`, NAME)
|
|
fmt.Print(helpText)
|
|
}
|
|
|
|
func readVersion() (string, error) {
|
|
version, err := versionFile.ReadFile("VERSION")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return strings.TrimSpace(string(version)), nil
|
|
}
|
|
|
|
func printVersion() {
|
|
fmt.Printf(
|
|
"%s %s %s %s %s\n",
|
|
NAME,
|
|
VERSION,
|
|
runtime.Version(),
|
|
runtime.GOOS,
|
|
runtime.GOARCH,
|
|
)
|
|
}
|
|
|
|
type readOnlyConn struct {
|
|
reader io.Reader
|
|
}
|
|
|
|
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
|
|
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
|
|
func (conn readOnlyConn) Close() error { return nil }
|
|
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
|
|
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
|
|
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
|
|
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
|
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
|
|
|
func main() {
|
|
versionFlag := flag.Bool("version", false, "Display the version of the server")
|
|
versionFlagShort := flag.Bool("v", false, "Display the version")
|
|
helpFlag := flag.Bool("help", false, "Display help message")
|
|
helpFlagShort := flag.Bool("h", false, "Display help message")
|
|
bindAddr := flag.String("bind", "localhost:8443", "The address to bind the server")
|
|
bindAddrShort := flag.String("b", "localhost:8443", "The address to bind the server")
|
|
flag.Parse()
|
|
|
|
if *bindAddrShort != "localhost:8443" {
|
|
bindAddr = bindAddrShort
|
|
}
|
|
if *versionFlag || *versionFlagShort {
|
|
printVersion()
|
|
os.Exit(0)
|
|
}
|
|
if *helpFlag || *helpFlagShort {
|
|
printHelp()
|
|
os.Exit(0)
|
|
}
|
|
|
|
err := loadWhitelist()
|
|
if err != nil {
|
|
log.Fatalf("Failed to load whitelist: %v", err)
|
|
}
|
|
go watchWhitelist()
|
|
|
|
l, err := net.Listen("tcp", *bindAddr)
|
|
if err != nil {
|
|
log.Fatalf("Failed to bind to address %s: %v", *bindAddr, err)
|
|
}
|
|
printVersion()
|
|
log.Printf("Listening on %s for incoming HTTPS (or HTTP) connections...", *bindAddr)
|
|
|
|
for {
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
log.Printf("Failed to accept connection: %v", err)
|
|
continue
|
|
}
|
|
go handleConnection(conn)
|
|
}
|
|
}
|
|
|
|
func peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
|
|
peekedBytes := new(bytes.Buffer)
|
|
hello, err := readClientHello(io.TeeReader(reader, peekedBytes))
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return hello, io.MultiReader(peekedBytes, reader), nil
|
|
}
|
|
|
|
func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
|
|
var hello *tls.ClientHelloInfo
|
|
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
|
|
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
hello = new(tls.ClientHelloInfo)
|
|
*hello = *argHello
|
|
return nil, nil
|
|
},
|
|
}).Handshake()
|
|
if hello == nil {
|
|
return nil, err
|
|
}
|
|
return hello, nil
|
|
}
|
|
|
|
func handleHTTPS(reader io.Reader, frontendConn net.Conn) {
|
|
clientHello, clientReader, err := peekClientHello(reader)
|
|
if err != nil {
|
|
log.Print("Failed to peek ClientHello:", err)
|
|
return
|
|
}
|
|
|
|
serverName := clientHello.ServerName
|
|
log.Printf("Forwarding request for domain: %s", serverName)
|
|
|
|
// Connect to the backend server as specified by the SNI
|
|
backendConn, err := net.DialTimeout("tcp", net.JoinHostPort(serverName, "443"), 10*time.Second)
|
|
if err != nil {
|
|
log.Printf("Failed to connect to backend %s: %v", serverName, err)
|
|
return
|
|
}
|
|
defer backendConn.Close()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
// Forward traffic from client to backend
|
|
go func() {
|
|
io.Copy(backendConn, clientReader)
|
|
backendConn.(*net.TCPConn).CloseWrite()
|
|
wg.Done()
|
|
}()
|
|
|
|
// Forward traffic from backend to client
|
|
go func() {
|
|
io.Copy(frontendConn, backendConn)
|
|
frontendConn.(*net.TCPConn).CloseWrite()
|
|
wg.Done()
|
|
}()
|
|
|
|
wg.Wait()
|
|
log.Printf("Completed forwarding for domain: %s", serverName)
|
|
}
|
|
|
|
func handleHTTP(reader io.Reader, frontendConn net.Conn) {
|
|
// Buffer the reader to peek into the HTTP request
|
|
bufferedReader := bufio.NewReader(reader)
|
|
req, err := http.ReadRequest(bufferedReader)
|
|
if err != nil {
|
|
log.Printf("Error reading HTTP request: %v", err)
|
|
return
|
|
}
|
|
|
|
// Extract the Host from the HTTP request
|
|
backendServer := req.Host
|
|
if backendServer == "" {
|
|
log.Println("No Host header found or empty Host header")
|
|
return
|
|
}
|
|
|
|
// Establish a connection to the backend server
|
|
backendConn, err := net.DialTimeout("tcp", net.JoinHostPort(backendServer, "80"), 10*time.Second)
|
|
if err != nil {
|
|
log.Printf("Failed to connect to backend %s: %v", backendServer, err)
|
|
return
|
|
}
|
|
defer backendConn.Close()
|
|
|
|
// Write the original request to the backend
|
|
req.Write(backendConn)
|
|
|
|
// Use a WaitGroup to wait for both copy operations to complete
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
// Forward remaining traffic from client to backend
|
|
go func() {
|
|
io.Copy(backendConn, bufferedReader) // bufferedReader now points to the rest of the stream after the initial read
|
|
backendConn.(*net.TCPConn).CloseWrite()
|
|
wg.Done()
|
|
}()
|
|
|
|
// Forward traffic from backend to client
|
|
go func() {
|
|
io.Copy(frontendConn, backendConn)
|
|
frontendConn.(*net.TCPConn).CloseWrite()
|
|
wg.Done()
|
|
}()
|
|
|
|
// Wait for both forwarding operations to complete
|
|
wg.Wait()
|
|
log.Printf("Completed forwarding HTTP traffic to %s", backendServer)
|
|
}
|
|
|
|
func handleConnection(clientConn net.Conn) {
|
|
defer clientConn.Close()
|
|
|
|
tcpAddr, ok := clientConn.RemoteAddr().(*net.TCPAddr)
|
|
if !ok {
|
|
log.Printf("Failed to obtain TCP address from connection")
|
|
clientConn.Close()
|
|
return
|
|
}
|
|
clientIP := tcpAddr.IP.String()
|
|
|
|
mu.RLock()
|
|
_, allowed := whitelist[clientIP]
|
|
mu.RUnlock()
|
|
|
|
if !allowed {
|
|
log.Printf("Connection from %s denied", clientIP)
|
|
return
|
|
}
|
|
|
|
log.Printf("Received connection from %s:%d", tcpAddr.IP, tcpAddr.Port)
|
|
|
|
// Create a buffer to peek the first byte
|
|
buf := make([]byte, 1)
|
|
_, err := clientConn.Read(buf)
|
|
if err != nil {
|
|
log.Print("Failed to read from client connection:", err)
|
|
return
|
|
}
|
|
|
|
// Use a MultiReader to put the peeked byte back in front of the clientConn stream
|
|
clientReader := io.MultiReader(bytes.NewReader(buf), clientConn)
|
|
|
|
// Check if the first byte indicates a TLS handshake (0x16)
|
|
if buf[0] == 0x16 {
|
|
// Handle as HTTPS
|
|
handleHTTPS(clientReader, clientConn)
|
|
} else {
|
|
// Handle as HTTP or other non-TLS traffic
|
|
handleHTTP(clientReader, clientConn)
|
|
}
|
|
|
|
}
|
|
|
|
func loadWhitelist() error {
|
|
file, err := os.Open(whitelistFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
|
|
newWhitelist := make(map[string]struct{})
|
|
|
|
scanner := bufio.NewScanner(file)
|
|
for scanner.Scan() {
|
|
ip := strings.TrimSpace(scanner.Text())
|
|
if ip != "" && !strings.HasPrefix(ip, "#") {
|
|
newWhitelist[ip] = struct{}{}
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
mu.Lock()
|
|
whitelist = newWhitelist
|
|
mu.Unlock()
|
|
|
|
log.Printf("Whitelist updated: %v", newWhitelist)
|
|
|
|
return nil
|
|
}
|
|
|
|
func watchWhitelist() {
|
|
watcher, err := fsnotify.NewWatcher()
|
|
if err != nil {
|
|
log.Fatalf("Failed to create file watcher: %v", err)
|
|
}
|
|
defer watcher.Close()
|
|
|
|
err = watcher.Add(whitelistFile)
|
|
if err != nil {
|
|
log.Fatalf("Failed to watch file: %v", err)
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case event, ok := <-watcher.Events:
|
|
if !ok {
|
|
return
|
|
}
|
|
if event.Op&fsnotify.Write == fsnotify.Write {
|
|
log.Println("Whitelist file modified")
|
|
if err := loadWhitelist(); err != nil {
|
|
log.Printf("Failed to reload whitelist: %v", err)
|
|
}
|
|
}
|
|
case err, ok := <-watcher.Errors:
|
|
if !ok {
|
|
return
|
|
}
|
|
log.Printf("File watcher error: %v", err)
|
|
}
|
|
}
|
|
}
|