Skip to content

Commit

Permalink
AI Sample app (#1955)
Browse files Browse the repository at this point in the history
  • Loading branch information
yschimke authored Jan 12, 2024
1 parent 619ad28 commit 3a35072
Show file tree
Hide file tree
Showing 95 changed files with 3,655 additions and 51 deletions.
148 changes: 148 additions & 0 deletions ai/sample/core/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

plugins {
id("com.android.library")
id("com.google.protobuf")
kotlin("android")
id("com.google.devtools.ksp")
id("dagger.hilt.android.plugin")
}

android {
compileSdk = 34

defaultConfig {
minSdk = 23

testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
}

compileOptions {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
}

buildFeatures {
buildConfig = false
}

kotlinOptions {
jvmTarget = "11"
freeCompilerArgs = freeCompilerArgs +
listOf(
"-opt-in=kotlin.RequiresOptIn",
"-opt-in=com.google.android.horologist.annotations.ExperimentalHorologistApi",
)
}
packaging {
resources {
excludes +=
listOf(
"/META-INF/AL2.0",
"/META-INF/LGPL2.1",
)
}
}

testOptions {
unitTests {
isIncludeAndroidResources = true
}
}

lint {
checkReleaseBuilds = false
textReport = true
}

namespace = "com.google.android.horologist.ai.sample.core"
}

protobuf {
protoc {
artifact = "com.google.protobuf:protoc:3.25.1"
}
plugins {
create("javalite") {
artifact = "com.google.protobuf:protoc-gen-javalite:3.0.0"
}
create("grpc") {
artifact = "io.grpc:protoc-gen-grpc-java:1.60.0"
}
create("grpckt") {
artifact = "io.grpc:protoc-gen-grpc-kotlin:1.3.0:jdk8@jar"
}
}
generateProtoTasks {
all().forEach { task ->
task.builtins {
create("java") {
option("lite")
}
create("kotlin") {
option("lite")
}
}
task.plugins {
create("grpc") {
option("lite")
}
create("grpckt") {
option("lite")
}
}
}
}
}

dependencies {
api(projects.annotations)

implementation(libs.dagger.hiltandroid)
implementation(projects.datalayer.core)
implementation(projects.datalayer.grpc)
ksp(libs.dagger.hiltandroidcompiler)

implementation(libs.kotlin.stdlib)
implementation(libs.kotlinx.coroutines.core)

api(libs.playservices.wearable)
implementation(libs.kotlinx.coroutines.playservices)
api(libs.androidx.datastore.preferences)
api(libs.androidx.datastore)
api(libs.protobuf.kotlin.lite)
api(libs.androidx.lifecycle.runtime)
api(libs.androidx.wear.remote.interactions)
api(libs.androidx.lifecycle.service)
api(projects.datalayer.grpc)
api(libs.io.grpc.grpc.android)
api(libs.io.grpc.grpc.binder)

testImplementation(libs.junit)
testImplementation(libs.truth)
testImplementation(libs.androidx.test.ext.ktx)
testImplementation(libs.kotlinx.coroutines.test)
testImplementation(libs.robolectric)

androidTestImplementation(libs.compose.ui.test.junit4)
androidTestImplementation(libs.espresso.core)
androidTestImplementation(libs.junit)
androidTestImplementation(libs.truth)
}

tasks.maybeCreate("prepareKotlinIdeaImport")
.dependsOn("generateDebugProto")
28 changes: 28 additions & 0 deletions ai/sample/core/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<?xml version="1.0" encoding="utf-8"?><!--
~ Copyright 2022 The Android Open Source Project
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ https://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->

<manifest xmlns:android="http://schemas.android.com/apk/res/android">
<!--
Required to be able to identify the companion with NodeClient, via calls to
getCompanionPackageForNode().
-->
<queries>
<intent>
<action android:name="com.google.android.gms.wearable.CAPABILITY_CHANGED" />
<data android:scheme="wear" />
</intent>
</queries>
</manifest>
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2024 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.android.horologist.ai.core

import android.content.Context
import android.content.Intent
import io.grpc.binder.AndroidComponentAddress
import io.grpc.binder.BinderChannelBuilder
import io.grpc.binder.UntrustedSecurityPolicies

object AiGrpcClientLookup {
fun lookupInferenceService(
context: Context,
packageName: String,
): InferenceServiceGrpcKt.InferenceServiceCoroutineStub {
val channel = BinderChannelBuilder.forAddress(
AndroidComponentAddress.forBindIntent(
Intent().apply {
setAction("InferenceService")
setPackage(packageName)
},
),
context,
)
.securityPolicy(UntrustedSecurityPolicies.untrustedPublic())
.build()

return InferenceServiceGrpcKt.InferenceServiceCoroutineStub(channel)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2024 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.android.horologist.ai.core

import android.content.Intent
import android.os.IBinder
import androidx.annotation.CallSuper
import androidx.lifecycle.LifecycleService
import io.grpc.BindableService
import io.grpc.CompressorRegistry
import io.grpc.DecompressorRegistry
import io.grpc.Server
import io.grpc.binder.AndroidComponentAddress
import io.grpc.binder.BinderServerBuilder
import io.grpc.binder.IBinderReceiver
import io.grpc.binder.SecurityPolicy
import io.grpc.binder.ServerSecurityPolicy
import io.grpc.binder.UntrustedSecurityPolicies

abstract class BindableAiGrpcService : LifecycleService() {
private lateinit var server: Server
private val binderReceiver = IBinderReceiver()

open val securityPolicy: SecurityPolicy = UntrustedSecurityPolicies.untrustedPublic()

abstract val bindableService: BindableService

@CallSuper
override fun onCreate() {
super.onCreate()
val serverSecurityPolicy =
ServerSecurityPolicy.newBuilder()
.servicePolicy(InferenceServiceGrpc.SERVICE_NAME, securityPolicy)
.build()
server =
BinderServerBuilder.forAddress(AndroidComponentAddress.forContext(this), binderReceiver)
.securityPolicy(serverSecurityPolicy)
.addService(bindableService)
.decompressorRegistry(DecompressorRegistry.emptyInstance())
.compressorRegistry(CompressorRegistry.newEmptyInstance())
.build()

server.start()
}

override fun onBind(intent: Intent): IBinder {
super.onBind(intent)
return binderReceiver.get()!!
}

override fun onDestroy() {
server.shutdownNow()
super.onDestroy()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright 2024 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

@file:OptIn(ExperimentalCoroutinesApi::class)

package com.google.android.horologist.ai.core

import android.util.Log
import com.google.android.horologist.ai.core.registry.CombinedInferenceServiceRegistry
import com.google.protobuf.empty
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.stateIn
import javax.inject.Inject
import javax.inject.Singleton

@Singleton
class InferenceService
@Inject
constructor(
val registry: CombinedInferenceServiceRegistry,
val coroutineScope: CoroutineScope,
) {
val connectedModel = MutableStateFlow<ModelId?>(null)

val models = flow {
val models = registry.models().first()
emit(
coroutineScope {
// TODO subscribe and update models dynamically
models.map { remote ->
async {
try {
val serviceInfo = remote.serviceInfo(empty { })
Pair(serviceInfo, remote)
} catch (e: Exception) {
Log.w("InferenceService", "Failing for $remote", e)
// skip and rely on filterNotNull
null
}
}
}
}.awaitAll().filterNotNull(),
)
}.stateIn(coroutineScope, SharingStarted.Eagerly, null)

val currentModelInfo: Flow<Pair<ModelInfo, ServiceInfo>?> =
combine(connectedModel, models) { currentId, currentModels ->
currentModels?.firstNotNullOfOrNull { (serviceInfo, _) ->
serviceInfo.modelsList.find {
it.modelId == currentId
}?.let {
Pair(it, serviceInfo)
}
}
}

suspend fun submit(prompt: Prompt): Response {
val currentModel = connectedModel.value ?: throw Exception("No model selected")

val (_, service) = models.value?.first { it.first.modelsList.find { it.modelId == currentModel } != null }
?: throw Exception("Service missing")

return service.answerPrompt(
promptRequest {
this.prompt = prompt
this.modelId = currentModel
},
)
}

fun selectModel(modelId: ModelId) {
connectedModel.value = modelId
}

fun clearModel() {
connectedModel.value = null
}
}
Loading

0 comments on commit 3a35072

Please sign in to comment.