ComplexAnalyserNode (WebAudio) を作った (IQ信号のFFT) | tech - 氾濫原 に続き、WebAssembly を使って複素 FIR Filter を行う AudioWorkletNode の実装を書いてみた。前回と同様 Rust と wasm_bindgen を使っている。Rust 側の実装はとても素朴。

Analyser と違うのは、直接信号に手を加える必要があるというところ。つまり AudioWorklet 内で wasm の実装を呼ぶ必要がある。

wasm_bindgen を使おうと思ったが…

wasm_bindgen の JS 側の実装は Uint8Array を文字列化するために TextEncoder を使っている。しかし TextEncoder は AudioWorkletGlobalScope には(今のところ)存在しておらず、エラーになってしまった。

自動生成されたコードに手を入れて使うのはあまりやりたくないのと、どっちにしろメモリ管理を自力でやる必要はあるので、wasm_bindgen の生成するJSコードは使わず、直接 wasm を呼びだすようにした。

なおエラーメッセージの転送などで、どうしても Uint8Array から文字列を生成したいケースはある。今回は ASCII 以外の文字が入ることは想定していないので TextEncoder の代わりに String.fromCharCode() でお茶を濁した。

wasm module を AudioWorkletGlobalScope に受け渡す

AudioWorkletGlobalScope にはそもそも fetch もないため、wasm のコードを AudioWorkletGlobalScope 内で直接読みこんでインスタンス化することができない。

どうするかというと、メインスレッド(など)で fetch を行い、wasm のモジュールを得てから、これを postMessage で transfer するという余計な手順が必要になる。

wasm module は postMessage ができるが、wasm の instance はできないので、instance 化は AudioWorkletGlobalScope 側でやる必要がある。

  1. トップ
  2. tech
  3. WebAudio ComplexFirFilterNode AudioWorklet

https://github.com/cho45/complex-analyser-node

WebAudio の AnalyserNode は実数計算しかしないタイプ (仕様でそう決まっている) ですが、IQの2ch入力の信号を FFT して表示したいので、ほぼ類似のAPIを持つComplexAnalyserNodeを作ってみました。

簡単な割に IQ 信号を WebAudio で処理する際のデバッグに便利です。

FFT部分はwasmにコンパイルしたrustfftを呼んでおり、自分では実装していません。

AudioWorkletNode と AudioWorkletProcessorの使いわけ

AudioWorkletNode と AudioWorkletProcessor をうまく使いわける必要があるのですが、今回は以下のようにしています

  • AudioWorkletProcessor (audio スレッドで動く)
    • 入力をそのまま出力にコピーするだけ
    • 入力をバッファとして貯めこみ、port 経由で AudioWorkletNode にそのまま転送する
  • AudioWorkletNode (メインスレッドで動く)
    • AudioWorkletProcessor からくるバッファを管理し、FFT 用のバッファを保持する
    • getFloatFrequencyData() に応じてバッファの内容をFFTして返す
    • 対数変換やUSB/LSBの並び換え(fftshift)も rust でやってます

こういう構成なので、wasm は普通にメインスレッドで読みこんでメインスレッドで使っています。Analyser の場合は audio スレッドで信号内容に手を加えるということはしないので、余計なことを audio スレッドでやらせたくないという気持ちがあります。

もともと WebAudio にある AnalyserNode とインターフェイスを似せようとすると同期的にしなくてはいけないので若干制約があります。全部非同期にすればもうちょいやりようがある気はします。

  1. トップ
  2. tech
  3. ComplexAnalyserNode (WebAudio) を作った (IQ信号のFFT)

習作としてFIR (Finite Impulse Response) フィルタの可視化をつくってみた。

FIRフィルタのcoefficient(係数)をJSONで張りつけると、係数のグラフと、その周波数特性を表示する。複素数対応

  1. トップ
  2. tech
  3. FIRフィルタの可視化

https://developer.mozilla.org/ja/docs/WebAssembly/Rust_to_wasm に書いてある通りで便利。alert 出しても面白くないのでFFT のベンチをとってみるというのをやってみた。

こういう感じ

