1
use std::borrow::Cow;
2
use std::cmp::Ordering;
3
use std::fmt::{Debug, Display};
4
use std::io::Write;
5
use std::ops::Range;
6
use std::usize;
7

            
8
use byteorder::WriteBytesExt;
9
use serde::de::{SeqAccess, Visitor};
10
use serde::{ser, Deserialize, Serialize};
11
#[cfg(feature = "tracing")]
12
use tracing::instrument;
13

            
14
use crate::format::{self, Kind, Special, INITIAL_VERSION, V4_VERSION};
15
use crate::{Compatibility, Error, Result};
16

            
17
/// A Pot serializer.
18
pub struct Serializer<'a, W: WriteBytesExt> {
19
    symbol_map: SymbolMapRef<'a>,
20
    compatibility: Compatibility,
21
    output: W,
22
    bytes_written: usize,
23
}
24

            
25
impl<'a, W: WriteBytesExt> Debug for Serializer<'a, W> {
26
626
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27
626
        f.debug_struct("Serializer")
28
626
            .field("symbol_map", &self.symbol_map)
29
626
            .field("bytes_written", &self.bytes_written)
30
626
            .finish()
31
626
    }
32
}
33

            
34
impl<'a, W: WriteBytesExt> Serializer<'a, W> {
35
    /// Returns a new serializer outputting written bytes into `output`.
36
    #[inline]
37
1
    pub fn new(output: W) -> Result<Self> {
38
1
        Self::new_with_compatibility(output, Compatibility::default())
39
1
    }
40

            
41
    /// Returns a new serializer outputting written bytes into `output`.
42
    #[inline]
43
1157
    pub fn new_with_compatibility(output: W, compatibility: Compatibility) -> Result<Self> {
44
1157
        Self::new_with_symbol_map(
45
1157
            output,
46
1157
            SymbolMapRef::Ephemeral(EphemeralSymbolMap::default()),
47
1157
            compatibility,
48
1157
        )
49
1157
    }
50

            
51
1167
    fn new_with_symbol_map(
52
1167
        mut output: W,
53
1167
        symbol_map: SymbolMapRef<'a>,
54
1167
        compatibility: Compatibility,
55
1167
    ) -> Result<Self> {
56
1167
        let bytes_written = format::write_header(
57
1167
            &mut output,
58
1167
            match compatibility {
59
1166
                Compatibility::Full => INITIAL_VERSION,
60
1
                Compatibility::V4 => V4_VERSION,
61
            },
62
        )?;
63
1167
        Ok(Self {
64
1167
            compatibility,
65
1167
            symbol_map,
66
1167
            output,
67
1167
            bytes_written,
68
1167
        })
69
1167
    }
70

            
71
1121239
    #[cfg_attr(feature = "tracing", instrument)]
72
    fn write_symbol(&mut self, symbol: &'static str) -> Result<()> {
73
        let registered_symbol = self.symbol_map.find_or_add(symbol);
74
        if registered_symbol.new {
75
            // The arg is the length followed by a 0 bit.
76
            let arg = (symbol.len() as u64) << 1;
77
            self.bytes_written += format::write_atom_header(&mut self.output, Kind::Symbol, arg)?;
78
            self.output.write_all(symbol.as_bytes())?;
79
            self.bytes_written += symbol.len();
80
        } else {
81
            // When a symbol was already emitted, just emit the id followed by a 1 bit.
82
            self.bytes_written += format::write_atom_header(
83
                &mut self.output,
84
                Kind::Symbol,
85
                u64::from((registered_symbol.id << 1) | 1),
86
            )?;
87
        }
88
        Ok(())
89
    }
90
}
91

            
92
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::Serializer for &'de mut Serializer<'a, W> {
93
    type Error = Error;
94
    type Ok = ();
95
    type SerializeMap = MapSerializer<'de, 'a, W>;
96
    type SerializeSeq = Self;
97
    type SerializeStruct = MapSerializer<'de, 'a, W>;
98
    type SerializeStructVariant = MapSerializer<'de, 'a, W>;
99
    type SerializeTuple = Self;
100
    type SerializeTupleStruct = Self;
101
    type SerializeTupleVariant = Self;
102

            
103
    #[inline]
104
1
    fn is_human_readable(&self) -> bool {
105
1
        false
106
1
    }
107

            
108
4
    #[cfg_attr(feature = "tracing", instrument)]
109
    #[inline]
110
    fn serialize_bool(self, v: bool) -> Result<()> {
111
        self.bytes_written += format::write_bool(&mut self.output, v)?;
112
        Ok(())
113
    }
114

            
115
14
    #[cfg_attr(feature = "tracing", instrument)]
116
    #[inline]
117
    fn serialize_i8(self, v: i8) -> Result<()> {
118
        self.bytes_written += format::write_i8(&mut self.output, v)?;
119
        Ok(())
120
    }
121

            
122
14
    #[cfg_attr(feature = "tracing", instrument)]
123
    #[inline]
124
    fn serialize_i16(self, v: i16) -> Result<()> {
125
        self.bytes_written += format::write_i16(&mut self.output, v)?;
126
        Ok(())
127
    }
128

            
129
18
    #[cfg_attr(feature = "tracing", instrument)]
130
    #[inline]
131
    fn serialize_i32(self, v: i32) -> Result<()> {
132
        self.bytes_written += format::write_i32(&mut self.output, v)?;
133
        Ok(())
134
    }
135

            
136
14
    #[cfg_attr(feature = "tracing", instrument)]
137
    #[inline]
138
    fn serialize_i64(self, v: i64) -> Result<()> {
139
        self.bytes_written += format::write_i64(&mut self.output, v)?;
140
        Ok(())
141
    }
142

            
143
60
    #[cfg_attr(feature = "tracing", instrument)]
144
    #[inline]
145
    fn serialize_i128(self, v: i128) -> Result<()> {
146
        self.bytes_written += format::write_i128(&mut self.output, v)?;
147
        Ok(())
148
    }
149

            
150
26
    #[cfg_attr(feature = "tracing", instrument)]
151
    #[inline]
152
    fn serialize_u8(self, v: u8) -> Result<()> {
153
        self.bytes_written += format::write_u8(&mut self.output, v)?;
154
        Ok(())
155
    }
156

            
157
140017
    #[cfg_attr(feature = "tracing", instrument)]
158
    #[inline]
159
    fn serialize_u16(self, v: u16) -> Result<()> {
160
        self.bytes_written += format::write_u16(&mut self.output, v)?;
161
        Ok(())
162
    }
163

            
164
17
    #[cfg_attr(feature = "tracing", instrument)]
165
    #[inline]
166
    fn serialize_u32(self, v: u32) -> Result<()> {
167
        self.bytes_written += format::write_u32(&mut self.output, v)?;
168
        Ok(())
169
    }
170

            
171
140046
    #[cfg_attr(feature = "tracing", instrument)]
172
    #[inline]
173
    fn serialize_u64(self, v: u64) -> Result<()> {
174
        self.bytes_written += format::write_u64(&mut self.output, v)?;
175
        Ok(())
176
    }
177

            
178
40
    #[cfg_attr(feature = "tracing", instrument)]
179
    #[inline]
180
    fn serialize_u128(self, v: u128) -> Result<()> {
181
        self.bytes_written += format::write_u128(&mut self.output, v)?;
182
        Ok(())
183
    }
184

            
185
16
    #[cfg_attr(feature = "tracing", instrument)]
186
    #[inline]
187
    fn serialize_f32(self, v: f32) -> Result<()> {
188
        self.bytes_written += format::write_f32(&mut self.output, v)?;
189
        Ok(())
190
    }
191

            
192
24
    #[cfg_attr(feature = "tracing", instrument)]
193
    #[inline]
194
    fn serialize_f64(self, v: f64) -> Result<()> {
195
        self.bytes_written += format::write_f64(&mut self.output, v)?;
196
        Ok(())
197
    }
198

            
199
13
    #[cfg_attr(feature = "tracing", instrument)]
200
    #[inline]
201
    fn serialize_char(self, v: char) -> Result<()> {
202
        self.bytes_written += format::write_u32(&mut self.output, v as u32)?;
203
        Ok(())
204
    }
205

            
206
489879
    #[cfg_attr(feature = "tracing", instrument)]
207
    #[inline]
208
    fn serialize_str(self, v: &str) -> Result<()> {
209
        self.bytes_written += format::write_str(&mut self.output, v)?;
210
        Ok(())
211
    }
212

            
213
9
    #[cfg_attr(feature = "tracing", instrument)]
214
    #[inline]
215
    fn serialize_bytes(self, v: &[u8]) -> Result<()> {
216
        self.bytes_written += format::write_bytes(&mut self.output, v)?;
217
        Ok(())
218
    }
219

            
220
70158
    #[cfg_attr(feature = "tracing", instrument)]
221
    #[inline]
222
    fn serialize_none(self) -> Result<()> {
223
        self.bytes_written += format::write_none(&mut self.output)?;
224
        Ok(())
225
    }
226

            
227
69857
    #[cfg_attr(feature = "tracing", instrument(level = "trace", skip(value)))]
228
    #[inline]
229
    fn serialize_some<T>(self, value: &T) -> Result<()>
230
    where
231
        T: ?Sized + Serialize,
232
    {
233
        value.serialize(self)
234
    }
235

            
236
7
    #[cfg_attr(feature = "tracing", instrument)]
237
    #[inline]
238
    fn serialize_unit(self) -> Result<()> {
239
        self.bytes_written += format::write_unit(&mut self.output)?;
240
        Ok(())
241
    }
242

            
243
2
    #[cfg_attr(feature = "tracing", instrument)]
244
    #[inline]
245
    fn serialize_unit_struct(self, _name: &'static str) -> Result<()> {
246
        self.serialize_unit()
247
    }
248

            
249
140009
    #[cfg_attr(feature = "tracing", instrument)]
250
    #[inline]
251
    fn serialize_unit_variant(
252
        self,
253
        _name: &'static str,
254
        _variant_index: u32,
255
        variant: &'static str,
256
    ) -> Result<()> {
257
        if matches!(self.compatibility, Compatibility::Full) {
258
            self.bytes_written += format::write_named(&mut self.output)?;
259
        }
260
        self.write_symbol(variant)?;
261
        Ok(())
262
    }
263

            
264
    #[cfg_attr(feature = "tracing", instrument(level = "trace", skip(value)))]
265
    #[inline]
266
    fn serialize_newtype_struct<T>(self, _name: &'static str, value: &T) -> Result<()>
267
    where
268
        T: ?Sized + Serialize,
269
    {
270
        value.serialize(self)
271
    }
272

            
273
4
    #[cfg_attr(feature = "tracing", instrument(level = "trace", skip(value)))]
274
    #[inline]
275
    fn serialize_newtype_variant<T>(
276
        self,
277
        _name: &'static str,
278
        _variant_index: u32,
279
        variant: &'static str,
280
        value: &T,
281
    ) -> Result<()>
282
    where
283
        T: ?Sized + Serialize,
284
    {
285
        format::write_named(&mut self.output)?;
286
        self.write_symbol(variant)?;
287
        value.serialize(&mut *self)?;
288
        Ok(())
289
    }
290

            
291
1022
    #[cfg_attr(feature = "tracing", instrument)]
292
    #[inline]
293
    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq> {
294
        let len = len.ok_or(Error::SequenceSizeMustBeKnown)?;
295
        self.bytes_written +=
296
            format::write_atom_header(&mut self.output, Kind::Sequence, len as u64)?;
297
        Ok(self)
298
    }
299

            
300
2
    #[cfg_attr(feature = "tracing", instrument)]
301
    #[inline]
302
    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple> {
303
        self.serialize_seq(Some(len))
304
    }
305

            
306
2
    #[cfg_attr(feature = "tracing", instrument)]
307
    #[inline]
308
    fn serialize_tuple_struct(
309
        self,
310
        _name: &'static str,
311
        len: usize,
312
    ) -> Result<Self::SerializeTupleStruct> {
313
        self.serialize_seq(Some(len))
314
    }
315

            
316
2
    #[cfg_attr(feature = "tracing", instrument)]
317
    #[inline]
318
    fn serialize_tuple_variant(
319
        self,
320
        _name: &'static str,
321
        _variant_index: u32,
322
        variant: &'static str,
323
        len: usize,
324
    ) -> Result<Self::SerializeTupleVariant> {
325
        format::write_named(&mut self.output)?;
326
        self.write_symbol(variant)?;
327
        self.serialize_seq(Some(len))
328
    }
329

            
330
141035
    #[cfg_attr(feature = "tracing", instrument)]
331
    #[inline]
332
    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap> {
333
        if let Some(len) = len {
334
            self.bytes_written +=
335
                format::write_atom_header(&mut self.output, Kind::Map, len as u64)?;
336
            Ok(MapSerializer {
337
                serializer: self,
338
                known_length: true,
339
            })
340
        } else {
341
            self.bytes_written += format::write_special(&mut self.output, Special::DynamicMap)?;
342
            Ok(MapSerializer {
343
                serializer: self,
344
                known_length: false,
345
            })
346
        }
347
    }
348

            
349
141032
    #[cfg_attr(feature = "tracing", instrument)]
350
    #[inline]
351
    fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
352
        self.serialize_map(Some(len))
353
    }
354

            
355
2
    #[cfg_attr(feature = "tracing", instrument)]
356
    #[inline]
357
    fn serialize_struct_variant(
358
        self,
359
        name: &'static str,
360
        _variant_index: u32,
361
        variant: &'static str,
362
        len: usize,
363
    ) -> Result<Self::SerializeStructVariant> {
364
        format::write_named(&mut self.output)?;
365
        self.write_symbol(variant)?;
366
        self.serialize_struct(name, len)
367
    }
368
}
369

            
370
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeSeq for &'de mut Serializer<'a, W> {
371
    type Error = Error;
372
    type Ok = ();
373

            
374
    #[inline]
375
140029
    fn serialize_element<T>(&mut self, value: &T) -> Result<()>
376
140029
    where
377
140029
        T: ?Sized + Serialize,
378
140029
    {
379
140029
        value.serialize(&mut **self)
380
140029
    }
381

            
382
    #[inline]
383
1016
    fn end(self) -> Result<()> {
384
1016
        Ok(())
385
1016
    }
386
}
387

            
388
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeTuple for &'de mut Serializer<'a, W> {
389
    type Error = Error;
390
    type Ok = ();
391

            
392
    #[inline]
393
6
    fn serialize_element<T>(&mut self, value: &T) -> Result<()>
394
6
    where
395
6
        T: ?Sized + Serialize,
396
6
    {
397
6
        value.serialize(&mut **self)
398
6
    }
399

            
400
    #[inline]
401
2
    fn end(self) -> Result<()> {
402
2
        Ok(())
403
2
    }
404
}
405

            
406
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeTupleStruct for &'de mut Serializer<'a, W> {
407
    type Error = Error;
408
    type Ok = ();
409

            
410
    #[inline]
411
4
    fn serialize_field<T>(&mut self, value: &T) -> Result<()>
412
4
    where
413
4
        T: ?Sized + Serialize,
414
4
    {
415
4
        value.serialize(&mut **self)
416
4
    }
417

            
418
    #[inline]
419
2
    fn end(self) -> Result<()> {
420
2
        Ok(())
421
2
    }
422
}
423

            
424
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeTupleVariant
425
    for &'de mut Serializer<'a, W>
