TrainingContinuation

emacav's picture

Hello,
first of all thanks for this beautiful project.

I'm in trouble with TrainingContinuation, here is my code:

...
if(new)
predictor.createNetwork();
else
predictor.loadNeuralNetwork();
predictor.trainResilientContinuation();
...

public void trainResilientContinuation() throws IOException {
System.out.println("trainResilientContinuation start!");

final ResilientPropagation train = new ResilientPropagation(network, this.trainDataSet);

// test if a paused train exists
try {
TrainingContinuation state = (TrainingContinuation)SerializeObject.load(RPROP_SAVE);
if (state != null) {
if (train.isValidResume(state))
train.resume(state);
}
} catch(Exception exx) {
exx.printStackTrace();
System.out.println("");
}

int epoch = 1;
do {
train.iteration();
if(epoch % EP_COUNT == 0) {
System.out.println("Iteration(ResilientCont) #" + epoch + " Error:" + train.getError());
}
epoch++;
} while (epoch < EPOCH && (train.getError() > MAX_ERROR));

System.out.println("Iteration(ResilientCont) #" + epoch + " Last Error:" + train.getError());

// save state
TrainingContinuation cont = train.pause();
SerializeObject.save(RPROP_SAVE, cont);
}

My network train for 50000 epoch and reaches an error of 0.06, but when I restart training after 500 epochs the error is 0.21!
Debugging my code I found that

System.out.println("LasGradient "+DumpMatrix.dumpArray((double[])cont.get(ResilientPropagation.LAST_GRADIENTS))+" - UpdVals "+DumpMatrix.dumpArray((double[])cont.get(ResilientPropagation.UPDATE_VALUES)));

says always
LasGradient[0,0,0, 0..] - UpdVals [0.1,0.1,0.1..]

I'm missing something?
Thanks in advance

jeffheaton's picture

It seems to work for me, it is checked in under "resume".

Here it is.

public static double XOR_INPUT[][] = { { 0.0, 0.0 }, { 1.0, 0.0 },
{ 0.0, 1.0 }, { 1.0, 1.0 } };

public static double XOR_IDEAL[][] = { { 0.0 }, { 1.0 }, { 1.0 }, { 0.0 } };

public static void main(String[] args)
{
Logging.stopConsoleLogging();
NeuralDataSet trainingSet = new BasicNeuralDataSet(XOR_INPUT, XOR_IDEAL);
BasicNetwork network = EncogUtility.simpleFeedForward(2, 4, 0, 1, false);
ResilientPropagation train = new ResilientPropagation(network, trainingSet);
train.addStrategy(new RequiredImprovementStrategy(5));

System.out.println("Perform initial train.");
EncogUtility.trainToError(train,network, trainingSet, 0.01);
TrainingContinuation cont = train.pause();
System.out.println(Arrays.toString((double[])cont.getContents().get(ResilientPropagation.LAST_GRADIENTS)));
System.out.println(Arrays.toString((double[])cont.getContents().get(ResilientPropagation.UPDATE_VALUES)));

try
{
SerializeObject.save("resume.ser", cont);
cont = (TrainingContinuation)SerializeObject.load("resume.ser");
}
catch(Exception ex)
{
ex.printStackTrace();
}

System.out.println("Now trying a second train, with continue from the first. Should stop after one iteration");
ResilientPropagation train2 = new ResilientPropagation(network, trainingSet);
train2.resume(cont);
EncogUtility.trainToError(train2,network, trainingSet, 0.01);
}

emacav's picture

Hi Jeff,
your example also works for me..
I quickly analyzed the evolution of the error, after a first training of 1000 epochs, this is what I got when I restarted the training

trainResilientContinuation start!
Resuming..
Network original error: 0.17342437045200296 - restart training...

Iteration(ResilientCont) #1 Error: 0.17342437045200296
Iteration(ResilientCont) #2 Error: 0.17342437045200296
Iteration(ResilientCont) #3 Error: 1.1358152596527957
Iteration(ResilientCont) #4 Error: 1.115804842553766
Iteration(ResilientCont) #5 Error: 1.085212389370948
Iteration(ResilientCont) #6 Error: 0.8472103169101813
Iteration(ResilientCont) #7 Error: 0.8910511705962119
Iteration(ResilientCont) #8 Error: 0.879940442056919
Iteration(ResilientCont) #9 Error: 0.6803667910355601
Iteration(ResilientCont) #10 Error: 0.47229788788210464
...
Iteration(ResilientCont) #994 Error: 0.21651797633733597
Iteration(ResilientCont) #995 Error: 0.21651698679797868
Iteration(ResilientCont) #996 Error: 0.2165125225059318
Iteration(ResilientCont) #997 Error: 0.21651110965863957
Iteration(ResilientCont) #998 Error: 0.21650792180297765
Iteration(ResilientCont) #999 Error: 0.2165051289375193
Iteration(ResilientCont) #1000 Last Error: 0.2165051289375193

Initially the error is the same of the old network, but I can't realize why his behavior degrades to 1.1358152596527957 on the 3rd iteration..
My new network is worse after 2000 iterations compared to the first (only 1000 iterations)

Many thanks for your time


Copyright 2005 - 2012 by Heaton Research, Inc.. Heaton Research™ and Encog™ are trademarks of Heaton Research. Click here for copyright, license and trademark information.