ensorFlow 智能移动项目:6~10(3)https://developer.aliyun.com/article/1426910
现在,让我们创建一个根据上一节中的数据加载代码进行修改的方法,该方法根据pred_len
和shift_pred
设置准备适当的训练和测试数据集:
def load_data(filename, seq_len, pred_len, shift_pred): f = open(filename, 'r').read() data = f.split('\n')[:-1] # get rid of the last '' so float(n) works data.reverse() d = [float(n) for n in data] lower = np.min(d) upper = np.max(d) scale = upper-lower normalized_d = [(x-lower)/scale for x in d] result = [] if shift_pred: pred_len = 1 for i in range((len(normalized_d) - seq_len - pred_len)/pred_len): result.append(normalized_d[i*pred_len: i*pred_len + seq_len + pred_len]) result = np.array(result) row = int(round(0.9 * result.shape[0])) train = result[:row, :] test = result[row:, :] np.random.shuffle(train) X_train = train[:, :-pred_len] X_test = test[:, :-pred_len] if shift_pred: y_train = train[:, 1:] y_test = test[:, 1:] else: y_train = train[:, -pred_len:] y_test = test[:, -pred_len:] X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1)) X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1)) return [X_train, y_train, X_test, y_test, lower, scale]
注意,在这里我们也使用归一化,使用与上一章相同的归一化方法,以查看它是否可以改善我们的模型。 当使用训练模型进行预测时,我们还返回lower
和scale
值,这是非规范化所需的值。
现在我们可以调用load_data
来获取训练和测试数据集,以及lower
和scale
值:
X_train, y_train, X_test, y_test, lower, scale = load_data(symbol + '.txt', seq_len, pred_len, shift_pred)
完整的模型构建代码如下:
model = Sequential() model.add(Bidirectional(LSTM(num_neurons, return_sequences=True, input_shape=(None, 1)), input_shape=(seq_len, 1))) model.add(Dropout(0.2)) model.add(LSTM(num_neurons, return_sequences=True)) model.add(Dropout(0.2)) model.add(LSTM(num_neurons, return_sequences=False)) model.add(Dropout(0.2)) if shift_pred: model.add(Dense(units=seq_len)) else: model.add(Dense(units=pred_len)) model.add(Activation('linear')) model.compile(loss='mse', optimizer='rmsprop') model.fit( X_train, y_train, batch_size=512, epochs=epochs, validation_split=0.05) print(model.output.op.name) print(model.input.op.name)
即使使用新添加的Bidirectional
,Dropout
,validation_split
和堆叠 LSTM 层,该代码也比 TensorFlow 中的模型构建代码更容易解释和简化。 请注意,LSTM 调用中的return_sequences
参数i
必须为True
,因此 LSTM 单元的输出将是完整的输出序列,而不仅仅是输出序列中的最后一个输出, 除非它是最后的堆叠层。 最后两个 print
语句将打印输入节点名称( bidirectional_1_input
)和输出节点名称(activation_1/Identity
),当我们冻结模型并在移动设备上运行模型时需要。
现在,如果您运行前面的代码,您将看到如下输出:
824/824 [==============================] - 7s 9ms/step - loss: 0.0833 - val_loss: 0.3831 Epoch 2/10 824/824 [==============================] - 2s 3ms/step - loss: 0.2546 - val_loss: 0.0308 Epoch 3/10 824/824 [==============================] - 2s 2ms/step - loss: 0.0258 - val_loss: 0.0098 Epoch 4/10 824/824 [==============================] - 2s 2ms/step - loss: 0.0085 - val_loss: 0.0035 Epoch 5/10 824/824 [==============================] - 2s 2ms/step - loss: 0.0044 - val_loss: 0.0026 Epoch 6/10 824/824 [==============================] - 2s 2ms/step - loss: 0.0038 - val_loss: 0.0022 Epoch 7/10 824/824 [==============================] - 2s 2ms/step - loss: 0.0033 - val_loss: 0.0019 Epoch 8/10 824/824 [==============================] - 2s 2ms/step - loss: 0.0030 - val_loss: 0.0019 Epoch 9/10 824/824 [==============================] - 2s 2ms/step - loss: 0.0028 - val_loss: 0.0017 Epoch 10/10 824/824 [==============================] - 2s 3ms/step - loss: 0.0027 - val_loss: 0.0019
训练损失和验证损失都可以通过简单调用model.fit
进行打印。
测试 Keras RNN 模型
现在该保存模型检查点并使用测试数据集来计算正确预测的数量,正如我们在上一节中所解释的那样:
saver = tf.train.Saver() saver.save(K.get_session(), '/tmp/keras_' + symbol + '.ckpt') predictions = [] correct = 0 total = pred_len*len(X_test) for i in range(len(X_test)): input = X_test[i] y_pred = model.predict(input.reshape(1, seq_len, 1)) predictions.append(scale * y_pred[0][-1] + lower) if shift_pred: if y_test[i][-1] >= input[-1][0] and y_pred[0][-1] >= input[-1] [0]: correct += 1 elif y_test[i][-1] < input[-1][0] and y_pred[0][-1] < input[-1][0]: correct += 1 else: for j in range(len(y_test[i])): if y_test[i][j] >= input[-1][0] and y_pred[0][j] >= input[-1][0]: correct += 1 elif y_test[i][j] < input[-1][0] and y_pred[0][j] < input[-1][0]: correct += 1
我们主要调用model.predict
来获取X_test
中每个实例的预测,并将其与真实值和前一天的价格一起使用,以查看在方向方面是否为正确的预测。 最后,让我们根据测试数据集和预测来绘制真实价格:
y_test = scale * y_test + lower y_test = y_test[:, -1] xs = [i for i, _ in enumerate(y_test)] plt.plot(xs, y_test, 'g-', label='true') plt.plot(xs, predictions, 'r-', label='prediction') plt.legend(loc=0) if shift_pred: plt.title("%s - epochs=%d, shift_pred=True, seq_len=%d: %d/%d=%.2f%%" %(symbol, epochs, seq_len, correct, total, 100*float(correct)/total)) else: plt.title("%s - epochs=%d, lens=%d,%d: %d/%d=%.2f%%" %(symbol, epochs, seq_len, pred_len, correct, total, 100*float(correct)/total)) plt.show()
您会看到类似图 8.2 的内容:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tZ7FK7s8-1681653119036)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/b8227667-1d0d-4ea7-bf0b-fa7dac192064.png)]
图 8.2:使用 Keras 双向和堆叠 LSTM 层进行股价预测
很容易在栈中添加更多 LSTM 层,或者使用诸如学习率和丢弃率以及许多恒定设置之类的超参数。 但是,对于使用pred_len
和shift_pred
的不同设置,正确率的差异还没有发现。 也许我们现在应该对接近 60% 的正确率感到满意,并看看如何在 iOS 和 Android 上使用 TensorFlow 和 Keras 训练的模型-我们可以在以后继续尝试改进模型,但是,了解使用 TensorFlow 和 Keras 训练的 RNN 模型是否会遇到任何问题将非常有价值。
正如 FrançoisChollet 指出的那样,“深度学习更多的是艺术而不是科学……每个问题都是独特的,您将不得不尝试并经验地评估不同的策略。目前尚无理论可以提前准确地告诉您应该做什么。 以最佳方式解决问题。您必须尝试并进行迭代。” 希望我们为您使用 TensorFlow 和 Keras API 改善股票价格预测模型提供了一个很好的起点。
本节中最后要做的就是从检查点冻结 Keras 模型-因为我们在虚拟环境中安装了 TensorFlow 和 Keras,而 TensorFlow 是 VirtualEnv 中唯一安装并受支持的深度学习库,Keras 使用 TensorFlow 后端,并通过saver.save(K.get_session(), '/tmp/keras_' + symbol + '.ckpt')
调用以 TensorFlow 格式生成检查点。 现在运行以下命令冻结检查点(回想我们在训练期间从print(model.input.op.name)
获得output_node_name
):
python tensorflow/python/tools/freeze_graph.py --input_meta_graph=/tmp/keras_amzn.ckpt.meta --input_checkpoint=/tmp/keras_amzn.ckpt --output_graph=/tmp/amzn_keras_frozen.pb --output_node_names="activation_1/Identity" --input_binary=true
因为我们的模型非常简单明了,所以我们将直接在移动设备上尝试这两个冻结的模型,而无需像前两章中那样使用transform_graph
工具。
在 iOS 上运行 TensorFlow 和 Keras 模型
我们不会通过重复项目设置步骤来烦您-只需按照我们之前的操作即可创建一个名为 StockPrice 的新 Objective-C 项目,该项目将使用手动构建的 TensorFlow 库(请参阅第 7 章,“使用 CNN 和 LSTM 识别绘画”的 iOS 部分(如果需要详细信息)。 然后将两个模型文件amzn_tf_frozen.pb
和amzn_keras_frozen.pb
添加到项目中,您应该在 Xcode 中拥有 StockPrice 项目,如图 8.3 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VZGO2jEI-1681653119036)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/454020d6-1f32-480d-b159-b474b6878540.png)]
图 8.3:在 Xcode 中使用 TensorFlow 和 Keras 训练的模型的 iOS 应用
在ViewController.mm
中,我们将首先声明一些变量和一个常量:
unique_ptr<tensorflow::Session> tf_session; UITextView *_tv; UIButton *_btn; NSMutableArray *_closeprices; const int SEQ_LEN = 20;
然后创建一个按钮点击处理器,以使用户可以选择 TensorFlow 或 Keras 模型(该按钮在viewDidLoad
方法中像以前一样创建):
- (IBAction)btnTapped:(id)sender { UIAlertAction* tf = [UIAlertAction actionWithTitle:@"Use TensorFlow Model" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) { [self getLatestData:NO]; }]; UIAlertAction* keras = [UIAlertAction actionWithTitle:@"Use Keras Model" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) { [self getLatestData:YES]; }]; UIAlertAction* none = [UIAlertAction actionWithTitle:@"None" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) {}]; UIAlertController* alert = [UIAlertController alertControllerWithTitle:@"RNN Model Pick" message:nil preferredStyle:UIAlertControllerStyleAlert]; [alert addAction:tf]; [alert addAction:keras]; [alert addAction:none]; [self presentViewController:alert animated:YES completion:nil]; }
getLatestData
方法首先发出 URL 请求以获取紧凑型版本的 Alpha Vantage API,该 API 返回 Amazon 每日股票数据的最后 100 个数据点,然后解析结果并将最后 20 个收盘价保存在_closeprices
数组中:
-(void)getLatestData:(BOOL)useKerasModel { NSURLSession *session = [NSURLSession sharedSession]; [[session dataTaskWithURL:[NSURL URLWithString:@"https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol=amzn&apikey=<your_api_key>&datatype=csv&outputsize=compact"] completionHandler:^(NSData *data, NSURLResponse *response, NSError *error) { NSString *stockinfo = [[NSString alloc] initWithData:data encoding:NSASCIIStringEncoding]; NSArray *lines = [stockinfo componentsSeparatedByString:@"\n"]; _closeprices = [NSMutableArray array]; for (int i=0; i<SEQ_LEN; i++) { NSArray *items = [lines[i+1] componentsSeparatedByString:@","]; [_closeprices addObject:items[4]]; } if (useKerasModel) [self runKerasModel]; else [self runTFModel]; }] resume]; }
runTFModel
方法定义如下:
- (void) runTFModel { tensorflow::Status load_status; load_status = LoadModel(@"amzn_tf_frozen", @"pb", &tf_session); tensorflow::Tensor prices(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, SEQ_LEN, 1})); auto prices_map = prices.tensor<float, 3>(); NSString *txt = @"Last 20 Days:\n"; for (int i = 0; i < SEQ_LEN; i++){ prices_map(0,i,0) = [_closeprices[SEQ_LEN-i-1] floatValue]; txt = [NSString stringWithFormat:@"%@%@\n", txt, _closeprices[SEQ_LEN-i-1]]; } std::vector<tensorflow::Tensor> output; tensorflow::Status run_status = tf_session->Run({{"Placeholder", prices}}, {"preds"}, {}, &output); if (!run_status.ok()) { LOG(ERROR) << "Running model failed:" << run_status; } else { tensorflow::Tensor preds = output[0]; auto preds_map = preds.tensor<float, 2>(); txt = [NSString stringWithFormat:@"%@\nPrediction with TF RNN model:\n%f", txt, preds_map(0,SEQ_LEN-1)]; dispatch_async(dispatch_get_main_queue(), ^{ [_tv setText:txt]; [_tv sizeToFit]; }); } }
preds_map(0,SEQ_LEN-1)
是基于最近 20 天的第二天的预测价格; Placeholder
是“在 TensorFlow 中训练 RNN 模型”小节的第四步的X = tf.placeholder(tf.float32, [None, seq_len, num_inputs])
中定义的输入节点名称。 在模型生成预测后,我们将其与最近 20 天的价格一起显示在TextView
中。
runKeras
方法的定义与此类似,但具有反规范化以及不同的输入和输出节点名称。 由于我们的 Keras 模型经过训练只能输出一个预测价格,而不是一系列seq_len
价格,因此我们使用preds_map(0,0)
来获得预测:
- (void) runKerasModel { tensorflow::Status load_status; load_status = LoadModel(@"amzn_keras_frozen", @"pb", &tf_session); if (!load_status.ok()) return; tensorflow::Tensor prices(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, SEQ_LEN, 1})); auto prices_map = prices.tensor<float, 3>(); float lower = 5.97; float scale = 1479.37; NSString *txt = @"Last 20 Days:\n"; for (int i = 0; i < SEQ_LEN; i++){ prices_map(0,i,0) = ([_closeprices[SEQ_LEN-i-1] floatValue] - lower)/scale; txt = [NSString stringWithFormat:@"%@%@\n", txt, _closeprices[SEQ_LEN-i-1]]; } std::vector<tensorflow::Tensor> output; tensorflow::Status run_status = tf_session->Run({{"bidirectional_1_input", prices}}, {"activation_1/Identity"}, {}, &output); if (!run_status.ok()) { LOG(ERROR) << "Running model failed:" << run_status; } else { tensorflow::Tensor preds = output[0]; auto preds_map = preds.tensor<float, 2>(); txt = [NSString stringWithFormat:@"%@\nPrediction with Keras RNN model:\n%f", txt, scale * preds_map(0,0) + lower]; dispatch_async(dispatch_get_main_queue(), ^{ [_tv setText:txt]; [_tv sizeToFit]; }); } }
如果您现在运行该应用并点击Predict
按钮,您将看到模型选择消息(图 8.4):
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VMADgb2i-1681653119036)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/067fa448-a497-4f13-944f-e976240694c9.png)]
图 8.4:选择 TensorFlow 或 Keras RNN 模型
如果选择 TensorFlow 模型,则可能会出现错误:
Could not create TensorFlow Graph: Invalid argument: No OpKernel was registered to support Op 'Less' with these attrs. Registered devices: [CPU], Registered kernels: device='CPU'; T in [DT_FLOAT] [[Node: rnn/while/Less = Less[T=DT_INT32, _output_shapes=[[]]](rnn/while/Merge, rnn/while/Less/Enter)]]
如果选择 Keras 模型,则可能会出现稍微不同的错误:
Could not create TensorFlow Graph: Invalid argument: No OpKernel was registered to support Op 'Less' with these attrs. Registered devices: [CPU], Registered kernels: device='CPU'; T in [DT_FLOAT] [[Node: bidirectional_1/while_1/Less = Less[T=DT_INT32, _output_shapes=[[]]](bidirectional_1/while_1/Merge, bidirectional_1/while_1/Less/Enter)]]
我们在上一章中已经看到RefSwitch
操作出现类似的错误,并且知道针对此类错误的解决方法是在启用 -D__ANDROID_TYPES_FULL__
的情况下构建 TensorFlow 库。 如果没有看到这些错误,则意味着您在上一章的 iOS 应用中已建立了这样的库; 否则,请按照“为 iOS 构建自定义 TensorFlow 库”的开头的说明。 上一章的内容构建新的 TensorFlow 库,然后再次运行该应用。
现在选择 TensorFlow 模型,您将看到如图 8.5 所示的结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-j0y2fPol-1681653119036)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/28715473-1f89-4f14-8e81-f98242a11c8d.png)]
图 8.5:使用 TensorFlow RNN 模型进行预测
使用 Keras 模型输出不同的预测,如图 8.6 所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cET3Uh9U-1681653119037)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/4114a6b9-15da-41f7-b535-bf04c7fb4dad.png)]
图 8.6:使用 Keras RNN 模型进行预测
我们无法确定哪个模型能在没有进一步研究的情况下更好地工作,但是我们可以确定的是,我们的两个 RNN 模型都使用 TensorFlow 和 Keras API 从头开始训练了,其准确率接近 60%, 在 iOS 上运行良好,这很值得我们付出努力,因为我们正在尝试建立一个许多专家认为将达到与随机选择相同的表现的模型,并且在此过程中,我们学到了一些新奇的东西-使用 TensorFlow 和 Keras 构建 RNN 模型并在 iOS 上运行它们。 在下一章中,我们只剩下一件事了:如何在 Android 上使用模型? 我们会遇到新的障碍吗?
在 Android 上运行 TensorFlow 和 Keras 模型
事实证明,这就像使用 Android 上的模型在沙滩上散步一样-尽管我们必须使用自定义的 TensorFlow 库(而不是 TensorFlow pod),我们甚至不需要像上一章那样使用自定义的 TensorFlow Android 库。 截至 2018 年 2 月)。 与用于 iOS 的 TensorFlow Pod 相比,在build.gradle
文件中使用compile 'org.tensorflow:tensorflow-android:+'
构建的 TensorFlow Android 库必须对Less
操作具有更完整的数据类型支持。
要在 Android 中测试模型,请创建一个新的 Android 应用 StockPrice,并将两个模型文件添加到其assets
文件夹中。 然后在布局中添加几个按钮和一个TextView
并在MainActivity.java
中定义一些字段和常量:
private static final String TF_MODEL_FILENAME = "file:///android_asset/amzn_tf_frozen.pb"; private static final String KERAS_MODEL_FILENAME = "file:///android_asset/amzn_keras_frozen.pb"; private static final String INPUT_NODE_NAME_TF = "Placeholder"; private static final String OUTPUT_NODE_NAME_TF = "preds"; private static final String INPUT_NODE_NAME_KERAS = "bidirectional_1_input"; private static final String OUTPUT_NODE_NAME_KERAS = "activation_1/Identity"; private static final int SEQ_LEN = 20; private static final float LOWER = 5.97f; private static final float SCALE = 1479.37f; private TensorFlowInferenceInterface mInferenceInterface; private Button mButtonTF; private Button mButtonKeras; private TextView mTextView; private boolean mUseTFModel; private String mResult;
制作onCreate
如下:
protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); mButtonTF = findViewById(R.id.tfbutton); mButtonKeras = findViewById(R.id.kerasbutton); mTextView = findViewById(R.id.textview); mTextView.setMovementMethod(new ScrollingMovementMethod()); mButtonTF.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { mUseTFModel = true; Thread thread = new Thread(MainActivity.this); thread.start(); } }); mButtonKeras.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { mUseTFModel = false; Thread thread = new Thread(MainActivity.this); thread.start(); } }); }
其余代码全部在run
方法中,在点击TF PREDICTION
或KERAS PREDICTION
按钮时在工作线程中启动,需要一些解释,使用 Keras 模型需要在运行模型之前和之后规范化和非规范化:
public void run() { runOnUiThread( new Runnable() { @Override public void run() { mTextView.setText("Getting data..."); } }); float[] floatValues = new float[SEQ_LEN]; try { URL url = new URL("https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol=amzn&apikey=4SOSJM2XCRIB5IUS&datatype=csv&outputsize=compact"); HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection(); InputStream in = new BufferedInputStream(urlConnection.getInputStream()); Scanner s = new Scanner(in).useDelimiter("\\n"); mResult = "Last 20 Days:\n"; if (s.hasNext()) s.next(); // get rid of the first title line List<String> priceList = new ArrayList<>(); while (s.hasNext()) { String line = s.next(); String[] items = line.split(","); priceList.add(items[4]); } for (int i=0; i<SEQ_LEN; i++) mResult += priceList.get(SEQ_LEN-i-1) + "\n"; for (int i=0; i<SEQ_LEN; i++) { if (mUseTFModel) floatValues[i] = Float.parseFloat(priceList.get(SEQ_LEN-i-1)); else floatValues[i] = (Float.parseFloat(priceList.get(SEQ_LEN-i-1)) - LOWER) / SCALE; } AssetManager assetManager = getAssets(); mInferenceInterface = new TensorFlowInferenceInterface(assetManager, mUseTFModel ? TF_MODEL_FILENAME : KERAS_MODEL_FILENAME); mInferenceInterface.feed(mUseTFModel ? INPUT_NODE_NAME_TF : INPUT_NODE_NAME_KERAS, floatValues, 1, SEQ_LEN, 1); float[] predictions = new float[mUseTFModel ? SEQ_LEN : 1]; mInferenceInterface.run(new String[] {mUseTFModel ? OUTPUT_NODE_NAME_TF : OUTPUT_NODE_NAME_KERAS}, false); mInferenceInterface.fetch(mUseTFModel ? OUTPUT_NODE_NAME_TF : OUTPUT_NODE_NAME_KERAS, predictions); if (mUseTFModel) { mResult += "\nPrediction with TF RNN model:\n" + predictions[SEQ_LEN - 1]; } else { mResult += "\nPrediction with Keras RNN model:\n" + (predictions[0] * SCALE + LOWER); } runOnUiThread( new Runnable() { @Override public void run() { mTextView.setText(mResult); } }); } catch (Exception e) { e.printStackTrace(); } }
现在运行该应用,然后点击TF PREDICTION
按钮,您将在图 8.7 中看到结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yZX2YsPg-1681653119037)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/8ce6016b-658a-46c4-a051-04c772d503fc.png)]
图 8.7:使用 TensorFlow 模型在亚马逊上进行股价预测
选择 KERAS 预测将为您提供如图 8.8 所示的结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1l5o0sai-1681653119037)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/ec5b48b2-0c2a-439a-af59-303b9c628019.png)]
图 8.8:使用 Keras 模型在亚马逊上进行股价预测
总结
在本章中,我们首先对表示不屑一顾,试图通过使用 TensorFlow 和 Keras RNN API 预测股价来击败市场。 我们首先讨论了 RNN 和 LSTM 模型是什么以及如何使用它们进行股价预测。 然后,我们使用 TensorFlow 和 Keras 从零开始构建了两个 RNN 模型,接近测试正确率的 60%。 最后,我们介绍了如何冻结模型并在 iOS 和 Android 上使用它们,并使用自定义 TensorFlow 库修复了 iOS 上可能出现的运行时错误。
如果您对我们尚未建立预测正确率为 80% 或 90% 的模型感到有些失望,则可能需要继续进行“尝试并迭代”过程,以查看是否可以以该正确率预测股票价格。 但是,您肯定会从使用 TensorFlow 和 Keras API 的 RNN 模型构建,训练和测试中学到的技能以及在 iOS 和 Android 上运行的技能而受益。
如果您对使用深度学习技术打败市场感兴趣并感到兴奋,让我们在 GAN(生成对抗网络)上的下一章中进行研究,该模型试图击败能够分辨真实数据与虚假数据之间差异的对手, 并且越来越擅长生成看起来像真实数据的数据,欺骗对手。 GAN 实际上被深度学习的一些顶级研究人员誉为是过去十年中深度学习中最有趣和令人兴奋的想法。
九、使用 GAN 生成和增强图像
自 2012 年深度学习起步以来,有人认为 Ian Goodfellow 在 2014 年提出的生成对抗网络(GAN)比这更有趣或更有前途。 实际上, Facebook AI 研究主管和之一,深度学习研究人员之一的 Yann LeCun 将 GAN 和对抗训练称为,“这是近十年来机器学习中最有趣的想法。” 因此,我们如何在这里不介绍它,以了解 GAN 为什么如此令人兴奋,以及如何构建 GAN 模型并在 iOS 和 Android 上运行它们?
在本章中,我们将首先概述 GAN 是什么,它如何工作以及为什么它具有如此巨大的潜力。 然后,我们将研究两个 GAN 模型:一个基本的 GAN 模型可用于生成类似人的手写数字,另一个更高级的 GAN 模型可将低分辨率的图像增强为高分辨率的图像。 我们将向您展示如何在 Python 和 TensorFlow 中构建和训练此类模型,以及如何为移动部署准备模型。 然后,我们将提供带有完整源代码的 iOS 和 Android 应用,它们使用这些模型来生成手写数字并增强图像。 在本章的最后,您应该准备好进一步探索各种基于 GAN 的模型,或者开始构建自己的模型,并了解如何在移动应用中运行它们。
总之,本章将涵盖以下主题:
- GAN – 什么以及为什么
- 使用 TensorFlow 构建和训练 GAN 模型
- 在 iOS 中使用 GAN 模型
- 在 Android 中使用 GAN 模型
GAN – 什么以及为什么
GAN 是学习生成类似于真实数据或训练集中数据的神经网络。 GAN 的关键思想是让生成器网络和判别器网络相互竞争:生成器试图生成看起来像真实数据的数据,而判别器试图分辨生成的数据是否真实(从已知真实数据)或伪造(由生成器生成)。 生成器和判别器是一起训练的,在训练过程中,生成器学会生成看起来越来越像真实数据的数据,而判别器则学会将真实数据与伪数据区分开。 生成器通过尝试使判别器的输出概率为真实数据来学习,当将生成器的输出作为判别器的输入时,生成器的输出概率尽可能接近 1.0,而判别器通过尝试实现两个目标来学习:
- 当以生成器的输出作为输入时,使其输出的可能性为实,尽可能接近 0.0,这恰好是生成器的相反目标
- 当输入真实数据作为输入时,使其输出的可能性为实数,尽可能接近 1.0
在下一节中,您将看到与生成器和判别器网络及其训练过程的给定描述相匹配的详细代码片段。 如果您想了解更多关于 GAN 的知识,除了这里的摘要概述之外,您还可以在 YouTube 上搜索“GAN 简介”,并观看 2016 年 NIPS(神经信息处理系统)和 ICCV(国际计算机视觉会议)2017 大会上的 Ian Goodfellow 的 GAN 入门和教程视频。 事实上,YouTube 上有 7 个 NIPS 2016 对抗训练训练班视频和 12 个 ICCV 2017 GAN 指导视频,您可以自己投入其中。
在生成器和判别器两个参与者的竞争目标下,GAN 是一个寻求两个对手之间保持平衡的系统。 如果两个玩家都具有无限的能力并且可以进行最佳训练,那么纳什均衡(继 1994 年诺贝尔经济学奖得主约翰·纳什和电影主题《美丽心灵》之后) 一种状态,在这种状态下,任何玩家都无法通过仅更改其自己的策略来获利,这对应于生成器生成数据的状态,该数据看起来像真实数据,而判别器无法从假数据中分辨真实数据。
如果您有兴趣了解有关纳什均衡的更多信息,请访问 Google “可汗学院纳什均衡”,并观看 Sal Khan 撰写的两个有趣的视频。 《经济学家》解释经济学的“纳什均衡”维基百科页面和文章“纳什均衡是什么,为什么重要?”也是不错的读物。 了解 GAN 的基本直觉和想法将有助于您进一步了解 GAN 具有巨大潜力的原因。
生成器能够生成看起来像真实数据的数据的潜力意味着可以使用 GAN 开发各种出色的应用,例如:
- 从劣质图像生成高质量图像
- 图像修复(修复丢失或损坏的图像)
- 翻译图像(例如,从边缘草图到照片,或者在人脸上添加或移除诸如眼镜之类的对象)
- 从文本生成图像(和第 6 章,“使用自然语言描述图像”的 Text2Image 相反)
- 撰写看起来像真实新闻的新闻文章
- 生成与训练集中的音频相似的音频波形
基本上,GAN 可以从随机输入生成逼真的图像,文本或音频数据; 如果您具有一组源数据和目标数据的训练集,则 GAN 还可从类似于源数据的输入中生成类似于目标数据的数据。 GAN 模型中的生成器和判别器以动态方式工作的这一通用特性,使 GAN 可以生成任何种类的现实输出,这使 GAN 十分令人兴奋。
但是,由于生成器和判别器的动态或竞争目标,训练 GAN 达到纳什均衡状态是一个棘手且困难的问题。 实际上,这仍然是一个开放的研究问题 – Ian Goodfellow 在 2017 年 8 月对 Andrew Ng 进行的“深度学习英雄”采访中(YouTube 上的搜索ian goodfellow andrew ng
)说,如果我们可以使 GAN 变得像深度学习一样可靠,我们将看到 GAN 取得更大的成功,否则我们最终将用其他形式的生成模型代替它们。
尽管在 GAN 的训练方面存在挑战,但是在训练期间您已经可以应用许多有效的已知技巧 – 我们在这里不会介绍它们,但是如果您有兴趣调整我们将在本章中描述的模型或许多其他 GAN 模型 ),或构建自己的 GAN 模型。
使用 TensorFlow 构建和训练 GAN 模型
通常,GAN 模型具有两个神经网络:G
用于生成器,D
用于判别器。 x
是来自训练集的一些实际数据输入,z
是随机输入噪声。 在训练过程中,D(x)
是x
为真实的概率,D
尝试使D(x)
接近 1;G(z)
是具有随机输入z
的生成的输出,并且D
试图使D(G(z))
接近 0,但同时G
试图使D(G(z))
接近 1。 现在,让我们首先来看一下如何在 TensorFlow 和 Python 中构建基本的 GAN 模型,该模型可以编写或生成手写数字。
生成手写数字的基本 GAN 模型
手写数字的训练模型基于仓库,这是这个页面的分支,并添加了显示生成的数字并使用输入占位符保存 TensorFlow 训练模型的脚本,因此我们的 iOS 和 Android 应用可以使用该模型。 是的您应该查看原始仓库的博客。在继续之前,需要对具有代码的 GAN 模型有基本的了解。
在研究定义生成器和判别器网络并进行 GAN 训练的核心代码片段之前,让我们先运行脚本以在克隆存储库并转到仓库目录之后训练和测试模型:
git clone https://github.com/jeffxtang/generative-adversarial-networks cd generative-adversarial-networks
该派生向gan-script-fast.py
脚本添加了检查点保存代码,还添加了新脚本gan-script-test.py
以使用随机输入的占位符测试和保存新的检查点–因此,使用新检查点冻结的模型可以在 iOS 和 Android 应用中使用。
运行命令python gan-script-fast.py
训练模型,在 Ubuntu 上的 GTX-1070 GPU 上花费不到一小时。 训练完成后,检查点文件将保存在模型目录中。 现在运行python gan-script-test.py
来查看一些生成的手写数字。 该脚本还从模型目录读取检查点文件,并在运行gan-script-fast.py
时保存该文件,然后将更新的检查点文件以及随机输入占位符重新保存在newmodel
目录中:
ls -lt newmodel -rw-r--r-- 1 jeffmbair staff 266311 Mar 5 16:43 ckpt.meta -rw-r--r-- 1 jeffmbair staff 65 Mar 5 16:42 checkpoint -rw-r--r-- 1 jeffmbair staff 69252168 Mar 5 16:42 ckpt.data-00000-of-00001 -rw-r--r-- 1 jeffmbair staff 2660 Mar 5 16:42 ckpt.index
gan-script-test.py
中的下一个代码片段显示了输入节点名称(z_placeholder
)和输出节点名称(Sigmoid_1
),如print(generated_images)
所示:
z_placeholder = tf.placeholder(tf.float32, [None, z_dimensions], name='z_placeholder') ... saver.restore(sess, 'model/ckpt') generated_images = generator(z_placeholder, 5, z_dimensions) print(generated_images) images = sess.run(generated_images, {z_placeholder: z_batch}) saver.save(sess, "newmodel/ckpt")
在gan-script-fast.py
脚本中,方法def discriminator(images, reuse_variables=None)
定义了一个判别器网络,该网络使用一个真实的手写图像输入或由生成器生成的一个手写输入,经过一个典型的小型 CNN 网络,该网络具有两层conv2d
层,每一层都带有relu
激活和平均池化层以及两个完全连接的层来输出一个标量值,该标量值将保持输入图像为真或假的概率。 另一种方法def generator(batch_size, z_dim)
定义了生成器网络,该网络采用随机输入的图像向量并将其转换为具有 3 个conv2d
层的28 x 28
图像。
现在可以使用这两种方法来定义三个输出:
Gz
,即随机图像输入的生成器输出:Gz = generator(batch_size, z_dimensions)
Dx
,是真实图像输入的判别器输出:Dx = discriminator(x_placeholder)
Dg
,Gz
的判别器输出:Dg = discriminator(Gz, reuse_variables=True)
和三个损失函数:
d_loss_real
,Dx
和 1 之差:d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dx, labels = tf.ones_like(Dx)))
d_loss_fake
,Dg
和 0 之差:d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.zeros_like(Dg)))
g_loss
,Dg
和 1 之差:g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.ones_like(Dg)))
请注意,判别器尝试使 d_loss_fake
最小化,而生成器尝试使g_loss
最小化,两种情况下Dg
之间的差分别为 0 和 1。
最后,现在可以为三个损失函数设置三个优化器:d_trainer_fake
,d_trainer_real
和g_trainer
,它们全部是通过tf.train.AdamOptimizer
的minimize
方法定义的。
现在,脚本仅创建一个 TensorFlow 会话,通过运行三个优化器将生成器和判别器进行 100,000 步训练,将随机图像输入馈入生成器,将真实和伪图像输入均馈入判别器。
在运行 gan-script-fast.py
和gan-script-test.py
之后,将检查点文件从newmodel
目录运至/tmp
,然后转到 TensorFlow 源根目录并运行:
python tensorflow/python/tools/freeze_graph.py \ --input_meta_graph=/tmp/ckpt.meta \ --input_checkpoint=/tmp/ckpt \ --output_graph=/tmp/gan_mnist.pb \ --output_node_names="Sigmoid_1" \ --input_binary=true
这将创建可用于移动应用的冻结模型gan_mnist.pb
。 但是在此之前,让我们看一下可以增强低分辨率图像的更高级的 GAN 模型。
增强图像分辨率的高级 GAN 模型
我们将用于增强低分辨率模糊图像的模型,基于论文《使用条件对抗网络的图像到图像转换》及其 TensorFlow 实现 pix2pix。 在仓库的分支中,我们添加了两个脚本:
tools/convert.py
从普通图像创建模糊图像pix2pix_runinference.py
添加了一个用于低分辨率图像输入的占位符和一个用于返回增强图像的操作,并保存了新的检查点文件,我们将冻结这些文件以生成在移动设备上使用的模型文件。
基本上,pix2pix 使用 GAN 将输入图像映射到输出图像。 您可以使用不同类型的输入图像和输出图像来创建许多有趣的图像转换:
- 地图到航拍
- 白天到黑夜
- 边界到照片
- 黑白图像到彩色图像
- 损坏的图像到原始图像
- 从低分辨率图像到高分辨率图像
在所有情况下,生成器都将输入图像转换为输出图像,试图使输出看起来像真实的目标图像,判别器将训练集中的样本或生成器的输出作为输入,并尝试告诉它是真实图像还是生成器生成的图像。 自然,与模型相比,pix2pix 中的生成器和判别器网络以更复杂的方式构建以生成手写数字,并且训练还应用了一些技巧来使过程稳定-有关详细信息,您可以阅读本文或较早提供的 TensorFlow 实现链接。 我们在这里仅向您展示如何设置训练集和训练 pix2pix 模型以增强低分辨率图像。
- 通过在终端上运行来克隆仓库:
git clone https://github.com/jeffxtang/pix2pix-tensorflow cd pix2pix-tensorflow
- 创建一个新目录
photos/original
并复制一些图像文件-例如,我们将所有拉布拉多犬的图片从斯坦福狗数据集(在第 2 章,“使用迁移学习的图像分类”中使用)复制到photos/original
目录 - 运行脚本
python tools/process.py --input_dir photos/original --operation resize --output_dir photos/resized
调整photo/original
目录中图像的大小并将调整后的图像保存到photos/resized
目录中 - 运行
mkdir photos/blurry
,然后运行python tools/convert.py
,以使用流行的 ImageMagick 的convert
命令将调整大小的图像转换为模糊的图像。convert.py
的代码如下:
import os file_names = os.listdir("photos/resized/") for f in file_names: if f.find(".png") != -1: os.system("convert photos/resized/" + f + " -blur 0x3 photos/blurry/" + f)
- 将
photos/resized
和photos/blurry
中的每个文件合并为一个对,并将所有配对的图像(一个调整大小的图像,另一个模糊的版本)保存到photos/resized_blurry
目录:
python tools/process.py --input_dir photos/resized --b_dir photos/blurry --operation combine --output_dir photos/resized_blurry
- 运行拆分工具
python tools/split.py --dir photos/resized_blurry
,将文件转换为train
目录和val
目录 - 通过运行以下命令训练
pix2pix
模型:
python pix2pix.py \ --mode train \ --output_dir photos/resized_blurry/ckpt_1000 \ --max_epochs 1000 \ --input_dir photos/resized_blurry/train \ --which_direction BtoA
方向BtoA
表示从模糊图像转换为原始图像。 在 GTX-1070 GPU 上进行的训练大约需要四个小时,并且photos/resized_blurry/ckpt_1000
目录中生成的检查点文件如下所示:
-rw-rw-r-- 1 jeff jeff 1721531 Mar 2 18:37 model-136000.meta -rw-rw-r-- 1 jeff jeff 81 Mar 2 18:37 checkpoint -rw-rw-r-- 1 jeff jeff 686331732 Mar 2 18:37 model-136000.data-00000-of-00001 -rw-rw-r-- 1 jeff jeff 10424 Mar 2 18:37 model-136000.index -rw-rw-r-- 1 jeff jeff 3807975 Mar 2 14:19 graph.pbtxt -rw-rw-r-- 1 jeff jeff 682 Mar 2 14:19 options.json
- (可选)您可以在测试模式下运行脚本,然后在
--output_dir
指定的目录中检查图像翻译结果:
python pix2pix.py \ --mode test \ --output_dir photos/resized_blurry/output_1000 \ --input_dir photos/resized_blurry/val \ --checkpoint photos/resized_blurry/ckpt_1000
- 运行
pix2pix_runinference.py
脚本以恢复在步骤 7 中保存的检查点,为图像输入创建一个新的占位符,为它提供测试图像ww.png
,将翻译输出为result.png
,最后将新的检查点文件保存在newckpt
目录:
python pix2pix_runinference.py \ --mode test \ --output_dir photos/blurry_output \ --input_dir photos/blurry_test \ --checkpoint photos/resized_blurry/ckpt_1000
以下pix2pix_runinference.py
中的代码段设置并打印输入和输出节点:
image_feed = tf.placeholder(dtype=tf.float32, shape=(1, 256, 256, 3), name="image_feed") print(image_feed) # Tensor("image_feed:0", shape=(1, 256, 256, 3), dtype=float32) with tf.variable_scope("generator", reuse=True): output_image = deprocess(create_generator(image_feed, 3)) print(output_image) #Tensor("generator_1/deprocess/truediv:0", shape=(1, 256, 256, 3), dtype=float32)
具有tf.variable_scope("generator", reuse=True):
的行非常重要,因为需要共享generator
变量,以便可以使用所有训练后的参数值。 否则,您会看到奇怪的翻译结果。
以下代码显示了如何在newckpt
目录中填充占位符,运行 GAN 模型并保存生成器的输出以及检查点文件:
if a.mode == "test": from scipy import misc image = misc.imread("ww.png").reshape(1, 256, 256, 3) image = (image / 255.0) * 2 - 1 result = sess.run(output_image, feed_dict={image_feed:image}) misc.imsave("result.png", result.reshape(256, 256, 3)) saver.save(sess, "newckpt/pix2pix")
图 9.1 显示了原始测试图像,其模糊版本以及经过训练的 GAN 模型的生成器输出。 结果并不理想,但是 GAN 模型确实具有更好的分辨率而没有模糊效果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PNdQLBQU-1681653119037)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/b73deef3-1598-4019-ac72-5c0212d53c74.png)]
图 9.1:原始的,模糊的和生成的
- 现在,将
newckpt
目录复制到/tmp
,我们可以如下冻结模型:
python tensorflow/python/tools/freeze_graph.py \ --input_meta_graph=/tmp/newckpt/pix2pix.meta \ --input_checkpoint=/tmp/newckpt/pix2pix \ --output_graph=/tmp/newckpt/pix2pix.pb \ --output_node_names="generator_1/deprocess/truediv" \ --input_binary=true
- 生成的
pix2pix.pb
模型文件很大,约为 217MB,将其加载到 iOS 或 Android 设备上时会崩溃或导致内存不足(OOM)错误。 我们必须像在第 6 章,“使用自然语言描述图像”的复杂 im2txt 模型中所做的那样,将其转换为 iOS 的映射格式。
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ --in_graph=/tmp/newckpt/pix2pix.pb \ --out_graph=/tmp/newckpt/pix2pix_transformed.pb \ --inputs="image_feed" \ --outputs="generator_1/deprocess/truediv" \ --transforms='strip_unused_nodes(type=float, shape="1,256,256,3") fold_constants(ignore_errors=true, clear_output_shapes=true) fold_batch_norms fold_old_batch_norms' bazel-bin/tensorflow/contrib/util/convert_graphdef_memmapped_format \ --in_graph=/tmp/newckpt/pix2pix_transformed.pb \ --out_graph=/tmp/newckpt/pix2pix_transformed_memmapped.pb
pix2pix_transformed_memmapped.pb
模型文件现在可以在 iOS 中使用。
- 要为 Android 构建模型,我们需要量化冻结的模型,以将模型大小从 217MB 减少到约 54MB:
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ --in_graph=/tmp/newckpt/pix2pix.pb \ --out_graph=/tmp/newckpt/pix2pix_transformed_quantized.pb --inputs="image_feed" \ --outputs="generator_1/deprocess/truediv" \ --transforms='quantize_weights'
现在,让我们看看如何在移动应用中使用两个 GAN 模型。
在 iOS 中使用 GAN 模型
如果您尝试在 iOS 应用中使用 TensorFlow 窗格并加载gan_mnist.pb
文件,则会收到错误消息:
Could not create TensorFlow Graph: Invalid argument: No OpKernel was registered to support Op 'RandomStandardNormal' with these attrs. Registered devices: [CPU], Registered kernels: <no registered kernels> [[Node: z_1/RandomStandardNormal = RandomStandardNormal[T=DT_INT32, _output_shapes=[[50,100]], dtype=DT_FLOAT, seed=0, seed2=0](z_1/shape)]]
将行添加到tf_op_files.txt
之后,请确保tensorflow/contrib/makefile/tf_op_files.txt
文件具有tensorflow/core/kernels/random_op.cc
,该文件实现了RandomStandardNormal
操作,并且libtensorflow-core.a
是由 tensorflow/contrib/makefile/build_all_ios.sh
构建的。
此外,如果即使在使用 TensorFlow 1.4 构建的自定义 TensorFlow 库中尝试加载pix2pix_transformed_memmapped.pb
,也会出现以下错误:
No OpKernel was registered to support Op 'FIFOQueueV2' with these attrs. Registered devices: [CPU], Registered kernels: <no registered kernels> [[Node: batch/fifo_queue = FIFOQueueV2[_output_shapes=[[]], capacity=32, component_types=[DT_STRING, DT_FLOAT, DT_FLOAT], container="", shapes=[[], [256,256,1], [256,256,2]], shared_name=""]()]]
您需要将tensorflow/core/kernels/fifo_queue_op.cc
添加到tf_op_files.txt
并重建 iOS 库。 但是,如果您使用 TensorFlow 1.5 或 1.6,则tensorflow/core/kernels/fifo_queue_op.cc
文件已经添加到tf_op_files.txt
文件中。 在每个新版本的 TensorFlow 中,默认情况下,越来越多的内核被添加到tf_op_files.txt
。
借助为模型构建的 TensorFlow iOS 库,让我们在 Xcode 中创建一个名为 GAN 的新项目,并像在第 8 章,“使用 RNN 预测股价”一样在该项目中设置 TensorFlow。 以及其他不使用 TensorFlow 窗格的章节。 然后将两个模型文件gan_mnist.pb
和pix2pix_transformed_memmapped.pb
以及一个测试图像拖放到项目中。 另外,将第 6 章,“使用自然语言描述图像”的 iOS 项目中的tensorflow_utils.h
, tensorflow_utils.mm
,ios_image_load.h
和 ios_image_load.mm
文件复制到 GAN 项目。 将ViewController.m
重命名为ViewController.mm
。
现在,您的 Xcode 应该类似于图 9.2:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-otj0WMJO-1681653119038)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/intel-mobi-proj-tf/img/9cb1fa66-5a4e-4a1b-aa51-a16fd9051f57.png)]
图 9.2:在 Xcode 中显示 GAN 应用
我们将创建一个按钮,在点击该按钮时,提示用户选择一个模型以生成数字或增强图像:
- (IBAction)btnTapped:(id)sender { UIAlertAction* mnist = [UIAlertAction actionWithTitle:@"Generate Digits" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) { _iv.image = NULL; dispatch_async(dispatch_get_global_queue(0, 0), ^{ NSArray *arrayGreyscaleValues = [self runMNISTModel]; dispatch_async(dispatch_get_main_queue(), ^{ UIImage *imgDigit = [self createMNISTImageInRect:_iv.frame values:arrayGreyscaleValues]; _iv.image = imgDigit; }); }); }]; UIAlertAction* pix2pix = [UIAlertAction actionWithTitle:@"Enhance Image" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) { _iv.image = [UIImage imageNamed:image_name]; dispatch_async(dispatch_get_global_queue(0, 0), ^{ NSArray *arrayRGBValues = [self runPix2PixBlurryModel]; dispatch_async(dispatch_get_main_queue(), ^{ UIImage *imgTranslated = [self createTranslatedImageInRect:_iv.frame values:arrayRGBValues]; _iv.image = imgTranslated; }); }); }]; UIAlertAction* none = [UIAlertAction actionWithTitle:@"None" style:UIAlertActionStyleDefault handler:^(UIAlertAction * action) {}]; UIAlertController* alert = [UIAlertController alertControllerWithTitle:@"Use GAN to" message:nil preferredStyle:UIAlertControllerStyleAlert]; [alert addAction:mnist]; [alert addAction:pix2pix]; [alert addAction:none]; [self presentViewController:alert animated:YES completion:nil]; }
这里的代码非常简单。 应用的主要功能通过以下四种方法实现: runMNISTModel
, runPix2PixBlurryModel
, createMNISTImageInRect
和 createTranslatedImageInRect
。
使用基本 GAN 模型
在runMNISTModel
中,我们调用辅助方法LoadModel
来加载 GAN 模型,然后将输入张量设置为具有正态分布(均值 0.0 和 std 1.0)的 100 个随机数的 6 批。 该模型期望具有正态分布的随机输入。 您可以将 6 更改为任何其他数字,然后取回该数字的生成位数:
- (NSArray*) runMNISTModel { tensorflow::Status load_status; load_status = LoadModel(@"gan_mnist", @"pb", &tf_session); if (!load_status.ok()) return NULL; std::string input_layer = "z_placeholder"; std::string output_layer = "Sigmoid_1"; tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({6, 100})); auto input_map = input_tensor.tensor<float, 2>(); unsigned seed = (unsigned)std::chrono::system_clock::now().time_since_epoch().count(); std::default_random_engine generator (seed); std::normal_distribution<double> distribution(0.0, 1.0); for (int i = 0; i < 6; i++){ for (int j = 0; j < 100; j++){ double number = distribution(generator); input_map(i,j) = number; } }
runMNISTModel
方法中的其余代码运行模型,获得6 * 28 * 28
浮点数的输出,表示每批像素大小为28 * 28
的图像在每个像素处的灰度值,并调用方法createMNISTImageInRect
,以便在将图像上下文转换为UIImage
之前,先使用 UIBezierPath
在图像上下文中呈现数字,然后将其返回并显示在UIImageView
中:
std::vector<tensorflow::Tensor> outputs; tensorflow::Status run_status = tf_session->Run({{input_layer, input_tensor}}, {output_layer}, {}, &outputs); if (!run_status.ok()) { LOG(ERROR) << "Running model failed: " << run_status; return NULL; } tensorflow::string status_string = run_status.ToString(); tensorflow::Tensor* output_tensor = &outputs[0]; const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>, Eigen::Aligned>& output = output_tensor->flat<float>(); const long count = output.size(); NSMutableArray *arrayGreyscaleValues = [NSMutableArray array]; for (int i = 0; i < count; ++i) { const float value = output(i); [arrayGreyscaleValues addObject:[NSNumber numberWithFloat:value]]; } return arrayGreyscaleValues; }
createMNISTImageInRect
的定义如下-我们在第 7 章,“使用 CNN 和 LSTM 识别绘画”中使用了类似的技术:
- (UIImage *)createMNISTImageInRect:(CGRect)rect values:(NSArray*)greyscaleValues { UIGraphicsBeginImageContextWithOptions(CGSizeMake(rect.size.width, rect.size.height), NO, 0.0); int i=0; const int size = 3; for (NSNumber *val in greyscaleValues) { float c = [val floatValue]; int x = i%28; int y = i/28; i++; CGRect rect = CGRectMake(145+size*x, 50+y*size, size, size); UIBezierPath *path = [UIBezierPath bezierPathWithRect:rect]; UIColor *color = [UIColor colorWithRed:c green:c blue:c alpha:1.0]; [color setFill]; [path fill]; } UIImage *image = UIGraphicsGetImageFromCurrentImageContext(); UIGraphicsEndImageContext(); return image; }
对于每个像素,我们绘制一个宽度和高度均为 3 的小矩形,并为该像素返回灰度值。
使用高级 GAN 模型
在runPix2PixBlurryModel
方法中,我们使用LoadMemoryMappedModel
方法加载pix2pix_transformed_memmapped.pb
模型文件,并加载测试图像并设置输入张量,其方式与第 4 章,“以惊人的艺术样式迁移图片”相同:
- (NSArray*) runPix2PixBlurryModel { tensorflow::Status load_status; load_status = LoadMemoryMappedModel(@"pix2pix_transformed_memmapped", @"pb", &tf_session, &tf_memmapped_env); if (!load_status.ok()) return NULL; std::string input_layer = "image_feed"; std::string output_layer = "generator_1/deprocess/truediv"; NSString* image_path = FilePathForResourceName(@"ww", @"png"); int image_width; int image_height; int image_channels; std::vector<tensorflow::uint8> image_data = LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels);
然后我们运行模型,获得256 * 256 * 3
(图像大小为256 * 256
,RGB 具有 3 个值)浮点数的输出,并调用createTranslatedImageInRect
将数字转换为UIImage
:
std::vector<tensorflow::Tensor> outputs; tensorflow::Status run_status = tf_session->Run({{input_layer, image_tensor}}, {output_layer}, {}, &outputs); if (!run_status.ok()) { LOG(ERROR) << "Running model failed: " << run_status; return NULL; } tensorflow::string status_string = run_status.ToString(); tensorflow::Tensor* output_tensor = &outputs[0]; const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>, Eigen::Aligned>& output = output_tensor->flat<float>(); const long count = output.size(); // 256*256*3 NSMutableArray *arrayRGBValues = [NSMutableArray array]; for (int i = 0; i < count; ++i) { const float value = output(i); [arrayRGBValues addObject:[NSNumber numberWithFloat:value]]; } return arrayRGBValues;
ensorFlow 智能移动项目:6~10(5)https://developer.aliyun.com/article/1426912