タスクは N=4096 の複素数のFFTをやることとした。rust 側のコードは rustfftを呼ぶだけ。

以下のような組合せで実行

  • [wasm] instance wasm_bindgen
    • wasm_bindgen が生成した struct のJSブリッジを使って普通にAPIを呼ぶ
      • API呼び出し時に必ずJSからWASMにTypedArray のコピーが発生する
  • [wasm] instance pointer
    • wasm_bindgen が生成した wasm を直接使って struct のポインタ、バッファのポインタを自力で管理する
    • WASMからメモリを確保して直接使っているのでコピーが減る
  • [wasm] one func pointer, one func wasm_bindgen
    • rust 側の FFT インスタンスを使い捨てるバージョン。あまり意味はないが事前計算がどれぐらい重いのかわかる。
  • [js] instance dsp.js
    • dsp.js の FFT インスタンスを使う場合。厳密には入力・出力が普通のFFTと違うので、参考程度 (入力が実数だけコピーもしてない、出力の振幅を余計に計算しているなど)
N = 4096
[wasm] instance pointer x 8,361 ops/sec ±0.23% (97 runs sampled)
[wasm] instance wasm_bindgen x 7,869 ops/sec ±0.23% (99 runs sampled)
[wasm] one func pointer x 2,730 ops/sec ±0.48% (95 runs sampled)
[wasm] one func wasm_bindgen x 2,753 ops/sec ±0.34% (97 runs sampled)
[js] instance dsp.js x 3,722 ops/sec ±0.74% (92 runs sampled)
Fastest is [wasm] instance pointer
+------------------------------+--------------------------+----------------------------+---------------------------------+-------------------------+---------------------------------+----------------------------+
| name                         | ops                      | vs [wasm] instance pointer | vs [wasm] instance wasm_bindgen | vs [js] instance dsp.js | vs [wasm] one func wasm_bindgen | vs [wasm] one func pointer |
+------------------------------+--------------------------+----------------------------+---------------------------------+-------------------------+---------------------------------+----------------------------+
| [wasm] instance pointer      | 8360.9ops/sec (+/-0.23%) | -                          | 6%                              | 125%                    | 204%                            | 206%                       |
+------------------------------+--------------------------+----------------------------+---------------------------------+-------------------------+---------------------------------+----------------------------+
| [wasm] instance wasm_bindgen | 7869.3ops/sec (+/-0.23%) | -6%                        | -                               | 111%                    | 186%                            | 188%                       |
+------------------------------+--------------------------+----------------------------+---------------------------------+-------------------------+---------------------------------+----------------------------+
| [js] instance dsp.js         | 3721.9ops/sec (+/-0.74%) | -55%                       | -53%                            | -                       | 35%                             | 36%                        |
+------------------------------+--------------------------+----------------------------+---------------------------------+-------------------------+---------------------------------+----------------------------+
| [wasm] one func wasm_bindgen | 2752.8ops/sec (+/-0.34%) | -67%                       | -65%                            | -26%                    | -                               | 1%                         |
+------------------------------+--------------------------+----------------------------+---------------------------------+-------------------------+---------------------------------+----------------------------+
| [wasm] one func pointer      | 2730.0ops/sec (+/-0.48%) | -67%                       | -65%                            | -27%                    | -1%                             | -                          |
+------------------------------+--------------------------+----------------------------+---------------------------------+-------------------------+---------------------------------+----------------------------+

https://github.com/cho45/wasm-fft-sketch/blob/master/sketch.js#L174

とりあえず何も考えずに pure rust のライブラリをコンパイルして呼んでいるだけなのに、wasm 版が早い (rustfft が良いのかもしれないが)。

計算の比重が高いからか、意外とメモリコピーしてていても差がでない。

wasm_bindgen の使い勝手がいい

wasm_bindgen と wasm-pack が大変使い勝手が良く、ほぼ悩むことなく即 Rust のコードを書きはじめて、またそれをすぐに JS から呼ぶことができる。内部的には Rust 側のブリッジ関数と JS 側のブリッジ関数を同時に作ってくれている。

少し効率的な実装に置き換える

