edo1z blog

プログラミングなどに関するブログです

手書き文字を作れるJavascriptをつくってTensorFlowで予測させてみた

MNISTでテストしているだけだと味気ないので、ブラウザで数字を手書きして、それを予測するようにしてみました。JavascriptのCanvasでお絵かきアプリみたいのを作って、そこに手書きで数字を書いてボタン押したら、28x28に縮小して、Base64の状態でサーバに送ります。サーバでPythonが、numpyの配列にして、TensorFlowに渡して推測しております。

多分やり方は大体あってるのではないかと思ってるのですが、いかんせん精度が超悪いです。そもそもまだTensorFlowがCNNになっていないのですが、それでもMNISTで学習・テストすると97%位にはなっているものです。背景黒・文字色白で0-255の明るさを28x28もつ配列にしているので、形式は合っているはずなのですが、やはり文字の太さとか大きさとかそういうのによって、全然違うということなのかなと思っております。ブラウザで割と適当に数字を書いてもいい感じの精度で答えを言ってくるのかなと期待していたので残念です。とりあえず、今度CNNで対応してみて、それでもダメだったら、Javascriptの画像自体で訓練をしていくようにしてみようかなと思ってます。(何か根本的な原因がありそうな気もしますが)

GitHub

一応GitHubに入れておきました。誰か精度よくしてくれたら嬉しいです。

追記(2017/01/25) 上記のリポジトリは、「手書き文字を作れるJavascriptをつくってTensorFlowで予測させてみた(2)」の内容に更新しました。

PythonでWEBサーバ起動して、必要なディレクトリ・ファイルを作成

$ mkdir mnist_test
$ cd mnist_test
$ python -m http.server --cgi
$ touch index.html
$ mkdir cgi-bin
$ touch cgi-bin/mnist.py
$ touch cgi-bin/mytensor.py
$ mkdir cgi-bin/mnist
$ mkdir cgi-bin/ckpt

各ファイルの中身を作成

index.html

<html>
<head>
<title>MNIST TEST</title>
</head>
<body>
<h1>MNIST TEST</h1>
<canvas id="canvas1" width="400" height="400" style="border: 1px solid #999;"></canvas><br><br>
<input id="clear" type="button" value="Clear" onclick="canvasClear();">
<input id="submit" type="button" value="Submit" onclick="saveImg();"><br><br>
<img id="preview"><span id="answer"></span>
<script src="http://code.jquery.com/jquery-3.1.1.min.js"></script>
<script>
var url = 'http://127.0.0.1:8000/cgi-bin/mnist.py';
var lineWidth = 50;
var lineColor = '#ffffff';
var imgW = imgH = 28;

var canvas = document.getElementById('canvas1');
var ctx = canvas.getContext('2d');
var cleft = canvas.getBoundingClientRect().left;
var ctop = canvas.getBoundingClientRect().top;
var mouseX = mouseY = null;

canvasClear();
canvas.addEventListener('mousemove', mmove, false);
canvas.addEventListener('mousedown', mdown, false);
canvas.addEventListener('mouseup', mouseInit, false);
canvas.addEventListener('mouseout', mouseInit, false);

function mmove(e){
    if (e.buttons == 1 || e.witch == 1) {
        draw(e.clientX - cleft, e.clientY - ctop);
    };
}

function mdown(e){
    draw(e.clientX - cleft, e.clientY - ctop);
}

function draw(x, y){
    ctx.beginPath();
    if(mouseX === null) ctx.moveTo(x, y);
    else ctx.moveTo(mouseX, mouseY);
    ctx.lineTo(x, y);
    ctx.lineCap = "round";
    ctx.lineWidth = lineWidth;
    ctx.strokeStyle = lineColor;
    ctx.stroke();
    mouseX = x;
    mouseY = y;
}

function mouseInit(){
    mouseX = mouseY = null;
}

function canvasClear(){
    ctx.clearRect(0, 0, canvas.width, canvas.height);
    ctx.fillStyle = '#000000';
    ctx.fillRect(0, 0, canvas.width, canvas.height);
    $('#preview').attr('src', '');
    $('#answer').empty();
}

