Skip to content

Commit

Permalink
Add flag to disable high accuracy mode (#101 #136)
Browse files Browse the repository at this point in the history
  • Loading branch information
pemistahl committed Jun 2, 2022
1 parent 1901e6a commit a845fe4
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 54 deletions.
56 changes: 36 additions & 20 deletions src/main/kotlin/com/github/pemistahl/lingua/api/LanguageDetector.kt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class LanguageDetector internal constructor(
internal val languages: MutableSet<Language>,
internal val minimumRelativeDistance: Double,
isEveryLanguageModelPreloaded: Boolean,
internal val isHighAccuracyModeEnabled: Boolean,
internal val numberOfLoadedLanguages: Int = languages.size,
) {
private val languagesWithUniqueCharacters = languages.filterNot { it.uniqueCharacters.isNullOrBlank() }.asSequence()
Expand Down Expand Up @@ -124,7 +125,13 @@ class LanguageDetector internal constructor(
return values
}

val ngramSizeRange = if (cleanedUpText.length >= 120) (3..3) else (1..5)
val ngramSizeRange = if (cleanedUpText.length >= HIGH_ACCURACY_MODE_MAX_TEXT_LENGTH ||
!isHighAccuracyModeEnabled
) {
(3..3)
} else {
(1..5)
}
val tasks = ngramSizeRange.filter { i -> cleanedUpText.length >= i }.map { i ->
Callable {
val testDataModel = TestDataLanguageModel.fromText(cleanedUpText, ngramLength = i)
Expand Down Expand Up @@ -188,20 +195,22 @@ class LanguageDetector internal constructor(
* in parallel.
*/
fun unloadLanguageModels() {
synchronized(unigramLanguageModels) {
languages.forEach(unigramLanguageModels::remove)
}
synchronized(bigramLanguageModels) {
languages.forEach(bigramLanguageModels::remove)
}
synchronized(trigramLanguageModels) {
languages.forEach(trigramLanguageModels::remove)
}
synchronized(quadrigramLanguageModels) {
languages.forEach(quadrigramLanguageModels::remove)
}
synchronized(fivegramLanguageModels) {
languages.forEach(fivegramLanguageModels::remove)
if (isHighAccuracyModeEnabled) {
synchronized(unigramLanguageModels) {
languages.forEach(unigramLanguageModels::remove)
}
synchronized(bigramLanguageModels) {
languages.forEach(bigramLanguageModels::remove)
}
synchronized(quadrigramLanguageModels) {
languages.forEach(quadrigramLanguageModels::remove)
}
synchronized(fivegramLanguageModels) {
languages.forEach(fivegramLanguageModels::remove)
}
}
}

Expand Down Expand Up @@ -350,8 +359,8 @@ class LanguageDetector internal constructor(
val (mostFrequentLanguage, firstCharCount) = sortedTotalLanguageCounts[0]
val (_, secondCharCount) = sortedTotalLanguageCounts[1]

return when {
firstCharCount == secondCharCount -> UNKNOWN
return when (firstCharCount) {
secondCharCount -> UNKNOWN
else -> mostFrequentLanguage
}
}
Expand Down Expand Up @@ -460,27 +469,34 @@ class LanguageDetector internal constructor(
val tasks = mutableListOf<Callable<Object2FloatMap<String>>>()

for (language in languages) {
tasks.add(Callable { loadLanguageModels(unigramLanguageModels, language, 1) })
tasks.add(Callable { loadLanguageModels(bigramLanguageModels, language, 2) })
tasks.add(Callable { loadLanguageModels(trigramLanguageModels, language, 3) })
tasks.add(Callable { loadLanguageModels(quadrigramLanguageModels, language, 4) })
tasks.add(Callable { loadLanguageModels(fivegramLanguageModels, language, 5) })

if (isHighAccuracyModeEnabled) {
tasks.add(Callable { loadLanguageModels(unigramLanguageModels, language, 1) })
tasks.add(Callable { loadLanguageModels(bigramLanguageModels, language, 2) })
tasks.add(Callable { loadLanguageModels(quadrigramLanguageModels, language, 4) })
tasks.add(Callable { loadLanguageModels(fivegramLanguageModels, language, 5) })
}
}

ForkJoinPool.commonPool().invokeAll(tasks)
ForkJoinPool.commonPool().invokeAll(tasks).forEach { it.get() }
}

override fun equals(other: Any?) = when {
this === other -> true
other !is LanguageDetector -> false
languages != other.languages -> false
minimumRelativeDistance != other.minimumRelativeDistance -> false
isHighAccuracyModeEnabled != other.isHighAccuracyModeEnabled -> false
else -> true
}

override fun hashCode() = 31 * languages.hashCode() + minimumRelativeDistance.hashCode()
override fun hashCode() =
31 * languages.hashCode() + minimumRelativeDistance.hashCode() + isHighAccuracyModeEnabled.hashCode()

internal companion object {
private const val HIGH_ACCURACY_MODE_MAX_TEXT_LENGTH = 120

internal val unigramLanguageModels = enumMapOf<Language, Object2FloatMap<String>>()
internal val bigramLanguageModels = enumMapOf<Language, Object2FloatMap<String>>()
internal val trigramLanguageModels = enumMapOf<Language, Object2FloatMap<String>>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@ package com.github.pemistahl.lingua.api
class LanguageDetectorBuilder private constructor(
internal val languages: List<Language>,
internal var minimumRelativeDistance: Double = 0.0,
internal var isEveryLanguageModelPreloaded: Boolean = false
internal var isEveryLanguageModelPreloaded: Boolean = false,
internal var isHighAccuracyModeEnabled: Boolean = true
) {
/**
* Creates and returns the configured instance of [LanguageDetector].
*/
fun build() = LanguageDetector(languages.toMutableSet(), minimumRelativeDistance, isEveryLanguageModelPreloaded)
fun build() = LanguageDetector(
languages.toMutableSet(),
minimumRelativeDistance,
isEveryLanguageModelPreloaded,
isHighAccuracyModeEnabled
)

/**
* Sets the desired value for the minimum relative distance measure.
Expand Down Expand Up @@ -72,6 +78,24 @@ class LanguageDetectorBuilder private constructor(
return this
}

/**
* Disables the high accuracy mode in order to save memory and increase performance.
*
* By default, *Lingua's* high detection accuracy comes at the cost of
* loading large language models into memory which might not be feasible
* for systems running low on resources.
*
* This method disables the high accuracy mode so that only a small subset
* of language models is loaded into memory. The downside of this approach
* is that detection accuracy for short texts consisting of less than 120
* characters will drop significantly. However, detection accuracy for texts
* which are longer than 120 characters will remain mostly unaffected.
*/
fun withoutHighAccuracyMode(): LanguageDetectorBuilder {
this.isHighAccuracyModeEnabled = false
return this
}

companion object {
/**
* Creates and returns an instance of LanguageDetectorBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package com.github.pemistahl.lingua.api

import com.github.pemistahl.lingua.api.Language.ENGLISH
import com.github.pemistahl.lingua.api.Language.GERMAN
import com.github.pemistahl.lingua.api.Language.SWEDISH
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.junit.jupiter.api.Test
Expand All @@ -31,11 +34,13 @@ class LanguageDetectorBuilderTest {
assertThat(builder.languages).isEqualTo(Language.all())
assertThat(builder.minimumRelativeDistance).isEqualTo(0.0)
assertThat(builder.isEveryLanguageModelPreloaded).isFalse
assertThat(builder.isHighAccuracyModeEnabled).isTrue
assertThat(builder.build()).isEqualTo(
LanguageDetector(
Language.all().toMutableSet(),
minimumRelativeDistance = 0.0,
isEveryLanguageModelPreloaded = false
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = true
)
)

Expand All @@ -44,7 +49,8 @@ class LanguageDetectorBuilderTest {
LanguageDetector(
Language.all().toMutableSet(),
minimumRelativeDistance = 0.2,
isEveryLanguageModelPreloaded = false
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = true
)
)
}
Expand All @@ -56,11 +62,13 @@ class LanguageDetectorBuilderTest {
assertThat(builder.languages).isEqualTo(Language.allSpokenOnes())
assertThat(builder.minimumRelativeDistance).isEqualTo(0.0)
assertThat(builder.isEveryLanguageModelPreloaded).isFalse
assertThat(builder.isHighAccuracyModeEnabled).isTrue
assertThat(builder.build()).isEqualTo(
LanguageDetector(
Language.allSpokenOnes().toMutableSet(),
minimumRelativeDistance = 0.0,
isEveryLanguageModelPreloaded = false
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = true
)
)

Expand All @@ -69,7 +77,8 @@ class LanguageDetectorBuilderTest {
LanguageDetector(
Language.allSpokenOnes().toMutableSet(),
minimumRelativeDistance = 0.2,
isEveryLanguageModelPreloaded = false
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = true
)
)
}
Expand Down Expand Up @@ -112,11 +121,13 @@ class LanguageDetectorBuilderTest {
assertThat(builder.languages).isEqualTo(expectedLanguages)
assertThat(builder.minimumRelativeDistance).isEqualTo(0.0)
assertThat(builder.isEveryLanguageModelPreloaded).isFalse
assertThat(builder.isHighAccuracyModeEnabled).isTrue
assertThat(builder.build()).isEqualTo(
LanguageDetector(
expectedLanguages.toMutableSet(),
minimumRelativeDistance = 0.0,
isEveryLanguageModelPreloaded = false
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = true
)
)

Expand All @@ -125,32 +136,35 @@ class LanguageDetectorBuilderTest {
LanguageDetector(
expectedLanguages.toMutableSet(),
minimumRelativeDistance = 0.2,
isEveryLanguageModelPreloaded = false
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = true
)
)
}
run {
val languages = Language.values().toSet().minus(arrayOf(Language.GERMAN, Language.ENGLISH)).toTypedArray()
val languages = Language.values().toSet().minus(arrayOf(GERMAN, ENGLISH)).toTypedArray()
assertThatIllegalArgumentException().isThrownBy {
LanguageDetectorBuilder.fromAllLanguagesWithout(Language.GERMAN, *languages)
LanguageDetectorBuilder.fromAllLanguagesWithout(GERMAN, *languages)
}.withMessage(minimumLanguagesErrorMessage)
}
}

@Test
fun `assert that LanguageDetector can be built from whitelist`() {
run {
val builder = LanguageDetectorBuilder.fromLanguages(Language.GERMAN, Language.ENGLISH)
val expectedLanguages = listOf(Language.GERMAN, Language.ENGLISH)
val builder = LanguageDetectorBuilder.fromLanguages(GERMAN, ENGLISH)
val expectedLanguages = listOf(GERMAN, ENGLISH)

assertThat(builder.languages).isEqualTo(expectedLanguages)
assertThat(builder.minimumRelativeDistance).isEqualTo(0.0)
assertThat(builder.isEveryLanguageModelPreloaded).isFalse
assertThat(builder.isHighAccuracyModeEnabled).isTrue
assertThat(builder.build()).isEqualTo(
LanguageDetector(
expectedLanguages.toMutableSet(),
minimumRelativeDistance = 0.0,
isEveryLanguageModelPreloaded = false
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = true
)
)

Expand All @@ -159,13 +173,14 @@ class LanguageDetectorBuilderTest {
LanguageDetector(
expectedLanguages.toMutableSet(),
minimumRelativeDistance = 0.2,
isEveryLanguageModelPreloaded = false
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = true
)
)
}
run {
assertThatIllegalArgumentException().isThrownBy {
LanguageDetectorBuilder.fromLanguages(Language.GERMAN)
LanguageDetectorBuilder.fromLanguages(GERMAN)
}.withMessage(minimumLanguagesErrorMessage)
}
}
Expand All @@ -174,16 +189,18 @@ class LanguageDetectorBuilderTest {
fun `assert that LanguageDetector can be built from iso codes`() {
run {
val builder = LanguageDetectorBuilder.fromIsoCodes639_1(IsoCode639_1.DE, IsoCode639_1.SV)
val expectedLanguages = listOf(Language.GERMAN, Language.SWEDISH)
val expectedLanguages = listOf(GERMAN, SWEDISH)

assertThat(builder.languages).isEqualTo(expectedLanguages)
assertThat(builder.minimumRelativeDistance).isEqualTo(0.0)
assertThat(builder.isEveryLanguageModelPreloaded).isFalse
assertThat(builder.isHighAccuracyModeEnabled).isTrue
assertThat(builder.build()).isEqualTo(
LanguageDetector(
expectedLanguages.toMutableSet(),
minimumRelativeDistance = 0.0,
isEveryLanguageModelPreloaded = false
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = true
)
)

Expand All @@ -192,15 +209,17 @@ class LanguageDetectorBuilderTest {
LanguageDetector(
expectedLanguages.toMutableSet(),
minimumRelativeDistance = 0.2,
isEveryLanguageModelPreloaded = false
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = true
)
)

assertThat(builder.build()).isEqualTo(
LanguageDetector(
expectedLanguages.toMutableSet(),
minimumRelativeDistance = 0.2,
isEveryLanguageModelPreloaded = false
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = true
)
)
}
Expand All @@ -225,4 +244,55 @@ class LanguageDetectorBuilderTest {
}.withMessage(errorMessage)
}
}

@Test
fun `assert that LanguageDetector can be built with preloaded language models`() {
val builder = LanguageDetectorBuilder.fromLanguages(ENGLISH, GERMAN).withPreloadedLanguageModels()
val expectedLanguages = listOf(ENGLISH, GERMAN)

assertThat(builder.languages).isEqualTo(expectedLanguages)
assertThat(builder.minimumRelativeDistance).isEqualTo(0.0)
assertThat(builder.isEveryLanguageModelPreloaded).isTrue
assertThat(builder.isHighAccuracyModeEnabled).isTrue
assertThat(builder.build()).isEqualTo(
LanguageDetector(
expectedLanguages.toMutableSet(),
minimumRelativeDistance = 0.0,
isEveryLanguageModelPreloaded = true,
isHighAccuracyModeEnabled = true
)
)
}

@Test
fun `assert that LanguageDetector can be built without high accuracy mode`() {
val builder = LanguageDetectorBuilder.fromLanguages(ENGLISH, GERMAN).withoutHighAccuracyMode()
val expectedLanguages = listOf(ENGLISH, GERMAN)

assertThat(builder.languages).isEqualTo(expectedLanguages)
assertThat(builder.minimumRelativeDistance).isEqualTo(0.0)
assertThat(builder.isEveryLanguageModelPreloaded).isFalse
assertThat(builder.isHighAccuracyModeEnabled).isFalse
assertThat(builder.build()).isEqualTo(
LanguageDetector(
expectedLanguages.toMutableSet(),
minimumRelativeDistance = 0.0,
isEveryLanguageModelPreloaded = false,
isHighAccuracyModeEnabled = false
)
)

builder.withPreloadedLanguageModels()

assertThat(builder.isEveryLanguageModelPreloaded).isTrue
assertThat(builder.isHighAccuracyModeEnabled).isFalse
assertThat(builder.build()).isEqualTo(
LanguageDetector(
expectedLanguages.toMutableSet(),
minimumRelativeDistance = 0.0,
isEveryLanguageModelPreloaded = true,
isHighAccuracyModeEnabled = false
)
)
}
}
Loading

0 comments on commit a845fe4

Please sign in to comment.