Skip to content

Commit 141ef1c

Browse files
author
Philippe Rémy
authored
Merge pull request #42 from philipperemy/paper
Implementation closer to the paper
2 parents 0ae28a0 + 12a1fc8 commit 141ef1c

File tree

8 files changed

+47
-100
lines changed

8 files changed

+47
-100
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ model.fit(x, y) # Keras model.
9595

9696
### Arguments
9797

98-
`tcn.TCN(nb_filters=64, kernel_size=2, nb_stacks=1, dilations=[1, 2, 4, 8, 16, 32], activation='norm_relu', padding='causal', use_skip_connections=True, dropout_rate=0.0, return_sequences=True, name='tcn')`
98+
`TCN(nb_filters=64, kernel_size=2, nb_stacks=1, dilations=[1, 2, 4, 8, 16, 32], activation='norm_relu', padding='causal', use_skip_connections=True, dropout_rate=0.0, return_sequences=True, name='tcn')`
9999

100100
- `nb_filters`: Integer. The number of filters to use in the convolutional layers.
101101
- `kernel_size`: Integer. The size of the kernel to use in each convolutional layer.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='keras-tcn',
5-
version='2.3.6',
5+
version='2.5.6',
66
description='Keras TCN',
77
author='Philippe Remy',
88
license='MIT',

tasks/adding_problem/main.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import keras
2-
from utils import data_generator
32

43
from tcn import compiled_tcn
4+
from utils import data_generator
55

66
x_train, y_train = data_generator(n=200000, seq_length=600)
77
x_test, y_test = data_generator(n=40000, seq_length=600)
@@ -22,9 +22,8 @@ def run_task():
2222
nb_filters=24,
2323
kernel_size=8,
2424
dilations=[2 ** i for i in range(9)],
25-
nb_stacks=2,
25+
nb_stacks=1,
2626
max_len=x_train.shape[1],
27-
activation='norm_relu',
2827
use_skip_connections=True,
2928
regression=True,
3029
dropout_rate=0)
@@ -39,7 +38,7 @@ def run_task():
3938
model.summary()
4039

4140
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=500,
42-
callbacks=[psv], batch_size=128)
41+
callbacks=[psv], batch_size=256)
4342

4443

4544
if __name__ == '__main__':

tasks/copy_memory/main.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import keras
2-
from utils import data_generator
32

43
from tcn import compiled_tcn
4+
from utils import data_generator
55

66
x_train, y_train = data_generator(601, 10, 30000)
77
x_test, y_test = data_generator(601, 10, 6000)
@@ -25,9 +25,8 @@ def run_task():
2525
nb_filters=10,
2626
kernel_size=8,
2727
dilations=[2 ** i for i in range(9)],
28-
nb_stacks=2,
28+
nb_stacks=1,
2929
max_len=x_train[0:1].shape[1],
30-
activation='norm_relu',
3130
use_skip_connections=True,
3231
return_sequences=True)
3332

tasks/mnist_pixel/main.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@ def run_task():
99
model = compiled_tcn(return_sequences=False,
1010
num_feat=1,
1111
num_classes=10,
12-
nb_filters=25,
13-
kernel_size=7,
12+
nb_filters=20,
13+
kernel_size=6,
1414
dilations=[2 ** i for i in range(9)],
15-
nb_stacks=2,
15+
nb_stacks=1,
1616
max_len=x_train[0:1].shape[1],
17-
activation='norm_relu',
1817
use_skip_connections=True)
1918

2019
print(f'x_train.shape = {x_train.shape}')

tasks/receptive-field/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def run_task(sequence_length=8):
1515
dilations=[1, 2, 4, 8, 16, 32],
1616
nb_stacks=6,
1717
max_len=x_train[0:1].shape[1],
18-
activation='norm_relu',
1918
use_skip_connections=False)
2019

2120
print(f'x_train.shape = {x_train.shape}')

tcn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from tcn.tcn import TCN, compiled_tcn
22

3-
__version__ = '2.3.5'
3+
__version__ = '2.5.6'

tcn/tcn.py

Lines changed: 36 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,44 @@
1+
from typing import List, Tuple
2+
13
import keras.backend as K
24
import keras.layers
35
from keras import optimizers
46
from keras.engine.topology import Layer
5-
from keras.layers import Activation, Lambda
7+
from keras.layers import Activation, Lambda, BatchNormalization
68
from keras.layers import Conv1D, SpatialDropout1D
79
from keras.layers import Convolution1D, Dense
810
from keras.models import Input, Model
9-
from typing import List, Tuple
10-
11-
12-
def channel_normalization(x):
13-
# type: (Layer) -> Layer
14-
""" Normalize a layer to the maximum activation
15-
16-
This keeps a layers values between zero and one.
17-
It helps with relu's unbounded activation
18-
19-
Args:
20-
x: The layer to normalize
21-
22-
Returns:
23-
A maximal normalized layer
24-
"""
25-
max_values = K.max(K.abs(x), 2, keepdims=True) + 1e-5
26-
out = x / max_values
27-
return out
2811

