package org.deeplearning4j.examples.recurrent.character;

import org.apache.commons.io.FileUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;

import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.Random;

/**GravesLSTM - przykład modelowania znaków.
 * @autor Alex Black

    Trenowanie sieci LSTM RNN generującej tekst znak po znaku.
    Przykład jest inspirowany wpisem na blogu Andreja Karpathy'ego
    "The Unreasonable Effectiveness of Recurrent Neural Networks"
    http://karpathy.github.io/2015/05/21/rnn-effectiveness/

    Sieć jest trenowana na wszystkich dziełach Williama Szekspira pobranych z Projektu Gutenberg. Przystosowanie go do trenowania na innych tekstach powinno być stosunkowo łatwe.
    Więcej szczegółowych informacji o sieciach RNN i bibliotece DL4J znajdziesz na stronach:
    http://deeplearning4j.org/usingrnns
    http://deeplearning4j.org/lstm
    http://deeplearning4j.org/recurrentnetwork
 */
public class GravesLSTMCharModellingExample {
	public static void main( String[] args ) throws Exception {
		int lstmLayerSize = 200;					// Liczba jednostek w każdej warstwie GravesLSTM.
		int miniBatchSize = 32;						// Wielkość minipaczki wykorzystywanej w treningu.
		int exampleLength = 1000;					// Długość sekwencji treningowej, oczywiście można ją zwiększyć.
        int tbpttLength = 50;         // Długość przyciętej propagacji wstecznej (aktualizacja parametrów co 50 znaków).
		int numEpochs = 1;							  // Całkowita liczba epok treningowych.
        int generateSamplesEveryNMinibatches = 10;  // Częstotliwość generowania próbek: 1000 znaków / 50 dł. propagacji = 20 aktualizacji parametrów na minipaczkę.
		int nSamplesToGenerate = 4;					  // Liczba próbek generowanych po każdej epoce treningowej.
		int nCharactersToSample = 300;				// Długość generowanej próbki.
		String generationInitialization = null;		// Opcjonalna inicjalizacja znaków; jeżeli null, wybierany jest losowy znak.
		// Powyższe instrukcje inicjują sieć LSTM sekwencją znaków.
		// Wszystkie znaki inicjujące muszą być domyślnie zwracane przez metodę CharacterIterator.getMinimalCharacterSet().
		Random rng = new Random(12345);

		// Utworzenie iteratora przekształcającego tekst na format wektorowy, który można wykorzystać
		// do przetrenowania sieci GravesLSTM.
		CharacterIterator iter = getShakespeareIterator(miniBatchSize,exampleLength);
		int nOut = iter.totalOutcomes();

		// Konfiguracja sieci.
		MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
			.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
			.learningRate(0.1)
			.rmsDecay(0.95)
			.seed(12345)
			.regularization(true)
			.l2(0.001)
            .weightInit(WeightInit.XAVIER)
            .updater(Updater.RMSPROP)
			.list()
			.layer(0, new GravesLSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
					.activation(Activation.TANH).build())
			.layer(1, new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)
					.activation(Activation.TANH).build())
			.layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX)        // MCXENT + softmax do klasyfikacji.
					.nIn(lstmLayerSize).nOut(nOut).build())
            .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)
			.pretrain(false).backprop(true)
			.build();

		MultiLayerNetwork net = new MultiLayerNetwork(conf);
		net.init();
		net.setListeners(new ScoreIterationListener(1));

		// Wyświetlenie liczby parametrów dla całej sieci i dla poszczególnych warstw.
		Layer[] layers = net.getLayers();
		int totalNumParams = 0;
		for( int i=0; i<layers.length; i++ ){
			int nParams = layers[i].numParams();
			System.out.println("Liczba parametrów warstwy " + i + ": " + nParams);
			totalNumParams += nParams;
		}
		System.out.println("Całkowita liczba parametrów sieci: " + totalNumParams);

		// Trening sieci, a następnie wygenerowanie i wyświetlenie próbek.
    int miniBatchNumber = 0;
		for( int i=0; i<numEpochs; i++ ){
            while(iter.hasNext()){
                DataSet ds = iter.next();
                net.fit(ds);
                if(++miniBatchNumber % generateSamplesEveryNMinibatches == 0){
                    System.out.println("--------------------");
                    System.out.println("Przetworzonych " + miniBatchNumber + " minipaczek o wielkości " + miniBatchSize + "x" + exampleLength + " znaków");
                    System.out.println("Znaki próbkujące zdefiniowane podczas inicjalizacji \"" + (generationInitialization == null ? "" : generationInitialization) + "\"");
                    String[] samples = sampleCharactersFromNetwork(generationInitialization,net,iter,rng,nCharactersToSample,nSamplesToGenerate);
                    for( int j=0; j<samples.length; j++ ){
                        System.out.println("----- Próbka " + j + " -----");
                        System.out.println(samples[j]);
                        System.out.println();
                    }
                }
            }

			iter.reset();	// Przygotowanie iteratora do następnej epoki.
		}

		System.out.println("\n\nKoniec przykładu");
	}

  /** Pobranie danych treningowych i zapisanie na lokalnym dysku (w katalogu temp).
   Następnie skonfigurowanie i zwrócenie prostego iteratora DataSetIterator wektoryzującego tekst.
   * @param miniBatchSize - liczba segmentów tekstu w każdej minipaczce treningowej.
   * @param sequenceLength - liczba znaków w każdym segmencie tekstu.
   */
	public static CharacterIterator getShakespeareIterator(int miniBatchSize, int sequenceLength) throws Exception{
    // Dzieła wszystkie Williama Szekspira.
    // Plik 5,3 MB w formacie UTF-8, ok. 5,4 mln znaków.
    // https://www.gutenberg.org/ebooks/100
		String url = "https://s3.amazonaws.com/dl4j-distribution/pg100.txt";
		String tempDir = System.getProperty("java.io.tmpdir");
		String fileLocation = tempDir + "/Shakespeare.txt";	//Storage location from downloaded file
		File f = new File(fileLocation);
		if( !f.exists() ){
			FileUtils.copyURLToFile(new URL(url), f);
			System.out.println("Pobrany plik zapisany w " + f.getAbsolutePath());
		} else {
			System.out.println("Użyty istniejący plik zapisany w " + f.getAbsolutePath());
		}

		if(!f.exists()) throw new IOException("Plik nie istnieje: " + fileLocation); // Problem z pobraniem?

		char[] validCharacters = CharacterIterator.getMinimalCharacterSet();	// Jakie znaki są dozwolone? Inne zostaną usunięte.
		return new CharacterIterator(fileLocation, Charset.forName("UTF-8"),
				miniBatchSize, sequenceLength, validCharacters, new Random(12345));
	}

  /** Wygenerowanie próbki przez sieć dla zadanych (opcjonalnych, w tym również pustych) parametrów inicjalizacyjnych.
   Inicjalizacja może być wykonana w celu przekazania sieci RNN sekwencji, którą ma rozszerzyć/kontynuować. Zwróć uwagę,
   że inicjalizowane są wszystkie próbki.
   * @param initialization - ciąg znaków; jeżeli null, znak inicjujący jest wybierany losowo dla wszystkich próbek.
   * @param charactersToSample - liczba znaków w próbce (z wyłączeniem znaków inicjujących).
   * @param net - sieć składająca się z jednej lub kilku warstw GravesLSTM/RNN i warstwy wyjściowej z funkcją softmax.
   * @param iter - iterator wykorzystywany do przekształcania indeksów z powrotem na znaki.
   */
	private static String[] sampleCharactersFromNetwork(String initialization, MultiLayerNetwork net,
                                                      CharacterIterator iter, Random rng, int charactersToSample, int numSamples ){
		// Inicjalizacja lub użycie losowego znaku.
		if( initialization == null ){
			initialization = String.valueOf(iter.getRandomCharacter());
		}

		// Utworzenie danych inicjujących.
		INDArray initializationInput = Nd4j.zeros(numSamples, iter.inputColumns(), initialization.length());
		char[] init = initialization.toCharArray();
		for( int i=0; i<init.length; i++ ){
			int idx = iter.convertCharacterToIndex(init[i]);
			for( int j=0; j<numSamples; j++ ){
				initializationInput.putScalar(new int[]{j,idx,i}, 1.0f);
			}
		}

		StringBuilder[] sb = new StringBuilder[numSamples];
		for( int i=0; i<numSamples; i++ ) sb[i] = new StringBuilder(initialization);

    // Próbki z sieci (i przesłanie ich ponownie na wejście), znak po znaku (dla wszystkich próbek).
    // Próbkowanie jest zrównoleglone.
		net.rnnClearPreviousState();
		INDArray output = net.rnnTimeStep(initializationInput);
		output = output.tensorAlongDimension(output.size(2)-1,1,0);	// Odczytanie wyników z ostatniego kroku czasowego.

		for( int i=0; i<charactersToSample; i++ ){
      // Przygotowanie następnych danych wejściowych (pojedynczego kroku czasowego) poprzez próbkowanie poprzedniego wyniku.
			INDArray nextInput = Nd4j.zeros(numSamples,iter.inputColumns());
			// Wynikiem jest rozkład prawdopodobieństwa. Próbkowanie każdego przykładu, który trzeba wygenerować, i dodanie go do nowego wejścia
			for( int s=0; s<numSamples; s++ ){
				double[] outputProbDistribution = new double[iter.totalOutcomes()];
				for( int j=0; j<outputProbDistribution.length; j++ ) outputProbDistribution[j] = output.getDouble(s,j);
				int sampledCharacterIdx = sampleFromDistribution(outputProbDistribution,rng);

				nextInput.putScalar(new int[]{s,sampledCharacterIdx}, 1.0f);		  // Przygotowanie następnego kroku czasowego.
				sb[s].append(iter.convertIndexToCharacter(sampledCharacterIdx));	// Dodanie próbkowanego znaku do obiektu StringBuilder (wyniku czytelnego dla człowieka).
			}

			output = net.rnnTimeStep(nextInput);	// Czy wykonać następny krok w przód?
		}

		String[] out = new String[numSamples];
		for( int i=0; i<numSamples; i++ ) out[i] = sb[i].toString();
		return out;
	}

  /** Próbkowanie zadanego rozkładu prawdopodobieństwa dyskretnych klas i zwrócenie wygenerowanego indeksu klasy.
   * @param distribution - rozkład prawdopodobieństwa klas; musi sumować się do 1,0.
   */
	public static int sampleFromDistribution( double[] distribution, Random rng ){
		double d = rng.nextDouble();
		double sum = 0.0;
		for( int i=0; i<distribution.length; i++ ){
			sum += distribution[i];
			if( d <= sum ) return i;
		}
		// Jeżeli rozkład prawdopodobieństwa jest poprawny, poniższa instrukcja nie powinna być nigdy wykonana.
		throw new IllegalArgumentException("Błędny rozkład prawdopodobieństwa? d="+d+", sum="+sum);
	}
}
