summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonham.Chen <me@jonham.cn>2022-01-08 10:32:13 +0800
committerGitHub <noreply@github.com>2022-01-08 10:32:13 +0800
commitaf5c4d00e81b62a3f6ff6cb34a89502400552a2d (patch)
tree024080e9f44a0e3a9e41af81c8c3b2877fc69832
parent9e64df6a96685afcfbc7295beda38739868a6871 (diff)
feat: implement SHA-512 algorithm to ProtectSheet (#1115)
-rw-r--r--chart.go4
-rw-r--r--crypt.go64
-rw-r--r--crypt_test.go20
-rw-r--r--datavalidation.go26
-rw-r--r--datavalidation_test.go4
-rw-r--r--errors.go50
-rw-r--r--excelize_test.go57
-rw-r--r--pivotTable_test.go2
-rw-r--r--sheet.go49
-rw-r--r--xmlWorksheet.go1
10 files changed, 223 insertions, 54 deletions
diff --git a/chart.go b/chart.go
index 755c160..b43f9f2 100644
--- a/chart.go
+++ b/chart.go
@@ -980,12 +980,12 @@ func (f *File) getFormatChart(format string, combo []string) (*formatChart, []*f
return formatSet, comboCharts, err
}
if _, ok := chartValAxNumFmtFormatCode[comboChart.Type]; !ok {
- return formatSet, comboCharts, newUnsupportChartType(comboChart.Type)
+ return formatSet, comboCharts, newUnsupportedChartType(comboChart.Type)
}
comboCharts = append(comboCharts, comboChart)
}
if _, ok := chartValAxNumFmtFormatCode[formatSet.Type]; !ok {
- return formatSet, comboCharts, newUnsupportChartType(formatSet.Type)
+ return formatSet, comboCharts, newUnsupportedChartType(formatSet.Type)
}
return formatSet, comboCharts, err
}
diff --git a/crypt.go b/crypt.go
index ae39bba..65b9956 100644
--- a/crypt.go
+++ b/crypt.go
@@ -43,6 +43,7 @@ var (
packageOffset = 8 // First 8 bytes are the size of the stream
packageEncryptionChunkSize = 4096
iterCount = 50000
+ sheetProtectionSpinCount = 1e5
oleIdentifier = []byte{
0xd0, 0xcf, 0x11, 0xe0, 0xa1, 0xb1, 0x1a, 0xe1,
}
@@ -146,7 +147,7 @@ func Decrypt(raw []byte, opt *Options) (packageBuf []byte, err error) {
case "standard":
return standardDecrypt(encryptionInfoBuf, encryptedPackageBuf, opt)
default:
- err = ErrUnsupportEncryptMechanism
+ err = ErrUnsupportedEncryptMechanism
}
return
}
@@ -307,7 +308,7 @@ func encryptionMechanism(buffer []byte) (mechanism string, err error) {
} else if (versionMajor == 3 || versionMajor == 4) && versionMinor == 3 {
mechanism = "extensible"
}
- err = ErrUnsupportEncryptMechanism
+ err = ErrUnsupportedEncryptMechanism
return
}
@@ -387,14 +388,14 @@ func standardConvertPasswdToKey(header StandardEncryptionHeader, verifier Standa
key = hashing("sha1", iterator, key)
}
var block int
- hfinal := hashing("sha1", key, createUInt32LEBuffer(block, 4))
+ hFinal := hashing("sha1", key, createUInt32LEBuffer(block, 4))
cbRequiredKeyLength := int(header.KeySize) / 8
cbHash := sha1.Size
buf1 := bytes.Repeat([]byte{0x36}, 64)
- buf1 = append(standardXORBytes(hfinal, buf1[:cbHash]), buf1[cbHash:]...)
+ buf1 = append(standardXORBytes(hFinal, buf1[:cbHash]), buf1[cbHash:]...)
x1 := hashing("sha1", buf1)
buf2 := bytes.Repeat([]byte{0x5c}, 64)
- buf2 = append(standardXORBytes(hfinal, buf2[:cbHash]), buf2[cbHash:]...)
+ buf2 = append(standardXORBytes(hFinal, buf2[:cbHash]), buf2[cbHash:]...)
x2 := hashing("sha1", buf2)
x3 := append(x1, x2...)
keyDerived := x3[:cbRequiredKeyLength]
@@ -417,7 +418,8 @@ func standardXORBytes(a, b []byte) []byte {
// ECMA-376 Agile Encryption
// agileDecrypt decrypt the CFB file format with ECMA-376 agile encryption.
-// Support cryptographic algorithm: MD4, MD5, RIPEMD-160, SHA1, SHA256, SHA384 and SHA512.
+// Support cryptographic algorithm: MD4, MD5, RIPEMD-160, SHA1, SHA256,
+// SHA384 and SHA512.
func agileDecrypt(encryptionInfoBuf, encryptedPackageBuf []byte, opt *Options) (packageBuf []byte, err error) {
var encryptionInfo Encryption
if encryptionInfo, err = parseEncryptionInfo(encryptionInfoBuf[8:]); err != nil {
@@ -605,11 +607,55 @@ func createIV(blockKey interface{}, encryption Encryption) ([]byte, error) {
return iv, nil
}
-// randomBytes returns securely generated random bytes. It will return an error if the system's
-// secure random number generator fails to function correctly, in which case the caller should not
-// continue.
+// randomBytes returns securely generated random bytes. It will return an
+// error if the system's secure random number generator fails to function
+// correctly, in which case the caller should not continue.
func randomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
return b, err
}
+
+// ISO Write Protection Method
+
+// genISOPasswdHash implements the ISO password hashing algorithm by given
+// plaintext password, name of the cryptographic hash algorithm, salt value
+// and spin count.
+func genISOPasswdHash(passwd, hashAlgorithm, salt string, spinCount int) (hashValue, saltValue string, err error) {
+ if len(passwd) < 1 || len(passwd) > MaxFieldLength {
+ err = ErrPasswordLengthInvalid
+ return
+ }
+ hash, ok := map[string]string{
+ "MD4": "md4",
+ "MD5": "md5",
+ "SHA-1": "sha1",
+ "SHA-256": "sha256",
+ "SHA-384": "sha384",
+ "SHA-512": "sha512",
+ }[hashAlgorithm]
+ if !ok {
+ err = ErrUnsupportedHashAlgorithm
+ return
+ }
+ var b bytes.Buffer
+ s, _ := randomBytes(16)
+ if salt != "" {
+ if s, err = base64.StdEncoding.DecodeString(salt); err != nil {
+ return
+ }
+ }
+ b.Write(s)
+ encoder := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewEncoder()
+ passwordBuffer, _ := encoder.Bytes([]byte(passwd))
+ b.Write(passwordBuffer)
+ // Generate the initial hash.
+ key := hashing(hash, b.Bytes())
+ // Now regenerate until spin count.
+ for i := 0; i < spinCount; i++ {
+ iterator := createUInt32LEBuffer(i, 4)
+ key = hashing(hash, key, iterator)
+ }
+ hashValue, saltValue = base64.StdEncoding.EncodeToString(key), base64.StdEncoding.EncodeToString(s)
+ return
+}
diff --git a/crypt_test.go b/crypt_test.go
index 0796482..0ad6f98 100644
--- a/crypt_test.go
+++ b/crypt_test.go
@@ -28,11 +28,27 @@ func TestEncrypt(t *testing.T) {
func TestEncryptionMechanism(t *testing.T) {
mechanism, err := encryptionMechanism([]byte{3, 0, 3, 0})
assert.Equal(t, mechanism, "extensible")
- assert.EqualError(t, err, ErrUnsupportEncryptMechanism.Error())
+ assert.EqualError(t, err, ErrUnsupportedEncryptMechanism.Error())
_, err = encryptionMechanism([]byte{})
assert.EqualError(t, err, ErrUnknownEncryptMechanism.Error())
}
func TestHashing(t *testing.T) {
- assert.Equal(t, hashing("unsupportHashAlgorithm", []byte{}), []uint8([]byte(nil)))
+ assert.Equal(t, hashing("unsupportedHashAlgorithm", []byte{}), []uint8([]byte(nil)))
+}
+
+func TestGenISOPasswdHash(t *testing.T) {
+ for hashAlgorithm, expected := range map[string][]string{
+ "MD4": {"2lZQZUubVHLm/t6KsuHX4w==", "TTHjJdU70B/6Zq83XGhHVA=="},
+ "MD5": {"HWbqyd4dKKCjk1fEhk2kuQ==", "8ADyorkumWCayIukRhlVKQ=="},
+ "SHA-1": {"XErQIV3Ol+nhXkyCxrLTEQm+mSc=", "I3nDtyf59ASaNX1l6KpFnA=="},
+ "SHA-256": {"7oqMFyfED+mPrzRIBQ+KpKT4SClMHEPOZldliP15xAA=", "ru1R/w3P3Jna2Qo+EE8QiA=="},
+ "SHA-384": {"nMODLlxsC8vr0btcq0kp/jksg5FaI3az5Sjo1yZk+/x4bFzsuIvpDKUhJGAk/fzo", "Zjq9/jHlgOY6MzFDSlVNZg=="},
+ "SHA-512": {"YZ6jrGOFQgVKK3rDK/0SHGGgxEmFJglQIIRamZc2PkxVtUBp54fQn96+jVXEOqo6dtCSanqksXGcm/h3KaiR4Q==", "p5s/bybHBPtusI7EydTIrg=="},
+ } {
+ hashValue, saltValue, err := genISOPasswdHash("password", hashAlgorithm, expected[1], int(sheetProtectionSpinCount))
+ assert.NoError(t, err)
+ assert.Equal(t, expected[0], hashValue)
+ assert.Equal(t, expected[1], saltValue)
+ }
}
diff --git a/datavalidation.go b/datavalidation.go
index 80a0295..205d948 100644
--- a/datavalidation.go
+++ b/datavalidation.go
@@ -29,7 +29,7 @@ const (
DataValidationTypeDate
DataValidationTypeDecimal
typeList // inline use
- DataValidationTypeTextLeng
+ DataValidationTypeTextLength
DataValidationTypeTime
// DataValidationTypeWhole Integer
DataValidationTypeWhole
@@ -116,7 +116,7 @@ func (dd *DataValidation) SetInput(title, msg string) {
func (dd *DataValidation) SetDropList(keys []string) error {
formula := strings.Join(keys, ",")
if MaxFieldLength < len(utf16.Encode([]rune(formula))) {
- return ErrDataValidationFormulaLenth
+ return ErrDataValidationFormulaLength
}
dd.Formula1 = fmt.Sprintf(`<formula1>"%s"</formula1>`, formulaEscaper.Replace(formula))
dd.Type = convDataValidationType(typeList)
@@ -155,7 +155,7 @@ func (dd *DataValidation) SetRange(f1, f2 interface{}, t DataValidationType, o D
}
dd.Formula1, dd.Formula2 = formula1, formula2
dd.Type = convDataValidationType(t)
- dd.Operator = convDataValidationOperatior(o)
+ dd.Operator = convDataValidationOperator(o)
return nil
}
@@ -192,22 +192,22 @@ func (dd *DataValidation) SetSqref(sqref string) {
// convDataValidationType get excel data validation type.
func convDataValidationType(t DataValidationType) string {
typeMap := map[DataValidationType]string{
- typeNone: "none",
- DataValidationTypeCustom: "custom",
- DataValidationTypeDate: "date",
- DataValidationTypeDecimal: "decimal",
- typeList: "list",
- DataValidationTypeTextLeng: "textLength",
- DataValidationTypeTime: "time",
- DataValidationTypeWhole: "whole",
+ typeNone: "none",
+ DataValidationTypeCustom: "custom",
+ DataValidationTypeDate: "date",
+ DataValidationTypeDecimal: "decimal",
+ typeList: "list",
+ DataValidationTypeTextLength: "textLength",
+ DataValidationTypeTime: "time",
+ DataValidationTypeWhole: "whole",
}
return typeMap[t]
}
-// convDataValidationOperatior get excel data validation operator.
-func convDataValidationOperatior(o DataValidationOperator) string {
+// convDataValidationOperator get excel data validation operator.
+func convDataValidationOperator(o DataValidationOperator) string {
typeMap := map[DataValidationOperator]string{
DataValidationOperatorBetween: "between",
DataValidationOperatorEqual: "equal",
diff --git a/datavalidation_test.go b/datavalidation_test.go
index 56e7d73..d07f1b1 100644
--- a/datavalidation_test.go
+++ b/datavalidation_test.go
@@ -94,7 +94,7 @@ func TestDataValidationError(t *testing.T) {
t.Errorf("data validation error. Formula1 must be empty!")
return
}
- assert.EqualError(t, err, ErrDataValidationFormulaLenth.Error())
+ assert.EqualError(t, err, ErrDataValidationFormulaLength.Error())
assert.EqualError(t, dvRange.SetRange(nil, 20, DataValidationTypeWhole, DataValidationOperatorBetween), ErrParameterInvalid.Error())
assert.EqualError(t, dvRange.SetRange(10, nil, DataValidationTypeWhole, DataValidationOperatorBetween), ErrParameterInvalid.Error())
assert.NoError(t, dvRange.SetRange(10, 20, DataValidationTypeWhole, DataValidationOperatorGreaterThan))
@@ -114,7 +114,7 @@ func TestDataValidationError(t *testing.T) {
err = dvRange.SetDropList(keys)
assert.Equal(t, prevFormula1, dvRange.Formula1,
"Formula1 should be unchanged for invalid input %v", keys)
- assert.EqualError(t, err, ErrDataValidationFormulaLenth.Error())
+ assert.EqualError(t, err, ErrDataValidationFormulaLength.Error())
}
assert.NoError(t, f.AddDataValidation("Sheet1", dvRange))
assert.NoError(t, dvRange.SetRange(
diff --git a/errors.go b/errors.go
index 9460803..9f39a7a 100644
--- a/errors.go
+++ b/errors.go
@@ -16,42 +16,50 @@ import (
"fmt"
)
-// newInvalidColumnNameError defined the error message on receiving the invalid column name.
+// newInvalidColumnNameError defined the error message on receiving the
+// invalid column name.
func newInvalidColumnNameError(col string) error {
return fmt.Errorf("invalid column name %q", col)
}
-// newInvalidRowNumberError defined the error message on receiving the invalid row number.
+// newInvalidRowNumberError defined the error message on receiving the invalid
+// row number.
func newInvalidRowNumberError(row int) error {
return fmt.Errorf("invalid row number %d", row)
}
-// newInvalidCellNameError defined the error message on receiving the invalid cell name.
+// newInvalidCellNameError defined the error message on receiving the invalid
+// cell name.
func newInvalidCellNameError(cell string) error {
return fmt.Errorf("invalid cell name %q", cell)
}
-// newInvalidExcelDateError defined the error message on receiving the data with negative values.
+// newInvalidExcelDateError defined the error message on receiving the data
+// with negative values.
func newInvalidExcelDateError(dateValue float64) error {
return fmt.Errorf("invalid date value %f, negative values are not supported", dateValue)
}
-// newUnsupportChartType defined the error message on receiving the chart type are unsupported.
-func newUnsupportChartType(chartType string) error {
+// newUnsupportedChartType defined the error message on receiving the chart
+// type are unsupported.
+func newUnsupportedChartType(chartType string) error {
return fmt.Errorf("unsupported chart type %s", chartType)
}
-// newUnzipSizeLimitError defined the error message on unzip size exceeds the limit.
+// newUnzipSizeLimitError defined the error message on unzip size exceeds the
+// limit.
func newUnzipSizeLimitError(unzipSizeLimit int64) error {
return fmt.Errorf("unzip size exceeds the %d bytes limit", unzipSizeLimit)
}
-// newInvalidStyleID defined the error message on receiving the invalid style ID.
+// newInvalidStyleID defined the error message on receiving the invalid style
+// ID.
func newInvalidStyleID(styleID int) error {
return fmt.Errorf("invalid style ID %d, negative values are not supported", styleID)
}
-// newFieldLengthError defined the error message on receiving the field length overflow.
+// newFieldLengthError defined the error message on receiving the field length
+// overflow.
func newFieldLengthError(name string) error {
return fmt.Errorf("field %s must be less or equal than 255 characters", name)
}
@@ -103,12 +111,18 @@ var (
ErrMaxFileNameLength = errors.New("file name length exceeds maximum limit")
// ErrEncrypt defined the error message on encryption spreadsheet.
ErrEncrypt = errors.New("not support encryption currently")
- // ErrUnknownEncryptMechanism defined the error message on unsupport
+ // ErrUnknownEncryptMechanism defined the error message on unsupported
// encryption mechanism.
ErrUnknownEncryptMechanism = errors.New("unknown encryption mechanism")
- // ErrUnsupportEncryptMechanism defined the error message on unsupport
+ // ErrUnsupportedEncryptMechanism defined the error message on unsupported
// encryption mechanism.
- ErrUnsupportEncryptMechanism = errors.New("unsupport encryption mechanism")
+ ErrUnsupportedEncryptMechanism = errors.New("unsupported encryption mechanism")
+ // ErrUnsupportedHashAlgorithm defined the error message on unsupported
+ // hash algorithm.
+ ErrUnsupportedHashAlgorithm = errors.New("unsupported hash algorithm")
+ // ErrPasswordLengthInvalid defined the error message on invalid password
+ // length.
+ ErrPasswordLengthInvalid = errors.New("password length invalid")
// ErrParameterRequired defined the error message on receive the empty
// parameter.
ErrParameterRequired = errors.New("parameter is required")
@@ -131,11 +145,17 @@ var (
// ErrSheetIdx defined the error message on receive the invalid worksheet
// index.
ErrSheetIdx = errors.New("invalid worksheet index")
+ // ErrUnprotectSheet defined the error message on worksheet has set no
+ // protection.
+ ErrUnprotectSheet = errors.New("worksheet has set no protect")
+ // ErrUnprotectSheetPassword defined the error message on remove sheet
+ // protection with password verification failed.
+ ErrUnprotectSheetPassword = errors.New("worksheet protect password not match")
// ErrGroupSheets defined the error message on group sheets.
ErrGroupSheets = errors.New("group worksheet must contain an active worksheet")
- // ErrDataValidationFormulaLenth defined the error message for receiving a
+ // ErrDataValidationFormulaLength defined the error message for receiving a
// data validation formula length that exceeds the limit.
- ErrDataValidationFormulaLenth = errors.New("data validation must be 0-255 characters")
+ ErrDataValidationFormulaLength = errors.New("data validation must be 0-255 characters")
// ErrDataValidationRange defined the error message on set decimal range
// exceeds limit.
ErrDataValidationRange = errors.New("data validation range exceeds limit")
@@ -164,5 +184,5 @@ var (
ErrSparkline = errors.New("must have the same number of 'Location' and 'Range' parameters")
// ErrSparklineStyle defined the error message on receive the invalid
// sparkline Style parameters.
- ErrSparklineStyle = errors.New("parameter 'Style' must betweent 0-35")
+ ErrSparklineStyle = errors.New("parameter 'Style' must between 0-35")
)
diff --git a/excelize_test.go b/excelize_test.go
index 9aaaae9..0edfe11 100644
--- a/excelize_test.go
+++ b/excelize_test.go
@@ -1160,13 +1160,44 @@ func TestHSL(t *testing.T) {
func TestProtectSheet(t *testing.T) {
f := NewFile()
- assert.NoError(t, f.ProtectSheet("Sheet1", nil))
- assert.NoError(t, f.ProtectSheet("Sheet1", &FormatSheetProtection{
+ sheetName := f.GetSheetName(0)
+ assert.NoError(t, f.ProtectSheet(sheetName, nil))
+ // Test protect worksheet with XOR hash algorithm
+ assert.NoError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{
Password: "password",
EditScenarios: false,
}))
-
+ ws, err := f.workSheetReader(sheetName)
+ assert.NoError(t, err)
+ assert.Equal(t, "83AF", ws.SheetProtection.Password)
assert.NoError(t, f.SaveAs(filepath.Join("test", "TestProtectSheet.xlsx")))
+ // Test protect worksheet with SHA-512 hash algorithm
+ assert.NoError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{
+ AlgorithmName: "SHA-512",
+ Password: "password",
+ }))
+ ws, err = f.workSheetReader(sheetName)
+ assert.NoError(t, err)
+ assert.Equal(t, 24, len(ws.SheetProtection.SaltValue))
+ assert.Equal(t, 88, len(ws.SheetProtection.HashValue))
+ assert.Equal(t, int(sheetProtectionSpinCount), ws.SheetProtection.SpinCount)
+ // Test remove sheet protection with an incorrect password
+ assert.EqualError(t, f.UnprotectSheet(sheetName, "wrongPassword"), ErrUnprotectSheetPassword.Error())
+ // Test remove sheet protection with password verification
+ assert.NoError(t, f.UnprotectSheet(sheetName, "password"))
+ // Test protect worksheet with empty password
+ assert.NoError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{}))
+ assert.Equal(t, "", ws.SheetProtection.Password)
+ // Test protect worksheet with password exceeds the limit length
+ assert.EqualError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{
+ AlgorithmName: "MD4",
+ Password: strings.Repeat("s", MaxFieldLength+1),
+ }), ErrPasswordLengthInvalid.Error())
+ // Test protect worksheet with unsupported hash algorithm
+ assert.EqualError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{
+ AlgorithmName: "RIPEMD-160",
+ Password: "password",
+ }), ErrUnsupportedHashAlgorithm.Error())
// Test protect not exists worksheet.
assert.EqualError(t, f.ProtectSheet("SheetN", nil), "sheet SheetN is not exist")
}
@@ -1176,12 +1207,30 @@ func TestUnprotectSheet(t *testing.T) {
if !assert.NoError(t, err) {
t.FailNow()
}
- // Test unprotect not exists worksheet.
+ // Test remove protection on not exists worksheet.
assert.EqualError(t, f.UnprotectSheet("SheetN"), "sheet SheetN is not exist")
assert.NoError(t, f.UnprotectSheet("Sheet1"))
+ assert.EqualError(t, f.UnprotectSheet("Sheet1", "password"), ErrUnprotectSheet.Error())
assert.NoError(t, f.SaveAs(filepath.Join("test", "TestUnprotectSheet.xlsx")))
assert.NoError(t, f.Close())
+
+ f = NewFile()
+ sheetName := f.GetSheetName(0)
+ assert.NoError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{Password: "password"}))
+ // Test remove sheet protection with an incorrect password
+ assert.EqualError(t, f.UnprotectSheet(sheetName, "wrongPassword"), ErrUnprotectSheetPassword.Error())
+ // Test remove sheet protection with password verification
+ assert.NoError(t, f.UnprotectSheet(sheetName, "password"))
+ // Test with invalid salt value
+ assert.NoError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{
+ AlgorithmName: "SHA-512",
+ Password: "password",
+ }))
+ ws, err := f.workSheetReader(sheetName)
+ assert.NoError(t, err)
+ ws.SheetProtection.SaltValue = "YWJjZA====="
+ assert.EqualError(t, f.UnprotectSheet(sheetName, "wrongPassword"), "illegal base64 data at input byte 8")
}
func TestSetDefaultTimeStyle(t *testing.T) {
diff --git a/pivotTable_test.go b/pivotTable_test.go
index 3487793..d7a8eb1 100644
--- a/pivotTable_test.go
+++ b/pivotTable_test.go
@@ -222,7 +222,7 @@ func TestAddPivotTable(t *testing.T) {
PivotTableRange: "Sheet1!$G$2:$M$34",
Rows: []PivotTableField{{Data: "Month", DefaultSubtotal: true}, {Data: "Year"}},
Columns: []PivotTableField{{Data: "Type", DefaultSubtotal: true}},
- Data: []PivotTableField{{Data: "Sales", Subtotal: "-", Name: strings.Repeat("s", 256)}},
+ Data: []PivotTableField{{Data: "Sales", Subtotal: "-", Name: strings.Repeat("s", MaxFieldLength+1)}},
}))
// Test adjust range with invalid range
diff --git a/sheet.go b/sheet.go
index 17f6693..26baca8 100644
--- a/sheet.go
+++ b/sheet.go
@@ -1129,10 +1129,14 @@ func (f *File) SetHeaderFooter(sheet string, settings *FormatHeaderFooter) error
}
// ProtectSheet provides a function to prevent other users from accidentally
-// or deliberately changing, moving, or deleting data in a worksheet. For
-// example, protect Sheet1 with protection settings:
+// or deliberately changing, moving, or deleting data in a worksheet. The
+// optional field AlgorithmName specified hash algorithm, support XOR, MD4,
+// MD5, SHA1, SHA256, SHA384, and SHA512 currently, if no hash algorithm
+// specified, will be using the XOR algorithm as default. For example,
+// protect Sheet1 with protection settings:
//
// err := f.ProtectSheet("Sheet1", &excelize.FormatSheetProtection{
+// AlgorithmName: "SHA-512",
// Password: "password",
// EditScenarios: false,
// })
@@ -1168,22 +1172,55 @@ func (f *File) ProtectSheet(sheet string, settings *FormatSheetProtection) error
Sort: settings.Sort,
}
if settings.Password != "" {
- ws.SheetProtection.Password = genSheetPasswd(settings.Password)
+ if settings.AlgorithmName == "" {
+ ws.SheetProtection.Password = genSheetPasswd(settings.Password)
+ return err
+ }
+ hashValue, saltValue, err := genISOPasswdHash(settings.Password, settings.AlgorithmName, "", int(sheetProtectionSpinCount))
+ if err != nil {
+ return err
+ }
+ ws.SheetProtection.Password = ""
+ ws.SheetProtection.AlgorithmName = settings.AlgorithmName
+ ws.SheetProtection.SaltValue = saltValue
+ ws.SheetProtection.HashValue = hashValue
+ ws.SheetProtection.SpinCount = int(sheetProtectionSpinCount)
}
return err
}
-// UnprotectSheet provides a function to unprotect an Excel worksheet.
-func (f *File) UnprotectSheet(sheet string) error {
+// UnprotectSheet provides a function to remove protection for a sheet,
+// specified the second optional password parameter to remove sheet
+// protection with password verification.
+func (f *File) UnprotectSheet(sheet string, password ...string) error {
ws, err := f.workSheetReader(sheet)
if err != nil {
return err
}
+ // password verification
+ if len(password) > 0 {
+ if ws.SheetProtection == nil {
+ return ErrUnprotectSheet
+ }
+ if ws.SheetProtection.AlgorithmName == "" && ws.SheetProtection.Password != genSheetPasswd(password[0]) {
+ return ErrUnprotectSheetPassword
+ }
+ if ws.SheetProtection.AlgorithmName != "" {
+ // check with given salt value
+ hashValue, _, err := genISOPasswdHash(password[0], ws.SheetProtection.AlgorithmName, ws.SheetProtection.SaltValue, ws.SheetProtection.SpinCount)
+ if err != nil {
+ return err
+ }
+ if ws.SheetProtection.HashValue != hashValue {
+ return ErrUnprotectSheetPassword
+ }
+ }
+ }
ws.SheetProtection = nil
return err
}
-// trimSheetName provides a function to trim invaild characters by given worksheet
+// trimSheetName provides a function to trim invalid characters by given worksheet
// name.
func trimSheetName(name string) string {
if strings.ContainsAny(name, ":\\/?*[]") || utf8.RuneCountInString(name) > 31 {
diff --git a/xmlWorksheet.go b/xmlWorksheet.go
index 217f367..b09d630 100644
--- a/xmlWorksheet.go
+++ b/xmlWorksheet.go
@@ -838,6 +838,7 @@ type formatConditional struct {
// FormatSheetProtection directly maps the settings of worksheet protection.
type FormatSheetProtection struct {
+ AlgorithmName string
AutoFilter bool
DeleteColumns bool
DeleteRows bool