426
{
427
    type Error = Error;
428
    type Ok = ();
429

            
430
    #[inline]
431
4
    fn serialize_field<T>(&mut self, value: &T) -> Result<()>
432
4
    where
433
4
        T: ?Sized + Serialize,
434
4
    {
435
4
        value.serialize(&mut **self)
436
4
    }
437

            
438
    #[inline]
439
2
    fn end(self) -> Result<()> {
440
2
        Ok(())
441
2
    }
442
}
443

            
444
/// Serializes map-like values.
445
pub struct MapSerializer<'de, 'a, W: WriteBytesExt> {
446
    serializer: &'de mut Serializer<'a, W>,
447
    known_length: bool,
448
}
449

            
450
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeMap for MapSerializer<'de, 'a, W> {
451
    type Error = Error;
452
    type Ok = ();
453

            
454
    #[inline]
455
5
    fn serialize_key<T>(&mut self, key: &T) -> Result<()>
456
5
    where
457
5
        T: ?Sized + Serialize,
458
5
    {
459
5
        key.serialize(&mut *self.serializer)
460
5
    }
461

            
462
    #[inline]
463
5
    fn serialize_value<T>(&mut self, value: &T) -> Result<()>
464
5
    where
465
5
        T: ?Sized + Serialize,
466
5
    {
467
5
        value.serialize(&mut *self.serializer)
468
5
    }
469

            
470
    #[inline]
471
3
    fn end(self) -> Result<()> {
472
3
        if !self.known_length {
473
2
            format::write_special(&mut self.serializer.output, Special::DynamicEnd)?;
474
1
        }
475
3
        Ok(())
476
3
    }
477
}
478

            
479
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeStruct for MapSerializer<'de, 'a, W> {
480
    type Error = Error;
481
    type Ok = ();
482

            
483
    #[inline]
484
981220
    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
485
981220
    where
486
981220
        T: ?Sized + Serialize,
487
981220
    {
488
981220
        self.serializer.write_symbol(key)?;
489
981220
        value.serialize(&mut *self.serializer)
490
981220
    }
491

            
492
    #[inline]
493
141030
    fn end(self) -> Result<()> {
494
141030
        if !self.known_length {
495
            format::write_special(&mut self.serializer.output, Special::DynamicEnd)?;
496
141030
        }
497
141030
        Ok(())
498
141030
    }
499
}
500

            
501
impl<'de, 'a: 'de, W: WriteBytesExt + 'a> ser::SerializeStructVariant
502
    for MapSerializer<'de, 'a, W>
