前言
随着当下的社会发展,我们的手机屏幕越来越大。我们的单手难以覆盖整个手机,所以当我们想要单手去点击屏幕另一侧的地方的时,就会感到较为困难。这时候我们就会想,这个按钮要是更靠近我们就好了。
那我们有办法让这些按钮自动的更靠近我们的操作手机的手么?
答案是有的,只要我们能判断出当前操作的手机是左手还是右手即可。左手按钮即可偏左;右手的话,按钮就偏右。
有了大致思路,开干!
方案
- 方案一
非机器学习方式:Recognizing the Operating Hand and the Hand-Changing Process for User Interface Adjustment on Smartphones - PMC (nih.gov) - 方案二机器学习方案(我们的方案):
- 训练一个二分类的CNN神经网络模型来识别用户是左手 or 右手操作。
- 输入:用户在屏幕上的滑动轨迹
- 输出:左手 or 右手
From: 【Android 客户端专场 学习资料二】第四届字节跳动青训营 - 掘金 (juejin.cn)
实践
样本训练这里不做介绍,对应的模型直接采用该库的 ahcyd008/OperatingHandRecognition: 端智能左右手识别学习Android Demo + 模型训练 (github.com)
导入
由于方案二事采用深度学习的,所以我们需要引入深度学习对应于Android的框架。这些框架几乎都是几个巨头大厂的,我们这边使用的是 Google
的 tensorflow
lite
版本。它是适合于 Android 使用的 tensorflow
框架,我们主要是把正常的模型压缩,转化后,就能在 Android 中使用了。
其余的两个库一个是 Google
的 task
库,一个是 Google
的 guava
库。前者是对深度学习开启后台任务以及进行监控,而 guava
则是提供一个功能更加强大的 Java 封装库。
//app/build.gradle dependencies { // Task API implementation "com.google.android.gms:play-services-tasks:17.2.1" // tensorflow lite 依赖 implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT' implementation("com.google.guava:guava:31.1-android") } 复制代码
//settings.gradle pluginManagement { repositories { gradlePluginPortal() google() mavenCentral() maven { url "https://jitpack.io" } maven { name 'ossrh-snapshot' url 'https://oss.sonatype.org/content/repositories/snapshots' } } } dependencyResolutionManagement { repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) repositories { google() mavenCentral() jcenter() maven { name 'ossrh-snapshot' url 'https://oss.sonatype.org/content/repositories/snapshots' } } } 复制代码
导入库的代码如上,记得在最后要加入 tensorflow
的仓库地址。
最后记得引入项目打包过的模型
模型的连接处理
class OperatingHandClassifier(private val context: Context) { private var interpreter: Interpreter? = null private var modelInputSize = 0 var isInitialized = false private set /** Executor to run inference task in the background */ private val executorService: ExecutorService = Executors.newSingleThreadScheduledExecutor() private var hasInit = false fun checkAndInit() { if (hasInit) { return } hasInit = true val task = TaskCompletionSource<Void?>() executorService.execute { try { initializeInterpreter() task.setResult(null) } catch (e: IOException) { task.setException(e) } } task.task.addOnFailureListener { e -> Log.e(TAG, "Error to setting up digit classifier.", e) } } @Throws(IOException::class) private fun initializeInterpreter() { // Load the TF Lite model val assetManager = context.assets val model = loadModelFile(assetManager) // Initialize TF Lite Interpreter with NNAPI enabled val options = Interpreter.Options() // 测试发现 NNAPI 对 MaxPooling1D 有支持问题,如果遇到在手机端预测和python预测不准问题可以尝试关掉 NNAPI, 再check下 options.setUseNNAPI(true) val interpreter = Interpreter(model, options) // Read input shape from model file val inputShape = interpreter.getInputTensor(0).shape() val simpleCount = inputShape[1] val tensorSize = inputShape[2] modelInputSize = FLOAT_TYPE_SIZE * simpleCount * tensorSize * PIXEL_SIZE val outputShape = interpreter.getOutputTensor(0).shape() // Finish interpreter initialization this.interpreter = interpreter isInitialized = true Log.d(TAG, "Initialized TFLite interpreter. inputShape:${Arrays.toString(inputShape)}, outputShape:${Arrays.toString(outputShape)}") } @Throws(IOException::class) private fun loadModelFile(assetManager: AssetManager): ByteBuffer { val fileDescriptor = assetManager.openFd(MODEL_FILE) // 使用全连接网络模型 // val fileDescriptor = assetManager.openFd(MODEL_CNN_FILE) // 使用卷积神经网络模型 val inputStream = FileInputStream(fileDescriptor.fileDescriptor) val fileChannel = inputStream.channel val startOffset = fileDescriptor.startOffset val declaredLength = fileDescriptor.declaredLength return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) } private fun classify(pointList: JSONArray): ClassifierLabelResult { if (!isInitialized) { throw IllegalStateException("TF Lite Interpreter is not initialized yet.") } try { // Preprocessing: resize the input var startTime: Long = System.nanoTime() val byteBuffer = convertFloatArrayToByteBuffer(pointList) var elapsedTime = (System.nanoTime() - startTime) / 1000000 Log.d(TAG, "Preprocessing time = " + elapsedTime + "ms") startTime = System.nanoTime() val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) } interpreter?.run(byteBuffer, result) elapsedTime = (System.nanoTime() - startTime) / 1000000 Log.d(TAG, "Inference time = " + elapsedTime + "ms result=" + result[0].contentToString()) // return top 4 val output = result[0][0] return if (output > 0.5f) { ClassifierLabelResult(output, "right", labelRight) } else { ClassifierLabelResult(1.0f-output, "left", labelLeft) } } catch (e: Throwable) { Log.e(TAG, "Inference error", e) } return ClassifierLabelResult(-1f, "unknown", labelUnknown) } fun classifyAsync(pointList: JSONArray): Task<ClassifierLabelResult> { val task = TaskCompletionSource<ClassifierLabelResult>() executorService.execute { val result = classify(pointList) task.setResult(result) } return task.task } fun close() { executorService.execute { interpreter?.close() Log.d(TAG, "Closed TFLite interpreter.") } } private fun convertFloatArrayToByteBuffer(pointList: JSONArray): ByteBuffer { Log.d(TAG, "convertFloatArrayToByteBuffer pointList=$pointList") val byteBuffer = ByteBuffer.allocateDirect(modelInputSize) byteBuffer.order(ByteOrder.nativeOrder()) val step = pointList.length().toFloat() / sampleCount for (i in 0 until sampleCount) { val e = pointList[(i * step).toInt()] as JSONArray for (j in 0 until tensorSize) { val value = (e[j] as Number).toFloat() // x y w h density dtime byteBuffer.putFloat(value) } } return byteBuffer } companion object { private const val TAG = "ClientAI#Classifier" private const val MODEL_FILE = "mymodel.tflite" private const val FLOAT_TYPE_SIZE = 4 private const val PIXEL_SIZE = 1 private const val OUTPUT_CLASSES_COUNT = 1 const val sampleCount = 9 const val tensorSize = 6 const val labelLeft = 0; const val labelRight = 1; const val labelUnknown = -1; } } class ClassifierLabelResult(var score: Float, var label: String ,val labelInt: Int) { override fun toString(): String { val format = DecimalFormat("#.##") return "$label score:${format.format(score)}" } } 复制代码
class MotionEventTracker(var context: Context) { companion object { const val TAG = "ClientAI#tracker" } interface ITrackDataReadyListener { fun onTrackDataReady(dataList: JSONArray) } private var width = 0 private var height = 0 private var density = 1f private var listener: ITrackDataReadyListener? = null fun checkAndInit(listener: ITrackDataReadyListener) { this.listener = listener val metric = context.resources.displayMetrics width = min(metric.widthPixels, metric.heightPixels) height = max(metric.widthPixels, metric.heightPixels) density = metric.density } private var currentEvents: JSONArray? = null private var currentDownTime = 0L fun recordMotionEvent(ev: MotionEvent) { if (ev.pointerCount > 1) { currentEvents = null return } if (ev.action == MotionEvent.ACTION_DOWN) { currentEvents = JSONArray() currentDownTime = ev.eventTime } if (currentEvents != null) { if (ev.historySize > 0) { for (i in 0 until ev.historySize) { currentEvents?.put(buildPoint(ev.getHistoricalX(i), ev.getHistoricalY(i), ev.getHistoricalEventTime(i))) } } currentEvents?.put(buildPoint(ev.x, ev.y, ev.eventTime)) } if (ev.action == MotionEvent.ACTION_UP) { currentEvents?.let { if (it.length() >= 6) { listener?.onTrackDataReady(it) // 触发预测 Log.i(TAG, "cache events, eventCount=${it.length()}, data=$it") } else { // 过滤点击和误触轨迹 Log.i(TAG, "skipped short events, eventCount=${it.length()}, data=$it") } } currentEvents = null } } private fun buildPoint(x: Float, y: Float, timestamp: Long): JSONArray { val point = JSONArray() point.put(x) point.put(y) point.put(width) point.put(height) point.put(density) point.put(timestamp - currentDownTime) return point } } 复制代码