ただ、wasm はメモリ空間が JS のメモリ空間と分かれているため、wasm_bindgenが生成するJS側のブリッジ関数は (便利ではあるが) 若干非効率な実装になっており、TypedArray の受け渡しではコピーが多くなる。

これを防ぐには、やはり自力で wasm 側のメモリ空間からメモリを確保して TypedArray をインスタンス化して使用し、必要なくなったら free するという、メモリ管理を自分でやる必要がある。

これはまぁまぁ面倒くさいが、生成されたJSコードを読めばどう呼べば適切かは容易にわかるので、とりあえずは難しいことではない。

メモリ管理は必要

生成コードをただ使う場合、関数呼び出しだけなら生成コード内でメモリのfreeが行われるので、あまり気にする必要はないが、struct に関しては生成コードをただ使っている場合でも、明示的にオブジェクトの free を呼ぶ必要がある。JS にはデストラクタがないので仕方ないが、注意がいる。

  1. トップ
  2. tech
  3. Rust + wasm の環境が wasm_bindgen でめっちゃ簡単になっていた

Benchmark.js ちゃんと使えるので良いのですが、計測を頑張っている割に結果表示が貧弱というのが悲しいところです。

なので Perl の Benchmark.pm 風に表示する complete の関数を書いてみました。cli-table に依存します。

for x 115,102,309 ops/sec ±0.43% (95 runs sampled)
for of x 62,020,029 ops/sec ±0.23% (96 runs sampled)
Fastest is for
+--------+-------------------------------+--------+-----------+
| name   | ops                           | vs for | vs for of |
+--------+-------------------------------+--------+-----------+
| for    | 115102308.6ops/sec (+/-0.43%) | -      | 86%       |
+--------+-------------------------------+--------+-----------+
| for of | 62020028.7ops/sec (+/-0.23%)  | -46%   | -         |
+--------+-------------------------------+--------+-----------+
//#!/usr/bin/env node

"use strict";

const Benchmark = require('benchmark');
const Table = require('cli-table');
const array = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
new Benchmark.Suite().
	add('for', () => {
		let sum = 0;
		for (let i = 0, len = array.length; i < len; i++) {
			sum += array[i];
		}
		return sum;
	}).
	add('for of', () => {
		let sum = 0;
		for (let i of array) {
			sum += i;
		}
		return sum;
	}).
	on('cycle', function(event) {
		console.log(String(event.target));
	}).
	on('complete', function() {
		console.log('Fastest is ' + this.filter('fastest').map('name'));

		const array = this.slice(0).sort( (a, b) => b.hz - a.hz);
		const table = new Table({
			chars: {
				'top': '-' ,
				'top-mid': '+' ,
				'top-left': '+' ,
				'top-right': '+',
				'bottom': '-' ,
				'bottom-mid': '+' ,
				'bottom-left': '+' ,
				'bottom-right': '+',
				'left': '|' ,
				'left-mid': '+' ,
				'mid': '-' ,
				'mid-mid': '+',
				'right': '|' ,
				'right-mid': '+' ,
				'middle': '|'
			},
			head: ['name', 'ops'].concat( array.map( b => 'vs ' + b.name ) )
		});
		const comparison = array.map( (a, ia) => array.map( (b, ib) => {
			if (ia === ib) return "-";
			return Math.round((a.hz / b.hz - 1) * 100) + '%';
		}));
		array.forEach( (bench, i) => {
			table.push([
				bench.name,
				`${bench.hz.toFixed(1)}ops/sec (+/-${bench.stats.rme.toFixed(2)}%)`
			].concat(comparison[i]))
		});
		console.log(table.toString());
	}).
	run({});
  1. トップ
  2. tech
  3. Benchmark.js の結果表示を改善する

input channel 数はちゃんと2なのに同じデータが入っている

モノラルになるとかじゃなく、ステレオ入力すると、デフォルトではモノラル結合されて、同じデータが2チャンネルにコピーされて流れてくる。

解決方法

echoCancelation が有効だとこういう挙動になるらしく、これを止めさせるとちゃんとステレオデータがとれる。getUserMedia で以下のように指定する。余計なことをできるだけ止めさせたい場合はいろいろ指定する必要がある。

