From 24967a5c25499f92b4e58b8d6f8a92a46a7acc7a Mon Sep 17 00:00:00 2001 From: Zitao <369815332@qq.com> Date: Tue, 22 Jun 2021 14:06:08 +0800 Subject: feat: stream write to zip directly (#863) --- file.go | 74 +++++++++++++++++++++++++++++++++++++++--------------------- file_test.go | 43 ++++++++++++++++++++++++++--------- 2 files changed, 80 insertions(+), 37 deletions(-) diff --git a/file.go b/file.go index 36b1a42..495718a 100644 --- a/file.go +++ b/file.go @@ -87,17 +87,55 @@ func (f *File) Write(w io.Writer) error { // WriteTo implements io.WriterTo to write the file. func (f *File) WriteTo(w io.Writer) (int64, error) { - buf, err := f.WriteToBuffer() - if err != nil { + if f.options != nil && f.options.Password != "" { + buf, err := f.WriteToBuffer() + if err != nil { + return 0, err + } + return buf.WriteTo(w) + } + if err := f.writeDirectToWriter(w); err != nil { return 0, err } - return buf.WriteTo(w) + return 0, nil } -// WriteToBuffer provides a function to get bytes.Buffer from the saved file. +// WriteToBuffer provides a function to get bytes.Buffer from the saved file. And it allocate space in memory. Be careful when the file size is large. func (f *File) WriteToBuffer() (*bytes.Buffer, error) { buf := new(bytes.Buffer) zw := zip.NewWriter(buf) + + if err := f.writeToZip(zw); err != nil { + return buf, zw.Close() + } + + if f.options != nil && f.options.Password != "" { + if err := zw.Close(); err != nil { + return buf, err + } + b, err := Encrypt(buf.Bytes(), f.options) + if err != nil { + return buf, err + } + buf.Reset() + buf.Write(b) + return buf, nil + } + return buf, zw.Close() +} + +// writeDirectToWriter provides a function to write to io.Writer. +func (f *File) writeDirectToWriter(w io.Writer) error { + zw := zip.NewWriter(w) + if err := f.writeToZip(zw); err != nil { + zw.Close() + return err + } + return zw.Close() +} + +// writeToZip provides a function to write to zip.Writer +func (f *File) writeToZip(zw *zip.Writer) error { f.calcChainWriter() f.commentsWriter() f.contentTypesWriter() @@ -112,19 +150,17 @@ func (f *File) WriteToBuffer() (*bytes.Buffer, error) { for path, stream := range f.streams { fi, err := zw.Create(path) if err != nil { - zw.Close() - return buf, err + return err } var from io.Reader from, err = stream.rawData.Reader() if err != nil { stream.rawData.Close() - return buf, err + return err } _, err = io.Copy(fi, from) if err != nil { - zw.Close() - return buf, err + return err } stream.rawData.Close() } @@ -135,27 +171,13 @@ func (f *File) WriteToBuffer() (*bytes.Buffer, error) { } fi, err := zw.Create(path) if err != nil { - zw.Close() - return buf, err + return err } _, err = fi.Write(content) if err != nil { - zw.Close() - return buf, err + return err } } - if f.options != nil && f.options.Password != "" { - if err := zw.Close(); err != nil { - return buf, err - } - b, err := Encrypt(buf.Bytes(), f.options) - if err != nil { - return buf, err - } - buf.Reset() - buf.Write(b) - return buf, nil - } - return buf, zw.Close() + return nil } diff --git a/file_test.go b/file_test.go index 656271f..dbbf75a 100644 --- a/file_test.go +++ b/file_test.go @@ -3,6 +3,7 @@ package excelize import ( "bufio" "bytes" + "os" "strings" "testing" @@ -33,16 +34,36 @@ func BenchmarkWrite(b *testing.B) { } func TestWriteTo(t *testing.T) { - f := File{} - buf := bytes.Buffer{} - f.XLSX = make(map[string][]byte) - f.XLSX["/d/"] = []byte("s") - _, err := f.WriteTo(bufio.NewWriter(&buf)) - assert.EqualError(t, err, "zip: write to directory") - delete(f.XLSX, "/d/") + // Test WriteToBuffer err + { + f := File{} + buf := bytes.Buffer{} + f.XLSX = make(map[string][]byte) + f.XLSX["/d/"] = []byte("s") + _, err := f.WriteTo(bufio.NewWriter(&buf)) + assert.EqualError(t, err, "zip: write to directory") + delete(f.XLSX, "/d/") + } // Test file path overflow - const maxUint16 = 1<<16 - 1 - f.XLSX[strings.Repeat("s", maxUint16+1)] = nil - _, err = f.WriteTo(bufio.NewWriter(&buf)) - assert.EqualError(t, err, "zip: FileHeader.Name too long") + { + f := File{} + buf := bytes.Buffer{} + f.XLSX = make(map[string][]byte) + const maxUint16 = 1<<16 - 1 + f.XLSX[strings.Repeat("s", maxUint16+1)] = nil + _, err := f.WriteTo(bufio.NewWriter(&buf)) + assert.EqualError(t, err, "zip: FileHeader.Name too long") + } + // Test StreamsWriter err + { + f := File{} + buf := bytes.Buffer{} + f.XLSX = make(map[string][]byte) + f.XLSX["s"] = nil + f.streams = make(map[string]*StreamWriter) + file, _ := os.Open("123") + f.streams["s"] = &StreamWriter{rawData: bufferedWriter{tmp: file}} + _, err := f.WriteTo(bufio.NewWriter(&buf)) + assert.Nil(t, err) + } } -- cgit v1.2.1