503
{
504
    type Error = Error;
505
    type Ok = ();
506

            
507
    #[inline]
508
2
    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
509
2
    where
510
2
        T: ?Sized + Serialize,
511
2
    {
512
2
        self.serializer.write_symbol(key)?;
513
2
        value.serialize(&mut *self.serializer)
514
2
    }
515

            
516
    #[inline]
517
2
    fn end(self) -> Result<()> {
518
2
        if !self.known_length {
519
            format::write_special(&mut self.serializer.output, Special::DynamicEnd)?;
520
2
        }
521
2
        Ok(())
522
2
    }
523
}
524

            
525
#[derive(Default)]
526
struct EphemeralSymbolMap {
527
    symbols: Vec<(&'static str, u32)>,
528
}
529

            
530
struct RegisteredSymbol {
531
    id: u32,
532
    new: bool,
533
}
534

            
535
impl EphemeralSymbolMap {
536
    #[allow(clippy::cast_possible_truncation)]
537
4484215
    fn find_or_add(&mut self, symbol: &'static str) -> RegisteredSymbol {
538
4484215
        // Symbols have to be static strings, and so we can rely on the addres
539
4484215
        // not changing. To avoid string comparisons, we're going to use the
540
4484215
        // address of the str in the map.
541
4484215
        let symbol_address = symbol.as_ptr() as usize;
542
4484215
        // Perform a binary search to find this existing element.
543
4484215
        match self
544
4484215
            .symbols
545
14058564
            .binary_search_by(|check| (check.0.as_ptr() as usize).cmp(&symbol_address))
546
        {
547
4431834
            Ok(position) => RegisteredSymbol {
548
4431834
                id: self.symbols[position].1,
549
4431834
                new: false,
550
4431834
            },
551
52381
            Err(position) => {
552
52381
                let id = self.symbols.len() as u32;
553
52381
                self.symbols.insert(position, (symbol, id));
554
52381
                RegisteredSymbol { id, new: true }
555
            }
556
        }
557
4484215
    }
558
}
559

            
560
impl Debug for EphemeralSymbolMap {
561
497
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562
497
        let mut set = f.debug_set();
563
7866
        for index in SymbolIdSorter::new(&self.symbols, |sym| sym.1) {
564
1984
            set.entry(&self.symbols[index].0);
565
1984
        }
566
497
        set.finish()
567
497
    }
568
}
569

            
570
/// A list of previously serialized symbols.
571
pub struct SymbolMap {
572
    symbols: String,
573
    entries: Vec<(Range<usize>, u32)>,
574
    static_lookup: Vec<(usize, u32)>,
575
    compatibility: Compatibility,
576
}
577

            
578
impl Debug for SymbolMap {
579
135
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
580
135
        let mut s = f.debug_set();
581
1255
        for entry in &self.entries {
582
1120
            s.entry(&&self.symbols[entry.0.clone()]);
583
1120
        }
584
135
        s.finish()
585
135
    }
