Skip to main content

nautilus_serialization/arrow/
json.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, sync::Arc};
17
18use arrow::{
19    array::{
20        Array, ArrayRef, BooleanArray, BooleanBuilder, Float64Array, Float64Builder, StringBuilder,
21        UInt64Array, UInt64Builder,
22    },
23    datatypes::{DataType, Field, Schema},
24    error::ArrowError,
25    record_batch::RecordBatch,
26};
27use serde::{Serialize, de::DeserializeOwned};
28use serde_json::{Map, Number, Value};
29
30use super::{EncodingError, StringColumnRef, extract_column, extract_column_string};
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub enum JsonFieldEncoding {
34    Utf8,
35    Utf8Json,
36    UInt64,
37    Float64,
38    Boolean,
39}
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq)]
42pub struct JsonFieldSpec {
43    pub name: &'static str,
44    pub encoding: JsonFieldEncoding,
45    pub nullable: bool,
46}
47
48impl JsonFieldSpec {
49    #[must_use]
50    pub const fn utf8(name: &'static str, nullable: bool) -> Self {
51        Self {
52            name,
53            encoding: JsonFieldEncoding::Utf8,
54            nullable,
55        }
56    }
57
58    #[must_use]
59    pub const fn utf8_json(name: &'static str, nullable: bool) -> Self {
60        Self {
61            name,
62            encoding: JsonFieldEncoding::Utf8Json,
63            nullable,
64        }
65    }
66
67    #[must_use]
68    pub const fn u64(name: &'static str, nullable: bool) -> Self {
69        Self {
70            name,
71            encoding: JsonFieldEncoding::UInt64,
72            nullable,
73        }
74    }
75
76    #[must_use]
77    pub const fn f64(name: &'static str, nullable: bool) -> Self {
78        Self {
79            name,
80            encoding: JsonFieldEncoding::Float64,
81            nullable,
82        }
83    }
84
85    #[must_use]
86    pub const fn boolean(name: &'static str, nullable: bool) -> Self {
87        Self {
88            name,
89            encoding: JsonFieldEncoding::Boolean,
90            nullable,
91        }
92    }
93
94    fn field(self) -> Field {
95        let data_type = match self.encoding {
96            JsonFieldEncoding::Utf8 | JsonFieldEncoding::Utf8Json => DataType::Utf8,
97            JsonFieldEncoding::UInt64 => DataType::UInt64,
98            JsonFieldEncoding::Float64 => DataType::Float64,
99            JsonFieldEncoding::Boolean => DataType::Boolean,
100        };
101
102        Field::new(self.name, data_type, self.nullable)
103    }
104}
105
106#[must_use]
107pub fn metadata_for_type(type_name: &'static str) -> HashMap<String, String> {
108    HashMap::from([("type".to_string(), type_name.to_string())])
109}
110
111#[must_use]
112pub fn schema_for_type(
113    type_name: &'static str,
114    metadata: Option<HashMap<String, String>>,
115    fields: &[JsonFieldSpec],
116) -> Schema {
117    let mut merged = metadata.unwrap_or_default();
118    merged.insert("type".to_string(), type_name.to_string());
119
120    Schema::new_with_metadata(
121        fields
122            .iter()
123            .copied()
124            .map(JsonFieldSpec::field)
125            .collect::<Vec<_>>(),
126        merged,
127    )
128}
129
130/// Encodes typed records into an Arrow record batch with the supplied schema metadata.
131///
132/// # Errors
133///
134/// Returns an error if JSON serialization fails or if a field cannot be encoded into
135/// the requested Arrow column type.
136pub fn encode_batch<T: Serialize>(
137    type_name: &'static str,
138    metadata: &HashMap<String, String>,
139    data: &[T],
140    fields: &[JsonFieldSpec],
141) -> Result<RecordBatch, ArrowError> {
142    let rows = serialize_rows(data)?;
143    let arrays: Result<Vec<ArrayRef>, ArrowError> = fields
144        .iter()
145        .copied()
146        .map(|field| encode_column(field, &rows))
147        .collect();
148
149    RecordBatch::try_new(
150        Arc::new(schema_for_type(type_name, Some(metadata.clone()), fields)),
151        arrays?,
152    )
153}
154
155/// Decodes typed records from an Arrow record batch produced by encode_batch.
156///
157/// # Errors
158///
159/// Returns an error if a required column is missing, has the wrong type, contains
160/// invalid JSON, or cannot be deserialized into the target type.
161pub fn decode_batch<T: DeserializeOwned>(
162    metadata: &HashMap<String, String>,
163    record_batch: &RecordBatch,
164    fields: &[JsonFieldSpec],
165    fallback_type_name: Option<&'static str>,
166) -> Result<Vec<T>, EncodingError> {
167    let columns: Result<Vec<_>, EncodingError> = fields
168        .iter()
169        .enumerate()
170        .map(|(index, field)| decode_column_ref(record_batch.columns(), *field, index))
171        .collect();
172    let columns = columns?;
173
174    let mut decoded = Vec::with_capacity(record_batch.num_rows());
175    let type_name = metadata
176        .get("type")
177        .cloned()
178        .or_else(|| fallback_type_name.map(str::to_string));
179
180    for row in 0..record_batch.num_rows() {
181        let mut value = Map::new();
182        if let Some(type_name) = &type_name {
183            value.insert("type".to_string(), Value::String(type_name.clone()));
184        }
185
186        for column in &columns {
187            value.insert(column.name().to_string(), column.to_json(row)?);
188        }
189
190        let json = serde_json::to_vec(&Value::Object(value))
191            .map_err(|e| EncodingError::ParseError("record_batch", format!("row {row}: {e}")))?;
192        decoded.push(
193            serde_json::from_slice(&json).map_err(|e| {
194                EncodingError::ParseError("record_batch", format!("row {row}: {e}"))
195            })?,
196        );
197    }
198
199    Ok(decoded)
200}
201
202fn serialize_rows<T: Serialize>(data: &[T]) -> Result<Vec<Map<String, Value>>, ArrowError> {
203    data.iter()
204        .map(|item| match serde_json::to_value(item) {
205            Ok(Value::Object(map)) => Ok(map),
206            Ok(_) => Err(invalid_argument(
207                "Expected serialized value to be a JSON object".to_string(),
208            )),
209            Err(e) => Err(invalid_argument(e.to_string())),
210        })
211        .collect()
212}
213
214fn encode_column(
215    field: JsonFieldSpec,
216    rows: &[Map<String, Value>],
217) -> Result<ArrayRef, ArrowError> {
218    match field.encoding {
219        JsonFieldEncoding::Utf8 => encode_utf8_column(field, rows),
220        JsonFieldEncoding::Utf8Json => encode_utf8_json_column(field, rows),
221        JsonFieldEncoding::UInt64 => encode_u64_column(field, rows),
222        JsonFieldEncoding::Float64 => encode_f64_column(field, rows),
223        JsonFieldEncoding::Boolean => encode_bool_column(field, rows),
224    }
225}
226
227fn encode_utf8_column(
228    field: JsonFieldSpec,
229    rows: &[Map<String, Value>],
230) -> Result<ArrayRef, ArrowError> {
231    let mut builder = StringBuilder::new();
232
233    for row in rows {
234        match require_value(field, row.get(field.name))? {
235            Some(value) => builder.append_value(value_to_string(value)?),
236            None => builder.append_null(),
237        }
238    }
239
240    Ok(Arc::new(builder.finish()))
241}
242
243fn encode_utf8_json_column(
244    field: JsonFieldSpec,
245    rows: &[Map<String, Value>],
246) -> Result<ArrayRef, ArrowError> {
247    let mut builder = StringBuilder::new();
248
249    for row in rows {
250        match require_value(field, row.get(field.name))? {
251            Some(value) => builder.append_value(
252                serde_json::to_string(value).map_err(|e| invalid_argument(e.to_string()))?,
253            ),
254            None => builder.append_null(),
255        }
256    }
257
258    Ok(Arc::new(builder.finish()))
259}
260
261fn encode_u64_column(
262    field: JsonFieldSpec,
263    rows: &[Map<String, Value>],
264) -> Result<ArrayRef, ArrowError> {
265    let mut builder = UInt64Builder::new();
266
267    for row in rows {
268        match require_value(field, row.get(field.name))? {
269            Some(value) => builder.append_value(parse_u64(value)?),
270            None => builder.append_null(),
271        }
272    }
273
274    Ok(Arc::new(builder.finish()))
275}
276
277fn encode_f64_column(
278    field: JsonFieldSpec,
279    rows: &[Map<String, Value>],
280) -> Result<ArrayRef, ArrowError> {
281    let mut builder = Float64Builder::new();
282
283    for row in rows {
284        match require_value(field, row.get(field.name))? {
285            Some(value) => builder.append_value(parse_f64(value)?),
286            None => builder.append_null(),
287        }
288    }
289
290    Ok(Arc::new(builder.finish()))
291}
292
293fn encode_bool_column(
294    field: JsonFieldSpec,
295    rows: &[Map<String, Value>],
296) -> Result<ArrayRef, ArrowError> {
297    let mut builder = BooleanBuilder::new();
298
299    for row in rows {
300        match require_value(field, row.get(field.name))? {
301            Some(value) => builder.append_value(parse_bool(value)?),
302            None => builder.append_null(),
303        }
304    }
305
306    Ok(Arc::new(builder.finish()))
307}
308
309fn require_value(
310    field: JsonFieldSpec,
311    value: Option<&Value>,
312) -> Result<Option<&Value>, ArrowError> {
313    match value {
314        Some(Value::Null) | None if !field.nullable => Err(invalid_argument(format!(
315            "Missing required field `{}`",
316            field.name
317        ))),
318        Some(Value::Null) | None => Ok(None),
319        Some(value) => Ok(Some(value)),
320    }
321}
322
323fn value_to_string(value: &Value) -> Result<String, ArrowError> {
324    match value {
325        Value::String(value) => Ok(value.clone()),
326        Value::Null => Err(invalid_argument("Unexpected null value".to_string())),
327        Value::Bool(_) | Value::Number(_) => Ok(value.to_string()),
328        Value::Array(_) | Value::Object(_) => {
329            serde_json::to_string(value).map_err(|e| invalid_argument(e.to_string()))
330        }
331    }
332}
333
334fn parse_u64(value: &Value) -> Result<u64, ArrowError> {
335    match value {
336        Value::Number(number) => number
337            .as_u64()
338            .ok_or_else(|| invalid_argument(format!("Expected u64, found `{number}`"))),
339        Value::String(value) => value
340            .parse::<u64>()
341            .map_err(|e| invalid_argument(format!("Failed to parse u64 from `{value}`: {e}"))),
342        _ => Err(invalid_argument(format!(
343            "Expected u64-compatible value, found `{value}`"
344        ))),
345    }
346}
347
348fn parse_f64(value: &Value) -> Result<f64, ArrowError> {
349    match value {
350        Value::Number(number) => number
351            .as_f64()
352            .ok_or_else(|| invalid_argument(format!("Expected f64, found `{number}`"))),
353        Value::String(value) => value
354            .parse::<f64>()
355            .map_err(|e| invalid_argument(format!("Failed to parse f64 from `{value}`: {e}"))),
356        _ => Err(invalid_argument(format!(
357            "Expected f64-compatible value, found `{value}`"
358        ))),
359    }
360}
361
362fn parse_bool(value: &Value) -> Result<bool, ArrowError> {
363    match value {
364        Value::Bool(value) => Ok(*value),
365        Value::String(value) => value
366            .parse::<bool>()
367            .map_err(|e| invalid_argument(format!("Failed to parse bool from `{value}`: {e}"))),
368        _ => Err(invalid_argument(format!(
369            "Expected bool-compatible value, found `{value}`"
370        ))),
371    }
372}
373
374enum ColumnRef<'a> {
375    Utf8 {
376        name: &'static str,
377        values: StringColumnRef<'a>,
378    },
379    Utf8Json {
380        name: &'static str,
381        values: StringColumnRef<'a>,
382    },
383    UInt64 {
384        name: &'static str,
385        values: &'a UInt64Array,
386    },
387    Float64 {
388        name: &'static str,
389        values: &'a Float64Array,
390    },
391    Boolean {
392        name: &'static str,
393        values: &'a BooleanArray,
394    },
395}
396
397impl ColumnRef<'_> {
398    fn name(&self) -> &'static str {
399        match self {
400            Self::Utf8 { name, .. }
401            | Self::Utf8Json { name, .. }
402            | Self::UInt64 { name, .. }
403            | Self::Float64 { name, .. }
404            | Self::Boolean { name, .. } => name,
405        }
406    }
407
408    fn to_json(&self, row: usize) -> Result<Value, EncodingError> {
409        match self {
410            Self::Utf8 { values, .. } => {
411                if values_is_null(values, row) {
412                    Ok(Value::Null)
413                } else {
414                    Ok(Value::String(values.value(row).to_string()))
415                }
416            }
417            Self::Utf8Json { values, .. } => {
418                if values_is_null(values, row) {
419                    Ok(Value::Null)
420                } else {
421                    serde_json::from_str(values.value(row)).map_err(|e| {
422                        EncodingError::ParseError(self.name(), format!("row {row}: {e}"))
423                    })
424                }
425            }
426            Self::UInt64 { values, .. } => {
427                if values.is_null(row) {
428                    Ok(Value::Null)
429                } else {
430                    Ok(Value::Number(Number::from(values.value(row))))
431                }
432            }
433            Self::Float64 { values, .. } => {
434                if values.is_null(row) {
435                    Ok(Value::Null)
436                } else {
437                    Number::from_f64(values.value(row))
438                        .map(Value::Number)
439                        .ok_or_else(|| {
440                            EncodingError::ParseError(
441                                self.name(),
442                                format!("row {row}: invalid f64 value"),
443                            )
444                        })
445                }
446            }
447            Self::Boolean { values, .. } => {
448                if values.is_null(row) {
449                    Ok(Value::Null)
450                } else {
451                    Ok(Value::Bool(values.value(row)))
452                }
453            }
454        }
455    }
456}
457
458fn decode_column_ref(
459    columns: &[ArrayRef],
460    field: JsonFieldSpec,
461    index: usize,
462) -> Result<ColumnRef<'_>, EncodingError> {
463    match field.encoding {
464        JsonFieldEncoding::Utf8 => Ok(ColumnRef::Utf8 {
465            name: field.name,
466            values: extract_column_string(columns, field.name, index)?,
467        }),
468        JsonFieldEncoding::Utf8Json => Ok(ColumnRef::Utf8Json {
469            name: field.name,
470            values: extract_column_string(columns, field.name, index)?,
471        }),
472        JsonFieldEncoding::UInt64 => Ok(ColumnRef::UInt64 {
473            name: field.name,
474            values: extract_column::<UInt64Array>(columns, field.name, index, DataType::UInt64)?,
475        }),
476        JsonFieldEncoding::Float64 => Ok(ColumnRef::Float64 {
477            name: field.name,
478            values: extract_column::<Float64Array>(columns, field.name, index, DataType::Float64)?,
479        }),
480        JsonFieldEncoding::Boolean => Ok(ColumnRef::Boolean {
481            name: field.name,
482            values: extract_column::<BooleanArray>(columns, field.name, index, DataType::Boolean)?,
483        }),
484    }
485}
486
487fn values_is_null(values: &StringColumnRef<'_>, row: usize) -> bool {
488    match values {
489        StringColumnRef::Utf8(values) => values.is_null(row),
490        StringColumnRef::Utf8View(values) => values.is_null(row),
491    }
492}
493
494fn invalid_argument(message: String) -> ArrowError {
495    ArrowError::InvalidArgumentError(message)
496}