package org.apache;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;

import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;

import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

/**
 * MockUtil
 *
 * @author Tadashi Nakayama
 */
public final class WebMock {

    /** ContextAttribute */
    private static final Consumer<ServletContext> ATTRIBUTE_CONTEXT = sc -> {
        final Map<String, Object> attrMap = new HashMap<>();
        Mockito.when(sc.getAttribute(ArgumentMatchers.anyString())).thenAnswer(
            m -> attrMap.get(String.class.cast(m.getArgument(0)))
        );
        Mockito.when(sc.getAttributeNames()).thenAnswer(
            m -> Collections.enumeration(attrMap.keySet())
        );
        Mockito.doAnswer(m -> {
            attrMap.remove(String.class.cast(m.getArgument(0)));
            return null;
        }).when(sc).removeAttribute(ArgumentMatchers.anyString());
        Mockito.doAnswer(m -> {
            attrMap.put(m.getArgument(0), m.getArgument(1));
            return null;
        }).when(sc).setAttribute(
                ArgumentMatchers.anyString(), ArgumentMatchers.any());
    };

    /** SessionAttribute */
    private static final Consumer<HttpSession> ATTRIBUTE_SESSION = sess -> {
        final Map<String, Object> attrMap = new HashMap<>();
        Mockito.when(sess.getAttribute(ArgumentMatchers.anyString())).thenAnswer(
            m -> attrMap.get(String.class.cast(m.getArgument(0)))
        );
        Mockito.when(sess.getAttributeNames()).thenAnswer(
            m -> Collections.enumeration(attrMap.keySet())
        );
        Mockito.doAnswer(m -> {
            attrMap.remove(String.class.cast(m.getArgument(0)));
            return null;
        }).when(sess).removeAttribute(ArgumentMatchers.anyString());
        Mockito.doAnswer(m -> {
            attrMap.put(m.getArgument(0), m.getArgument(1));
            return null;
        }).when(sess).setAttribute(
                ArgumentMatchers.anyString(), ArgumentMatchers.any());
    };

    /** RequestAttribute */
    private static final Consumer<HttpServletRequest> ATTRIBUTE_REQUEST = req -> {
        final Map<String, Object> attrMap = new HashMap<>();
        Mockito.when(req.getAttribute(ArgumentMatchers.anyString())).thenAnswer(
            m -> attrMap.get(String.class.cast(m.getArgument(0)))
        );
        Mockito.when(req.getAttributeNames()).thenAnswer(
            m -> Collections.enumeration(attrMap.keySet())
        );
        Mockito.doAnswer(m -> {
            attrMap.remove(String.class.cast(m.getArgument(0)));
            return null;
        }).when(req).removeAttribute(ArgumentMatchers.anyString());
        Mockito.doAnswer(m -> {
            attrMap.put(m.getArgument(0), m.getArgument(1));
            return null;
        }).when(req).setAttribute(
                ArgumentMatchers.anyString(), ArgumentMatchers.any());
    };

    /**
     * Constructor
     */
    private WebMock() {
        throw new AssertionError();
    }

    /**
     * create ServletContext
     * @return ServletContext
     */
    public static ServletContext createServletContext() {
        final ServletContext sc = Mockito.mock(ServletContext.class);
        ATTRIBUTE_CONTEXT.accept(sc);

        final Map<String, String> initParam = new HashMap<>();
        Mockito.when(sc.getInitParameter(ArgumentMatchers.anyString())).thenAnswer(
            m -> initParam.get(String.class.cast(m.getArgument(0)))
        );
        Mockito.when(sc.getInitParameterNames()).thenAnswer(
            m -> Collections.enumeration(initParam.keySet())
        );
        Mockito.doAnswer(m -> {
            initParam.put(m.getArgument(0), m.getArgument(1));
            return null;
        }).when(sc).setInitParameter(
                ArgumentMatchers.anyString(), ArgumentMatchers.anyString());
        return sc;
    }

    /**
     * create HttpSession
     * @return HttpSession
     */
    public static HttpSession createHttpSession() {
        final HttpSession sess = Mockito.mock(HttpSession.class);
        ATTRIBUTE_SESSION.accept(sess);
        return sess;
    }

    /**
     * create HttpServletRequest
     * @param sess HttpSession
     * @param hMap HeaderMap
     * @param pMap ParameterMap
     * @return HttpServletRequest
     */
    public static HttpServletRequest createHttpServletRequest(final HttpSession sess,
            final Map<String, String[]> hMap, final Map<String, String[]> pMap) {

        final HttpServletRequest req = Mockito.mock(HttpServletRequest.class);
        ATTRIBUTE_REQUEST.accept(req);

        final Map<String, String[]> header = new HashMap<>(hMap);
        Mockito.when(req.getHeader(ArgumentMatchers.anyString())).thenAnswer(
            m -> {
                String[] val = header.get(String.class.cast(m.getArgument(0)));
                return val == null || val.length < 1 ? null : val[0];
            }
        );
        Mockito.when(req.getHeaders(ArgumentMatchers.anyString())).thenAnswer(
            m -> {
                String[] val = header.get(String.class.cast(m.getArgument(0)));
                return val == null ? Collections.emptyEnumeration()
                        : Collections.enumeration(Arrays.asList(val));
            }
        );
        Mockito.when(req.getHeaderNames()).thenAnswer(
            m -> Collections.enumeration(header.keySet())
        );

        final Map<String, String[]> param = new HashMap<>(pMap);
        Mockito.when(req.getParameter(ArgumentMatchers.anyString())).thenAnswer(
            m -> {
                String[] val = param.get(String.class.cast(m.getArgument(0)));
                return val == null || val.length < 1 ? null : val[0];
            }
        );
        Mockito.when(req.getParameterValues(ArgumentMatchers.anyString())).thenAnswer(
            m -> param.get(String.class.cast(m.getArgument(0)))
        );
        Mockito.when(req.getParameterNames()).thenAnswer(
            m -> Collections.enumeration(param.keySet())
        );

        final boolean[] called = new boolean[1];
        Mockito.when(req.getSession()).thenAnswer(m -> {
            called[0] = true;
            return sess;
        });
        Mockito.when(req.getSession(true)).thenAnswer(m -> {
            called[0] = true;
            return sess;
        });
        Mockito.when(req.getSession(false)).thenAnswer(m -> {
            return called[0] ? sess : null;
        });

        return req;
    }

}
