[GO]使用 Multi Part上传超过4G的大文件(全球首创)

前言

  • 最近在学习并使用 go 语言进行一些项目的开发,也测试了一下 http + 内置的 multipart.Writer 来实现 POST 上传文件到我以前的一个 Java 版本的文件服务器上.
  • 网上的文章一般都很简单, 按照其规范编写了如下上传多个文件的测试代码:
    for _, uf := range uploadFiles {
        fileInfo, _ := os.Stat(uf)
        file, _ := os.Open(uf)
        fileWrite, _ := mpWrite.CreateFormFile("file", filepath.Base(uf))
        writeLength, _ := io.Copy(fileWrite, file)
        assert.Equal(t, fileInfo.Size(), writeLength, "拷贝文件的内容到内存缓冲区中")
        _ = file.Close()
    }
  • 从语法 io.Copy(fileWrite, file) 来看, 需要把文件的所有内容都拷贝到 bytes.Buffer 实例中, 在上传大文件时必然占用大量内存. 实测验证果然, 占用的内存量为 2倍文件大小.
  • 网上搜索了一下, 发现大家也都知道这个问题, 但解决方案一般有两种:
    • 分成小文件块上传。这种方案看起来性能比较好,但也存在一些问题:
      • 无法兼容老的只支持单文件整体上传的 Server
      • 上传时也要占用 “小文件块” 那么大的内存,而不像 C++/Java 等其他语言的小内存消耗.
      • 上传到服务器后,可能还需要合并,这就要求不同的 Server 实例能访问到上传的多个文件块(比如采用 NFS 或 IP Hash 的 LB 策略等), 也有不少限制.
    • 使用 io.Pipe 等, 但这种方式针对每一个上传的文件(可以优化成大于xxMB的文件), 都各创建一个 pipe + goroutine, 感觉也比较难受.
      -在 google, stackoverflow, chatgpt, github 等各个地方, 至今都没有找到任何类似 C++/Java 的小内存消耗的 multi part 解决方案, 如果有人知道的话, 也可以告知.

解决方案

  • 参考go 源码中的 multipart.Writer , 重写一个 VirtualWriter , 实现按需读取文件内容并发送,从而可以支持超过 4G 的文件上传。
  • 实现的功能:
    • 接口尽量参考 multipart.Writer, 支持多个 field, file ;
    • CreateFormFile 时不需要将文件内容都读入内存,而是保存信息,在 Read 时按需读取;
    • 增加 onProgressCallback 回调, 从而可以知道文件上传进度( 这个功能似乎不容易合并到 go sdk 中 )

源码 (可以直接拷贝到项目中使用)

GitHub: https://github.com/fishjam/go-library
国内镜像: https://gitee.com/fishjam/go-library

import (
	"bufio"
	"bytes"
	"crypto/rand"
	"errors"
	"fmt"
	"io"
	"os"
	"path/filepath"
	"strings"
	"sync"
)

// OnProgressCallback provide progress callback when do real POST, it's useful when uploading large files.
//
// Notice: the part is nil and err is EOF when send last end boundary.
type OnProgressCallback func(part VirtualPart, err error, readCount, totalCount int64)

// VirtualPart handle all the part's content and provide detail information in OnProgressCallback
type VirtualPart interface {
	// Name returns current field name
	Name() string

	// Len returns current part length, include boundary
	Len() int64

	// Remain returns current remain length while reading
	Remain() int64

	// Close closes the part(FilePart)
	Close() error

	read(p []byte) (n int, err error)
}

type fieldPart struct {
	fieldName   string
	fieldValue  string
	fieldLength int64
	readOffset  int64
}

func (fp *fieldPart) Name() string {
	return fp.fieldName
}

func (fp *fieldPart) Len() int64 {
	return fp.fieldLength
}

