/*
 * Copyright (C) 2008-2011 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
/*
 * Copyright (C) 2011 OgakiSoft
 * 
 * Merged AbstractLearner and InstanceLearner.
 * Changed to cooperate with SvmWrapper.
 * Added some methods serialize(), deserialize(), saveSvmData(), etc. 
 */

package ogakisoft.android.gesture.reform;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.TreeMap;

import ogakisoft.android.util.LOG;
import ogakisoft.android.util.Utils;

/**
 * Learner
 * 
 * @author The Android Open Source Project
 * @author noritoshi ogaki
 * @version 1.0
 */
public class Learner {
	/** Log tag */
	private static final String TAG = "Learner";
	/** List of instance */
	private final List<Instance> mInstances = Collections
			.synchronizedList(new ArrayList<Instance>());
	/** Index of integer-id */
	private final Map<Integer, Long> mId = Collections
			.synchronizedMap(new HashMap<Integer, Long>());
	/** Max number of index */
	private static final int MAX_NUM_OF_INDEX = 100000;
	/** A difference stroke-count with target and Learner.Instance's threshold */
	public static final int STROKE_COUNT_THRESHOLD = 2;
	/** line-feed string */
	private static final String LINE_FEED = "\n";
	/** file-number */
//	private final int mFileNo;

	public Learner(int fileno) {
//		mFileNo = fileno;
	}

	/**
	 * Add an instance to the learner
	 * 
	 * @param instance
	 */
	public void addInstance(Instance instance) {
		if (null != instance) {
			mInstances.add(instance);
			// LOG.d(TAG, "addInstance: id={0,number,#}, label={1}",
			// instance.id,
			// instance.label);
			final int intId = convertId(instance.id);
			mId.put(Integer.valueOf(intId), Long.valueOf(instance.id));
		}
	}

	/**
	 * Retrieve all the instances
	 * 
	 * @return instances
	 */
	public List<Instance> getInstances() {
		return mInstances;
	}

	/**
	 * Remove an instance based on its id
	 * 
	 * @param id
	 */
	public void removeInstance(long id) {
		final List<Instance> instances = mInstances;
		final int count = instances.size();
		Instance instance = null;
		for (int i = 0; i < count; i++) {
			instance = instances.get(i);
			if (id == instance.id) {
				instances.remove(instance);
				break;
			}
		}
		synchronized (mId) {
			Map.Entry<Integer, Long> entry;
			for (final Iterator<Map.Entry<Integer, Long>> it = mId.entrySet()
					.iterator(); it.hasNext();) {
				entry = it.next();
				if (entry.getValue() == id) {
					mId.remove(entry.getKey());
					break;
				}
			}
		}
	}

	/**
	 * Remove all the instances
	 */
	public void removeAll() {
		final List<Instance> toDelete = new ArrayList<Instance>();
		final List<Instance> instances = mInstances;
		final int count = instances.size();
		Instance instance;
		for (int i = 0; i < count; i++) {
			instance = instances.get(i);
			toDelete.add(instance);
		}
		if (toDelete.size() > 0) {
			instances.removeAll(toDelete);
			LOG.d(TAG, "removeAll: remove count={0,number,#}", toDelete.size());
		}
		synchronized (mId) {
			mId.clear();
		}
	}

	/**
	 * classify
	 * 
	 * @param sequenceType
	 * @param vector
	 * @return predict result
	 */
	public List<Prediction> classify(int sequenceType, float[] vector,
			Feature feature) {
		// LOG.d(TAG, "classify fileno={0,number,#}", mFileNo);
		final List<Prediction> result = new ArrayList<Prediction>();
		final List<Instance> instances = getInstances();
		final TreeMap<String, Double> label2score = new TreeMap<String, Double>();
		Instance instance = null;
		double weight = 0d;
		Double score = 0d;
//		int[] arrayId;
		int count = instances.size();
//		float match;
		for (int i = 0; i < count; i++) {
			instance = instances.get(i);
//			if ((match = instance.match(feature)) == 0) {
//				continue;
//			}
//			LOG.d(TAG, "classify: label={0} match={1,number,#.###}",
//					instance.label, match);
			// if (instance.vector.length != vector.length) {
			// continue;
			// }
			weight = computeWeight(sequenceType, instance.vector, vector);
			score = label2score.get(instance.label);
			if (null == score || weight > score) {
				label2score.put(instance.label, weight);
			}
		}
		// arrayId = SvmWrapper.getInstance().classify(mFileNo, vector, this);
		// count = arrayId.length;
		// for (int i = 0; i < count; i++) {
		// if (0 != arrayId[i]) {
		// instance = null;
		// Long longId = mId.get(arrayId[i]);
		// if (null != longId) {
		// instance = getInstanceById(longId.longValue());
		// weight = computeWeight(sequenceType, instance.vector,
		// vector);
		// score = label2score.get(instance.label);
		// if (null == score || weight > score) {
		// label2score.put(instance.label, weight);
		// }
		// }
		// }
		// }
		for (Entry<String, Double> e : label2score.entrySet()) {
			if (!result.contains(e.getKey())) {
				result.add(new Prediction(e.getKey(), e.getValue()));
			}
		}
		// normalize
		// for (Prediction prediction : predictions) {
		// prediction.score /= sum;
		// }
		// Collections.sort(result, new Comparator<Prediction>() {
		// public int compare(Prediction object1, Prediction object2) {
		// final double score1 = object1.score;
		// final double score2 = object2.score;
		// if (score1 > score2) {
		// return -1;
		// } else if (score1 < score2) {
		// return 1;
		// } else {
		// return 0;
		// }
		// }
		// });
		return result;
	}

