diff --git a/flake.nix b/flake.nix index beaabaf..fc03010 100644 --- a/flake.nix +++ b/flake.nix @@ -15,7 +15,7 @@ packages = rec { slenpaste = pkgs.buildGoModule { pname = "slenpaste"; - version = "0.1.0"; + version = "0.1.1"; src = ./.; goPackagePath = "github.com/slendidev/slenpaste"; vendorHash = null; diff --git a/main.go b/main.go index a4e7861..ca34585 100644 --- a/main.go +++ b/main.go @@ -73,45 +73,50 @@ func uploadHandler(w http.ResponseWriter, r *http.Request) { } var reader io.Reader + var ext string + // handle multipart file upload contentType := r.Header.Get("Content-Type") if strings.HasPrefix(contentType, "multipart/form-data") { if err := r.ParseMultipartForm(10 << 20); err == nil { - if file, _, err := r.FormFile("file"); err == nil { + if file, header, err := r.FormFile("file"); err == nil { defer file.Close() reader = file + ext = filepath.Ext(header.Filename) } } } + + // fallback to body upload if reader == nil { reader = r.Body defer r.Body.Close() } - expVal := r.URL.Query().Get("expiry") - var dur time.Duration - var onView bool - switch expVal { - case "view": - onView = true - case "0": - // no expiry - default: - dur, _ = time.ParseDuration(expVal) + // default extension + if ext == "" { + ext = ".txt" } - id := randomID(6) + // generate ID and filename + id := randomID(6) + filename := id + ext + + // ensure storage dir if err := os.MkdirAll(staticDir, 0755); err != nil { http.Error(w, "Server error", http.StatusInternalServerError) return } - path := filepath.Join(staticDir, id) + path := filepath.Join(staticDir, filename) + + // save file out, err := os.Create(path) if err != nil { http.Error(w, "Save error", http.StatusInternalServerError) return } defer out.Close() + n, err := io.Copy(out, reader) if err != nil { http.Error(w, "Write error", http.StatusInternalServerError) @@ -123,16 +128,26 @@ func uploadHandler(w http.ResponseWriter, r *http.Request) { return } - if dur > 0 || onView { - m := meta{ExpireOnView: onView} - if dur > 0 { - m.Expiry = time.Now().Add(dur) + // write metadata if needed + expVal := r.URL.Query().Get("expiry") + var m meta + switch expVal { + case "view": + m.ExpireOnView = true + case "0": + // no expiry + default: + if d, err := time.ParseDuration(expVal); err == nil { + m.Expiry = time.Now().Add(d) } + } + if !m.Expiry.IsZero() || m.ExpireOnView { metaBytes, _ := json.Marshal(m) _ = os.WriteFile(path+".json", metaBytes, 0644) } - fmt.Fprintf(w, "http://%s/%s\n", domain, id) + // respond with URL including extension + fmt.Fprintf(w, "http://%s/%s\n", domain, filename) } func viewHandler(w http.ResponseWriter, r *http.Request) { @@ -141,7 +156,7 @@ func viewHandler(w http.ResponseWriter, r *http.Request) { indexHandler(w, r) return } - path := filepath.Join(staticDir, id) + path := filepath.Join(staticDir, id) metaPath := path + ".json" // load and enforce metadata @@ -172,11 +187,11 @@ func viewHandler(w http.ResponseWriter, r *http.Request) { } func main() { - flag.StringVar(&domain, "domain", "localhost:8080", "domain name for URLs") - flag.StringVar(&listenAddr, "listen", "0.0.0.0:8080", "listen address") - flag.StringVar(&staticDir, "static", "static", "directory to save pastes") - flag.DurationVar(&expireDur, "expire", 0, "time after which paste expires (e.g. 5m, 1h)") - flag.BoolVar(&expireOnView, "expire-on-view", false, "delete paste after it's viewed once") + flag.StringVar(&domain, "domain", "localhost:8080", "domain name for URLs") + flag.StringVar(&listenAddr, "listen", "0.0.0.0:8080", "listen address") + flag.StringVar(&staticDir, "static", "static", "directory to save pastes") + flag.DurationVar(&expireDur, "expire", 0, "time after which paste expires (e.g. 5m, 1h)") + flag.BoolVar(&expireOnView, "expire-on-view", false, "delete paste after it's viewed once") flag.Parse() http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {