Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions examples/language-model/make_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ def one_hot(v, ndim):
'pride_and_prejudice.txt',
'shakespeare.txt'
]:
with open('books/%s' % book, 'r') as infile:
with open('books/%s' % book, 'r', encoding='utf-8') as infile:
chars = [
c for c in ' '.join(infile.read().lower().split())
if c in set(vocab)
if c in vocab
]
all_chars += [' ']
all_chars += chars

all_chars = list(' '.join(''.join(all_chars).split()))
num_chars = len(all_chars)
with open('cleaned.txt', 'w') as outfile:
with open('cleaned.txt', 'w', encoding='utf-8') as outfile:
outfile.write(''.join(all_chars))


Expand All @@ -71,10 +71,7 @@ def one_hot(v, ndim):
data_portions[i][1] * 0.1
)

max_i = sum([
int(round(len(all_chars) * fraction))
for name, fraction in data_portions
]) - seq_len
max_i = sum(int(round(len(all_chars) * fraction)) for (name, fraction) in data_portions) - seq_len

for i in range(max_i):

Expand Down Expand Up @@ -105,12 +102,12 @@ def one_hot(v, ndim):

start_i = end_i

with open('data/%s.jsonl' % name, 'w') as outfile:
with open('data/%s.jsonl' % name, 'w', encoding='utf-8') as outfile:
for sample_x, sample_y in zip(x0, y0):
outfile.write(json.dumps({
'in_seq': sample_x.tolist(),
'out_char': sample_y.tolist()
}))
outfile.write('\n')

del x0, y0
del x0, y0
13 changes: 5 additions & 8 deletions kur/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import logging
import kur.__main__

from . import __version__, __homepage__
from . import Kurfile, __homepage__, __version__
from .utils import logcolor
from . import Kurfile

from .plugins import Plugin
from .engine import JinjaEngine

Expand Down Expand Up @@ -87,10 +87,7 @@ def build(args):
spec = parse_kurfile(args.kurfile, args.engine)

if args.compile == 'auto':
result = []
for section in ('train', 'test', 'evaluate'):
if section in spec.data:
result.append((section, 'data' in spec.data[section]))
result = [(section, 'data' in spec.data[section]) for section in ('train', 'test', 'evaluate') if section in spec.data]
if not result:
logger.info('Trying to build a bare model.')
args.compile = 'none'
Expand Down Expand Up @@ -118,7 +115,7 @@ def build(args):

if args.compile == 'none':
return
elif args.compile == 'train':
if args.compile == 'train':
target = spec.get_trainer(with_optimizer=True)
elif args.compile == 'test':
target = spec.get_trainer(with_optimizer=False)
Expand Down Expand Up @@ -452,7 +449,7 @@ def main():
if gotcha:
plugin_dir = arg
break
elif arg == '--plugin':
if arg == '--plugin':
gotcha = True
plugin_dir = plugin_dir or os.environ.get('KUR_PLUGIN')
load_plugins(plugin_dir)
Expand Down