const stream = await navigator.mediaDevices.getUserMedia({
        audio: {
                channelCount: {ideal: 2, min: 1},
                echoCancellation: { exact: false },
                noiseSuppression: { exact: false },
                autoGainControl:{ exact:  false },  
        }
});

ref

  1. トップ
  2. tech
  3. Chrome の WebAudio でステレオ入力ができない場合

入力データ

(batch_size, timesteps, features) が入力になる。batch_size はこのバッチ(学習・予測の1単位)中のサンプル数 (予測するデータの数) で stateless なら None (可変長) にできる、timesteps は与える時系列の長さ、features は特徴量。

features は 例えば1次元の波形データであれば 1 になる。3つのセンサーデータがある場合なら3になる。

timesteps は固定長にもできるし、可変長にもできる。可変長の場合は None を指定する。可変長といっても、単一バッチ中の timesteps の数は揃える必要がある。

timesteps を変えても RNN レイヤーのパラメータ数は変化しない。パラメータ数は features のサイズとRNNのユニット数に依存する。

timesteps とバッチ

stateless RNN の場合、RNN の内部状態はバッチごとにリセットされる。

1つのサンプルの出力は、与えた timesteps の数からしか影響されない。また、バッチ内の各サンプルは独立している。

例えば以下のような (3, 5, 1) なバッチを与えた場合、stateless RNN は、1つ目のサンプルに関しては [1, 2, 3, 4, 5] という情報をつかって 6 を予測するようになる。2つ目のサンプルも同様で、1つ目のサンプルとは内部状態が独立して (学習した重みはもちろん共有しているが) [ 2, 3, 4, 5, 6] から 7 を予測する。

# batch_input_shape = (3, 5, 1)
1 2 3 4 5 -> 6
2 3 4 5 6 -> 7
3 4 5 6 7 -> 8

stateful RNN

stateful の場合は手動でリセットしない限り、バッチの最後の内部状態はリセットされない。

極端な例だと1回のバッチで timesteps が1というこもありうる。混乱しやすいのでサンプル数1で例を示してみる。

以下のように3回のバッチにわけてstateful RNNへ入力を与える。すべてサンプル数 (batch_size) は固定。timesteps の数は任意。stateless ではリセットしていたバッチ最後の各サンプルの最後の状態を保持しているので、batch #2 では、1つのtimestepsを与えているだけだが、それまでの timesteps である 1 2 3 4 5 も考慮された状態で予測される。

# batch #1 shape = (1, 5, 1)
1 2 3 4 5 -> 6

# batch #2 shape = (1, 1, 1)
6 -> 7

# batch #3 shape = (1, 2, 1)
7, 8-> 9

stateless to stateful

statelessで学習させて stateful に予測させるということもできる。これらの違いは内部状態をバッチ間で共有するかどうかだけなので、十分に長い系列で学習しているならあまり問題はない。この場合は、学習時に与えたステップ数しか考慮されていないが、連続で推論して新しい系列を得たいときはstateless より早く予測できる (途中の状態を保持したままなため再度過去の時系列を与える必要がない)。

stateful で学習させるのは結構めんどうくさいので、stateless で十分に長い時系列を与えて学習させて、リアルタイム予測などで実際に使うときは stateful にするというやりかたはありかもしれない。

stateless to statefulの例

サンプルデータ

意味がある予測ではないけど、試しにやってみた。

space = np.linspace(0, 1, 1000)
data = signal.square(2 * np.pi * 50 *  space) * signal.square(2 * np.pi * 30 *  space) / 2 + 0.5

こういう一時データをRNN(GRU)で学習させてみる。急激な値の変化があると学習結果が面白いことが多いので矩形波にしてる。

stateless で学習させる

def create_model(batch_size=None, timesteps=None, stateful=False):
    inputs = Input(batch_shape=(batch_size, None, 1))
    x = inputs
    x = GRU(32, stateful=stateful, return_sequences=False)(x)
    x = Dense(1, activation='sigmoid')(x)
    outputs = x
    model = Model(inputs, outputs)
    model.compile(optimizer=keras.optimizers.Adam(lr=0.01), loss='binary_crossentropy', metrics=['accuracy'])
    return model