func (fp *fieldPart) read(p []byte) (n int, err error) {
	reader := bytes.NewReader([]byte(fp.fieldValue[fp.readOffset:]))
	bufReader := bufio.NewReader(reader)
	n, err = bufReader.Read(p)
	if err == io.EOF {
		err = nil
	}
	fp.readOffset += int64(n)
	return
}
func (fp *fieldPart) Remain() int64 {
	return fp.fieldLength - fp.readOffset
}

func (fp *fieldPart) Close() error {
	return nil
}

type filePart struct {
	fieldValue  string
	fieldLength int64
	readOffset  int64
	filePath    string
	fileSize    int64
	file        *os.File
	once        sync.Once
}

func (fp *filePart) Name() string {
	return fp.filePath
}

func (fp *filePart) Len() int64 {
	//the last 2 is last \r\n after file content
	return fp.fieldLength + fp.fileSize + 2
}

func (fp *filePart) read(p []byte) (n int, err error) {
	var (
		nField    int
		nFile     int
		nLastCrLf int
	)
	fp.once.Do(func() {
		//open file
		fp.file, err = os.Open(fp.filePath)
		if err != nil {
			//open file fail, example: delete file after CreateFormFile
			return
		}
	})
	if err != nil {
		//once.Do error
		return 0, err
	}
	if fp.readOffset < fp.fieldLength {
		// read from field
		reader := bytes.NewReader([]byte(fp.fieldValue[fp.readOffset:]))
		nField, err = reader.Read(p)
		fp.readOffset += int64(nField)
		if err == io.EOF {
			err = nil
		}
	}
	if fp.readOffset >= fp.fieldLength {
		//read from file
		fileOffset := fp.readOffset - fp.fieldLength
		nFile, err = fp.file.ReadAt(p[nField:], fileOffset)
		if err == io.EOF {
			err = nil
			//after read file end, then append \r\n
			reader := bytes.NewReader([]byte("\r\n"))
			nLastCrLf, err = reader.Read(p[(nField + nFile):])
			fp.readOffset += int64(nLastCrLf) // 2 char
		}
		if err == nil {
			//total read count
			fp.readOffset += int64(nFile)
		}
	}
	//current read count
	n = nField + nFile + nLastCrLf
	return n, err
}

func (fp *filePart) Remain() int64 {
	return fp.Len() - fp.readOffset
}

func (fp *filePart) Close() (err error) {
	if fp.file != nil {
		err = fp.file.Close()
		fp.file = nil
	}
	return err
}

// VirtualWriter is similar as builtin [mime/multipart.writer],
// but can support upload lots of files (4G+) with little memory consume.
type VirtualWriter struct {
	closeAfterRead bool
	boundary       string
	parts          []VirtualPart
	finished       bool
	readPartIndex  int
	readCount      int64
	totalCount     int64
	callback       OnProgressCallback
}

// NewVirtualWriter returns a new multipart writer with a random boundary,
func NewVirtualWriter() *VirtualWriter {
	boundary := randomBoundary()
	return &VirtualWriter{
		closeAfterRead: true,
		boundary:       boundary,
		parts:          make([]VirtualPart, 0),
		readPartIndex:  0,
		readCount:      0,
		totalCount:     int64(len(boundary) + 6), //init total count is for last boundary(--%s--\r\n)
	}
}

// SetCloseAfterRead close the part immediately after read the part, it's used in file part, default is true
func (vw *VirtualWriter) SetCloseAfterRead(closeAfterRead bool) {
	vw.closeAfterRead = closeAfterRead
}

// SetBoundary overrides the VirtualWriter's default randomly-generated
// boundary separator with an explicit value.
//
// Copy from built-in multipart.Writer, and modify for totalCount
func (vw *VirtualWriter) SetBoundary(boundary string) error {
	if len(vw.parts) > 0 {
		return errors.New("mime: SetBoundary called after write")
	}

	// rfc2046#section-5.1.1
	if len(boundary) < 1 || len(boundary) > 70 {
		return errors.New("mime: invalid boundary length")
	}

	end := len(boundary) - 1
	for i, b := range boundary {
		if 'A' <= b && b <= 'Z' || 'a' <= b && b <= 'z' || '0' <= b && b <= '9' {
			continue
		}
		switch b {
		case '\'', '(', ')', '+', '_', ',', '-', '.', '/', ':', '=', '?':
			continue
		case ' ':
			if i != end {
				continue
			}
		}
		return errors.New("mime: invalid boundary character")
	}
	vw.boundary = boundary

	vw.totalCount = int64(len(boundary) + 6) //init total count is for last boundary(--%s--\r\n)
	return nil
}

