forked from OpenNMT/OpenNMT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.lua
553 lines (438 loc) · 19 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
require('onmt.init')
local path = require('pl.path')
require('tds')
local cmd = torch.CmdLine()
cmd:text("")
cmd:text("**train.lua**")
cmd:text("")
cmd:option('-config', '', [[Read options from this file]])
cmd:text("")
cmd:text("**Data options**")
cmd:text("")
cmd:option('-data', '', [[Path to the training *-train.t7 file from preprocess.lua]])
cmd:option('-save_model', '', [[Model filename (the model will be saved as
<save_model>_epochN_PPL.t7 where PPL is the validation perplexity]])
cmd:option('-train_from', '', [[If training from a checkpoint then this is the path to the pretrained model.]])
cmd:option('-continue', false, [[If training from a checkpoint, whether to continue the training in the same configuration or not.]])
cmd:text("")
cmd:text("**Model options**")
cmd:text("")
cmd:option('-layers', 2, [[Number of layers in the RNN encoder/decoder]])
cmd:option('-rnn_size', 500, [[Size of RNN hidden states]])
cmd:option('-rnn_type', 'LSTM', [[Type of RNN cell: LSTM, GRU]])
cmd:option('-word_vec_size', 500, [[Word embedding sizes]])
cmd:option('-feat_merge', 'concat', [[Merge action for the features embeddings: concat or sum]])
cmd:option('-feat_vec_exponent', 0.7, [[When using concatenation, if the feature takes N values
then the embedding dimension will be set to N^exponent]])
cmd:option('-feat_vec_size', 20, [[When using sum, the common embedding size of the features]])
cmd:option('-input_feed', 1, [[Feed the context vector at each time step as additional input (via concatenation with the word embeddings) to the decoder.]])
cmd:option('-residual', false, [[Add residual connections between RNN layers.]])
cmd:option('-brnn', false, [[Use a bidirectional encoder]])
cmd:option('-brnn_merge', 'sum', [[Merge action for the bidirectional hidden states: concat or sum]])
cmd:text("")
cmd:text("**Optimization options**")
cmd:text("")
cmd:option('-max_batch_size', 64, [[Maximum batch size]])
cmd:option('-end_epoch', 13, [[The final epoch of the training]])
cmd:option('-start_epoch', 1, [[If loading from a checkpoint, the epoch from which to start]])
cmd:option('-start_iteration', 1, [[If loading from a checkpoint, the iteration from which to start]])
cmd:option('-param_init', 0.1, [[Parameters are initialized over uniform distribution with support (-param_init, param_init)]])
cmd:option('-optim', 'sgd', [[Optimization method. Possible options are: sgd, adagrad, adadelta, adam]])
cmd:option('-learning_rate', 1, [[Starting learning rate. If adagrad/adadelta/adam is used,
then this is the global learning rate. Recommended settings are: sgd = 1,
adagrad = 0.1, adadelta = 1, adam = 0.0002]])
cmd:option('-max_grad_norm', 5, [[If the norm of the gradient vector exceeds this renormalize it to have the norm equal to max_grad_norm]])
cmd:option('-dropout', 0.3, [[Dropout probability. Dropout is applied between vertical LSTM stacks.]])
cmd:option('-learning_rate_decay', 0.5, [[Decay learning rate by this much if (i) perplexity does not decrease
on the validation set or (ii) epoch has gone past the start_decay_at_limit]])
cmd:option('-start_decay_at', 9, [[Start decay after this epoch]])
cmd:option('-curriculum', 0, [[For this many epochs, order the minibatches based on source
sequence length. Sometimes setting this to 1 will increase convergence speed.]])
cmd:option('-pre_word_vecs_enc', '', [[If a valid path is specified, then this will load
pretrained word embeddings on the encoder side.
See README for specific formatting instructions.]])
cmd:option('-pre_word_vecs_dec', '', [[If a valid path is specified, then this will load
pretrained word embeddings on the decoder side.
See README for specific formatting instructions.]])
cmd:option('-fix_word_vecs_enc', false, [[Fix word embeddings on the encoder side]])
cmd:option('-fix_word_vecs_dec', false, [[Fix word embeddings on the decoder side]])
cmd:text("")
cmd:text("**Other options**")
cmd:text("")
-- GPU
onmt.utils.Cuda.declareOpts(cmd)
cmd:option('-async_parallel', false, [[Use asynchronous parallelism training.]])
cmd:option('-async_parallel_minbatch', 1000, [[For async parallel computing, minimal number of batches before being parallel.]])
cmd:option('-no_nccl', false, [[Disable usage of nccl in parallel mode.]])
cmd:option('-disable_mem_optimization', false, [[Disable sharing internal of internal buffers between clones - which is in general safe,
except if you want to look inside clones for visualization purpose for instance.]])
-- bookkeeping
cmd:option('-save_every', 0, [[Save intermediate models every this many iterations within an epoch.
If = 0, will not save models within an epoch. ]])
cmd:option('-report_every', 50, [[Print stats every this many iterations within an epoch.]])
cmd:option('-seed', 3435, [[Seed for random initialization]])
cmd:option('-json_log', false, [[Outputs logs in JSON format.]])
onmt.utils.Logger.declareOpts(cmd)
local opt = cmd:parse(arg)
local function initParams(model, verbose)
local numParams = 0
local params = {}
local gradParams = {}
if verbose then
_G.logger:info('Initializing parameters...')
end
-- Order the model table because we need all replicas to have the same order.
local orderedIndex = {}
for key in pairs(model) do
table.insert(orderedIndex, key)
end
table.sort(orderedIndex)
for _, key in ipairs(orderedIndex) do
local mod = model[key]
local p, gp = mod:getParameters()
if opt.train_from:len() == 0 then
p:uniform(-opt.param_init, opt.param_init)
mod:apply(function (m)
if m.postParametersInitialization then
m:postParametersInitialization()
end
end)
end
numParams = numParams + p:size(1)
table.insert(params, p)
table.insert(gradParams, gp)
end
if verbose then
_G.logger:info(" * number of parameters: " .. numParams)
end
return params, gradParams
end
local function buildCriterion(vocabSize, features)
local criterion = nn.ParallelCriterion(false)
local function addNllCriterion(size)
-- Ignores padding value.
local w = torch.ones(size)
w[onmt.Constants.PAD] = 0
local nll = nn.ClassNLLCriterion(w)
-- Let the training code manage loss normalization.
nll.sizeAverage = false
criterion:add(nll)
end
addNllCriterion(vocabSize)
for j = 1, #features do
addNllCriterion(features[j]:size())
end
return criterion
end
local function eval(model, criterion, data)
local loss = 0
local total = 0
model.encoder:evaluate()
model.decoder:evaluate()
for i = 1, data:batchCount() do
local batch = onmt.utils.Cuda.convert(data:getBatch(i))
local encoderStates, context = model.encoder:forward(batch)
loss = loss + model.decoder:computeLoss(batch, encoderStates, context, criterion)
total = total + batch.targetNonZeros
end
model.encoder:training()
model.decoder:training()
return math.exp(loss / total)
end
local function trainModel(model, trainData, validData, dataset, info)
local params, gradParams = {}, {}
local criterion
onmt.utils.Parallel.launch(function(idx)
-- Only logs information of the first thread.
local verbose = idx == 1 and not opt.json_log
_G.params, _G.gradParams = initParams(_G.model, verbose)
for _, mod in pairs(_G.model) do
mod:training()
end
-- define criterion of each GPU
_G.criterion = onmt.utils.Cuda.convert(buildCriterion(dataset.dicts.tgt.words:size(),
dataset.dicts.tgt.features))
-- optimize memory of the first clone
if not opt.disable_mem_optimization then
local batch = onmt.utils.Cuda.convert(trainData:getBatch(1))
batch.totalSize = batch.size
onmt.utils.Memory.optimize(_G.model, _G.criterion, batch, verbose)
end
return idx, _G.criterion, _G.params, _G.gradParams
end, function(idx, thecriterion, theparams, thegradParams)
if idx == 1 then
criterion = thecriterion
end
params[idx] = theparams
gradParams[idx] = thegradParams
end)
local optim = onmt.train.Optim.new({
method = opt.optim,
numModels = #params[1],
learningRate = opt.learning_rate,
learningRateDecay = opt.learning_rate_decay,
startDecayAt = opt.start_decay_at,
optimStates = opt.optim_states
})
local checkpoint = onmt.train.Checkpoint.new(opt, model, optim, dataset.dicts)
local function trainEpoch(epoch, lastValidPpl)
local epochState
local batchOrder
local startI = opt.start_iteration
local numIterations = trainData:batchCount()
-- In parallel mode, number of iterations is reduced to reflect larger batch size.
if onmt.utils.Parallel.count > 1 and not opt.async_parallel then
numIterations = math.ceil(numIterations / onmt.utils.Parallel.count)
end
if startI > 1 and info ~= nil then
epochState = onmt.train.EpochState.new(epoch, numIterations, optim:getLearningRate(), lastValidPpl, info.epochStatus)
batchOrder = info.batchOrder
else
epochState = onmt.train.EpochState.new(epoch, numIterations, optim:getLearningRate(), lastValidPpl)
-- Shuffle mini batch order.
batchOrder = torch.randperm(trainData:batchCount())
end
opt.start_iteration = 1
local function trainNetwork(batch)
optim:zeroGrad(_G.gradParams)
local encStates, context = _G.model.encoder:forward(batch)
local decOutputs = _G.model.decoder:forward(_G.batch, encStates, context)
local encGradStatesOut, gradContext, loss = _G.model.decoder:backward(_G.batch, decOutputs, _G.criterion)
_G.model.encoder:backward(_G.batch, encGradStatesOut, gradContext)
return loss
end
if not opt.async_parallel then
local iter = 1
for i = startI, trainData:batchCount(), onmt.utils.Parallel.count do
local batches = {}
local totalSize = 0
for j = 1, math.min(onmt.utils.Parallel.count, trainData:batchCount() - i + 1) do
local batchIdx = batchOrder[i + j - 1]
if epoch <= opt.curriculum then
batchIdx = i + j - 1
end
table.insert(batches, trainData:getBatch(batchIdx))
totalSize = totalSize + batches[#batches].size
end
local losses = {}
onmt.utils.Parallel.launch(function(idx)
_G.batch = batches[idx]
if _G.batch == nil then
return idx, 0
end
-- Send batch data to the GPU.
onmt.utils.Cuda.convert(_G.batch)
_G.batch.totalSize = totalSize
local loss = trainNetwork(_G.batch)
return idx, loss
end,
function(idx, loss)
losses[idx]=loss
end)
-- Accumulate the gradients from the different parallel threads.
onmt.utils.Parallel.accGradParams(gradParams, batches)
-- Update the parameters.
optim:prepareGrad(gradParams[1], opt.max_grad_norm)
optim:updateParams(params[1], gradParams[1])
-- Synchronize the parameters with the different parallel threads.
onmt.utils.Parallel.syncParams(params)
for bi = 1, #batches do
epochState:update(batches[bi], losses[bi])
end
if iter % opt.report_every == 0 then
epochState:log(iter, opt.json_log)
end
if opt.save_every > 0 and iter % opt.save_every == 0 then
checkpoint:saveIteration(iter, epochState, batchOrder, not opt.json_log)
end
iter = iter + 1
end
else
-- Asynchronous parallel.
local counter = onmt.utils.Parallel.getCounter()
counter:set(startI)
local masterGPU = onmt.utils.Cuda.gpuIds[1]
local gradBuffer = onmt.utils.Parallel.gradBuffer
local gmutexId = onmt.utils.Parallel.gmutexId()
while counter:get() <= trainData:batchCount() do
local startCounter = counter:get()
onmt.utils.Parallel.launch(function(idx)
-- First GPU is only used for master parameters.
-- Use 1 GPU only for 1000 first batch.
if idx == 1 or (idx > 2 and epoch == 1 and counter:get() < opt.async_parallel_minbatch) then
return
end
local lossThread = 0
local batchThread = {
size = 1,
sourceLength = 0,
targetLength = 0,
targetNonZeros = 0
}
while true do
-- Do not process more than 1000 batches (TODO - make option) in one shot.
if counter:get() - startCounter >= 1000 then
return lossThread, batchThread
end
local i = counter:inc()
if i > trainData:batchCount() then
return lossThread, batchThread
end
local batchIdx = batchOrder[i]
if epoch <= opt.curriculum then
batchIdx = i
end
_G.batch = trainData:getBatch(batchIdx)
-- Send batch data to the GPU.
onmt.utils.Cuda.convert(_G.batch)
_G.batch.totalSize = _G.batch.size
local loss = trainNetwork()
-- Update the parameters.
optim:prepareGrad(_G.gradParams, opt.max_grad_norm)
-- Add up gradParams to params and synchronize back to this thread.
onmt.utils.Parallel.updateAndSync(params[1], _G.gradParams, _G.params, gradBuffer, masterGPU, gmutexId)
batchThread.sourceLength = batchThread.sourceLength + _G.batch.sourceLength * _G.batch.size
batchThread.targetLength = batchThread.targetLength + _G.batch.targetLength * _G.batch.size
batchThread.targetNonZeros = batchThread.targetNonZeros + _G.batch.targetNonZeros
lossThread = lossThread + loss
-- we don't have information about the other threads here - we can only report progress
if i % opt.report_every == 0 then
_G.logger:info('Epoch %d ; ... batch %d/%d', epoch, i, trainData:batchCount())
end
end
end,
function(theloss, thebatch)
if theloss then
epochState:update(thebatch, theloss)
end
end)
if opt.report_every > 0 then
epochState:log(counter:get(), opt.json_log)
end
if opt.save_every > 0 then
checkpoint:saveIteration(counter:get(), epochState, batchOrder, not opt.json_log)
end
end
end
return epochState
end
local validPpl = 0
if not opt.json_log then
_G.logger:info('Start training...')
end
for epoch = opt.start_epoch, opt.end_epoch do
if not opt.json_log then
_G.logger:info('')
end
local epochState = trainEpoch(epoch, validPpl)
validPpl = eval(model, criterion, validData)
if not opt.json_log then
_G.logger:info('Validation perplexity: %.2f', validPpl)
end
if opt.optim == 'sgd' then
optim:updateLearningRate(validPpl, epoch)
end
checkpoint:saveEpoch(validPpl, epochState, not opt.json_log)
end
end
local function main()
local requiredOptions = {
"data",
"save_model"
}
onmt.utils.Opt.init(opt, requiredOptions)
_G.logger = onmt.utils.Logger.new(opt.log_file, opt.disable_logs, opt.log_level)
onmt.utils.Cuda.init(opt)
onmt.utils.Parallel.init(opt)
local checkpoint = {}
if opt.train_from:len() > 0 then
assert(path.exists(opt.train_from), 'checkpoint path invalid')
if not opt.json_log then
_G.logger:info('Loading checkpoint \'' .. opt.train_from .. '\'...')
end
checkpoint = torch.load(opt.train_from)
opt.layers = checkpoint.options.layers
opt.rnn_size = checkpoint.options.rnn_size
opt.brnn = checkpoint.options.brnn
opt.brnn_merge = checkpoint.options.brnn_merge
opt.input_feed = checkpoint.options.input_feed
-- Resume training from checkpoint
if opt.continue then
opt.optim = checkpoint.options.optim
opt.learning_rate_decay = checkpoint.options.learning_rate_decay
opt.start_decay_at = checkpoint.options.start_decay_at
opt.curriculum = checkpoint.options.curriculum
opt.learning_rate = checkpoint.info.learningRate
opt.optim_states = checkpoint.info.optimStates
opt.start_epoch = checkpoint.info.epoch
opt.start_iteration = checkpoint.info.iteration
if not opt.json_log then
_G.logger:info('Resuming training from epoch ' .. opt.start_epoch
.. ' at iteration ' .. opt.start_iteration .. '...')
end
end
end
-- Create the data loader class.
if not opt.json_log then
_G.logger:info('Loading data from \'' .. opt.data .. '\'...')
end
local dataset = torch.load(opt.data, 'binary', false)
local trainData = onmt.data.Dataset.new(dataset.train.src, dataset.train.tgt)
local validData = onmt.data.Dataset.new(dataset.valid.src, dataset.valid.tgt)
trainData:setBatchSize(opt.max_batch_size)
validData:setBatchSize(opt.max_batch_size)
if not opt.json_log then
_G.logger:info(' * vocabulary size: source = %d; target = %d',
dataset.dicts.src.words:size(), dataset.dicts.tgt.words:size())
_G.logger:info(' * additional features: source = %d; target = %d',
#dataset.dicts.src.features, #dataset.dicts.tgt.features)
_G.logger:info(' * maximum sequence length: source = %d; target = %d',
trainData.maxSourceLength, trainData.maxTargetLength)
_G.logger:info(' * number of training sentences: %d', #trainData.src)
_G.logger:info(' * maximum batch size: %d', opt.max_batch_size)
else
local metadata = {
options = opt,
vocabSize = {
source = dataset.dicts.src.words:size(),
target = dataset.dicts.tgt.words:size()
},
additionalFeatures = {
source = #dataset.dicts.src.features,
target = #dataset.dicts.tgt.features
},
sequenceLength = {
source = trainData.maxSourceLength,
target = trainData.maxTargetLength
},
trainingSentences = #trainData.src
}
onmt.utils.Log.logJson(metadata)
end
if not opt.json_log then
_G.logger:info('Building model...')
end
local model
onmt.utils.Parallel.launch(function(idx)
_G.model = {}
if checkpoint.models then
_G.model.encoder = onmt.Models.loadEncoder(checkpoint.models.encoder, idx > 1)
_G.model.decoder = onmt.Models.loadDecoder(checkpoint.models.decoder, idx > 1)
else
local verbose = idx == 1 and not opt.json_log
_G.model.encoder = onmt.Models.buildEncoder(opt, dataset.dicts.src)
_G.model.decoder = onmt.Models.buildDecoder(opt, dataset.dicts.tgt, verbose)
end
for _, mod in pairs(_G.model) do
onmt.utils.Cuda.convert(mod)
end
return idx, _G.model
end, function(idx, themodel)
if idx == 1 then
model = themodel
end
end)
trainModel(model, trainData, validData, dataset, checkpoint.info)
_G.logger:shutDown()
end
main()