Skip to main content

nautilus_model/defi/
pool_identifier.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    fmt::{Debug, Display},
18    hash::{Hash, Hasher},
19    str::FromStr,
20};
21
22use alloy_primitives::Address;
23use nautilus_core::{correctness::FAILED, hex};
24use serde::{Deserialize, Deserializer, Serialize, Serializer};
25use ustr::Ustr;
26
27/// Protocol-aware pool identifier for DeFi liquidity pools.
28///
29/// This enum distinguishes between two types of pool identifiers:
30/// - **Address**: Used by V2/V3 protocols where pool identifier equals pool contract address (42 chars: "0x" + 40 hex)
31/// - **`PoolId`**: Used by V4 protocols where pool identifier is a bytes32 hash (66 chars: "0x" + 64 hex)
32///
33/// The type implements case-insensitive equality and hashing for address comparison,
34/// while preserving the original case for display purposes.
35#[derive(Clone, Copy, PartialOrd, Ord)]
36pub enum PoolIdentifier {
37    /// V2/V3 pool identifier (checksummed Ethereum address)
38    Address(Ustr),
39    /// V4 pool identifier (32-byte pool ID as hex string)
40    PoolId(Ustr),
41}
42
43impl PoolIdentifier {
44    /// Creates a new [`PoolIdentifier`] instance with correctness checking.
45    ///
46    /// Automatically detects variant based on string length:
47    /// - 42 characters (0x + 40 hex): Address variant
48    /// - 66 characters (0x + 64 hex): `PoolId` variant
49    ///
50    /// # Errors
51    ///
52    /// Returns an error if:
53    /// - String doesn't start with "0x"
54    /// - Length is neither 42 nor 66 characters
55    /// - Contains invalid hex characters
56    /// - Address checksum validation fails (for Address variant)
57    pub fn new_checked<T: AsRef<str>>(value: T) -> anyhow::Result<Self> {
58        let value = value.as_ref();
59
60        if !value.starts_with("0x") {
61            anyhow::bail!("Pool identifier must start with '0x', was: {value}");
62        }
63
64        match value.len() {
65            42 => {
66                validate_hex_string(value)?;
67
68                // Parse without strict checksum validation, then normalize to checksummed format
69                let addr = value
70                    .parse::<Address>()
71                    .map_err(|e| anyhow::anyhow!("Invalid address: {e}"))?;
72
73                // Store the checksummed version
74                Ok(Self::Address(Ustr::from(addr.to_checksum(None).as_str())))
75            }
76            66 => {
77                // PoolId variant (32 bytes)
78                validate_hex_string(value)?;
79
80                // Store lowercase version for consistency
81                Ok(Self::PoolId(Ustr::from(&value.to_lowercase())))
82            }
83            len => {
84                anyhow::bail!(
85                    "Pool identifier must be 42 chars (address) or 66 chars (pool ID), was {len} chars: {value}"
86                )
87            }
88        }
89    }
90
91    /// Creates a new [`PoolIdentifier`] instance.
92    ///
93    /// # Panics
94    ///
95    /// Panics if validation fails.
96    #[must_use]
97    pub fn new<T: AsRef<str>>(value: T) -> Self {
98        Self::new_checked(value).expect(FAILED)
99    }
100
101    /// Creates an Address variant from an alloy Address.
102    ///
103    /// Returns the checksummed representation.
104    #[must_use]
105    pub fn from_address(address: Address) -> Self {
106        Self::Address(Ustr::from(address.to_checksum(None).as_str()))
107    }
108
109    /// Creates a `PoolId` variant from raw bytes (32 bytes).
110    ///
111    /// # Errors
112    ///
113    /// Returns an error if bytes length is not 32.
114    pub fn from_pool_id_bytes(bytes: &[u8]) -> anyhow::Result<Self> {
115        anyhow::ensure!(
116            bytes.len() == 32,
117            "Pool ID must be 32 bytes, was {}",
118            bytes.len()
119        );
120
121        Ok(Self::PoolId(Ustr::from(&hex::encode_prefixed(bytes))))
122    }
123
124    /// Creates a `PoolId` variant from a hex string (with or without 0x prefix).
125    ///
126    /// # Errors
127    ///
128    /// Returns an error if the string is not valid 64-character hex.
129    pub fn from_pool_id_hex<T: AsRef<str>>(hex: T) -> anyhow::Result<Self> {
130        let hex = hex.as_ref();
131        let hex_str = hex.strip_prefix("0x").unwrap_or(hex);
132
133        anyhow::ensure!(
134            hex_str.len() == 64,
135            "Pool ID hex must be 64 characters (32 bytes), was {}",
136            hex_str.len()
137        );
138
139        validate_hex_string(&format!("0x{hex_str}"))?;
140
141        Ok(Self::PoolId(Ustr::from(&format!(
142            "0x{}",
143            hex_str.to_lowercase()
144        ))))
145    }
146
147    /// Returns the inner identifier value as a Ustr.
148    #[must_use]
149    pub fn inner(&self) -> Ustr {
150        match self {
151            Self::Address(s) | Self::PoolId(s) => *s,
152        }
153    }
154
155    /// Returns the inner identifier value as a string slice.
156    #[must_use]
157    pub fn as_str(&self) -> &str {
158        match self {
159            Self::Address(s) | Self::PoolId(s) => s.as_str(),
160        }
161    }
162
163    /// Returns true if this is an Address variant (V2/V3 pools).
164    #[must_use]
165    pub fn is_address(&self) -> bool {
166        matches!(self, Self::Address(_))
167    }
168
169    /// Returns true if this is a `PoolId` variant (V4 pools).
170    #[must_use]
171    pub fn is_pool_id(&self) -> bool {
172        matches!(self, Self::PoolId(_))
173    }
174
175    /// Converts to native Address type (V2/V3 pools only).
176    ///
177    /// Returns the underlying Address for use with alloy/ethers operations.
178    ///
179    /// # Errors
180    ///
181    /// Returns error if this is a `PoolId` variant or if parsing fails.
182    pub fn to_address(&self) -> anyhow::Result<Address> {
183        match self {
184            Self::Address(s) => Address::parse_checksummed(s.as_str(), None)
185                .map_err(|e| anyhow::anyhow!("Failed to parse address: {e}")),
186            Self::PoolId(_) => anyhow::bail!("Cannot convert PoolId variant to Address"),
187        }
188    }
189
190    /// Converts to native bytes array (V4 pools only).
191    ///
192    /// Returns the 32-byte pool ID for use in V4-specific operations.
193    ///
194    /// # Errors
195    ///
196    /// Returns error if this is an Address variant or if hex decoding fails.
197    pub fn to_pool_id_bytes(&self) -> anyhow::Result<[u8; 32]> {
198        match self {
199            Self::PoolId(s) => {
200                let hex_str = s.as_str().strip_prefix("0x").unwrap_or(s.as_str());
201                hex::decode_array::<32>(hex_str)
202                    .map_err(|e| anyhow::anyhow!("Failed to decode pool ID hex: {e}"))
203            }
204            Self::Address(_) => anyhow::bail!("Cannot convert Address variant to PoolId bytes"),
205        }
206    }
207}
208
209/// Validates that a string contains only valid hexadecimal characters after "0x" prefix.
210fn validate_hex_string(s: &str) -> anyhow::Result<()> {
211    let hex_part = &s[2..];
212    if !hex_part.chars().all(|c| c.is_ascii_hexdigit()) {
213        anyhow::bail!("Invalid hex characters in: {s}");
214    }
215    Ok(())
216}
217
218impl PartialEq for PoolIdentifier {
219    fn eq(&self, other: &Self) -> bool {
220        match (self, other) {
221            (Self::Address(a), Self::Address(b)) | (Self::PoolId(a), Self::PoolId(b)) => {
222                // Case-insensitive comparison
223                a.as_str().eq_ignore_ascii_case(b.as_str())
224            }
225            // Different variants are never equal
226            _ => false,
227        }
228    }
229}
230
231impl Eq for PoolIdentifier {}
232
233impl Hash for PoolIdentifier {
234    fn hash<H: Hasher>(&self, state: &mut H) {
235        // Hash the variant discriminant first
236        std::mem::discriminant(self).hash(state);
237
238        // Then hash the lowercase version of the string
239        match self {
240            Self::Address(s) | Self::PoolId(s) => {
241                for byte in s.as_str().bytes() {
242                    state.write_u8(byte.to_ascii_lowercase());
243                }
244            }
245        }
246    }
247}
248
249impl Display for PoolIdentifier {
250    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251        match self {
252            Self::Address(s) | Self::PoolId(s) => write!(f, "{s}"),
253        }
254    }
255}
256
257impl Debug for PoolIdentifier {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        match self {
260            Self::Address(s) => write!(f, "Address({s:?})"),
261            Self::PoolId(s) => write!(f, "PoolId({s:?})"),
262        }
263    }
264}
265
266impl Serialize for PoolIdentifier {
267    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
268    where
269        S: Serializer,
270    {
271        // Serialize as plain string (same as current String behavior)
272        match self {
273            Self::Address(s) | Self::PoolId(s) => s.serialize(serializer),
274        }
275    }
276}
277
278impl<'de> Deserialize<'de> for PoolIdentifier {
279    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
280    where
281        D: Deserializer<'de>,
282    {
283        let value_str: &str = Deserialize::deserialize(deserializer)?;
284        Self::new_checked(value_str).map_err(serde::de::Error::custom)
285    }
286}
287
288impl FromStr for PoolIdentifier {
289    type Err = anyhow::Error;
290
291    fn from_str(s: &str) -> Result<Self, Self::Err> {
292        Self::new_checked(s)
293    }
294}
295
296impl From<&str> for PoolIdentifier {
297    fn from(value: &str) -> Self {
298        Self::new(value)
299    }
300}
301
302impl From<String> for PoolIdentifier {
303    fn from(value: String) -> Self {
304        Self::new(value)
305    }
306}
307
308impl AsRef<str> for PoolIdentifier {
309    fn as_ref(&self) -> &str {
310        self.as_str()
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use rstest::rstest;
317
318    use super::*;
319
320    #[rstest]
321    #[case("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2", true)] // Valid checksummed address
322    #[case("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2", true)] // Lowercase address
323    #[case(
324        "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461",
325        true
326    )] // V4 Pool ID
327    fn test_valid_pool_identifiers(#[case] input: &str, #[case] expected_valid: bool) {
328        let result = PoolIdentifier::new_checked(input);
329        assert_eq!(result.is_ok(), expected_valid, "Input: {input}");
330    }
331
332    #[rstest]
333    #[case("C02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")] // Missing 0x
334    #[case("0xC02aaA39")] // Too short
335    #[case("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2EXTRA")] // Too long
336    #[case("0xGGGGGGGGb223FE8D0A0e5C4F27eAD9083C756Cc2")] // Invalid hex
337    fn test_invalid_pool_identifiers(#[case] input: &str) {
338        let result = PoolIdentifier::new_checked(input);
339        assert!(result.is_err(), "Input should fail: {input}");
340    }
341
342    #[rstest]
343    fn test_case_insensitive_equality() {
344        let addr1 = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
345        let addr2 = PoolIdentifier::new("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2");
346        let addr3 = PoolIdentifier::new("0xC02AAA39B223FE8D0A0E5C4F27EAD9083C756CC2");
347
348        assert_eq!(addr1, addr2);
349        assert_eq!(addr2, addr3);
350        assert_eq!(addr1, addr3);
351    }
352
353    #[rstest]
354    fn test_case_insensitive_hashing() {
355        use std::collections::HashMap;
356
357        let mut map = HashMap::new();
358        let addr1 = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
359        let addr2 = PoolIdentifier::new("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2");
360
361        map.insert(addr1, "value1");
362
363        // Should be able to retrieve using different case
364        assert_eq!(map.get(&addr2), Some(&"value1"));
365    }
366
367    #[rstest]
368    fn test_display_preserves_case() {
369        let checksummed = "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2";
370        let addr = PoolIdentifier::new_checked(checksummed).unwrap();
371
372        // Display should show checksummed version
373        assert_eq!(addr.to_string(), checksummed);
374    }
375
376    #[rstest]
377    fn test_variant_detection() {
378        let address = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
379        let pool_id = PoolIdentifier::new(
380            "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461",
381        );
382
383        assert!(address.is_address());
384        assert!(!address.is_pool_id());
385
386        assert!(pool_id.is_pool_id());
387        assert!(!pool_id.is_address());
388    }
389
390    #[rstest]
391    fn test_different_variants_not_equal() {
392        let address = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
393        let pool_id = PoolIdentifier::new(
394            "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461",
395        );
396
397        assert_ne!(address, pool_id);
398    }
399
400    #[rstest]
401    fn test_serialization_roundtrip() {
402        let original = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
403
404        let json = serde_json::to_string(&original).unwrap();
405        let deserialized: PoolIdentifier = serde_json::from_str(&json).unwrap();
406
407        assert_eq!(original, deserialized);
408    }
409
410    #[rstest]
411    fn test_from_address() {
412        let addr = Address::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2").unwrap();
413        let pool_id = PoolIdentifier::from_address(addr);
414
415        assert!(pool_id.is_address());
416        assert_eq!(
417            pool_id.to_string(),
418            "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"
419        );
420    }
421
422    #[rstest]
423    fn test_from_pool_id_bytes() {
424        let bytes: [u8; 32] = [
425            0xc9, 0xbc, 0x80, 0x43, 0x29, 0x41, 0x46, 0x42, 0x4a, 0x4e, 0x46, 0x07, 0xd8, 0xad,
426            0x83, 0x7d, 0x6a, 0x65, 0x91, 0x42, 0x82, 0x2b, 0xba, 0xaa, 0xbc, 0x83, 0xbb, 0x57,
427            0xe7, 0x44, 0x74, 0x61,
428        ];
429
430        let pool_id = PoolIdentifier::from_pool_id_bytes(&bytes).unwrap();
431
432        assert!(pool_id.is_pool_id());
433        assert_eq!(
434            pool_id.to_string(),
435            "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461"
436        );
437    }
438
439    #[rstest]
440    fn test_to_address() {
441        let id = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
442        let address = id.to_address().unwrap();
443
444        assert_eq!(
445            address.to_string(),
446            "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"
447        );
448    }
449
450    #[rstest]
451    fn test_to_address_fails_for_pool_id() {
452        let pool_id = PoolIdentifier::new(
453            "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461",
454        );
455        let result = pool_id.to_address();
456
457        assert!(result.is_err());
458    }
459
460    #[rstest]
461    fn test_to_pool_id_bytes() {
462        let pool_id = PoolIdentifier::new(
463            "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461",
464        );
465        let bytes = pool_id.to_pool_id_bytes().unwrap();
466
467        assert_eq!(bytes.len(), 32);
468        assert_eq!(bytes[0], 0xc9);
469        assert_eq!(bytes[31], 0x61);
470    }
471
472    #[rstest]
473    fn test_to_pool_id_bytes_fails_for_address() {
474        let address = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
475        let result = address.to_pool_id_bytes();
476
477        assert!(result.is_err());
478    }
479
480    #[rstest]
481    fn test_conversion_roundtrip_address() {
482        let original_addr =
483            Address::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2").unwrap();
484        let pool_id = PoolIdentifier::from_address(original_addr);
485        let converted_addr = pool_id.to_address().unwrap();
486
487        assert_eq!(original_addr, converted_addr);
488    }
489
490    #[rstest]
491    fn test_conversion_roundtrip_pool_id() {
492        let original_bytes: [u8; 32] = [
493            0xc9, 0xbc, 0x80, 0x43, 0x29, 0x41, 0x46, 0x42, 0x4a, 0x4e, 0x46, 0x07, 0xd8, 0xad,
494            0x83, 0x7d, 0x6a, 0x65, 0x91, 0x42, 0x82, 0x2b, 0xba, 0xaa, 0xbc, 0x83, 0xbb, 0x57,
495            0xe7, 0x44, 0x74, 0x61,
496        ];
497
498        let pool_id = PoolIdentifier::from_pool_id_bytes(&original_bytes).unwrap();
499        let converted_bytes = pool_id.to_pool_id_bytes().unwrap();
500
501        assert_eq!(original_bytes, converted_bytes);
502    }
503}