golang http上传文件的用法以及对官方库的零拷贝优化

原创内容,转载请注明出处

Posted by Weakyon Blog on May 10, 2016

golang的http上传文件的实现细节严格遵循了multipart/form-data的RFC1867规范,细节网上已经分析了很多了

本篇只是讨论golang上如何使用官方库实现的框架


客户端上传很简单,利用”mime/multipart”官方库即可完成上传

可以看到这里上传文件时全部加载在内存的bodyBuf字符数组中,所以这里只是上传小文件,大文件这么上传会写爆内存

  
func postFile(filename, targetUrl, token string) error {                        
    var (                                                                       
        bodyBuf         *bytes.Buffer                                           
        bodyWriter      *multipart.Writer                                       
        file            *os.File                                                
        err             error                                                   
        contentType     string                                                  
        client          http.Client                                             
        req             *http.Request                                           
        resp            *http.Response                                          
        respBody        []byte                                                  
    )                                                                           
    bodyBuf = &bytes.Buffer{}                                                   
    bodyWriter = multipart.NewWriter(bodyBuf)                                   
                                                                                
    fileWriter, err := bodyWriter.CreateFormFile("PolkaFile",filename)          
    if err != nil {                                                             
        return err                                                              
    }                                                                           
                                                                                
    file, err = os.Open(filename)                                               
    defer file.Close()                                                          
    if err != nil {                                                             
        return err                                                              
    }                                                                           
                                                                                
    _, err = io.Copy(fileWriter, file)                                          
    if err != nil {                                                             
        return err                                                              
    }                                                                           
    
    //这里必须Close,否则不会向bodyBuf写入boundary分隔符
    err = bodyWriter.Close();                                                   
    if err != nil {                                                             
        return err                                                              
    }                                                                           
    contentType = bodyWriter.FormDataContentType()                              
                                                                                
    req, err = http.NewRequest("POST", targetUrl, bodyBuf)                      
    if err != nil {                                                             
        return err                                                              
    }                                                                           
    req.Header.Set("Content-Type",contentType)                                  
    req.Header.Set("Authorization",token)                                       
    resp, err = client.Do(req)                                                  
    if err != nil {                                                             
        return err                                                              
    }                                                                           
    defer resp.Body.Close()                                                     
                                                                                
    respBody, err = ioutil.ReadAll(resp.Body)                                   
    if err != nil {                                                             
        return err                                                              
    }                                                                           
    log.Infof("resp status: %s,resp body: %s",resp.Status, string(respBody))    
    return nil                                                                  
}

服务端的上传也很简单

在golang http框架用mux的HandleFunc方法绑定的回调函数中写入几行代码即可

  
func (this *proxy) upload(wr http.ResponseWriter, r *http.Request) {               
    file, handle, err := r.FormFile("PolkaFile")                                   
    defer file.Close()                                                             
    if err != nil {                                                                
        log.Errorf("%v",err)                                                       
    }                                                                              
    return                                                                         
}

这里的handle的Filename字段对应了客户端上传时”PolkaFile”这个字符串绑定的文件名

这里的file是一个接口

  
type File interface {                                                           
    io.Reader                                                                   
    io.ReaderAt                                                                 
    io.Seeker                                                                   
    io.Closer                                                                   
}

为什么要这么做呢,因为文件有可能存储在内存里或者是磁盘上

  
按照http.Request.FormFile()->http.Request.ParseMultipartForm()->multipart.Reader.ReadForm()这个调用链来看

multipart.Reader.ReadForm()中有这么一段
// file, store in memory or on disk                                     
fh := &FileHeader{                                                      
    Filename: filename,                                                 
    Header:   p.Header,                                                 
}                                                                       
n, err := io.CopyN(&b, p, maxMemory+1)                                  
if err != nil && err != io.EOF {                                        
    return nil, err                                                     
}                                                                       
if n > maxMemory {                                                      
    // too big, write to disk and flush buffer                          
    file, err := ioutil.TempFile("", "multipart-")                      
    if err != nil {                                                     
        return nil, err                                                 
    }                                                                   
    defer file.Close()                                                  
    _, err = io.Copy(file, io.MultiReader(&b, p))                       
    if err != nil {                                                     
        os.Remove(file.Name())                                          
        return nil, err                                                 
    }                                                                   
    fh.tmpfile = file.Name()                                            
} else {                                                                
    fh.content = b.Bytes()                                              
    maxMemory -= n                                                      
}

数据从p中读出,如果大于最大内存(默认值32MB),那么会用MultiReader重新读取,并且写入到TempFile中

这部分文件会在multipart.Reader.RemoveAll()中被销毁

  
// RemoveAll removes any temporary files associated with a Form.                   
func (f *Form) RemoveAll() error {                                                 
    var err error                                                                  
    for _, fhs := range f.File {                                                   
        for _, fh := range fhs {                                                   
            if fh.tmpfile != "" {                                                  
                e := os.Remove(fh.tmpfile)                                         
                if e != nil && err == nil {                                        
                    err = e                                                        
                }                                                                  
            }                                                                      
        }                                                                          
    }                                                                              
    return err                                                                     
}