// FormDataContentType returns the Content-Type for an HTTP
// multipart/form-data with this Writer's Boundary.
//
// Copy from built-in multipart.Writer #FormDataContentType
func (vw *VirtualWriter) FormDataContentType() string {
	b := vw.boundary
	// We must quote the boundary if it contains any of the
	// tspecials characters defined by RFC 2045, or space.
	if strings.ContainsAny(b, `()<>@,;:\"/[]?= `) {
		b = `"` + b + `"`
	}
	return "multipart/form-data; boundary=" + b
}

func randomBoundary() string {
	var buf [30]byte
	_, err := io.ReadFull(rand.Reader, buf[:])
	if err != nil {
		panic(err)
	}
	return fmt.Sprintf("%x", buf[:])
}

var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")

func escapeQuotes(s string) string {
	return quoteEscaper.Replace(s)
}

// SetProgressCallback set the callback function
func (vw *VirtualWriter) SetProgressCallback(callback OnProgressCallback) {
	vw.callback = callback
}

// CreateFormFile creates a new form-data header with the provided field name and file name.
// But it just remains the file information, will not read the file contents to memory until actual POST occurs.
func (vw *VirtualWriter) CreateFormFile(fieldName, filePath string) error {
	stat, err := os.Stat(filePath)
	if err != nil {
		return err
	}

	fieldValue := fmt.Sprintf("--%s\r\nContent-Disposition: form-data; name=\"%s\"; filename=\"%s\"\r\n"+
		"Content-Type: application/octet-stream\r\n\r\n",
		vw.boundary, escapeQuotes(fieldName), filepath.Base(filePath))

	part := filePart{
		fieldValue:  fieldValue,
		filePath:    filePath,
		fieldLength: int64(len(fieldValue)),
		fileSize:    stat.Size(),
		readOffset:  0,
	}
	vw.parts = append(vw.parts, &part)
	vw.totalCount += part.Len()
	return nil
}

// WriteField creates a new form-data header with the provided field name and value
func (vw *VirtualWriter) WriteField(fieldName, value string) error {
	fieldVal := fmt.Sprintf("--%s\r\nContent-Disposition: form-data; name=\"%s\"\r\n\r\n%s\r\n",
		vw.boundary,
		escapeQuotes(fieldName), value)
	part := fieldPart{
		fieldName:   fieldName,
		fieldValue:  fieldVal,
		fieldLength: int64(len(fieldVal)),
		readOffset:  0,
	}
	vw.parts = append(vw.parts, &part)
	vw.totalCount += part.Len()
	return nil
}

// Read is used for Reader function(example: body parameter for http.NewRequest),
// user should NOT call it directly.
func (vw *VirtualWriter) Read(p []byte) (nRead int, err error) {
	var (
		nReadLastBoundary int
	)

	if vw.readPartIndex < len(vw.parts) {
		part := vw.parts[vw.readPartIndex]
		nRead, err = part.read(p) //p[nb:])
		//log.Printf("idx[%d], part read, nRead=%d, remain=%d, err=%+v",
		//	vw.readPartIndex, nRead, part.Remain(), err)

		if err == nil {
			//TODO: read than available data, file change after add ?
			//return -1, errors.New("read more data than available")
			if part.Remain() == 0 {
				if vw.closeAfterRead {
					_ = part.Close()
				}
				//read all current part data, will try read next
				vw.readPartIndex++
			}
			vw.readCount += int64(nRead)
		}
		if vw.callback != nil {
			vw.callback(part, err, vw.readCount, vw.totalCount)
		}
	}
	if vw.readPartIndex >= len(vw.parts) {
		//already read all part's data
		strLastBoundary := fmt.Sprintf("--%s--\r\n", vw.boundary)
		reader := bytes.NewReader([]byte(strLastBoundary))
		nReadLastBoundary, err = reader.Read(p[nRead:])
		if err == nil {
			nRead += nReadLastBoundary
			err = io.EOF
		}
		vw.readCount += int64(nReadLastBoundary)
		if vw.callback != nil {
			vw.callback(nil, err, vw.readCount, vw.totalCount)
		}
	}
	return nRead, err
}

