base64/engine/general_purpose/
decode.rs

1use crate::{
2    engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode},
3    DecodeError, DecodeSliceError, PAD_BYTE,
4};
5
6#[doc(hidden)]
7pub struct GeneralPurposeEstimate {
8    /// input len % 4
9    rem: usize,
10    conservative_decoded_len: usize,
11}
12
13impl GeneralPurposeEstimate {
14    pub(crate) fn new(encoded_len: usize) -> Self {
15        let rem = encoded_len % 4;
16        Self {
17            rem,
18            conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3,
19        }
20    }
21}
22
23impl DecodeEstimate for GeneralPurposeEstimate {
24    fn decoded_len_estimate(&self) -> usize {
25        self.conservative_decoded_len
26    }
27}
28
29/// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
30/// Returns the decode metadata, or an error.
31// We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
32// inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
33// but this is fragile and the best setting changes with only minor code modifications.
34#[inline]
35pub(crate) fn decode_helper(
36    input: &[u8],
37    estimate: GeneralPurposeEstimate,
38    output: &mut [u8],
39    decode_table: &[u8; 256],
40    decode_allow_trailing_bits: bool,
41    padding_mode: DecodePaddingMode,
42) -> Result<DecodeMetadata, DecodeSliceError> {
43    let input_complete_nonterminal_quads_len =
44        complete_quads_len(input, estimate.rem, output.len(), decode_table)?;
45
46    const UNROLLED_INPUT_CHUNK_SIZE: usize = 32;
47    const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3;
48
49    let input_complete_quads_after_unrolled_chunks_len =
50        input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE;
51
52    let input_unrolled_loop_len =
53        input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len;
54
55    // chunks of 32 bytes
56    for (chunk_index, chunk) in input[..input_unrolled_loop_len]
57        .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE)
58        .enumerate()
59    {
60        let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE;
61        let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE
62            ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE];
63
64        decode_chunk_8(
65            &chunk[0..8],
66            input_index,
67            decode_table,
68            &mut chunk_output[0..6],
69        )?;
70        decode_chunk_8(
71            &chunk[8..16],
72            input_index + 8,
73            decode_table,
74            &mut chunk_output[6..12],
75        )?;
76        decode_chunk_8(
77            &chunk[16..24],
78            input_index + 16,
79            decode_table,
80            &mut chunk_output[12..18],
81        )?;
82        decode_chunk_8(
83            &chunk[24..32],
84            input_index + 24,
85            decode_table,
86            &mut chunk_output[18..24],
87        )?;
88    }
89
90    // remaining quads, except for the last possibly partial one, as it may have padding
91    let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3;
92    let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3;
93    {
94        let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len];
95
96        for (chunk_index, chunk) in input
97            [input_unrolled_loop_len..input_complete_nonterminal_quads_len]
98            .chunks_exact(4)
99            .enumerate()
100        {
101            let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3];
102
103            decode_chunk_4(
104                chunk,
105                input_unrolled_loop_len + chunk_index * 4,
106                decode_table,
107                chunk_output,
108            )?;
109        }
110    }
111
112    super::decode_suffix::decode_suffix(
113        input,
114        input_complete_nonterminal_quads_len,
115        output,
116        output_complete_quad_len,
117        decode_table,
118        decode_allow_trailing_bits,
119        padding_mode,
120    )
121}
122
123/// Returns the length of complete quads, except for the last one, even if it is complete.
124///
125/// Returns an error if the output len is not big enough for decoding those complete quads, or if
126/// the input % 4 == 1, and that last byte is an invalid value other than a pad byte.
127///
128/// - `input` is the base64 input
129/// - `input_len_rem` is input len % 4
130/// - `output_len` is the length of the output slice
131pub(crate) fn complete_quads_len(
132    input: &[u8],
133    input_len_rem: usize,
134    output_len: usize,
135    decode_table: &[u8; 256],
136) -> Result<usize, DecodeSliceError> {
137    debug_assert!(input.len() % 4 == input_len_rem);
138
139    // detect a trailing invalid byte, like a newline, as a user convenience
140    if input_len_rem == 1 {
141        let last_byte = input[input.len() - 1];
142        // exclude pad bytes; might be part of padding that extends from earlier in the input
143        if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE {
144            return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into());
145        }
146    };
147
148    // skip last quad, even if it's complete, as it may have padding
149    let input_complete_nonterminal_quads_len = input
150        .len()
151        .saturating_sub(input_len_rem)
152        // if rem was 0, subtract 4 to avoid padding
153        .saturating_sub((input_len_rem == 0) as usize * 4);
154    debug_assert!(
155        input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len))
156    );
157
158    // check that everything except the last quad handled by decode_suffix will fit
159    if output_len < input_complete_nonterminal_quads_len / 4 * 3 {
160        return Err(DecodeSliceError::OutputSliceTooSmall);
161    };
162    Ok(input_complete_nonterminal_quads_len)
163}
164
165/// Decode 8 bytes of input into 6 bytes of output.
166///
167/// `input` is the 8 bytes to decode.
168/// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
169/// accurately)
170/// `decode_table` is the lookup table for the particular base64 alphabet.
171/// `output` will have its first 6 bytes overwritten
172// yes, really inline (worth 30-50% speedup)
173#[inline(always)]
174fn decode_chunk_8(
175    input: &[u8],
176    index_at_start_of_input: usize,
177    decode_table: &[u8; 256],
178    output: &mut [u8],
179) -> Result<(), DecodeError> {
180    let morsel = decode_table[usize::from(input[0])];
181    if morsel == INVALID_VALUE {
182        return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
183    }
184    let mut accum = u64::from(morsel) << 58;
185
186    let morsel = decode_table[usize::from(input[1])];
187    if morsel == INVALID_VALUE {
188        return Err(DecodeError::InvalidByte(
189            index_at_start_of_input + 1,
190            input[1],
191        ));
192    }
193    accum |= u64::from(morsel) << 52;
194
195    let morsel = decode_table[usize::from(input[2])];
196    if morsel == INVALID_VALUE {
197        return Err(DecodeError::InvalidByte(
198            index_at_start_of_input + 2,
199            input[2],
200        ));
201    }
202    accum |= u64::from(morsel) << 46;
203
204    let morsel = decode_table[usize::from(input[3])];
205    if morsel == INVALID_VALUE {
206        return Err(DecodeError::InvalidByte(
207            index_at_start_of_input + 3,
208            input[3],
209        ));
210    }
211    accum |= u64::from(morsel) << 40;
212
213    let morsel = decode_table[usize::from(input[4])];
214    if morsel == INVALID_VALUE {
215        return Err(DecodeError::InvalidByte(
216            index_at_start_of_input + 4,
217            input[4],
218        ));
219    }
220    accum |= u64::from(morsel) << 34;
221
222    let morsel = decode_table[usize::from(input[5])];
223    if morsel == INVALID_VALUE {
224        return Err(DecodeError::InvalidByte(
225            index_at_start_of_input + 5,
226            input[5],
227        ));
228    }
229    accum |= u64::from(morsel) << 28;
230
231    let morsel = decode_table[usize::from(input[6])];
232    if morsel == INVALID_VALUE {
233        return Err(DecodeError::InvalidByte(
234            index_at_start_of_input + 6,
235            input[6],
236        ));
237    }
238    accum |= u64::from(morsel) << 22;
239
240    let morsel = decode_table[usize::from(input[7])];
241    if morsel == INVALID_VALUE {
242        return Err(DecodeError::InvalidByte(
243            index_at_start_of_input + 7,
244            input[7],
245        ));
246    }
247    accum |= u64::from(morsel) << 16;
248
249    output[..6].copy_from_slice(&accum.to_be_bytes()[..6]);
250
251    Ok(())
252}
253
254/// Like [decode_chunk_8] but for 4 bytes of input and 3 bytes of output.
255#[inline(always)]
256fn decode_chunk_4(
257    input: &[u8],
258    index_at_start_of_input: usize,
259    decode_table: &[u8; 256],
260    output: &mut [u8],
261) -> Result<(), DecodeError> {
262    let morsel = decode_table[usize::from(input[0])];
263    if morsel == INVALID_VALUE {
264        return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
265    }
266    let mut accum = u32::from(morsel) << 26;
267
268    let morsel = decode_table[usize::from(input[1])];
269    if morsel == INVALID_VALUE {
270        return Err(DecodeError::InvalidByte(
271            index_at_start_of_input + 1,
272            input[1],
273        ));
274    }
275    accum |= u32::from(morsel) << 20;
276
277    let morsel = decode_table[usize::from(input[2])];
278    if morsel == INVALID_VALUE {
279        return Err(DecodeError::InvalidByte(
280            index_at_start_of_input + 2,
281            input[2],
282        ));
283    }
284    accum |= u32::from(morsel) << 14;
285
286    let morsel = decode_table[usize::from(input[3])];
287    if morsel == INVALID_VALUE {
288        return Err(DecodeError::InvalidByte(
289            index_at_start_of_input + 3,
290            input[3],
291        ));
292    }
293    accum |= u32::from(morsel) << 8;
294
295    output[..3].copy_from_slice(&accum.to_be_bytes()[..3]);
296
297    Ok(())
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    use crate::engine::general_purpose::STANDARD;
305
306    #[test]
307    fn decode_chunk_8_writes_only_6_bytes() {
308        let input = b"Zm9vYmFy"; // "foobar"
309        let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
310
311        decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
312        assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
313    }
314
315    #[test]
316    fn decode_chunk_4_writes_only_3_bytes() {
317        let input = b"Zm9v"; // "foobar"
318        let mut output = [0_u8, 1, 2, 3];
319
320        decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
321        assert_eq!(&vec![b'f', b'o', b'o', 3], &output);
322    }
323
324    #[test]
325    fn estimate_short_lengths() {
326        for (range, decoded_len_estimate) in [
327            (0..=0, 0),
328            (1..=4, 3),
329            (5..=8, 6),
330            (9..=12, 9),
331            (13..=16, 12),
332            (17..=20, 15),
333        ] {
334            for encoded_len in range {
335                let estimate = GeneralPurposeEstimate::new(encoded_len);
336                assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate());
337            }
338        }
339    }
340
341    #[test]
342    fn estimate_via_u128_inflation() {
343        // cover both ends of usize
344        (0..1000)
345            .chain(usize::MAX - 1000..=usize::MAX)
346            .for_each(|encoded_len| {
347                // inflate to 128 bit type to be able to safely use the easy formulas
348                let len_128 = encoded_len as u128;
349
350                let estimate = GeneralPurposeEstimate::new(encoded_len);
351                assert_eq!(
352                    (len_128 + 3) / 4 * 3,
353                    estimate.conservative_decoded_len as u128
354                );
355            })
356    }
357}