この前、「手書き文字を作れるJavascriptをつくってTensorFlowで予測させてみた」という投稿でブラウザ上で手書きした文字画像を、MNISTで訓練したモデルで予測してみましたが、ものすごく精度が悪かったです。今回改めて、CNNを使ってやってみたらかなり精度が上がりました。何%か測ったりしてませんが、自分の手書きだと90%は超える感じでした。やっぱりCNNはすごいなーと思いました。でももしかしたら前回のものにミスがあり、CNNではなくても精度は本当はもっと高い可能性はあります。
もうちょっとやるとしたら、文字を画像の中心に適度な大きさで書く必要があり、例えば右上に小さく2と書いても認識されません。あとは、現在はMNISTに合わせて、手書き文字画像も背景黒、文字色白で作成するように固定していますが、これらの色を変えても認識するようにしたいです。今度やってみます。
Github
https://github.com/endoyuta/mnist_test
index.html
<html> <head> <title>MNIST TEST</title> </head> <body> <h1>MNIST TEST</h1> <canvas id="canvas1" width="400" height="400" style="border: 1px solid #999;"></canvas><br><br> <input id="clear" type="button" value="Clear" onclick="canvasClear();"> <input id="submit" type="button" value="Submit" onclick="saveImg();"><br><br> <img id="preview"><span id="answer"></span> <script src="http://code.jquery.com/jquery-3.1.1.min.js"></script> <script> var url = 'http://127.0.0.1:8000/cgi-bin/mnist.py'; var lineWidth = 40; var lineColor = '#ffffff'; var imgW = imgH = 28; var canvas = document.getElementById('canvas1'); var ctx = canvas.getContext('2d'); var cleft = canvas.getBoundingClientRect().left; var ctop = canvas.getBoundingClientRect().top; var mouseX = mouseY = null; canvasClear(); canvas.addEventListener('mousemove', mmove, false); canvas.addEventListener('mousedown', mdown, false); canvas.addEventListener('mouseup', mouseInit, false); canvas.addEventListener('mouseout', mouseInit, false); function mmove(e){ if (e.buttons == 1 || e.witch == 1) { draw(e.clientX - cleft, e.clientY - ctop); }; } function mdown(e){ draw(e.clientX - cleft, e.clientY - ctop); } function draw(x, y){ ctx.beginPath(); if(mouseX === null) ctx.moveTo(x, y); else ctx.moveTo(mouseX, mouseY); ctx.lineTo(x, y); ctx.lineCap = "round"; ctx.lineWidth = lineWidth; ctx.strokeStyle = lineColor; ctx.stroke(); mouseX = x; mouseY = y; } function mouseInit(){ mouseX = mouseY = null; } function canvasClear(){ ctx.clearRect(0, 0, canvas.width, canvas.height); ctx.fillStyle = '#000000'; ctx.fillRect(0, 0, canvas.width, canvas.height); $('#preview').attr('src', ''); $('#answer').empty(); } function toImg(){ var tmp = document.createElement('canvas'); tmp.width = imgW; tmp.height = imgH; var tmpCtx = tmp.getContext('2d'); tmpCtx.drawImage(canvas, 0, 0, canvas.width, canvas.height, 0, 0, imgW, imgH); var img = tmp.toDataURL('image/jpeg'); $('#preview').attr('src', img) return img; } function saveImg(){ var img = toImg(); console.log(img); $.ajax({ url: url, type: 'POST', data: {img: img}, dataType: 'json', success: function(data){ if(data.status){ $('#answer').html('は、' + data.num + 'です'); }else{ $('#answer').html('は、分かりません'); } }, }); } </script> </body> </html>
cgi-bin/mnist.py
#!/usr/bin/env python import sys import os import cgi import json import cgitb cgitb.enable() from PIL import Image import numpy as np from io import BytesIO from binascii import a2b_base64 import mytensor print('Content-Type: text/json; charset=utf-8') print() if os.environ['REQUEST_METHOD'] == 'POST': data = cgi.FieldStorage() img_str = data.getvalue('img', None) if img_str: b64_str = img_str.split(',')[1] img = Image.open(BytesIO(a2b_base64(b64_str))).convert('L') img_arr = np.array(img).reshape(1, -1) img_arr = img_arr / 255 result = mytensor.predict(img_arr) print(json.dumps({'status': True, 'num': result})) sys.exit() print(json.dumps({'status': False, 'num': False}))
cgi-bin/mytensor.py
import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_data def _weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) def _bias_variable(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial) def _conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') def interface(): x = tf.placeholder(tf.float32, shape=[None, 784]) y_ = tf.placeholder(tf.float32, shape=[None, 10]) x_image = tf.reshape(x, [-1, 28, 28, 1]) W_conv1 = _weight_variable([5, 5, 1, 32]) b_conv1 = _bias_variable([32]) h_conv1 = tf.nn.relu(_conv2d(x_image, W_conv1) + b_conv1) h_pool1 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') W_conv2 = _weight_variable([5, 5, 32, 64]) b_conv2 = _bias_variable([64]) h_conv2 = tf.nn.relu(_conv2d(h_pool1, W_conv2) + b_conv2) h_pool2 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') W_fc1 = _weight_variable([7 * 7 * 64, 1024]) b_fc1 = _bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) keep_prob = tf.placeholder(tf.float32) h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) W_fc2 = _weight_variable([1024, 10]) b_fc2 = _bias_variable([10]) y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv, y_)) train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() class CNNModel(): pass model = CNNModel() model.x = x model.y_ = y_ model.keep_prob = keep_prob model.y_conv = y_conv model.train_step = train_step model.accuracy = accuracy model.saver = saver return model def predict(img): ckpt = tf.train.get_checkpoint_state('./cgi-bin/ckpt') if not ckpt: return False m = interface() with tf.Session() as sess: m.saver.restore(sess, ckpt.model_checkpoint_path) result = sess.run(m.y_conv, feed_dict={m.x: img, m.keep_prob:1.0}) return int(np.argmax(result)) def train(): if tf.train.get_checkpoint_state('./ckpt'): print('train ok') return mnist = input_data.read_data_sets('./mnist', one_hot=True, dtype=tf.float32) m = interface() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(20000): batch = mnist.train.next_batch(100) m.train_step.run(feed_dict={m.x: batch[0], m.y_: batch[1], m.keep_prob: 0.5}) if i % 100 == 0: train_accuracy = m.accuracy.eval(feed_dict={m.x:batch[0], m.y_: batch[1], m.keep_prob: 1.0}) print("step %d, training accuracy %g"%(i, train_accuracy)) m.saver.save(sess, './ckpt/model.ckpt') print('train ok') if __name__ == '__main__': train()