✖
RNN/LSTM/GRU の入力と stateful 化 (keras)
入力データ
(batch_size, timesteps, features) が入力になる。batch_size はこのバッチ(学習・予測の1単位)中のサンプル数 (予測するデータの数) で stateless なら None (可変長) にできる、timesteps は与える時系列の長さ、features は特徴量。
features は 例えば1次元の波形データであれば 1 になる。3つのセンサーデータがある場合なら3になる。
timesteps は固定長にもできるし、可変長にもできる。可変長の場合は None を指定する。可変長といっても、単一バッチ中の timesteps の数は揃える必要がある。
timesteps を変えても RNN レイヤーのパラメータ数は変化しない。パラメータ数は features のサイズとRNNのユニット数に依存する。
timesteps とバッチ
stateless RNN の場合、RNN の内部状態はバッチごとにリセットされる。
1つのサンプルの出力は、与えた timesteps の数からしか影響されない。また、バッチ内の各サンプルは独立している。
例えば以下のような (3, 5, 1) なバッチを与えた場合、stateless RNN は、1つ目のサンプルに関しては [1, 2, 3, 4, 5] という情報をつかって 6 を予測するようになる。2つ目のサンプルも同様で、1つ目のサンプルとは内部状態が独立して (学習した重みはもちろん共有しているが) [ 2, 3, 4, 5, 6] から 7 を予測する。
# batch_input_shape = (3, 5, 1) 1 2 3 4 5 -> 6 2 3 4 5 6 -> 7 3 4 5 6 7 -> 8
stateful RNN
stateful の場合は手動でリセットしない限り、バッチの最後の内部状態はリセットされない。
極端な例だと1回のバッチで timesteps が1というこもありうる。混乱しやすいのでサンプル数1で例を示してみる。
以下のように3回のバッチにわけてstateful RNNへ入力を与える。すべてサンプル数 (batch_size) は固定。timesteps の数は任意。stateless ではリセットしていたバッチ最後の各サンプルの最後の状態を保持しているので、batch #2 では、1つのtimestepsを与えているだけだが、それまでの timesteps である 1 2 3 4 5 も考慮された状態で予測される。
# batch #1 shape = (1, 5, 1) 1 2 3 4 5 -> 6 # batch #2 shape = (1, 1, 1) 6 -> 7 # batch #3 shape = (1, 2, 1) 7, 8-> 9
stateless to stateful
statelessで学習させて stateful に予測させるということもできる。これらの違いは内部状態をバッチ間で共有するかどうかだけなので、十分に長い系列で学習しているならあまり問題はない。この場合は、学習時に与えたステップ数しか考慮されていないが、連続で推論して新しい系列を得たいときはstateless より早く予測できる (途中の状態を保持したままなため再度過去の時系列を与える必要がない)。
stateful で学習させるのは結構めんどうくさいので、stateless で十分に長い時系列を与えて学習させて、リアルタイム予測などで実際に使うときは stateful にするというやりかたはありかもしれない。
stateless to statefulの例
サンプルデータ
意味がある予測ではないけど、試しにやってみた。
space = np.linspace(0, 1, 1000)
data = signal.square(2 * np.pi * 50 * space) * signal.square(2 * np.pi * 30 * space) / 2 + 0.5
こういう一時データをRNN(GRU)で学習させてみる。急激な値の変化があると学習結果が面白いことが多いので矩形波にしてる。
stateless で学習させる
def create_model(batch_size=None, timesteps=None, stateful=False):
inputs = Input(batch_shape=(batch_size, None, 1))
x = inputs
x = GRU(32, stateful=stateful, return_sequences=False)(x)
x = Dense(1, activation='sigmoid')(x)
outputs = x
model = Model(inputs, outputs)
model.compile(optimizer=keras.optimizers.Adam(lr=0.01), loss='binary_crossentropy', metrics=['accuracy'])
return model
model = create_model(timesteps = 100)
model.summary()
gen = keras.preprocessing.sequence.TimeseriesGenerator(data[:-10].reshape( (len(data)-10), 1 ), data[10:], length=100)
model.fit_generator(gen, epochs=30, validation_data=gen, shuffle=False)
過去の状態から10ステップ先を1つ予測するという問題設定にした (特に問題に意味はない)
stateless でも stateful でも同じモデルを作れるようにして、まず stateless で学習させる。ここでは過去の系列は100ステップで学習させている。特に意味がある学習ではないので validation に同じ系列を指定してる。
stateless で予測させる
tstart = time.time()
predict_len = 500
start = 207
result = []
for n in range(predict_len):
predicted = model.predict( data[start+n-100:start+n].reshape( (1, 100, 1) ) )
result.append(predicted[0])
result = np.array( result)
elapsed = time.time() - tstart
print('predict {}ms'.format(elapsed * 1000))
plt.figure(figsize=(10,10))
plt.subplot(211)
plt.title('expected')
plt.plot(data[start+10:start+500+10])
plt.plot(result)
plt.subplot(212)
plt.title('predicted (stateless)')
plt.plot(result)
plt.show()
result_stateless = result
後々、stateful とのコードと合わせるため一括で predict せずサンプル数1つで予測させてる。
一応一通り予測できている。ちなみに "predict 6278.290271759033ms" かかった。
statefulモデルにする
model.save_weights('/tmp/param.hdf5')
model_stateful = create_model(batch_size = 1, stateful=True)
model_stateful.load_weights('/tmp/param.hdf5')
model_stateful.summary()
weight を保存して、stateful=True にしたモデルに作りなおして weight をロードする。つまり stateless のモデルと weight は一緒の状態になる。
stateful モデルで予測する
tstart = time.time()
start = 207
result = []
## preload same state with stateless
# model_stateful.predict( data[start-1-100:start-1].reshape( (1, 100, 1) ) )
for n in range(predict_len):
# give just new 1 timestep
predicted = model_stateful.predict( data[start+n-1:start+n].reshape( (1, 1, 1) ) )
result.append(predicted[0])
result = np.array( result)
elapsed = time.time() - tstart
print('predict {}ms'.format(elapsed * 1000))
plt.figure(figsize=(10,10))
plt.subplot(311)
plt.title('predicted (stateless)')
plt.plot(result_stateless)
plt.subplot(312)
plt.title('predicted (stateful)')
plt.plot(result)
plt.subplot(313)
plt.title('predicted (stateful+stateless)')
plt.plot(result_stateless)
plt.plot(result)
plt.show()
stateful では毎回1ステップだけ与えて予測させる。最初は過去の時系列がないため乱れているけど、ステップを与え続けると stateless と同じように予測できる。stateful でもコメンアウトしたところを実行すれば最初から stateless と全く同じ状態になる。
これは "predict 327.3053169250488ms" で終わる。