2912

30-
def wave_net_activation(x):
31-
# type: (Layer) -> Layer
32-
"""This method defines the activation used for WaveNet
33-
34-
described in https://deepmind.com/blog/wavenet-generative-model-raw-audio/
35-
36-
Args:
37-
x: The layer we want to apply the activation to
38-
39-
Returns:
40-
A new layer with the wavenet activation applied
41-
"""
42-
tanh_out = Activation('tanh')(x)
43-
sigm_out = Activation('sigmoid')(x)
44-
return keras.layers.multiply([tanh_out, sigm_out])
45-
46-
47-
def residual_block(x, s, i, c, activation, nb_filters, kernel_size, padding, dropout_rate=0, name=''):
48-
# type: (Layer, int, int, int, str, int, int, str, float, str) -> Tuple[Layer, Layer]
13+
def residual_block(x, dilation_rate, nb_filters, kernel_size, padding, dropout_rate=0):
14+
# type: (Layer, int, int, int, str, float) -> Tuple[Layer, Layer]
4915
"""Defines the residual block for the WaveNet TCN
5016
5117
Args:
5218
x: The previous layer in the model
53-
s: The stack index i.e. which stack in the overall TCN
54-
i: The dilation power of 2 we are using for this residual block
55-
c: The dilation name to make it unique. In case we have same dilation twice: [1, 1, 2, 4].
56-
activation: The name of the type of activation to use
19+
dilation_rate: The dilation power of 2 we are using for this residual block
5720
nb_filters: The number of convolutional filters to use in this block
5821
kernel_size: The size of the convolutional kernel
5922
padding: The padding used in the convolutional layers, 'same' or 'causal'.
6023
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
61-
name: Name of the model. Useful when having multiple TCN.
6224
6325
Returns:
6426
A tuple where the first element is the residual model layer, and the second
6527
is the skip connection.
6628
"""
29+
prev_x = x
30+
for k in range(2):
31+
x = Conv1D(filters=nb_filters,
32+
kernel_size=kernel_size,
33+
dilation_rate=dilation_rate,
34+
padding=padding)(x)
35+
# x = BatchNormalization()(x) # TODO should be WeightNorm here.
36+
x = Activation('relu')(x)
37+
x = SpatialDropout1D(rate=dropout_rate)(x)
6738

68-
original_x = x
69-
conv = Conv1D(filters=nb_filters, kernel_size=kernel_size,
70-
dilation_rate=i, padding=padding,
71-
name=name + '_d_%s_conv_%d-%d_tanh_s%d' % (padding, i, c, s))(x)
72-
if activation == 'norm_relu':
73-
x = Activation('relu')(conv)
74-
x = Lambda(channel_normalization)(x)
75-
elif activation == 'wavenet':
76-
x = wave_net_activation(conv)
77-
else:
78-
x = Activation(activation)(conv)
79-
80-
x = SpatialDropout1D(dropout_rate, name=name + '_spatial_dropout1d_%d-%d_s%d_%f' % (i, c, s, dropout_rate))(x)
81-
82-
# 1x1 conv.
39+
# 1x1 conv to match the shapes (channel dimension).
8340
x = Convolution1D(nb_filters, 1, padding='same')(x)
84-
res_x = keras.layers.add([original_x, x])
41+
res_x = keras.layers.add([prev_x, x])
8542
return res_x, x
8643

8744