586
}
587

            
588
impl Default for SymbolMap {
589
    #[inline]
590
5
    fn default() -> Self {
591
5
        Self::new()
592
5
    }
593
}
594

            
595
impl SymbolMap {
596
    /// Returns a new, empty symbol map.
597
    #[must_use]
598
13
    pub const fn new() -> Self {
599
13
        Self {
600
13
            symbols: String::new(),
601
13
            entries: Vec::new(),
602
13
            static_lookup: Vec::new(),
603
13
            compatibility: Compatibility::const_default(),
604
13
        }
605
13
    }
606

            
607
    /// Sets the compatibility mode for serializing and returns self.
608
    pub const fn with_compatibility(mut self, compatibility: Compatibility) -> Self {
609
        self.compatibility = compatibility;
610
        self
611
    }
612

            
613
    /// Sets the compatibility mode for serializing.
614
    pub fn set_compatibility(&mut self, compatibility: Compatibility) {
615
        self.compatibility = compatibility;
616
    }
617

            
618
    /// Returns a serializer that writes into `output` and persists symbols
619
    /// into `self`.
620
    #[inline]
621
10
    pub fn serializer_for<W: WriteBytesExt>(&mut self, output: W) -> Result<Serializer<'_, W>> {
622
10
        let compatibility = self.compatibility;
623
10
        Serializer::new_with_symbol_map(output, SymbolMapRef::Persistent(self), compatibility)
624
10
    }
