1use std::{
17 collections::HashMap,
18 hash::{Hash, Hasher},
19};
20
21use derive_builder::Builder;
22use nautilus_core::{UnixNanos, correctness::FAILED};
23use serde::{Deserialize, Serialize};
24
25use crate::{
26 expressions::{Bindings, CompiledExpression, compile_numeric},
27 identifiers::{InstrumentId, Symbol, Venue},
28 types::Price,
29};
30
31const MAX_INLINE_COMPONENTS: usize = 8;
32
33#[derive(Clone, Debug, Builder)]
38#[cfg_attr(
39 feature = "python",
40 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model", from_py_object)
41)]
42#[cfg_attr(
43 feature = "python",
44 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.model")
45)]
46pub struct SyntheticInstrument {
47 pub id: InstrumentId,
49 pub price_precision: u8,
51 pub price_increment: Price,
53 pub components: Vec<InstrumentId>,
55 pub formula: String,
57 pub ts_event: UnixNanos,
59 pub ts_init: UnixNanos,
61 #[builder(setter(skip), default)]
62 component_names: Vec<String>,
63 #[builder(setter(skip), default)]
64 compiled_formula: CompiledExpression,
65}
66
67impl Serialize for SyntheticInstrument {
68 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
69 where
70 S: serde::Serializer,
71 {
72 use serde::ser::SerializeStruct;
73 let mut state = serializer.serialize_struct("SyntheticInstrument", 7)?;
74 state.serialize_field("id", &self.id)?;
75 state.serialize_field("price_precision", &self.price_precision)?;
76 state.serialize_field("price_increment", &self.price_increment)?;
77 state.serialize_field("components", &self.components)?;
78 state.serialize_field("formula", &self.formula)?;
79 state.serialize_field("ts_event", &self.ts_event)?;
80 state.serialize_field("ts_init", &self.ts_init)?;
81 state.end()
82 }
83}
84
85impl<'de> Deserialize<'de> for SyntheticInstrument {
86 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
87 where
88 D: serde::Deserializer<'de>,
89 {
90 #[derive(Deserialize)]
91 struct Fields {
92 id: InstrumentId,
93 price_precision: u8,
94 price_increment: Price,
95 components: Vec<InstrumentId>,
96 formula: String,
97 ts_event: UnixNanos,
98 ts_init: UnixNanos,
99 }
100
101 let fields = Fields::deserialize(deserializer)?;
102 let component_names = component_names_from_components(&fields.components);
103 let compiled_formula =
104 compile_formula(&fields.formula, &component_names).map_err(serde::de::Error::custom)?;
105
106 Ok(Self {
107 id: fields.id,
108 price_precision: fields.price_precision,
109 price_increment: fields.price_increment,
110 components: fields.components,
111 formula: fields.formula,
112 ts_event: fields.ts_event,
113 ts_init: fields.ts_init,
114 component_names,
115 compiled_formula,
116 })
117 }
118}
119
120impl SyntheticInstrument {
121 pub fn new_checked(
130 symbol: Symbol,
131 price_precision: u8,
132 components: Vec<InstrumentId>,
133 formula: &str,
134 ts_event: UnixNanos,
135 ts_init: UnixNanos,
136 ) -> anyhow::Result<Self> {
137 let price_increment =
138 Price::new_checked(10f64.powi(-i32::from(price_precision)), price_precision)?;
139 let component_names = component_names_from_components(&components);
140 let compiled_formula = compile_formula(formula, &component_names)?;
141
142 Ok(Self {
143 id: InstrumentId::new(symbol, Venue::synthetic()),
144 price_precision,
145 price_increment,
146 components,
147 formula: formula.to_string(),
148 component_names,
149 compiled_formula,
150 ts_event,
151 ts_init,
152 })
153 }
154
155 #[must_use]
157 pub fn is_valid_formula_for_components(formula: &str, components: &[InstrumentId]) -> bool {
158 let component_names = component_names_from_components(components);
159 compile_formula(formula, &component_names).is_ok()
160 }
161
162 #[must_use]
168 pub fn new(
169 symbol: Symbol,
170 price_precision: u8,
171 components: Vec<InstrumentId>,
172 formula: &str,
173 ts_event: UnixNanos,
174 ts_init: UnixNanos,
175 ) -> Self {
176 Self::new_checked(
177 symbol,
178 price_precision,
179 components,
180 formula,
181 ts_event,
182 ts_init,
183 )
184 .expect(FAILED)
185 }
186
187 #[must_use]
189 pub fn is_valid_formula(&self, formula: &str) -> bool {
190 Self::is_valid_formula_for_components(formula, &self.components)
191 }
192
193 pub fn change_formula(&mut self, formula: &str) -> anyhow::Result<()> {
199 let compiled_formula = compile_formula(formula, &self.component_names)?;
200 self.formula = formula.to_string();
201 self.compiled_formula = compiled_formula;
202 Ok(())
203 }
204
205 pub fn calculate_from_map(&self, inputs: &HashMap<String, f64>) -> anyhow::Result<Price> {
212 let n = self.component_names.len();
213 let mut buf = [0.0_f64; MAX_INLINE_COMPONENTS];
214 let input_values: &[f64] = if n <= MAX_INLINE_COMPONENTS {
215 for (i, component_name) in self.component_names.iter().enumerate() {
216 buf[i] = *inputs.get(component_name).ok_or_else(|| {
217 anyhow::anyhow!("Missing price for component: {component_name}")
218 })?;
219 }
220 &buf[..n]
221 } else {
222 let v: std::result::Result<Vec<f64>, _> = self
224 .component_names
225 .iter()
226 .map(|name| {
227 inputs
228 .get(name)
229 .copied()
230 .ok_or_else(|| anyhow::anyhow!("Missing price for component: {name}"))
231 })
232 .collect();
233 return self.calculate(&v?);
234 };
235
236 self.calculate(input_values)
237 }
238
239 pub fn calculate(&self, inputs: &[f64]) -> anyhow::Result<Price> {
247 if inputs.len() != self.component_names.len() {
248 anyhow::bail!(
249 "Expected {} input values, received {}",
250 self.component_names.len(),
251 inputs.len(),
252 );
253 }
254
255 for (i, value) in inputs.iter().enumerate() {
256 if !value.is_finite() {
257 anyhow::bail!(
258 "Non-finite input price for component {}: {value}",
259 self.component_names[i],
260 );
261 }
262 }
263
264 let price = self.compiled_formula.eval_number(inputs)?;
265 Price::new_checked(price, self.price_precision)
266 .map_err(|e| anyhow::anyhow!("Formula result produced invalid price: {e}"))
267 }
268}
269
270fn component_names_from_components(components: &[InstrumentId]) -> Vec<String> {
271 components.iter().map(ToString::to_string).collect()
272}
273
274fn build_bindings(component_names: &[String]) -> anyhow::Result<Bindings> {
278 let mut bindings = Bindings::new();
279
280 for (slot, component_name) in component_names.iter().enumerate() {
281 bindings.add(slot, component_name)?;
282 }
283
284 for (slot, component_name) in component_names.iter().enumerate() {
285 let legacy_name = component_name.replace('-', "_");
286
287 if legacy_name != *component_name {
288 let _ = bindings.add_alias(slot, &legacy_name);
290 }
291 }
292
293 Ok(bindings)
294}
295
296fn compile_formula(
300 formula: &str,
301 component_names: &[String],
302) -> anyhow::Result<CompiledExpression> {
303 let bindings = build_bindings(component_names)?;
304 Ok(compile_numeric(formula, &bindings)?)
305}
306
307impl PartialEq<Self> for SyntheticInstrument {
308 fn eq(&self, other: &Self) -> bool {
309 self.id == other.id
310 }
311}
312
313impl Eq for SyntheticInstrument {}
314
315impl Hash for SyntheticInstrument {
316 fn hash<H: Hasher>(&self, state: &mut H) {
317 self.id.hash(state);
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use std::str::FromStr;
324
325 use rstest::rstest;
326
327 use super::*;
328
329 #[rstest]
330 fn test_calculate_from_map() {
331 let synth = SyntheticInstrument::default();
332 let mut inputs = HashMap::new();
333 inputs.insert("BTC.BINANCE".to_string(), 100.0);
334 inputs.insert("LTC.BINANCE".to_string(), 200.0);
335 let price = synth.calculate_from_map(&inputs).unwrap();
336
337 assert_eq!(price, Price::from("150.0"));
338 assert_eq!(
339 synth.formula,
340 "(BTC.BINANCE + LTC.BINANCE) / 2.0".to_string()
341 );
342 }
343
344 #[rstest]
345 fn test_calculate() {
346 let synth = SyntheticInstrument::default();
347 let inputs = vec![100.0, 200.0];
348 let price = synth.calculate(&inputs).unwrap();
349 assert_eq!(price, Price::from("150.0"));
350 }
351
352 #[rstest]
353 fn test_change_formula() {
354 let mut synth = SyntheticInstrument::default();
355 let new_formula = "(BTC.BINANCE + LTC.BINANCE) / 4";
356 synth.change_formula(new_formula).unwrap();
357
358 let mut inputs = HashMap::new();
359 inputs.insert("BTC.BINANCE".to_string(), 100.0);
360 inputs.insert("LTC.BINANCE".to_string(), 200.0);
361 let price = synth.calculate_from_map(&inputs).unwrap();
362
363 assert_eq!(price, Price::from("75.0"));
364 assert_eq!(synth.formula, new_formula);
365 }
366
367 #[rstest]
368 fn test_hyphenated_instrument_ids_preserve_raw_formula() {
369 let comp1 = InstrumentId::from_str("ETHUSDC-PERP.BINANCE_FUTURES").unwrap();
370 let comp2 = InstrumentId::from_str("ETH_USDC-PERP.HYPERLIQUID").unwrap();
371 let components = vec![comp1, comp2];
372 let raw_formula = format!("({comp1} + {comp2}) / 2.0");
373 let symbol = Symbol::from("ETH-USDC");
374 let synth =
375 SyntheticInstrument::new(symbol, 2, components, &raw_formula, 0.into(), 0.into());
376 let price = synth.calculate(&[100.0, 200.0]).unwrap();
377
378 assert_eq!(price, Price::from("150.0"));
379 assert_eq!(synth.formula, raw_formula);
380 }
381
382 #[rstest]
383 fn test_hyphenated_instrument_ids_support_legacy_sanitized_formula() {
384 let comp1 = InstrumentId::from_str("ETH-USDT-SWAP.OKX").unwrap();
385 let comp2 = InstrumentId::from_str("ETH-USDC-PERP.HYPERLIQUID").unwrap();
386 let components = vec![comp1, comp2];
387 let legacy_formula = format!(
388 "({} + {}) / 2.0",
389 components[0].to_string().replace('-', "_"),
390 components[1].to_string().replace('-', "_"),
391 );
392 let symbol = Symbol::from("ETH-USD");
393 let synth = SyntheticInstrument::new(
394 symbol,
395 2,
396 components.clone(),
397 &legacy_formula,
398 0.into(),
399 0.into(),
400 );
401 let mut inputs = HashMap::new();
402 inputs.insert(components[0].to_string(), 100.0);
403 inputs.insert(components[1].to_string(), 200.0);
404 let price = synth.calculate_from_map(&inputs).unwrap();
405
406 assert_eq!(price, Price::from("150.0"));
407 assert_eq!(synth.formula, legacy_formula);
408 }
409
410 #[rstest]
411 fn test_slashed_instrument_ids_calculate_from_map() {
412 let comp1 = InstrumentId::from_str("AUD/USD.SIM").unwrap();
413 let comp2 = InstrumentId::from_str("NZD/USD.SIM").unwrap();
414 let components = vec![comp1, comp2];
415 let raw_formula = format!("({} + {}) / 2.0", components[0], components[1]);
416
417 let synth = SyntheticInstrument::new(
418 Symbol::from("FX-BASKET"),
419 5,
420 components.clone(),
421 &raw_formula,
422 0.into(),
423 0.into(),
424 );
425 let mut inputs = HashMap::new();
426 inputs.insert(components[0].to_string(), 0.65001);
427 inputs.insert(components[1].to_string(), 0.59001);
428
429 let price = synth.calculate_from_map(&inputs).unwrap();
430
431 assert_eq!(price, Price::from("0.62001"));
432 assert_eq!(synth.formula, raw_formula);
433 }
434
435 #[rstest]
436 fn test_deserialize_rejects_unknown_formula_symbol() {
437 let synth = SyntheticInstrument::default();
438 let payload = serde_json::to_string(&synth).unwrap().replace(
439 "\"(BTC.BINANCE + LTC.BINANCE) / 2.0\"",
440 "\"BTC.BINANCE + missing\"",
441 );
442
443 let error = serde_json::from_str::<SyntheticInstrument>(&payload).unwrap_err();
444
445 assert!(
446 error.to_string().contains("Unknown symbol `missing`"),
447 "{error}",
448 );
449 }
450
451 #[rstest]
452 fn test_calculate_rejects_wrong_input_count() {
453 let synth = SyntheticInstrument::default();
454 let error = synth.calculate(&[100.0]).unwrap_err();
455
456 assert!(
457 error
458 .to_string()
459 .contains("Expected 2 input values, received 1"),
460 "{error}",
461 );
462 }
463
464 #[rstest]
465 fn test_calculate_from_map_rejects_missing_component() {
466 let synth = SyntheticInstrument::default();
467 let mut inputs = HashMap::new();
468 inputs.insert("BTC.BINANCE".to_string(), 100.0);
469
470 let error = synth.calculate_from_map(&inputs).unwrap_err();
471
472 assert!(
473 error
474 .to_string()
475 .contains("Missing price for component: LTC.BINANCE"),
476 "{error}",
477 );
478 }
479
480 #[rstest]
481 fn test_calculate_rejects_invalid_price_result() {
482 let mut synth = SyntheticInstrument::default();
483 synth
484 .change_formula("BTC.BINANCE / (LTC.BINANCE - LTC.BINANCE)")
485 .unwrap();
486
487 let error = synth.calculate(&[100.0, 100.0]).unwrap_err();
488
489 assert!(
490 error
491 .to_string()
492 .contains("Formula result produced invalid price"),
493 "{error}",
494 );
495 }
496
497 #[rstest]
498 fn test_is_valid_formula() {
499 let synth = SyntheticInstrument::default();
500
501 assert!(synth.is_valid_formula("(BTC.BINANCE + LTC.BINANCE) / 3"));
502 assert!(!synth.is_valid_formula("UNKNOWN.VENUE + 1"));
503 assert!(!synth.is_valid_formula(""));
504 }
505
506 #[rstest]
507 #[case(f64::NAN, 100.0, "Non-finite input price")]
508 #[case(100.0, f64::INFINITY, "Non-finite input price")]
509 #[case(f64::NEG_INFINITY, 100.0, "Non-finite input price")]
510 fn test_calculate_rejects_non_finite_inputs(
511 #[case] a: f64,
512 #[case] b: f64,
513 #[case] expected_msg: &str,
514 ) {
515 let synth = SyntheticInstrument::default();
516 let error = synth.calculate(&[a, b]).unwrap_err();
517
518 assert!(error.to_string().contains(expected_msg), "{error}");
519 }
520
521 #[rstest]
522 fn test_components_with_colliding_legacy_aliases_coexist() {
523 let comp1 = InstrumentId::from_str("FOO-BAR.VENUE").unwrap();
524 let comp2 = InstrumentId::from_str("FOO_BAR.VENUE").unwrap();
525 let formula = format!("{comp1} + {comp2}");
526 let synth = SyntheticInstrument::new(
527 Symbol::from("TEST"),
528 2,
529 vec![comp1, comp2],
530 &formula,
531 0.into(),
532 0.into(),
533 );
534 let price = synth.calculate(&[100.0, 200.0]).unwrap();
535
536 assert_eq!(price, Price::from("300.0"));
537 }
538
539 #[rstest]
540 fn test_calculate_from_map_fallback_for_many_components() {
541 let count = MAX_INLINE_COMPONENTS + 2;
542 let components: Vec<InstrumentId> = (0..count)
543 .map(|i| InstrumentId::from(format!("C{i}.VENUE").as_str()))
544 .collect();
545 let terms: Vec<String> = components.iter().map(|c| c.to_string()).collect();
546 let formula = terms.join(" + ");
547
548 let synth = SyntheticInstrument::new(
549 Symbol::from("BIG"),
550 2,
551 components.clone(),
552 &formula,
553 0.into(),
554 0.into(),
555 );
556
557 let mut inputs = HashMap::new();
558 for component in &components {
559 inputs.insert(component.to_string(), 10.0);
560 }
561
562 let price = synth.calculate_from_map(&inputs).unwrap();
563
564 assert_eq!(price, Price::from("100.0"));
565 }
566}