:- module binstring.
%=============================================================================%
% Copyright (c) 2019-2020, AlaskanEmily, Transnat Games
%
% This software is provided 'as-is', without any express or implied
% warranty.  In no event will the authors be held liable for any damages
% arising from the use of this software.
%
% Permission is granted to anyone to use this software for any purpose,
% including commercial applications, and to alter it and redistribute it
% freely, subject to the following restrictions:
%
% 1. The origin of this software must not be misrepresented; you must not
%    claim that you wrote the original software. If you use this software
%    in a product, an acknowledgment in the product documentation would be
%    appreciated but is not required.
% 2. Altered source versions must be plainly marked as such, and must not be
%    misrepresented as being the original software.
% 3. This notice may not be removed or altered from any source distribution.
%=============================================================================%
% Really simple module to save/load a string to binary streams.
% Includes options for nul-termination or Pascal-style leading sizes.
:- interface.
%=============================================================================%

:- use_module stream.

%-----------------------------------------------------------------------------%

:- type error --->
    unexpected_eof ;
    encoding_error ;
    stream_error(string).

%-----------------------------------------------------------------------------%
% Inst's to let us get a result that is unique on ok, but grounded on error,
% with a unique tag.
:- inst error_uniq for binstring.error/0 ==
    unique(unexpected_eof ; encoding_error ; stream_error(ground)).

:- inst result_uniq(T) for stream.maybe_partial_res/2 ==
    unique(stream.ok(unique) ; stream.error(ground, T)).

:- mode result_uo == free >> result_uniq(error_uniq).

:- type partial_result == stream.maybe_partial_res(string, binstring.error).

%-----------------------------------------------------------------------------%
% Saves a string as NUL-terminated
:- pred save_string_utf8(Stream, string, State, State)
    <= stream.writer(Stream, int, State).
:- mode save_string_utf8(in, in, di, uo) is det.