625

            
626
    /// Serializes `value` into `writer` while persisting symbols into `self`.
627
10
    pub fn serialize_to<T, W>(&mut self, writer: W, value: &T) -> Result<()>
628
10
    where
629
10
        W: Write,
630
10
        T: Serialize,
631
10
    {
632
10
        value.serialize(&mut self.serializer_for(writer)?)
633
10
    }
634

            
635
    /// Serializes `value` into a new `Vec<u8>` while persisting symbols into
636
    /// `self`.
637
6
    pub fn serialize_to_vec<T>(&mut self, value: &T) -> Result<Vec<u8>>
638
6
    where
639
6
        T: Serialize,
640
6
    {
641
6
        let mut output = Vec::new();
642
6
        self.serialize_to(&mut output, value)?;
643
6
        Ok(output)
644
6
    }
645

            
646
157
    fn find_or_add(&mut self, symbol: &'static str) -> RegisteredSymbol {
647
157
        // Symbols have to be static strings, and so we can rely on the addres
648
157
        // not changing. To avoid string comparisons, we're going to use the
649
157
        // address of the str in the map.
650
157
        let symbol_address = symbol.as_ptr() as usize;
651
157
        // Perform a binary search to find this existing element.
652
157
        match self
653
157
            .static_lookup
654
378
            .binary_search_by(|check| symbol_address.cmp(&check.0))
655
        {
656
69
            Ok(position) => RegisteredSymbol {
657
69
                id: self.static_lookup[position].1,
658
69
                new: false,
659
69
            },
660
88
            Err(position) => {
661
88
                // This static symbol hasn't been encountered before.
662
88
                let symbol = self.find_entry_by_str(symbol);
663
88
                self.static_lookup
664
88
                    .insert(position, (symbol_address, symbol.id));
665
88
                symbol
666
            }
667
        }
668
157
    }
669

            
670
    #[allow(clippy::cast_possible_truncation)]
671
88
    fn find_entry_by_str(&mut self, symbol: &str) -> RegisteredSymbol {
672
88
        match self
673
88
            .entries
674
186
            .binary_search_by(|check| self.symbols[check.0.clone()].cmp(symbol))
675
        {
676
2
            Ok(index) => RegisteredSymbol {
677
2
                id: self.entries[index].1,
678
2
                new: false,
679
2
            },
680
86
            Err(insert_at) => {
681
86
                let id = self.entries.len() as u32;
682
86
                let start = self.symbols.len();
683
86
                self.symbols.push_str(symbol);
684
86
                self.entries
685
86
                    .insert(insert_at, (start..self.symbols.len(), id));
686
86
                RegisteredSymbol { id, new: true }
687
            }
688
        }
689
88
    }
690

            
691
    /// Inserts `symbol` into this map.
692
    ///
693
    /// Returns true if this symbol had not previously been registered. Returns
694
    /// false if the symbol was already included in the map.
695
    pub fn insert(&mut self, symbol: &str) -> bool {
696
        self.find_entry_by_str(symbol).new
697
    }
698

            
699
    /// Returns the number of entries in this map.
700
    #[must_use]
701
7
    pub fn len(&self) -> usize {
702
7
        self.entries.len()
703
7
    }
704

            
705
    /// Returns true if the map has no entries.
706
    #[must_use]
707
1
    pub fn is_empty(&self) -> bool {
708
1
        self.len() == 0
709
1
    }
710

            
711
    /// Adds all symbols encountered in `value`.
712
    ///
713
    /// Returns the number of symbols added.
714
    ///
715
    /// Due to how serde works, this function can only encounter symbols that
716
    /// are being used. For example, if `T` is an enum, only variant being
717
    /// passed in will have its name, and additional calls for each variant will
718
    /// be needed to ensure every symbol is added.
719
7
    pub fn populate_from<T>(&mut self, value: &T) -> Result<usize, SymbolMapPopulationError>
720
7
    where
721
7
        T: Serialize,
