decimal.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107
  1. // Package decimal implements an arbitrary precision fixed-point decimal.
  2. //
  3. // To use as part of a struct:
  4. //
  5. // type Struct struct {
  6. // Number Decimal
  7. // }
  8. //
  9. // The zero-value of a Decimal is 0, as you would expect.
  10. //
  11. // The best way to create a new Decimal is to use decimal.NewFromString, ex:
  12. //
  13. // n, err := decimal.NewFromString("-123.4567")
  14. // n.String() // output: "-123.4567"
  15. //
  16. // NOTE: This can "only" represent numbers with a maximum of 2^31 digits
  17. // after the decimal point.
  18. package decimal
  19. import (
  20. "database/sql/driver"
  21. "encoding/binary"
  22. "fmt"
  23. "math"
  24. "math/big"
  25. "strconv"
  26. "strings"
  27. )
  28. // DivisionPrecision is the number of decimal places in the result when it
  29. // doesn't divide exactly.
  30. //
  31. // Example:
  32. //
  33. // d1 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3)
  34. // d1.String() // output: "0.6666666666666667"
  35. // d2 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(30000)
  36. // d2.String() // output: "0.0000666666666667"
  37. // d3 := decimal.NewFromFloat(20000).Div(decimal.NewFromFloat(3)
  38. // d3.String() // output: "6666.6666666666666667"
  39. // decimal.DivisionPrecision = 3
  40. // d4 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3)
  41. // d4.String() // output: "0.667"
  42. //
  43. var DivisionPrecision = 16
  44. // MarshalJSONWithoutQuotes should be set to true if you want the decimal to
  45. // be JSON marshaled as a number, instead of as a string.
  46. // WARNING: this is dangerous for decimals with many digits, since many JSON
  47. // unmarshallers (ex: Javascript's) will unmarshal JSON numbers to IEEE 754
  48. // double-precision floating point numbers, which means you can potentially
  49. // silently lose precision.
  50. var MarshalJSONWithoutQuotes = false
  51. // Zero constant, to make computations faster.
  52. var Zero = New(0, 1)
  53. // fiveDec used in Cash Rounding
  54. var fiveDec = New(5, 0)
  55. var zeroInt = big.NewInt(0)
  56. var oneInt = big.NewInt(1)
  57. var twoInt = big.NewInt(2)
  58. var fourInt = big.NewInt(4)
  59. var fiveInt = big.NewInt(5)
  60. var tenInt = big.NewInt(10)
  61. var twentyInt = big.NewInt(20)
  62. // Decimal represents a fixed-point decimal. It is immutable.
  63. // number = value * 10 ^ exp
  64. type Decimal struct {
  65. value *big.Int
  66. // NOTE(vadim): this must be an int32, because we cast it to float64 during
  67. // calculations. If exp is 64 bit, we might lose precision.
  68. // If we cared about being able to represent every possible decimal, we
  69. // could make exp a *big.Int but it would hurt performance and numbers
  70. // like that are unrealistic.
  71. exp int32
  72. }
  73. // New returns a new fixed-point decimal, value * 10 ^ exp.
  74. func New(value int64, exp int32) Decimal {
  75. return Decimal{
  76. value: big.NewInt(value),
  77. exp: exp,
  78. }
  79. }
  80. // NewFromBigInt returns a new Decimal from a big.Int, value * 10 ^ exp
  81. func NewFromBigInt(value *big.Int, exp int32) Decimal {
  82. return Decimal{
  83. value: big.NewInt(0).Set(value),
  84. exp: exp,
  85. }
  86. }
  87. // NewFromString returns a new Decimal from a string representation.
  88. //
  89. // Example:
  90. //
  91. // d, err := NewFromString("-123.45")
  92. // d2, err := NewFromString(".0001")
  93. //
  94. func NewFromString(value string) (Decimal, error) {
  95. originalInput := value
  96. var intString string
  97. var exp int64
  98. // Check if number is using scientific notation
  99. eIndex := strings.IndexAny(value, "Ee")
  100. if eIndex != -1 {
  101. expInt, err := strconv.ParseInt(value[eIndex+1:], 10, 32)
  102. if err != nil {
  103. if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange {
  104. return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", value)
  105. }
  106. return Decimal{}, fmt.Errorf("can't convert %s to decimal: exponent is not numeric", value)
  107. }
  108. value = value[:eIndex]
  109. exp = expInt
  110. }
  111. parts := strings.Split(value, ".")
  112. if len(parts) == 1 {
  113. // There is no decimal point, we can just parse the original string as
  114. // an int
  115. intString = value
  116. } else if len(parts) == 2 {
  117. // strip the insignificant digits for more accurate comparisons.
  118. decimalPart := strings.TrimRight(parts[1], "0")
  119. intString = parts[0] + decimalPart
  120. expInt := -len(decimalPart)
  121. exp += int64(expInt)
  122. } else {
  123. return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value)
  124. }
  125. dValue := new(big.Int)
  126. _, ok := dValue.SetString(intString, 10)
  127. if !ok {
  128. return Decimal{}, fmt.Errorf("can't convert %s to decimal", value)
  129. }
  130. if exp < math.MinInt32 || exp > math.MaxInt32 {
  131. // NOTE(vadim): I doubt a string could realistically be this long
  132. return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", originalInput)
  133. }
  134. return Decimal{
  135. value: dValue,
  136. exp: int32(exp),
  137. }, nil
  138. }
  139. // RequireFromString returns a new Decimal from a string representation
  140. // or panics if NewFromString would have returned an error.
  141. //
  142. // Example:
  143. //
  144. // d := RequireFromString("-123.45")
  145. // d2 := RequireFromString(".0001")
  146. //
  147. func RequireFromString(value string) Decimal {
  148. dec, err := NewFromString(value)
  149. if err != nil {
  150. panic(err)
  151. }
  152. return dec
  153. }
  154. // NewFromFloat converts a float64 to Decimal.
  155. //
  156. // Example:
  157. //
  158. // NewFromFloat(123.45678901234567).String() // output: "123.4567890123456"
  159. // NewFromFloat(.00000000000000001).String() // output: "0.00000000000000001"
  160. //
  161. // NOTE: some float64 numbers can take up about 300 bytes of memory in decimal representation.
  162. // Consider using NewFromFloatWithExponent if space is more important than precision.
  163. //
  164. // NOTE: this will panic on NaN, +/-inf
  165. func NewFromFloat(value float64) Decimal {
  166. return NewFromFloatWithExponent(value, math.MinInt32)
  167. }
  168. // NewFromFloatWithExponent converts a float64 to Decimal, with an arbitrary
  169. // number of fractional digits.
  170. //
  171. // Example:
  172. //
  173. // NewFromFloatWithExponent(123.456, -2).String() // output: "123.46"
  174. //
  175. func NewFromFloatWithExponent(value float64, exp int32) Decimal {
  176. if math.IsNaN(value) || math.IsInf(value, 0) {
  177. panic(fmt.Sprintf("Cannot create a Decimal from %v", value))
  178. }
  179. bits := math.Float64bits(value)
  180. mant := bits & (1<<52 - 1)
  181. exp2 := int32((bits >> 52) & (1<<11 - 1))
  182. sign := bits >> 63
  183. if exp2 == 0 {
  184. // specials
  185. if mant == 0 {
  186. return Decimal{}
  187. } else {
  188. // subnormal
  189. exp2++
  190. }
  191. } else {
  192. // normal
  193. mant |= 1 << 52
  194. }
  195. exp2 -= 1023 + 52
  196. // normalizing base-2 values
  197. for mant&1 == 0 {
  198. mant = mant >> 1
  199. exp2++
  200. }
  201. // maximum number of fractional base-10 digits to represent 2^N exactly cannot be more than -N if N<0
  202. if exp < 0 && exp < exp2 {
  203. if exp2 < 0 {
  204. exp = exp2
  205. } else {
  206. exp = 0
  207. }
  208. }
  209. // representing 10^M * 2^N as 5^M * 2^(M+N)
  210. exp2 -= exp
  211. temp := big.NewInt(1)
  212. dMant := big.NewInt(int64(mant))
  213. // applying 5^M
  214. if exp > 0 {
  215. temp = temp.SetInt64(int64(exp))
  216. temp = temp.Exp(fiveInt, temp, nil)
  217. } else if exp < 0 {
  218. temp = temp.SetInt64(-int64(exp))
  219. temp = temp.Exp(fiveInt, temp, nil)
  220. dMant = dMant.Mul(dMant, temp)
  221. temp = temp.SetUint64(1)
  222. }
  223. // applying 2^(M+N)
  224. if exp2 > 0 {
  225. dMant = dMant.Lsh(dMant, uint(exp2))
  226. } else if exp2 < 0 {
  227. temp = temp.Lsh(temp, uint(-exp2))
  228. }
  229. // rounding and downscaling
  230. if exp > 0 || exp2 < 0 {
  231. halfDown := new(big.Int).Rsh(temp, 1)
  232. dMant = dMant.Add(dMant, halfDown)
  233. dMant = dMant.Quo(dMant, temp)
  234. }
  235. if sign == 1 {
  236. dMant = dMant.Neg(dMant)
  237. }
  238. return Decimal{
  239. value: dMant,
  240. exp: exp,
  241. }
  242. }
  243. // rescale returns a rescaled version of the decimal. Returned
  244. // decimal may be less precise if the given exponent is bigger
  245. // than the initial exponent of the Decimal.
  246. // NOTE: this will truncate, NOT round
  247. //
  248. // Example:
  249. //
  250. // d := New(12345, -4)
  251. // d2 := d.rescale(-1)
  252. // d3 := d2.rescale(-4)
  253. // println(d1)
  254. // println(d2)
  255. // println(d3)
  256. //
  257. // Output:
  258. //
  259. // 1.2345
  260. // 1.2
  261. // 1.2000
  262. //
  263. func (d Decimal) rescale(exp int32) Decimal {
  264. d.ensureInitialized()
  265. // NOTE(vadim): must convert exps to float64 before - to prevent overflow
  266. diff := math.Abs(float64(exp) - float64(d.exp))
  267. value := new(big.Int).Set(d.value)
  268. expScale := new(big.Int).Exp(tenInt, big.NewInt(int64(diff)), nil)
  269. if exp > d.exp {
  270. value = value.Quo(value, expScale)
  271. } else if exp < d.exp {
  272. value = value.Mul(value, expScale)
  273. }
  274. return Decimal{
  275. value: value,
  276. exp: exp,
  277. }
  278. }
  279. // Abs returns the absolute value of the decimal.
  280. func (d Decimal) Abs() Decimal {
  281. d.ensureInitialized()
  282. d2Value := new(big.Int).Abs(d.value)
  283. return Decimal{
  284. value: d2Value,
  285. exp: d.exp,
  286. }
  287. }
  288. // Add returns d + d2.
  289. func (d Decimal) Add(d2 Decimal) Decimal {
  290. baseScale := min(d.exp, d2.exp)
  291. rd := d.rescale(baseScale)
  292. rd2 := d2.rescale(baseScale)
  293. d3Value := new(big.Int).Add(rd.value, rd2.value)
  294. return Decimal{
  295. value: d3Value,
  296. exp: baseScale,
  297. }
  298. }
  299. // Sub returns d - d2.
  300. func (d Decimal) Sub(d2 Decimal) Decimal {
  301. baseScale := min(d.exp, d2.exp)
  302. rd := d.rescale(baseScale)
  303. rd2 := d2.rescale(baseScale)
  304. d3Value := new(big.Int).Sub(rd.value, rd2.value)
  305. return Decimal{
  306. value: d3Value,
  307. exp: baseScale,
  308. }
  309. }
  310. // Neg returns -d.
  311. func (d Decimal) Neg() Decimal {
  312. d.ensureInitialized()
  313. val := new(big.Int).Neg(d.value)
  314. return Decimal{
  315. value: val,
  316. exp: d.exp,
  317. }
  318. }
  319. // Mul returns d * d2.
  320. func (d Decimal) Mul(d2 Decimal) Decimal {
  321. d.ensureInitialized()
  322. d2.ensureInitialized()
  323. expInt64 := int64(d.exp) + int64(d2.exp)
  324. if expInt64 > math.MaxInt32 || expInt64 < math.MinInt32 {
  325. // NOTE(vadim): better to panic than give incorrect results, as
  326. // Decimals are usually used for money
  327. panic(fmt.Sprintf("exponent %v overflows an int32!", expInt64))
  328. }
  329. d3Value := new(big.Int).Mul(d.value, d2.value)
  330. return Decimal{
  331. value: d3Value,
  332. exp: int32(expInt64),
  333. }
  334. }
  335. // Div returns d / d2. If it doesn't divide exactly, the result will have
  336. // DivisionPrecision digits after the decimal point.
  337. func (d Decimal) Div(d2 Decimal) Decimal {
  338. return d.DivRound(d2, int32(DivisionPrecision))
  339. }
  340. // QuoRem does divsion with remainder
  341. // d.QuoRem(d2,precision) returns quotient q and remainder r such that
  342. // d = d2 * q + r, q an integer multiple of 10^(-precision)
  343. // 0 <= r < abs(d2) * 10 ^(-precision) if d>=0
  344. // 0 >= r > -abs(d2) * 10 ^(-precision) if d<0
  345. // Note that precision<0 is allowed as input.
  346. func (d Decimal) QuoRem(d2 Decimal, precision int32) (Decimal, Decimal) {
  347. d.ensureInitialized()
  348. d2.ensureInitialized()
  349. if d2.value.Sign() == 0 {
  350. panic("decimal division by 0")
  351. }
  352. scale := -precision
  353. e := int64(d.exp - d2.exp - scale)
  354. if e > math.MaxInt32 || e < math.MinInt32 {
  355. panic("overflow in decimal QuoRem")
  356. }
  357. var aa, bb, expo big.Int
  358. var scalerest int32
  359. // d = a 10^ea
  360. // d2 = b 10^eb
  361. if e < 0 {
  362. aa = *d.value
  363. expo.SetInt64(-e)
  364. bb.Exp(tenInt, &expo, nil)
  365. bb.Mul(d2.value, &bb)
  366. scalerest = d.exp
  367. // now aa = a
  368. // bb = b 10^(scale + eb - ea)
  369. } else {
  370. expo.SetInt64(e)
  371. aa.Exp(tenInt, &expo, nil)
  372. aa.Mul(d.value, &aa)
  373. bb = *d2.value
  374. scalerest = scale + d2.exp
  375. // now aa = a ^ (ea - eb - scale)
  376. // bb = b
  377. }
  378. var q, r big.Int
  379. q.QuoRem(&aa, &bb, &r)
  380. dq := Decimal{value: &q, exp: scale}
  381. dr := Decimal{value: &r, exp: scalerest}
  382. return dq, dr
  383. }
  384. // DivRound divides and rounds to a given precision
  385. // i.e. to an integer multiple of 10^(-precision)
  386. // for a positive quotient digit 5 is rounded up, away from 0
  387. // if the quotient is negative then digit 5 is rounded down, away from 0
  388. // Note that precision<0 is allowed as input.
  389. func (d Decimal) DivRound(d2 Decimal, precision int32) Decimal {
  390. // QuoRem already checks initialization
  391. q, r := d.QuoRem(d2, precision)
  392. // the actual rounding decision is based on comparing r*10^precision and d2/2
  393. // instead compare 2 r 10 ^precision and d2
  394. var rv2 big.Int
  395. rv2.Abs(r.value)
  396. rv2.Lsh(&rv2, 1)
  397. // now rv2 = abs(r.value) * 2
  398. r2 := Decimal{value: &rv2, exp: r.exp + precision}
  399. // r2 is now 2 * r * 10 ^ precision
  400. var c = r2.Cmp(d2.Abs())
  401. if c < 0 {
  402. return q
  403. }
  404. if d.value.Sign()*d2.value.Sign() < 0 {
  405. return q.Sub(New(1, -precision))
  406. }
  407. return q.Add(New(1, -precision))
  408. }
  409. // Mod returns d % d2.
  410. func (d Decimal) Mod(d2 Decimal) Decimal {
  411. quo := d.Div(d2).Truncate(0)
  412. return d.Sub(d2.Mul(quo))
  413. }
  414. // Pow returns d to the power d2
  415. func (d Decimal) Pow(d2 Decimal) Decimal {
  416. var temp Decimal
  417. if d2.IntPart() == 0 {
  418. return NewFromFloat(1)
  419. }
  420. temp = d.Pow(d2.Div(NewFromFloat(2)))
  421. if d2.IntPart()%2 == 0 {
  422. return temp.Mul(temp)
  423. }
  424. if d2.IntPart() > 0 {
  425. return temp.Mul(temp).Mul(d)
  426. }
  427. return temp.Mul(temp).Div(d)
  428. }
  429. // Cmp compares the numbers represented by d and d2 and returns:
  430. //
  431. // -1 if d < d2
  432. // 0 if d == d2
  433. // +1 if d > d2
  434. //
  435. func (d Decimal) Cmp(d2 Decimal) int {
  436. d.ensureInitialized()
  437. d2.ensureInitialized()
  438. if d.exp == d2.exp {
  439. return d.value.Cmp(d2.value)
  440. }
  441. baseExp := min(d.exp, d2.exp)
  442. rd := d.rescale(baseExp)
  443. rd2 := d2.rescale(baseExp)
  444. return rd.value.Cmp(rd2.value)
  445. }
  446. // Equal returns whether the numbers represented by d and d2 are equal.
  447. func (d Decimal) Equal(d2 Decimal) bool {
  448. return d.Cmp(d2) == 0
  449. }
  450. // Equals is deprecated, please use Equal method instead
  451. func (d Decimal) Equals(d2 Decimal) bool {
  452. return d.Equal(d2)
  453. }
  454. // GreaterThan (GT) returns true when d is greater than d2.
  455. func (d Decimal) GreaterThan(d2 Decimal) bool {
  456. return d.Cmp(d2) == 1
  457. }
  458. // GreaterThanOrEqual (GTE) returns true when d is greater than or equal to d2.
  459. func (d Decimal) GreaterThanOrEqual(d2 Decimal) bool {
  460. cmp := d.Cmp(d2)
  461. return cmp == 1 || cmp == 0
  462. }
  463. // LessThan (LT) returns true when d is less than d2.
  464. func (d Decimal) LessThan(d2 Decimal) bool {
  465. return d.Cmp(d2) == -1
  466. }
  467. // LessThanOrEqual (LTE) returns true when d is less than or equal to d2.
  468. func (d Decimal) LessThanOrEqual(d2 Decimal) bool {
  469. cmp := d.Cmp(d2)
  470. return cmp == -1 || cmp == 0
  471. }
  472. // Sign returns:
  473. //
  474. // -1 if d < 0
  475. // 0 if d == 0
  476. // +1 if d > 0
  477. //
  478. func (d Decimal) Sign() int {
  479. if d.value == nil {
  480. return 0
  481. }
  482. return d.value.Sign()
  483. }
  484. // Exponent returns the exponent, or scale component of the decimal.
  485. func (d Decimal) Exponent() int32 {
  486. return d.exp
  487. }
  488. // Coefficient returns the coefficient of the decimal. It is scaled by 10^Exponent()
  489. func (d Decimal) Coefficient() *big.Int {
  490. // we copy the coefficient so that mutating the result does not mutate the
  491. // Decimal.
  492. return big.NewInt(0).Set(d.value)
  493. }
  494. // IntPart returns the integer component of the decimal.
  495. func (d Decimal) IntPart() int64 {
  496. scaledD := d.rescale(0)
  497. return scaledD.value.Int64()
  498. }
  499. // Rat returns a rational number representation of the decimal.
  500. func (d Decimal) Rat() *big.Rat {
  501. d.ensureInitialized()
  502. if d.exp <= 0 {
  503. // NOTE(vadim): must negate after casting to prevent int32 overflow
  504. denom := new(big.Int).Exp(tenInt, big.NewInt(-int64(d.exp)), nil)
  505. return new(big.Rat).SetFrac(d.value, denom)
  506. }
  507. mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil)
  508. num := new(big.Int).Mul(d.value, mul)
  509. return new(big.Rat).SetFrac(num, oneInt)
  510. }
  511. // Float64 returns the nearest float64 value for d and a bool indicating
  512. // whether f represents d exactly.
  513. // For more details, see the documentation for big.Rat.Float64
  514. func (d Decimal) Float64() (f float64, exact bool) {
  515. return d.Rat().Float64()
  516. }
  517. // String returns the string representation of the decimal
  518. // with the fixed point.
  519. //
  520. // Example:
  521. //
  522. // d := New(-12345, -3)
  523. // println(d.String())
  524. //
  525. // Output:
  526. //
  527. // -12.345
  528. //
  529. func (d Decimal) String() string {
  530. return d.string(true)
  531. }
  532. // StringFixed returns a rounded fixed-point string with places digits after
  533. // the decimal point.
  534. //
  535. // Example:
  536. //
  537. // NewFromFloat(0).StringFixed(2) // output: "0.00"
  538. // NewFromFloat(0).StringFixed(0) // output: "0"
  539. // NewFromFloat(5.45).StringFixed(0) // output: "5"
  540. // NewFromFloat(5.45).StringFixed(1) // output: "5.5"
  541. // NewFromFloat(5.45).StringFixed(2) // output: "5.45"
  542. // NewFromFloat(5.45).StringFixed(3) // output: "5.450"
  543. // NewFromFloat(545).StringFixed(-1) // output: "550"
  544. //
  545. func (d Decimal) StringFixed(places int32) string {
  546. rounded := d.Round(places)
  547. return rounded.string(false)
  548. }
  549. // StringFixedBank returns a banker rounded fixed-point string with places digits
  550. // after the decimal point.
  551. //
  552. // Example:
  553. //
  554. // NewFromFloat(0).StringFixed(2) // output: "0.00"
  555. // NewFromFloat(0).StringFixed(0) // output: "0"
  556. // NewFromFloat(5.45).StringFixed(0) // output: "5"
  557. // NewFromFloat(5.45).StringFixed(1) // output: "5.4"
  558. // NewFromFloat(5.45).StringFixed(2) // output: "5.45"
  559. // NewFromFloat(5.45).StringFixed(3) // output: "5.450"
  560. // NewFromFloat(545).StringFixed(-1) // output: "550"
  561. //
  562. func (d Decimal) StringFixedBank(places int32) string {
  563. rounded := d.RoundBank(places)
  564. return rounded.string(false)
  565. }
  566. // StringFixedCash returns a Swedish/Cash rounded fixed-point string. For
  567. // more details see the documentation at function RoundCash.
  568. func (d Decimal) StringFixedCash(interval uint8) string {
  569. rounded := d.RoundCash(interval)
  570. return rounded.string(false)
  571. }
  572. // Round rounds the decimal to places decimal places.
  573. // If places < 0, it will round the integer part to the nearest 10^(-places).
  574. //
  575. // Example:
  576. //
  577. // NewFromFloat(5.45).Round(1).String() // output: "5.5"
  578. // NewFromFloat(545).Round(-1).String() // output: "550"
  579. //
  580. func (d Decimal) Round(places int32) Decimal {
  581. // truncate to places + 1
  582. ret := d.rescale(-places - 1)
  583. // add sign(d) * 0.5
  584. if ret.value.Sign() < 0 {
  585. ret.value.Sub(ret.value, fiveInt)
  586. } else {
  587. ret.value.Add(ret.value, fiveInt)
  588. }
  589. // floor for positive numbers, ceil for negative numbers
  590. _, m := ret.value.DivMod(ret.value, tenInt, new(big.Int))
  591. ret.exp++
  592. if ret.value.Sign() < 0 && m.Cmp(zeroInt) != 0 {
  593. ret.value.Add(ret.value, oneInt)
  594. }
  595. return ret
  596. }
  597. // RoundBank rounds the decimal to places decimal places.
  598. // If the final digit to round is equidistant from the nearest two integers the
  599. // rounded value is taken as the even number
  600. //
  601. // If places < 0, it will round the integer part to the nearest 10^(-places).
  602. //
  603. // Examples:
  604. //
  605. // NewFromFloat(5.45).Round(1).String() // output: "5.4"
  606. // NewFromFloat(545).Round(-1).String() // output: "540"
  607. // NewFromFloat(5.46).Round(1).String() // output: "5.5"
  608. // NewFromFloat(546).Round(-1).String() // output: "550"
  609. // NewFromFloat(5.55).Round(1).String() // output: "5.6"
  610. // NewFromFloat(555).Round(-1).String() // output: "560"
  611. //
  612. func (d Decimal) RoundBank(places int32) Decimal {
  613. round := d.Round(places)
  614. remainder := d.Sub(round).Abs()
  615. half := New(5, -places-1)
  616. if remainder.Cmp(half) == 0 && round.value.Bit(0) != 0 {
  617. if round.value.Sign() < 0 {
  618. round.value.Add(round.value, oneInt)
  619. } else {
  620. round.value.Sub(round.value, oneInt)
  621. }
  622. }
  623. return round
  624. }
  625. // RoundCash aka Cash/Penny/öre rounding rounds decimal to a specific
  626. // interval. The amount payable for a cash transaction is rounded to the nearest
  627. // multiple of the minimum currency unit available. The following intervals are
  628. // available: 5, 10, 15, 25, 50 and 100; any other number throws a panic.
  629. // 5: 5 cent rounding 3.43 => 3.45
  630. // 10: 10 cent rounding 3.45 => 3.50 (5 gets rounded up)
  631. // 15: 10 cent rounding 3.45 => 3.40 (5 gets rounded down)
  632. // 25: 25 cent rounding 3.41 => 3.50
  633. // 50: 50 cent rounding 3.75 => 4.00
  634. // 100: 100 cent rounding 3.50 => 4.00
  635. // For more details: https://en.wikipedia.org/wiki/Cash_rounding
  636. func (d Decimal) RoundCash(interval uint8) Decimal {
  637. var iVal *big.Int
  638. switch interval {
  639. case 5:
  640. iVal = twentyInt
  641. case 10:
  642. iVal = tenInt
  643. case 15:
  644. if d.exp < 0 {
  645. // TODO: optimize and reduce allocations
  646. orgExp := d.exp
  647. dOne := New(10^-int64(orgExp), orgExp)
  648. d2 := d
  649. d2.exp = 0
  650. if d2.Mod(fiveDec).Equal(Zero) {
  651. d2.exp = orgExp
  652. d2 = d2.Sub(dOne)
  653. d = d2
  654. }
  655. }
  656. iVal = tenInt
  657. case 25:
  658. iVal = fourInt
  659. case 50:
  660. iVal = twoInt
  661. case 100:
  662. iVal = oneInt
  663. default:
  664. panic(fmt.Sprintf("Decimal does not support this Cash rounding interval `%d`. Supported: 5, 10, 15, 25, 50, 100", interval))
  665. }
  666. dVal := Decimal{
  667. value: iVal,
  668. }
  669. // TODO: optimize those calculations to reduce the high allocations (~29 allocs).
  670. return d.Mul(dVal).Round(0).Div(dVal).Truncate(2)
  671. }
  672. // Floor returns the nearest integer value less than or equal to d.
  673. func (d Decimal) Floor() Decimal {
  674. d.ensureInitialized()
  675. if d.exp >= 0 {
  676. return d
  677. }
  678. exp := big.NewInt(10)
  679. // NOTE(vadim): must negate after casting to prevent int32 overflow
  680. exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
  681. z := new(big.Int).Div(d.value, exp)
  682. return Decimal{value: z, exp: 0}
  683. }
  684. // Ceil returns the nearest integer value greater than or equal to d.
  685. func (d Decimal) Ceil() Decimal {
  686. d.ensureInitialized()
  687. if d.exp >= 0 {
  688. return d
  689. }
  690. exp := big.NewInt(10)
  691. // NOTE(vadim): must negate after casting to prevent int32 overflow
  692. exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
  693. z, m := new(big.Int).DivMod(d.value, exp, new(big.Int))
  694. if m.Cmp(zeroInt) != 0 {
  695. z.Add(z, oneInt)
  696. }
  697. return Decimal{value: z, exp: 0}
  698. }
  699. // Truncate truncates off digits from the number, without rounding.
  700. //
  701. // NOTE: precision is the last digit that will not be truncated (must be >= 0).
  702. //
  703. // Example:
  704. //
  705. // decimal.NewFromString("123.456").Truncate(2).String() // "123.45"
  706. //
  707. func (d Decimal) Truncate(precision int32) Decimal {
  708. d.ensureInitialized()
  709. if precision >= 0 && -precision > d.exp {
  710. return d.rescale(-precision)
  711. }
  712. return d
  713. }
  714. // UnmarshalJSON implements the json.Unmarshaler interface.
  715. func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error {
  716. if string(decimalBytes) == "null" {
  717. return nil
  718. }
  719. str, err := unquoteIfQuoted(decimalBytes)
  720. if err != nil {
  721. return fmt.Errorf("Error decoding string '%s': %s", decimalBytes, err)
  722. }
  723. decimal, err := NewFromString(str)
  724. *d = decimal
  725. if err != nil {
  726. return fmt.Errorf("Error decoding string '%s': %s", str, err)
  727. }
  728. return nil
  729. }
  730. // MarshalJSON implements the json.Marshaler interface.
  731. func (d Decimal) MarshalJSON() ([]byte, error) {
  732. var str string
  733. if MarshalJSONWithoutQuotes {
  734. str = d.String()
  735. } else {
  736. str = "\"" + d.String() + "\""
  737. }
  738. return []byte(str), nil
  739. }
  740. // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. As a string representation
  741. // is already used when encoding to text, this method stores that string as []byte
  742. func (d *Decimal) UnmarshalBinary(data []byte) error {
  743. // Extract the exponent
  744. d.exp = int32(binary.BigEndian.Uint32(data[:4]))
  745. // Extract the value
  746. d.value = new(big.Int)
  747. return d.value.GobDecode(data[4:])
  748. }
  749. // MarshalBinary implements the encoding.BinaryMarshaler interface.
  750. func (d Decimal) MarshalBinary() (data []byte, err error) {
  751. // Write the exponent first since it's a fixed size
  752. v1 := make([]byte, 4)
  753. binary.BigEndian.PutUint32(v1, uint32(d.exp))
  754. // Add the value
  755. var v2 []byte
  756. if v2, err = d.value.GobEncode(); err != nil {
  757. return
  758. }
  759. // Return the byte array
  760. data = append(v1, v2...)
  761. return
  762. }
  763. // Scan implements the sql.Scanner interface for database deserialization.
  764. func (d *Decimal) Scan(value interface{}) error {
  765. // first try to see if the data is stored in database as a Numeric datatype
  766. switch v := value.(type) {
  767. case float32:
  768. *d = NewFromFloat(float64(v))
  769. return nil
  770. case float64:
  771. // numeric in sqlite3 sends us float64
  772. *d = NewFromFloat(v)
  773. return nil
  774. case int64:
  775. // at least in sqlite3 when the value is 0 in db, the data is sent
  776. // to us as an int64 instead of a float64 ...
  777. *d = New(v, 0)
  778. return nil
  779. default:
  780. // default is trying to interpret value stored as string
  781. str, err := unquoteIfQuoted(v)
  782. if err != nil {
  783. return err
  784. }
  785. *d, err = NewFromString(str)
  786. return err
  787. }
  788. }
  789. // Value implements the driver.Valuer interface for database serialization.
  790. func (d Decimal) Value() (driver.Value, error) {
  791. return d.String(), nil
  792. }
  793. // UnmarshalText implements the encoding.TextUnmarshaler interface for XML
  794. // deserialization.
  795. func (d *Decimal) UnmarshalText(text []byte) error {
  796. str := string(text)
  797. dec, err := NewFromString(str)
  798. *d = dec
  799. if err != nil {
  800. return fmt.Errorf("Error decoding string '%s': %s", str, err)
  801. }
  802. return nil
  803. }
  804. // MarshalText implements the encoding.TextMarshaler interface for XML
  805. // serialization.
  806. func (d Decimal) MarshalText() (text []byte, err error) {
  807. return []byte(d.String()), nil
  808. }
  809. // GobEncode implements the gob.GobEncoder interface for gob serialization.
  810. func (d Decimal) GobEncode() ([]byte, error) {
  811. return d.MarshalBinary()
  812. }
  813. // GobDecode implements the gob.GobDecoder interface for gob serialization.
  814. func (d *Decimal) GobDecode(data []byte) error {
  815. return d.UnmarshalBinary(data)
  816. }
  817. // StringScaled first scales the decimal then calls .String() on it.
  818. // NOTE: buggy, unintuitive, and DEPRECATED! Use StringFixed instead.
  819. func (d Decimal) StringScaled(exp int32) string {
  820. return d.rescale(exp).String()
  821. }
  822. func (d Decimal) string(trimTrailingZeros bool) string {
  823. if d.exp >= 0 {
  824. return d.rescale(0).value.String()
  825. }
  826. abs := new(big.Int).Abs(d.value)
  827. str := abs.String()
  828. var intPart, fractionalPart string
  829. // NOTE(vadim): this cast to int will cause bugs if d.exp == INT_MIN
  830. // and you are on a 32-bit machine. Won't fix this super-edge case.
  831. dExpInt := int(d.exp)
  832. if len(str) > -dExpInt {
  833. intPart = str[:len(str)+dExpInt]
  834. fractionalPart = str[len(str)+dExpInt:]
  835. } else {
  836. intPart = "0"
  837. num0s := -dExpInt - len(str)
  838. fractionalPart = strings.Repeat("0", num0s) + str
  839. }
  840. if trimTrailingZeros {
  841. i := len(fractionalPart) - 1
  842. for ; i >= 0; i-- {
  843. if fractionalPart[i] != '0' {
  844. break
  845. }
  846. }
  847. fractionalPart = fractionalPart[:i+1]
  848. }
  849. number := intPart
  850. if len(fractionalPart) > 0 {
  851. number += "." + fractionalPart
  852. }
  853. if d.value.Sign() < 0 {
  854. return "-" + number
  855. }
  856. return number
  857. }
  858. func (d *Decimal) ensureInitialized() {
  859. if d.value == nil {
  860. d.value = new(big.Int)
  861. }
  862. }
  863. // Min returns the smallest Decimal that was passed in the arguments.
  864. //
  865. // To call this function with an array, you must do:
  866. //
  867. // Min(arr[0], arr[1:]...)
  868. //
  869. // This makes it harder to accidentally call Min with 0 arguments.
  870. func Min(first Decimal, rest ...Decimal) Decimal {
  871. ans := first
  872. for _, item := range rest {
  873. if item.Cmp(ans) < 0 {
  874. ans = item
  875. }
  876. }
  877. return ans
  878. }
  879. // Max returns the largest Decimal that was passed in the arguments.
  880. //
  881. // To call this function with an array, you must do:
  882. //
  883. // Max(arr[0], arr[1:]...)
  884. //
  885. // This makes it harder to accidentally call Max with 0 arguments.
  886. func Max(first Decimal, rest ...Decimal) Decimal {
  887. ans := first
  888. for _, item := range rest {
  889. if item.Cmp(ans) > 0 {
  890. ans = item
  891. }
  892. }
  893. return ans
  894. }
  895. // Sum returns the combined total of the provided first and rest Decimals
  896. func Sum(first Decimal, rest ...Decimal) Decimal {
  897. total := first
  898. for _, item := range rest {
  899. total = total.Add(item)
  900. }
  901. return total
  902. }
  903. // Avg returns the average value of the provided first and rest Decimals
  904. func Avg(first Decimal, rest ...Decimal) Decimal {
  905. count := New(int64(len(rest)+1), 0)
  906. sum := Sum(first, rest...)
  907. return sum.Div(count)
  908. }
  909. func min(x, y int32) int32 {
  910. if x >= y {
  911. return y
  912. }
  913. return x
  914. }
  915. func unquoteIfQuoted(value interface{}) (string, error) {
  916. var bytes []byte
  917. switch v := value.(type) {
  918. case string:
  919. bytes = []byte(v)
  920. case []byte:
  921. bytes = v
  922. default:
  923. return "", fmt.Errorf("Could not convert value '%+v' to byte array of type '%T'",
  924. value, value)
  925. }
  926. // If the amount is quoted, strip the quotes
  927. if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' {
  928. bytes = bytes[1 : len(bytes)-1]
  929. }
  930. return string(bytes), nil
  931. }
  932. // NullDecimal represents a nullable decimal with compatibility for
  933. // scanning null values from the database.
  934. type NullDecimal struct {
  935. Decimal Decimal
  936. Valid bool
  937. }
  938. // Scan implements the sql.Scanner interface for database deserialization.
  939. func (d *NullDecimal) Scan(value interface{}) error {
  940. if value == nil {
  941. d.Valid = false
  942. return nil
  943. }
  944. d.Valid = true
  945. return d.Decimal.Scan(value)
  946. }
  947. // Value implements the driver.Valuer interface for database serialization.
  948. func (d NullDecimal) Value() (driver.Value, error) {
  949. if !d.Valid {
  950. return nil, nil
  951. }
  952. return d.Decimal.Value()
  953. }
  954. // UnmarshalJSON implements the json.Unmarshaler interface.
  955. func (d *NullDecimal) UnmarshalJSON(decimalBytes []byte) error {
  956. if string(decimalBytes) == "null" {
  957. d.Valid = false
  958. return nil
  959. }
  960. d.Valid = true
  961. return d.Decimal.UnmarshalJSON(decimalBytes)
  962. }
  963. // MarshalJSON implements the json.Marshaler interface.
  964. func (d NullDecimal) MarshalJSON() ([]byte, error) {
  965. if !d.Valid {
  966. return []byte("null"), nil
  967. }
  968. return d.Decimal.MarshalJSON()
  969. }