ensorFlow 智能移动项目:6~10(2)https://developer.aliyun.com/article/1426909
接下来,我们定义输入和输出节点名称,并创建一个包含总点数的张量:
std::string input_name1 = "Reshape"; std::string input_name2 = "Squeeze"; std::string output_name1 = "dense/BiasAdd"; std::string output_name2 = "ArgMax" const int BATCH_SIZE = 8; tensorflow::Tensor seqlen_tensor(tensorflow::DT_INT64, tensorflow::TensorShape({BATCH_SIZE})); auto seqlen_mapped = seqlen_tensor.tensor<int64_t, 1>(); int64_t* seqlen_mapped_data = seqlen_mapped.data(); for (int i=0; i<BATCH_SIZE; i++) { seqlen_mapped_data[i] = total_points; }
请注意,在运行train_model.py
来训练模型时,我们必须使用与BATCH_SIZE
相同的BATCH_SIZE
,默认情况下为 8。
保存所有转换点值的另一个张量在这里创建:
tensorflow::Tensor points_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({8, total_points, 3})); auto points_tensor_mapped = points_tensor.tensor<float, 3>(); float* out = points_tensor_mapped.data(); for (int i=0; i<BATCH_SIZE; i++) { for (int j=0; j<total_points*3; j++) out[i*total_points*3+j] = normalized_points[j]; }
- 现在,我们运行模型并获得预期的输出:
std::vector<tensorflow::Tensor> outputs; tensorflow::Status run_status = tf_session->Run({{input_name1, points_tensor}, {input_name2, seqlen_tensor}}, {output_name1, output_name2}, {}, &outputs); if (!run_status.ok()) { LOG(ERROR) << "Getting model failed:" << run_status; return ""; } tensorflow::string status_string = run_status.ToString(); tensorflow::Tensor* logits_tensor = &outputs[0];
- 使用修改后的
GetTopN
版本并解析logits
获得最佳结果:
const int kNumResults = 5; const float kThreshold = 0.1f; std::vector<std::pair<float, int> > top_results; const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>, Eigen::Aligned>& logits = logits_tensor->flat<float>(); GetTopN(logits, kNumResults, kThreshold, &top_results); string result = ""; for (int i=0; i<top_results.size(); i++) { std::pair<float, int> r = top_results[i]; if (result == "") result = classes[r.second]; else result += ", " + classes[r.second]; }
- 通过将
logits
值转换为 softmax 值来更改GetTopN
,然后返回顶部 softmax 值及其位置:
float sum = 0.0; for (int i = 0; i < CLASS_COUNT; ++i) { sum += expf(prediction(i)); } for (int i = 0; i < CLASS_COUNT; ++i) { const float value = expf(prediction(i)) / sum; if (value < threshold) { continue; } top_result_pq.push(std::pair<float, int>(value, i)); if (top_result_pq.size() > num_results) { top_result_pq.pop(); } }
- 最后,
normalizeScreenCoordinates
函数将其在触摸事件中捕获的屏幕坐标中的所有点转换为增量差异 – 这几乎是这个页面中的 Python 方法parse_line
的一部分:
void normalizeScreenCoordinates(NSMutableArray *allPoints, float *normalized) { float lowerx=MAXFLOAT, lowery=MAXFLOAT, upperx=-MAXFLOAT, uppery=-MAXFLOAT; for (NSArray *cp in allPoints) { for (NSValue *pointVal in cp) { CGPoint point = pointVal.CGPointValue; if (point.x < lowerx) lowerx = point.x; if (point.y < lowery) lowery = point.y; if (point.x > upperx) upperx = point.x; if (point.y > uppery) uppery = point.y; } } float scalex = upperx - lowerx; float scaley = uppery - lowery; int n = 0; for (NSArray *cp in allPoints) { int m=0; for (NSValue *pointVal in cp) { CGPoint point = pointVal.CGPointValue; normalized[n*3] = (point.x - lowerx) / scalex; normalized[n*3+1] = (point.y - lowery) / scaley; normalized[n*3+2] = (m ==cp.count-1 ? 1 : 0); n++; m++; } } for (int i=0; i<n-1; i++) { normalized[i*3] = normalized[(i+1)*3] - normalized[i*3]; normalized[i*3+1] = normalized[(i+1)*3+1] - normalized[i*3+1]; normalized[i*3+2] = normalized[(i+1)*3+2]; } }
现在,您可以在 iOS 模拟器或设备中运行该应用,开始绘画,并查看模型认为您正在绘画的内容。 图 7.8 显示了一些绘画和分类结果–不是最佳绘画,而是整个过程!
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MnGEbMeY-1681653119034)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/48678dd2-0490-4b97-a0a6-fa4ba52bc27b.png)]
图 7.8:在 iOS 上显示绘画和分类结果
在 Android 中使用绘画分类模型
现在该看看我们如何在 Android 中加载和使用该模型。 在之前的章节中,我们通过使用 Android 应用的build.gradle
文件并添加了一行 compile 'org.tensorflow:tensorflow-android:+'
仅添加了 TensorFlow 支持。 与 iOS 相比,我们必须构建一个自定义的 TensorFlow 库来修复不同的模型加载或运行错误(例如,在第 3 章,“检测对象及其位置”中,第四章,“变换具有惊人艺术风格的图片”和第五章,“了解简单的语音命令”),Android 的默认 TensorFlow 库对注册的操作和数据类型有更好的支持,这可能是因为 Android 是 Google 的一等公民,而 iOS 是第二名,甚至是第二名。
事实是,当我们处理各种惊人的模型时,我们不得不面对不可避免的问题只是时间问题:我们必须手动为 Android 构建 TensorFlow 库,以修复默认 TensorFlow 库中的一些根本无法应对的错误。 No OpKernel was registered to support Op 'RefSwitch' with these attrs.
错误就是这样的错误之一。 对于乐观的开发人员来说,这仅意味着另一种向您的技能组合中添加新技巧的机会。
为 Android 构建自定义 TensorFlow 库
请按照以下步骤手动为 Android 构建自定义的 TensorFlow 库:
- 在您的 TensorFlow 根目录中,有一个名为
WORKSPACE
的文件。 编辑它,并使android_sdk_repository
和android_ndk_repository
看起来像以下设置(用您自己的设置替换build_tools_version
以及 SDK 和 NDK 路径):
android_sdk_repository( name = "androidsdk", api_level = 23, build_tools_version = "26.0.1", path = "$HOME/Library/Android/sdk", ) android_ndk_repository( name="androidndk", path="$HOME/Downloads/android-ndk-r15c", api_level=14)
- 如果您还使用过本书中的 iOS 应用,并且已将
tensorflow/core/platform/default/mutex.h
从#include "nsync_cv.h"
和#include "nsync_mu.h"
更改为#include "nsync/public/nsync_cv.h"
和#include "nsync/public/nsync_mu.h"
,请参见第 3 章, “检测对象及其位置” 时,您需要将其更改回以成功构建 TensorFlow Android 库(此后,当您使用手动构建的 TensorFlow 库在 Xcode 和 iOS 应用上工作时,需要先添加nsync/public
这两个标头。
Changing tensorflow/core/platform/default/mutex.h
back and forth certainly is not an ideal solution. It’s supposed to be just as a workaround. As it only needs to be changed when you start using a manually built TensorFlow iOS library or when you build a custom TensorFlow library, we can live with it for now.
- 如果您具有支持 x86 CPU 的虚拟模拟器或 Android 设备,请运行以下命令来构建本机 TensorFlow 库:
bazel build -c opt --copt="-D__ANDROID_TYPES_FULL__" //tensorflow/contrib/android:libtensorflow_inference.so \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --cpu=x86_64
如果您的 Android 设备像大多数 Android 设备一样支持 armeabi-v7a,请运行以下命令:
bazel build -c opt --copt="-D__ANDROID_TYPES_FULL__" //tensorflow/contrib/android:libtensorflow_inference.so \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --cpu=armeabi-v7a
在 Android 应用中使用手动构建的本机库时,您需要让该应用知道该库是针对哪个 CPU 指令集(也称为应用二进制接口(ABI))构建的。 Android 支持两种主要的 ABI:ARM 和 X86,而armeabi-v7a
是 Android 上最受欢迎的 ABI。 要找出您的设备或仿真器使用的是哪个 ABI,请运行adb -s shell getprop ro.product.cpu.abi
。 例如,此命令为我的 Nexus 7 平板电脑返回armeabi-v7a
,为我的模拟器返回x86_64
。
如果您具有支持 x86_64 的虚拟仿真器以在开发过程中进行快速测试,并且在设备上进行最终性能测试,则可能要同时构建两者。
构建完成后,您将在bazel-bin/tensorflow/contrib/android
文件夹中看到 TensorFlow 本机库文件libtensorflow_inference.so
。 将其拖到android/app/src/main/jniLibs/armeabi-v7a
或 android/app/src/main/jniLibs/x86_64
的app
文件夹中,如图 7.9 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-U37zgeVF-1681653119035)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/0865ad60-dece-4349-9825-24ca46220d07.png)]
图 7.9:显示 TensorFlow 本机库文件
- 通过运行以下命令构建 TensorFlow 本机库的 Java 接口:
bazel build //tensorflow/contrib/android:android_tensorflow_inference_java
这将在bazel-bin/tensorflow/contrib/android
处生成文件libandroid_tensorflow_inference_java.jar
。 将文件移动到 android/app/lib
文件夹,如图 7.10 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-p6iE4MOg-1681653119035)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/51e4c597-7447-4623-8e07-396b56faddfa.png)]
图 7.10:将 Java 接口文件显示到 TensorFlow 库
现在,我们准备在 Android 中编码和测试模型。
开发一个 Android 应用来使用该模型
请按照以下步骤使用 TensorFlow 库和我们先前构建的模型创建一个新的 Android 应用:
- 在 Android Studio 中,创建一个名为 QuickDraw 的新 Android 应用,接受所有默认设置。 然后在应用的
build.gradle
中,将compile files('libs/libandroid_tensorflow_inference_java.jar')
添加到依赖项的末尾。 像以前一样创建一个新的assets
文件夹,并将quickdraw_frozen_long_blacklist_strip_transformed.pb
和classes.txt
拖放到该文件夹中。 - 创建一个名为
QuickDrawView
的新 Java 类,该类扩展了View
,并如下设置字段及其构造器:
public class QuickDrawView extends View { private Path mPath; private Paint mPaint, mCanvasPaint; private Canvas mCanvas; private Bitmap mBitmap; private MainActivity mActivity; private List<List<Pair<Float, Float>>> mAllPoints = new ArrayList<List<Pair<Float, Float>>>(); private List<Pair<Float, Float>> mConsecutivePoints = new ArrayList<Pair<Float, Float>>(); public QuickDrawView(Context context, AttributeSet attrs) { super(context, attrs); mActivity = (MainActivity) context; setPathPaint(); }
mAllPoints
用于保存mConsecutivePoints
的列表。 QuickDrawView
用于主要活动的布局中,以显示用户的绘画。
- 如下定义
setPathPaint
方法:
private void setPathPaint() { mPath = new Path(); mPaint = new Paint(); mPaint.setColor(0xFF000000); mPaint.setAntiAlias(true); mPaint.setStrokeWidth(18); mPaint.setStyle(Paint.Style.STROKE); mPaint.setStrokeJoin(Paint.Join.ROUND); mCanvasPaint = new Paint(Paint.DITHER_FLAG); }
添加两个实例化Bitmap
和Canvas
对象并向用户显示在画布上绘画的重写方法:
@Override protected void onSizeChanged(int w, int h, int oldw, int oldh) { super.onSizeChanged(w, h, oldw, oldh); mBitmap = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888); mCanvas = new Canvas(mBitmap); } @Override protected void onDraw(Canvas canvas) { canvas.drawBitmap(mBitmap, 0, 0, mCanvasPaint); canvas.drawPath(mPath, mPaint); }
- 覆盖方法
onTouchEvent
用于填充mConsecutivePoints
和mAllPoints
,调用画布的drawPath
方法,使图无效(以调用onDraw
方法),以及(每次使用MotionEvent.ACTION_UP
完成笔划线),以启动一个新线程以使用模型对绘画进行分类:
@Override public boolean onTouchEvent(MotionEvent event) { if (!mActivity.canDraw()) return true; float x = event.getX(); float y = event.getY(); switch (event.getAction()) { case MotionEvent.ACTION_DOWN: mConsecutivePoints.clear(); mConsecutivePoints.add(new Pair(x, y)); mPath.moveTo(x, y); break; case MotionEvent.ACTION_MOVE: mConsecutivePoints.add(new Pair(x, y)); mPath.lineTo(x, y); break; case MotionEvent.ACTION_UP: mConsecutivePoints.add(new Pair(x, y)); mAllPoints.add(new ArrayList<Pair<Float, Float>> (mConsecutivePoints)); mCanvas.drawPath(mPath, mPaint); mPath.reset(); Thread thread = new Thread(mActivity); thread.start(); break; default: return false; } invalidate(); return true; }
- 定义两个将由
MainActivity
调用的公共方法,以获取所有点并在用户点击重新启动按钮后重置绘画:
public List<List<Pair<Float, Float>>> getAllPoints() { return mAllPoints; } public void clearAllPointsAndRedraw() { mBitmap = Bitmap.createBitmap(mBitmap.getWidth(), mBitmap.getHeight(), Bitmap.Config.ARGB_8888); mCanvas = new Canvas(mBitmap); mCanvasPaint = new Paint(Paint.DITHER_FLAG); mCanvas.drawBitmap(mBitmap, 0, 0, mCanvasPaint); setPathPaint(); invalidate(); mAllPoints.clear(); }
- 现在打开
MainActivity
,并使其实现Runnable
及其字段,如下所示:
public class MainActivity extends AppCompatActivity implements Runnable { private static final String MODEL_FILE = "file:///android_asset/quickdraw_frozen_long_blacklist_strip_transformed.pb"; private static final String CLASSES_FILE = "file:///android_asset/classes.txt"; private static final String INPUT_NODE1 = "Reshape"; private static final String INPUT_NODE2 = "Squeeze"; private static final String OUTPUT_NODE1 = "dense/BiasAdd"; private static final String OUTPUT_NODE2 = "ArgMax"; private static final int CLASSES_COUNT = 345; private static final int BATCH_SIZE = 8; private String[] mClasses = new String[CLASSES_COUNT]; private QuickDrawView mDrawView; private Button mButton; private TextView mTextView; private String mResult = ""; private boolean mCanDraw = false; private TensorFlowInferenceInterface mInferenceInterface;
- 在主布局文件
activity_main.xml
中,除了我们之前所做的TextView
和Button
之外,还创建一个QuickDrawView
元素:
<com.ailabby.quickdraw.QuickDrawView android:id="@+id/drawview" android:layout_width="fill_parent" android:layout_height="fill_parent" app:layout_constraintBottom_toBottomOf="parent" app:layout_constraintLeft_toLeftOf="parent" app:layout_constraintRight_toRightOf="parent" app:layout_constraintTop_toTopOf="parent"/>
- 返回
MainActivity
; 在其onCreate
方法中,将 UI 元素 ID 与字段绑定,为启动/重启按钮设置点击监听器。 然后将classes.txt
文件读入字符串数组:
@Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); mDrawView = findViewById(R.id.drawview); mButton = findViewById(R.id.button); mTextView = findViewById(R.id.textview); mButton.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { mCanDraw = true; mButton.setText("Restart"); mTextView.setText(""); mDrawView.clearAllPointsAndRedraw(); } }); String classesFilename = CLASSES_FILE.split("file:///android_asset/")[1]; BufferedReader br = null; int linenum = 0; try { br = new BufferedReader(new InputStreamReader(getAssets().open(classesFilename))); String line; while ((line = br.readLine()) != null) { mClasses[linenum++] = line; } br.close(); } catch (IOException e) { throw new RuntimeException("Problem reading classes file!" , e); } }
- 然后从线程的
run
方法中调用同步方法classifyDrawing
:
public void run() { classifyDrawing(); } private synchronized void classifyDrawing() { try { double normalized_points[] = normalizeScreenCoordinates(); long total_points = normalized_points.length / 3; float[] floatValues = new float[normalized_points.length*BATCH_SIZE]; for (int i=0; i<normalized_points.length; i++) { for (int j=0; j<BATCH_SIZE; j++) floatValues[j*normalized_points.length + i] = (float)normalized_points[i]; } long[] seqlen = new long[BATCH_SIZE]; for (int i=0; i<BATCH_SIZE; i++) seqlen[i] = total_points;
即将实现的normalizeScreenCoordinates
方法将用户绘画点转换为模型期望的格式。 floatValues
和seqlen
将被输入模型。 请注意,由于模型需要这些确切的数据类型(float
和int64
),因此我们必须在floatValues
中使用float
在seqlen
中使用long
,否则在使用模型时会发生运行时错误。
- 创建一个与 TensorFlow 库的 Java 接口以加载模型,向模型提供输入并获取输出:
AssetManager assetManager = getAssets(); mInferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_FILE); mInferenceInterface.feed(INPUT_NODE1, floatValues, BATCH_SIZE, total_points, 3); mInferenceInterface.feed(INPUT_NODE2, seqlen, BATCH_SIZE); float[] logits = new float[CLASSES_COUNT * BATCH_SIZE]; float[] argmax = new float[CLASSES_COUNT * BATCH_SIZE]; mInferenceInterface.run(new String[] {OUTPUT_NODE1, OUTPUT_NODE2}, false); mInferenceInterface.fetch(OUTPUT_NODE1, logits); mInferenceInterface.fetch(OUTPUT_NODE1, argmax);
- 归一化所提取的
logits
概率并以降序对其进行排序:
double sum = 0.0; for (int i=0; i<CLASSES_COUNT; i++) sum += Math.exp(logits[i]); List<Pair<Integer, Float>> prob_idx = new ArrayList<Pair<Integer, Float>>(); for (int j = 0; j < CLASSES_COUNT; j++) { prob_idx.add(new Pair(j, (float)(Math.exp(logits[j]) / sum) )); } Collections.sort(prob_idx, new Comparator<Pair<Integer, Float>>() { @Override public int compare(final Pair<Integer, Float> o1, final Pair<Integer, Float> o2) { return o1.second > o2.second ? -1 : (o1.second == o2.second ? 0 : 1); } });
获取前五个结果并将其显示在TextView
中:
mResult = ""; for (int i=0; i<5; i++) { if (prob_idx.get(i).second > 0.1) { if (mResult == "") mResult = "" + mClasses[prob_idx.get(i).first]; else mResult = mResult + ", " + mClasses[prob_idx.get(i).first]; } } runOnUiThread( new Runnable() { @Override public void run() { mTextView.setText(mResult); } });
- 最后,实现
normalizeScreenCoordinates
方法,它是 iOS 实现的便捷端口:
private double[] normalizeScreenCoordinates() { List<List<Pair<Float, Float>>> allPoints = mDrawView.getAllPoints(); int total_points = 0; for (List<Pair<Float, Float>> cp : allPoints) { total_points += cp.size(); } double[] normalized = new double[total_points * 3]; float lowerx=Float.MAX_VALUE, lowery=Float.MAX_VALUE, upperx=-Float.MAX_VALUE, uppery=-Float.MAX_VALUE; for (List<Pair<Float, Float>> cp : allPoints) { for (Pair<Float, Float> p : cp) { if (p.first < lowerx) lowerx = p.first; if (p.second < lowery) lowery = p.second; if (p.first > upperx) upperx = p.first; if (p.second > uppery) uppery = p.second; } } float scalex = upperx - lowerx; float scaley = uppery - lowery; int n = 0; for (List<Pair<Float, Float>> cp : allPoints) { int m = 0; for (Pair<Float, Float> p : cp) { normalized[n*3] = (p.first - lowerx) / scalex; normalized[n*3+1] = (p.second - lowery) / scaley; normalized[n*3+2] = (m ==cp.size()-1 ? 1 : 0); n++; m++; } } for (int i=0; i<n-1; i++) { normalized[i*3] = normalized[(i+1)*3] - normalized[i*3]; normalized[i*3+1] = normalized[(i+1)*3+1] - normalized[i*3+1]; normalized[i*3+2] = normalized[(i+1)*3+2]; } return normalized; }
在您的 Android 模拟器或设备上运行该应用,并享受分类结果的乐趣。 您应该看到类似图 7.11 的内容:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-o4rePOTz-1681653119035)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/560fa3a2-94a9-4f93-ae05-bd09921b8e0c.png)]
图 7.11:在 Android 上显示绘画和分类结果
既然您已经了解了训练 Quick Draw 模型的全过程,并在 iOS 和 Android 应用中使用了它,那么您当然可以微调训练方法,使其更加准确,并改善移动应用的乐趣。
在本章我们不得不结束有趣旅程之前的最后一个提示是,如果您使用错误的 ABI 构建适用于 Android 的 TensorFlow 本机库,您仍然可以从 Android Studio 构建和运行该应用,但将出现运行时错误java.lang.RuntimeException: Native TF methods not found; check that the correct native libraries are present in the APK.
,这意味着您的应用的jniLibs
文件夹中没有正确的 TensorFlow 本机库(图 7.9)。 要找出jniLibs
内特定 ABI 文件夹中是否缺少该文件,可以从Android Studio | View | Tool Windows
中打开Device File Explorer
,然后选择设备的data | app | package | lib
来查看,如图 7.12 所示。 如果您更喜欢命令行,则也可以使用adb
工具找出来。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0X6N6LxN-1681653119035)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/16494075-d42a-4b86-a48a-4cb0bc2ec865.png)]
图 7.12:使用设备文件资源管理器检出 TensorFlow 本机库文件
总结
在本章中,我们首先描述了绘画分类模型的工作原理,然后介绍了如何使用高级 TensorFlow Estimator API 训练这种模型。 我们研究了如何编写 Python 代码以使用经过训练的模型进行预测,然后详细讨论了如何找到正确的输入和输出节点名称以及如何以正确的方式冻结和转换模型以使移动应用可以使用它。 我们还提供了一种新方法来构建新的 TensorFlow 自定义 iOS 库,并提供了一个逐步教程,以构建适用于 Android 的 TensorFlow 自定义库,以修复使用模型时的运行时错误。 最后,我们展示了 iOS 和 Android 代码,这些代码捕获并显示用户绘画,将其转换为模型所需的数据,并处理和呈现模型返回的分类结果。 希望您在漫长的旅途中学到了很多东西。
到目前为止,除了来自其他开放源代码项目的几个模型以外,所有由我们自己进行预训练或训练的模型,我们在 iOS 和 Android 应用中使用的都是 TensorFlow 开放源代码项目,当然,该项目提供了大量强大的模型,其中一些模型在强大的 GPU 上进行了数周的训练。 但是,如果您有兴趣从头开始构建自己的模型,并且还对本章中使用和应用的强大 RNN 模型以及概念感到困惑,那么下一章就是您所需要的:我们将讨论如何从头开始构建自己的 RNN 模型并在移动应用中使用它,从而带来另一种乐趣-从股市中赚钱-至少我们会尽力做到这一点。 当然,没有人能保证您每次都能从每次股票交易中获利,但是至少让我们看看我们的 RNN 模型如何帮助我们提高这样做的机会。
八、用 RNN 预测股价
如果在上一章中在移动设备上玩过涂鸦和构建(并运行模型以识别涂鸦),当您在股市上赚钱时会感到很开心,而如果您不认真的话会变得很认真。 一方面,股价是时间序列数据,一系列离散时间数据,而处理时间序列数据的最佳深度学习方法是 RNN,这是我们在前两章中使用的方法。 AurélienGéron 在他的畅销书《Scikit-Learn 和 TensorFlow 机器学习实战》中,建议使用 RNN“分析时间序列数据,例如股票价格,并告诉您何时买卖”。 另一方面,其他人则认为股票的过去表现无法预测其未来收益,因此,随机选择的投资组合的表现与专家精心挑选的股票一样好。 实际上,Keras(在 TensorFlow 和其他几个库之上运行的非常受欢迎的高级深度学习库)的作者 FrançoisChollet 在他的畅销书《Python 深度学习》中表示,使用 RNN。 仅用公开数据来击败市场是“一项非常困难的努力,您可能会浪费时间和资源,而无所作为。”
因此,冒着“可能”浪费我们时间和资源的风险,但是可以肯定的是,我们至少将了解更多有关 RNN 的知识,以及为什么有可能比随机 50% 的策略更好地预测股价,我们将首先概述如何使用 RNN 进行股票价格预测,然后讨论如何使用 TensorFlow API 构建 RNN 模型来预测股票价格,以及如何使用易于使用的 Keras API 来为价格预测构建 RNN LSTM 模型。 我们将测试这些模型是否可以击败随机的买入或卖出策略。 如果我们对我们的模型感到满意,以提高我们在市场上的领先优势,或者只是出于专有技术的目的,我们将了解如何冻结并准备 TensorFlow 和 Keras 模型以在 iOS 和 Android 应用上运行。 如果该模型可以提高我们的机会,那么我们支持该模型的移动应用可以在任何时候,无论何时何地做出买或卖决定。 感觉有点不确定和兴奋? 欢迎来到市场。
总之,本章将涵盖以下主题:
- RNN 和股价预测:什么以及如何
- 使用 TensorFlow RNN API 进行股价预测
- 使用 Keras RNN LSTM API 进行股价预测
- 在 iOS 上运行 TensorFlow 和 Keras 模型
- 在 Android 上运行 TensorFlow 和 Keras 模型
RNN 和股价预测 – 什么以及如何
前馈网络(例如密集连接的网络)没有内存,无法将每个输入视为一个整体。 例如,表示为像素向量的图像输入在单个步骤中由前馈网络处理。 但是,使用具有内存的网络可以更好地处理时间序列数据,例如最近 10 或 20 天的股价。 假设过去 10 天的价格为X1, X2, ..., X10
,其中X1
为最早的和X10
为最晚,然后将所有 10 天价格视为一个序列输入,并且当 RNN 处理此类输入时,将发生以下步骤:
- 按顺序连接到第一个元素
X1
的特定 RNN 单元处理X1
并获取其输出y1
- 在序列输入中,连接到下一个元素
X2
的另一个 RNN 单元使用X2
以及先前的输出y1
, 获得下一个输出y2
- 重复该过程:在时间步长使用 RNN 单元处理输入序列中的
Xi
元素时,先前的输出y[i-1]
,在时间步i-1
与Xi
一起使用,以在时间步i
生成新的输出yi
。
因此,在时间步长i
的每个yi
输出,都具有有关输入序列中直到时间步长i
以及包括时间步长i
的所有元素的信息:X1, X2, ..., X[i-1]
和Xi
。 在 RNN 训练期间,预测价格y1, y2, ..., y9
和y10
的每个时间步长与每个时间步长的真实目标价格进行比较,即X2, X3, ..., X10
和X11
和损失函数因此被定义并用于优化以更新网络参数。 训练完成后,在预测期间,将X11
用作输入序列的预测,X1, X2, ..., X10
。
这就是为什么我们说 RNN 有内存。 RNN 对于处理股票价格数据似乎很有意义,因为直觉是,今天(以及明天和后天等等)的股票价格可能会受其前N
天的价格影响。
LSTM 只是解决 RNN 已知梯度消失问题的一种 RNN,我们在第 6 章,“用自然语言描述图像”中引入了 LSTM。 基本上,在训练 RNN 模型的过程中,,如果到 RNN 的输入序列的时间步太长,则使用反向传播更新较早时间步的网络权重可能会得到 0 的梯度值, 导致没有学习发生。 例如,当我们使用 50 天的价格作为输入,并且如果使用 50 天甚至 40 天的时间步长变得太长,则常规 RNN 将是不可训练的。 LSTM 通过添加一个长期状态来解决此问题,该状态决定可以丢弃哪些信息以及需要在许多时间步骤中存储和携带哪些信息。
可以很好地解决梯度消失问题的另一种 RNN 被称为门控循环单元(GRU),它稍微简化了标准 LSTM 模型,并且越来越受欢迎。 TensorFlow 和 Keras API 均支持基本的 RNN 和 LSTM/GRU 模型。 在接下来的两部分中,您将看到使用 RNN 和标准 LSTM 的具体 TensorFlow 和 Keras API,并且可以在代码中简单地将LSTM
替换为GRU
,以将使用 GRU 模型的结果与 RNN 和标准 LSTM 模型比较。
三种常用技术可以使 LSTM 模型表现更好:
- 堆叠 LSTM 层并增加层中神经元的数量:如果不产生过拟合,通常这将导致功能更强大,更准确的网络模型。 如果还没有,那么您绝对应该玩 TensorFlow Playground来体验一下。
- 使用丢弃处理过拟合。 删除意味着随机删除层中的隐藏单元和输入单元。
- 使用双向 RNN 在两个方向(常规方向和反向方向)处理每个输入序列,希望检测出可能被常规单向 RNN 忽略的模式。
所有这些技术已经实现,并且可以在 TensorFlow 和 Keras API 中轻松访问。
那么,我们如何使用 RNN 和 LSTM 测试股价预测? 我们将在这个页面上使用免费的 API 收集特定股票代码的每日股票价格数据,将其解析为训练集和测试集,并每次向 RNN/LSTM 模型提供一批训练输入(每个训练输入有 20 个时间步长,即,连续 20 天的价格),对模型进行训练,然后进行测试以查看模型在测试数据集中的准确率。 我们将同时使用 TensorFlow 和 Keras API 进行测试,并比较常规 RNN 和 LSTM 模型之间的差异。 我们还将测试三个略有不同的序列输入和输出,看看哪个是最好的:
- 根据过去
N
天预测一天的价格 - 根据过去
N
天预测M
天的价格 - 基于将过去
N
天移动 1 并使用预测序列的最后输出作为第二天的预测价格进行预测
现在让我们深入研究 TensorFlow RNN API 并进行编码以训练模型来预测股票价格,以查看其准确率如何。
将 TensorFlow RNN API 用于股价预测
首先,您需要在这里索取免费的 API 密钥,以便获取任何股票代码的股价数据。 取得 API 密钥后,打开终端并运行以下命令(将替换为您自己的密钥后)以获取 Amazon(amzn)和 Google(goog)的每日股票数据,或将它们替换为你感兴趣的任何符号:
curl -o daily_amzn.csv "https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol=amzn&apikey=<your_api_key>&datatype=csv&outputsize=full" curl -o daily_goog.csv "https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol=goog&apikey=<your_api_key>&datatype=csv&outputsize=full"
这将生成一个daily_amzn.csv
或daily_goog.csv
csv 文件 ,其顶行为“时间戳,开盘,高位,低位,收盘,交易量”,这些行的其余部分作为每日股票信息。 我们只关心收盘价,因此运行以下命令以获取所有收盘价:
cut -d ',' -f 5 daily_amzn.csv | tail -n +2 > amzn.txt cut -d ',' -f 5 daily_goog.csv | tail -n +2 > goog.txt
截至 2018 年 2 月 26 日,amzn.txt
或goog.txt
中的行数为 4,566 或 987,这是亚马逊或 Google 的交易天数。 现在,让我们看一下使用 TensorFlow RNN API 训练和预测模型的完整 Python 代码。
在 TensorFlow 中训练 RNN 模型
- 导入所需的 Python 包并定义一些常量:
import numpy as np import tensorflow as tf from tensorflow.contrib.rnn import * import matplotlib.pyplot as plt num_neurons = 100 num_inputs = 1 num_outputs = 1 symbol = 'goog' # amzn epochs = 500 seq_len = 20 learning_rate = 0.001
NumPy 是用于 N 维数组操作的最受欢迎的 Python 库,而 Matplotlib 是领先的 Python 2D 绘图库。 我们将使用 numpy 处理数据集,并使用 Matplotlib 可视化股票价格和预测。 num_neurons
是 RNN(或更准确地说是 RNN 单元)在每个时间步长上的神经元数量-每个神经元在该时间步长上都接收输入序列的输入元素,并从前一个时间步长上接收输出。 num_inputs
和num_outputs
指定每个时间步长的输入和输出数量-我们将从每个时间步长的 20 天输入序列中将一个股票价格提供给带有num_neurons
神经元的 RNN 单元,并在每个步骤期望一个预测的股票输出。 seq_len
是时间步数。 因此,我们将使用 Google 的 20 天股票价格作为输入序列,并将这些输入发送给具有 100 个神经元的 RNN 单元。
- 打开并读取包含所有价格的文本文件,将价格解析为
float
数字列表,颠倒列表顺序,以便最早的价格首先开始,然后每次添加seq_len+1
值(第一个seq_len
值将是 RNN 的输入序列,最后的seq_len
值将是目标输出序列),从列表中的第一个开始,每次移动 1 直到列表的末尾,直到一个 numpyresult
数组:
f = open(symbol + '.txt', 'r').read() data = f.split('\n')[:-1] # get rid of the last '' so float(n) works data.reverse() d = [float(n) for n in data] result = [] for i in range(len(d) - seq_len - 1): result.append(d[i: i + seq_len + 1]) result = np.array(result)
result
数组现在包含我们模型的整个数据集,但是我们需要将其进一步处理为 RNN API 期望的格式。 首先,将其分为训练集(占整个数据集的 90%)和测试集(占 10%):
row = int(round(0.9 * result.shape[0])) train = result[:row, :] test = result[row:, :]
然后随机地随机排列训练集,作为机器学习模型训练中的标准做法:
np.random.shuffle(train)
制定训练集和测试集X_train
和X_test
的输入序列,以及训练集和测试集y_train
和y_test
的目标输出序列。 请注意,大写字母X
和小写字母y
是机器学习中常用的命名约定,分别代表输入和目标输出:
X_train = train[:, :-1] # all rows with all columns except the last one X_test = test[:, :-1] # each row contains seq_len + 1 columns y_train = train[:, 1:] y_test = test[:, 1:]
最后,将四个数组重塑为 3-D(批大小,时间步数以及输入或输出数),以完成训练和测试数据集的准备:
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], num_inputs)) X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], num_inputs)) y_train = np.reshape(y_train, (y_train.shape[0], y_train.shape[1], num_outputs)) y_test = np.reshape(y_test, (y_test.shape[0], y_test.shape[1], num_outputs))
注意,X_train.shape[1]
,X_test.shape[1]
,y_train.shape[1]
和y_test.shape[1]
与seq_len
相同。
- 我们已经准备好构建模型。 创建两个占位符,以便在训练期间和
X_test
一起喂入X_train
和y_train
:
X = tf.placeholder(tf.float32, [None, seq_len, num_inputs]) y = tf.placeholder(tf.float32, [None, seq_len, num_outputs])
使用BasicRNNCell
创建一个 RNN 单元,每个时间步分别具有 num_neurons
神经元,:
cell = tf.contrib.rnn.OutputProjectionWrapper( tf.contrib.rnn.BasicRNNCell(num_units=num_neurons, activation=tf.nn.relu), output_size=num_outputs) outputs, _ = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
OutputProjectionWrapper
用于在每个单元的输出之上添加一个完全连接的层,因此,在每个时间步长处,RNN 单元的输出(将是num_neurons
值的序列)都会减小为单个值。 这就是 RNN 在每个时间步为输入序列中的每个值输出一个值,或为每个实例的seq_len
个数的值的每个输入序列输出总计seq_len
个数的值的方式。
dynamic_rnn
用于循环所有时间步长的 RNN 信元,总和为seq_len
(在X
形状中定义),它返回两个值:每个时间步长的输出列表,以及网络的最终状态。 接下来,我们将使用第一个outputs
返回的整形值来定义损失函数。
- 通过以标准方式指定预测张量,损失,优化器和训练操作来完成模型定义:
preds = tf.reshape(outputs, [1, seq_len], name="preds") loss = tf.reduce_mean(tf.square(outputs - y)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) training_op = optimizer.minimize(loss)
请注意,当我们使用freeze_graph
工具准备要在移动设备上部署的模型时,"preds"
将用作输出节点名称,它也将在 iOS 和 Android 中用于运行模型进行预测。 如您所见,在我们甚至开始训练模型之前一定要知道那条信息,这绝对是一件很高兴的事情,而这是我们从头开始构建的模型的好处。
- 开始训练过程。 对于每个周期,我们将
X_train
和y_train
数据输入以运行training_op
以最小化loss
,然后保存模型检查点文件,并每 10 个周期打印损失值:
init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: init.run() count = 0 for _ in range(epochs): n=0 sess.run(training_op, feed_dict={X: X_train, y: y_train}) count += 1 if count % 10 == 0: saver.save(sess, "/tmp/" + symbol + "_model.ckpt") loss_val = loss.eval(feed_dict={X: X_train, y: y_train}) print(count, "loss:", loss_val)
如果您运行上面的代码,您将看到如下输出:
(10, 'loss:', 243802.61) (20, 'loss:', 80629.57) (30, 'loss:', 40018.996) (40, 'loss:', 28197.496) (50, 'loss:', 24306.758) ... (460, 'loss:', 93.095985) (470, 'loss:', 92.864082) (480, 'loss:', 92.33461) (490, 'loss:', 92.09893) (500, 'loss:', 91.966286)
您可以在第 4 步中用BasicLSTMCell
替换BasicRNNCell
并运行训练代码,但是使用BasicLSTMCell
进行训练要慢得多,并且在 500 个周期之后损失值仍然很大。 在本节中,我们将不再对BasicLSTMCell
进行实验,但是为了进行比较,在使用 Keras 的下一部分中,您将看到堆叠 LSTM 层,丢弃法和双向 RNN 的详细用法。
测试 TensorFlow RNN 模型
要查看 500 个周期后的损失值是否足够好,让我们使用测试数据集添加以下代码,以计算总测试示例中正确预测的数量(正确的意思是,预测价格在目标价格的同一个方向上上下波动,相对于前一天的价格):
correct = 0 y_pred = sess.run(outputs, feed_dict={X: X_test}) targets = [] predictions = [] for i in range(y_pred.shape[0]): input = X_test[i] target = y_test[i] prediction = y_pred[i] targets.append(target[-1][0]) predictions.append(prediction[-1][0]) if target[-1][0] >= input[-1][0] and prediction[-1][0] >= input[-1][0]: correct += 1 elif target[-1][0] < input[-1][0] and prediction[-1][0] < input[-1][0]: correct += 1
现在我们可以使用plot
方法可视化预测正确率:
total = len(X_test) xs = [i for i, _ in enumerate(y_test)] plt.plot(xs, predictions, 'r-', label='prediction') plt.plot(xs, targets, 'b-', label='true') plt.legend(loc=0) plt.title("%s - %d/%d=%.2f%%" %(symbol, correct, total, 100*float(correct)/total)) plt.show()
现在运行代码将显示如图 8.1 所示,正确预测的比率为 56.25% :
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Q8Ngho3K-1681653119035)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/71f57975-9e80-4886-a13f-7a38b52dc84d.png)]
图 8.1:显示使用 TensorFlow RNN 训练的股价预测正确性
注意,每次运行此训练和测试代码时,您获得的比率可能都会有所不同。 通过微调模型的超参数,您可能会获得超过 60% 的比率,这似乎比随机预测要好。 如果您乐观的话,您可能会认为至少有 50% (56.25%)的东西要显示出来,并且可能希望看到该模型在移动设备上运行。 但首先让我们看看是否可以使用酷的 Keras 库来构建更好的模型-在执行此操作之前,让我们通过简单地运行来冻结经过训练的 TensorFlow 模型:
python tensorflow/python/tools/freeze_graph.py --input_meta_graph=/tmp/amzn_model.ckpt.meta --input_checkpoint=/tmp/amzn_model.ckpt --output_graph=/tmp/amzn_tf_frozen.pb --output_node_names="preds" --input_binary=true
将 Keras RNN LSTM API 用于股价预测
Keras 是一个非常易于使用的高级深度学习 Python 库,它运行在 TensorFlow,Theano 和 CNTK 等其他流行的深度学习库之上。 您很快就会看到,Keras 使构建和使用模型变得更加容易。 要安装和使用 Keras 以及 TensorFlow 作为 Keras 的后端,最好首先设置一个 VirtualEnv:
sudo pip install virtualenv
如果您的机器和 iOS 和 Android 应用上都有 TensorFlow 1.4 源,请运行以下命令;否则,请运行以下命令。 使用 TensorFlow 1.4 自定义库:
cd mkdir ~/tf14_keras virtualenv --system-site-packages ~/tf14_keras/ cd ~/tf14_keras/ source ./bin/activate easy_install -U pip pip install --upgrade https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.4.0-py2-none-any.whl pip install keras
如果您的机器上装有 TensorFlow 1.5 源,则应在 Keras 上安装 TensorFlow 1.5,因为使用 Keras 创建的模型需要具有与 TensorFlow 移动应用所使用的模型相同的 TensorFlow 版本,或者在尝试加载模型时发生错误:
cd mkdir ~/tf15_keras virtualenv --system-site-packages ~/tf15_keras/ cd ~/tf15_keras/ source ./bin/activate easy_install -U pip pip install --upgrade https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0-py2-none-any.whl pip install keras
如果您的操作系统不是 Mac 或计算机具有 GPU,则您需要用正确的 URL 替换 TensorFlow Python 包 URL,您可以在这个页面上找到它。
在 Keras 中训练 RNN 模型
现在,让我们看看在 Keras 中建立和训练 LSTM 模型以预测股价的过程。 首先,一些导入和常量设置:
import keras from keras import backend as K from keras.layers.core import Dense, Activation, Dropout from keras.layers.recurrent import LSTM from keras.layers import Bidirectional from keras.models import Sequential import matplotlib.pyplot as plt import tensorflow as tf import numpy as np symbol = 'amzn' epochs = 10 num_neurons = 100 seq_len = 20 pred_len = 1 shift_pred = False
shift_pred
用于指示我们是否要预测价格的输出序列而不是单个输出价格。 如果是True
,我们将根据输入X1, X2, ..., Xn
来预测X2, X3, ..., X[n+1]
,就像我们在使用 TensorFlow API 的最后一部分中所做的那样。 如果shift_pred
为False
,我们将基于输入X1, X2, ..., Xn
来预测输出的pred_len
。 例如,如果pred_len
为 1,我们将预测X[n+1]
,如果pred_len
为 3,我们将预测X[n+1], X[n+2], X[n+3]
,这很有意义,因为我们很想知道价格是连续连续 3 天上涨还是仅上涨 1 天然后下降 2 天。
ensorFlow 智能移动项目:6~10(4)https://developer.aliyun.com/article/1426911