722
7
    {
723
7
        let start_count = self.entries.len();
724
7
        value.serialize(&mut SymbolMapPopulator(self))?;
725
7
        Ok(self.entries.len() - start_count)
726
7
    }
727
}
728

            
729
impl Serialize for SymbolMap {
730
    #[inline]
731
2
    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
732
2
    where
733
2
        S: serde::Serializer,
734
2
    {
735
        use serde::ser::SerializeSeq;
736
2
        let mut seq = serializer.serialize_seq(Some(self.len()))?;
737
6
        for index in SymbolIdSorter::new(&self.entries, |entry| entry.1) {
738
4
            seq.serialize_element(&self.symbols[self.entries[index].0.clone()])?;
739
        }
740
2
        seq.end()
741
2
    }
742
}
743

            
744
impl<'de> Deserialize<'de> for SymbolMap {
745
    #[inline]
746
1
    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
747
1
    where
748
1
        D: serde::Deserializer<'de>,
749
1
    {
750
1
        deserializer.deserialize_seq(SymbolMapVisitor)
751
1
    }
752
}
753

            
754
struct SymbolMapVisitor;
755

            
756
impl<'de> Visitor<'de> for SymbolMapVisitor {
757
    type Value = SymbolMap;
758

            
759
    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
760
        formatter.write_str("symbol map")
761
    }
762

            
763
    #[inline]
764
1
    fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
765
1
    where
766
1
        A: SeqAccess<'de>,
767
1
    {
768
1
        let mut map = SymbolMap::new();
769
1
        if let Some(hint) = seq.size_hint() {
770
1
            map.entries.reserve(hint);
771
1
        }
772
1
        let mut id = 0;
773
3
        while let Some(element) = seq.next_element::<Cow<'_, str>>()? {
774
2
            let start = map.symbols.len();
775
2
            map.symbols.push_str(&element);
776
2
            map.entries.push((start..map.symbols.len(), id));
777
2
            id += 1;
778
2
        }
779

            
780
1
        map.entries
781
1
            .sort_by(|a, b| map.symbols[a.0.clone()].cmp(&map.symbols[b.0.clone()]));
782
1

            
783
1
        Ok(map)
784
1
    }
785
}
786

            
787
#[derive(Debug)]
788
enum SymbolMapRef<'a> {
789
    Ephemeral(EphemeralSymbolMap),
790
    Persistent(&'a mut SymbolMap),
791
}
792

            
793
impl SymbolMapRef<'_> {
794
4484341
    fn find_or_add(&mut self, symbol: &'static str) -> RegisteredSymbol {
795
4484341
        match self {
796
4484211
            SymbolMapRef::Ephemeral(map) => map.find_or_add(symbol),
797
130
            SymbolMapRef::Persistent(map) => map.find_or_add(symbol),
798
        }
799
4484341
    }
