1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
use serde::de::{self, Deserialize};
use serde::ser::{self, Serialize};

use crate::ByteUnit;

impl<'de> Deserialize<'de> for ByteUnit {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
        where D: serde::Deserializer<'de>
    {
        deserializer.deserialize_u64(Visitor)
    }
}

macro_rules! visit_integer_fn {
    ($name:ident: $T:ty) => (
        fn $name<E: de::Error>(self, v: $T) -> Result<Self::Value, E> {
            Ok(v.into())
        }
    )
}

struct Visitor;

impl<'de> de::Visitor<'de> for Visitor {
    type Value = ByteUnit;

    fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
        formatter.write_str("a byte unit as an integer or string")
    }

    visit_integer_fn!(visit_i8: i8);
    visit_integer_fn!(visit_i16: i16);
    visit_integer_fn!(visit_i32: i32);
    visit_integer_fn!(visit_i64: i64);
    visit_integer_fn!(visit_i128: i128);

    visit_integer_fn!(visit_u8: u8);
    visit_integer_fn!(visit_u16: u16);
    visit_integer_fn!(visit_u32: u32);
    visit_integer_fn!(visit_u64: u64);
    visit_integer_fn!(visit_u128: u128);

    fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
        v.parse().map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"byte unit string"))
    }
}

impl Serialize for ByteUnit {
    fn serialize<S: ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
        serializer.serialize_u64(self.as_u64())
    }
}

#[cfg(test)]
mod serde_tests {
    use serde_test::{assert_de_tokens, assert_ser_tokens, Token};
    use crate::ByteUnit;

    #[test]
    fn test_de() {
        let half_mib = ByteUnit::Kibibyte(512);
        assert_de_tokens(&half_mib, &[Token::Str("512 kib")]);
        assert_de_tokens(&half_mib, &[Token::Str("512 KiB")]);
        assert_de_tokens(&half_mib, &[Token::Str("512KiB")]);
        assert_de_tokens(&half_mib, &[Token::Str("524288")]);
        assert_de_tokens(&half_mib, &[Token::U32(524288)]);
        assert_de_tokens(&half_mib, &[Token::U64(524288)]);
        assert_de_tokens(&half_mib, &[Token::I32(524288)]);
        assert_de_tokens(&half_mib, &[Token::I64(524288)]);

        let one_mib = ByteUnit::Mebibyte(1);
        assert_de_tokens(&one_mib, &[Token::Str("1 mib")]);
        assert_de_tokens(&one_mib, &[Token::Str("1 MiB")]);
        assert_de_tokens(&one_mib, &[Token::Str("1mib")]);

        let zero = ByteUnit::Byte(0);
        assert_de_tokens(&zero, &[Token::Str("0")]);
        assert_de_tokens(&zero, &[Token::Str("0 B")]);
        assert_de_tokens(&zero, &[Token::U32(0)]);
        assert_de_tokens(&zero, &[Token::U64(0)]);
        assert_de_tokens(&zero, &[Token::I32(-34)]);
        assert_de_tokens(&zero, &[Token::I64(-2483)]);
    }

    #[test]
    fn test_ser() {
        let half_mib = ByteUnit::Kibibyte(512);
        assert_ser_tokens(&half_mib, &[Token::U64(512 << 10)]);

        let ten_bytes = ByteUnit::Byte(10);
        assert_ser_tokens(&ten_bytes, &[Token::U64(10)]);

        let zero = ByteUnit::Byte(0);
        assert_de_tokens(&zero, &[Token::U64(0)]);
    }
}