	private double computeWeight(int sequenceType, float[] v1, float[] v2) {
		double weight = 0d;
		double distance = 0d;
		if (sequenceType == GestureStore.SEQUENCE_SENSITIVE) {
			distance = GestureUtilities.cosineDistance(v1, v2);
		} else {
			distance = GestureUtilities.squaredEuclideanDistance(v1, v2);
		}
		if (0 == distance) {
			weight = Double.MAX_VALUE;
		} else {
			weight = 1 / distance;
		}
		return weight;
	}

//	private Instance getInstanceById(long id) {
//		List<Instance> list = getInstances();
//		int count = list.size();
//		Instance result = null;
//		for (int i = 0; i < count; i++) {
//			if (list.get(i).id == id) {
//				result = list.get(i);
//				break;
//			}
//		}
//		return result;
//	}

	/**
	 * Method containsKey.
	 * 
	 * @param key
	 *            convert value of gesture-id
	 * @return true if key exist
	 */
	public boolean containsKey(int key) {
		return mId.containsKey(Integer.valueOf(key));
	}

	/**
	 * Method getLabel.
	 * 
	 * @param key
	 *            convert value of gesture-id
	 * @return character to show by a gesture
	 */
	public String getLabel(int key) {
		String label = "";
		if (mId.containsKey(Integer.valueOf(key))) {
			label = getInstance(mId.get(Integer.valueOf(key))).label;
		}
		return label;
	}

	/**
	 * Method convertId.
	 * 
	 * @param id
	 *            gesture-id
	 * @return convert value of gesture-id
	 */
	private int convertId(long id) {
		int intId = -1;
		if (null != mId) {
			intId = containsId(id);
		}
		if (intId == -1) {
			final String str = String.valueOf(id);
			intId = Integer.parseInt(str.substring(str.length() - 5));
			while (mId.containsKey(Integer.valueOf(intId))) {
				intId = randomNumber();
			}
		}
		return intId;
	}

	/**
	 * Method containsId.
	 * 
	 * @param id
	 *            gesture-id
	 * @return convert value of gesture-id, -1 if id does not exist
	 */
	private int containsId(long id) {
		int intId = -1;
		synchronized (mId) {
			Map.Entry<Integer, Long> entry;
			for (final Iterator<Map.Entry<Integer, Long>> it = mId.entrySet()
					.iterator(); it.hasNext();) {
				entry = it.next();
				if (entry.getValue().longValue() == id) {
					intId = entry.getKey().intValue();
					break;
				}
			}
		}
		return intId;
	}

	/**
	 * Method randomNumber
	 * 
	 * @return higher than 0 and lower than MAX_NUM_OF_INDEX
	 */
	private int randomNumber() {
		final Random random = new Random();
		int i = 0;
		do {
			i = random.nextInt();
		} while (i <= 0 || i > MAX_NUM_OF_INDEX);
		return i;
	}

	/**
	 * Method serialize.
	 * 
	 * @param file
	 */
	public void serialize(File file) {
		DataOutputStream out = null;
		try {
			out = new DataOutputStream(new FileOutputStream(file));
			serialize(out);
		} catch (IOException e) {
			LOG.e(TAG, "serialize: {0}", e.getMessage());
		} finally {
			try {
				if (null != out) {
					out.close();
				}
			} catch (IOException e) {
				LOG.e(TAG, "serialize: {0}", e.getMessage());
			}
		}
	}