800
}
801

            
802
struct SymbolMapPopulator<'a>(&'a mut SymbolMap);
803

            
804
impl<'ser, 'a> serde::ser::Serializer for &'ser mut SymbolMapPopulator<'a> {
805
    type Error = SymbolMapPopulationError;
806
    type Ok = ();
807
    type SerializeMap = Self;
808
    type SerializeSeq = Self;
809
    type SerializeStruct = Self;
810
    type SerializeStructVariant = Self;
811
    type SerializeTuple = Self;
812
    type SerializeTupleStruct = Self;
813
    type SerializeTupleVariant = Self;
814

            
815
    #[inline]
816
    fn serialize_bool(self, _v: bool) -> std::result::Result<Self::Ok, Self::Error> {
817
        Ok(())
818
    }
819

            
820
    #[inline]
821
1
    fn serialize_i8(self, _v: i8) -> std::result::Result<Self::Ok, Self::Error> {
822
1
        Ok(())
823
1
    }
824

            
825
    #[inline]
826
1
    fn serialize_i16(self, _v: i16) -> std::result::Result<Self::Ok, Self::Error> {
827
1
        Ok(())
828
1
    }
829

            
830
    #[inline]
831
1
    fn serialize_i32(self, _v: i32) -> std::result::Result<Self::Ok, Self::Error> {
832
1
        Ok(())
833
1
    }
834

            
835
    #[inline]
836
1
    fn serialize_i64(self, _v: i64) -> std::result::Result<Self::Ok, Self::Error> {
837
1
        Ok(())
838
1
    }
839

            
840
    #[inline]
841
1
    fn serialize_u8(self, _v: u8) -> std::result::Result<Self::Ok, Self::Error> {
842
1
        Ok(())
843
1
    }
844

            
845
    #[inline]
846
1
    fn serialize_u16(self, _v: u16) -> std::result::Result<Self::Ok, Self::Error> {
847
1
        Ok(())
848
1
    }
849

            
850
    #[inline]
851
1
    fn serialize_u32(self, _v: u32) -> std::result::Result<Self::Ok, Self::Error> {
852
1
        Ok(())
853
1
    }
854

            
855
    #[inline]
856
6
    fn serialize_u64(self, _v: u64) -> std::result::Result<Self::Ok, Self::Error> {
857
6
        Ok(())
858
6
    }
859

            
860
    #[inline]
861
1
    fn serialize_i128(self, _v: i128) -> Result<Self::Ok, Self::Error> {
862
1
        Ok(())
863
1
    }
864

            
865
    #[inline]
866
1
    fn serialize_u128(self, _v: u128) -> Result<Self::Ok, Self::Error> {
867
1
        Ok(())
868
1
    }
869

            
870
    #[inline]
871
1
    fn serialize_f32(self, _v: f32) -> std::result::Result<Self::Ok, Self::Error> {
872
1
        Ok(())
873
1
    }
874

            
875
    #[inline]
876
1
    fn serialize_f64(self, _v: f64) -> std::result::Result<Self::Ok, Self::Error> {
877
1
        Ok(())
878
1
    }
879

            
880
    #[inline]
881
1
    fn serialize_char(self, _v: char) -> std::result::Result<Self::Ok, Self::Error> {
882
1
        Ok(())
883
1
    }
884

            
885
    #[inline]
886
1
    fn serialize_str(self, _v: &str) -> std::result::Result<Self::Ok, Self::Error> {
887
1
        Ok(())
888
1
    }
889

            
890
    #[inline]
891
    fn serialize_bytes(self, _v: &[u8]) -> std::result::Result<Self::Ok, Self::Error> {
892
        Ok(())
893
    }
894

            
895
    #[inline]
896
    fn serialize_none(self) -> std::result::Result<Self::Ok, Self::Error> {
897
        Ok(())
898
    }
899

            
900
    #[inline]
901
    fn serialize_some<T: ?Sized>(self, value: &T) -> std::result::Result<Self::Ok, Self::Error>
902
    where
903
        T: Serialize,
904
    {
905
        value.serialize(self)
906
    }
907

            
908
    #[inline]
909
    fn serialize_unit(self) -> std::result::Result<Self::Ok, Self::Error> {
910
        Ok(())
911
    }
912

            
913
    #[inline]
914
    fn serialize_unit_struct(
915
        self,
916
        _name: &'static str,
917
    ) -> std::result::Result<Self::Ok, Self::Error> {
918
        Ok(())
919
    }
920

            
921
    #[inline]
922
2
    fn serialize_unit_variant(
923
2
        self,
924
2
        _name: &'static str,
925
2
        _variant_index: u32,
926
2
        variant: &'static str,
927
2
    ) -> std::result::Result<Self::Ok, Self::Error> {
928
2
        self.0.find_or_add(variant);
929
2
        Ok(())
930
2
    }
931

            
932
    #[inline]
933
    fn serialize_newtype_struct<T: ?Sized>(
934
        self,
935
        _name: &'static str,
936
        value: &T,
937
    ) -> std::result::Result<Self::Ok, Self::Error>
938
    where
939
        T: Serialize,
940
    {
941
        value.serialize(self)
942
    }
943

            
944
    #[inline]
945
1
    fn serialize_newtype_variant<T: ?Sized>(
946
1
        self,
947
1
        _name: &'static str,
948
1
        _variant_index: u32,
949
1
        variant: &'static str,
950
1
        value: &T,
951
1
    ) -> std::result::Result<Self::Ok, Self::Error>
952
1
    where
953
1
        T: Serialize,
954
1
    {
955
1
        self.0.find_or_add(variant);
956
1
        value.serialize(self)
957
1
    }
958

            
959
    #[inline]
960
    fn serialize_seq(
961
        self,
962
        _len: Option<usize>,
963
    ) -> std::result::Result<Self::SerializeSeq, Self::Error> {
964
        Ok(self)
965
    }
966

            
967
    #[inline]
968
    fn serialize_tuple(
969
        self,
970
        _len: usize,
971
    ) -> std::result::Result<Self::SerializeTuple, Self::Error> {
972
        Ok(self)
973
    }
974

            
975
    #[inline]
976
    fn serialize_tuple_struct(
977
        self,
978
        _name: &'static str,
979
        _len: usize,
980
    ) -> std::result::Result<Self::SerializeTupleStruct, Self::Error> {
981
        Ok(self)
982
    }
983

            
984
    #[inline]
985
1
    fn serialize_tuple_variant(
986
1
        self,
987
1
        _name: &'static str,
988
1
        _variant_index: u32,
989
1
        variant: &'static str,
990
1
        _len: usize,
991
1
    ) -> std::result::Result<Self::SerializeTupleVariant, Self::Error> {
992
1
        self.0.find_or_add(variant);
993
1
        Ok(self)
994
1
    }
995

            
996
    #[inline]
997
    fn serialize_map(
998
        self,
999
        _len: Option<usize>,
    ) -> std::result::Result<Self::SerializeMap, Self::Error> {
        Ok(self)
    }

            
    #[inline]
2
    fn serialize_struct(
2
        self,
2
        _name: &'static str,
2
        _len: usize,
2
    ) -> std::result::Result<Self::SerializeStruct, Self::Error> {
2
        Ok(self)
2
    }

            
    #[inline]
1
    fn serialize_struct_variant(
1
        self,
1
        _name: &'static str,
1
        _variant_index: u32,
1
        variant: &'static str,
1
        _len: usize,
1
    ) -> std::result::Result<Self::SerializeStructVariant, Self::Error> {
1
        self.0.find_or_add(variant);
1
        Ok(self)
1
    }
}

            
impl serde::ser::SerializeMap for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
    fn serialize_key<T: ?Sized>(&mut self, key: &T) -> std::result::Result<(), Self::Error>
    where
        T: Serialize,
    {
        key.serialize(&mut SymbolMapPopulator(&mut *self.0))
    }

            
    #[inline]
    fn serialize_value<T: ?Sized>(&mut self, value: &T) -> std::result::Result<(), Self::Error>
    where
        T: Serialize,
    {
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
    }

            
    #[inline]
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
        Ok(())
    }
}

            
impl serde::ser::SerializeSeq for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
    fn serialize_element<T: ?Sized>(&mut self, value: &T) -> std::result::Result<(), Self::Error>
    where
        T: Serialize,
    {
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
    }

            
    #[inline]
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
        Ok(())
    }
}

            
impl serde::ser::SerializeStruct for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
15
    fn serialize_field<T: ?Sized>(
15
        &mut self,
15
        key: &'static str,
15
        value: &T,
15
    ) -> std::result::Result<(), Self::Error>
