suggester_phrase.go 14 KB


  1. // Copyright 2012-present Oliver Eilhard. All rights reserved.
  2. // Use of this source code is governed by a MIT-license.
  3. // See http://olivere.mit-license.org/license.txt for details.
  4. package elastic
  5. // PhraseSuggester provides an API to access word alternatives
  6. // on a per token basis within a certain string distance.
  7. // For more details, see
  8. // https://www.elastic.co/guide/en/elasticsearch/reference/5.2/search-suggesters-phrase.html.
  9. type PhraseSuggester struct {
  10. Suggester
  11. name string
  12. text string
  13. field string
  14. analyzer string
  15. size *int
  16. shardSize *int
  17. contextQueries []SuggesterContextQuery
  18. // fields specific to a phrase suggester
  19. maxErrors *float64
  20. separator *string
  21. realWordErrorLikelihood *float64
  22. confidence *float64
  23. generators map[string][]CandidateGenerator
  24. gramSize *int
  25. smoothingModel SmoothingModel
  26. forceUnigrams *bool
  27. tokenLimit *int
  28. preTag, postTag *string
  29. collateQuery *string
  30. collatePreference *string
  31. collateParams map[string]interface{}
  32. collatePrune *bool
  33. }
  34. // NewPhraseSuggester creates a new PhraseSuggester.
  35. func NewPhraseSuggester(name string) *PhraseSuggester {
  36. return &PhraseSuggester{
  37. name: name,
  38. collateParams: make(map[string]interface{}),
  39. }
  40. }
  41. func (q *PhraseSuggester) Name() string {
  42. return q.name
  43. }
  44. func (q *PhraseSuggester) Text(text string) *PhraseSuggester {
  45. q.text = text
  46. return q
  47. }
  48. func (q *PhraseSuggester) Field(field string) *PhraseSuggester {
  49. q.field = field
  50. return q
  51. }
  52. func (q *PhraseSuggester) Analyzer(analyzer string) *PhraseSuggester {
  53. q.analyzer = analyzer
  54. return q
  55. }
  56. func (q *PhraseSuggester) Size(size int) *PhraseSuggester {
  57. q.size = &size
  58. return q
  59. }
  60. func (q *PhraseSuggester) ShardSize(shardSize int) *PhraseSuggester {
  61. q.shardSize = &shardSize
  62. return q
  63. }
  64. func (q *PhraseSuggester) ContextQuery(query SuggesterContextQuery) *PhraseSuggester {
  65. q.contextQueries = append(q.contextQueries, query)
  66. return q
  67. }
  68. func (q *PhraseSuggester) ContextQueries(queries ...SuggesterContextQuery) *PhraseSuggester {
  69. q.contextQueries = append(q.contextQueries, queries...)
  70. return q
  71. }
  72. func (q *PhraseSuggester) GramSize(gramSize int) *PhraseSuggester {
  73. if gramSize >= 1 {
  74. q.gramSize = &gramSize
  75. }
  76. return q
  77. }
  78. func (q *PhraseSuggester) MaxErrors(maxErrors float64) *PhraseSuggester {
  79. q.maxErrors = &maxErrors
  80. return q
  81. }
  82. func (q *PhraseSuggester) Separator(separator string) *PhraseSuggester {
  83. q.separator = &separator
  84. return q
  85. }
  86. func (q *PhraseSuggester) RealWordErrorLikelihood(realWordErrorLikelihood float64) *PhraseSuggester {
  87. q.realWordErrorLikelihood = &realWordErrorLikelihood
  88. return q
  89. }
  90. func (q *PhraseSuggester) Confidence(confidence float64) *PhraseSuggester {
  91. q.confidence = &confidence
  92. return q
  93. }
  94. func (q *PhraseSuggester) CandidateGenerator(generator CandidateGenerator) *PhraseSuggester {
  95. if q.generators == nil {
  96. q.generators = make(map[string][]CandidateGenerator)
  97. }
  98. typ := generator.Type()
  99. if _, found := q.generators[typ]; !found {
  100. q.generators[typ] = make([]CandidateGenerator, 0)
  101. }
  102. q.generators[typ] = append(q.generators[typ], generator)
  103. return q
  104. }
  105. func (q *PhraseSuggester) CandidateGenerators(generators ...CandidateGenerator) *PhraseSuggester {
  106. for _, g := range generators {
  107. q = q.CandidateGenerator(g)
  108. }
  109. return q
  110. }
  111. func (q *PhraseSuggester) ClearCandidateGenerator() *PhraseSuggester {
  112. q.generators = nil
  113. return q
  114. }
  115. func (q *PhraseSuggester) ForceUnigrams(forceUnigrams bool) *PhraseSuggester {
  116. q.forceUnigrams = &forceUnigrams
  117. return q
  118. }
  119. func (q *PhraseSuggester) SmoothingModel(smoothingModel SmoothingModel) *PhraseSuggester {
  120. q.smoothingModel = smoothingModel
  121. return q
  122. }
  123. func (q *PhraseSuggester) TokenLimit(tokenLimit int) *PhraseSuggester {
  124. q.tokenLimit = &tokenLimit
  125. return q
  126. }
  127. func (q *PhraseSuggester) Highlight(preTag, postTag string) *PhraseSuggester {
  128. q.preTag = &preTag
  129. q.postTag = &postTag
  130. return q
  131. }
  132. func (q *PhraseSuggester) CollateQuery(collateQuery string) *PhraseSuggester {
  133. q.collateQuery = &collateQuery
  134. return q
  135. }
  136. func (q *PhraseSuggester) CollatePreference(collatePreference string) *PhraseSuggester {
  137. q.collatePreference = &collatePreference
  138. return q
  139. }
  140. func (q *PhraseSuggester) CollateParams(collateParams map[string]interface{}) *PhraseSuggester {
  141. q.collateParams = collateParams
  142. return q
  143. }
  144. func (q *PhraseSuggester) CollatePrune(collatePrune bool) *PhraseSuggester {
  145. q.collatePrune = &collatePrune
  146. return q
  147. }
  148. // phraseSuggesterRequest is necessary because the order in which
  149. // the JSON elements are routed to Elasticsearch is relevant.
  150. // We got into trouble when using plain maps because the text element
  151. // needs to go before the simple_phrase element.
  152. type phraseSuggesterRequest struct {
  153. Text string `json:"text"`
  154. Phrase interface{} `json:"phrase"`
  155. }
  156. // Source generates the source for the phrase suggester.
  157. func (q *PhraseSuggester) Source(includeName bool) (interface{}, error) {
  158. ps := &phraseSuggesterRequest{}
  159. if q.text != "" {
  160. ps.Text = q.text
  161. }
  162. suggester := make(map[string]interface{})
  163. ps.Phrase = suggester
  164. if q.analyzer != "" {
  165. suggester["analyzer"] = q.analyzer
  166. }
  167. if q.field != "" {
  168. suggester["field"] = q.field
  169. }
  170. if q.size != nil {
  171. suggester["size"] = *q.size
  172. }
  173. if q.shardSize != nil {
  174. suggester["shard_size"] = *q.shardSize
  175. }
  176. switch len(q.contextQueries) {
  177. case 0:
  178. case 1:
  179. src, err := q.contextQueries[0].Source()
  180. if err != nil {
  181. return nil, err
  182. }
  183. suggester["context"] = src
  184. default:
  185. var ctxq []interface{}
  186. for _, query := range q.contextQueries {
  187. src, err := query.Source()
  188. if err != nil {
  189. return nil, err
  190. }
  191. ctxq = append(ctxq, src)
  192. }
  193. suggester["context"] = ctxq
  194. }
  195. // Phase-specified parameters
  196. if q.realWordErrorLikelihood != nil {
  197. suggester["real_word_error_likelihood"] = *q.realWordErrorLikelihood
  198. }
  199. if q.confidence != nil {
  200. suggester["confidence"] = *q.confidence
  201. }
  202. if q.separator != nil {
  203. suggester["separator"] = *q.separator
  204. }
  205. if q.maxErrors != nil {
  206. suggester["max_errors"] = *q.maxErrors
  207. }
  208. if q.gramSize != nil {
  209. suggester["gram_size"] = *q.gramSize
  210. }
  211. if q.forceUnigrams != nil {
  212. suggester["force_unigrams"] = *q.forceUnigrams
  213. }
  214. if q.tokenLimit != nil {
  215. suggester["token_limit"] = *q.tokenLimit
  216. }
  217. if q.generators != nil && len(q.generators) > 0 {
  218. for typ, generators := range q.generators {
  219. var arr []interface{}
  220. for _, g := range generators {
  221. src, err := g.Source()
  222. if err != nil {
  223. return nil, err
  224. }
  225. arr = append(arr, src)
  226. }
  227. suggester[typ] = arr
  228. }
  229. }
  230. if q.smoothingModel != nil {
  231. src, err := q.smoothingModel.Source()
  232. if err != nil {
  233. return nil, err
  234. }
  235. x := make(map[string]interface{})
  236. x[q.smoothingModel.Type()] = src
  237. suggester["smoothing"] = x
  238. }
  239. if q.preTag != nil {
  240. hl := make(map[string]string)
  241. hl["pre_tag"] = *q.preTag
  242. if q.postTag != nil {
  243. hl["post_tag"] = *q.postTag
  244. }
  245. suggester["highlight"] = hl
  246. }
  247. if q.collateQuery != nil {
  248. collate := make(map[string]interface{})
  249. suggester["collate"] = collate
  250. if q.collateQuery != nil {
  251. collate["query"] = *q.collateQuery
  252. }
  253. if q.collatePreference != nil {
  254. collate["preference"] = *q.collatePreference
  255. }
  256. if len(q.collateParams) > 0 {
  257. collate["params"] = q.collateParams
  258. }
  259. if q.collatePrune != nil {
  260. collate["prune"] = *q.collatePrune
  261. }
  262. }
  263. if !includeName {
  264. return ps, nil
  265. }
  266. source := make(map[string]interface{})
  267. source[q.name] = ps
  268. return source, nil
  269. }
  270. // -- Smoothing models --
  271. type SmoothingModel interface {
  272. Type() string
  273. Source() (interface{}, error)
  274. }
  275. // StupidBackoffSmoothingModel implements a stupid backoff smoothing model.
  276. // See https://www.elastic.co/guide/en/elasticsearch/reference/5.2/search-suggesters-phrase.html#_smoothing_models
  277. // for details about smoothing models.
  278. type StupidBackoffSmoothingModel struct {
  279. discount float64
  280. }
  281. func NewStupidBackoffSmoothingModel(discount float64) *StupidBackoffSmoothingModel {
  282. return &StupidBackoffSmoothingModel{
  283. discount: discount,
  284. }
  285. }
  286. func (sm *StupidBackoffSmoothingModel) Type() string {
  287. return "stupid_backoff"
  288. }
  289. func (sm *StupidBackoffSmoothingModel) Source() (interface{}, error) {
  290. source := make(map[string]interface{})
  291. source["discount"] = sm.discount
  292. return source, nil
  293. }
  294. // --
  295. // LaplaceSmoothingModel implements a laplace smoothing model.
  296. // See https://www.elastic.co/guide/en/elasticsearch/reference/5.2/search-suggesters-phrase.html#_smoothing_models
  297. // for details about smoothing models.
  298. type LaplaceSmoothingModel struct {
  299. alpha float64
  300. }
  301. func NewLaplaceSmoothingModel(alpha float64) *LaplaceSmoothingModel {
  302. return &LaplaceSmoothingModel{
  303. alpha: alpha,
  304. }
  305. }
  306. func (sm *LaplaceSmoothingModel) Type() string {
  307. return "laplace"
  308. }
  309. func (sm *LaplaceSmoothingModel) Source() (interface{}, error) {
  310. source := make(map[string]interface{})
  311. source["alpha"] = sm.alpha
  312. return source, nil
  313. }
  314. // --
  315. // LinearInterpolationSmoothingModel implements a linear interpolation
  316. // smoothing model.
  317. // See https://www.elastic.co/guide/en/elasticsearch/reference/5.2/search-suggesters-phrase.html#_smoothing_models
  318. // for details about smoothing models.
  319. type LinearInterpolationSmoothingModel struct {
  320. trigramLamda float64
  321. bigramLambda float64
  322. unigramLambda float64
  323. }
  324. func NewLinearInterpolationSmoothingModel(trigramLamda, bigramLambda, unigramLambda float64) *LinearInterpolationSmoothingModel {
  325. return &LinearInterpolationSmoothingModel{
  326. trigramLamda: trigramLamda,
  327. bigramLambda: bigramLambda,
  328. unigramLambda: unigramLambda,
  329. }
  330. }
  331. func (sm *LinearInterpolationSmoothingModel) Type() string {
  332. return "linear_interpolation"
  333. }
  334. func (sm *LinearInterpolationSmoothingModel) Source() (interface{}, error) {
  335. source := make(map[string]interface{})
  336. source["trigram_lambda"] = sm.trigramLamda
  337. source["bigram_lambda"] = sm.bigramLambda
  338. source["unigram_lambda"] = sm.unigramLambda
  339. return source, nil
  340. }
  341. // -- CandidateGenerator --
  342. type CandidateGenerator interface {
  343. Type() string
  344. Source() (interface{}, error)
  345. }
  346. // DirectCandidateGenerator implements a direct candidate generator.
  347. // See https://www.elastic.co/guide/en/elasticsearch/reference/5.2/search-suggesters-phrase.html#_smoothing_models
  348. // for details about smoothing models.
  349. type DirectCandidateGenerator struct {
  350. field string
  351. preFilter *string
  352. postFilter *string
  353. suggestMode *string
  354. accuracy *float64
  355. size *int
  356. sort *string
  357. stringDistance *string
  358. maxEdits *int
  359. maxInspections *int
  360. maxTermFreq *float64
  361. prefixLength *int
  362. minWordLength *int
  363. minDocFreq *float64
  364. }
  365. func NewDirectCandidateGenerator(field string) *DirectCandidateGenerator {
  366. return &DirectCandidateGenerator{
  367. field: field,
  368. }
  369. }
  370. func (g *DirectCandidateGenerator) Type() string {
  371. return "direct_generator"
  372. }
  373. func (g *DirectCandidateGenerator) Field(field string) *DirectCandidateGenerator {
  374. g.field = field
  375. return g
  376. }
  377. func (g *DirectCandidateGenerator) PreFilter(preFilter string) *DirectCandidateGenerator {
  378. g.preFilter = &preFilter
  379. return g
  380. }
  381. func (g *DirectCandidateGenerator) PostFilter(postFilter string) *DirectCandidateGenerator {
  382. g.postFilter = &postFilter
  383. return g
  384. }
  385. func (g *DirectCandidateGenerator) SuggestMode(suggestMode string) *DirectCandidateGenerator {
  386. g.suggestMode = &suggestMode
  387. return g
  388. }
  389. func (g *DirectCandidateGenerator) Accuracy(accuracy float64) *DirectCandidateGenerator {
  390. g.accuracy = &accuracy
  391. return g
  392. }
  393. func (g *DirectCandidateGenerator) Size(size int) *DirectCandidateGenerator {
  394. g.size = &size
  395. return g
  396. }
  397. func (g *DirectCandidateGenerator) Sort(sort string) *DirectCandidateGenerator {
  398. g.sort = &sort
  399. return g
  400. }
  401. func (g *DirectCandidateGenerator) StringDistance(stringDistance string) *DirectCandidateGenerator {
  402. g.stringDistance = &stringDistance
  403. return g
  404. }
  405. func (g *DirectCandidateGenerator) MaxEdits(maxEdits int) *DirectCandidateGenerator {
  406. g.maxEdits = &maxEdits
  407. return g
  408. }
  409. func (g *DirectCandidateGenerator) MaxInspections(maxInspections int) *DirectCandidateGenerator {
  410. g.maxInspections = &maxInspections
  411. return g
  412. }
  413. func (g *DirectCandidateGenerator) MaxTermFreq(maxTermFreq float64) *DirectCandidateGenerator {
  414. g.maxTermFreq = &maxTermFreq
  415. return g
  416. }
  417. func (g *DirectCandidateGenerator) PrefixLength(prefixLength int) *DirectCandidateGenerator {
  418. g.prefixLength = &prefixLength
  419. return g
  420. }
  421. func (g *DirectCandidateGenerator) MinWordLength(minWordLength int) *DirectCandidateGenerator {
  422. g.minWordLength = &minWordLength
  423. return g
  424. }
  425. func (g *DirectCandidateGenerator) MinDocFreq(minDocFreq float64) *DirectCandidateGenerator {
  426. g.minDocFreq = &minDocFreq
  427. return g
  428. }
  429. func (g *DirectCandidateGenerator) Source() (interface{}, error) {
  430. source := make(map[string]interface{})
  431. if g.field != "" {
  432. source["field"] = g.field
  433. }
  434. if g.suggestMode != nil {
  435. source["suggest_mode"] = *g.suggestMode
  436. }
  437. if g.accuracy != nil {
  438. source["accuracy"] = *g.accuracy
  439. }
  440. if g.size != nil {
  441. source["size"] = *g.size
  442. }
  443. if g.sort != nil {
  444. source["sort"] = *g.sort
  445. }
  446. if g.stringDistance != nil {
  447. source["string_distance"] = *g.stringDistance
  448. }
  449. if g.maxEdits != nil {
  450. source["max_edits"] = *g.maxEdits
  451. }
  452. if g.maxInspections != nil {
  453. source["max_inspections"] = *g.maxInspections
  454. }
  455. if g.maxTermFreq != nil {
  456. source["max_term_freq"] = *g.maxTermFreq
  457. }
  458. if g.prefixLength != nil {
  459. source["prefix_length"] = *g.prefixLength
  460. }
  461. if g.minWordLength != nil {
  462. source["min_word_length"] = *g.minWordLength
  463. }
  464. if g.minDocFreq != nil {
  465. source["min_doc_freq"] = *g.minDocFreq
  466. }
  467. if g.preFilter != nil {
  468. source["pre_filter"] = *g.preFilter
  469. }
  470. if g.postFilter != nil {
  471. source["post_filter"] = *g.postFilter
  472. }
  473. return source, nil
  474. }