Optical character recognition

    This starts off with 4 letter words. For the first 12 epochs, thedifficulty is gradually increased using the TextImageGenerator classwhich is both a generator class for test/train data and a Kerascallback class. After 20 epochs, longer sequences are thrown at itby recompiling the model to handle a wider image and rebuildingthe word list to include two words separated by a space.

    Additional dependencies

    This requires and editdistance packages:

    Then install Python dependencies:

    1. import os
    2. import itertools
    3. import codecs
    4. import re
    5. import datetime
    6. import cairocffi as cairo
    7. import editdistance
    8. import numpy as np
    9. from scipy import ndimage
    10. import pylab
    11. from keras import backend as K
    12. from keras.layers.convolutional import Conv2D, MaxPooling2D
    13. from keras.layers import Input, Dense, Activation
    14. from keras.layers import Reshape, Lambda
    15. from keras.layers.merge import add, concatenate
    16. from keras.models import Model
    17. from keras.layers.recurrent import GRU
    18. from keras.optimizers import SGD
    19. from keras.utils.data_utils import get_file
    20. from keras.preprocessing import image
    21. import keras.callbacks
    22. OUTPUT_DIR = 'image_ocr'
    23. # character classes and matching regex filter
    24. regex = r'^[a-z ]+$'
    25. alphabet = u'abcdefghijklmnopqrstuvwxyz '
    26. np.random.seed(55)
    27. # this creates larger "blotches" of noise which look
    28. # more realistic than just adding gaussian noise
    29. # assumes greyscale with pixels ranging from 0 to 1
    30. def speckle(img):
    31. severity = np.random.uniform(0, 0.6)
    32. blur = ndimage.gaussian_filter(np.random.randn(*img.shape) * severity, 1)
    33. img_speck = (img + blur)
    34. img_speck[img_speck > 1] = 1
    35. img_speck[img_speck <= 0] = 0
    36. return img_speck
    37. # paints the string in a random location the bounding box
    38. # also uses a random font, a slight random rotation,
    39. # and a random amount of speckle noise
    40. def paint_text(text, w, h, rotate=False, ud=False, multi_fonts=False):
    41. surface = cairo.ImageSurface(cairo.FORMAT_RGB24, w, h)
    42. with cairo.Context(surface) as context:
    43. context.set_source_rgb(1, 1, 1) # White
    44. context.paint()
    45. # this font list works in CentOS 7
    46. if multi_fonts:
    47. fonts = [
    48. 'Century Schoolbook', 'Courier', 'STIX',
    49. 'URW Chancery L', 'FreeMono']
    50. context.select_font_face(
    51. np.random.choice(fonts),
    52. cairo.FONT_SLANT_NORMAL,
    53. np.random.choice([cairo.FONT_WEIGHT_BOLD, cairo.FONT_WEIGHT_NORMAL]))
    54. else:
    55. context.select_font_face('Courier',
    56. cairo.FONT_SLANT_NORMAL,
    57. cairo.FONT_WEIGHT_BOLD)
    58. context.set_font_size(25)
    59. box = context.text_extents(text)
    60. border_w_h = (4, 4)
    61. if box[2] > (w - 2 * border_w_h[1]) or box[3] > (h - 2 * border_w_h[0]):
    62. raise IOError(('Could not fit string into image.'
    63. 'Max char count is too large for given image width.'))
    64. # teach the RNN translational invariance by
    65. # fitting text box randomly on canvas, with some room to rotate
    66. max_shift_x = w - box[2] - border_w_h[0]
    67. max_shift_y = h - box[3] - border_w_h[1]
    68. top_left_x = np.random.randint(0, int(max_shift_x))
    69. if ud:
    70. top_left_y = np.random.randint(0, int(max_shift_y))
    71. else:
    72. top_left_y = h // 2
    73. context.move_to(top_left_x - int(box[0]), top_left_y - int(box[1]))
    74. context.set_source_rgb(0, 0, 0)
    75. context.show_text(text)
    76. buf = surface.get_data()
    77. a = np.frombuffer(buf, np.uint8)
    78. a.shape = (h, w, 4)
    79. a = a[:, :, 0] # grab single channel
    80. a = a.astype(np.float32) / 255
    81. a = np.expand_dims(a, 0)
    82. if rotate:
    83. a = image.random_rotation(a, 3 * (w - top_left_x) / w + 1)
    84. a = speckle(a)
    85. return a
    86. def shuffle_mats_or_lists(matrix_list, stop_ind=None):
    87. ret = []
    88. assert all([len(i) == len(matrix_list[0]) for i in matrix_list])
    89. len_val = len(matrix_list[0])
    90. if stop_ind is None:
    91. stop_ind = len_val
    92. assert stop_ind <= len_val
    93. a = list(range(stop_ind))
    94. np.random.shuffle(a)
    95. a += list(range(stop_ind, len_val))
    96. for mat in matrix_list:
    97. if isinstance(mat, np.ndarray):
    98. ret.append(mat[a])
    99. elif isinstance(mat, list):
    100. ret.append([mat[i] for i in a])
    101. else:
    102. raise TypeError('`shuffle_mats_or_lists` only supports '
    103. 'numpy.array and list objects.')
    104. return ret
    105. # Translation of characters to unique integer values
    106. def text_to_labels(text):
    107. ret = []
    108. for char in text:
    109. ret.append(alphabet.find(char))
    110. return ret
    111. # Reverse translation of numerical classes back to characters
    112. def labels_to_text(labels):
    113. ret = []
    114. for c in labels:
    115. if c == len(alphabet): # CTC Blank
    116. ret.append("")
    117. else:
    118. ret.append(alphabet[c])
    119. return "".join(ret)
    120. # only a-z and space..probably not to difficult
    121. # to expand to uppercase and symbols
    122. def is_valid_str(in_str):
    123. search = re.compile(regex, re.UNICODE).search
    124. return bool(search(in_str))
    125. # Uses generator functions to supply train/test with
    126. # data. Image renderings and text are created on the fly
    127. # each time with random perturbations
    128. class TextImageGenerator(keras.callbacks.Callback):
    129. def __init__(self, monogram_file, bigram_file, minibatch_size,
    130. img_w, img_h, downsample_factor, val_split,
    131. absolute_max_string_len=16):
    132. self.minibatch_size = minibatch_size
    133. self.img_w = img_w
    134. self.img_h = img_h
    135. self.monogram_file = monogram_file
    136. self.bigram_file = bigram_file
    137. self.downsample_factor = downsample_factor
    138. self.val_split = val_split
    139. self.blank_label = self.get_output_size() - 1
    140. self.absolute_max_string_len = absolute_max_string_len
    141. def get_output_size(self):
    142. return len(alphabet) + 1
    143. # num_words can be independent of the epoch size due to the use of generators
    144. # as max_string_len grows, num_words can grow
    145. def build_word_list(self, num_words, max_string_len=None, mono_fraction=0.5):
    146. assert max_string_len <= self.absolute_max_string_len
    147. assert num_words % self.minibatch_size == 0
    148. assert (self.val_split * num_words) % self.minibatch_size == 0
    149. self.num_words = num_words
    150. self.string_list = [''] * self.num_words
    151. tmp_string_list = []
    152. self.max_string_len = max_string_len
    153. self.Y_data = np.ones([self.num_words, self.absolute_max_string_len]) * -1
    154. self.X_text = []
    155. self.Y_len = [0] * self.num_words
    156. def _is_length_of_word_valid(word):
    157. return (max_string_len == -1 or
    158. max_string_len is None or
    159. len(word) <= max_string_len)
    160. # monogram file is sorted by frequency in english speech
    161. with codecs.open(self.monogram_file, mode='r', encoding='utf-8') as f:
    162. for line in f:
    163. if len(tmp_string_list) == int(self.num_words * mono_fraction):
    164. break
    165. word = line.rstrip()
    166. if _is_length_of_word_valid(word):
    167. tmp_string_list.append(word)
    168. # bigram file contains common word pairings in english speech
    169. with codecs.open(self.bigram_file, mode='r', encoding='utf-8') as f:
    170. lines = f.readlines()
    171. for line in lines:
    172. if len(tmp_string_list) == self.num_words:
    173. break
    174. columns = line.lower().split()
    175. word = columns[0] + ' ' + columns[1]
    176. if is_valid_str(word) and _is_length_of_word_valid(word):
    177. tmp_string_list.append(word)
    178. if len(tmp_string_list) != self.num_words:
    179. raise IOError('Could not pull enough words'
    180. 'from supplied monogram and bigram files.')
    181. # interlace to mix up the easy and hard words
    182. self.string_list[::2] = tmp_string_list[:self.num_words // 2]
    183. self.string_list[1::2] = tmp_string_list[self.num_words // 2:]
    184. for i, word in enumerate(self.string_list):
    185. self.Y_len[i] = len(word)
    186. self.Y_data[i, 0:len(word)] = text_to_labels(word)
    187. self.X_text.append(word)
    188. self.Y_len = np.expand_dims(np.array(self.Y_len), 1)
    189. self.cur_val_index = self.val_split
    190. self.cur_train_index = 0
    191. # each time an image is requested from train/val/test, a new random
    192. # painting of the text is performed
    193. def get_batch(self, index, size, train):
    194. # width and height are backwards from typical Keras convention
    195. # because width is the time dimension when it gets fed into the RNN
    196. if K.image_data_format() == 'channels_first':
    197. X_data = np.ones([size, 1, self.img_w, self.img_h])
    198. else:
    199. X_data = np.ones([size, self.img_w, self.img_h, 1])
    200. labels = np.ones([size, self.absolute_max_string_len])
    201. input_length = np.zeros([size, 1])
    202. label_length = np.zeros([size, 1])
    203. source_str = []
    204. for i in range(size):
    205. # Mix in some blank inputs. This seems to be important for
    206. # achieving translational invariance
    207. if train and i > size - 4:
    208. if K.image_data_format() == 'channels_first':
    209. X_data[i, 0, 0:self.img_w, :] = self.paint_func('')[0, :, :].T
    210. else:
    211. X_data[i, 0:self.img_w, :, 0] = self.paint_func('',)[0, :, :].T
    212. labels[i, 0] = self.blank_label
    213. input_length[i] = self.img_w // self.downsample_factor - 2
    214. label_length[i] = 1
    215. source_str.append('')
    216. else:
    217. if K.image_data_format() == 'channels_first':
    218. X_data[i, 0, 0:self.img_w, :] = (
    219. self.paint_func(self.X_text[index + i])[0, :, :].T)
    220. else:
    221. X_data[i, 0:self.img_w, :, 0] = (
    222. self.paint_func(self.X_text[index + i])[0, :, :].T)
    223. labels[i, :] = self.Y_data[index + i]
    224. input_length[i] = self.img_w // self.downsample_factor - 2
    225. label_length[i] = self.Y_len[index + i]
    226. source_str.append(self.X_text[index + i])
    227. inputs = {'the_input': X_data,
    228. 'the_labels': labels,
    229. 'input_length': input_length,
    230. 'label_length': label_length,
    231. 'source_str': source_str # used for visualization only
    232. }
    233. outputs = {'ctc': np.zeros([size])} # dummy data for dummy loss function
    234. return (inputs, outputs)
    235. def next_train(self):
    236. while 1:
    237. ret = self.get_batch(self.cur_train_index,
    238. self.minibatch_size, train=True)
    239. self.cur_train_index += self.minibatch_size
    240. if self.cur_train_index >= self.val_split:
    241. self.cur_train_index = self.cur_train_index % 32
    242. (self.X_text, self.Y_data, self.Y_len) = shuffle_mats_or_lists(
    243. [self.X_text, self.Y_data, self.Y_len], self.val_split)
    244. yield ret
    245. def next_val(self):
    246. while 1:
    247. ret = self.get_batch(self.cur_val_index,
    248. self.minibatch_size, train=False)
    249. self.cur_val_index += self.minibatch_size
    250. if self.cur_val_index >= self.num_words:
    251. self.cur_val_index = self.val_split + self.cur_val_index % 32
    252. yield ret
    253. def on_train_begin(self, logs={}):
    254. self.build_word_list(16000, 4, 1)
    255. self.paint_func = lambda text: paint_text(
    256. text, self.img_w, self.img_h,
    257. rotate=False, ud=False, multi_fonts=False)
    258. def on_epoch_begin(self, epoch, logs={}):
    259. # rebind the paint function to implement curriculum learning
    260. if 3 <= epoch < 6:
    261. self.paint_func = lambda text: paint_text(
    262. text, self.img_w, self.img_h,
    263. rotate=False, ud=True, multi_fonts=False)
    264. elif 6 <= epoch < 9:
    265. self.paint_func = lambda text: paint_text(
    266. text, self.img_w, self.img_h,
    267. rotate=False, ud=True, multi_fonts=True)
    268. elif epoch >= 9:
    269. self.paint_func = lambda text: paint_text(
    270. text, self.img_w, self.img_h,
    271. rotate=True, ud=True, multi_fonts=True)
    272. if epoch >= 21 and self.max_string_len < 12:
    273. self.build_word_list(32000, 12, 0.5)
    274. # the actual loss calc occurs here despite it not being
    275. # an internal Keras loss function
    276. def ctc_lambda_func(args):
    277. y_pred, labels, input_length, label_length = args
    278. # the 2 is critical here since the first couple outputs of the RNN
    279. # tend to be garbage:
    280. y_pred = y_pred[:, 2:, :]
    281. return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
    282. # For a real OCR application, this should be beam search with a dictionary
    283. # and language model. For this example, best path is sufficient.
    284. def decode_batch(test_func, word_batch):
    285. out = test_func([word_batch])[0]
    286. ret = []
    287. for j in range(out.shape[0]):
    288. out_best = list(np.argmax(out[j, 2:], 1))
    289. out_best = [k for k, g in itertools.groupby(out_best)]
    290. outstr = labels_to_text(out_best)
    291. ret.append(outstr)
    292. return ret
    293. class VizCallback(keras.callbacks.Callback):
    294. def __init__(self, run_name, test_func, text_img_gen, num_display_words=6):
    295. self.output_dir = os.path.join(
    296. OUTPUT_DIR, run_name)
    297. self.text_img_gen = text_img_gen
    298. self.num_display_words = num_display_words
    299. if not os.path.exists(self.output_dir):
    300. os.makedirs(self.output_dir)
    301. def show_edit_distance(self, num):
    302. num_left = num
    303. mean_norm_ed = 0.0
    304. mean_ed = 0.0
    305. while num_left > 0:
    306. word_batch = next(self.text_img_gen)[0]
    307. num_proc = min(word_batch['the_input'].shape[0], num_left)
    308. decoded_res = decode_batch(self.test_func,
    309. word_batch['the_input'][0:num_proc])
    310. for j in range(num_proc):
    311. edit_dist = editdistance.eval(decoded_res[j],
    312. word_batch['source_str'][j])
    313. mean_ed += float(edit_dist)
    314. mean_norm_ed += float(edit_dist) / len(word_batch['source_str'][j])
    315. num_left -= num_proc
    316. mean_norm_ed = mean_norm_ed / num
    317. mean_ed = mean_ed / num
    318. print('\nOut of %d samples: Mean edit distance:'
    319. '%.3f Mean normalized edit distance: %0.3f'
    320. % (num, mean_ed, mean_norm_ed))
    321. def on_epoch_end(self, epoch, logs={}):
    322. self.model.save_weights(
    323. os.path.join(self.output_dir, 'weights%02d.h5' % (epoch)))
    324. self.show_edit_distance(256)
    325. word_batch = next(self.text_img_gen)[0]
    326. res = decode_batch(self.test_func,
    327. word_batch['the_input'][0:self.num_display_words])
    328. if word_batch['the_input'][0].shape[0] < 256:
    329. cols = 2
    330. else:
    331. cols = 1
    332. for i in range(self.num_display_words):
    333. pylab.subplot(self.num_display_words // cols, cols, i + 1)
    334. if K.image_data_format() == 'channels_first':
    335. the_input = word_batch['the_input'][i, 0, :, :]
    336. else:
    337. the_input = word_batch['the_input'][i, :, :, 0]
    338. pylab.imshow(the_input.T, cmap='Greys_r')
    339. pylab.xlabel(
    340. 'Truth = \'%s\'\nDecoded = \'%s\'' %
    341. (word_batch['source_str'][i], res[i]))
    342. fig = pylab.gcf()
    343. fig.set_size_inches(10, 13)
    344. pylab.savefig(os.path.join(self.output_dir, 'e%02d.png' % (epoch)))
    345. pylab.close()
    346. def train(run_name, start_epoch, stop_epoch, img_w):
    347. # Input Parameters
    348. img_h = 64
    349. words_per_epoch = 16000
    350. val_split = 0.2
    351. val_words = int(words_per_epoch * (val_split))
    352. # Network parameters
    353. conv_filters = 16
    354. kernel_size = (3, 3)
    355. pool_size = 2
    356. time_dense_size = 32
    357. rnn_size = 512
    358. minibatch_size = 32
    359. if K.image_data_format() == 'channels_first':
    360. input_shape = (1, img_w, img_h)
    361. else:
    362. input_shape = (img_w, img_h, 1)
    363. fdir = os.path.dirname(
    364. get_file('wordlists.tgz',
    365. origin='http://www.mythic-ai.com/datasets/wordlists.tgz',
    366. untar=True))
    367. img_gen = TextImageGenerator(
    368. monogram_file=os.path.join(fdir, 'wordlist_mono_clean.txt'),
    369. bigram_file=os.path.join(fdir, 'wordlist_bi_clean.txt'),
    370. minibatch_size=minibatch_size,
    371. img_w=img_w,
    372. img_h=img_h,
    373. downsample_factor=(pool_size ** 2),
    374. val_split=words_per_epoch - val_words)
    375. act = 'relu'
    376. input_data = Input(name='the_input', shape=input_shape, dtype='float32')
    377. inner = Conv2D(conv_filters, kernel_size, padding='same',
    378. activation=act, kernel_initializer='he_normal',
    379. name='conv1')(input_data)
    380. inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1')(inner)
    381. inner = Conv2D(conv_filters, kernel_size, padding='same',
    382. activation=act, kernel_initializer='he_normal',
    383. name='conv2')(inner)
    384. inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2')(inner)
    385. conv_to_rnn_dims = (img_w // (pool_size ** 2),
    386. (img_h // (pool_size ** 2)) * conv_filters)
    387. inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner)
    388. # cuts down input size going into RNN:
    389. inner = Dense(time_dense_size, activation=act, name='dense1')(inner)
    390. # Two layers of bidirectional GRUs
    391. # GRU seems to work as well, if not better than LSTM:
    392. gru_1 = GRU(rnn_size, return_sequences=True,
    393. kernel_initializer='he_normal', name='gru1')(inner)
    394. gru_1b = GRU(rnn_size, return_sequences=True,
    395. go_backwards=True, kernel_initializer='he_normal',
    396. name='gru1_b')(inner)
    397. gru1_merged = add([gru_1, gru_1b])
    398. gru_2 = GRU(rnn_size, return_sequences=True,
    399. kernel_initializer='he_normal', name='gru2')(gru1_merged)
    400. gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True,
    401. kernel_initializer='he_normal', name='gru2_b')(gru1_merged)
    402. # transforms RNN output to character activations:
    403. inner = Dense(img_gen.get_output_size(), kernel_initializer='he_normal',
    404. name='dense2')(concatenate([gru_2, gru_2b]))
    405. y_pred = Activation('softmax', name='softmax')(inner)
    406. Model(inputs=input_data, outputs=y_pred).summary()
    407. labels = Input(name='the_labels',
    408. shape=[img_gen.absolute_max_string_len], dtype='float32')
    409. input_length = Input(name='input_length', shape=[1], dtype='int64')
    410. label_length = Input(name='label_length', shape=[1], dtype='int64')
    411. # Keras doesn't currently support loss funcs with extra parameters
    412. # so CTC loss is implemented in a lambda layer
    413. loss_out = Lambda(
    414. ctc_lambda_func, output_shape=(1,),
    415. name='ctc')([y_pred, labels, input_length, label_length])
    416. # clipnorm seems to speeds up convergence
    417. sgd = SGD(learning_rate=0.02,
    418. decay=1e-6,
    419. momentum=0.9,
    420. nesterov=True)
    421. model = Model(inputs=[input_data, labels, input_length, label_length],
    422. outputs=loss_out)
    423. # the loss calc occurs elsewhere, so use a dummy lambda func for the loss
    424. model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)
    425. if start_epoch > 0:
    426. weight_file = os.path.join(
    427. OUTPUT_DIR,
    428. os.path.join(run_name, 'weights%02d.h5' % (start_epoch - 1)))
    429. model.load_weights(weight_file)
    430. # captures output of softmax so we can decode the output during visualization
    431. test_func = K.function([input_data], [y_pred])
    432. viz_cb = VizCallback(run_name, test_func, img_gen.next_val())
    433. model.fit_generator(
    434. generator=img_gen.next_train(),
    435. steps_per_epoch=(words_per_epoch - val_words) // minibatch_size,
    436. epochs=stop_epoch,
    437. validation_data=img_gen.next_val(),
    438. validation_steps=val_words // minibatch_size,
    439. callbacks=[viz_cb, img_gen],
    440. initial_epoch=start_epoch)
    441. if __name__ == '__main__':
    442. run_name = datetime.datetime.now().strftime('%Y:%m:%d:%H:%M:%S')
    443. train(run_name, 0, 20, 128)
    444. # increase to wider images and start at epoch 20.
    445. # The learned weights are reloaded
    446. train(run_name, 20, 25, 512)