-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbilstm_mnist.py
55 lines (38 loc) · 1.2 KB
/
bilstm_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -*- coding: utf-8 -*-
"""
Created on Wed Feb 27 01:18:45 2019
@author: tanma
"""
import os
from keras.models import Model
from keras.layers import Input, CuDNNLSTM, CuDNNGRU, Bidirectional
from keras.layers import GlobalMaxPooling1D, Lambda, Concatenate, Dense
from keras.datasets import mnist
import keras.backend as K
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
(x_train, y_train),(x_test,y_test) = mnist.load_data()
D = 28
M = 15
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
input_ = Input(shape=(D, D))
rnn1 = Bidirectional(CuDNNLSTM(M, return_sequences=True))
x1 = rnn1(input_)
x1 = GlobalMaxPooling1D()(x1)
rnn2 = Bidirectional(CuDNNLSTM(M, return_sequences=True))
permutor = Lambda(lambda t: K.permute_dimensions(t, pattern=(0, 2, 1)))
x2 = permutor(input_)
x2 = rnn2(x2)
x2 = GlobalMaxPooling1D()(x2)
concatenator = Concatenate(axis=1)
x = concatenator([x1, x2])
output = Dense(10, activation='softmax')(x)
model = Model(inputs=input_, outputs=output)
model.compile(
loss='sparse_categorical_crossentropy',
optimizer='adam',
metrics=['accuracy']
)
model.fit(x_train, y_train, batch_size=32, epochs=10, validation_split = 0.25)