42
main.go
42
main.go
@@ -12,6 +12,8 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
"sync"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -20,6 +22,9 @@ var (
|
||||
staticDir string
|
||||
expireDur time.Duration
|
||||
expireOnView bool
|
||||
limiters = make(map[string]*rate.Limiter)
|
||||
limMu sync.Mutex
|
||||
useHTTPS bool
|
||||
)
|
||||
|
||||
type meta struct {
|
||||
@@ -27,6 +32,32 @@ type meta struct {
|
||||
ExpireOnView bool `json:"expire_on_view"`
|
||||
}
|
||||
|
||||
func getLimiter(ip string) *rate.Limiter {
|
||||
limMu.Lock()
|
||||
defer limMu.Unlock()
|
||||
lim, ok := limiters[ip]
|
||||
if !ok {
|
||||
lim = rate.NewLimiter(1, 5) // 1 req/sec, burst of 5
|
||||
limiters[ip] = lim
|
||||
}
|
||||
return lim
|
||||
}
|
||||
|
||||
func rateLimitMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := r.RemoteAddr
|
||||
if i := strings.LastIndex(ip, ":"); i != -1 {
|
||||
ip = ip[:i]
|
||||
}
|
||||
lim := getLimiter(ip)
|
||||
if !lim.Allow() {
|
||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func randomID(n int) string {
|
||||
letters := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
b := make([]rune, n)
|
||||
@@ -140,7 +171,11 @@ func uploadHandler(w http.ResponseWriter, r *http.Request) {
|
||||
_ = os.WriteFile(path+".json", metaBytes, 0644)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "http://%s/%s\n", domain, filename)
|
||||
scheme := "http"
|
||||
if useHTTPS {
|
||||
scheme = "https"
|
||||
}
|
||||
fmt.Fprintf(w, "%s://%s/%s\n", scheme, domain, filename)
|
||||
}
|
||||
|
||||
func viewHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -192,15 +227,16 @@ func main() {
|
||||
flag.StringVar(&staticDir, "static", "static", "directory to save pastes")
|
||||
flag.DurationVar(&expireDur, "expire", 0, "time after which paste expires (e.g. 5m, 1h)")
|
||||
flag.BoolVar(&expireOnView, "expire-on-view", false, "delete paste after it's viewed once")
|
||||
flag.BoolVar(&useHTTPS, "https", false, "use https:// in generated URLs")
|
||||
flag.Parse()
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.HandleFunc("/", rateLimitMiddleware(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPost {
|
||||
uploadHandler(w, r)
|
||||
} else {
|
||||
viewHandler(w, r)
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
fmt.Printf("slenpaste running at http://%s, storing in %s\n", listenAddr, staticDir)
|
||||
http.ListenAndServe(listenAddr, nil)
|
||||
|
||||
Reference in New Issue
Block a user