@@ -109,7 +66,6 @@ class TCN:
10966
kernel_size: The size of the kernel to use in each convolutional layer.
11067
dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
11168
nb_stacks : The number of stacks of residual blocks to use.
112-
activation: The activations to use (norm_relu, wavenet, relu...).
11369
padding: The padding to use in the convolutional layers, 'causal' or 'same'.
11470
use_skip_connections: Boolean. If we want to add skip connections from input to each residual block.
11571
return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
@@ -125,7 +81,6 @@ def __init__(self,
12581
kernel_size=2,
12682
nb_stacks=1,
12783
dilations=[1, 2, 4, 8, 16, 32],
128-
activation='norm_relu',
12984
padding='causal',
13085
use_skip_connections=True,
13186
dropout_rate=0.0,
@@ -135,7 +90,6 @@ def __init__(self,
13590
self.return_sequences = return_sequences
13691
self.dropout_rate = dropout_rate
13792
self.use_skip_connections = use_skip_connections
138-
self.activation = activation
13993
self.dilations = dilations
14094
self.nb_stacks = nb_stacks
14195
self.kernel_size = kernel_size
@@ -147,27 +101,29 @@ def __init__(self,
147101

148102
if not isinstance(nb_filters, int):
149103
print('An interface change occurred after the version 2.1.2.')
150-
print('Before: tcn.TCN(i, return_sequences=False, ...)')
151-
print('Now should be: tcn.TCN(return_sequences=False, ...)(i)')
152-
print('Second solution is to pip install keras-tcn==2.1.2 to downgrade.')
104+
print('Before: tcn.TCN(x, return_sequences=False, ...)')
105+
print('Now should be: tcn.TCN(return_sequences=False, ...)(x)')
106+
print('The alternative is to downgrade to 2.1.2 (pip install keras-tcn==2.1.2).')
153107
raise Exception()
154108

155109
def __call__(self, inputs):
156110
x = inputs
157-
x = Convolution1D(self.nb_filters, 1, padding=self.padding, name=self.name + '_initial_conv')(x)
111+
# 1D FCN.
112+
x = Convolution1D(self.nb_filters, 1, padding=self.padding)(x)
158113
skip_connections = []
159114
for s in range(self.nb_stacks):
160-
for i, d in enumerate(self.dilations):
161-
x, skip_out = residual_block(x, s, d, i, self.activation, self.nb_filters,
162-
self.kernel_size, self.padding, self.dropout_rate, name=self.name)
115+
for d in self.dilations:
116+
x, skip_out = residual_block(x,
117+
dilation_rate=d,
118+
nb_filters=self.nb_filters,
119+
kernel_size=self.kernel_size,
120+
padding=self.padding,
121+
dropout_rate=self.dropout_rate)
163122
skip_connections.append(skip_out)
164123
if self.use_skip_connections:
165124
x = keras.layers.add(skip_connections)
166-
x = Activation('relu')(x)
167-
168125
if not self.return_sequences:
169-
output_slice_index = -1
170-
x = Lambda(lambda tt: tt[:, output_slice_index, :])(x)
126+
x = Lambda(lambda tt: tt[:, -1, :])(x)
171127
return x
172128

173129

@@ -178,7 +134,6 @@ def compiled_tcn(num_feat, # type: int
178134
dilations, # type: List[int]
179135
nb_stacks, # type: int
180136
max_len, # type: int
181-
activation='norm_relu', # type: str
182137
padding='causal', # type: str
183138
use_skip_connections=True, # type: bool
184139
return_sequences=True,
@@ -197,7 +152,6 @@ def compiled_tcn(num_feat, # type: int
197152
dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
198153
nb_stacks : The number of stacks of residual blocks to use.
199154
max_len: The maximum sequence length, use None if the sequence length is dynamic.
200-
activation: The activations to use.
201155
padding: The padding to use in the convolutional layers.
202156
use_skip_connections: Boolean. If we want to add skip connections from input to each residual block.
203157
return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
@@ -213,8 +167,8 @@ def compiled_tcn(num_feat, # type: int
213167

214168
input_layer = Input(shape=(max_len, num_feat))
215169

216-
x = TCN(nb_filters, kernel_size, nb_stacks, dilations, activation,
217-
padding, use_skip_connections, dropout_rate, return_sequences, name)(input_layer)
170+
x = TCN(nb_filters, kernel_size, nb_stacks, dilations, padding,
171+
use_skip_connections, dropout_rate, return_sequences, name)(input_layer)
218172

219173
print('x.shape=', x.shape)
220174

@@ -223,13 +177,11 @@ def compiled_tcn(num_feat, # type: int
223177
x = Dense(num_classes)(x)
224178
x = Activation('softmax')(x)
225179
output_layer = x
226-
print(f'model.x = {input_layer.shape}')
227-
print(f'model.y = {output_layer.shape}')
228180
model = Model(input_layer, output_layer)
229181

230182
# https://github.com/keras-team/keras/pull/11373
231183
# It's now in Keras@master but still not available with pip.
232-
# TODO To remove later.
184+
# TODO remove later.
233185
def accuracy(y_true, y_pred):
234186
# reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
235187
if K.ndim(y_true) == K.ndim(y_pred):
@@ -241,16 +193,15 @@ def accuracy(y_true, y_pred):
241193

242194
adam = optimizers.Adam(lr=0.002, clipnorm=1.)
243195
model.compile(adam, loss='sparse_categorical_crossentropy', metrics=[accuracy])
244-
print('Adam with norm clipping.')
245196
else:
246197
# regression
247198
x = Dense(1)(x)
248199
x = Activation('linear')(x)
249200
output_layer = x
250-
print(f'model.x = {input_layer.shape}')
251-
print(f'model.y = {output_layer.shape}')
252201
model = Model(input_layer, output_layer)
253202
adam = optimizers.Adam(lr=0.002, clipnorm=1.)
254203
model.compile(adam, loss='mean_squared_error')
255-
204+
print(f'model.x = {input_layer.shape}')
205+
print(f'model.y = {output_layer.shape}')
206+
print('Adam with norm clipping.')
256207
return model

0 commit comments

Comments
 (0)