15
    where
15
        T: Serialize,
15
    {
15
        self.0.find_or_add(key);
15
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
15
    }

            
    #[inline]
2
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
2
        Ok(())
2
    }
}

            
impl serde::ser::SerializeStructVariant for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
1
    fn serialize_field<T: ?Sized>(
1
        &mut self,
1
        key: &'static str,
1
        value: &T,
1
    ) -> std::result::Result<(), Self::Error>
1
    where
1
        T: Serialize,
1
    {
1
        self.0.find_or_add(key);
1
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
1
    }

            
    #[inline]
1
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
1
        Ok(())
1
    }
}

            
impl serde::ser::SerializeTuple for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
    fn serialize_element<T: ?Sized>(&mut self, value: &T) -> std::result::Result<(), Self::Error>
    where
        T: Serialize,
    {
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
    }

            
    #[inline]
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
        Ok(())
    }
}
impl serde::ser::SerializeTupleStruct for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
    fn serialize_field<T: ?Sized>(&mut self, value: &T) -> std::result::Result<(), Self::Error>
    where
        T: Serialize,
    {
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
    }

            
    #[inline]
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
        Ok(())
    }
}
impl serde::ser::SerializeTupleVariant for &mut SymbolMapPopulator<'_> {
    type Error = SymbolMapPopulationError;
    type Ok = ();

            
    #[inline]
2
    fn serialize_field<T: ?Sized>(&mut self, value: &T) -> std::result::Result<(), Self::Error>
2
    where
2
        T: Serialize,
2
    {
2
        value.serialize(&mut SymbolMapPopulator(&mut *self.0))
2
    }

            
    #[inline]
1
    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
1
        Ok(())
1
    }
}

            
/// A [`Serialize`] implementation returned an error.
#[derive(Debug)]
pub struct SymbolMapPopulationError(String);

            
impl Display for SymbolMapPopulationError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(&self.0)
    }
}

            
impl std::error::Error for SymbolMapPopulationError {}

            
impl serde::ser::Error for SymbolMapPopulationError {
    fn custom<T>(msg: T) -> Self
    where
        T: Display,
    {
        Self(msg.to_string())
    }
}

            
struct SymbolIdSorter<'a, T, F> {
    source: &'a [T],
    map: F,
    min: usize,
    id: u32,
}

            
impl<'a, T, F> SymbolIdSorter<'a, T, F>
where
    F: FnMut(&T) -> u32,
{
499
    pub fn new(source: &'a [T], map: F) -> Self {
499
        Self {
499
            source,
499
            map,
499
            min: 0,
499
            id: 0,
499
        }
499
    }
}
impl<'a, T, F> Iterator for SymbolIdSorter<'a, T, F>
where
    F: FnMut(&T) -> u32,
    T: Clone,
{
    type Item = usize;

            
2487
    fn next(&mut self) -> Option<Self::Item> {
2487
        let mut encountered_greater = false;
2487
        let start_min = self.min;
7871
        for (relative_index, entry) in self.source[start_min..].iter().enumerate() {
7870
            let id = (self.map)(entry);
7870
            match id.cmp(&self.id) {
                Ordering::Equal => {
1988
                    self.id += 1;
1988
                    let index = start_min + relative_index;
1988
                    if !encountered_greater {
895
                        self.min = index + 1;
1095
                    }
1988
                    return Some(index);
                }
1779
                Ordering::Greater => encountered_greater = true,
1093
                Ordering::Less if !encountered_greater => self.min = start_min + relative_index,
3010
                Ordering::Less => {}
            }
        }

            
499
        None
2487
    }
}

            
#[test]
1
fn symbol_map_debug() {
1
    let mut map = EphemeralSymbolMap::default();
1
    // To force the order, we're splitting a single string into multiple parts.
1
    let full_source = "abcd";
1

            
1
    map.find_or_add(&full_source[1..2]);
1
    map.find_or_add(&full_source[0..1]);
1
    map.find_or_add(&full_source[2..3]);
1
    map.find_or_add(&full_source[3..4]);
1

            
1
    // Verify the map sorted the symbols correctly (by memory address).
1
    assert_eq!(map.symbols[0].0, "a");
1
    assert_eq!(map.symbols[1].0, "b");
1
    assert_eq!(map.symbols[2].0, "c");
1
    assert_eq!(map.symbols[3].0, "d");

            
    // Verify the debug output printed the correct order.
1
    assert_eq!(format!("{map:?}"), r#"{"b", "a", "c", "d"}"#);
1
}