%-----------------------------------------------------------------------------%
% Loads a string as NUL-terminated
:- pred load_string_utf8(Stream, partial_result, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode load_string_utf8(in, result_uo, di, uo) is det.

%-----------------------------------------------------------------------------%
% Saves a string as non-NUL-terminated byte sequence with a single size byte
:- pred save_string_utf8_len1(Stream, string, State, State)
    <= stream.writer(Stream, int, State).
:- mode save_string_utf8_len1(in, in, di, uo) is det.

%-----------------------------------------------------------------------------%
% Loads a string as non-NUL-terminated byte sequence with a single size byte
:- pred load_string_utf8_len1(Stream, partial_result, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode load_string_utf8_len1(in, result_uo, di, uo) is det.

%-----------------------------------------------------------------------------%
% Saves a string as non-NUL-terminated byte sequence with two size bytes (Big)
:- pred save_string_utf8_len2_be(Stream, string, State, State)
    <= stream.writer(Stream, int, State).
:- mode save_string_utf8_len2_be(in, in, di, uo) is det.

%-----------------------------------------------------------------------------%
% Loads a string as non-NUL-terminated byte sequence with two size bytes (Big)
:- pred load_string_utf8_len2_be(Stream, partial_result, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode load_string_utf8_len2_be(in, result_uo, di, uo) is det.

%-----------------------------------------------------------------------------%
% Saves a string as non-NUL-terminated byte sequence with two size bytes (Little)
:- pred save_string_utf8_len2_le(Stream, string, State, State)
    <= stream.writer(Stream, int, State).
:- mode save_string_utf8_len2_le(in, in, di, uo) is det.

%-----------------------------------------------------------------------------%
% Loads a string as non-NUL-terminated byte sequence with two size bytes (Little)
:- pred load_string_utf8_len2_le(Stream, partial_result, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode load_string_utf8_len2_le(in, result_uo, di, uo) is det.

%-----------------------------------------------------------------------------%
% Saves a string as non-NUL-terminated byte sequence with four size bytes (Big)
:- pred save_string_utf8_len4_be(Stream, string, State, State)
    <= stream.writer(Stream, int, State).
:- mode save_string_utf8_len4_be(in, in, di, uo) is det.

%-----------------------------------------------------------------------------%
% Loads a string as non-NUL-terminated byte sequence with four size bytes (Big)
:- pred load_string_utf8_len4_be(Stream, partial_result, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode load_string_utf8_len4_be(in, result_uo, di, uo) is det.

%-----------------------------------------------------------------------------%
% Saves a string as non-NUL-terminated byte sequence with four size bytes (Little)
:- pred save_string_utf8_len4_le(Stream, string, State, State)
    <= stream.writer(Stream, int, State).
:- mode save_string_utf8_len4_le(in, in, di, uo) is det.

%-----------------------------------------------------------------------------%
% Loads a string as non-NUL-terminated byte sequence with four size bytes (Little)
:- pred load_string_utf8_len4_le(Stream, partial_result, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode load_string_utf8_len4_le(in, result_uo, di, uo) is det.

%=============================================================================%
:- implementation.
%=============================================================================%

:- use_module char.
:- use_module exception.
:- import_module int.
:- import_module list.
:- use_module string.

%-----------------------------------------------------------------------------%
% Some insts to allow us to get the result_uo when using util functions to
% create the error/0 data.
:- inst err_uniq for binstring.error/0 ==
    unique(unexpected_eof ; stream_error(ground)).

:- inst res_uniq(T) for stream.res/2 ==
    unique(stream.ok(unique) ; stream.error(T)).

:- mode res_uo == free >> res_uniq(err_uniq).

:- mode error_di == error_uniq >> clobbered.
:- mode error_uo == free >> error_uniq.

:- type int_res == stream.res(int, binstring.error).

%-----------------------------------------------------------------------------%

:- pred string_from_utf8_code_unit_list_rev(list.list(int), string).
:- mode string_from_utf8_code_unit_list_rev(di, uo) is semidet.
string_from_utf8_code_unit_list_rev(List, String) :-
    string.from_utf8_code_unit_list(list.reverse(List), String).

%-----------------------------------------------------------------------------%

:- pred pad(Stream, int, State, State)
    <= stream.writer(Stream, int, State).
:- mode pad(in, in, di, uo) is det.

pad(Stream, N) --> pad(Stream, N, 32).

:- pred pad(Stream, int, int, State, State)
    <= stream.writer(Stream, int, State).
:- mode pad(in, in, in, di, uo) is det.

pad(Stream, N, I, !State) :-
    ( if
        N < 1
    then
        true
    else
        stream.put(Stream, I, !State),
        pad(Stream, N - 1, I, !State)
    ).

%-----------------------------------------------------------------------------%

:- pred write_utf8(Stream, character, int, int, State, State)
    <= stream.writer(Stream, int, State).
:- mode write_utf8(in, in, di, uo, di, uo) is det.

write_utf8(Stream, Char, I, O, !State) :-
    ( if
        char.to_utf8(Char, UTF8)
    then
        list.foldl2(
            (pred(Int::in, FI::di, (FI + 1)::uo, StI::di, StO::uo) is det :-
                stream.put(Stream, Int, StI, StO) ),
                UTF8,
                I, O,
                !State)
    else
        I = O
    ).

%-----------------------------------------------------------------------------%

:- pred write_utf8(Stream, int, character, int, int, State, State)
    <= stream.writer(Stream, int, State).
:- mode write_utf8(in, in, in, di, uo, di, uo) is det.

write_utf8(Stream, Max, Char, I, O, !State) :-
    ( if
        I < Max,
        char.to_utf8(Char, UTF8),
        Next = list.length(UTF8) + I,
        Next =< Max
    then
        list.foldl(stream.put(Stream), UTF8, !State),
        O = Next + 0
    else
        O = I + 0
    ).

%-----------------------------------------------------------------------------%

:- pred write_utf8(Stream, character, State, State)
    <= stream.writer(Stream, int, State).
:- mode write_utf8(in, in, di, uo) is det.

write_utf8(Stream, Char, !State) :-
    ( if
        char.to_utf8(Char, UTF8)
    then
        list.foldl(stream.put(Stream), UTF8, !State)
    else
        true
    ).

%-----------------------------------------------------------------------------%

:- func emit_error(binstring.error, list.list(int)) = partial_result. 
:- mode emit_error(error_di, di) = (result_uo) is det.

emit_error(Error, List) = Result :-
    ( if
        string_from_utf8_code_unit_list_rev(List, String)
    then
        Result = stream.error(String, Error)
    else
        Result = stream.error("", encoding_error)
    ).

%-----------------------------------------------------------------------------%

:- pred read_utf8(Stream, list.list(int), partial_result, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode read_utf8(in, di, result_uo, di, uo) is det.

read_utf8(Stream, List, Result, !State) :-
    stream.get(Stream, ByteResult, !State),
    (
        (
            ByteResult = stream.error(StreamError),
            Error = stream_error(stream.error_message(StreamError))
        ;
            ByteResult = stream.eof,
            Error = unexpected_eof
        ),
        emit_error(Error, List) = Result
    ;
        ByteResult = stream.ok(Byte),
        ( if
            Byte = 0
        then
            ( if
                string_from_utf8_code_unit_list_rev(List, String)
            then
                Result = stream.ok(String)
            else
                Result = stream.error("", encoding_error)
            )
        else
            read_utf8(Stream, [Byte+0|List], Result, !State)
        )
    ).

%-----------------------------------------------------------------------------%

:- pred read_utf8_len(Stream, int, list.list(int), partial_result, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode read_utf8_len(in, in, di, result_uo, di, uo) is det.

read_utf8_len(Stream, Remaining, List, Result, !State) :-
    builtin.compare(Cmp, Remaining, 0),
    (
        Cmp = (<),
        exception.throw(
            exception.software_error("Invalid index in binstring.read_utf8_len"))
    ;
        Cmp = (=),
        ( if
            string_from_utf8_code_unit_list_rev(List, String)
        then
            Result = stream.ok(String)
        else
            Result = stream.error("", encoding_error)
        )
    ;
        Cmp = (>),
        stream.get(Stream, ByteResult, !State),
        (
            (
                ByteResult = stream.error(StreamError),
                Error = stream_error(stream.error_message(StreamError))
            ;
                ByteResult = stream.eof,
                Error = unexpected_eof
            ),
            emit_error(Error, List) = Result
        ;
            ByteResult = stream.ok(Byte),
            ( if
                % Early-out for obviously incorrect bytes for UTF8
                ( Byte = 0 ; Byte >= 0xFE )
            then
                emit_error(encoding_error, List) = Result
            else
                read_utf8_len(Stream, Remaining - 1, [Byte+0|List], Result, !State)
            )
        )
    ).

%-----------------------------------------------------------------------------%

save_string_utf8(Stream, String, !State) :-
    string.foldl(write_utf8(Stream), String, !State),
    stream.put(Stream, 0, !State).

%-----------------------------------------------------------------------------%

load_string_utf8(Stream, Result, !State) :-
    read_utf8(Stream, [], Result, !State).

%-----------------------------------------------------------------------------%

save_string_utf8_len1(Stream, String, !State) :-
    Len = int.min(string.length(String), 0xFF),
    stream.put(Stream, Len, !State),
    string.foldl2(write_utf8(Stream, Len), String, 0, End, !State),
    pad(Stream, Len - End, !State).

%-----------------------------------------------------------------------------%

:- pred read_int_be(Stream, int, int, int_res, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode read_int_be(in, di, di, res_uo, di, uo) is det.
read_int_be(Stream, N, Val, Result, !State) :-
    builtin.compare(Cmp, N, 0),
    (
        Cmp = (<),
        exception.throw(
            exception.software_error("Invalid index in binstring.read_int"))
    ;
        Cmp = (=),
        Result = stream.ok(Val)
    ;
        Cmp = (>),
        stream.get(Stream, ByteResult, !State),
        (
            ByteResult = stream.error(StreamError),
            Result = stream.error(stream_error(stream.error_message(StreamError)))
        ;
            ByteResult = stream.eof,
            Result = stream.error(unexpected_eof)
        ;
            ByteResult = stream.ok(Byte),
            read_int_be(Stream, N - 1, int.unchecked_left_shift(Val, 3) + Byte, Result, !State)
        )
    ).

:- pred read_int_be(Stream, int, int_res, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode read_int_be(in, di, res_uo, di, uo) is det.

read_int_be(Stream, N, Result) --> read_int_be(Stream, N, 0, Result).

%-----------------------------------------------------------------------------%

:- pred read_int_le(Stream, int, int, int, int_res, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode read_int_le(in, di, in, di, res_uo, di, uo) is det.

read_int_le(Stream, N, Max, Val, Result, !State) :-
    builtin.compare(Cmp, N, Max),
    (
        Cmp = (>),
        exception.throw(exception.software_error("Invalid index in binstring.read_int"))
    ;
        Cmp = (=),
        Result = stream.ok(Val)
    ;
        Cmp = (<),
        stream.get(Stream, ByteResult, !State),
        (
            ByteResult = stream.error(StreamError),
            Result = stream.error(stream_error(stream.error_message(StreamError)))
        ;
            ByteResult = stream.eof,
            Result = stream.error(unexpected_eof)
        ;
            ByteResult = stream.ok(Byte),
            read_int_be(Stream, N + 1, int.unchecked_left_shift(Byte, 3 * N) + Val, Result, !State)
        )
    ).

%-----------------------------------------------------------------------------%

:- pred read_int_le(Stream, int, int_res, State, State)
    <= (stream.reader(Stream, int, State, Error), stream.error(Error)).
:- mode read_int_le(in, di, res_uo, di, uo) is det.

read_int_le(Stream, N, Result) --> read_int_le(Stream, 0, N, 0, Result).

%-----------------------------------------------------------------------------%

load_string_utf8_len1(Stream, Result, !State) :-
    stream.get(Stream, LenResult, !State),
    (
        (
            LenResult = stream.error(StreamError),
            Error = stream_error(stream.error_message(StreamError))
        ;
            LenResult = stream.eof,
            Error = unexpected_eof
        ),
        Result = stream.error("", Error)
    ;
        LenResult = stream.ok(Len),
        ( if
            Len = 0
        then
            Result = stream.ok("")
        else
            read_utf8_len(Stream, Len, [], Result, !State)
        )
    ).

%-----------------------------------------------------------------------------%

save_string_utf8_len2_be(Stream, String, !State) :-
    Len = int.min(string.length(String), 0xFFFF),
    stream.put(Stream, int.unchecked_right_shift(Len, 8), !State),
    stream.put(Stream, Len /\ 0xFF, !State),
    string.foldl2(write_utf8(Stream, Len), String, 0, End, !State),
    pad(Stream, Len - End, !State).

%-----------------------------------------------------------------------------%

load_string_utf8_len2_be(Stream, Result, !State) :-
    read_int_be(Stream, 2, LenResult, !State),
    (
        LenResult = stream.error(Error),
        Result = stream.error("", Error)
    ;
        LenResult = stream.ok(Len),
        ( if
            Len = 0
        then
            Result = stream.ok("")
        else
            read_utf8_len(Stream, Len, [], Result, !State)
        )
    ).

%-----------------------------------------------------------------------------%

save_string_utf8_len2_le(Stream, String, !State) :-
    Len = int.min(string.length(String), 0xFFFF),
    stream.put(Stream, Len /\ 0xFF, !State),
    stream.put(Stream, int.unchecked_right_shift(Len, 8), !State),
    string.foldl2(write_utf8(Stream, Len), String, 0, End, !State),
    pad(Stream, Len - End, !State).

%-----------------------------------------------------------------------------%

load_string_utf8_len2_le(Stream, Result, !State) :-
    read_int_le(Stream, 2, LenResult, !State),
    (
        LenResult = stream.error(Error),
        Result = stream.error("", Error)
    ;
        LenResult = stream.ok(Len),
        ( if
            Len = 0
        then
            Result = stream.ok("")
        else
            read_utf8_len(Stream, Len, [], Result, !State)
        )
    ).

%-----------------------------------------------------------------------------%

save_string_utf8_len4_be(Stream, String, !State) :-
    Len = int.min(string.length(String), 0xFFFFFFFF),
    stream.put(Stream, int.unchecked_right_shift(Len, 24) /\ 0xFF, !State),
    stream.put(Stream, int.unchecked_right_shift(Len, 16) /\ 0xFF, !State),
    stream.put(Stream, int.unchecked_right_shift(Len, 8) /\ 0xFF, !State),
    stream.put(Stream, Len /\ 0xFF, !State),
    string.foldl2(write_utf8(Stream, Len), String, 0, End, !State),
    pad(Stream, Len - End, !State).

%-----------------------------------------------------------------------------%

load_string_utf8_len4_be(Stream, Result, !State) :-
    read_int_be(Stream, 4, LenResult, !State),
    (
        LenResult = stream.error(Error),
        Result = stream.error("", Error)
    ;
        LenResult = stream.ok(Len),
        ( if
            Len = 0
        then
            Result = stream.ok("")
        else
            read_utf8_len(Stream, Len, [], Result, !State)
        )
    ).

%-----------------------------------------------------------------------------%

save_string_utf8_len4_le(Stream, String, !State) :-
    Len = int.min(string.length(String), 0xFFFFFFFF),
    stream.put(Stream, Len /\ 0xFF, !State),
    stream.put(Stream, int.unchecked_right_shift(Len, 8), !State),
    stream.put(Stream, int.unchecked_right_shift(Len, 16) /\ 0xFF, !State),
    stream.put(Stream, int.unchecked_right_shift(Len, 24) /\ 0xFF, !State),
    string.foldl2(write_utf8(Stream, Len), String, 0, End, !State),
    pad(Stream, Len - End, !State).

%-----------------------------------------------------------------------------%

load_string_utf8_len4_le(Stream, Result, !State) :-
    read_int_le(Stream, 4, LenResult, !State),
    (
        LenResult = stream.error(Error),
        Result = stream.error("", Error)
    ;
        LenResult = stream.ok(Len),
        ( if
            Len = 0
        then
            Result = stream.ok("")
        else
            read_utf8_len(Stream, Len, [], Result, !State)
        )
    ).