	/**
	 * serialize
	 * 
	 * @param out
	 * @throws IOException
	 */
	private void serialize(DataOutputStream out) throws IOException {
		Map.Entry<Integer, Long> entry;
		int count;
		long id;
		Instance value;
		synchronized (mId) {
			synchronized (mInstances) {
				out.writeInt(mId.size());
				for (final Iterator<Map.Entry<Integer, Long>> it = mId
						.entrySet().iterator(); it.hasNext();) {
					entry = it.next();
					id = entry.getValue().longValue();
					value = getInstance(id);
					out.writeLong(value.id);
					out.writeInt(entry.getKey().intValue());
					out.writeUTF(value.label);
					count = value.vector.length;
					out.writeInt(count);
					for (int i = 0; i < count; i++) {
						out.writeFloat(value.vector[i]);
					}
//					out.writeInt(value.feature.strokes_count);
//					out.writeInt(value.feature.sum_straight);
//					out.writeInt(value.feature.sum_vertical_down);
//					out.writeInt(value.feature.sum_vertical_up);
//					out.writeInt(value.feature.sum_horizontal_left);
//					out.writeInt(value.feature.sum_horizontal_right);
//					out.writeInt(value.feature.sum_circle);
					// out.writeInt(value.feature.sum_parallel);
//					out.writeInt(value.feature.sum_right_down);
				}
			}
		}
	}

	/**
	 * Method getInstance.
	 * 
	 * @param id
	 * @return Instance
	 */
	private Instance getInstance(long id) {
		final List<Instance> list = mInstances;
		final int count = mInstances.size();
		Instance result = null;
		for (int i = 0; i < count; i++) {
			if (id == list.get(i).id) {
				result = list.get(i);
				break;
			}
		}
		return result;
	}

	/**
	 * Get Instance by intId
	 * 
	 * @param intId
	 * @return Instance
	 */
	public Instance getInstance(int intId) {
		Instance result = null;
		Integer key = Integer.valueOf(intId);
		if (mId.containsKey(key)) {
			result = getInstance(mId.get(key));
		}
		return result;
	}

	/**
	 * Method deserialize.
	 * 
	 * @param file
	 *            input-file
	 */
	public void deserialize(File file) {
		DataInputStream in = null;
		try {
			in = new DataInputStream(new FileInputStream(file));
			deserialize(in);
		} catch (IOException e) {
			LOG.e(TAG, "deserialize: {0}", e.getMessage());
		} finally {
			try {
				if (null != in) {
					in.close();
				}
			} catch (IOException e) {
				LOG.e(TAG, "deserialize: {0}", e.getMessage());
			}
		}
	}

	/**
	 * deserialize
	 * 
	 * @param in
	 * @throws IOException
	 */
	private void deserialize(DataInputStream in) throws IOException {
		long id;
		int intId;
		String label;
		int count;
		float[] vector;
//		Feature f;
		synchronized (mId) {
			synchronized (mInstances) {
				mId.clear();
				mInstances.clear();
				final int size = in.readInt();
				for (int i = 0; i < size; i++) {
					id = in.readLong();
					intId = in.readInt();
					label = in.readUTF();
					count = in.readInt();
					vector = new float[count];
					for (int v = 0; v < count; v++) {
						vector[v] = in.readFloat();
					}
//					f = new Feature();
//					f.strokes_count = in.readInt();
//					f.sum_straight = in.readInt();
//					f.sum_vertical_down = in.readInt();
//					f.sum_vertical_up = in.readInt();
//					f.sum_horizontal_left = in.readInt();
//					f.sum_horizontal_right = in.readInt();
//					f.sum_circle = in.readInt();
					// f.sum_parallel = in.readInt();
//					f.sum_right_down = in.readInt();
					mInstances.add(new Instance(id, vector, label)); //, f));
					mId.put(Integer.valueOf(intId), Long.valueOf(id));
				}
			}
		}
	}

	/**
	 * Method saveSvmData.
	 * 
	 * @param file
	 *            output-file
	 */
	public void saveSvmData(File file) {
		FileOutputStream out = null;
		Map.Entry<Integer, Long> entry;
		Instance value;
		int count = 0;
		try {
			out = new FileOutputStream(file);
			for (final Iterator<Map.Entry<Integer, Long>> it = mId.entrySet()
					.iterator(); it.hasNext();) {
				entry = it.next();
				value = getInstance(entry.getValue().longValue());
				out.write(Utils.concat(
						String.valueOf(entry.getKey().intValue()), " ")
						.getBytes());
				count = value.vector.length;
				for (int i = 0; i < count; i += 2) {
					if (value.vector[i] != 0 || value.vector[i + 1] != 0) {
						out.write(Utils.concat(String.valueOf((i + 1)), ":",
								String.valueOf(value.vector[i]), " ")
								.getBytes());
						out.write(Utils.concat(String.valueOf((i + 2)), ":",
								String.valueOf(value.vector[i + 1]), " ")
								.getBytes());
					}
				}
				out.write(String.valueOf(LINE_FEED).getBytes());
			}
		} catch (IOException e) {
			LOG.e(TAG, "saveSvmData: {0}", e.getMessage());
		} finally {
			try {
				if (null != out) {
					out.close();
				}
			} catch (IOException e) {
				LOG.e(TAG, "saveSvmData: {0}", e.getMessage());
			}
		}
	}
}
