手書き文字を作れるJavascriptをつくってTensorFlowで予測させてみた
MNISTでテストしているだけだと味気ないので、ブラウザで数字を手書きして、それを予測するようにしてみました。JavascriptのCanvasでお絵かきアプリみたいのを作って、そこに手書きで数字を書いてボタン押したら、28x28に縮小して、Base64の状態でサーバに送ります。サーバでPythonが、numpyの配列にして、TensorFlowに渡して推測しております。

多分やり方は大体あってるのではないかと思ってるのですが、いかんせん精度が超悪いです。そもそもまだTensorFlowがCNNになっていないのですが、それでもMNISTで学習・テストすると97%位にはなっているものです。背景黒・文字色白で0-255の明るさを28x28もつ配列にしているので、形式は合っているはずなのですが、やはり文字の太さとか大きさとかそういうのによって、全然違うということなのかなと思っております。ブラウザで割と適当に数字を書いてもいい感じの精度で答えを言ってくるのかなと期待していたので残念です。とりあえず、今度CNNで対応してみて、それでもダメだったら、Javascriptの画像自体で訓練をしていくようにしてみようかなと思ってます。(何か根本的な原因がありそうな気もしますが)
GitHub
一応GitHubに入れておきました。誰か精度よくしてくれたら嬉しいです。
追記(2017/01/25) 上記のリポジトリは、「手書き文字を作れるJavascriptをつくってTensorFlowで予測させてみた(2)」の内容に更新しました。
PythonでWEBサーバ起動して、必要なディレクトリ・ファイルを作成
$ mkdir mnist_test$ cd mnist_test$ python -m http.server --cgi$ touch index.html$ mkdir cgi-bin$ touch cgi-bin/mnist.py$ touch cgi-bin/mytensor.py$ mkdir cgi-bin/mnist$ mkdir cgi-bin/ckpt各ファイルの中身を作成
index.html
<html><head><title>MNIST TEST</title></head><body><h1>MNIST TEST</h1><canvas id="canvas1" width="400" height="400" style="border: 1px solid #999;"></canvas><br><br><input id="clear" type="button" value="Clear" onclick="canvasClear();"><input id="submit" type="button" value="Submit" onclick="saveImg();"><br><br><img id="preview"><span id="answer"></span><script src="http://code.jquery.com/jquery-3.1.1.min.js"></script><script>var url = 'http://127.0.0.1:8000/cgi-bin/mnist.py';var lineWidth = 50;var lineColor = '#ffffff';var imgW = imgH = 28;
var canvas = document.getElementById('canvas1');var ctx = canvas.getContext('2d');var cleft = canvas.getBoundingClientRect().left;var ctop = canvas.getBoundingClientRect().top;var mouseX = mouseY = null;
canvasClear();canvas.addEventListener('mousemove', mmove, false);canvas.addEventListener('mousedown', mdown, false);canvas.addEventListener('mouseup', mouseInit, false);canvas.addEventListener('mouseout', mouseInit, false);
function mmove(e){ if (e.buttons == 1 || e.witch == 1) { draw(e.clientX - cleft, e.clientY - ctop); };}
function mdown(e){ draw(e.clientX - cleft, e.clientY - ctop);}
function draw(x, y){ ctx.beginPath(); if(mouseX === null) ctx.moveTo(x, y); else ctx.moveTo(mouseX, mouseY); ctx.lineTo(x, y); ctx.lineCap = "round"; ctx.lineWidth = lineWidth; ctx.strokeStyle = lineColor; ctx.stroke(); mouseX = x; mouseY = y;}
function mouseInit(){ mouseX = mouseY = null;}
function canvasClear(){ ctx.clearRect(0, 0, canvas.width, canvas.height); ctx.fillStyle = '#000000'; ctx.fillRect(0, 0, canvas.width, canvas.height); $('#preview').attr('src', ''); $('#answer').empty();}
function toImg(){ var tmp = document.createElement('canvas'); tmp.width = imgW; tmp.height = imgH; var tmpCtx = tmp.getContext('2d'); tmpCtx.drawImage(canvas, 0, 0, canvas.width, canvas.height, 0, 0, imgW, imgH); var img = tmp.toDataURL('image/jpeg'); $('#preview').attr('src', img) return img;}
function saveImg(){ var img = toImg(); console.log(img); $.ajax({ url: url, type: 'POST', data: {img: img}, dataType: 'json', success: function(data){ if(data.status){ $('#answer').html('は、' + data.num + 'です'); }else{ $('#answer').html('は、分かりません'); } }, });}</script></body></html>cgi-bin/mnist.py
#!/usr/bin/env python
import sysimport osimport cgiimport jsonimport cgitbcgitb.enable()
from PIL import Imageimport numpy as npfrom io import BytesIOfrom binascii import a2b_base64
from mytensor import MyTensor
import logging
print('Content-Type: text/json; charset=utf-8')print()
if os.environ['REQUEST_METHOD'] == 'POST': data = cgi.FieldStorage() img_str = data.getvalue('img', None) if img_str: b64_str = img_str.split(',')[1] img = Image.open(BytesIO(a2b_base64(b64_str))).convert('L') img_arr = np.array(img).reshape(1, -1) tf = MyTensor() result = tf.predict(img_arr) print(json.dumps({'status': True, 'num': result})) sys.exit()print(json.dumps({'status': False, 'num': False}))cgi-bin/mytensor.py
import tensorflow as tfimport numpy as npfrom tensorflow.examples.tutorials.mnist import input_datafrom PIL import Image
class MyTensor: H = 625 BATCH_SIZE = 100 DROP_OUT_RATE = 0.5
def __init__(self): self.x = tf.placeholder(tf.float32, [None, 784]) self.t = tf.placeholder(tf.float32, [None, 10]) self.w1 = tf.Variable(tf.random_normal([784, self.H], mean=0.0, stddev=0.05)) self.b1 = tf.Variable(tf.zeros([self.H])) self.w2 = tf.Variable(tf.random_normal([self.H, self.H], mean=0.0, stddev=0.05)) self.b2 = tf.Variable(tf.zeros([self.H])) self.w3 = tf.Variable(tf.random_normal([self.H, 10], mean=0.0, stddev=0.05)) self.b3 = tf.Variable(tf.zeros([10]))
self.a1 = tf.sigmoid(tf.matmul(self.x, self.w1) + self.b1) self.a2 = tf.sigmoid(tf.matmul(self.a1, self.w2) + self.b2) self.keep_prob = tf.placeholder(tf.float32) self.drop = tf.nn.dropout(self.a2, self.keep_prob) self.y = tf.nn.relu(tf.matmul(self.drop, self.w3) + self.b3) self.loss = tf.nn.l2_loss(self.y - self.t) / self.BATCH_SIZE
self.train_step = tf.train.AdamOptimizer(1e-4).minimize(self.loss) self.correct = tf.equal(tf.argmax(self.y, 1), tf.argmax(self.t, 1)) self.accuracy = tf.reduce_mean(tf.cast(self.correct, tf.float32)) self.saver = tf.train.Saver() if not self.ckpt(): self.train()
def ckpt(self): return tf.train.get_checkpoint_state('.\cgi-bin\ckpt')
def predict(self, img): with tf.Session() as sess: self.saver.restore(sess, self.ckpt().model_checkpoint_path) result = sess.run(self.y, feed_dict={self.x: img, self.keep_prob:1.0}) return int(np.argmax(result))
def train(self): mnist = input_data.read_data_sets('.\cgi-bin\mnist', one_hot=True, dtype=tf.uint8) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(20000): batch_x, batch_t = mnist.train.next_batch(100) sess.run(self.train_step, feed_dict={self.x: batch_x, self.t: batch_t, self.keep_prob:(1-self.DROP_OUT_RATE)}) self.saver.save(sess, '.\cgi-bin\ckpt\model.ckpt')ブラウザでアクセス
下記URLにアクセスすると、MNIST TESTというお絵かきページみたいのがでてきます。ちなみに、最初にsubmitボタン押すと、MNISTのダウンロードから学習までをするので、時間かかるし、結果も表示されません。2回目からはちゃんと結果が(一応)表示されます。
http://127.0.0.1:8000