model = create_model(timesteps = 100)
model.summary()

gen = keras.preprocessing.sequence.TimeseriesGenerator(data[:-10].reshape( (len(data)-10), 1 ), data[10:], length=100)
model.fit_generator(gen, epochs=30, validation_data=gen, shuffle=False)

過去の状態から10ステップ先を1つ予測するという問題設定にした (特に問題に意味はない)

stateless でも stateful でも同じモデルを作れるようにして、まず stateless で学習させる。ここでは過去の系列は100ステップで学習させている。特に意味がある学習ではないので validation に同じ系列を指定してる。

stateless で予測させる

tstart = time.time()
predict_len = 500

start = 207
result = []
for n in range(predict_len):
    predicted = model.predict( data[start+n-100:start+n].reshape( (1, 100, 1) ) )
    result.append(predicted[0])

result = np.array( result)

elapsed = time.time() - tstart
print('predict {}ms'.format(elapsed * 1000))

plt.figure(figsize=(10,10))
plt.subplot(211)
plt.title('expected')
plt.plot(data[start+10:start+500+10])
plt.plot(result)
plt.subplot(212)
plt.title('predicted (stateless)')
plt.plot(result)
plt.show()

result_stateless = result

後々、stateful とのコードと合わせるため一括で predict せずサンプル数1つで予測させてる。

一応一通り予測できている。ちなみに "predict 6278.290271759033ms" かかった。

statefulモデルにする

model.save_weights('/tmp/param.hdf5')
model_stateful = create_model(batch_size = 1, stateful=True)
model_stateful.load_weights('/tmp/param.hdf5')
model_stateful.summary()

weight を保存して、stateful=True にしたモデルに作りなおして weight をロードする。つまり stateless のモデルと weight は一緒の状態になる。

stateful モデルで予測する

tstart = time.time()

start = 207
result = []
## preload same state with stateless
# model_stateful.predict( data[start-1-100:start-1].reshape( (1, 100, 1) ) )
for n in range(predict_len):
    # give just new 1 timestep
    predicted = model_stateful.predict( data[start+n-1:start+n].reshape( (1, 1, 1) ) )
    result.append(predicted[0])
    
result = np.array( result)

elapsed = time.time() - tstart
print('predict {}ms'.format(elapsed * 1000))

plt.figure(figsize=(10,10))
plt.subplot(311)
plt.title('predicted (stateless)')
plt.plot(result_stateless)
plt.subplot(312)
plt.title('predicted (stateful)')
plt.plot(result)
plt.subplot(313)
plt.title('predicted (stateful+stateless)')
plt.plot(result_stateless)
plt.plot(result)
plt.show()

stateful では毎回1ステップだけ与えて予測させる。最初は過去の時系列がないため乱れているけど、ステップを与え続けると stateless と同じように予測できる。stateful でもコメンアウトしたところを実行すれば最初から stateless と全く同じ状態になる。

これは "predict 327.3053169250488ms" で終わる。

  1. トップ
  2. tech
  3. RNN/LSTM/GRU の入力と stateful 化 (keras)

CW の最小単位である短点の長さ t は以下で求められる。w は符号速度、単位 wpm (通常は 10〜40wpm) 。

5 (トトトトト 短点5つ)や訂正信号 <HH> (トトトトトトトト 短点8つ) を送信しているときに最大の帯域幅になる。短点の長さ t の on/off の繰り返しであるので、波長 2t の矩形波となる。24wpm では t = 50ms なので波長100ms、すなわち 10Hz の矩形波。

これを搬送波に乗せると (AM変調なので) 両側波帯に帯域が広がるため最低でも 20Hz の帯域幅になる。矩波形なので奇数次数の高調波も発生し、5倍まで考慮するだけで100Hzになる。

なお10wpm で 4Hz、50wpm で 21Hz の矩形波。

コンスタレーション(信号空間ダイヤグラム)

普通CWのコンスタレーションを気にすることはないと思うが、一応確認しておくと、BPSK などと比べるときに想像しやすい (なぜ BPSK が CW/OOK と比べて 3dB 有利かとか)

