Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ cmd:option('-model', 'model', 'neural network model')
cmd:option('-data', 'data.t7', 'training data')
cmd:option('-learningRate', 0.01, 'learning rate')
cmd:option('-initweights', '', 'initial weights')
cmd:option('-initstate', '', 'initial optimizer state')


params = cmd:parse(arg)

Expand Down Expand Up @@ -144,13 +146,20 @@ function trainModel(weights)
return {cost}, dw
end

-- create directory to save weights and videos
-- create directory to save weights, optimizer state and videos
lfs.mkdir('weights_' .. params.model)
lfs.mkdir('state_' .. params.model)
lfs.mkdir('video_' .. params.model)

local total_cost, config, state = 0, { learningRate = params.learningRate }, {}
collectgarbage()

if #params.initstate > 0 then
print('Loading optimizer state ' .. params.initstate)
state=torch.load(params.initstate)
end


for k = 1,params.iter do
xlua.progress(k, params.iter)
local _, cost = optim.adagrad(trainModel, w, config, state)
Expand All @@ -159,8 +168,10 @@ for k = 1,params.iter do
if k % 1000 == 0 then
print('Iteration ' .. k .. ', cost: ' .. total_cost / 1000)
total_cost = 0
-- save weights
-- save weights and optimizer state
torch.save('weights_' .. params.model .. '/' .. k .. '.dat', w:type('torch.FloatTensor'))
torch.save('state_' .. params.model .. '/' .. k .. '.dat', state)

-- visualise performance
evalModel(w)
end
Expand Down