function toImg(){
    var tmp = document.createElement('canvas');
    tmp.width = imgW;
    tmp.height = imgH;
    var tmpCtx = tmp.getContext('2d');
    tmpCtx.drawImage(canvas, 0, 0, canvas.width, canvas.height, 0, 0, imgW, imgH);
    var img = tmp.toDataURL('image/jpeg');
    $('#preview').attr('src', img)
    return img;
}

function saveImg(){
    var img = toImg();
    console.log(img);
    $.ajax({
        url: url,
        type: 'POST',
        data: {img: img},
        dataType: 'json',
        success: function(data){
            if(data.status){
                $('#answer').html('は、' + data.num + 'です');
            }else{
                $('#answer').html('は、分かりません');
            }
        },
    });
}
</script>
</body>
</html>

cgi-bin/mnist.py

#!/usr/bin/env python

import sys
import os
import cgi
import json
import cgitb
cgitb.enable()

from PIL import Image
import numpy as np
from io import BytesIO
from binascii import a2b_base64

from mytensor import MyTensor

import logging

print('Content-Type: text/json; charset=utf-8')
print()

if os.environ['REQUEST_METHOD'] == 'POST':
    data = cgi.FieldStorage()
    img_str = data.getvalue('img', None)
    if img_str:
        b64_str = img_str.split(',')[1]
        img = Image.open(BytesIO(a2b_base64(b64_str))).convert('L')
        img_arr = np.array(img).reshape(1, -1)
        tf = MyTensor()
        result = tf.predict(img_arr)
        print(json.dumps({'status': True, 'num': result}))
        sys.exit()
print(json.dumps({'status': False, 'num': False}))

cgi-bin/mytensor.py

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from PIL import Image

class MyTensor:
    H = 625
    BATCH_SIZE = 100
    DROP_OUT_RATE = 0.5

    def __init__(self):
        self.x = tf.placeholder(tf.float32, [None, 784])
        self.t = tf.placeholder(tf.float32, [None, 10])
        self.w1 = tf.Variable(tf.random_normal([784, self.H], mean=0.0, stddev=0.05))
        self.b1 = tf.Variable(tf.zeros([self.H]))
        self.w2 = tf.Variable(tf.random_normal([self.H, self.H], mean=0.0, stddev=0.05))
        self.b2 = tf.Variable(tf.zeros([self.H]))
        self.w3 = tf.Variable(tf.random_normal([self.H, 10], mean=0.0, stddev=0.05))
        self.b3 = tf.Variable(tf.zeros([10]))

        self.a1 = tf.sigmoid(tf.matmul(self.x, self.w1) + self.b1)
        self.a2 = tf.sigmoid(tf.matmul(self.a1, self.w2) + self.b2)
        self.keep_prob = tf.placeholder(tf.float32)
        self.drop = tf.nn.dropout(self.a2, self.keep_prob)
        self.y = tf.nn.relu(tf.matmul(self.drop, self.w3) + self.b3)
        self.loss = tf.nn.l2_loss(self.y - self.t) / self.BATCH_SIZE

        self.train_step = tf.train.AdamOptimizer(1e-4).minimize(self.loss)
        self.correct = tf.equal(tf.argmax(self.y, 1), tf.argmax(self.t, 1))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct, tf.float32))
        self.saver = tf.train.Saver()
        if not self.ckpt():
            self.train()

    def ckpt(self):
        return tf.train.get_checkpoint_state('.\cgi-bin\ckpt')

    def predict(self, img):
        with tf.Session() as sess:
            self.saver.restore(sess, self.ckpt().model_checkpoint_path)
            result = sess.run(self.y, feed_dict={self.x: img, self.keep_prob:1.0})
            return int(np.argmax(result))

    def train(self):
        mnist = input_data.read_data_sets('.\cgi-bin\mnist', one_hot=True, dtype=tf.uint8)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            for _ in range(20000):
                batch_x, batch_t = mnist.train.next_batch(100)
                sess.run(self.train_step, feed_dict={self.x: batch_x, self.t: batch_t, self.keep_prob:(1-self.DROP_OUT_RATE)})
            self.saver.save(sess, '.\cgi-bin\ckpt\model.ckpt')

ブラウザでアクセス

下記URLにアクセスすると、MNIST TESTというお絵かきページみたいのがでてきます。ちなみに、最初にsubmitボタン押すと、MNISTのダウンロードから学習までをするので、時間かかるし、結果も表示されません。2回目からはちゃんと結果が(一応)表示されます。

http://127.0.0.1:8000