上の図のように中央部 (off) と周辺部 (on)にわかれる。

これがたとえば BPSK の場合は、中央ではなく、左端と右端になる。すなわち信号空間的には距離はCWの2倍になる。2倍=3dBよくなるということはこういうこと

  1. トップ
  2. tech
  3. CW の信号帯域とコンスタレーション

RNN モールスデコーダの試作 | tech - 氾濫原

波形ではなくSFFTの結果を認識させる

モールスで必要なのはキャリア周波数の周辺帯域だけ。モールスは通信速度があまり早くなく、帯域幅も100Hz程度のため、周波数分解能と時間分解能にトレードオフのあるSFFTでも、十分解析可能なはずだと思った。

訓練データ

例によってここは node.js で作った。web-audio-engine を使って実際にモールスの信号をつくり、それを AnalyserNode で連続で FFT して画像をつくった。

ラベルデータは考えられるだけでいくつか作りかたがある

  1. 符号のon/off の正解データ (binary)
  2. どこからどこまでが、どの符号かのID (categorical)
  3. その符号を人間が認識できる最短の位置での符号 ID (categorical)

画像として処理するなら、どこからどこまでが符号でその符号が何かがわかればいいが、連続信号として処理したいなら、人間が認識できる最短の位置にラベル付けするのが正しそう、と考えた。

一応どの正解データも作れるのようにデータ生成コードを書いた。

符号のon/offはその後なんらかのアルゴリズムでさらにデコードする必要があるが、モールスの場合はクロックが固定ではないので割と面倒くさい。

ということで早々に符号列のパターンを直接認識させる方法をとることにした。認識させるのはあくまで符号列 (トツーやツートトト) であって、(A や B という文字ではない)

単語単位で認識させることができればもっと精度が上げられそうだけど、コールサイン (パターンはあるがほぼランダム) をとれないと意味がないので、ランダム精度を重視している。

備考

モールスは、人間の場合、速度によってやや異る方法で認識している。

  • 10wpm〜15wpm (低速) 「符号」をまとまって捉えられないので「長」「短」の組合せでデコードしている。符号表さえ覚えていればデコード可能。
  • 15〜30wpm (標準) ひとつの「符号」をまとめて捉えて直接文字で認識する (トツーときたら「A」と学習している) 音素の認識と似てる。聴覚受信の回路がある程度脳内でできていないと難しい。
  • 25〜40wpm (高速) ひとつの「単語」をまとめて捉えて文字列として認識する(ツー ト トトト ツー / TEST や ES / ト トトト) 。特に「E」は「ト」でとても短い符号なので、符号単位に認識していると間に合わない。

モデル

入力はタイムシリーズ型式 (None, timestep, features) timestep は 72、features はデコード対象周波数を中心とした magnitude。

出力はどのモールスの符号か?を表わす64次元のone-hotベクトル。

いろいろなモデルをつくっては壊して試した。

が、どうもこれではうまくいかなそうだという気がしてきた。2dB (ノイズ帯域500Hz) 程度のSNRでもほぼ認識できない。

  1. トップ
  2. tech
  3. モールスデコーダの続き

角度は周期があるのでよくよく考えると平均や分散を出すのがむずかしい。いろいろやりかたがあるみたいだけど「単位ベクトル合算法」で計算してみる。

#!/usr/bin/env python
import numpy as np

deg = np.array([80, 170, 175, 200, 265, 345])
rad = deg * np.pi / 180

# ベクトル化 (計算が楽になるので複素数に)
cmp = np.cos(rad) + np.sin(rad)*1j
# 平均をとる
mean_complex = np.mean(cmp)

# このときの複素数の角度が平均角度
avg = np.angle(mean_complex)
# 絶対値(ベクトル長)小さいほどばらつきは大きい ( [0, 1] )
var = 1 - np.absolute(mean_complex)
print(avg * 180 / np.pi + 360) #=> 190.65
print(var) #=> 0.68

平均角度は分散も考えないと意味がないことがある。90°と270°の平均は↑の計算で180°になるが、ベクトル長さが0なので、平均として妥当な角度は存在してない。

ref

  1. トップ
  2. tech
  3. 角度の平均・分散