multipart.Writer
来实现 POST
上传文件到我以前的一个 Java 版本的文件服务器上. for _, uf := range uploadFiles {
fileInfo, _ := os.Stat(uf)
file, _ := os.Open(uf)
fileWrite, _ := mpWrite.CreateFormFile("file", filepath.Base(uf))
writeLength, _ := io.Copy(fileWrite, file)
assert.Equal(t, fileInfo.Size(), writeLength, "拷贝文件的内容到内存缓冲区中")
_ = file.Close()
}
io.Copy(fileWrite, file)
来看, 需要把文件的所有内容都拷贝到 bytes.Buffer
实例中, 在上传大文件时必然占用大量内存. 实测验证果然, 占用的内存量为 2倍文件大小
.go
源码中的 multipart.Writer
, 重写一个 VirtualWriter
, 实现按需读取文件内容并发送,从而可以支持超过 4G 的文件上传。multipart.Writer
, 支持多个 field, file ;CreateFormFile
时不需要将文件内容都读入内存,而是保存信息,在 Read 时按需读取;onProgressCallback
回调, 从而可以知道文件上传进度( 这个功能似乎不容易合并到 go sdk 中 )GitHub: https://github.com/fishjam/go-library
国内镜像: https://gitee.com/fishjam/go-library
import (
"bufio"
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
)
// OnProgressCallback provide progress callback when do real POST, it's useful when uploading large files.
//
// Notice: the part is nil and err is EOF when send last end boundary.
type OnProgressCallback func(part VirtualPart, err error, readCount, totalCount int64)
// VirtualPart handle all the part's content and provide detail information in OnProgressCallback
type VirtualPart interface {
// Name returns current field name
Name() string
// Len returns current part length, include boundary
Len() int64
// Remain returns current remain length while reading
Remain() int64
// Close closes the part(FilePart)
Close() error
read(p []byte) (n int, err error)
}
type fieldPart struct {
fieldName string
fieldValue string
fieldLength int64
readOffset int64
}
func (fp *fieldPart) Name() string {
return fp.fieldName
}
func (fp *fieldPart) Len() int64 {
return fp.fieldLength
}
func (fp *fieldPart) read(p []byte) (n int, err error) {
reader := bytes.NewReader([]byte(fp.fieldValue[fp.readOffset:]))
bufReader := bufio.NewReader(reader)
n, err = bufReader.Read(p)
if err == io.EOF {
err = nil
}
fp.readOffset += int64(n)
return
}
func (fp *fieldPart) Remain() int64 {
return fp.fieldLength - fp.readOffset
}
func (fp *fieldPart) Close() error {
return nil
}
type filePart struct {
fieldValue string
fieldLength int64
readOffset int64
filePath string
fileSize int64
file *os.File
once sync.Once
}
func (fp *filePart) Name() string {
return fp.filePath
}
func (fp *filePart) Len() int64 {
//the last 2 is last \r\n after file content
return fp.fieldLength + fp.fileSize + 2
}
func (fp *filePart) read(p []byte) (n int, err error) {
var (
nField int
nFile int
nLastCrLf int
)
fp.once.Do(func() {
//open file
fp.file, err = os.Open(fp.filePath)
if err != nil {
//open file fail, example: delete file after CreateFormFile
return
}
})
if err != nil {
//once.Do error
return 0, err
}
if fp.readOffset < fp.fieldLength {
// read from field
reader := bytes.NewReader([]byte(fp.fieldValue[fp.readOffset:]))
nField, err = reader.Read(p)
fp.readOffset += int64(nField)
if err == io.EOF {
err = nil
}
}
if fp.readOffset >= fp.fieldLength {
//read from file
fileOffset := fp.readOffset - fp.fieldLength
nFile, err = fp.file.ReadAt(p[nField:], fileOffset)
if err == io.EOF {
err = nil
//after read file end, then append \r\n
reader := bytes.NewReader([]byte("\r\n"))
nLastCrLf, err = reader.Read(p[(nField + nFile):])
fp.readOffset += int64(nLastCrLf) // 2 char
}
if err == nil {
//total read count
fp.readOffset += int64(nFile)
}
}
//current read count
n = nField + nFile + nLastCrLf
return n, err
}
func (fp *filePart) Remain() int64 {
return fp.Len() - fp.readOffset
}
func (fp *filePart) Close() (err error) {
if fp.file != nil {
err = fp.file.Close()
fp.file = nil
}
return err
}
// VirtualWriter is similar as builtin [mime/multipart.writer],
// but can support upload lots of files (4G+) with little memory consume.
type VirtualWriter struct {
closeAfterRead bool
boundary string
parts []VirtualPart
finished bool
readPartIndex int
readCount int64
totalCount int64
callback OnProgressCallback
}
// NewVirtualWriter returns a new multipart writer with a random boundary,
func NewVirtualWriter() *VirtualWriter {
boundary := randomBoundary()
return &VirtualWriter{
closeAfterRead: true,
boundary: boundary,
parts: make([]VirtualPart, 0),
readPartIndex: 0,
readCount: 0,
totalCount: int64(len(boundary) + 6), //init total count is for last boundary(--%s--\r\n)
}
}
// SetCloseAfterRead close the part immediately after read the part, it's used in file part, default is true
func (vw *VirtualWriter) SetCloseAfterRead(closeAfterRead bool) {
vw.closeAfterRead = closeAfterRead
}
// SetBoundary overrides the VirtualWriter's default randomly-generated
// boundary separator with an explicit value.
//
// Copy from built-in multipart.Writer, and modify for totalCount
func (vw *VirtualWriter) SetBoundary(boundary string) error {
if len(vw.parts) > 0 {
return errors.New("mime: SetBoundary called after write")
}
// rfc2046#section-5.1.1
if len(boundary) < 1 || len(boundary) > 70 {
return errors.New("mime: invalid boundary length")
}
end := len(boundary) - 1
for i, b := range boundary {
if 'A' <= b && b <= 'Z' || 'a' <= b && b <= 'z' || '0' <= b && b <= '9' {
continue
}
switch b {
case '\'', '(', ')', '+', '_', ',', '-', '.', '/', ':', '=', '?':
continue
case ' ':
if i != end {
continue
}
}
return errors.New("mime: invalid boundary character")
}
vw.boundary = boundary
vw.totalCount = int64(len(boundary) + 6) //init total count is for last boundary(--%s--\r\n)
return nil
}
// FormDataContentType returns the Content-Type for an HTTP
// multipart/form-data with this Writer's Boundary.
//
// Copy from built-in multipart.Writer #FormDataContentType
func (vw *VirtualWriter) FormDataContentType() string {
b := vw.boundary
// We must quote the boundary if it contains any of the
// tspecials characters defined by RFC 2045, or space.
if strings.ContainsAny(b, `()<>@,;:\"/[]?= `) {
b = `"` + b + `"`
}
return "multipart/form-data; boundary=" + b
}
func randomBoundary() string {
var buf [30]byte
_, err := io.ReadFull(rand.Reader, buf[:])
if err != nil {
panic(err)
}
return fmt.Sprintf("%x", buf[:])
}
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
func escapeQuotes(s string) string {
return quoteEscaper.Replace(s)
}
// SetProgressCallback set the callback function
func (vw *VirtualWriter) SetProgressCallback(callback OnProgressCallback) {
vw.callback = callback
}
// CreateFormFile creates a new form-data header with the provided field name and file name.
// But it just remains the file information, will not read the file contents to memory until actual POST occurs.
func (vw *VirtualWriter) CreateFormFile(fieldName, filePath string) error {
stat, err := os.Stat(filePath)
if err != nil {
return err
}
fieldValue := fmt.Sprintf("--%s\r\nContent-Disposition: form-data; name=\"%s\"; filename=\"%s\"\r\n"+
"Content-Type: application/octet-stream\r\n\r\n",
vw.boundary, escapeQuotes(fieldName), filepath.Base(filePath))
part := filePart{
fieldValue: fieldValue,
filePath: filePath,
fieldLength: int64(len(fieldValue)),
fileSize: stat.Size(),
readOffset: 0,
}
vw.parts = append(vw.parts, &part)
vw.totalCount += part.Len()
return nil
}
// WriteField creates a new form-data header with the provided field name and value
func (vw *VirtualWriter) WriteField(fieldName, value string) error {
fieldVal := fmt.Sprintf("--%s\r\nContent-Disposition: form-data; name=\"%s\"\r\n\r\n%s\r\n",
vw.boundary,
escapeQuotes(fieldName), value)
part := fieldPart{
fieldName: fieldName,
fieldValue: fieldVal,
fieldLength: int64(len(fieldVal)),
readOffset: 0,
}
vw.parts = append(vw.parts, &part)
vw.totalCount += part.Len()
return nil
}
// Read is used for Reader function(example: body parameter for http.NewRequest),
// user should NOT call it directly.
func (vw *VirtualWriter) Read(p []byte) (nRead int, err error) {
var (
nReadLastBoundary int
)
if vw.readPartIndex < len(vw.parts) {
part := vw.parts[vw.readPartIndex]
nRead, err = part.read(p) //p[nb:])
//log.Printf("idx[%d], part read, nRead=%d, remain=%d, err=%+v",
// vw.readPartIndex, nRead, part.Remain(), err)
if err == nil {
//TODO: read than available data, file change after add ?
//return -1, errors.New("read more data than available")
if part.Remain() == 0 {
if vw.closeAfterRead {
_ = part.Close()
}
//read all current part data, will try read next
vw.readPartIndex++
}
vw.readCount += int64(nRead)
}
if vw.callback != nil {
vw.callback(part, err, vw.readCount, vw.totalCount)
}
}
if vw.readPartIndex >= len(vw.parts) {
//already read all part's data
strLastBoundary := fmt.Sprintf("--%s--\r\n", vw.boundary)
reader := bytes.NewReader([]byte(strLastBoundary))
nReadLastBoundary, err = reader.Read(p[nRead:])
if err == nil {
nRead += nReadLastBoundary
err = io.EOF
}
vw.readCount += int64(nReadLastBoundary)
if vw.callback != nil {
vw.callback(nil, err, vw.readCount, vw.totalCount)
}
}
return nRead, err
}
// ContentLength return content length for all parts, it should same as "Content-Length" in http header
func (vw *VirtualWriter) ContentLength() int64 {
return vw.totalCount
}
// Close closes all the parts
func (vw *VirtualWriter) Close() error {
var err error
for _, part := range vw.parts {
pErr := part.Close()
if err == nil && pErr != nil {
//just return first error
err = pErr
}
}
vw.readCount = 0
vw.totalCount = 0
return err
}
// TestUploadFileWithVirtualWriter
// after this test case run, will create upload folder, and upload some files into it,
// then compare the source and target file's md5
func TestUploadFilesWithVirtualWriter(t *testing.T) {
//local fiddler proxy port, if not 0(example: 8888), then can use local fiddle to monitor network data
localProxyPort := 0
//this UT will remove all upload files after test, if you want remains the uploaded files to compare,
//can set removeUploadFiles to false
removeUploadFiles := true
uploadTempFolder := debugutil.VerifyWithResult(os.MkdirTemp(os.TempDir(), "virtual_writer"))
t.Logf("uploadTempFolder=%s", uploadTempFolder)
if removeUploadFiles {
defer func() {
_ = debugutil.Verify(os.RemoveAll(uploadTempFolder))
}()
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mpReader := debugutil.VerifyWithResult(r.MultipartReader())
uploadFields := make(map[string]string)
_ = debugutil.Verify(os.MkdirAll(uploadTempFolder, 0755))
for {
part, err := mpReader.NextPart()
//part, err := utils.VerifyWithResultEx[*multipart.Part](mpReader.NextPart())
if err == io.EOF {
break
}
name := part.FileName()
if name == "" {
buf := make([]byte, 1024)
n, err := part.Read(buf)
//t.Logf("part.Read error:%+v, n=%d", err, n)
if err != nil { // 等于 EOF 时表示读取完毕,之后 buf 中的才是结果,
uploadFields[part.FormName()] = string(buf[0:n])
} else {
uploadFields[part.FormName()] = string(buf[0:n]) //TODO: alloc buf?
}
continue
}
uploadFields[part.FormName()] = part.FileName()
filename := part.FileName()
filePath := path.Join(uploadTempFolder, filename)
outFile := debugutil.VerifyWithResult(os.Create(filePath))
defer outFile.Close()
_ = debugutil.VerifyWithResult(io.Copy(outFile, part))
}
marshalResult := debugutil.VerifyWithResult(json.Marshal(uploadFields))
w.WriteHeader(http.StatusOK)
_ = debugutil.VerifyWithResult(w.Write(marshalResult))
}))
defer ts.Close()
t.Logf("ts.URL=%s", ts.URL)
uploadResultExpected := make(map[string]string)
uploadUrl := fmt.Sprintf("%s%s", ts.URL, "/upload")
uploadFiles := []string{
"virtual_writer.go",
"virtual_writer_test.go",
//"not_exist",
// 此处可以输入想测试上传的大文件路径, 支持 4G+
//"F:\\ISO\\Windows\\Win10\\Win10_21H2_x64_CN_20220412.iso",
}
params := map[string]string{
"key": "value",
"type": "data",
"keepFileName": "on",
}
transport := http.DefaultTransport.(*http.Transport).Clone()
if localProxyPort != 0 {
proxyAddr := fmt.Sprintf("http://127.0.0.1:%d", localProxyPort)
proxyUrl, _ := url.Parse(proxyAddr)
proxyFunc := http.ProxyURL(proxyUrl)
transport.Proxy = proxyFunc
}
client := http.Client{
Transport: transport,
}
mpWrite := NewVirtualWriter()
defer func() {
_ = debugutil.Verify(mpWrite.Close())
}()
//_ = mpWrite.SetBoundary("----WebKitFormBoundary2HwfBAoBw2hJ33gD")
mpWrite.SetProgressCallback(func(part VirtualPart, err error, readCount, totalCount int64) {
name := ""
if part != nil {
name = part.Name()
}
_ = name
t.Logf("on progress: part name=%s, err=%+v, readCount=%d, totalCount=%d, percent=%0.2f",
name, err, readCount, totalCount, float64(readCount*100)/float64(totalCount))
debugutil.GoAssertTrue(t, readCount <= totalCount, "progress")
})
for key, val := range params {
_ = mpWrite.WriteField(key, val)
uploadResultExpected[key] = val
}
fileIndex := 0
for _, uf := range uploadFiles {
fieldName := fmt.Sprintf("file%d", fileIndex)
if err := debugutil.Verify(mpWrite.CreateFormFile(fieldName, uf)); err != nil {
t.Logf("CreateFormFile error, %+v", err)
//in real code, should handle this error and maybe return
//return
}
uploadResultExpected[fieldName] = filepath.Base(uf)
fileIndex++
}
t.Logf("content-length=%d", mpWrite.ContentLength())
req, err := http.NewRequest(http.MethodPost, uploadUrl, mpWrite)
req.Header.Set("Content-Type", mpWrite.FormDataContentType())
resp, err := client.Do(req)
debugutil.GoAssertTrue(t, err == nil, "client.Do should successful")
if err == nil {
body := debugutil.VerifyWithResult[[]byte](io.ReadAll(resp.Body))
defer resp.Body.Close()
var uploadResponse map[string]string
_ = debugutil.Verify(json.Unmarshal(body, &uploadResponse))
t.Logf("response=%s, err=%+v", string(body), err)
debugutil.GoAssertEqual(t, uploadResultExpected, uploadResponse, "upload result")
}
//check the uploaded files, should same as original files by check sum(MD5)
for idx, uf := range uploadFiles {
srcSum := fileCheckSum(uf)
dstPath := path.Join(uploadTempFolder, filepath.Base(uf))
dstSum := fileCheckSum(dstPath)
debugutil.GoAssertEqual(t, srcSum, dstSum, fmt.Sprintf("compare[%d] %s <=> %s", idx, uf, dstPath))
}
//use sleep to wait, so can verify memory usage, or runtime.MemStats?
if false {
time.Sleep(time.Second * 30)
}
}
func fileCheckSum(fileName string) string {
f, err := os.Open(fileName)
if err != nil {
log.Fatal(err)
}
defer f.Close()
h := md5.New()
//h := sha256.New()
//h := sha1.New()
//h := sha512.New()
if _, err := io.Copy(h, f); err != nil {
log.Fatal(err)
}
return fmt.Sprintf("%x", h.Sum(nil))
}