/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.shiro.web.filter;

import org.apache.shiro.util.StringUtils;
import org.apache.shiro.web.util.WebUtils;

import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;

/**
 * A request filter that blocks malicious requests. Invalid request will respond with a 400 response code.
 * <p>
 * This filter checks and blocks the request if the following characters are found in the request URI:
 * <ul>
 *     <li>Semicolon - can be disabled by setting {@code blockSemicolon = false}</li>
 *     <li>Backslash - can be disabled by setting {@code blockBackslash = false}</li>
 *     <li>Non-ASCII characters - can be disabled by setting {@code blockNonAscii = false}, the ability to disable this check will be removed in future version.</li>
 *     <li>Path traversals - can be disabled by setting {@code blockTraversal = false}</li>
 * </ul>
 *
 * @see <a href="https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/firewall/StrictHttpFirewall.html">This class was inspired by Spring Security StrictHttpFirewall</a>
 * @since 1.6
 */
public class InvalidRequestFilter extends AccessControlFilter {

    private static final List<String> SEMICOLON = Collections.unmodifiableList(Arrays.asList(";", "%3b", "%3B"));

    private static final List<String> BACKSLASH = Collections.unmodifiableList(Arrays.asList("\\", "%5c", "%5C"));

    private static final List<String> FORWARDSLASH = Collections.unmodifiableList(Arrays.asList("%2f", "%2F"));

    private static final List<String> PERIOD = Collections.unmodifiableList(Arrays.asList("%2e", "%2E"));

    private boolean blockSemicolon = true;

    private boolean blockBackslash = !Boolean.getBoolean(WebUtils.ALLOW_BACKSLASH);

    private boolean blockNonAscii = true;

    private boolean blockTraversal = true;

    private boolean blockEncodedPeriod = true;

    private boolean blockEncodedForwardSlash = true;

    private boolean blockRewriteTraversal = true;

    @Override
    protected boolean isAccessAllowed(ServletRequest req, ServletResponse response, Object mappedValue) throws Exception {
        HttpServletRequest request = WebUtils.toHttp(req);
        // check the original and decoded values
        return isValid(request.getRequestURI())      // user request string (not decoded)
                && isValid(request.getServletPath()) // decoded servlet part
                && isValid(request.getPathInfo());   // decoded path info (may be null)
    }

    private boolean isValid(String uri) {
        return !StringUtils.hasText(uri)
               || ( !containsSemicolon(uri)
                 && !containsBackslash(uri)
                 && !containsNonAsciiCharacters(uri)
                 && !containsTraversal(uri)
                 && !containsEncodedPeriods(uri)
                 && !containsEncodedForwardSlash(uri));
    }

    @Override
    protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
        WebUtils.toHttp(response).sendError(400, "Invalid request");
        return false;
    }

    private boolean containsSemicolon(String uri) {
        if (isBlockSemicolon()) {
            return SEMICOLON.stream().anyMatch(uri::contains);
        }
        return false;
    }

    private boolean containsBackslash(String uri) {
        if (isBlockBackslash()) {
            return BACKSLASH.stream().anyMatch(uri::contains);
        }
        return false;
    }

    private boolean containsNonAsciiCharacters(String uri) {
        if (isBlockNonAscii()) {
            return !containsOnlyPrintableAsciiCharacters(uri);
        }
        return false;
    }

    private static boolean containsOnlyPrintableAsciiCharacters(String uri) {
        int length = uri.length();
        for (int i = 0; i < length; i++) {
            char c = uri.charAt(i);
            if (c < '\u0020' || c > '\u007e') {
                return false;
            }
        }
        return true;
    }

    private boolean containsTraversal(String uri) {
        if (isBlockTraversal()) {
            return !isNormalized(uri)
                || (isBlockRewriteTraversal() && Stream.of("/..;", "/.;").anyMatch(uri::contains));
        }
        return false;
    }

    private boolean containsEncodedPeriods(String uri) {
        if (isBlockEncodedPeriod()) {
            return PERIOD.stream().anyMatch(uri::contains);
        }
        return false;
    }

    private boolean containsEncodedForwardSlash(String uri) {
        if (isBlockEncodedForwardSlash()) {
            return FORWARDSLASH.stream().anyMatch(uri::contains);
        }
        return false;
    }

    /**
     * Checks whether a path is normalized (doesn't contain path traversal sequences like
     * "./", "/../" or "/.")
     * @param path the path to test
     * @return true if the path doesn't contain any path-traversal character sequences.
     */
    private boolean isNormalized(String path) {
        if (path == null) {
            return true;
        }
        for (int i = path.length(); i > 0;) {
            int slashIndex = path.lastIndexOf('/', i - 1);
            int gap = i - slashIndex;
            if (gap == 2 && path.charAt(slashIndex + 1) == '.') {
                return false; // ".", "/./" or "/."
            }
            if (gap == 3 && path.charAt(slashIndex + 1) == '.' && path.charAt(slashIndex + 2) == '.') {
                return false;
            }
            i = slashIndex;
        }
        return true;
    }

    public boolean isBlockSemicolon() {
        return blockSemicolon;
    }

    public void setBlockSemicolon(boolean blockSemicolon) {
        this.blockSemicolon = blockSemicolon;
    }

    public boolean isBlockBackslash() {
        return blockBackslash;
    }

    public void setBlockBackslash(boolean blockBackslash) {
        this.blockBackslash = blockBackslash;
    }

    public boolean isBlockNonAscii() {
        return blockNonAscii;
    }

    public void setBlockNonAscii(boolean blockNonAscii) {
        this.blockNonAscii = blockNonAscii;
    }

    public boolean isBlockTraversal() {
        return blockTraversal;
    }

    public void setBlockTraversal(boolean blockTraversal) {
        this.blockTraversal = blockTraversal;
    }

    public boolean isBlockEncodedPeriod() {
        return blockEncodedPeriod;
    }

    public void setBlockEncodedPeriod(boolean blockEncodedPeriod) {
        this.blockEncodedPeriod = blockEncodedPeriod;
    }

    public boolean isBlockEncodedForwardSlash() {
        return blockEncodedForwardSlash;
    }

    public void setBlockEncodedForwardSlash(boolean blockEncodedForwardSlash) {
        this.blockEncodedForwardSlash = blockEncodedForwardSlash;
    }

    public boolean isBlockRewriteTraversal() {
        return blockRewriteTraversal;
    }

    public void setBlockRewriteTraversal(boolean blockRewriteTraversal) {
        this.blockRewriteTraversal = blockRewriteTraversal;
    }
}