RemoveAll()函数会在http.response.finishRequest()这个函数中被调用,也就是结束请求后被销毁


我们可以通过handle来获取这个结构体,但是无法操作content和tmpfile这两个未导出字段

  
type FileHeader struct {                                                           
    Filename string                                                                
    Header   textproto.MIMEHeader                                                  
                                                                                   
    content []byte                                                                 
    tmpfile string                                                                 
}

只能通过这个File接口,所以数据必须通过io.Copy等方式进行一次复制,这样效率会大大降低

假如是一个[]byte数组,那么这个数组可以直接被发送数据,执行一次复制会加大内存的开销和GC的负担

加入是一个tmp文件,那么这个文件可以直接使用sendfile调用发送出去,执行一次复制再写到其他文件或者写到内存,也会增加内存的开销

要做到零拷贝,就只能通过unsafe来直接操作未导出字段

我写了些测试代码如下

  
package common                                                                     
                                                                                   
import (                                                                           
    "fmt"                                                                          
    "net/textproto"                                                                
)                                                                                  
                                                                                   
type FileHeader struct {                                                           
    Filename string                                                                
    Header   textproto.MIMEHeader                                                  
                                                                                   
    content []byte                                                                 
    tmpfile string                                                                 
}                                                                                  
                                                                                   
func (this *FileHeader) Set() {                                                    
    this.Filename = "123"                                                          
    this.content = []byte{1,2,3,4,5,6,7,8,9,10}                                    
    this.tmpfile = "456"                                                           
}                                                                                  
                                                                                   
func (this *FileHeader) Get() {                                                    
    fmt.Println(this.Filename)                                                     
    fmt.Println(this.content)                                                      
    fmt.Println(this.tmpfile)                                                      
}

这是这个山寨FileHeader,用于进行简单的GetSet

然后是UnSafeMultipart

  
package main                                                                       
                                                                                   
import (                                                                           
    "unsafe"                                                                       
    "fmt"                                                                          
    "coding.net/tedcy/Polka/common"                                                
)                                                                                  
                                                                                   
type UnsafeMultipart struct {                                                      
    common.FileHeader                                                              
}                                                                                  
                                                                                   
func (this *UnsafeMultipart) GetContent() []byte{                                  
    contentPtr := (*[]byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&this.Header)) + uintptr(unsafe.Sizeof(this.Header))))
    return *contentPtr                                                             
}                                                                                  
                                                                                   
func (this *UnsafeMultipart) GetTmpfileName() string{                              
    tmpfileNamePtr := (*string)(                                                   
        unsafe.Pointer(uintptr(unsafe.Pointer(&this.Header)) +                     
            uintptr(unsafe.Sizeof(this.Header)) +                                  
            uintptr(unsafe.Sizeof([]byte(nil)))))                                  
    return *tmpfileNamePtr                                                         
}                                                                                  
                                                                                   
func main() {                                                                      
    u := &UnsafeMultipart{}                                                        
    u.Set()                                                                        
    u.Get()                                                                        
    fmt.Println(u.GetContent())                                                    
    fmt.Println(u.GetTmpfileName())                                                
}

输出

123
[1 2 3 4 5 6 7 8 9 10]
456
[1 2 3 4 5 6 7 8 9 10]
456

可以看到,已经取得了我要的未导出字段

unsafe有几个小细节

1 unsafe.Pointer()传入的只能是指针

2 unsafe.Pointer要参与运算只能转换为uintptr

3 uintptr得到的地址只能转换为unsafe.Pointer

4 最终转换的unsafe.Pointer也只能转换成某个类型的指针,例如[]byte,string

5 加一段测试代码

fmt.Println(uintptr(unsafe.Pointer(u)))                                        
fmt.Println(uintptr(unsafe.Pointer(&u.FileHeader)))                            
fmt.Println(uintptr(unsafe.Pointer(&u.Filename)))                              
fmt.Println(uintptr(unsafe.Pointer(&u.Header)))

输出是

859530421504
859530421504
859530421504
859530421520

可以看到u的地址和u封装的FileHeader是同一个地址,和FileHeader内的Filename也是同一个地址

然后Filename虽然是string类型,但是实际占用是固定长度的大小


完毕


补充:

后来发现以上全是考虑太多了,go的http库有两种解析multipart上传的方式,另外一种可以满足需求

利用http request的MultipartReader()即可,这个函数是和ParseMultipartForm互斥的,无论哪个被调用,另外一个都不能再调用了

ParseMultipartForm是将表单的form全部解析好,而MultipartReader则是将表单解析成单独的Part内容,以供解析

reader, err := r.MultipartReader()
if err != nil {
    ...
}
for {
    part, err := reader.NextPart()
    if err == io.EOF {
        break
    }
    if part.FileName() != "" {
        //filecontent
        continue
    }
    data, err := ioutil.ReadAll(part)
    if err != nil {
        ...
    }
    // form data, key = part.FormName(),value = string(data)
}

可以看到,解析完全依赖自己的处理逻辑,也就是避免了二次拷贝

10 May 2016