// ContentLength return content length for all parts, it should same as "Content-Length" in http header
func (vw *VirtualWriter) ContentLength() int64 {
	return vw.totalCount
}

// Close closes all the parts
func (vw *VirtualWriter) Close() error {
	var err error
	for _, part := range vw.parts {
		pErr := part.Close()
		if err == nil && pErr != nil {
			//just return first error
			err = pErr
		}
	}
	vw.readCount = 0
	vw.totalCount = 0
	return err
}

UT(演示上传 4G+ 文件)

// TestUploadFileWithVirtualWriter
// after this test case run, will create upload folder, and upload some files into it,
// then compare the source and target file's md5
func TestUploadFilesWithVirtualWriter(t *testing.T) {
	//local fiddler proxy port, if not 0(example: 8888), then can use local fiddle to monitor network data
	localProxyPort := 0

	//this UT will remove all upload files after test, if you want remains the uploaded files to compare,
	//can set removeUploadFiles to false
	removeUploadFiles := true

	uploadTempFolder := debugutil.VerifyWithResult(os.MkdirTemp(os.TempDir(), "virtual_writer"))
	t.Logf("uploadTempFolder=%s", uploadTempFolder)
	if removeUploadFiles {
		defer func() {
			_ = debugutil.Verify(os.RemoveAll(uploadTempFolder))
		}()
	}

	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		mpReader := debugutil.VerifyWithResult(r.MultipartReader())

		uploadFields := make(map[string]string)
		_ = debugutil.Verify(os.MkdirAll(uploadTempFolder, 0755))

		for {
			part, err := mpReader.NextPart()
			//part, err := utils.VerifyWithResultEx[*multipart.Part](mpReader.NextPart())
			if err == io.EOF {
				break
			}

			name := part.FileName()
			if name == "" {
				buf := make([]byte, 1024)
				n, err := part.Read(buf)
				//t.Logf("part.Read error:%+v, n=%d", err, n)
				if err != nil { // 等于 EOF 时表示读取完毕,之后 buf 中的才是结果,
					uploadFields[part.FormName()] = string(buf[0:n])
				} else {
					uploadFields[part.FormName()] = string(buf[0:n]) //TODO: alloc buf?
				}
				continue
			}
			uploadFields[part.FormName()] = part.FileName()

			filename := part.FileName()
			filePath := path.Join(uploadTempFolder, filename)
			outFile := debugutil.VerifyWithResult(os.Create(filePath))
			defer outFile.Close()

			_ = debugutil.VerifyWithResult(io.Copy(outFile, part))
		}
		marshalResult := debugutil.VerifyWithResult(json.Marshal(uploadFields))
		w.WriteHeader(http.StatusOK)
		_ = debugutil.VerifyWithResult(w.Write(marshalResult))
	}))

	defer ts.Close()
	t.Logf("ts.URL=%s", ts.URL)

	uploadResultExpected := make(map[string]string)

	uploadUrl := fmt.Sprintf("%s%s", ts.URL, "/upload")
	uploadFiles := []string{
		"virtual_writer.go",
		"virtual_writer_test.go",
		//"not_exist",

		// 此处可以输入想测试上传的大文件路径, 支持 4G+
		//"F:\\ISO\\Windows\\Win10\\Win10_21H2_x64_CN_20220412.iso",
	}
	params := map[string]string{
		"key":          "value",
		"type":         "data",
		"keepFileName": "on",
	}

	transport := http.DefaultTransport.(*http.Transport).Clone()
	if localProxyPort != 0 {
		proxyAddr := fmt.Sprintf("http://127.0.0.1:%d", localProxyPort)
		proxyUrl, _ := url.Parse(proxyAddr)
		proxyFunc := http.ProxyURL(proxyUrl)
		transport.Proxy = proxyFunc
	}

	client := http.Client{
		Transport: transport,
	}
	mpWrite := NewVirtualWriter()
	defer func() {
		_ = debugutil.Verify(mpWrite.Close())
	}()

	//_ = mpWrite.SetBoundary("----WebKitFormBoundary2HwfBAoBw2hJ33gD")
	mpWrite.SetProgressCallback(func(part VirtualPart, err error, readCount, totalCount int64) {
		name := ""
		if part != nil {
			name = part.Name()
		}
		_ = name

		t.Logf("on progress: part name=%s, err=%+v, readCount=%d, totalCount=%d, percent=%0.2f",
			name, err, readCount, totalCount, float64(readCount*100)/float64(totalCount))
		debugutil.GoAssertTrue(t, readCount <= totalCount, "progress")
	})

	for key, val := range params {
		_ = mpWrite.WriteField(key, val)
		uploadResultExpected[key] = val
	}

	fileIndex := 0
	for _, uf := range uploadFiles {
		fieldName := fmt.Sprintf("file%d", fileIndex)
		if err := debugutil.Verify(mpWrite.CreateFormFile(fieldName, uf)); err != nil {
			t.Logf("CreateFormFile error, %+v", err)
			//in real code, should handle this error and maybe return

			//return
		}
		uploadResultExpected[fieldName] = filepath.Base(uf)
		fileIndex++
	}
	t.Logf("content-length=%d", mpWrite.ContentLength())

	req, err := http.NewRequest(http.MethodPost, uploadUrl, mpWrite)

	req.Header.Set("Content-Type", mpWrite.FormDataContentType())
	resp, err := client.Do(req)
	debugutil.GoAssertTrue(t, err == nil, "client.Do should successful")

	if err == nil {
		body := debugutil.VerifyWithResult[[]byte](io.ReadAll(resp.Body))
		defer resp.Body.Close()

		var uploadResponse map[string]string
		_ = debugutil.Verify(json.Unmarshal(body, &uploadResponse))
		t.Logf("response=%s, err=%+v", string(body), err)

		debugutil.GoAssertEqual(t, uploadResultExpected, uploadResponse, "upload result")
	}

	//check the uploaded files, should same as original files by check sum(MD5)
	for idx, uf := range uploadFiles {
		srcSum := fileCheckSum(uf)

		dstPath := path.Join(uploadTempFolder, filepath.Base(uf))
		dstSum := fileCheckSum(dstPath)

		debugutil.GoAssertEqual(t, srcSum, dstSum, fmt.Sprintf("compare[%d] %s <=> %s", idx, uf, dstPath))
	}

	//use sleep to wait, so can verify memory usage, or runtime.MemStats?
	if false {
		time.Sleep(time.Second * 30)
	}

}

func fileCheckSum(fileName string) string {
	f, err := os.Open(fileName)
	if err != nil {
		log.Fatal(err)
	}

	defer f.Close()

	h := md5.New()
	//h := sha256.New()
	//h := sha1.New()
	//h := sha512.New()

	if _, err := io.Copy(h, f); err != nil {
		log.Fatal(err)
	}

	return fmt.Sprintf("%x", h.Sum(nil))
}

遗留问题

  • 目前已经给 golang/go 提了 issue(#65203), 希望之后能合并到sdk中,从而方便大家的使用 . 由于和老系统有兼容性问题, 提交起来比较麻烦,远没有自己维护方便, 因此已经放弃.

你可能感兴趣的:(golang,开发语言,后端)