Optical character recognition

    1. import os
    2. import itertools
    3. import codecs
    4. import re
    5. import datetime
    6. import cairocffi as cairo
    7. import editdistance
    8. import numpy as np
    9. from scipy import ndimage
    10. import pylab
    11. from keras import backend as K
    12. from keras.layers.convolutional import Conv2D, MaxPooling2D
    13. from keras.layers import Input, Dense, Activation
    14. from keras.layers import Reshape, Lambda
    15. from keras.layers.merge import add, concatenate
    16. from keras.models import Model
    17. from keras.layers.recurrent import GRU
    18. from keras.optimizers import SGD
    19. from keras.utils.data_utils import get_file
    20. from keras.preprocessing import image
    21. import keras.callbacks
    22. OUTPUT_DIR = 'image_ocr'
    23. # character classes and matching regex filter
    24. regex = r'^[a-z ]+$'
    25. alphabet = u'abcdefghijklmnopqrstuvwxyz '
    26. np.random.seed(55)
    27. # this creates larger "blotches" of noise which look
    28. # more realistic than just adding gaussian noise
    29. # assumes greyscale with pixels ranging from 0 to 1
    30. def speckle(img):
    31. severity = np.random.uniform(0, 0.6)
    32. blur = ndimage.gaussian_filter(np.random.randn(*img.shape) * severity, 1)
    33. img_speck = (img + blur)
    34. img_speck[img_speck > 1] = 1
    35. img_speck[img_speck <= 0] = 0
    36. return img_speck
    37. # paints the string in a random location the bounding box
    38. # also uses a random font, a slight random rotation,
    39. # and a random amount of speckle noise
    40. def paint_text(text, w, h, rotate=False, ud=False, multi_fonts=False):
    41. surface = cairo.ImageSurface(cairo.FORMAT_RGB24, w, h)
    42. with cairo.Context(surface) as context:
    43. context.set_source_rgb(1, 1, 1) # White
    44. context.paint()
    45. # this font list works in CentOS 7
    46. if multi_fonts:
    47. fonts = [
    48. 'Century Schoolbook', 'Courier', 'STIX',
    49. 'URW Chancery L', 'FreeMono']
    50. context.select_font_face(
    51. np.random.choice(fonts),
    52. cairo.FONT_SLANT_NORMAL,
    53. np.random.choice([cairo.FONT_WEIGHT_BOLD, cairo.FONT_WEIGHT_NORMAL]))
    54. else:
    55. context.select_font_face('Courier',
    56. cairo.FONT_SLANT_NORMAL,
    57. cairo.FONT_WEIGHT_BOLD)
    58. context.set_font_size(25)
    59. box = context.text_extents(text)
    60. border_w_h = (4, 4)
    61. if box[2] > (w - 2 * border_w_h[1]) or box[3] > (h - 2 * border_w_h[0]):
    62. raise IOError(('Could not fit string into image.'
    63. 'Max char count is too large for given image width.'))
    64. # teach the RNN translational invariance by
    65. # fitting text box randomly on canvas, with some room to rotate
    66. max_shift_x = w - box[2] - border_w_h[0]
    67. max_shift_y = h - box[3] - border_w_h[1]
    68. top_left_x = np.random.randint(0, int(max_shift_x))
    69. if ud:
    70. top_left_y = np.random.randint(0, int(max_shift_y))
    71. else:
    72. top_left_y = h // 2
    73. context.move_to(top_left_x - int(box[0]), top_left_y - int(box[1]))
    74. context.set_source_rgb(0, 0, 0)
    75. context.show_text(text)
    76. buf = surface.get_data()
    77. a = np.frombuffer(buf, np.uint8)
    78. a.shape = (h, w, 4)
    79. a = a[:, :, 0] # grab single channel
    80. a = a.astype(np.float32) / 255
    81. a = np.expand_dims(a, 0)
    82. if rotate:
    83. a = image.random_rotation(a, 3 * (w - top_left_x) / w + 1)
    84. a = speckle(a)
    85. return a
    86. def shuffle_mats_or_lists(matrix_list, stop_ind=None):
    87. ret = []
    88. assert all([len(i) == len(matrix_list[0]) for i in matrix_list])
    89. len_val = len(matrix_list[0])
    90. if stop_ind is None:
    91. stop_ind = len_val
    92. assert stop_ind <= len_val
    93. a = list(range(stop_ind))
    94. np.random.shuffle(a)
    95. a += list(range(stop_ind, len_val))
    96. for mat in matrix_list:
    97. if isinstance(mat, np.ndarray):
    98. ret.append(mat[a])
    99. elif isinstance(mat, list):
    100. ret.append([mat[i] for i in a])
    101. else:
    102. raise TypeError('`shuffle_mats_or_lists` only supports '
    103. 'numpy.array and list objects.')
    104. return ret
    105. # Translation of characters to unique integer values
    106. def text_to_labels(text):
    107. ret = []
    108. for char in text:
    109. ret.append(alphabet.find(char))
    110. return ret
    111. # Reverse translation of numerical classes back to characters
    112. def labels_to_text(labels):
    113. ret = []
    114. for c in labels:
    115. if c == len(alphabet): # CTC Blank
    116. ret.append("")
    117. else:
    118. ret.append(alphabet[c])
    119. return "".join(ret)
    120. # only a-z and space..probably not to difficult
    121. # to expand to uppercase and symbols
    122. def is_valid_str(in_str):
    123. search = re.compile(regex, re.UNICODE).search
    124. return bool(search(in_str))
    125. # Uses generator functions to supply train/test with
    126. # data. Image renderings and text are created on the fly
    127. # each time with random perturbations
    128. class TextImageGenerator(keras.callbacks.Callback):
    129. def __init__(self, monogram_file, bigram_file, minibatch_size,
    130. img_w, img_h, downsample_factor, val_split,
    131. absolute_max_string_len=16):
    132. self.minibatch_size = minibatch_size
    133. self.img_w = img_w
    134. self.img_h = img_h
    135. self.monogram_file = monogram_file
    136. self.bigram_file = bigram_file
    137. self.downsample_factor = downsample_factor
    138. self.val_split = val_split
    139. self.blank_label = self.get_output_size() - 1
    140. self.absolute_max_string_len = absolute_max_string_len
    141. def get_output_size(self):
    142. # num_words can be independent of the epoch size due to the use of generators
    143. # as max_string_len grows, num_words can grow
    144. def build_word_list(self, num_words, max_string_len=None, mono_fraction=0.5):
    145. assert max_string_len <= self.absolute_max_string_len
    146. assert num_words % self.minibatch_size == 0
    147. assert (self.val_split * num_words) % self.minibatch_size == 0
    148. self.num_words = num_words
    149. self.string_list = [''] * self.num_words
    150. tmp_string_list = []
    151. self.max_string_len = max_string_len
    152. self.Y_data = np.ones([self.num_words, self.absolute_max_string_len]) * -1
    153. self.X_text = []
    154. self.Y_len = [0] * self.num_words
    155. def _is_length_of_word_valid(word):
    156. return (max_string_len == -1 or
    157. max_string_len is None or
    158. len(word) <= max_string_len)
    159. # monogram file is sorted by frequency in english speech
    160. with codecs.open(self.monogram_file, mode='r', encoding='utf-8') as f:
    161. for line in f:
    162. if len(tmp_string_list) == int(self.num_words * mono_fraction):
    163. break
    164. word = line.rstrip()
    165. if _is_length_of_word_valid(word):
    166. tmp_string_list.append(word)
    167. # bigram file contains common word pairings in english speech
    168. with codecs.open(self.bigram_file, mode='r', encoding='utf-8') as f:
    169. lines = f.readlines()
    170. for line in lines:
    171. if len(tmp_string_list) == self.num_words:
    172. break
    173. columns = line.lower().split()
    174. word = columns[0] + ' ' + columns[1]
    175. if is_valid_str(word) and _is_length_of_word_valid(word):
    176. tmp_string_list.append(word)
    177. if len(tmp_string_list) != self.num_words:
    178. raise IOError('Could not pull enough words'
    179. 'from supplied monogram and bigram files.')
    180. # interlace to mix up the easy and hard words
    181. self.string_list[::2] = tmp_string_list[:self.num_words // 2]
    182. self.string_list[1::2] = tmp_string_list[self.num_words // 2:]
    183. for i, word in enumerate(self.string_list):
    184. self.Y_len[i] = len(word)
    185. self.Y_data[i, 0:len(word)] = text_to_labels(word)
    186. self.X_text.append(word)
    187. self.Y_len = np.expand_dims(np.array(self.Y_len), 1)
    188. self.cur_val_index = self.val_split
    189. self.cur_train_index = 0
    190. # each time an image is requested from train/val/test, a new random
    191. # painting of the text is performed
    192. def get_batch(self, index, size, train):
    193. # width and height are backwards from typical Keras convention
    194. # because width is the time dimension when it gets fed into the RNN
    195. if K.image_data_format() == 'channels_first':
    196. X_data = np.ones([size, 1, self.img_w, self.img_h])
    197. else:
    198. X_data = np.ones([size, self.img_w, self.img_h, 1])
    199. labels = np.ones([size, self.absolute_max_string_len])
    200. input_length = np.zeros([size, 1])
    201. label_length = np.zeros([size, 1])
    202. source_str = []
    203. for i in range(size):
    204. # Mix in some blank inputs. This seems to be important for
    205. # achieving translational invariance
    206. if train and i > size - 4:
    207. if K.image_data_format() == 'channels_first':
    208. X_data[i, 0, 0:self.img_w, :] = self.paint_func('')[0, :, :].T
    209. else:
    210. X_data[i, 0:self.img_w, :, 0] = self.paint_func('',)[0, :, :].T
    211. labels[i, 0] = self.blank_label
    212. input_length[i] = self.img_w // self.downsample_factor - 2
    213. label_length[i] = 1
    214. source_str.append('')
    215. else:
    216. if K.image_data_format() == 'channels_first':
    217. X_data[i, 0, 0:self.img_w, :] = (
    218. self.paint_func(self.X_text[index + i])[0, :, :].T)
    219. else:
    220. X_data[i, 0:self.img_w, :, 0] = (
    221. self.paint_func(self.X_text[index + i])[0, :, :].T)
    222. labels[i, :] = self.Y_data[index + i]
    223. input_length[i] = self.img_w // self.downsample_factor - 2
    224. label_length[i] = self.Y_len[index + i]
    225. source_str.append(self.X_text[index + i])
    226. inputs = {'the_input': X_data,
    227. 'the_labels': labels,
    228. 'input_length': input_length,
    229. 'label_length': label_length,
    230. 'source_str': source_str # used for visualization only
    231. }
    232. outputs = {'ctc': np.zeros([size])} # dummy data for dummy loss function
    233. return (inputs, outputs)
    234. def next_train(self):
    235. while 1:
    236. ret = self.get_batch(self.cur_train_index,
    237. self.minibatch_size, train=True)
    238. self.cur_train_index += self.minibatch_size
    239. if self.cur_train_index >= self.val_split:
    240. self.cur_train_index = self.cur_train_index % 32
    241. (self.X_text, self.Y_data, self.Y_len) = shuffle_mats_or_lists(
    242. [self.X_text, self.Y_data, self.Y_len], self.val_split)
    243. yield ret
    244. def next_val(self):
    245. while 1:
    246. ret = self.get_batch(self.cur_val_index,
    247. self.minibatch_size, train=False)
    248. self.cur_val_index += self.minibatch_size
    249. if self.cur_val_index >= self.num_words:
    250. self.cur_val_index = self.val_split + self.cur_val_index % 32
    251. yield ret
    252. def on_train_begin(self, logs={}):
    253. self.build_word_list(16000, 4, 1)
    254. self.paint_func = lambda text: paint_text(
    255. text, self.img_w, self.img_h,
    256. rotate=False, ud=False, multi_fonts=False)
    257. def on_epoch_begin(self, epoch, logs={}):
    258. # rebind the paint function to implement curriculum learning
    259. if 3 <= epoch < 6:
    260. self.paint_func = lambda text: paint_text(
    261. text, self.img_w, self.img_h,
    262. rotate=False, ud=True, multi_fonts=False)
    263. elif 6 <= epoch < 9:
    264. self.paint_func = lambda text: paint_text(
    265. text, self.img_w, self.img_h,
    266. rotate=False, ud=True, multi_fonts=True)
    267. elif epoch >= 9:
    268. self.paint_func = lambda text: paint_text(
    269. text, self.img_w, self.img_h,
    270. rotate=True, ud=True, multi_fonts=True)
    271. if epoch >= 21 and self.max_string_len < 12:
    272. self.build_word_list(32000, 12, 0.5)
    273. # the actual loss calc occurs here despite it not being
    274. # an internal Keras loss function
    275. def ctc_lambda_func(args):
    276. y_pred, labels, input_length, label_length = args
    277. # the 2 is critical here since the first couple outputs of the RNN
    278. # tend to be garbage:
    279. y_pred = y_pred[:, 2:, :]
    280. return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
    281. # For a real OCR application, this should be beam search with a dictionary
    282. # and language model. For this example, best path is sufficient.
    283. def decode_batch(test_func, word_batch):
    284. out = test_func([word_batch])[0]
    285. ret = []
    286. for j in range(out.shape[0]):
    287. out_best = list(np.argmax(out[j, 2:], 1))
    288. out_best = [k for k, g in itertools.groupby(out_best)]
    289. outstr = labels_to_text(out_best)
    290. ret.append(outstr)
    291. return ret
    292. def __init__(self, run_name, test_func, text_img_gen, num_display_words=6):
    293. self.test_func = test_func
    294. self.output_dir = os.path.join(
    295. OUTPUT_DIR, run_name)
    296. self.text_img_gen = text_img_gen
    297. self.num_display_words = num_display_words
    298. if not os.path.exists(self.output_dir):
    299. os.makedirs(self.output_dir)
    300. def show_edit_distance(self, num):
    301. num_left = num
    302. mean_norm_ed = 0.0
    303. mean_ed = 0.0
    304. while num_left > 0:
    305. word_batch = next(self.text_img_gen)[0]
    306. num_proc = min(word_batch['the_input'].shape[0], num_left)
    307. decoded_res = decode_batch(self.test_func,
    308. word_batch['the_input'][0:num_proc])
    309. for j in range(num_proc):
    310. edit_dist = editdistance.eval(decoded_res[j],
    311. word_batch['source_str'][j])
    312. mean_ed += float(edit_dist)
    313. mean_norm_ed += float(edit_dist) / len(word_batch['source_str'][j])
    314. num_left -= num_proc
    315. mean_norm_ed = mean_norm_ed / num
    316. mean_ed = mean_ed / num
    317. print('\nOut of %d samples: Mean edit distance:'
    318. '%.3f Mean normalized edit distance: %0.3f'
    319. % (num, mean_ed, mean_norm_ed))
    320. def on_epoch_end(self, epoch, logs={}):
    321. self.model.save_weights(
    322. os.path.join(self.output_dir, 'weights%02d.h5' % (epoch)))
    323. self.show_edit_distance(256)
    324. word_batch = next(self.text_img_gen)[0]
    325. res = decode_batch(self.test_func,
    326. word_batch['the_input'][0:self.num_display_words])
    327. if word_batch['the_input'][0].shape[0] < 256:
    328. cols = 2
    329. else:
    330. cols = 1
    331. for i in range(self.num_display_words):
    332. pylab.subplot(self.num_display_words // cols, cols, i + 1)
    333. if K.image_data_format() == 'channels_first':
    334. the_input = word_batch['the_input'][i, 0, :, :]
    335. else:
    336. the_input = word_batch['the_input'][i, :, :, 0]
    337. pylab.imshow(the_input.T, cmap='Greys_r')
    338. pylab.xlabel(
    339. 'Truth = \'%s\'\nDecoded = \'%s\'' %
    340. (word_batch['source_str'][i], res[i]))
    341. fig = pylab.gcf()
    342. fig.set_size_inches(10, 13)
    343. pylab.savefig(os.path.join(self.output_dir, 'e%02d.png' % (epoch)))
    344. pylab.close()
    345. def train(run_name, start_epoch, stop_epoch, img_w):
    346. # Input Parameters
    347. img_h = 64
    348. words_per_epoch = 16000
    349. val_split = 0.2
    350. val_words = int(words_per_epoch * (val_split))
    351. # Network parameters
    352. conv_filters = 16
    353. kernel_size = (3, 3)
    354. pool_size = 2
    355. time_dense_size = 32
    356. rnn_size = 512
    357. minibatch_size = 32
    358. if K.image_data_format() == 'channels_first':
    359. input_shape = (1, img_w, img_h)
    360. else:
    361. input_shape = (img_w, img_h, 1)
    362. fdir = os.path.dirname(
    363. get_file('wordlists.tgz',
    364. origin='http://www.mythic-ai.com/datasets/wordlists.tgz',
    365. untar=True))
    366. img_gen = TextImageGenerator(
    367. monogram_file=os.path.join(fdir, 'wordlist_mono_clean.txt'),
    368. bigram_file=os.path.join(fdir, 'wordlist_bi_clean.txt'),
    369. minibatch_size=minibatch_size,
    370. img_w=img_w,
    371. img_h=img_h,
    372. downsample_factor=(pool_size ** 2),
    373. val_split=words_per_epoch - val_words)
    374. act = 'relu'
    375. input_data = Input(name='the_input', shape=input_shape, dtype='float32')
    376. inner = Conv2D(conv_filters, kernel_size, padding='same',
    377. activation=act, kernel_initializer='he_normal',
    378. name='conv1')(input_data)
    379. inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1')(inner)
    380. inner = Conv2D(conv_filters, kernel_size, padding='same',
    381. activation=act, kernel_initializer='he_normal',
    382. name='conv2')(inner)
    383. inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2')(inner)
    384. conv_to_rnn_dims = (img_w // (pool_size ** 2),
    385. (img_h // (pool_size ** 2)) * conv_filters)
    386. inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner)
    387. # cuts down input size going into RNN:
    388. inner = Dense(time_dense_size, activation=act, name='dense1')(inner)
    389. # Two layers of bidirectional GRUs
    390. # GRU seems to work as well, if not better than LSTM:
    391. gru_1 = GRU(rnn_size, return_sequences=True,
    392. kernel_initializer='he_normal', name='gru1')(inner)
    393. gru_1b = GRU(rnn_size, return_sequences=True,
    394. go_backwards=True, kernel_initializer='he_normal',
    395. name='gru1_b')(inner)
    396. gru1_merged = add([gru_1, gru_1b])
    397. gru_2 = GRU(rnn_size, return_sequences=True,
    398. kernel_initializer='he_normal', name='gru2')(gru1_merged)
    399. gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True,
    400. kernel_initializer='he_normal', name='gru2_b')(gru1_merged)
    401. # transforms RNN output to character activations:
    402. inner = Dense(img_gen.get_output_size(), kernel_initializer='he_normal',
    403. name='dense2')(concatenate([gru_2, gru_2b]))
    404. y_pred = Activation('softmax', name='softmax')(inner)
    405. Model(inputs=input_data, outputs=y_pred).summary()
    406. labels = Input(name='the_labels',
    407. shape=[img_gen.absolute_max_string_len], dtype='float32')
    408. input_length = Input(name='input_length', shape=[1], dtype='int64')
    409. label_length = Input(name='label_length', shape=[1], dtype='int64')
    410. # Keras doesn't currently support loss funcs with extra parameters
    411. # so CTC loss is implemented in a lambda layer
    412. loss_out = Lambda(
    413. ctc_lambda_func, output_shape=(1,),
    414. name='ctc')([y_pred, labels, input_length, label_length])
    415. # clipnorm seems to speeds up convergence
    416. sgd = SGD(lr=0.02, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
    417. model = Model(inputs=[input_data, labels, input_length, label_length],
    418. outputs=loss_out)
    419. # the loss calc occurs elsewhere, so use a dummy lambda func for the loss
    420. model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)
    421. if start_epoch > 0:
    422. weight_file = os.path.join(
    423. OUTPUT_DIR,
    424. os.path.join(run_name, 'weights%02d.h5' % (start_epoch - 1)))
    425. model.load_weights(weight_file)
    426. # captures output of softmax so we can decode the output during visualization
    427. test_func = K.function([input_data], [y_pred])
    428. viz_cb = VizCallback(run_name, test_func, img_gen.next_val())
    429. model.fit_generator(
    430. generator=img_gen.next_train(),
    431. steps_per_epoch=(words_per_epoch - val_words) // minibatch_size,
    432. epochs=stop_epoch,
    433. validation_data=img_gen.next_val(),
    434. validation_steps=val_words // minibatch_size,
    435. callbacks=[viz_cb, img_gen],
    436. initial_epoch=start_epoch)
    437. if __name__ == '__main__':
    438. run_name = datetime.datetime.now().strftime('%Y:%m:%d:%H:%M:%S')
    439. train(run_name, 0, 20, 128)
    440. # increase to wider images and start at epoch 20.
    441. train(run_name, 20, 25, 512)