Skip to content
Snippets Groups Projects
Commit a43b2e94 authored by Jacob Vosmaer (GitLab)'s avatar Jacob Vosmaer (GitLab)
Browse files

Put multipart rewrite work in functions

parent c28f06f7
No related branches found
No related tags found
1 merge request!100Prometheus metrics for multipart file extraction
Pipeline #
Loading
Loading
@@ -37,6 +37,13 @@ var (
)
)
 
type rewriter struct {
writer *multipart.Writer
tempPath string
filter MultipartFormProcessor
directories []string
}
func init() {
prometheus.MustRegister(multipartUploadRequests)
prometheus.MustRegister(multipartFileUploadBytes)
Loading
Loading
@@ -56,10 +63,14 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
 
multipartUploadRequests.Inc()
 
var directories []string
rew := &rewriter{
writer: writer,
tempPath: tempPath,
filter: filter,
}
 
cleanup = func() {
for _, dir := range directories {
for _, dir := range rew.directories {
os.RemoveAll(dir)
}
}
Loading
Loading
@@ -83,60 +94,77 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
}
 
// Copy form field
if filename := p.FileName(); filename != "" {
multipartFiles.Inc()
if strings.Contains(filename, "/") || filename == "." || filename == ".." {
return cleanup, fmt.Errorf("illegal filename: %q", filename)
}
// Create temporary directory where the uploaded file will be stored
if err := os.MkdirAll(tempPath, 0700); err != nil {
return cleanup, fmt.Errorf("mkdir for tempfile: %v", err)
}
tempDir, err := ioutil.TempDir(tempPath, "multipart-")
if err != nil {
return cleanup, fmt.Errorf("create tempdir: %v", err)
}
directories = append(directories, tempDir)
file, err := os.OpenFile(path.Join(tempDir, filename), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return cleanup, fmt.Errorf("rewriteFormFilesFromMultipart: temp file: %v", err)
}
defer file.Close()
// Add file entry
writer.WriteField(name+".path", file.Name())
writer.WriteField(name+".name", filename)
written, err := io.Copy(file, p)
if err != nil {
return cleanup, fmt.Errorf("copy from multipart to tempfile: %v", err)
}
multipartFileUploadBytes.Add(float64(written))
file.Close()
if err := filter.ProcessFile(name, file.Name(), writer); err != nil {
return cleanup, err
}
if p.FileName() != "" {
err = rew.handleFilePart(name, p)
} else {
np, err := writer.CreatePart(p.Header)
if err != nil {
return cleanup, fmt.Errorf("create multipart field: %v", err)
}
_, err = io.Copy(np, p)
if err != nil {
return cleanup, fmt.Errorf("duplicate multipart field: %v", err)
}
if err := filter.ProcessField(name, writer); err != nil {
return cleanup, fmt.Errorf("process multipart field: %v", err)
}
err = rew.copyPart(name, p)
}
if err != nil {
return cleanup, err
}
}
return cleanup, nil
}
func (rew *rewriter) handleFilePart(name string, p *multipart.Part) error {
multipartFiles.Inc()
filename := p.FileName()
if strings.Contains(filename, "/") || filename == "." || filename == ".." {
return fmt.Errorf("illegal filename: %q", filename)
}
// Create temporary directory where the uploaded file will be stored
if err := os.MkdirAll(rew.tempPath, 0700); err != nil {
return fmt.Errorf("mkdir for tempfile: %v", err)
}
tempDir, err := ioutil.TempDir(rew.tempPath, "multipart-")
if err != nil {
return fmt.Errorf("create tempdir: %v", err)
}
rew.directories = append(rew.directories, tempDir)
file, err := os.OpenFile(path.Join(tempDir, filename), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return fmt.Errorf("rewriteFormFilesFromMultipart: temp file: %v", err)
}
defer file.Close()
// Add file entry
rew.writer.WriteField(name+".path", file.Name())
rew.writer.WriteField(name+".name", filename)
written, err := io.Copy(file, p)
if err != nil {
return fmt.Errorf("copy from multipart to tempfile: %v", err)
}
multipartFileUploadBytes.Add(float64(written))
file.Close()
if err := rew.filter.ProcessFile(name, file.Name(), rew.writer); err != nil {
return err
}
return nil
}
func (rew *rewriter) copyPart(name string, p *multipart.Part) error {
np, err := rew.writer.CreatePart(p.Header)
if err != nil {
return fmt.Errorf("create multipart field: %v", err)
}
if _, err := io.Copy(np, p); err != nil {
return fmt.Errorf("duplicate multipart field: %v", err)
}
if err := rew.filter.ProcessField(name, rew.writer); err != nil {
return fmt.Errorf("process multipart field: %v", err)
}
return nil
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment