1use std::{collections::HashMap, sync::Arc};
17
18use arrow::{
19 array::{
20 Array, ArrayRef, BooleanArray, BooleanBuilder, Float64Array, Float64Builder, StringBuilder,
21 UInt64Array, UInt64Builder,
22 },
23 datatypes::{DataType, Field, Schema},
24 error::ArrowError,
25 record_batch::RecordBatch,
26};
27use serde::{Serialize, de::DeserializeOwned};
28use serde_json::{Map, Number, Value};
29
30use super::{EncodingError, StringColumnRef, extract_column, extract_column_string};
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub enum JsonFieldEncoding {
34 Utf8,
35 Utf8Json,
36 UInt64,
37 Float64,
38 Boolean,
39}
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq)]
42pub struct JsonFieldSpec {
43 pub name: &'static str,
44 pub encoding: JsonFieldEncoding,
45 pub nullable: bool,
46}
47
48impl JsonFieldSpec {
49 #[must_use]
50 pub const fn utf8(name: &'static str, nullable: bool) -> Self {
51 Self {
52 name,
53 encoding: JsonFieldEncoding::Utf8,
54 nullable,
55 }
56 }
57
58 #[must_use]
59 pub const fn utf8_json(name: &'static str, nullable: bool) -> Self {
60 Self {
61 name,
62 encoding: JsonFieldEncoding::Utf8Json,
63 nullable,
64 }
65 }
66
67 #[must_use]
68 pub const fn u64(name: &'static str, nullable: bool) -> Self {
69 Self {
70 name,
71 encoding: JsonFieldEncoding::UInt64,
72 nullable,
73 }
74 }
75
76 #[must_use]
77 pub const fn f64(name: &'static str, nullable: bool) -> Self {
78 Self {
79 name,
80 encoding: JsonFieldEncoding::Float64,
81 nullable,
82 }
83 }
84
85 #[must_use]
86 pub const fn boolean(name: &'static str, nullable: bool) -> Self {
87 Self {
88 name,
89 encoding: JsonFieldEncoding::Boolean,
90 nullable,
91 }
92 }
93
94 fn field(self) -> Field {
95 let data_type = match self.encoding {
96 JsonFieldEncoding::Utf8 | JsonFieldEncoding::Utf8Json => DataType::Utf8,
97 JsonFieldEncoding::UInt64 => DataType::UInt64,
98 JsonFieldEncoding::Float64 => DataType::Float64,
99 JsonFieldEncoding::Boolean => DataType::Boolean,
100 };
101
102 Field::new(self.name, data_type, self.nullable)
103 }
104}
105
106#[must_use]
107pub fn metadata_for_type(type_name: &'static str) -> HashMap<String, String> {
108 HashMap::from([("type".to_string(), type_name.to_string())])
109}
110
111#[must_use]
112pub fn schema_for_type(
113 type_name: &'static str,
114 metadata: Option<HashMap<String, String>>,
115 fields: &[JsonFieldSpec],
116) -> Schema {
117 let mut merged = metadata.unwrap_or_default();
118 merged.insert("type".to_string(), type_name.to_string());
119
120 Schema::new_with_metadata(
121 fields
122 .iter()
123 .copied()
124 .map(JsonFieldSpec::field)
125 .collect::<Vec<_>>(),
126 merged,
127 )
128}
129
130pub fn encode_batch<T: Serialize>(
137 type_name: &'static str,
138 metadata: &HashMap<String, String>,
139 data: &[T],
140 fields: &[JsonFieldSpec],
141) -> Result<RecordBatch, ArrowError> {
142 let rows = serialize_rows(data)?;
143 let arrays: Result<Vec<ArrayRef>, ArrowError> = fields
144 .iter()
145 .copied()
146 .map(|field| encode_column(field, &rows))
147 .collect();
148
149 RecordBatch::try_new(
150 Arc::new(schema_for_type(type_name, Some(metadata.clone()), fields)),
151 arrays?,
152 )
153}
154
155pub fn decode_batch<T: DeserializeOwned>(
162 metadata: &HashMap<String, String>,
163 record_batch: &RecordBatch,
164 fields: &[JsonFieldSpec],
165 fallback_type_name: Option<&'static str>,
166) -> Result<Vec<T>, EncodingError> {
167 let columns: Result<Vec<_>, EncodingError> = fields
168 .iter()
169 .enumerate()
170 .map(|(index, field)| decode_column_ref(record_batch.columns(), *field, index))
171 .collect();
172 let columns = columns?;
173
174 let mut decoded = Vec::with_capacity(record_batch.num_rows());
175 let type_name = metadata
176 .get("type")
177 .cloned()
178 .or_else(|| fallback_type_name.map(str::to_string));
179
180 for row in 0..record_batch.num_rows() {
181 let mut value = Map::new();
182 if let Some(type_name) = &type_name {
183 value.insert("type".to_string(), Value::String(type_name.clone()));
184 }
185
186 for column in &columns {
187 value.insert(column.name().to_string(), column.to_json(row)?);
188 }
189
190 let json = serde_json::to_vec(&Value::Object(value))
191 .map_err(|e| EncodingError::ParseError("record_batch", format!("row {row}: {e}")))?;
192 decoded.push(
193 serde_json::from_slice(&json).map_err(|e| {
194 EncodingError::ParseError("record_batch", format!("row {row}: {e}"))
195 })?,
196 );
197 }
198
199 Ok(decoded)
200}
201
202fn serialize_rows<T: Serialize>(data: &[T]) -> Result<Vec<Map<String, Value>>, ArrowError> {
203 data.iter()
204 .map(|item| match serde_json::to_value(item) {
205 Ok(Value::Object(map)) => Ok(map),
206 Ok(_) => Err(invalid_argument(
207 "Expected serialized value to be a JSON object".to_string(),
208 )),
209 Err(e) => Err(invalid_argument(e.to_string())),
210 })
211 .collect()
212}
213
214fn encode_column(
215 field: JsonFieldSpec,
216 rows: &[Map<String, Value>],
217) -> Result<ArrayRef, ArrowError> {
218 match field.encoding {
219 JsonFieldEncoding::Utf8 => encode_utf8_column(field, rows),
220 JsonFieldEncoding::Utf8Json => encode_utf8_json_column(field, rows),
221 JsonFieldEncoding::UInt64 => encode_u64_column(field, rows),
222 JsonFieldEncoding::Float64 => encode_f64_column(field, rows),
223 JsonFieldEncoding::Boolean => encode_bool_column(field, rows),
224 }
225}
226
227fn encode_utf8_column(
228 field: JsonFieldSpec,
229 rows: &[Map<String, Value>],
230) -> Result<ArrayRef, ArrowError> {
231 let mut builder = StringBuilder::new();
232
233 for row in rows {
234 match require_value(field, row.get(field.name))? {
235 Some(value) => builder.append_value(value_to_string(value)?),
236 None => builder.append_null(),
237 }
238 }
239
240 Ok(Arc::new(builder.finish()))
241}
242
243fn encode_utf8_json_column(
244 field: JsonFieldSpec,
245 rows: &[Map<String, Value>],
246) -> Result<ArrayRef, ArrowError> {
247 let mut builder = StringBuilder::new();
248
249 for row in rows {
250 match require_value(field, row.get(field.name))? {
251 Some(value) => builder.append_value(
252 serde_json::to_string(value).map_err(|e| invalid_argument(e.to_string()))?,
253 ),
254 None => builder.append_null(),
255 }
256 }
257
258 Ok(Arc::new(builder.finish()))
259}
260
261fn encode_u64_column(
262 field: JsonFieldSpec,
263 rows: &[Map<String, Value>],
264) -> Result<ArrayRef, ArrowError> {
265 let mut builder = UInt64Builder::new();
266
267 for row in rows {
268 match require_value(field, row.get(field.name))? {
269 Some(value) => builder.append_value(parse_u64(value)?),
270 None => builder.append_null(),
271 }
272 }
273
274 Ok(Arc::new(builder.finish()))
275}
276
277fn encode_f64_column(
278 field: JsonFieldSpec,
279 rows: &[Map<String, Value>],
280) -> Result<ArrayRef, ArrowError> {
281 let mut builder = Float64Builder::new();
282
283 for row in rows {
284 match require_value(field, row.get(field.name))? {
285 Some(value) => builder.append_value(parse_f64(value)?),
286 None => builder.append_null(),
287 }
288 }
289
290 Ok(Arc::new(builder.finish()))
291}
292
293fn encode_bool_column(
294 field: JsonFieldSpec,
295 rows: &[Map<String, Value>],
296) -> Result<ArrayRef, ArrowError> {
297 let mut builder = BooleanBuilder::new();
298
299 for row in rows {
300 match require_value(field, row.get(field.name))? {
301 Some(value) => builder.append_value(parse_bool(value)?),
302 None => builder.append_null(),
303 }
304 }
305
306 Ok(Arc::new(builder.finish()))
307}
308
309fn require_value(
310 field: JsonFieldSpec,
311 value: Option<&Value>,
312) -> Result<Option<&Value>, ArrowError> {
313 match value {
314 Some(Value::Null) | None if !field.nullable => Err(invalid_argument(format!(
315 "Missing required field `{}`",
316 field.name
317 ))),
318 Some(Value::Null) | None => Ok(None),
319 Some(value) => Ok(Some(value)),
320 }
321}
322
323fn value_to_string(value: &Value) -> Result<String, ArrowError> {
324 match value {
325 Value::String(value) => Ok(value.clone()),
326 Value::Null => Err(invalid_argument("Unexpected null value".to_string())),
327 Value::Bool(_) | Value::Number(_) => Ok(value.to_string()),
328 Value::Array(_) | Value::Object(_) => {
329 serde_json::to_string(value).map_err(|e| invalid_argument(e.to_string()))
330 }
331 }
332}
333
334fn parse_u64(value: &Value) -> Result<u64, ArrowError> {
335 match value {
336 Value::Number(number) => number
337 .as_u64()
338 .ok_or_else(|| invalid_argument(format!("Expected u64, found `{number}`"))),
339 Value::String(value) => value
340 .parse::<u64>()
341 .map_err(|e| invalid_argument(format!("Failed to parse u64 from `{value}`: {e}"))),
342 _ => Err(invalid_argument(format!(
343 "Expected u64-compatible value, found `{value}`"
344 ))),
345 }
346}
347
348fn parse_f64(value: &Value) -> Result<f64, ArrowError> {
349 match value {
350 Value::Number(number) => number
351 .as_f64()
352 .ok_or_else(|| invalid_argument(format!("Expected f64, found `{number}`"))),
353 Value::String(value) => value
354 .parse::<f64>()
355 .map_err(|e| invalid_argument(format!("Failed to parse f64 from `{value}`: {e}"))),
356 _ => Err(invalid_argument(format!(
357 "Expected f64-compatible value, found `{value}`"
358 ))),
359 }
360}
361
362fn parse_bool(value: &Value) -> Result<bool, ArrowError> {
363 match value {
364 Value::Bool(value) => Ok(*value),
365 Value::String(value) => value
366 .parse::<bool>()
367 .map_err(|e| invalid_argument(format!("Failed to parse bool from `{value}`: {e}"))),
368 _ => Err(invalid_argument(format!(
369 "Expected bool-compatible value, found `{value}`"
370 ))),
371 }
372}
373
374enum ColumnRef<'a> {
375 Utf8 {
376 name: &'static str,
377 values: StringColumnRef<'a>,
378 },
379 Utf8Json {
380 name: &'static str,
381 values: StringColumnRef<'a>,
382 },
383 UInt64 {
384 name: &'static str,
385 values: &'a UInt64Array,
386 },
387 Float64 {
388 name: &'static str,
389 values: &'a Float64Array,
390 },
391 Boolean {
392 name: &'static str,
393 values: &'a BooleanArray,
394 },
395}
396
397impl ColumnRef<'_> {
398 fn name(&self) -> &'static str {
399 match self {
400 Self::Utf8 { name, .. }
401 | Self::Utf8Json { name, .. }
402 | Self::UInt64 { name, .. }
403 | Self::Float64 { name, .. }
404 | Self::Boolean { name, .. } => name,
405 }
406 }
407
408 fn to_json(&self, row: usize) -> Result<Value, EncodingError> {
409 match self {
410 Self::Utf8 { values, .. } => {
411 if values_is_null(values, row) {
412 Ok(Value::Null)
413 } else {
414 Ok(Value::String(values.value(row).to_string()))
415 }
416 }
417 Self::Utf8Json { values, .. } => {
418 if values_is_null(values, row) {
419 Ok(Value::Null)
420 } else {
421 serde_json::from_str(values.value(row)).map_err(|e| {
422 EncodingError::ParseError(self.name(), format!("row {row}: {e}"))
423 })
424 }
425 }
426 Self::UInt64 { values, .. } => {
427 if values.is_null(row) {
428 Ok(Value::Null)
429 } else {
430 Ok(Value::Number(Number::from(values.value(row))))
431 }
432 }
433 Self::Float64 { values, .. } => {
434 if values.is_null(row) {
435 Ok(Value::Null)
436 } else {
437 Number::from_f64(values.value(row))
438 .map(Value::Number)
439 .ok_or_else(|| {
440 EncodingError::ParseError(
441 self.name(),
442 format!("row {row}: invalid f64 value"),
443 )
444 })
445 }
446 }
447 Self::Boolean { values, .. } => {
448 if values.is_null(row) {
449 Ok(Value::Null)
450 } else {
451 Ok(Value::Bool(values.value(row)))
452 }
453 }
454 }
455 }
456}
457
458fn decode_column_ref(
459 columns: &[ArrayRef],
460 field: JsonFieldSpec,
461 index: usize,
462) -> Result<ColumnRef<'_>, EncodingError> {
463 match field.encoding {
464 JsonFieldEncoding::Utf8 => Ok(ColumnRef::Utf8 {
465 name: field.name,
466 values: extract_column_string(columns, field.name, index)?,
467 }),
468 JsonFieldEncoding::Utf8Json => Ok(ColumnRef::Utf8Json {
469 name: field.name,
470 values: extract_column_string(columns, field.name, index)?,
471 }),
472 JsonFieldEncoding::UInt64 => Ok(ColumnRef::UInt64 {
473 name: field.name,
474 values: extract_column::<UInt64Array>(columns, field.name, index, DataType::UInt64)?,
475 }),
476 JsonFieldEncoding::Float64 => Ok(ColumnRef::Float64 {
477 name: field.name,
478 values: extract_column::<Float64Array>(columns, field.name, index, DataType::Float64)?,
479 }),
480 JsonFieldEncoding::Boolean => Ok(ColumnRef::Boolean {
481 name: field.name,
482 values: extract_column::<BooleanArray>(columns, field.name, index, DataType::Boolean)?,
483 }),
484 }
485}
486
487fn values_is_null(values: &StringColumnRef<'_>, row: usize) -> bool {
488 match values {
489 StringColumnRef::Utf8(values) => values.is_null(row),
490 StringColumnRef::Utf8View(values) => values.is_null(row),
491 }
492}
493
494fn invalid_argument(message: String) -> ArrowError {
495 ArrowError::InvalidArgumentError(message)
496}