commit 5bb655b7ded055395cd8c1ef281c76f91da159de Author: Ondřej Mekina Date: Tue Mar 4 11:22:53 2025 +0100 initial commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..f95f275 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# Zprávy zumepro + +## Návrh systému + +![diagram návrhu sysému](./docs/exports/system_arch.svg) diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..3a41be3 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1 @@ +dst/ diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..807e3a6 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,22 @@ +.PHONY: build +build: dst .WAIT \ + dst/system_arch.svg + +export: exports .WAIT \ + exports/system_arch.svg + +dst: + mkdir dst + +exports: + mkdir exports + +dst/%.svg: %.dot + dot -Tsvg $< > $@ + +exports/%.svg: dst/%.svg + svgcleaner $< $@ + +.PHONY: clean +clean: + rm -rf dst diff --git a/docs/exports/system_arch.svg b/docs/exports/system_arch.svg new file mode 100644 index 0000000..e971aa3 --- /dev/null +++ b/docs/exports/system_arch.svg @@ -0,0 +1 @@ +REST serverAPI na zprávyREST klient \ No newline at end of file diff --git a/docs/system_arch.dot b/docs/system_arch.dot new file mode 100644 index 0000000..393be60 --- /dev/null +++ b/docs/system_arch.dot @@ -0,0 +1,11 @@ +digraph sysdiag { + graph [fontname="arial"]; + node [fontname="arial"]; + edge [fontname="arial"]; + + RS [shape=box3d,label="REST server"] + ZA [label="API na zprávy"] + RS -> ZA + RC [shape=box,label="REST klient"] + RC -> RS +} diff --git a/lib/inferium/.gitignore b/lib/inferium/.gitignore new file mode 100644 index 0000000..a5ff07f --- /dev/null +++ b/lib/inferium/.gitignore @@ -0,0 +1,8 @@ +/target + + +# Added by cargo +# +# already existing elements were commented out + +#/target diff --git a/lib/inferium/Cargo.lock b/lib/inferium/Cargo.lock new file mode 100644 index 0000000..5424d08 --- /dev/null +++ b/lib/inferium/Cargo.lock @@ -0,0 +1,746 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "addr2line" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "aws-lc-rs" +version = "1.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd755adf9707cf671e31d944a189be3deaaeee11c8bc1d669bb8022ac90fbd0" +dependencies = [ + "aws-lc-sys", + "paste", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f9dd2e03ee80ca2822dd6ea431163d2ef259f2066a4d6ccaca6d9dcb386aa43" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", + "paste", +] + +[[package]] +name = "backtrace" +version = "0.3.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-targets", +] + +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", + "which", +] + +[[package]] +name = "bitflags" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" + +[[package]] +name = "bytes" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" + +[[package]] +name = "cc" +version = "1.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af" +dependencies = [ + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + +[[package]] +name = "either" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7914353092ddf589ad78f25c5c1c21b7f80b0ff8621e7c814c3485b5306da9d" + +[[package]] +name = "errno" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "gimli" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" + +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + +[[package]] +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "inferium" +version = "0.1.0" +dependencies = [ + "proc", + "tokio", + "tokio-rustls", + "webpki-roots", +] + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + +[[package]] +name = "libc" +version = "0.2.169" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" + +[[package]] +name = "libloading" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +dependencies = [ + "cfg-if", + "windows-targets", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3b1c9bd4fe1f0f8b387f6eb9eb3b4a1aa26185e5750efb9140301703f62cd1b" +dependencies = [ + "adler2", +] + +[[package]] +name = "mio" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.52.0", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "object" +version = "0.36.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "prettyplease" +version = "0.2.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc" +version = "0.1.0" + +[[package]] +name = "proc-macro2" +version = "1.0.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "ring" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da5349ae27d3887ca812fb375b45a4fbb36d8d12d2df394968cd86e35683fe73" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustls" +version = "0.23.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" +dependencies = [ + "aws-lc-rs", + "log", + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" + +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "socket2" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tokio" +version = "1.43.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.52.0", +] + +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "unicode-ident" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "webpki-roots" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" diff --git a/lib/inferium/Cargo.toml b/lib/inferium/Cargo.toml new file mode 100644 index 0000000..6946ca2 --- /dev/null +++ b/lib/inferium/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "inferium" +version = "0.1.0" +edition = "2021" +description = "A small HTTP library" +authors = ["Ondřej Mekina "] +keywords = ["http", "inferium", "zumepro"] +categories = ["network-programming", "web-programming:http-client", "web-programming:http-server"] + +[dependencies] +proc = { path = "./proc/" } +tokio = { version = "1.43.0", features = ["full"], optional = true } +tokio-rustls = { version = "0.26.1", optional = true } +webpki-roots = { version = "0.26.8", optional = true } + +[features] +full = ["async", "tokio-full"] +dev = ["full", "testing", "dep:webpki-roots"] +async = [] +tokio-full = ["async", "tokio-net", "tokio-unixsocks", "tokio-tls"] +tokio-net = ["async", "dep:tokio"] +tokio-unixsocks = ["async", "dep:tokio"] +tokio-tls = ["dep:tokio-rustls", "tokio-net"] +testing = [] diff --git a/lib/inferium/README.md b/lib/inferium/README.md new file mode 100644 index 0000000..d58dd1e --- /dev/null +++ b/lib/inferium/README.md @@ -0,0 +1,3 @@ +# inferium + +A small HTTP library written in Rust diff --git a/lib/inferium/benches/client.rs b/lib/inferium/benches/client.rs new file mode 100644 index 0000000..95ff914 --- /dev/null +++ b/lib/inferium/benches/client.rs @@ -0,0 +1,60 @@ +#![feature(test)] +extern crate test; +use std::collections::HashMap; + +use test::Bencher; + +extern crate inferium; +use inferium::h1::{SyncClient, Response, ResponseHead, ProtocolVariant}; +use inferium::{Status, HeaderValue}; +use inferium::TestSyncStream; + +fn parse_response_sync_inner() { + let src = "HTTP/1.1 200 OK\r\nserver: inferium\r\n\r\n".as_bytes().to_vec(); + let stream = TestSyncStream::<4>::new(&src); + let mut client = SyncClient::>::new(stream); + let target = Response::HeadersOnly(ResponseHead::new( + Status::Ok, + ProtocolVariant::HTTP1_1, + HashMap::from([ + ("server".into(), HeaderValue::new(vec!["inferium".to_string()])) + ]) + )); + assert_eq!(client.receive_response().unwrap(), target); +} + +fn parse_response_sync_inner_body() { + let mut src = "HTTP/1.1 200 OK\r\ncontent-length: 50\r\n\r\n".as_bytes().to_vec(); + src.extend_from_slice(b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + let stream = TestSyncStream::<4>::new(&src); + let mut client = SyncClient::>::new(stream); + let target_head = ResponseHead::new( + Status::Ok, + ProtocolVariant::HTTP1_1, + HashMap::from([ + ("content-length".into(), HeaderValue::new(vec!["50".to_string()])) + ]) + ); + let Response::WithSizedBody((h, mut b)) = client.receive_response().unwrap() else { + panic!(); + }; + let b = b.recv_all().unwrap(); + assert_eq!(h, target_head); + assert_eq!(b, b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); +} + +#[bench] +fn parse_response_sync(b: &mut Bencher) { + b.bytes = 37; + b.iter(|| { + parse_response_sync_inner(); + }); +} + +#[bench] +fn parse_response_sync_with_body(b: &mut Bencher) { + b.bytes = 39 + 13; + b.iter(|| { + parse_response_sync_inner_body(); + }); +} diff --git a/lib/inferium/benches/status_code_parsing.rs b/lib/inferium/benches/status_code_parsing.rs new file mode 100644 index 0000000..5d0fc18 --- /dev/null +++ b/lib/inferium/benches/status_code_parsing.rs @@ -0,0 +1,43 @@ +#![feature(test)] +extern crate test; +use test::Bencher; + +extern crate inferium; +use inferium::Status; + +#[bench] +fn baseline(b: &mut Bencher) { + fn status_test_from_slice(raw: &[u8]) -> Result { + match raw { + b"200" => Ok(Status::Ok), + b"404" => Ok(Status::NotFound), + _ => Err(()), + } + } + + b.iter(|| { + assert_eq!(status_test_from_slice(b"200".as_slice()), Ok(Status::Ok)); + }); +} + +#[bench] +fn valid_ok(b: &mut Bencher) { + b.iter(|| { + assert_eq!(Status::try_from(b"200".as_slice()), Ok(Status::Ok)); + }); + +} + +#[bench] +fn valid_internal_server_error(b: &mut Bencher) { + b.iter(|| { + assert_eq!(Status::try_from(b"500".as_slice()), Ok(Status::InternalServerError)); + }); +} + +#[bench] +fn invalid(b: &mut Bencher) { + b.iter(|| { + assert!(Status::try_from(b"690".as_slice()).is_err()); + }) +} diff --git a/lib/inferium/examples/going_async.rs b/lib/inferium/examples/going_async.rs new file mode 100644 index 0000000..946b6e6 --- /dev/null +++ b/lib/inferium/examples/going_async.rs @@ -0,0 +1,94 @@ +// This is an async port of `examples/simple_server.rs`. Please see that example first. +// +// Features `async` and `tokio-net` must be enabled for this example to compile. +// It is also possible to enable feature `full` (which will enable all the features). + +use std::{collections::HashMap, net::SocketAddr}; +use tokio::net::{TcpListener, TcpStream}; +use inferium::{ + h1::{ProtocolVariant, Request, ResponseHead, ServerSendError, AsyncServer}, + HeaderKey, Method, Status, TokioInet +}; + +#[tokio::main] +async fn main() { + let listener = TcpListener::bind("localhost:8080").await.unwrap(); + loop { + let (conn, addr) = listener.accept().await.unwrap(); + // Here we are creating a new asynchronous task for every client. + // This will fork off in an asynchronous manner and won't block our accept loop. + tokio::task::spawn(async move { + // We created this new async block, so we need to `.await` on this function to propagate + // the future from the function to the top of the spawned task (this async block). + handle_client(conn, addr).await; + }); + // You can now handle multiple clients at once... congratulations. + } +} + +async fn handle_client(conn: TcpStream, addr: SocketAddr) { + println!("connection from {addr:?}"); + let mut server_handler = AsyncServer::::new(TokioInet::new(conn)); + // When receiving or sending - we call the same functions with `.await` appended (in an async + // context). This will automatically poll the returned future from the top of the context. + // The polling is handled by tokio here - so we don't need to worry about it. + while let Ok(request) = server_handler.receive_request().await { + match handle_request(request, addr) { + Ok((h, b)) => if let Err(_) = send_response(h, b, &mut server_handler).await { break; }, + Err(()) => break, + } + }; + println!("ended connection for {addr:?}"); +} + +fn handle_request( + req: Request, addr: SocketAddr +) -> Result<(ResponseHead, &'static [u8]), ()> { + let Request::HeadersOnly(headers) = req else { + return Err(()); + }; + + println!("req from {addr:?}: {headers}"); + + const OK_RESPONSE: &[u8] = b" + + +

Hello, world!

+

Hello from inferium.

+ +"; + const NOT_FOUND_RESPONSE: &[u8] = b" + + +

Not found

+

This page was not found

+ +"; + + Ok(match (headers.method(), headers.uri().path()) { + (&Method::GET, "/") => (ResponseHead::new( + Status::Ok, + ProtocolVariant::HTTP1_0, + HashMap::from([ + (HeaderKey::SERVER, "inferium".parse().unwrap()), + (HeaderKey::CONTENT_LENGTH, OK_RESPONSE.len().into()), + ]) + ), OK_RESPONSE), + + _ => (ResponseHead::new( + Status::NotFound, + ProtocolVariant::HTTP1_0, + HashMap::from([ + (HeaderKey::SERVER, "inferium".parse().unwrap()), + (HeaderKey::CONTENT_LENGTH, NOT_FOUND_RESPONSE.len().into()), + ]) + ), NOT_FOUND_RESPONSE), + }) +} + +async fn send_response( + response: ResponseHead, body: &[u8], conn: &mut AsyncServer +)-> Result<(), ServerSendError> { + conn.send_response(&response).await?; + conn.send_body_bytes(body).await.map_err(|e| e.try_into().unwrap()) +} diff --git a/lib/inferium/examples/https_client.rs b/lib/inferium/examples/https_client.rs new file mode 100644 index 0000000..7d6224c --- /dev/null +++ b/lib/inferium/examples/https_client.rs @@ -0,0 +1,88 @@ +// This is a port of a client from `examples/start_here.rs`. Please see that example first. +// Also... maybe brush up on some async tasks since we are going to need them here (there is an +// example on async with inferium in `examples/going_async.rs`). +// +// Features `async`, `tokio-net` and `webpki-roots` dependency must be enabled for this example to +// compile. We recommend enabling the `dev` feature when running this example. + +use std::{collections::HashMap, sync::Arc}; +use tokio::net::TcpStream; +use tokio_rustls::{ + rustls::{ + pki_types::ServerName, + ClientConfig, + RootCertStore + }, + TlsConnector, + TlsStream +}; +use inferium::{ + h1::{ + ProtocolVariant, + RequestHead, + Response, + AsyncClient + }, + HeaderKey, + Method, + TokioRustls +}; + +async fn run_tls_handshake(raw_stream: TcpStream) -> TlsStream { + let mut root_certs = RootCertStore::empty(); + root_certs.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let config = ClientConfig::builder() + .with_root_certificates(root_certs) + .with_no_client_auth(); + let connector = TlsConnector::from(Arc::new(config)); + let verify_server_name = ServerName::try_from("zumepro.cz").unwrap(); + TlsStream::Client(connector.connect(verify_server_name, raw_stream).await.unwrap()) +} + +#[tokio::main] +async fn main() { + let stream = TcpStream::connect("zumepro.cz:443").await.unwrap(); + let stream = run_tls_handshake(stream).await; + let conn = TokioRustls::new(stream); + let mut client = AsyncClient::::new(conn); + + let to_send = RequestHead::new( + Method::GET, "/".parse().unwrap(), ProtocolVariant::HTTP1_1, + HashMap::from([ + (HeaderKey::USER_AGENT, "Mozilla/5.0 (inferium)".parse().unwrap()), + (HeaderKey::HOST, "zumepro.cz".parse().unwrap()), + (HeaderKey::CONNECTION, "close".parse().unwrap()) + ]) + ); + println!("----------> Sending\n\n{to_send}\n"); + client.send_request(&to_send).await.unwrap(); + let response = client.receive_response().await.unwrap(); + + let (header, body) = match response { + Response::HeadersOnly(h) => (h, None), + Response::WithSizedBody((_, _)) => panic!(), + Response::WithChunkedBody((h, b)) => (h, Some(b)), + }; + + println!("----------< Received\n\n{header}\n"); + + if let Some(mut body) = body { + + // Since our zumepro server sends bodies chunked - we will need to handle it. + // This simple loop just collects all the chunks into the res vector. + let mut res = Vec::new(); + while let Some(mut chunk) = body.get_chunk_async().await.unwrap() { + // Now here is a difference between sync and async. + // + // For easy body manipulation and no redundant trait pollution, a body within an + // asynchronous stream can be sent/received using the same methods as a synchronous one, + // but with the suffix `_async`. + res.append(&mut chunk.recv_all_async().await.unwrap()); + } + + println!( + "----------< Body\n\n{:?}\n", + std::str::from_utf8(&res).unwrap() + ); + } +} diff --git a/lib/inferium/examples/simple_server.rs b/lib/inferium/examples/simple_server.rs new file mode 100644 index 0000000..ddbb821 --- /dev/null +++ b/lib/inferium/examples/simple_server.rs @@ -0,0 +1,84 @@ +use std::{collections::HashMap, net::{SocketAddr, TcpListener}}; +use inferium::{ + h1::{ProtocolVariant, Request, ResponseHead, ServerSendError, SyncServer}, + HeaderKey, Method, Status, StdInet +}; + +fn main() { + let listener = TcpListener::bind("localhost:8080").unwrap(); + loop { + let (conn, addr) = listener.accept().unwrap(); + println!("connection from {addr:?}"); + let mut server_handler = SyncServer::::new(StdInet::new(conn)); + // We'll serve the client as long as it sends valid requests. + // Note that this will effectively block other clients. + while let Ok(request) = server_handler.receive_request() { + // This matching is here to provide a way of controlling the while loop. + match handle_request(request, addr) { + Ok((h, b)) => if let Err(_) = send_response(h, b, &mut server_handler) { break; }, + Err(()) => break, + } + }; + println!("ended connection for {addr:?}"); + } +} + +fn handle_request( + req: Request, addr: SocketAddr +) -> Result<(ResponseHead, &'static [u8]), ()> { + let Request::HeadersOnly(headers) = req else { + // We will not handle POST requests with bodies - so let's tell the client to f*ck off. + return Err(()); + }; + + println!("req from {addr:?}: {headers}"); + + const OK_RESPONSE: &[u8] = b" + + +

Hello, world!

+

Hello from inferium.

+ +"; + const NOT_FOUND_RESPONSE: &[u8] = b" + + +

Not found

+

This page was not found

+ +"; + + // The URI can contain both path and parameters - so we're just getting the path here. + Ok(match (headers.method(), headers.uri().path()) { + // The ok response with our index page + (&Method::GET, "/") => (ResponseHead::new( + Status::Ok, + ProtocolVariant::HTTP1_0, + HashMap::from([ + (HeaderKey::SERVER, "inferium".parse().unwrap()), + (HeaderKey::CONTENT_LENGTH, OK_RESPONSE.len().into()), + ]) + ), OK_RESPONSE), + + // The not found response with an example not found page + _ => (ResponseHead::new( + Status::NotFound, + ProtocolVariant::HTTP1_0, + HashMap::from([ + (HeaderKey::SERVER, "inferium".parse().unwrap()), + (HeaderKey::CONTENT_LENGTH, NOT_FOUND_RESPONSE.len().into()), + ]) + ), NOT_FOUND_RESPONSE), + }) +} + +fn send_response( + response: ResponseHead, body: &[u8], conn: &mut SyncServer +)-> Result<(), ServerSendError> { + conn.send_response(&response)?; + // The send body can fail on an I/O error or if the content-length header does not match the + // actual sent length in this scenario. But we know that we have the correct length so with + // `.try_into().unwrap()` we tell inferium to convert the error and panic on (not so much) + // possible body length discrepancy. + conn.send_body_bytes(body).map_err(|e| e.try_into().unwrap()) +} diff --git a/lib/inferium/examples/start_here.rs b/lib/inferium/examples/start_here.rs new file mode 100644 index 0000000..2ce5dd2 --- /dev/null +++ b/lib/inferium/examples/start_here.rs @@ -0,0 +1,101 @@ +// Hello, and welcome to inferium. A performance-oriented small HTTP library written in Rust that +// keeps you (the user) in charge. + +// Let's first import some necessary things. + +// In inferium - HashMaps are used to store uri parameters and headers. +use std::collections::HashMap; +// TcpStream is needed if we want to connect to the internet. +use std::net::TcpStream; + +use inferium::{ + // The h1 module contains all the protocol specific things for HTTP/1.(0/1). + h1::{ + // ProtocolVariant contains variants with the protocol versions supported in this module: + // - HTTP/1.1 + // - HTTP/1.0 + ProtocolVariant, + // RequestHead contains headers and the HTTP request headline (method, path, protocol). + RequestHead, + // Response here is a wrapper for a response that can have the following: + // - Headers only + // - Headers and body (with a known length or chunked) + // + // ! The body in the response is not yet collected. It is up to you if you wish to discard + // the connection or receive and collect the response body into some structure. + // + // The same things here go for the Request object which is nearly the same except that the + // headline contains protocol and status instead. + Response, + // Sync client is a stream wrapper that helps us keep track of the open connection and + // perform request/response operations. + // + // The server equivalent is SyncServer. + SyncClient + }, + // Header key contains various known header keys, but can also store arbitrary (unknown) header + // key in the OTHER variant. + HeaderKey, + Method, + // StdInet here is a stream wrapper that allows the TcpStream to be used by inferium. There is + // also a unix socket equivalent and some asynchronous io wrappers. + StdInet +}; + +fn main() { + // Let's first create a connection... nothing weird here. + let conn = StdInet::new(TcpStream::connect("zumepro.cz:80").unwrap()); + // And a client... + let mut client = SyncClient::::new(conn); + + // Now let's create a request to send + let to_send = RequestHead::new( + // The path here is parsed into an HTTP path object (which also supports parameters) + // I'm using HTTP/1.0 in this example as HTTP/1.1 automatically infers a compatibility with + // chunked encoding (which I'm not even trying to handle here). + Method::GET, "/".parse().unwrap(), ProtocolVariant::HTTP1_0, + HashMap::from([ + // All headers are HeaderKey - HeaderValue pairs. We can parse the header value into + // the desired object. + // + // Constructing arbitrary header key is supported using the OTHER variant - however + // it's not recommended as a violation of the HTTP protocol can happen. + // + // If you really want to construct an arbitrary header key - please carefully check + // that all of the symbols are valid. + (HeaderKey::USER_AGENT, "Mozilla/5.0 (inferium)".parse().unwrap()), + (HeaderKey::HOST, "zumepro.cz".parse().unwrap()), + ]) + ); + println!("----------> Sending\n\n{to_send}\n"); + // Let's send the request - this is pretty straightforward. + client.send_request(&to_send).unwrap(); + // As is receiving a response. + let response = client.receive_response().unwrap(); + + // Now (as we discussed earlier) - the response can have a body. + // In this example we'll try to handle a basic body with a known size. + let (header, body) = match response { + // Extracting the headers if no body is present. + Response::HeadersOnly(h) => (h, None), + // Extracting both the headers and the body if body is present. + Response::WithSizedBody((h, b)) => (h, Some(b)), + // We will not handle chunked responses in this example. + Response::WithChunkedBody((_, _)) => panic!(), + }; + + // inferium kindly provides a simple way to print the head of a request/response. + // It will be formatted pretty close to the actual protocol plaintext representation. + println!("----------< Received\n\n{header}\n"); + + // And finally... if we have a body, we'll print it. + if let Some(mut body) = body { + println!( + "----------< Body\n\n{:?}\n", + // A body is always returned in bytes. It's up to you to decode it however you see fit. + std::str::from_utf8(&mut body.recv_all().unwrap()).unwrap() + ); + } + + // And you're done. Come on... try to run it. +} diff --git a/lib/inferium/proc/.gitignore b/lib/inferium/proc/.gitignore new file mode 100644 index 0000000..2f7896d --- /dev/null +++ b/lib/inferium/proc/.gitignore @@ -0,0 +1 @@ +target/ diff --git a/lib/inferium/proc/Cargo.lock b/lib/inferium/proc/Cargo.lock new file mode 100644 index 0000000..ae45db4 --- /dev/null +++ b/lib/inferium/proc/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "proc" +version = "0.1.0" diff --git a/lib/inferium/proc/Cargo.toml b/lib/inferium/proc/Cargo.toml new file mode 100644 index 0000000..3577133 --- /dev/null +++ b/lib/inferium/proc/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "proc" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] diff --git a/lib/inferium/proc/src/lib.rs b/lib/inferium/proc/src/lib.rs new file mode 100644 index 0000000..ab59f8c --- /dev/null +++ b/lib/inferium/proc/src/lib.rs @@ -0,0 +1,175 @@ +use proc_macro::{Delimiter, TokenStream, TokenTree}; +extern crate proc_macro; + +#[proc_macro_derive(AutoimplHkeys)] +pub fn autoimpl_hkeys(item: TokenStream) -> TokenStream { + let mut iter = item.into_iter(); + let data = autoimpl_names_get_idents(&mut iter).expect("could not parse the name enum"); + + let obj_ident = data.obj_ident; + let mut cases_from = String::new(); + let mut cases_into = String::new(); + let mut cases_text = String::new(); + let mut cases_display = String::new(); + for entry in data.names { + let entry_transformed = autoimpl_hkeys_transform(&entry); + cases_from.push_str(&format!("{entry_transformed:?} => Self::{entry},\n")); + cases_into.push_str(&format!( + "{obj_ident}::{entry} => String::from({entry_transformed:?}),\n" + )); + cases_text.push_str(&format!("{obj_ident}::{entry} => b{entry_transformed:?},\n")); + cases_display.push_str( + &format!("{obj_ident}::{entry} => write!(f, {entry_transformed:?}),") + ); + } + + format!("impl From<&str> for {obj_ident} {{ + fn from(s: &str) -> Self {{ + match s.to_lowercase().as_str() {{ + {cases_from} + v => Self::OTHER(v.to_string()), + }} + }} + }} + + impl From<{obj_ident}> for String {{ + fn from(val: {obj_ident}) -> Self {{ + match val {{ + {cases_into} + {obj_ident}::OTHER(v) => v.clone(), + }} + }} + }} + + impl {obj_ident} {{ + pub(crate) fn text(&self) -> &[u8] {{ + match self {{ + {cases_text} + Self::OTHER(v) => v.as_bytes(), + }} + }} + }} + + impl std::fmt::Display for {obj_ident} {{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {{ + match self {{ + {cases_display} + Self::OTHER(v) => write!(f, \"{{v}}\"), + }} + }} + }}").parse().expect("could not parse the created autoimpl") +} + +#[proc_macro_derive(AutoimplMethods)] +pub fn autoimpl_methods(item: TokenStream) -> TokenStream { + let mut iter = item.into_iter(); + let data = autoimpl_names_get_idents(&mut iter).expect("could not parse the name enum"); + + let obj_ident = data.obj_ident; + let mut cases_from = String::new(); + let mut cases_into = String::new(); + let mut cases_text = String::new(); + for entry in data.names { + cases_from.push_str(&format!("{entry:?} => Ok(Self::{entry}),\n")); + cases_into.push_str(&format!("{obj_ident}::{entry} => {entry:?},\n")); + cases_text.push_str(&format!("{obj_ident}::{entry} => b{entry:?},\n")); + } + + format!("impl std::str::FromStr for {obj_ident} {{ + type Err = (); + + fn from_str(s: &str) -> Result {{ + match s {{ + {cases_from} + _ => Err(()), + }} + }} + }} + + impl From<{obj_ident}> for &'static str {{ + fn from(val: {obj_ident}) -> Self {{ + match val {{ + {cases_into} + }} + }} + }} + + impl {obj_ident} {{ + pub(crate) fn text(&self) -> &'static [u8] {{ + match self {{ + {cases_text} + }} + }} + }}").parse().expect("could not parse the created autoimpl") +} + +fn autoimpl_hkeys_transform(name: &str) -> String { + name.to_lowercase().replace("_", "-") +} + +macro_rules! get_tok { + (req $($rest:tt)+) => { + get_tok!($($rest)+)?.ok_or(())? + }; + + ($iter:ident == $t:ident $val:literal) => {{ + get_tok!($t $iter).map( + |v| v.map(|v| v.to_string() == $val) + ).unwrap_or_else(|_| Some(false)) + }}; + + ($iter:ident != $t:ident $val:literal) => { + get_tok!($t $iter).map( + |v| v.map(|v| v.to_string() != $val) + ).unwrap_or_else(|_| Some(false)) + }; + + ($t:ident $iter:ident) => {{ + match $iter.next() { + Some(TokenTree::$t(val)) => Ok(Some(val)), + None => Ok(None), + _ => Err(()), + } + }}; +} + +struct NameData { + obj_ident: String, + names: Vec, +} + +fn autoimpl_names_get_idents>( + iter: &mut I +) -> Result { + while let Some(is_enum_start) = get_tok!(iter == Ident "enum") { + if is_enum_start { + break; + } + } + let obj_ident = get_tok!(req Ident iter).to_string(); + let mut iter = { + let group = get_tok!(req Group iter); + if group.delimiter() != Delimiter::Brace { + return Err(()); + } + group.stream().into_iter() + }; + + let mut names = Vec::new(); + loop { + let Some(name) = get_tok!(Ident iter)? else { + break; + }; + if name.to_string() == "OTHER" { + break; + } + names.push(name.to_string()); + match get_tok!(iter == Punct ",") { + Some(true) => {}, + Some(false) => Err(())?, + None => break, + } + } + + Ok(NameData { obj_ident, names }) +} diff --git a/lib/inferium/src/body.rs b/lib/inferium/src/body.rs new file mode 100644 index 0000000..df486d1 --- /dev/null +++ b/lib/inferium/src/body.rs @@ -0,0 +1,1130 @@ +use crate::{ + io::{PrependableStream, ReaderError, ReaderValue, Receive, SyncReader}, + settings::BUF_SIZE_BODY +}; +#[cfg(feature = "async")] +use {crate::io::{AsyncReceive, AsyncSend, AsyncReader}, std::future::Future}; + +/// A chunked transport encoding incoming body. +/// +/// This object returns [`IncomingChunk`]s which can then be received on separately. +#[derive(Debug, PartialEq)] +pub struct ChunkedIn<'a, S> { + transport_stream: &'a mut S, + can_create_next: bool, +} + +/// An error that can be returned when trying to get information about a chunk +#[derive(Debug)] +pub enum ChunkedInReceiveError { + InvalidState, + InvalidLength, + InvalidTrailer, + NoData, + BufferOverflow, + IO(std::io::Error), +} + +impl From for ChunkedInReceiveError { + fn from(value: std::io::Error) -> Self { + Self::IO(value) + } +} + +/// Each incoming chunk has a prefixed length and when receiving tries to consume an `\r\n` at the +/// end. +#[cfg_attr(any(test, feature = "testing"), derive(PartialEq))] +#[derive(Debug)] +pub struct IncomingChunk<'a, S> { + /// The advertised size of this chunk + size_total: usize, + /// How many bytes of content to read in this chunk + remaining: usize, + buf: &'a mut [u8], + transport_stream: &'a mut S, + is_done: &'a mut bool, +} + +#[derive(Debug)] +pub enum ChunkFullReadError { + /// Returned if the trailer after the received chunk is not precisely `\r\n`. + InvalidTrailer, + IO(std::io::Error), +} + +impl From for ChunkFullReadError { + fn from(value: std::io::Error) -> Self { + Self::IO(value) + } +} + +impl<'a, S> ChunkedIn<'a, S> { + pub(crate) fn new(transport_stream: &'a mut S) -> Self { + Self { transport_stream, can_create_next: true } + } +} + +impl<'a, S> IncomingChunk<'a, S> { + pub(self) fn new( + transport_stream: &'a mut S, + buf: &'a mut [u8], + size_total: usize, + is_done: &'a mut bool, + ) -> Self { + Self { + transport_stream, + buf, + size_total, + is_done, + remaining: size_total + 2, + } + } +} + +impl IncomingChunk<'_, S> { + /// Get the advertised size of this chunk. + /// + /// Note that this does not have to match the actual received size on poll (the received size + /// can be smaller than this number). + #[inline] + pub fn len(&self) -> &usize { + &self.size_total + } +} + +impl IncomingChunk<'_, PrependableStream> { + /// Receive the advertised length or less (if the transport stream ended prematurely). + pub fn recv_all(&mut self) -> Result, ChunkFullReadError> { + let mut res = Vec::with_capacity(self.size_total); + let mut last_chunk_size: usize = usize::MAX; + let mut to_prepend: &[u8] = &[]; + while self.remaining > 0 && last_chunk_size != 0 { + let chunk_size = self.transport_stream.recv(self.buf)?; + last_chunk_size = chunk_size; + let real_read = std::cmp::min(chunk_size, self.remaining); + res.extend_from_slice(&self.buf[0..real_read]); + to_prepend = &self.buf[real_read..std::cmp::max(chunk_size, self.remaining)]; + self.remaining = self.remaining.saturating_sub(chunk_size); + } + self.transport_stream.prepend_to_read(to_prepend); + *self.is_done = true; + if res.len() < 2 || (res.pop(), res.pop()) != (Some(b'\n'), Some(b'\r')) { + return Err(ChunkFullReadError::InvalidTrailer); + } + Ok(res) + } +} + +#[cfg(feature = "async")] +impl IncomingChunk<'_, PrependableStream> { + /// Receive the advertised length or less (if the transport stream ended prematurely). + pub async fn recv_all_async(&mut self) -> Result, ChunkFullReadError> { + let mut res = Vec::with_capacity(self.size_total); + let mut last_chunk_size: usize = usize::MAX; + let mut to_prepend: &[u8] = &[]; + while self.remaining > 0 && last_chunk_size != 0 { + let chunk_size = self.transport_stream.recv(self.buf).await?; + last_chunk_size = chunk_size; + let real_read = std::cmp::min(chunk_size, self.remaining); + res.extend_from_slice(&self.buf[0..real_read]); + to_prepend = &self.buf[real_read..std::cmp::max(chunk_size, self.remaining)]; + self.remaining = self.remaining.saturating_sub(chunk_size); + } + self.transport_stream.prepend_to_read(to_prepend); + *self.is_done = true; + if res.len() < 2 || (res.pop(), res.pop()) != (Some(b'\n'), Some(b'\r')) { + return Err(ChunkFullReadError::InvalidTrailer); + } + Ok(res) + } +} + + +/// An incoming body with known (or at least advertised) length. +/// +/// This MUST NOT be a _chunked_ transport encoding and is usually denoted by the _content-length_ +/// header. +/// +/// Note that the advertised size may not correspond to the actually received data when polling the +/// connection. Inferium will just make sure that the received data does not exceed (or exceeds +/// just by at most one buffer) the advertised length. +/// +/// When polling using the [`Incoming::recv_all`] or [`Incoming::recv_all_async`] method, the data +/// exceeding the advertised length will be stripped to the desired length. +/// +/// The total data size after transfer can be less then the advertised length. Inferium will allow +/// you to receive the data, however your implementation may wish to check for data length +/// discrepancies and react accordingly. +#[derive(Debug, PartialEq)] +pub struct SizedIn<'a, S> { + size_total: usize, + remaining: usize, + transport_stream: &'a mut S, +} + +impl<'a, S> SizedIn<'a, S> { + /// Create a new body receiver with a known maximum size. + pub(crate) fn new(transport_stream: &'a mut S, max_length: usize) -> Self { + Self { transport_stream, size_total: max_length, remaining: max_length } + } +} + +impl SizedIn<'_, PrependableStream> { + pub(crate) fn prepend_to_stream(&mut self, to_prenend: &[u8]) { + self.transport_stream.prepend_to_read(to_prenend); + } +} + +/// An outgoing body with known length. +/// +/// The underlying `data_source` will be polled until the desired length is reached or until the +/// iterator returns [`None`]. +/// +/// The data is continuously flushed to the `transport_stream` with the same segmentation as the +/// data returned by the `data_source` iterator. +pub(crate) struct SizedOut<'a, D, S> { + size_total: usize, + remaining: usize, + data_source: &'a mut D, + transport_stream: &'a mut S, +} + +impl<'a, D, S> SizedOut<'a, D, S> { + /// Create a new body sender with a known maximum size. + pub fn new(data_source: &'a mut D, transport_stream: &'a mut S, length: usize) -> Self { + Self { data_source, transport_stream, size_total: length, remaining: length } + } +} + +#[derive(Debug, PartialEq)] +enum BodyBuf<'a> { + Internal(Vec), + External(&'a mut [u8]), + Unitialized, +} + +/// Incoming HTTP body. +/// +/// This is returned upon a successful header transfer and parsing. The inner type denotes the type +/// of received body. Since the body can have an unknown size on transfer - this [`Incoming`] type +/// does allow to statically alter the read behavior depending on the transfer encoding. +#[derive(Debug, PartialEq)] +pub struct Incoming<'a, T> { + inner: T, + buf: BodyBuf<'a>, + buf_size: usize, + exhausted: &'a mut bool, +} + +impl <'a, 'b, S: Receive> Incoming<'b, SizedIn<'a, S>> { + pub(crate) fn new(inner: SizedIn<'a, S>, exhausted: &'b mut bool) -> Self { + Self { inner, buf: BodyBuf::Unitialized, buf_size: BUF_SIZE_BODY, exhausted } + } + + pub(crate) fn with_buf_size(inner: SizedIn<'a, S>, buf_size: usize, exhausted: &'b mut bool) -> Self { + Self { inner, buf: BodyBuf::Unitialized, buf_size, exhausted } + } + + pub(crate) fn with_buf(inner: SizedIn<'a, S>, buf: &'b mut [u8], exhausted: &'b mut bool) -> Self { + let buf_size = buf.len(); + Self { inner, buf: BodyBuf::External(buf), buf_size, exhausted } + } +} + +impl<'a, 'b, S: Receive> Incoming<'b, ChunkedIn<'a, S>> { + pub(crate) fn new(inner: ChunkedIn<'a, S>, exhausted: &'b mut bool) -> Self { + Self { inner, buf: BodyBuf::Unitialized, buf_size: BUF_SIZE_BODY, exhausted } + } + + pub(crate) fn with_buf_size( + inner: ChunkedIn<'a, S>, buf_size: usize, exhausted: &'b mut bool + ) -> Self { + Self { inner, buf: BodyBuf::Unitialized, buf_size, exhausted } + } + + pub(crate) fn with_buf(inner: ChunkedIn<'a, S>, buf: &'b mut [u8], exhausted: &'b mut bool) -> Self { + let buf_size = buf.len(); + Self { inner, buf: BodyBuf::External(buf), buf_size, exhausted } + } +} + +#[cfg(feature = "async")] +impl<'a, 'b, S: AsyncReceive> Incoming<'b, SizedIn<'a, S>> { + /// Create a new body receiver with a known maximum size. + pub(crate) fn new_async(inner: SizedIn<'a, S>, exhausted: &'b mut bool) -> Self { + Self { inner, buf: BodyBuf::Unitialized, buf_size: BUF_SIZE_BODY, exhausted } + } + + pub(crate) fn with_buf_size_async( + inner: SizedIn<'a, S>, buf_size: usize, exhausted: &'b mut bool + ) -> Self { + Self { inner, buf: BodyBuf::Unitialized, buf_size, exhausted } + } +} + +#[cfg(feature = "async")] +impl<'a, 'b, S: AsyncReceive> Incoming<'b, ChunkedIn<'a, S>> { + /// Create a new body receiver with a chunked encoding. + pub(crate) fn new_async(inner: ChunkedIn<'a, S>, exhausted: &'b mut bool) -> Self { + Self { inner, buf: BodyBuf::Unitialized, buf_size: BUF_SIZE_BODY, exhausted } + } + + pub(crate) fn with_buf_size_async( + inner: ChunkedIn<'a, S>, buf_size: usize, exhausted: &'b mut bool + ) -> Self { + Self { inner, buf: BodyBuf::Unitialized, buf_size, exhausted } + } +} + +#[cfg(feature = "async")] +impl <'a, 'b, S: AsyncReceive> Incoming<'b, SizedIn<'a, S>> { + pub(crate) fn with_buf_async( + inner: SizedIn<'a, S>, buf: &'b mut [u8], exhausted: &'b mut bool + ) -> Self { + let buf_size = buf.len(); + Self { inner, buf: BodyBuf::External(buf), buf_size, exhausted } + } +} + +#[cfg(feature = "async")] +impl <'a, 'b, S: AsyncReceive> Incoming<'b, ChunkedIn<'a, S>> { + pub(crate) fn with_buf_async( + inner: ChunkedIn<'a, S>, buf: &'b mut [u8], exhausted: &'b mut bool + ) -> Self { + let buf_size = buf.len(); + Self { inner, buf: BodyBuf::External(buf), buf_size, exhausted } + } +} + +macro_rules! get_buf { + ($self: ident) => {{ + match $self.buf { + BodyBuf::Internal(ref mut val) => val.as_mut_slice(), + BodyBuf::External(ref mut val) => val, + BodyBuf::Unitialized => get_buf!(@make_buf $self), + } + }}; + + (@make_buf $self: ident) => {{ + $self.buf = BodyBuf::Internal(vec![0_u8; $self.buf_size]); + let BodyBuf::Internal(ref mut buf) = $self.buf else { + unreachable!(); + }; + buf.as_mut_slice() + }}; +} + +impl Incoming<'_, SizedIn<'_, S>> { + /// Get the advertised length. + /// + /// Note that when reading from this resource - it can yield less data than is advertised via + /// this method. + #[inline] + pub fn len(&self) -> &usize { + &self.inner.size_total + } +} + +const PAT_CRLF: &[u8] = b"\r\n"; +const JUT_CRLF: &[usize] = &[0, 0]; + +enum ReceiveUntilError { + NoData, + BufferOverflow, + IO(std::io::Error), +} + +enum ConsumeNError { + IO(std::io::Error), +} + +impl From for ChunkedInReceiveError { + fn from(value: ReceiveUntilError) -> Self { + match value { + ReceiveUntilError::BufferOverflow => Self::BufferOverflow, + ReceiveUntilError::NoData => Self::NoData, + ReceiveUntilError::IO(e) => Self::IO(e), + } + } +} + +impl From for ChunkedInReceiveError { + fn from(value: ConsumeNError) -> Self { + match value { + ConsumeNError::IO(e) => Self::IO(e), + } + } +} + +impl From for ConsumeNError { + fn from(value: std::io::Error) -> Self { + Self::IO(value) + } +} + +fn sync_receive_until_exact<'a, S: Receive>( + buf: &'a mut [u8], pat: &[u8], jumptable: &[usize], transport_stream: &mut PrependableStream +) -> Result<&'a [u8], ReceiveUntilError> { + let mut reader = SyncReader::new(transport_stream); + match reader.recv_until(pat, jumptable, buf) { + Ok(ReaderValue::ExactRead { up_to_delimiter: u }) => { + Ok(u) + }, + Ok(ReaderValue::LeakyRead { up_to_delimiter: u, rest: r }) => Ok({ + transport_stream.prepend_to_read(r); + u + }), + Err(ReaderError::NoData) => Err(ReceiveUntilError::NoData), + Err(ReaderError::BufferOverflow) => Err(ReceiveUntilError::BufferOverflow), + Err(ReaderError::IO(e)) => Err(ReceiveUntilError::IO(e)) + } +} + +/// Consume the length of the buffer (or until an error is encountered). +fn sync_consume_n( + buf: &mut [u8], transport_stream: &mut S, +) -> Result<(), ConsumeNError> { + let mut ptr = 0; + let mut last_size = usize::MAX; + while ptr < buf.len() && last_size != 0 { + last_size = transport_stream.recv(&mut buf[ptr..])?; + ptr += last_size; + } + Ok(()) +} + +#[cfg(feature = "async")] +async fn async_receive_until_exact<'a, S: AsyncReceive>( + buf: &'a mut [u8], pat: &[u8], jumptable: &[usize], transport_stream: &mut PrependableStream +) -> Result<&'a [u8], ReceiveUntilError> { + let mut reader = AsyncReader::new(transport_stream); + match reader.recv_until(pat, jumptable, buf).await { + Ok(ReaderValue::ExactRead { up_to_delimiter: u }) => Ok(u), + Ok(ReaderValue::LeakyRead { up_to_delimiter: u, rest: r }) => Ok({ + transport_stream.prepend_to_read(r); + u + }), + Err(ReaderError::NoData) => Err(ReceiveUntilError::NoData), + Err(ReaderError::BufferOverflow) => Err(ReceiveUntilError::BufferOverflow), + Err(ReaderError::IO(e)) => Err(ReceiveUntilError::IO(e)) + } +} + +/// Consume the length of the buffer (or until an error is encountered). +#[cfg(feature = "async")] +async fn async_consume_n( + buf: &mut [u8], transport_stream: &mut S, +) -> Result<(), ConsumeNError> { + let mut ptr = 0; + while ptr < buf.len() { + ptr += transport_stream.recv(&mut buf[ptr..]).await?; + } + Ok(()) +} + +impl Incoming<'_, ChunkedIn<'_, PrependableStream>> { + /// Synchronously receive headers for a single chunk returning the chunk object or an error. + /// + /// The received chunk can be received on directly (as it contains a mutable borrow of the + /// transport stream). + pub fn get_chunk( + &mut self + ) -> Result>>, ChunkedInReceiveError> { + if *self.exhausted { + return Ok(None); + } + if !self.inner.can_create_next { + return Err(ChunkedInReceiveError::InvalidState); + } + let buf = get_buf!(self); + let preface = sync_receive_until_exact( + buf, PAT_CRLF, JUT_CRLF, self.inner.transport_stream + )?; + let preface: usize = usize::from_str_radix(std::str::from_utf8(preface) + .map_err(|_| ChunkedInReceiveError::InvalidLength)?, + 16, + ).map_err(|_| ChunkedInReceiveError::InvalidLength)?; + self.inner.can_create_next = false; + if preface == 0 { + *self.exhausted = true; + sync_consume_n(&mut buf[..2], self.inner.transport_stream)?; + if buf[..2] != *b"\r\n" { return Err(ChunkedInReceiveError::InvalidTrailer); } + return Ok(None); + } + Ok(Some(IncomingChunk::new( + self.inner.transport_stream, + buf, + preface, + &mut self.inner.can_create_next + ))) + } +} + +#[cfg(feature = "async")] +impl Incoming<'_, ChunkedIn<'_, PrependableStream>> { + /// Asynchronously receive headers for a single chunk returning the chunk object or an error. + /// + /// The received chunk can be received on directly (as it contains a mutable borrow of the + /// transport stream). + pub async fn get_chunk_async( + &mut self + ) -> Result>>, ChunkedInReceiveError> { + if *self.exhausted { + return Ok(None); + } + if !self.inner.can_create_next { + return Err(ChunkedInReceiveError::InvalidState); + } + let buf = get_buf!(self); + let preface = async_receive_until_exact( + buf, PAT_CRLF, JUT_CRLF, self.inner.transport_stream + ).await?; + let preface: usize = usize::from_str_radix(std::str::from_utf8(preface) + .map_err(|_| ChunkedInReceiveError::InvalidLength)?, + 16, + ).map_err(|_| ChunkedInReceiveError::InvalidLength)?; + self.inner.can_create_next = false; + if preface == 0 { + *self.exhausted = true; + async_consume_n(&mut buf[..2], self.inner.transport_stream).await?; + if buf[..2] != *b"\r\n" { return Err(ChunkedInReceiveError::InvalidTrailer); } + return Ok(None); + } + Ok(Some(IncomingChunk::new( + self.inner.transport_stream, + buf, + preface, + &mut self.inner.can_create_next + ))) + } +} + +impl Incoming<'_, SizedIn<'_, PrependableStream>> { + /// Receive the advertised length (OR LESS). + /// + /// This will yield a vector with maximum size of the advertised content-length. + /// But it's size can possibly be smaller. + pub fn recv_all(&mut self) -> Result, std::io::Error> { + let mut res = Vec::new(); + let buf = get_buf!(self); + let mut last_chunk_size: usize = usize::MAX; + let mut to_prepend: &[u8] = &[]; + while self.inner.remaining > 0 && last_chunk_size != 0 { + let chunk_size = self.inner.transport_stream.recv(buf)?; + last_chunk_size = chunk_size; + let real_read = std::cmp::min(chunk_size, self.inner.remaining); + res.extend_from_slice(&buf[0..real_read]); + to_prepend = &buf[real_read..std::cmp::max(chunk_size, self.inner.remaining)]; + self.inner.remaining = self.inner.remaining.saturating_sub(chunk_size); + } + self.inner.prepend_to_stream(to_prepend); + *self.exhausted = true; + Ok(res) + } +} + +impl Incoming<'_, SizedIn<'_, PrependableStream>> { + /// Try to synchronously receive a chunk of a size given by `buf`'s size. + /// The values returned represent the actual read byte count. The returned value is represented + /// as `(, )`. The advised size is the number of bytes the caller + /// of this function is advised to interpret. The number of affected bytes is the number of + /// bytes actually written to the buffer. The caller is advised to ignore the bytes in the + /// range between the advised size and affected bytes. + /// + /// # Caution + /// This operation is exposed for performance reasons. If you don't have a reason to use this + /// exact function - consider using a differend read function. + /// + /// # Note + /// This function performs a raw operation on the underlying stream using the passed buffer and + /// thus the buffer can receive "junk" as the bytes beyond the advised size are prepended to + /// the transport stream for continuous successive reads. This "junk" should (nearly always) be + /// ignored. + /// + /// # Errors + /// Will fail if unable to read from the `transport_stream`. + /// `transport_stream`, updating the inner counter for later operations. + pub fn recv(&mut self, buf: &mut [u8]) -> Result<(usize, usize), std::io::Error> { + if self.inner.remaining == 0 { + return Ok((0, 0)); + } + let chunk_size = self.inner.transport_stream.recv(buf)?; + let advised_size = self.inner.remaining; + self.inner.remaining = self.inner.remaining.saturating_sub(chunk_size); + if self.inner.remaining == 0 { + *self.exhausted = true; + self.inner.prepend_to_stream(&buf[advised_size..chunk_size]); + } + Ok((advised_size, chunk_size)) + } +} + +#[cfg(feature = "async")] +impl Incoming<'_, SizedIn<'_, PrependableStream>> { + /// Asynchronously receive the advertised length (OR LESS). + /// + /// This will yield a vector with maximum size of the advertised content-length. + /// But it's size can possibly be smaller. + pub async fn recv_all_async(&mut self) -> Result, std::io::Error> { + let mut res = Vec::new(); + let buf = get_buf!(self); + let mut last_chunk_size: usize = usize::MAX; + let mut to_prepend: &[u8] = &[]; + while self.inner.remaining > 0 && last_chunk_size != 0 { + let chunk_size = self.inner.transport_stream.recv(buf).await?; + last_chunk_size = chunk_size; + let real_read = std::cmp::min(chunk_size, self.inner.remaining); + res.extend_from_slice(&buf[0..real_read]); + to_prepend = &buf[real_read..std::cmp::max(chunk_size, self.inner.remaining)]; + self.inner.remaining = self.inner.remaining.saturating_sub(chunk_size); + } + self.inner.prepend_to_stream(to_prepend); + *self.exhausted = true; + Ok(res) + } +} + +#[cfg(feature = "async")] +impl Incoming<'_, SizedIn<'_, PrependableStream>> { + /// Try to asynchronously receive a chunk of a size given by `buf`'s size. + /// The values returned represent the actual read byte count. The returned value is represented + /// as `(, )`. The advised size is the number of bytes the caller + /// of this function is advised to interpret. The number of affected bytes is the number of + /// bytes actually written to the buffer. The caller is advised to ignore the bytes in the + /// range between the advised size and affected bytes. + /// + /// # Caution + /// This operation is exposed for performance reasons. If you don't have a reason to use this + /// exact function - consider using a differend read function. + /// + /// # Note + /// This function performs a raw operation on the underlying stream using the passed buffer and + /// thus the buffer can receive "junk" as the bytes beyond the advised size are prepended to + /// the transport stream for continuous successive reads. This "junk" should (nearly always) be + /// ignored. + /// + /// # Errors + /// Will fail if unable to read from the `transport_stream`. + /// `transport_stream`, updating the inner counter for later operations. + async fn recv_async<'a>( + &'a mut self, buf: &'a mut [u8] + ) -> Result<(usize, usize), std::io::Error> { + if self.inner.remaining == 0 { + return Ok((0, 0)); + } + let chunk_size = self.inner.transport_stream.recv(buf).await?; + let advised_size = self.inner.remaining; + self.inner.remaining = self.inner.remaining.saturating_sub(chunk_size); + if self.inner.remaining == 0 { + *self.exhausted = true; + self.inner.prepend_to_stream(&buf[advised_size..chunk_size]); + } + Ok((advised_size, chunk_size)) + } +} + +/// Outgoing HTTP body. +/// +/// This can be passed to a connection handler to automatically poll and send the data via the +/// underlying `transport_stream`. +pub struct Outgoing { + inner: T, +} + +impl<'a, D: Iterator, S> Outgoing> { + pub fn new(inner: SizedOut<'a, D, S>) -> Self { + Self { inner } + } +} + +#[cfg(feature = "async")] +impl<'a, F: Future, D: Iterator, S> Outgoing> { + pub fn new_async(inner: SizedOut<'a, D, S>) -> Self { + Self { inner } + } +} + +impl Outgoing> { + /// Target length to send. + /// + /// The data resource will be polled until this length is reached or until the data iterator + /// stream returns a [`None`]. + #[inline] + pub fn len(&self) -> &usize { + &self.inner.size_total + } +} + +/// Outgoing body send error +#[derive(Debug)] +pub enum SendError { + /// The underlying data source has returned an unexpected number of bytes. + LengthDiscrepancy, + IO(std::io::Error), +} + +impl From for SendError { + fn from(value: std::io::Error) -> Self { + Self::IO(value) + } +} + +enum ShouldContSend { + Continue, + End, +} + +fn handle_sync_chunk( + chunk: &[u8], + remaining_size: &mut usize, + transport_stream: &mut S, +) -> Result { + let len_to_send = std::cmp::min(*remaining_size, chunk.len()); + *remaining_size -= len_to_send; + let mut ptr = 0_usize; + while ptr < len_to_send { + ptr += transport_stream.send(&chunk[ptr..len_to_send])?; + } + Ok(match *remaining_size { + 0 => ShouldContSend::End, + _ => ShouldContSend::Continue, + }) +} + +impl<'a, D: Iterator, S: super::io::Send> Outgoing> { + /// Send the advertised length or less. + /// + /// This will call `next` on the `data_source` iterator until the total length yielded is equal + /// to or greater than the `total_size` or until the `data_source` iterator has yielded a + /// [`None`]. + pub fn send_all(&mut self) -> Result<(), SendError> { + for chunk in self.inner.data_source.by_ref() { match handle_sync_chunk( + chunk, + &mut self.inner.remaining, + self.inner.transport_stream + )? { + ShouldContSend::Continue => {}, + ShouldContSend::End => break, + }} + if self.inner.remaining > 0 { + return Err(SendError::LengthDiscrepancy); + } + Ok(()) + } +} + +#[cfg(feature = "async")] +async fn handle_async_chunk<'a, S: AsyncSend>( + chunk: impl Future, + remaining_size: &mut usize, + transport_stream: &mut S, +) -> Result { + let chunk = chunk.await; + let len_to_send = std::cmp::min(*remaining_size, chunk.len()); + *remaining_size -= len_to_send; + let mut ptr = 0_usize; + while ptr < len_to_send { + ptr += transport_stream.send(&chunk[ptr..len_to_send]).await?; + } + Ok(match *remaining_size { + 0 => ShouldContSend::End, + _ => ShouldContSend::Continue, + }) +} + +#[cfg(feature = "async")] +impl<'a, F: Future, D: Iterator, S: super::io::AsyncSend> + Outgoing> +{ + /// Asynchronously send the advertised length or less. + /// + /// This will call `next` and `await` on the `data_source` iterator until the total length + /// summed is equal to or greater than the `total_size` or until the `data_source` iterator + /// has yielded a [`None`]. + pub async fn send_all_async(&mut self) -> Result<(), SendError> { + for chunk in self.inner.data_source.by_ref() { match handle_async_chunk( + chunk, + &mut self.inner.remaining, + self.inner.transport_stream + ).await? { + ShouldContSend::Continue => {}, + ShouldContSend::End => break, + }} + Ok(()) + } +} + +#[cfg(test)] +mod local_receive_sync { + use crate::{body::{Incoming, SizedIn}, io::{Receive, TestSyncStream, PrependableStream}}; + + #[test] + fn split_to_chunks() { + let src = "testtesttesttest".into(); + let mut src: TestSyncStream<4> = TestSyncStream::new(&src); + let mut res = Vec::new(); + let mut buf = [0_u8; 50]; + loop { + let read = src.recv(&mut buf).unwrap(); + if read == 0 { break; } + res.push(Vec::from(&buf[0..read])); + } + + assert_eq!(res, vec!["test".as_bytes(); 4]); + } + + #[test] + fn full_recv_till_end() { + let src = "testtesttesttest".into(); + let mut src = PrependableStream::new(TestSyncStream::<4>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(SizedIn::new(&mut src, 16), &mut exhausted); + let res = body.recv_all(); + assert_eq!(res.unwrap(), "testtesttesttest".as_bytes().to_vec()); + } + + #[test] + fn partial_recv_till_end() { + let src = "testtesttesttest".into(); + let mut src = PrependableStream::new(TestSyncStream::<4>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(SizedIn::new(&mut src, 8), &mut exhausted); + let res = body.recv_all(); + assert_eq!(res.unwrap(), "testtest".as_bytes().to_vec()); + } + + #[test] + fn overpromising_recv_till_end() { + let src = "test".into(); + let mut src = PrependableStream::new(TestSyncStream::<4>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(SizedIn::new(&mut src, 8), &mut exhausted); + let res = body.recv_all(); + assert_eq!(res.unwrap(), "test".as_bytes().to_vec()); + } + + #[test] + fn check_exhaust() { + let src = "testtesttesttest".into(); + let mut src = PrependableStream::new(TestSyncStream::<4>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(SizedIn::new(&mut src, 8), &mut exhausted); + body.recv_all().unwrap(); + assert_eq!(exhausted, true); + } +} + +#[cfg(test)] +mod local_send_sync { + use crate::io::Send; + use super::{Outgoing, SizedOut}; + + #[derive(Default)] + struct TestStream { + inner: Vec, + } + + impl Send for TestStream { + fn send(&mut self, buf: &[u8]) -> Result { + self.inner.extend_from_slice(&buf[0..std::cmp::min(CHUNK_SIZE, buf.len())]); + Ok(CHUNK_SIZE) + } + } + + #[test] + fn concat_chunks_with_exact_length() { + let data = vec!["hello", "world", "how", "are", "you"]; + let mut src = data.iter().map(|v| v.as_bytes()); + let mut stream = TestStream::<4>::default(); + Outgoing::new(SizedOut::new(&mut src, &mut stream, 19)).send_all().unwrap(); + assert_eq!(std::str::from_utf8(&stream.inner), Ok("helloworldhowareyou")) + } + + #[test] + fn partial_concat_chunks() { + let data = vec!["hello", "world", "how", "are", "you"]; + let mut src = data.iter().map(|v| v.as_bytes()); + let mut stream = TestStream::<3>::default(); + Outgoing::new(SizedOut::new(&mut src, &mut stream, 16)).send_all().unwrap(); + assert_eq!(std::str::from_utf8(&stream.inner), Ok("helloworldhoware")) + } + + #[test] + fn overpromising_concat_chunks() { + let data = vec!["hello", "world", "how", "are", "you"]; + let mut src = data.iter().map(|v| v.as_bytes()); + let mut stream = TestStream::<19>::default(); + match Outgoing::new(SizedOut::new(&mut src, &mut stream, 20)).send_all() { + Err(crate::body::SendError::LengthDiscrepancy) => assert!(true), + _ => assert!(false), + } + assert_eq!(std::str::from_utf8(&stream.inner), Ok("helloworldhowareyou")) + } +} + +#[cfg(all(test, feature = "async", feature = "tokio-full"))] +mod local_receive_async { + use super::{Incoming, SizedIn}; + use crate::io::{AsyncReceive, TestAsyncStream, PrependableStream}; + use tokio::test; + + #[test] + async fn split_to_chunks() { + let src = "testtesttesttest".into(); + let mut src: TestAsyncStream<4> = TestAsyncStream::new(&src); + let mut res = Vec::new(); + let mut buf = [0_u8; 50]; + loop { + let read = src.recv(&mut buf).await.unwrap(); + if read == 0 { break; } + res.push(Vec::from(&buf[0..read])); + } + + assert_eq!(res, vec!["test".as_bytes(); 4]); + } + + #[test] + async fn full_recv_till_end() { + let src = "testtesttesttest".into(); + let mut src = PrependableStream::new(TestAsyncStream::<11>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new_async(SizedIn::new(&mut src, 16), &mut exhausted); + let res = body.recv_all_async().await; + assert_eq!(res.unwrap(), "testtesttesttest".as_bytes().to_vec()); + } + + #[test] + async fn partial_recv_till_end() { + let src = "testtesttesttest".into(); + let mut src = PrependableStream::new(TestAsyncStream::<7>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new_async(SizedIn::new(&mut src, 8), &mut exhausted); + let res = body.recv_all_async().await; + assert_eq!(res.unwrap(), "testtest".as_bytes().to_vec()); + } + + #[test] + async fn overpromising_recv_till_end() { + let src = "test".into(); + let mut src = PrependableStream::new(TestAsyncStream::<3>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new_async(SizedIn::new(&mut src, 8), &mut exhausted); + let res = body.recv_all_async().await; + assert_eq!(res.unwrap(), "test".as_bytes().to_vec()); + } + + #[test] + async fn check_exhaust() { + let src = "test".into(); + let mut src = PrependableStream::new(TestAsyncStream::<3>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new_async(SizedIn::new(&mut src, 8), &mut exhausted); + body.recv_all_async().await.unwrap(); + assert_eq!(exhausted, true); + } +} + +#[cfg(all(test, feature = "async", feature = "tokio-full"))] +mod local_send_async { + use tokio::{time::sleep, test}; + use std::{time::Duration, future::Future, task::{Poll, Context}, pin::Pin}; + use crate::io::AsyncSend; + use super::{Outgoing, SizedOut}; + + #[derive(Default)] + struct TestStream { + inner: Vec, + } + + impl AsyncSend for + TestStream + { + async fn send<'a>(&'a mut self, buf: &'a [u8]) -> Result { + sleep(Duration::from_millis(TIMEOUT)).await; + self.inner.extend_from_slice(&buf[0..std::cmp::min(CHUNK_SIZE, buf.len())]); + Ok(CHUNK_SIZE) + } + } + + struct Resolved { inner: T } + impl From for Resolved { + fn from(inner: T) -> Self { + Self { inner } + } + } + + impl Future for Resolved { + type Output = T; + + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + Poll::Ready(self.inner.clone()) + } + } + + #[test] + async fn concat_chunks_with_exact_length() { + let data = vec!["hello", "world", "how", "are", "you"]; + let src = data.iter().map(|v| Resolved::from(v.as_bytes())); + let mut src = src.collect::>().into_iter(); + let mut stream = TestStream::<10, 1>::default(); + Outgoing::new_async(SizedOut::new(&mut src, &mut stream, 19)).send_all_async() + .await.unwrap(); + assert_eq!(std::str::from_utf8(&stream.inner), Ok("helloworldhowareyou")) + } + + #[test] + async fn partial_concat_chunks() { + let data = vec!["hello", "world", "how", "are", "you"]; + let src = data.iter().map(|v| Resolved::from(v.as_bytes())); + let mut src = src.collect::>().into_iter(); + let mut stream = TestStream::<11, 1>::default(); + Outgoing::new_async(SizedOut::new(&mut src, &mut stream, 16)).send_all_async() + .await.unwrap(); + assert_eq!(std::str::from_utf8(&stream.inner), Ok("helloworldhoware")) + } + + #[test] + async fn overpromising_concat_chunks() { + let data = vec!["hello", "world", "how", "are", "you"]; + let src = data.iter().map(|v| Resolved::from(v.as_bytes())); + let mut src = src.collect::>().into_iter(); + let mut stream = TestStream::<13, 1>::default(); + Outgoing::new_async(SizedOut::new(&mut src, &mut stream, 20)).send_all_async() + .await.unwrap(); + assert_eq!(std::str::from_utf8(&stream.inner), Ok("helloworldhowareyou")) + } +} + +#[cfg(test)] +mod body_end_prepend { + use crate::io::{TestSyncStream, PrependableStream, Receive}; + use super::{Incoming, SizedIn}; + + #[test] + fn no_prepend() { + let src = "Hello, world!".into(); + let mut stream = PrependableStream::new(TestSyncStream::<5>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(SizedIn::new(&mut stream, 5), &mut exhausted); + assert_eq!(body.recv_all().unwrap(), b"Hello"); + assert_eq!(exhausted, true); + let mut buf = [0_u8; 20]; + assert_eq!(stream.recv(&mut buf).unwrap(), 5); + assert_eq!(&buf[..6], b", wor\0"); + } + + #[test] + fn prepend() { + let src = "Hello, world!".into(); + let mut stream = PrependableStream::new(TestSyncStream::<7>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(SizedIn::new(&mut stream, 5), &mut exhausted); + assert_eq!(body.recv_all().unwrap(), b"Hello"); + assert_eq!(exhausted, true); + let mut buf = [0_u8; 20]; + assert_eq!(stream.recv(&mut buf).unwrap(), 2); + assert_eq!(stream.recv(&mut buf[2..]).unwrap(), 6); + assert_eq!(&buf[..9], b", world!\0"); + } + + #[test] + fn prepend_raw_read() { + let src = "Hello, world!".into(); + let mut stream = PrependableStream::new(TestSyncStream::<7>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(SizedIn::new(&mut stream, 5), &mut exhausted); + let mut buf = [0_u8; 20]; + assert_eq!(body.recv(&mut buf).unwrap(), (5, 7)); + assert_eq!(&buf[..8], b"Hello, \0"); + assert_eq!(exhausted, true); + assert_eq!(stream.recv(&mut buf).unwrap(), 2); + assert_eq!(stream.recv(&mut buf[2..]).unwrap(), 6); + assert_eq!(&buf[..9], b", world!\0"); + } + + #[test] + fn prepend_exhaust() { + let src = "Hello, world!".into(); + let mut stream = PrependableStream::new(TestSyncStream::<7>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(SizedIn::new(&mut stream, 5), &mut exhausted); + assert_eq!(body.recv_all().unwrap(), b"Hello"); + assert_eq!(exhausted, true); + let mut buf = [0_u8; 20]; + assert_eq!(stream.recv(&mut buf).unwrap(), 2); + assert_eq!(stream.recv(&mut buf[2..]).unwrap(), 6); + assert_eq!(&buf[..9], b", world!\0"); + assert_eq!(stream.recv(&mut buf).unwrap(), 0); + assert_eq!(&buf[..1], b","); + } + + #[test] + fn prepend_exhaust_raw_read() { + let src = "Hello, world!".into(); + let mut stream = PrependableStream::new(TestSyncStream::<7>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(SizedIn::new(&mut stream, 5), &mut exhausted); + let mut buf = [0_u8; 20]; + assert_eq!(body.recv(&mut buf).unwrap(), (5, 7)); + assert_eq!(&buf[..8], b"Hello, \0"); + assert_eq!(exhausted, true); + assert_eq!(stream.recv(&mut buf).unwrap(), 2); + assert_eq!(stream.recv(&mut buf[2..]).unwrap(), 6); + assert_eq!(&buf[..9], b", world!\0"); + assert_eq!(stream.recv(&mut buf).unwrap(), 0); + assert_eq!(&buf[..1], b","); + } +} + +#[cfg(test)] +mod local_receive_chunked_sync { + use crate::{body::ChunkedInReceiveError, io::{PrependableStream, Receive, TestSyncStream}}; + use super::{ChunkedIn, Incoming}; + + #[test] + fn valid_singlechunk() { + let src = b"D\r\nHello, world!\r\n0\r\n\r\n".to_vec(); + let mut stream = PrependableStream::new(TestSyncStream::<3>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(ChunkedIn::new(&mut stream), &mut exhausted); + let mut chunk = body.get_chunk().unwrap().unwrap(); + assert_eq!(chunk.recv_all().unwrap(), b"Hello, world!"); + assert_eq!(body.get_chunk().unwrap(), None); + drop(body); + let mut buf = [0_u8; 5]; + assert_eq!(stream.recv(&mut buf).unwrap(), 0); + } + + #[test] + fn valid_multichunk() { + let src = b"D\r\nHello, world!\r\n1\r\n \r\n3\r\nHow\r\n9\r\n are you?\r\n0\r\n\r\n" + .to_vec(); + let mut stream = PrependableStream::new(TestSyncStream::<3>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(ChunkedIn::new(&mut stream), &mut exhausted); + let mut chunk = body.get_chunk().unwrap().unwrap(); + assert_eq!(chunk.recv_all().unwrap(), b"Hello, world!"); + let mut chunk = body.get_chunk().unwrap().unwrap(); + assert_eq!(chunk.recv_all().unwrap(), b" "); + let mut chunk = body.get_chunk().unwrap().unwrap(); + assert_eq!(chunk.recv_all().unwrap(), b"How"); + let mut chunk = body.get_chunk().unwrap().unwrap(); + assert_eq!(chunk.recv_all().unwrap(), b" are you?"); + assert_eq!(body.get_chunk().unwrap(), None); + drop(body); + let mut buf = [0_u8; 5]; + assert_eq!(stream.recv(&mut buf).unwrap(), 0); + } + + #[test] + fn invalid_length() { + let src = b"z\r\ntest\r\n0\r\n\r\n".to_vec(); + let mut stream = PrependableStream::new(TestSyncStream::<3>::new(&src)); + let mut exhausted = false; + let mut body = Incoming::>::new(ChunkedIn::new(&mut stream), &mut exhausted); + match body.get_chunk() { + Err(ChunkedInReceiveError::InvalidLength) => assert!(true), + _ => assert!(false), + }; + } +} diff --git a/lib/inferium/src/headers.rs b/lib/inferium/src/headers.rs new file mode 100644 index 0000000..a19ad85 --- /dev/null +++ b/lib/inferium/src/headers.rs @@ -0,0 +1,200 @@ +use proc::AutoimplHkeys; + +/// Accepted valid HTTP header keys +/// +/// These can be sent by both the client and the server. +/// The calling implementation shall check and verify the validity of the headers, or may ignore +/// any invalid ones. +#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(AutoimplHkeys)] +#[allow(non_camel_case_types)] +#[derive(Debug)] +pub enum HeaderKey { + ACCEPT, + ACCEPT_CHARSET, + ACCEPT_ENCODING, + ACCEPT_LANGUAGE, + ACCEPT_RANGES, + ACCESS_CONTROL_ALLOW_CREDENTIALS, + ACCESS_CONTROL_ALLOW_HEADERS, + ACCESS_CONTROL_ALLOW_METHODS, + ACCESS_CONTROL_ALLOW_ORIGIN, + ACCESS_CONTROL_EXPOSE_HEADERS, + ACCESS_CONTROL_MAX_AGE, + ACCESS_CONTROL_REQUEST_HEADERS, + ACCESS_CONTROL_REQUEST_METHOD, + AGE, + ALLOW, + ALT_SVC, + AUTHORIZATION, + CACHE_CONTROL, + CACHE_STATUS, + CDN_CACHE_CONTROL, + CONNECTION, + CONTENT_DISPOSITION, + CONTENT_ENCODING, + CONTENT_LANGUAGE, + CONTENT_LENGTH, + CONTENT_LOCATION, + CONTENT_RANGE, + CONTENT_SECURITY_POLICY, + CONTENT_SECURITY_POLICY_REPORT_ONLY, + CONTENT_TYPE, + COOKIE, + DNT, + DATE, + ETAG, + EXPECT, + EXPIRES, + FORWARDED, + FROM, + HOST, + IF_MATCH, + IF_MODIFIED_SINCE, + IF_NONE_MATCH, + IF_RANGE, + IF_UNMODIFIED_SINCE, + LAST_MODIFIED, + LINK, + LOCATION, + MAX_FORWARDS, + ORIGIN, + PRAGMA, + PROXY_AUTHENTICATE, + PROXY_AUTHORIZATION, + PUBLIC_KEY_PINS, + PUBLIC_KEY_PINS_REPORT_ONLY, + RANGE, + REFERER, + REFERRER_POLICY, + REFRESH, + RETRY_AFTER, + SEC_WEBSOCKET_ACCEPT, + SEC_WEBSOCKET_EXTENSIONS, + SEC_WEBSOCKET_KEY, + SEC_WEBSOCKET_PROTOCOL, + SEC_WEBSOCKET_VERSION, + SERVER, + SET_COOKIE, + STRICT_TRANSPORT_SECURITY, + TE, + TRAILER, + TRANSFER_ENCODING, + UPGRADE, + UPGRADE_INSECURE_REQUESTS, + USER_AGENT, + VARY, + VIA, + WARNING, + WWW_AUTHENTICATE, + X_CONTENT_TYPE_OPTIONS, + X_DNS_PREFETCH_CONTROL, + X_FRAME_OPTIONS, + X_XSS_PROTECTION, + OTHER(String), +} + +/// Type containing all the header values for a given header key. +/// +/// Header entries are not omitted if duplicit, but chained to this type. +/// +/// Some header keys (such as _cookie_) may require multiple entries. Inferium allows for all +/// header keys to have duplicit entries. +/// +/// This type does however provide a way to query the first entry for a given key for easy +/// manipulation. +#[derive(Clone, PartialEq, Eq, Hash, Default, Debug)] +pub struct HeaderValue { + inner: Vec, +} + +impl std::str::FromStr for HeaderValue { + type Err = (); + + fn from_str(s: &str) -> Result { + if !is_valid(s) { + return Err(()); + } + Ok(Self { + inner: vec![s.to_string()], + }) + } +} + +macro_rules! autoimpl_valid_hval { +([$($from:ty),*]) => { + $(autoimpl_valid_hval!($from);)* +}; +($from: ty) => { + impl From<$from> for HeaderValue { + fn from(value: $from) -> Self { + Self { inner: vec![value.to_string()] } + } + } +}; +} + +autoimpl_valid_hval!([usize, u32, u64, u128, i32, i64, i128]); + +fn is_valid(v: &str) -> bool { + for ch in v.chars() { + match ch { + 'a'..='z' | 'A'..='Z' | '0'..='9' | ':' | ' ' | ',' | '.' | '=' | '&' | '*' | '/' | + '-' | '!' | '#' | '\'' | '(' | ')' | '+' | ';' | '@' | '[' | ']' | '~' => {}, + _ => return false, + } + } + true +} + +impl HeaderValue { + #[inline] + pub(crate) fn new(inner: Vec) -> Self { + Self { inner } + } + + #[inline] + pub(crate) fn add(&mut self, val: String) { + self.inner.push(val); + } + + /// Query the first entry for this header key. + /// + /// See the documentation of [`HeaderValue`] for more information. + /// + /// # Panics + /// This will panic if this instance of [`HeaderValue`] is in an invalid state, i.e. has no + /// value and hence should not exist as an instance at all. + #[inline] + pub fn get(&self) -> &str { + self.inner.first().expect("invalid header value state") + } + + /// Get all the entries for this header key. + /// + /// See the documentation of [`HeaderValue`] for more information. + #[inline] + pub fn all(&self) -> &Vec { + &self.inner + } +} + +#[cfg(test)] +mod test { + use super::HeaderKey; + + #[test] + fn from_cache_control_raw() { + assert_eq!(HeaderKey::from("Cache-Control"), HeaderKey::CACHE_CONTROL); + } + + #[test] + fn into_user_agent_raw() { + assert_eq!(String::from(HeaderKey::USER_AGENT), "user-agent"); + } + + #[test] + fn invalid_raw() { + assert_eq!(HeaderKey::from("cache-agent"), HeaderKey::OTHER("cache-agent".to_string())); + } +} diff --git a/lib/inferium/src/io.rs b/lib/inferium/src/io.rs new file mode 100644 index 0000000..c2f7050 --- /dev/null +++ b/lib/inferium/src/io.rs @@ -0,0 +1,1013 @@ +use std::io::{Read, Write}; +#[cfg(feature = "async")] +use std::{future::Future, task::{Poll, Context}, pin::Pin}; +#[cfg(feature = "tokio-tls")] +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// A transport stream abstraction that allows for reading data from the underlying resource. +pub trait Receive { + fn recv(&mut self, buf: &mut [u8]) -> Result; +} + +/// A transport stream abstraction that allows for writing data to the underlying resource. +pub trait Send { + fn send(&mut self, buf: &[u8]) -> Result; +} + +/// A transport stream abstraction that allows for reading data from the underlying asynchronous +/// resource. +#[cfg(feature = "async")] +pub trait AsyncReceive { + fn recv<'a>(&'a mut self, buf: &'a mut [u8]) -> + impl Future>; +} + +/// A transport stream abstraction that allows for writing data to the underlying asynchronous +/// resource. +#[cfg(feature = "async")] +pub trait AsyncSend { + fn send<'a>(&'a mut self, buf: &'a [u8]) -> + impl Future>; +} + +#[cfg(any(feature = "testing", test))] +#[derive(Debug, PartialEq)] +pub struct TestSyncStream<'a, const CHUNK_SIZE: usize> { + inner: &'a Vec, + ptr: usize, + received: Vec, +} + +#[cfg(any(feature = "testing", test))] +impl<'a, const CHUNK_SIZE: usize> TestSyncStream<'a, CHUNK_SIZE> { + pub fn new(inner: &'a Vec) -> Self { + Self { inner, ptr: 0, received: Vec::default() } + } +} + +#[cfg(any(feature = "testing", test))] +impl<'a, const CHUNK_SIZE: usize> Receive for TestSyncStream<'a, CHUNK_SIZE> { + fn recv(&mut self, buf: &mut [u8]) -> Result { + let end_ptr = std::cmp::min(self.ptr + CHUNK_SIZE, self.inner.len()); + if self.ptr >= self.inner.len() { + return Ok(0); + } + let size = end_ptr - self.ptr; + buf[0..size].clone_from_slice(&self.inner[self.ptr..end_ptr]); + self.ptr += size; + Ok(size) + } +} + +#[cfg(any(feature = "testing", test))] +impl<'a, const CHUNK_SIZE: usize> Send for TestSyncStream<'a, CHUNK_SIZE> { + fn send(&mut self, buf: &[u8]) -> Result { + self.received.extend_from_slice(buf); + Ok(buf.len()) + } +} + +#[cfg(all(any(feature = "testing", test), feature = "async"))] +#[cfg_attr(test, derive(Debug))] +pub struct TestAsyncStream<'a, const CHUNK_SIZE: usize> { + inner: &'a Vec, + ptr: usize, + received: Vec, +} + +#[cfg(all(any(feature = "testing", test), feature = "async"))] +impl<'a, const CHUNK_SIZE: usize> TestAsyncStream<'a, CHUNK_SIZE> { + pub fn new(inner: &'a Vec) -> Self { + Self { inner, ptr: 0, received: Vec::default() } + } +} + +#[cfg(all(any(feature = "testing", test), feature = "async"))] +impl<'a, const CHUNK_SIZE: usize> AsyncReceive for TestAsyncStream<'a, CHUNK_SIZE> { + async fn recv<'b>(&'b mut self, buf: &'b mut [u8]) -> Result { + let end_ptr = std::cmp::min(self.ptr + CHUNK_SIZE, self.inner.len()); + if self.ptr >= self.inner.len() { + return Ok(0); + } + let size = end_ptr - self.ptr; + buf[0..size].clone_from_slice(&self.inner[self.ptr..end_ptr]); + self.ptr += size; + Ok(size) + } +} + +#[cfg(all(any(feature = "testing", test), feature = "async"))] +impl<'a, const CHUNK_SIZE: usize> AsyncSend for TestAsyncStream<'a, CHUNK_SIZE> { + async fn send<'b>(&'b mut self, buf: &'b [u8]) -> Result { + self.received.extend_from_slice(buf); + Ok(buf.len()) + } +} + +#[derive(Debug, PartialEq)] +pub struct PrependableStream { + inner: T, + prepend_read: Vec, + prepend_read_ptr: usize, +} + +impl PrependableStream { + pub(crate) fn new(inner: T) -> Self { + Self { inner, prepend_read: Vec::default(), prepend_read_ptr: 0 } + } + + pub(crate) fn prepend_to_read(&mut self, to_prepend: &[u8]) { + self.prepend_read.extend_from_slice(to_prepend); + } +} + +impl Receive for PrependableStream { + fn recv(&mut self, buf: &mut [u8]) -> Result { + if self.prepend_read_ptr >= self.prepend_read.len() { + return self.inner.recv(buf); + } + let last_ptr = self.prepend_read_ptr; + let take_length = std::cmp::min(buf.len(), self.prepend_read.len() - last_ptr); + self.prepend_read_ptr += take_length; + buf[..take_length].copy_from_slice(&self.prepend_read[last_ptr..self.prepend_read_ptr]); + if self.prepend_read_ptr >= self.prepend_read.len() { + self.prepend_read = Vec::default(); + self.prepend_read_ptr = 0; + } + Ok(take_length) + } +} + +impl Send for PrependableStream { + fn send(&mut self, buf: &[u8]) -> Result { + self.inner.send(buf) + } +} + +#[cfg(feature = "async")] +impl AsyncReceive for PrependableStream { + async fn recv<'a>(&'a mut self, buf: &'a mut [u8]) -> Result { + if self.prepend_read_ptr >= self.prepend_read.len() { + return self.inner.recv(buf).await; + } + let last_ptr = self.prepend_read_ptr; + let take_length = std::cmp::min(buf.len(), self.prepend_read.len() - last_ptr); + self.prepend_read_ptr += take_length; + buf[..take_length].copy_from_slice(&self.prepend_read[last_ptr..self.prepend_read_ptr]); + if self.prepend_read_ptr >= self.prepend_read.len() { + self.prepend_read = Vec::default(); + self.prepend_read_ptr = 0; + } + Ok(take_length) + } +} + +#[cfg(feature = "async")] +impl AsyncSend for PrependableStream { + fn send<'a>(&'a mut self, buf: &'a [u8]) -> + impl Future> { + self.inner.send(buf) + } +} + +/// [`std::net::TcpStream`] wrapper. +/// +/// Able to perform blocking I/O operations on this stream. +/// +/// The inner stream cannot be shared. However this structure can be reused. +#[derive(Debug)] +pub struct StdInet { + inner: std::net::TcpStream +} + +impl StdInet { + pub fn new(inner: std::net::TcpStream) -> Self { + Self { inner } + } +} + +/// [`std::os::unix::net::UnixStream`] wrapper. +/// +/// Able to perform blocking I/O operations on this stream. +/// +/// The inner stream cannot be shared. However this structure can be reused. +#[derive(Debug)] +pub struct StdUnix { + inner: std::os::unix::net::UnixStream +} + +impl StdUnix { + pub fn new(inner: std::os::unix::net::UnixStream) -> Self { + Self { inner } + } +} + +impl Receive for StdInet { + /// Perform a blocking read from the underlying tcp stream. + /// + /// This yields the number of bytes read into the buffer (or the corresponding + /// [`std::io::Error`] if unable to read from the resource). + fn recv(&mut self, buf: &mut [u8]) -> Result { + self.inner.read(buf) + } +} + +impl Send for StdInet { + /// Perform a blocking write to the underlying tcp stream. + /// + /// This yields the number of bytes written from the buffer (or the + /// corresponding [`std::io::Error`] if unable to read from the resource). + fn send(&mut self, buf: &[u8]) -> Result { + self.inner.write(buf) + } +} + +impl Receive for StdUnix { + /// Perform a blocking read from the underlying unix stream. + /// + /// This yields the number of bytes read into the buffer (or the corresponding + /// [`std::io::Error`] if unable to read from the resource). + fn recv(&mut self, buf: &mut [u8]) -> Result { + self.inner.read(buf) + } +} + +impl Send for StdUnix { + /// Perform a blocking write to the underlying unix stream. + /// + /// This yields the number of bytes written from the buffer (or the + /// corresponding [`std::io::Error`] if unable to read from the resource). + fn send(&mut self, buf: &[u8]) -> Result { + self.inner.write(buf) + } +} + +pub struct SyncReader<'a, T: Receive> { + reader: &'a mut T, + offset: usize, +} + +impl<'a, T: Receive> SyncReader<'a, T> { + pub fn new(reader: &'a mut T) -> Self { + Self { reader, offset: 0 } + } +} + +#[derive(Debug)] +pub enum ReaderError { + BufferOverflow, + NoData, + IO(std::io::Error), +} + +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug)] +pub enum ReaderValue<'a> { + ExactRead { up_to_delimiter: &'a [u8] }, + LeakyRead { up_to_delimiter: &'a [u8], rest: &'a [u8] }, +} + +impl From for ReaderError { + fn from(value: std::io::Error) -> Self { + Self::IO(value) + } +} + +impl<'a, T: Receive> SyncReader<'_, T> { + /// Receive until a specific delimiter is encountered returning the chunk up to the delimiter + /// and rest of the received overhead on success or nothing if read matched exactly. + pub fn recv_until(&mut self, pat: &[u8], jumptable: &[usize], buf: &'a mut [u8]) -> + Result, ReaderError> { + let mut cur_ptr = 0; + let mut pat_ptr = 0_usize; + let dptr = loop { + let cur = self.recv_next(pat, &mut pat_ptr, jumptable, buf, &mut cur_ptr)?; + let Some(cur) = cur else { continue; }; + break cur; + }; + let till_delim = &buf[..dptr]; + let rest_start = dptr + pat.len(); + if rest_start >= buf.len() || rest_start == cur_ptr { + return Ok(ReaderValue::ExactRead { up_to_delimiter: till_delim }); + } + let from_delim = &buf[dptr+pat.len()..cur_ptr]; + Ok(ReaderValue::LeakyRead { up_to_delimiter: till_delim, rest: from_delim }) + } + + /// Receive a single chunk and check if the chunk contains the delimiter. + /// + /// If the delimiter is found - this returns the position of the start of the delimiter. + #[inline] + fn recv_next( + &mut self, + pat: &[u8], + pat_ptr: &mut usize, + jumptable: &[usize], + buf: &mut [u8], + cur_ptr: &mut usize, + ) -> + Result, ReaderError> { + let cur_chunk = &mut buf[*cur_ptr..]; + let size = self.reader.recv(cur_chunk)?; + if size == 0 { + return Err(ReaderError::NoData); + } + let prev_ptr = *cur_ptr; + *cur_ptr += size; + let cur_chunk = &cur_chunk[..size]; + let Some(found) = try_find( + cur_chunk, pat, pat_ptr, jumptable, &(prev_ptr + self.offset + )) else { + if *cur_ptr < buf.len() { return Ok(None); } + return Err(ReaderError::BufferOverflow); + }; + Ok(Some(found)) + } +} + +/// Try to find the start of a specific match given global ofset and the position in the pattern to +/// start matching. +/// +/// # Panics +/// This can panic if the `total_ptr` is lower than `pat_ptr` and thus partial match before the +/// data start is implied - this is an invalid state and will not be checked. +/// This can panic if the `pat_ptr` is higher than the length of `pat` or the jumptable is invalid +/// in the context of the passed pattern. +fn try_find( + haystack: &[u8], + pat: &[u8], + pat_ptr: &mut usize, + jumptable: &[usize], + total_ptr: &usize, +) -> Option { + for (idx, cur) in haystack.iter().enumerate() { + if *pat_ptr >= pat.len() { + return Some(*total_ptr + idx - pat.len()); + } + if *cur == pat[*pat_ptr] { + *pat_ptr += 1; + continue; + } + try_find_update_jumptable(cur, pat, pat_ptr, jumptable); + } + if *pat_ptr >= pat.len() { + return Some(*total_ptr + haystack.len() - pat.len()); + } + None +} + +#[cfg(feature = "async")] +pub struct AsyncReader<'a, T: AsyncReceive> { + reader: &'a mut T, + offset: usize, +} + +#[cfg(feature = "async")] +impl<'a, T: AsyncReceive> AsyncReader<'a, T> { + pub fn new(reader: &'a mut T) -> Self { + Self { reader, offset: 0 } + } +} + +#[cfg(feature = "async")] +impl<'a, T: AsyncReceive> AsyncReader<'_, T> { + /// Receive until a specific delimiter is encountered returning the chunk up to the delimiter + /// and rest of the received overhead on success or nothing if read matched exactly. + pub async fn recv_until(&mut self, pat: &[u8], jumptable: &[usize], buf: &'a mut [u8]) -> + Result, ReaderError> { + let mut cur_ptr = 0; + let mut pat_ptr = 0_usize; + let dptr = loop { + let cur = self.recv_next(pat, &mut pat_ptr, jumptable, buf, &mut cur_ptr).await?; + let Some(cur) = cur else { continue; }; + break cur; + }; + let till_delim = &buf[..dptr]; + let rest_start = dptr + pat.len(); + if rest_start >= buf.len() || rest_start == cur_ptr { + return Ok(ReaderValue::ExactRead { up_to_delimiter: till_delim }); + } + let from_delim = &buf[dptr+pat.len()..cur_ptr]; + Ok(ReaderValue::LeakyRead { up_to_delimiter: till_delim, rest: from_delim }) + } + + /// Receive a single chunk and check if the chunk contains the delimiter. + /// + /// If the delimiter is found - this returns the position of the start of the delimiter. + #[inline] + async fn recv_next( + &mut self, + pat: &[u8], + pat_ptr: &mut usize, + jumptable: &[usize], + buf: &mut [u8], + cur_ptr: &mut usize, + ) -> + Result, ReaderError> { + let cur_chunk = &mut buf[*cur_ptr..]; + let size = self.reader.recv(cur_chunk).await?; + if size == 0 { + return Err(ReaderError::NoData); + } + let prev_ptr = *cur_ptr; + *cur_ptr += size; + let cur_chunk = &cur_chunk[..size]; + let Some(found) = try_find( + cur_chunk, pat, pat_ptr, jumptable, &(prev_ptr + self.offset + )) else { + if *cur_ptr < buf.len() { return Ok(None); } + return Err(ReaderError::BufferOverflow); + }; + Ok(Some(found)) + } +} + +/// Do not use this method directly. +/// +/// This is a helper for [`try_find`] and it should be used instead. +fn try_find_update_jumptable( + hay: &u8, + pat: &[u8], + pat_ptr: &mut usize, + jumptable: &[usize], +) { + while *pat_ptr != 0 { + *pat_ptr = jumptable[*pat_ptr]; + if *hay == pat[*pat_ptr] { + *pat_ptr += 1; + break; + } + } +} + +/// [`tokio::net::TcpStream`] wrapper. +/// +/// Able to perform asynchronous I/O operations on this stream. +/// +/// The inner stream cannot be shared. However this structure can be reused. +#[cfg(feature = "tokio-net")] +#[derive(Debug)] +pub struct TokioInet { + inner: tokio::net::TcpStream +} + +#[cfg(feature = "tokio-net")] +impl TokioInet { + pub fn new(inner: tokio::net::TcpStream) -> Self { + Self { inner } + } +} + +/// [`tokio::net::UnixStream`] wrapper. +/// +/// Able to perform asynchronous I/O operations on this stream. +/// +/// The inner stream cannot be shared. However this structure can be reused. +#[cfg(feature = "tokio-unixsocks")] +#[derive(Debug)] +pub struct TokioUnix { + inner: tokio::net::UnixStream, +} + +#[cfg(feature = "tokio-unixsocks")] +impl TokioUnix { + pub fn new(inner: tokio::net::UnixStream) -> Self { + Self { inner } + } +} + +/// [`tokio_rustls::TlsStream`] wrapper. +/// +/// Able to perform encrypted asynchronous I/O operations on this stream. +/// +/// The inner stream cannot be shared. However this structure can be reused. +#[cfg(feature = "tokio-tls")] +#[derive(Debug)] +pub struct TokioRustls { + inner: tokio_rustls::TlsStream, +} + +#[cfg(feature = "tokio-tls")] +impl TokioRustls { + pub fn new(inner: tokio_rustls::TlsStream) -> Self { + Self { inner } + } +} + +#[cfg(any(feature = "tokio-net", feature = "tokio-unixsocks"))] +struct TokioReceive<'a, R> { + buf: &'a mut [u8], + reader: &'a mut R, +} + +#[cfg(any(feature = "tokio-net", feature = "tokio-unixsocks"))] +struct TokioSend<'a, W> { + buf: &'a [u8], + writer: &'a mut W, +} + +#[cfg(feature = "tokio-tls")] +struct TokioRustlsReceive<'a, R> { + buf: ReadBuf<'a>, + reader: &'a mut R, +} + +#[cfg(feature = "async")] +fn handle_async_poll_error(e: std::io::Error, cx: &mut Context<'_>) +-> Poll> { + if e.kind() == std::io::ErrorKind::WouldBlock { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + Poll::Ready(Err(e)) +} + +#[cfg(all(feature = "tokio-net", feature = "async"))] +impl Future for TokioReceive<'_, tokio::net::TcpStream> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + let read_ready = this.reader.poll_read_ready(cx); + match read_ready { + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + match this.reader.try_read(this.buf) { + Ok(v) => Poll::Ready(Ok(v)), + Err(e) => handle_async_poll_error(e, cx), + } + } +} + +#[cfg(all(feature = "tokio-net", feature = "async"))] +impl Future for TokioSend<'_, tokio::net::TcpStream> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + let write_ready = this.writer.poll_write_ready(cx); + match write_ready { + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(e)) => return handle_async_poll_error(e, cx), + Poll::Pending => return Poll::Pending, + } + match this.writer.try_write(this.buf) { + Ok(v) => Poll::Ready(Ok(v)), + Err(e) => handle_async_poll_error(e, cx), + } + } +} + +#[cfg(all(feature = "async", feature = "tokio-net"))] +impl AsyncReceive for TokioInet { + /// Perform an asynchronous read from the underlying tcp stream. + /// + /// When awaited - this yields the number of bytes read into the buffer (or the corresponding + /// [`std::io::Error`] if unable to read from the resource). + fn recv<'a>(&'a mut self, buf: &'a mut [u8]) -> + impl Future> + { + TokioReceive { buf, reader: &mut self.inner } + } +} + +#[cfg(all(feature = "async", feature = "tokio-net"))] +impl AsyncSend for TokioInet { + /// Perform an asynchronous write to the underlying tcp stream. + /// + /// When awaited - this yields the number of bytes written from the buffer (or the + /// corresponding [`std::io::Error`] if unable to read from the resource). + fn send<'a>(&'a mut self, buf: &'a [u8]) -> + impl Future> + { + TokioSend { buf, writer: &mut self.inner } + } +} + +#[cfg(all(feature = "tokio-unixsocks", feature = "async"))] +impl Future for TokioReceive<'_, tokio::net::UnixStream> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + let read_ready = this.reader.poll_read_ready(cx); + match read_ready { + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(e)) => return handle_async_poll_error(e, cx), + Poll::Pending => return Poll::Pending, + } + match this.reader.try_read(this.buf) { + Ok(v) => Poll::Ready(Ok(v)), + Err(e) => handle_async_poll_error(e, cx), + } + } +} + +#[cfg(all(feature = "tokio-unixsocks", feature = "async"))] +impl Future for TokioSend<'_, tokio::net::UnixStream> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + let write_ready = this.writer.poll_write_ready(cx); + match write_ready { + Poll::Ready(Ok(())) => {}, + Poll::Ready(Err(e)) => return handle_async_poll_error(e, cx), + Poll::Pending => return Poll::Pending, + } + match this.writer.try_write(this.buf) { + Ok(v) => Poll::Ready(Ok(v)), + Err(e) => handle_async_poll_error(e, cx), + } + } +} + +#[cfg(all(feature = "async", feature = "tokio-unixsocks"))] +impl AsyncReceive for TokioUnix { + /// Perform an asynchronous read from the underlying unix stream. + /// + /// When awaited - this yields the number of bytes read into the buffer (or the corresponding + /// [`std::io::Error`] if unable to read from the resource). + fn recv<'a>(&'a mut self, buf: &'a mut [u8]) -> + impl Future> + { + TokioReceive { buf, reader: &mut self.inner } + } +} + +#[cfg(all(feature = "async", feature = "tokio-unixsocks"))] +impl AsyncSend for TokioUnix { + /// Perform an asynchronous write to the underlying unix stream. + /// + /// When awaited - this yields the number of bytes written from the buffer (or the + /// corresponding [`std::io::Error`] if unable to read from the resource). + fn send<'a>(&'a mut self, buf: &'a [u8]) -> + impl Future> + { + TokioSend { buf, writer: &mut self.inner } + } +} + +#[cfg(all(feature = "async", feature = "tokio-tls"))] +impl Future for TokioRustlsReceive<'_, tokio_rustls::TlsStream> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + let reader = Pin::new(&mut this.reader); + match reader.poll_read(cx, &mut this.buf) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(this.buf.filled().len())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(all(feature = "async", feature = "tokio-tls"))] +impl Future for TokioSend<'_, tokio_rustls::TlsStream> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + let writer = Pin::new(&mut this.writer); + match writer.poll_write(cx, this.buf) { + Poll::Ready(Ok(l)) => Poll::Ready(Ok(l)), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(all(feature = "async", feature = "tokio-tls"))] +impl AsyncReceive for TokioRustls { + /// Perform an asynchronous read from the underlying tcp stream. + /// + /// When awaited - this yields the number of bytes read into the buffer (or the corresponding + /// [`std::io::Error`] if unable to read from the resource). + fn recv<'a>(&'a mut self, buf: &'a mut [u8]) -> + impl Future> + { + TokioRustlsReceive { buf: ReadBuf::new(buf), reader: &mut self.inner } + } +} + +#[cfg(all(feature = "async", feature = "tokio-tls"))] +impl AsyncSend for TokioRustls { + /// Perform an asynchronous write to the underlying tcp stream. + /// + /// When awaited - this yields the number of bytes written from the buffer (or the + /// corresponding [`std::io::Error`] if unable to read from the resource). + fn send<'a>(&'a mut self, buf: &'a [u8]) -> + impl Future> + { + TokioSend { buf, writer: &mut self.inner } + } +} + +#[cfg(test)] +mod streams { + use crate::io::try_find_update_jumptable; + use super::try_find; + + #[test] + fn find_whole_pattern() { + let src = vec![1_u8, 2, 3, 4, 5, 6, 7, 8, 9]; + assert_eq!(try_find(src.as_slice(), &[2, 3], &mut 0, &[0, 0], &0), Some(1)); + } + + #[test] + fn find_whole_pattern_with_offset() { + let src = vec![1_u8, 2, 3, 4, 5, 6, 7, 8, 9]; + assert_eq!(try_find(src.as_slice(), &[2, 3], &mut 0, &[0, 0], &2), Some(3)); + } + + #[test] + fn find_partial_pattern() { + let src = vec![1_u8, 2, 3, 4, 5, 6, 7, 8, 9]; + assert_eq!(try_find(src.as_slice(), &[0, 1], &mut 1, &[0, 0], &1), Some(0)); + } + + #[test] + fn find_nonexistent_pattern() { + let src = vec![1_u8, 2, 3, 4, 5, 6, 7, 8, 9]; + assert_eq!(try_find(src.as_slice(), &[0, 1], &mut 0, &[0, 0], &0), None); + } + + #[test] + fn find_partial_pattern_check_update_ptr() { + let src = vec![1_u8, 2, 3, 4, 5, 6, 7, 8, 9]; + let mut ptr = 0; + assert_eq!(try_find(src.as_slice(), &[7, 8, 9, 10], &mut ptr, &[0, 0], &0), None); + assert_eq!(ptr, 3); + } + + #[test] + fn find_pattern_verify_jumptable_jumps() { + let src = vec![1_u8, 2, 3, 1, 2, 3, 1]; + let mut ptr = 0; + assert_eq!(try_find( + src.as_slice(), + &[1, 2, 3, 1, 2, 3, 5], + &mut ptr, + &[0, 0, 0, 0, 1, 2, 3], + &0 + ), None); + assert_eq!(ptr, 4); + } + + #[test] + fn jumptable_test_multiple_jumps_until_match() { + let pat = vec![1_u8, 2, 3, 4, 5, 6, 7]; + let jt = vec![0_usize, 0, 1, 2, 3, 4, 5]; + let mut ptr = 6; + try_find_update_jumptable(&2, &pat, &mut ptr, &jt); + assert_eq!(ptr, 2); + } + + #[test] + fn jumptable_test_multiple_jumps_until_match_on_zero() { + let pat = vec![1_u8, 2, 3, 4, 5, 6, 7]; + let jt = vec![0_usize, 0, 1, 2, 3, 4, 5]; + let mut ptr = 6; + try_find_update_jumptable(&1, &pat, &mut ptr, &jt); + assert_eq!(ptr, 1); + } + + #[test] + fn jumptable_test_multiple_jumps_until_fail() { + let pat = vec![1_u8, 2, 3, 4, 5, 6, 7]; + let jt = vec![0_usize, 0, 1, 2, 3, 4, 5]; + let mut ptr = 6; + try_find_update_jumptable(&8, &pat, &mut ptr, &jt); + assert_eq!(ptr, 0); + } +} + +#[cfg(test)] +mod prependable_stream_sync { + use super::{PrependableStream, Receive, TestSyncStream}; + + #[test] + fn prepend_nothing() { + let src = "Hello, world!\n".as_bytes().to_vec(); + let mut stream = PrependableStream::new(TestSyncStream::<5> { + inner: &src, ptr: 0, received: Vec::default() + }); + let mut buf = [0_u8; 1024]; + assert_eq!(stream.recv(&mut buf).unwrap(), 5); + assert_eq!(&buf[..6], b"Hello\0"); + } + + #[test] + fn prepend_something() { + let src = "Hello, world!\n".as_bytes().to_vec(); + let mut stream = PrependableStream::new(TestSyncStream::<5> { + inner: &src, ptr: 0, received: Vec::default() + }); + stream.prepend_to_read(b"Well... "); + let mut buf = [0_u8; 1024]; + assert_eq!(stream.recv(&mut buf).unwrap(), 8); + assert_eq!(stream.recv(&mut buf[8..]).unwrap(), 5); + assert_eq!(&buf[..14], b"Well... Hello\0"); + } + + #[test] + fn prepend_something_multiple_reads() { + let src = "Hello, world!\n".as_bytes().to_vec(); + let mut stream = PrependableStream::new(TestSyncStream::<5> { + inner: &src, ptr: 0, received: Vec::default() + }); + stream.prepend_to_read(b"Well... "); + let mut buf = [0_u8; 1024]; + assert_eq!(stream.recv(&mut buf).unwrap(), 8); + assert_eq!(stream.recv(&mut buf[8..]).unwrap(), 5); + assert_eq!(stream.recv(&mut buf[13..]).unwrap(), 5); + assert_eq!(&buf[..19], b"Well... Hello, wor\0"); + } +} + +#[cfg(all(test, feature = "async", any(feature = "tokio-net", feature = "tokio-unixsocks")))] +mod prependable_stream_async { + use super::{PrependableStream, AsyncReceive}; + use tokio::{time::sleep, test}; + use std::time::Duration; + + struct TestStream<'a, const CHUNK_SIZE: usize> { + inner: &'a Vec, + ptr: usize, + } + + impl<'a, const CHUNK_SIZE: usize> AsyncReceive for TestStream<'a, CHUNK_SIZE> { + async fn recv<'b>(&'b mut self, buf: &'b mut [u8]) -> Result { + sleep(Duration::from_millis(1)).await; + let end_ptr = std::cmp::min(self.ptr + CHUNK_SIZE, self.inner.len()); + if self.ptr >= self.inner.len() { + return Ok(0); + } + let size = end_ptr - self.ptr; + buf[0..size].clone_from_slice(&self.inner[self.ptr..end_ptr]); + self.ptr += size; + Ok(size) + } + } + + #[test] + async fn prepend_nothing() { + let src = "Hello, world!\n".as_bytes().to_vec(); + let mut stream = PrependableStream::new(TestStream::<5> { inner: &src, ptr: 0 }); + let mut buf = [0_u8; 1024]; + assert_eq!(stream.recv(&mut buf).await.unwrap(), 5); + assert_eq!(&buf[..6], b"Hello\0"); + } + + #[test] + async fn prepend_something() { + let src = "Hello, world!\n".as_bytes().to_vec(); + let mut stream = PrependableStream::new(TestStream::<5> { inner: &src, ptr: 0 }); + stream.prepend_to_read(b"Well... "); + let mut buf = [0_u8; 1024]; + assert_eq!(stream.recv(&mut buf).await.unwrap(), 8); + assert_eq!(&buf[..9], b"Well... \0"); + } + + #[test] + async fn prepend_something_multiple_reads() { + let src = "Hello, world!\n".as_bytes().to_vec(); + let mut stream = PrependableStream::new(TestStream::<5> { inner: &src, ptr: 0 }); + stream.prepend_to_read(b"Well... "); + let mut buf = [0_u8; 1024]; + assert_eq!(stream.recv(&mut buf).await.unwrap(), 8); + assert_eq!(stream.recv(&mut buf[8..]).await.unwrap(), 5); + assert_eq!(stream.recv(&mut buf[13..]).await.unwrap(), 5); + assert_eq!(&buf[..19], b"Well... Hello, wor\0"); + } +} + +#[cfg(test)] +mod sync_reader { + use crate::io::{ReaderError, ReaderValue}; + use super::{SyncReader, TestSyncStream}; + + #[test] + fn read_until_valid_onesymbol() { + let src = "Hello, world!\n".as_bytes().to_vec(); + let target = "Hello, world!".as_bytes().to_vec(); + let mut stream = TestSyncStream::<4> { inner: &src, ptr: 0, received: Vec::default() }; + let mut buf = [0_u8; 1024]; + let mut sync_reader = SyncReader::new(&mut stream); + let read = sync_reader.recv_until(&[b'\n'], &[0], &mut buf).unwrap(); + assert_eq!(read, ReaderValue::ExactRead{up_to_delimiter: target.as_slice()}); + } + + #[test] + fn read_until_valid_onesymbol_with_leak() { + let src = "Hello, world".as_bytes().to_vec(); + let target = ( + "Hello,".as_bytes().to_vec(), + "world".as_bytes().to_vec() + ); + let mut stream = TestSyncStream::<6> { inner: &src, ptr: 0, received: Vec::default() }; + let mut buf = [0_u8; 1024]; + let mut sync_reader = SyncReader::new(&mut stream); + let read = sync_reader.recv_until(&[b' '], &[0], &mut buf).unwrap(); + assert_eq!(read, ReaderValue::LeakyRead { + up_to_delimiter: target.0.as_slice(), + rest: target.1.as_slice(), + }); + } + + #[test] + fn read_until_valid_boundary_multisymbol_with_leak() { + let src = "Hello, world".as_bytes().to_vec(); + let target = ( + "Hello".as_bytes().to_vec(), + "world".as_bytes().to_vec() + ); + let mut stream = TestSyncStream::<6> { inner: &src, ptr: 0, received: Vec::default() }; + let mut buf = [0_u8; 1024]; + let mut sync_reader = SyncReader::new(&mut stream); + let read = sync_reader.recv_until(&[b',', b' '], &[0, 0], &mut buf).unwrap(); + assert_eq!(read, ReaderValue::LeakyRead { + up_to_delimiter: target.0.as_slice(), + rest: target.1.as_slice(), + }); + } + + #[test] + fn read_no_delimiter() { + let src = "Hello, world!".as_bytes().to_vec(); + let mut stream = TestSyncStream::<6> { inner: &src, ptr: 0, received: Vec::default() }; + let mut buf = [0_u8; 1024]; + let mut sync_reader = SyncReader::new(&mut stream); + match sync_reader.recv_until(&[b'\n'], &[0], &mut buf) { + Err(ReaderError::NoData) => assert!(true), + v @ _ => { println!("{:?}", v); assert!(false); }, + }; + } +} + +#[cfg(all(test, feature = "async", any(feature = "tokio-net", feature = "tokio-unixsocks")))] +mod async_reader { + use super::{ReaderError, ReaderValue, TestAsyncStream}; + use tokio::test; + use super::AsyncReader; + + #[test] + async fn read_until_valid_onesymbol() { + let src = "Hello, world!\n".as_bytes().to_vec(); + let target = "Hello, world!".as_bytes().to_vec(); + let mut stream = TestAsyncStream::<4> { inner: &src, ptr: 0, received: Vec::default() }; + let mut buf = [0_u8; 1024]; + let mut sync_reader = AsyncReader::new(&mut stream); + let read = sync_reader.recv_until(&[b'\n'], &[0], &mut buf).await.unwrap(); + assert_eq!(read, ReaderValue::ExactRead{up_to_delimiter: target.as_slice()}); + } + + #[test] + async fn read_until_valid_onesymbol_with_leak() { + let src = "Hello, world".as_bytes().to_vec(); + let target = ( + "Hello,".as_bytes().to_vec(), + "world".as_bytes().to_vec() + ); + let mut stream = TestAsyncStream::<6> { inner: &src, ptr: 0, received: Vec::default() }; + let mut buf = [0_u8; 1024]; + let mut sync_reader = AsyncReader::new(&mut stream); + let read = sync_reader.recv_until(&[b' '], &[0], &mut buf).await.unwrap(); + assert_eq!(read, ReaderValue::LeakyRead { + up_to_delimiter: target.0.as_slice(), + rest: target.1.as_slice(), + }); + } + + #[test] + async fn read_until_valid_boundary_multisymbol_with_leak() { + let src = "Hello, world".as_bytes().to_vec(); + let target = ( + "Hello".as_bytes().to_vec(), + "world".as_bytes().to_vec() + ); + let mut stream = TestAsyncStream::<6> { inner: &src, ptr: 0, received: Vec::default() }; + let mut buf = [0_u8; 1024]; + let mut sync_reader = AsyncReader::new(&mut stream); + let read = sync_reader.recv_until(&[b',', b' '], &[0, 0], &mut buf).await.unwrap(); + assert_eq!(read, ReaderValue::LeakyRead { + up_to_delimiter: target.0.as_slice(), + rest: target.1.as_slice(), + }); + } + + #[test] + async fn read_no_delimiter() { + let src = "Hello, world!".as_bytes().to_vec(); + let mut stream = TestAsyncStream::<6> { inner: &src, ptr: 0, received: Vec::default() }; + let mut buf = [0_u8; 1024]; + let mut sync_reader = AsyncReader::new(&mut stream); + match sync_reader.recv_until(&[b'\n'], &[0], &mut buf).await { + Err(ReaderError::NoData) => assert!(true), + v @ _ => { println!("{:?}", v); assert!(false); }, + }; + } +} diff --git a/lib/inferium/src/lib.rs b/lib/inferium/src/lib.rs new file mode 100644 index 0000000..b6c8930 --- /dev/null +++ b/lib/inferium/src/lib.rs @@ -0,0 +1,32 @@ +pub mod settings; + +mod io; +pub use io::{StdInet, StdUnix}; +#[cfg(feature = "tokio-net")] +pub use io::TokioInet; +#[cfg(feature = "tokio-unixsocks")] +pub use io::TokioUnix; +#[cfg(feature = "tokio-tls")] +pub use io::TokioRustls; +#[cfg(feature = "testing")] +pub use io::TestSyncStream; +#[cfg(all(feature = "testing", feature = "async"))] +pub use io::TestAsyncStream; + +mod proto; +pub use proto::h1; + +mod status; +pub use status::Status; + +mod headers; +pub use headers::{HeaderKey, HeaderValue}; + +mod method; +pub use method::Method; + +mod path; +pub use path::{HttpPath, HttpPathParseError}; + +mod body; +pub use body::{Incoming, SizedIn, Outgoing}; diff --git a/lib/inferium/src/method.rs b/lib/inferium/src/method.rs new file mode 100644 index 0000000..34a6637 --- /dev/null +++ b/lib/inferium/src/method.rs @@ -0,0 +1,51 @@ +use proc::AutoimplMethods; + +#[derive(AutoimplMethods)] +#[allow(non_camel_case_types)] +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Clone, Copy)] +pub enum Method { + GET, + HEAD, + POST, + PUT, + DELETE, + CONNECT, + OPTIONS, + TRACE, + PATCH, +} + +impl std::fmt::Display for Method { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", <&'static str>::from(*self)) + } +} + +#[cfg(test)] +mod test { + use crate::method::Method; + + #[test] + fn method_from_raw() { + let src = "GET"; + assert_eq!(src.parse::(), Ok(Method::GET)); + } + + #[test] + fn method_to_raw() { + assert_eq!(<&'static str>::from(Method::GET), "GET"); + } + + #[test] + fn method_lowercase() { + let src = "get"; + assert_eq!(src.parse::(), Err(())); + } + + #[test] + fn method_nonexistent() { + let src = "GOST"; + assert_eq!(src.parse::(), Err(())); + } +} diff --git a/lib/inferium/src/path.rs b/lib/inferium/src/path.rs new file mode 100644 index 0000000..f8fa300 --- /dev/null +++ b/lib/inferium/src/path.rs @@ -0,0 +1,286 @@ +use std::collections::HashMap; + +/// A valid HTTP path with possible GET parameters. +/// +/// This struct provides multiple guarantees: +/// - The path is a valid [`String`] +/// - The path's first character is "/" +/// - The path is not empty (the shortest allowed path is "/") +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug)] +pub struct HttpPath { + pub(crate) path: String, + pub(crate) params: Option>, +} + + +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug)] +pub enum HttpPathParseError { + /// The URI could not be converted into a valid [`String`]. + InvalidString, + /// The path does not begin with a slash. + NoslashStart, + /// The path contains invalid characters as per + /// [RFC3986](https://datatracker.ietf.org/doc/html/rfc3986#section-3.3). + InvalidPath, + /// The query could not be parsed. + InvalidGetParams, +} + +impl std::fmt::Display for HttpPath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.serialize_to_string()) + } +} + +impl std::str::FromStr for HttpPath { + type Err = HttpPathParseError; + + fn from_str(s: &str) -> Result { + if !s.starts_with('/') { + return Err(HttpPathParseError::NoslashStart); + } + let mut value = s.split('?'); + let Some(path) = validate_http_path_chars(&mut value.next().unwrap().chars()) else { + return Err(HttpPathParseError::InvalidPath); + }; + let params = parse_get_params(value.next()) + .map_err(|_| HttpPathParseError::InvalidGetParams)?; + Ok(HttpPath { + path, + params, + }) + } +} + +impl TryFrom<&[u8]> for HttpPath { + type Error = HttpPathParseError; + + fn try_from(value: &[u8]) -> Result { + let value = String::from_utf8(Vec::from(value)) + .map_err(|_| HttpPathParseError::InvalidString)?; + if !value.starts_with('/') { + return Err(HttpPathParseError::NoslashStart); + } + let mut value = value.split('?'); + let Some(path) = validate_http_path_chars(&mut value.next().unwrap().chars()) else { + return Err(HttpPathParseError::InvalidPath); + }; + let params = parse_get_params(value.next()) + .map_err(|_| HttpPathParseError::InvalidGetParams)?; + Ok(HttpPath { + path, + params, + }) + } +} + +fn serialize_params(params: &HashMap) -> Vec { + let mut res = Vec::new(); + let mut first = true; + for entry in params.iter() { + if !first { + res.push(b'&'); + } + first = false; + res.extend_from_slice(entry.0.as_bytes()); + res.push(b'='); + res.extend_from_slice(entry.1.as_bytes()); + } + res +} + +fn serialize_params_to_string(params: &HashMap) -> String { + let mut res = String::new(); + let mut first = true; + for entry in params.iter() { + if !first { + res.push('&'); + } + first = false; + res.push_str(entry.0); + res.push('='); + res.push_str(entry.1); + } + res +} + +impl HttpPath { + pub(crate) fn serialize(&self) -> Vec { + let mut res = Vec::new(); + res.extend_from_slice(self.path.as_bytes()); + if let Some(ref p) = self.params { + res.push(b'?'); + res.extend_from_slice(&serialize_params(p)); + } + res + } + + fn serialize_to_string(&self) -> String { + let mut res = String::new(); + res.push_str(&self.path); + if let Some(ref p) = self.params { + res.push('?'); + res.push_str(&serialize_params_to_string(p)); + } + res + } + + #[inline] + pub fn path(&self) -> &str { + &self.path + } + + #[inline] + pub fn params(&self) -> &Option> { + &self.params + } +} + +fn parse_get_params(raw: Option<&str>) -> +Result>, ()> { + let raw = match raw { + Some(v) => v, + None => return Ok(None), + }; + let mut res = HashMap::new(); + let mut tmp = String::new(); + let mut key = None; + for cur in raw.chars() { + parse_get_params_handle_char( + &mut tmp, + cur, + &mut key, + &mut res, + )?; + } + if let (Some(k), 1..) = (key, tmp.len()) { + res.insert(k.to_string(), tmp); + }; + Ok(Some(res)) +} + +#[inline] +fn parse_get_params_handle_char( + tmp: &mut String, + cur: char, + key: &mut Option, + res: &mut HashMap, +) -> Result<(), ()> { + println!("cur: {cur}, key: {key:?}"); + match (cur, &key) { + ('&', Some(k)) => { + res.insert(k.to_string(), std::mem::take(tmp)); + *key = None; + Ok(()) + }, + ('=', None) => { + *key = Some(std::mem::take(tmp)); + Ok(()) + }, + ('&', None) => Err(()), + ('=', Some(_)) => Err(()), + (c, _) => { + tmp.push(c); + Ok(()) + } + } +} + +fn validate_http_path_chars>(iter: &mut I) -> Option { + let mut res = String::with_capacity(iter.size_hint().1?); + for ch in iter { + match ch { + 'a'..='z' | + 'A'..='Z' | + '0'..='9' | + '/' | + '.' | '-' | '_' | '~' | '!' | '$' | '&' | '\'' | + '(' | ')' | '*' | '+' | ',' | ';' | '=' | ':' | '@' => res.push(ch), + _ => return None, + } + } + Some(res) +} + +#[cfg(test)] +mod parse_path { + use std::collections::HashMap; + use super::{HttpPath, HttpPathParseError}; + + macro_rules! test { + ($name: ident, $raw_uri: literal, ok path $path: literal) => { + test!($name, $raw_uri, Ok(HttpPath { path: $path.to_string(), params: None })); + }; + + ( + $name: ident, $raw_uri: literal, + ok path $path: literal params [$($key: ident : $value: literal),* $(,)?] + ) => { + test!($name, $raw_uri, Ok(HttpPath { + path: $path.to_string(), + params: Some(HashMap::from([$((stringify!($key).to_string(), $value.to_string())),*])), + })); + }; + + ($name: ident, $raw_uri: literal, $res: expr) => { + #[test] + fn $name() { + let raw = $raw_uri.as_bytes().to_vec(); + let raw = raw.as_slice(); + assert_eq!(HttpPath::try_from(raw), $res); + } + }; + + } + + test!(valid_path_noparams, "/hello/world", ok path "/hello/world"); + test!(invalid_path_noslash, "hello/world", Err(HttpPathParseError::NoslashStart)); + test!(valid_singleparam, "/hello/world?hello=world", ok path "/hello/world" params [ + hello: "world" + ]); + test!(valid_multiparam, "/hello/world?hello=world&how=areyou", ok path "/hello/world" params [ + hello: "world", + how: "areyou", + ]); + test!(invalid_params, "/hello/world?&", Err(HttpPathParseError::InvalidGetParams)); + test!(path_invalid_char, "/hč", Err(HttpPathParseError::InvalidPath)); +} + +#[cfg(test)] +mod serialize_path { + use std::collections::HashMap; + use super::HttpPath; + + macro_rules! test { + (@uri $path: literal, [$(,)?]) => { + HttpPath { + path: $path.to_string(), + params: None, + } + }; + (@uri $path: literal, [$($qk: ident : $qv: literal),+$(,)?]) => { + HttpPath { + path: $path.to_string(), + params: Some(HashMap::from([$((stringify!($qk).to_string(), $qv.to_string())),*])), + } + }; + + ($name: ident, $path: literal [$($qk: ident : $qv: literal),*$(,)?], $target: literal) => { + #[test] + fn $name() { + let src = test!(@uri $path, [$($qk:$qv),*]); + assert_eq!(src.serialize(), $target); + } + }; + } + + test!(simple_slash, "/"[], b"/"); + test!(slash_with_singleparam_query, "/"[ + action: "none", + ], b"/?action=none"); + test!(path_with_singleparam_query, "/hello/world"[ + hello: "world", + ], b"/hello/world?hello=world"); +} diff --git a/lib/inferium/src/proto/h1/exports.rs b/lib/inferium/src/proto/h1/exports.rs new file mode 100644 index 0000000..6afb61c --- /dev/null +++ b/lib/inferium/src/proto/h1/exports.rs @@ -0,0 +1,666 @@ +use crate::{ + body::{ChunkedIn, Incoming, Outgoing, SizedIn, SizedOut}, + headers::HeaderKey, io::{self, PrependableStream, Receive, Send} +}; +use super::{ + head::{BadRequest, BadResponse, RequestHead, ResponseHead}, + stream_handler::{ + ExpectedBody, + StreamHandler, + StreamHandlerReceiveError, + StreamHandlerSendError + } +}; +#[cfg(feature = "async")] +use {crate::io::{AsyncSend, AsyncReceive}, std::future::Future}; + +#[derive(Debug)] +enum OutgoingBody { + Sized(usize), + None, +} + +/// Synchronous HTTP client handler. +/// +/// Calling I/O operations on this will block. Refer to [`AsyncClient`] (with enabled `async` +/// feature) for the asynchronous equivalent. +/// +/// You can send request headers and receive responses from the server. Sending requests with body +/// is supported, but relies on the caller to send the body correctly - inferium does not force you +/// to respect the HTTP protocol completely. +/// +/// This structure owns the underlying stream and can send multiple requests and responses to the +/// other endpoint no matter the protocol (inferium will allow you to send multiple requests in a +/// single HTTP/1.0 connection if you wish to). +#[derive(Debug)] +pub struct SyncClient { + handler: StreamHandler, + should_send_body: OutgoingBody, +} + +/// Synchronous HTTP server handler. +/// +/// Calling I/O operations on this will block. Refer to [`AsyncServer`] (with enabled `async` +/// feature) for the asynchronous equivalent. +/// +/// You can receive requests and send response headers to the client. When sending responses with +/// body, the caller must send the body correctly - inferium does not force you to respect the HTTP +/// protocol completely. +/// +/// This structure owns the underlying stream and can send multiple requests and responses to the +/// other endpoint no matter the protocol (inferium will allow you to send multiple responses in a +/// single HTTP/1.0 connection if you wish to). +#[derive(Debug)] +pub struct SyncServer { + handler: StreamHandler, + should_send_body: OutgoingBody, +} + +/// Asynchronous HTTP client handler. +/// +/// Calling I/O operations on this will not block and return a future. Refer to [`SyncClient`] +/// for the synchronous equivalent. +/// +/// You can send request headers and receive responses from the server. Sending requests with body +/// is supported, but relies on the caller to send the body correctly - inferium does not force you +/// to respect the HTTP protocol completely. +/// +/// This structure owns the underlying stream and can send multiple requests and responses to the +/// other endpoint no matter the protocol (inferium will allow you to send multiple requests in a +/// single HTTP/1.0 connection if you wish to). +#[cfg(feature = "async")] +#[derive(Debug)] +pub struct AsyncClient { + handler: StreamHandler, + should_send_body: OutgoingBody, +} + +/// Asynchronous HTTP server handler. +/// +/// Calling I/O operations on this will not block and return a future. Refer to [`SyncServer`] +/// for the synchronous equivalent. +/// +/// You can receive requests and send response headers to the client. When sending responses with +/// body, the caller must send the body correctly - inferium does not force you to respect the HTTP +/// protocol completely. +/// +/// This structure owns the underlying stream and can send multiple requests and responses to the +/// other endpoint no matter the protocol (inferium will allow you to send multiple responses in a +/// single HTTP/1.0 connection if you wish to). +#[cfg(feature = "async")] +#[derive(Debug)] +pub struct AsyncServer { + handler: StreamHandler, + should_send_body: OutgoingBody, +} + +macro_rules! autoimpl_new { +($for: ident [$(#$attr: tt $of_type: ty),*$(,)?]) => { + $(#$attr impl $for<$of_type> { + pub fn new(stream: $of_type) -> Self { + Self { handler: StreamHandler::new(stream), should_send_body: OutgoingBody::None } + } + })* +}; +($for: ident [$($of_type: ty),*$(,)?]) => { + $(impl $for<$of_type> { + pub fn new(stream: $of_type) -> Self { + Self { handler: StreamHandler::new(stream), should_send_body: OutgoingBody::None } + } + })* +}; +($($for: ident $impl_list: tt),+$(,)?) => { + $(autoimpl_new!($for $impl_list);)+ +}; +} + +autoimpl_new! { + // All supported synchronous I/O streams + SyncClient [ io::StdInet, io::StdUnix ], + SyncServer [ io::StdInet, io::StdUnix ], + + // All supported asynchronous I/O streams + AsyncClient [ + #[cfg(all(feature = "async", feature = "tokio-net"))] io::TokioInet, + #[cfg(all(feature = "async", feature = "tokio-unixsocks"))] io::TokioUnix, + #[cfg(all(feature = "async", feature = "tokio-tls"))] io::TokioRustls, + ], + AsyncServer [ + #[cfg(all(feature = "async", feature = "tokio-net"))] io::TokioInet, + #[cfg(all(feature = "async", feature = "tokio-unixsocks"))] io::TokioUnix, + #[cfg(all(feature = "async", feature = "tokio-tls"))] io::TokioRustls, + ], +} + +#[cfg(any(feature = "testing", test))] +impl<'a, const CHUNK_SIZE: usize> SyncClient> { + pub fn new(stream: io::TestSyncStream<'a, CHUNK_SIZE>) -> Self { + Self { handler: StreamHandler::new(stream), should_send_body: OutgoingBody::None } + } +} + +#[cfg(all(any(feature = "testing", test), feature = "async"))] +impl<'a, const CHUNK_SIZE: usize> AsyncClient> { + pub fn new(stream: io::TestAsyncStream<'a, CHUNK_SIZE>) -> Self { + Self { handler: StreamHandler::new(stream), should_send_body: OutgoingBody::None } + } +} + +#[cfg(any(feature = "testing", test))] +impl<'a, const CHUNK_SIZE: usize> SyncServer> { + pub fn new(stream: io::TestSyncStream<'a, CHUNK_SIZE>) -> Self { + Self { handler: StreamHandler::new(stream), should_send_body: OutgoingBody::None } + } +} + +#[cfg(all(any(feature = "testing", test), feature = "async"))] +impl<'a, const CHUNK_SIZE: usize> AsyncServer> { + pub fn new(stream: io::TestAsyncStream<'a, CHUNK_SIZE>) -> Self { + Self { handler: StreamHandler::new(stream), should_send_body: OutgoingBody::None } + } +} + +#[derive(Debug)] +pub enum ClientSendError { + /// This is a usage error. + /// + /// It is returned when the caller fails to either send the required body or receive the + /// advertised body from the other endpoint, and thus the protocol state is violated. + StateViolated, + InvalidContentLength, + IO(std::io::Error), +} + +#[derive(Debug)] +pub enum ServerSendError { + /// This is a usage error. + /// + /// It is returned when the caller fails to either send the required body or receive the + /// advertised body from the other endpoint, and thus the protocol state is violated. + StateViolated, + InvalidContentLength, + IO(std::io::Error), +} + +impl From for ClientSendError { + fn from(value: StreamHandlerSendError) -> Self { + match value { + StreamHandlerSendError::RequiresBodyPolling => ClientSendError::StateViolated, + StreamHandlerSendError::IO(e) => ClientSendError::IO(e), + } + } +} + +impl From for ServerSendError { + fn from(value: StreamHandlerSendError) -> Self { + match value { + StreamHandlerSendError::RequiresBodyPolling => ServerSendError::StateViolated, + StreamHandlerSendError::IO(e) => ServerSendError::IO(e), + } + } +} + +#[derive(Debug)] +pub enum ClientReceiveError { + /// This is a usage error. + /// + /// It is returned when the caller fails to either send the required body or receive the + /// advertised body from the other endpoint, and thus the protocol state is violated. + StateViolated, + /// If the response head being received is too large to fit into the pre-allocated buffer. + HeadTooLarge, + /// The response head could not be parsed. If additional details are known - they will be + /// contained in the inner value. + InvalidHead(Option), + IO(std::io::Error), +} + +#[derive(Debug)] +pub enum ServerReceiveError { + /// This is a usage error. + /// + /// It is returned when the caller fails to either send the required body or receive the + /// advertised body from the other endpoint, and thus the protocol state is violated. + StateViolated, + /// If the request head being received is too large to fit into the pre-allocated buffer. + HeadTooLarge, + /// The request head could not be parsed. If additional details are known - they will be + /// contained in the inner value. + InvalidHead(Option), + IO(std::io::Error), +} + +impl From> for ClientReceiveError { + fn from(value: StreamHandlerReceiveError) -> Self { + match value { + StreamHandlerReceiveError::IO(e) => Self::IO(e), + StreamHandlerReceiveError::NoData => Self::InvalidHead(None), + StreamHandlerReceiveError::ParsingError(e) => Self::InvalidHead(Some(e)), + StreamHandlerReceiveError::HeaderTooLarge => Self::HeadTooLarge, + StreamHandlerReceiveError::RequiresBodyPolling => Self::StateViolated, + StreamHandlerReceiveError::InvalidExpectedBody => Self::InvalidHead(None), + } + } +} + +impl From> for ServerReceiveError { + fn from(value: StreamHandlerReceiveError) -> Self { + match value { + StreamHandlerReceiveError::IO(e) => Self::IO(e), + StreamHandlerReceiveError::NoData => Self::InvalidHead(None), + StreamHandlerReceiveError::ParsingError(e) => Self::InvalidHead(Some(e)), + StreamHandlerReceiveError::HeaderTooLarge => Self::HeadTooLarge, + StreamHandlerReceiveError::RequiresBodyPolling => Self::StateViolated, + StreamHandlerReceiveError::InvalidExpectedBody => Self::InvalidHead(None), + } + } +} + +/// A generic error possibly returned when sending a body. +#[derive(Debug)] +pub enum BodySendError { + /// The real body input length to send did not match the body length advertised in the headers. + LengthDiscrepancy, + IO(std::io::Error), +} + +impl TryFrom for ServerSendError { + type Error = (); + + fn try_from(value: BodySendError) -> Result { + match value { + BodySendError::LengthDiscrepancy => Err(()), + BodySendError::IO(e) => Ok(Self::IO(e)), + } + } +} + +impl TryFrom for ClientSendError { + type Error = (); + + fn try_from(value: BodySendError) -> Result { + match value { + BodySendError::LengthDiscrepancy => Err(()), + BodySendError::IO(e) => Ok(Self::IO(e)), + } + } +} + +impl From for BodySendError { + fn from(value: crate::body::SendError) -> Self { + match value { + crate::body::SendError::LengthDiscrepancy => Self::LengthDiscrepancy, + crate::body::SendError::IO(e) => Self::IO(e), + } + } +} + +impl From for BodySendError { + fn from(value: std::io::Error) -> Self { + Self::IO(value) + } +} + +/// A received response with possible incoming body. +#[derive(Debug, PartialEq)] +pub enum Response<'a, T> { + /// The response does not advertise an incoming body in any known way. + HeadersOnly(ResponseHead), + /// The response advertised a body using the `content-length` header. + WithSizedBody((ResponseHead, Incoming<'a, SizedIn<'a, PrependableStream>>)), + /// The response advertised a chunked body using the `transfer-encoding` header. + WithChunkedBody((ResponseHead, Incoming<'a, ChunkedIn<'a, PrependableStream>>)), +} + +/// A received request with possible incoming body. +pub enum Request<'a, T> { + /// The request does not advertise an incoming body in any known way. + HeadersOnly(RequestHead), + /// The request advertised a body using the `content-length` header. + WithSizedBody((RequestHead, Incoming<'a, SizedIn<'a, PrependableStream>>)), + /// The request advertised a chunked body using the `transfer-encoding` header. + WithChunkedBody((RequestHead, Incoming<'a, ChunkedIn<'a, PrependableStream>>)), +} + +fn get_outgoing_req_content_length(head: &RequestHead) -> Result, ClientSendError> { + let Some(l) = head.headers.get(&HeaderKey::CONTENT_LENGTH) else { + return Ok(None); + }; + Ok(Some(l.get().parse().map_err(|_| ClientSendError::InvalidContentLength)?)) +} + +fn get_outgoing_res_content_length(head: &ResponseHead) -> Result, ServerSendError> { + let Some(l) = head.headers.get(&HeaderKey::CONTENT_LENGTH) else { + return Ok(None); + }; + Ok(Some(l.get().parse().map_err(|_| ServerSendError::InvalidContentLength)?)) +} + +fn send_sync<'a, D: Iterator, T: Send>( + data_source: &'a mut D, + stream: &'a mut PrependableStream, + length: usize, +) -> Result<(), crate::body::SendError> { + Outgoing::new(SizedOut::new(data_source, stream, length)).send_all() +} + +#[cfg(feature = "async")] +async fn send_async<'a, F: Future, D: Iterator, T: AsyncSend>( + data_source: &'a mut D, + stream: &'a mut PrependableStream, + length: usize, +) -> Result<(), crate::body::SendError> { + Outgoing::new_async(SizedOut::new(data_source, stream, length)).send_all_async().await +} + +fn send_sync_slc( + data_source: &[u8], stream: &mut PrependableStream +) -> Result<(), BodySendError> { + let mut ptr = 0; + while ptr < data_source.len() { + ptr += stream.send(&data_source[ptr..])?; + } + Ok(()) +} + +#[cfg(feature = "async")] +async fn send_async_slc( + data_source: &[u8], stream: &mut PrependableStream +) -> Result<(), BodySendError> { + let mut ptr = 0; + while ptr < data_source.len() { + ptr += stream.send(&data_source[ptr..]).await?; + } + Ok(()) +} + +impl SyncClient { + /// Send an HTTP request to the other endpoint (ideally a server). + /// + /// If the `content-length` header is set, this will require you (the caller) to send a body + /// with the advertised content length before sending the next request. + pub fn send_request(&mut self, req_head: &RequestHead) -> Result<(), ClientSendError> { + match self.should_send_body { + OutgoingBody::None => {}, + _ => return Err(ClientSendError::StateViolated), + } + if let Some(l) = get_outgoing_req_content_length(req_head)? { + self.should_send_body = OutgoingBody::Sized(l); + } + Ok(self.handler.send_request(req_head)?) + } + + /// Send the body to the other endpoint. + /// + /// Note that this will only consume the iterator until the desired size is sent. + /// + /// If no body should be sent (no prior request has advertised a successive body), this will + /// immediatelly return with an empty Ok - will not send anything nor consume anything from the + /// iterator. + /// + /// If you wish to send a body from a loaded source (not an iterator) - refer to + /// [`SyncClient::send_body_bytes`]. + pub fn send_body<'a, D: Iterator>(&'a mut self, data_source: &'a mut D) -> + Result<(), BodySendError> { + match self.should_send_body { + OutgoingBody::Sized(l) => send_sync(data_source, &mut self.handler.inner, l)?, + OutgoingBody::None => {}, + } + self.should_send_body = OutgoingBody::None; + Ok(()) + } + + /// Send the body to the other endpoint. + /// + /// If no body should be sent (no prior request has advertised a successive body), this will + /// immediatelly return with an empty Ok - will not send anything nor consume anything from the + /// iterator. + /// + /// If you wish to send a body from a streamed source (an iterator) - refer to + /// [`SyncClient::send_body`]. + pub fn send_body_bytes(&mut self, data_source: &[u8]) -> Result<(), BodySendError> { + let advertised_length = match self.should_send_body { + OutgoingBody::Sized(l) => l, + OutgoingBody::None => return Ok(()), + }; + if advertised_length != data_source.len() { + return Err(BodySendError::LengthDiscrepancy); + } + send_sync_slc(data_source, &mut self.handler.inner)?; + self.should_send_body = OutgoingBody::None; + Ok(()) + } + + /// Attempt to receive an HTTP response from the other endpoint (ideally a server). + /// + /// If a body is advertised by the other endpoint, you (the caller) will then have to poll the + /// returned body object until the expected content length is consumed before receiving another + /// response from the endpoint. + pub fn receive_response(&mut self) -> Result, ClientReceiveError> { + let received = self.handler.receive_response()?; + Ok(match received.1 { + None => Response::HeadersOnly(received.0), + Some(ExpectedBody::Sized(b)) => Response::WithSizedBody((received.0, b)), + Some(ExpectedBody::Chunked(b)) => Response::WithChunkedBody((received.0, b)), + }) + } +} + +impl SyncServer { + /// Send an HTTP response to the other endpoint (ideally a client). + /// + /// If the `content-length` header is set, this will require you (the caller) to send a body + /// with the advertised content length before sending the next response. + pub fn send_response(&mut self, req_head: &ResponseHead) -> Result<(), ServerSendError> { + match self.should_send_body { + OutgoingBody::None => {}, + _ => return Err(ServerSendError::StateViolated), + } + if let Some(l) = get_outgoing_res_content_length(req_head)? { + self.should_send_body = OutgoingBody::Sized(l); + } + Ok(self.handler.send_response(req_head)?) + } + + /// Send the body to the other endpoint. + /// + /// Note that this will only consume the iterator until the desired size is sent. + /// + /// If no body should be sent (no prior response has advertised a successive body), this will + /// immediatelly return with an empty Ok - will not send anything nor consume anything from the + /// iterator. + /// + /// If you wish to send a body from a loaded source (not an iterator) - refer to + /// [`SyncServer::send_body_bytes`]. + pub fn send_body<'a, D: Iterator>(&'a mut self, data_source: &'a mut D) -> + Result<(), BodySendError> { + match self.should_send_body { + OutgoingBody::Sized(l) => send_sync(data_source, &mut self.handler.inner, l)?, + OutgoingBody::None => (), + } + self.should_send_body = OutgoingBody::None; + Ok(()) + } + + /// Send the body to the other endpoint. + /// + /// If no body should be sent (no prior response has advertised a successive body), this will + /// immediatelly return with an empty Ok - will not send anything nor consume anything from the + /// iterator. + /// + /// If you wish to send a body from a streamed source (an iterator) - refer to + /// [`SyncServer::send_body`]. + pub fn send_body_bytes(&mut self, data_source: &[u8]) -> Result<(), BodySendError> { + let advertised_length = match self.should_send_body { + OutgoingBody::Sized(l) => l, + OutgoingBody::None => return Ok(()), + }; + if advertised_length != data_source.len() { + return Err(BodySendError::LengthDiscrepancy); + } + send_sync_slc(data_source, &mut self.handler.inner)?; + self.should_send_body = OutgoingBody::None; + Ok(()) + } + + /// Attempt to receive an HTTP request from the other endpoint (ideally a client). + /// + /// If a body is advertised by the other endpoint, you (the caller) will then have to poll the + /// returned body object until the expected content length is consumed before receiving another + /// request from the endpoint. + pub fn receive_request(&mut self) -> Result, ServerReceiveError> { + let received = self.handler.receive_request()?; + Ok(match received.1 { + None => Request::HeadersOnly(received.0), + Some(ExpectedBody::Sized(b)) => Request::WithSizedBody((received.0, b)), + Some(ExpectedBody::Chunked(b)) => Request::WithChunkedBody((received.0, b)), + }) + } +} + +#[cfg(feature = "async")] +impl AsyncClient { + /// Send an HTTP request to the other endpoint (ideally a server). + /// + /// If the `content-length` header is set, this will require you (the caller) to send a body + /// with the advertised content length before sending the next request. + pub async fn send_request(&mut self, req_head: &RequestHead) -> Result<(), ClientSendError> { + match self.should_send_body { + OutgoingBody::None => {}, + _ => return Err(ClientSendError::StateViolated), + } + if let Some(l) = get_outgoing_req_content_length(req_head)? { + self.should_send_body = OutgoingBody::Sized(l); + } + Ok(self.handler.send_request_async(req_head).await?) + } + + /// Send the body to the other endpoint. + /// + /// Note that this will only consume the iterator until the desired size is sent. + /// + /// If no body should be sent (no prior request has advertised a successive body), this will + /// immediatelly return with an empty Ok - will not send anything nor consume anything from the + /// iterator. + /// + /// If you wish to send a body from a loaded source (not an iterator) - refer to + /// [`AsyncClient::send_body_bytes`]. + pub async fn send_body<'a, F: Future, D: Iterator>( + &'a mut self, data_source: &'a mut D + ) -> Result<(), BodySendError> { + match self.should_send_body { + OutgoingBody::Sized(l) => send_async(data_source, &mut self.handler.inner, l).await?, + OutgoingBody::None => (), + } + Ok(()) + } + + /// Send the body to the other endpoint. + /// + /// If no body should be sent (no prior request has advertised a successive body), this will + /// immediatelly return with an empty Ok - will not send anything nor consume anything from the + /// iterator. + /// + /// If you wish to send a body from a streamed source (an iterator) - refer to + /// [`AsyncClient::send_body`]. + pub async fn send_body_bytes(&mut self, data_source: &[u8]) -> Result<(), BodySendError> { + let advertised_length = match self.should_send_body { + OutgoingBody::Sized(l) => l, + OutgoingBody::None => return Ok(()), + }; + if advertised_length != data_source.len() { + return Err(BodySendError::LengthDiscrepancy); + } + send_async_slc(data_source, &mut self.handler.inner).await?; + self.should_send_body = OutgoingBody::None; + Ok(()) + } + + /// Attempt to receive an HTTP response from the other endpoint (ideally a server). + /// + /// If a body is advertised by the other endpoint, you (the caller) will then have to poll the + /// returned body object until the expected content length is consumed before receiving another + /// response from the endpoint. + pub async fn receive_response(&mut self) -> Result, ClientReceiveError> { + let received = self.handler.receive_response_async().await?; + Ok(match received.1 { + None => Response::HeadersOnly(received.0), + Some(ExpectedBody::Sized(b)) => Response::WithSizedBody((received.0, b)), + Some(ExpectedBody::Chunked(b)) => Response::WithChunkedBody((received.0, b)), + }) + } +} + +#[cfg(feature = "async")] +impl AsyncServer { + /// Send an HTTP response to the other endpoint (ideally a client). + /// + /// If the `content-length` header is set, this will require you (the caller) to send a body + /// with the advertised content length before sending the next response. + pub async fn send_response(&mut self, req_head: &ResponseHead) -> Result<(), ServerSendError> { + match self.should_send_body { + OutgoingBody::None => {}, + _ => return Err(ServerSendError::StateViolated), + } + if let Some(l) = get_outgoing_res_content_length(req_head)? { + self.should_send_body = OutgoingBody::Sized(l); + } + Ok(self.handler.send_response_async(req_head).await?) + } + + /// Send the body to the other endpoint. + /// + /// Note that this will only consume the iterator until the desired size is sent. + /// + /// If no body should be sent (no prior response has advertised a successive body), this will + /// immediatelly return with an empty Ok - will not send anything nor consume anything from the + /// iterator. + /// + /// If you wish to send a body from a loaded source (not an iterator) - refer to + /// [`AsyncServer::send_body_bytes`]. + pub async fn send_body<'a, F: Future, D: Iterator>( + &'a mut self, data_source: &'a mut D + ) -> Result<(), BodySendError> { + match self.should_send_body { + OutgoingBody::Sized(l) => send_async(data_source, &mut self.handler.inner, l).await?, + OutgoingBody::None => (), + } + self.should_send_body = OutgoingBody::None; + Ok(()) + } + + /// Send the body to the other endpoint. + /// + /// If no body should be sent (no prior response has advertised a successive body), this will + /// immediatelly return with an empty Ok - will not send anything nor consume anything from the + /// iterator. + /// + /// If you wish to send a body from a streamed source (an iterator) - refer to + /// [`AsyncServer::send_body`]. + pub async fn send_body_bytes(&mut self, data_source: &[u8]) -> Result<(), BodySendError> { + let advertised_length = match self.should_send_body { + OutgoingBody::Sized(l) => l, + OutgoingBody::None => return Ok(()), + }; + if advertised_length != data_source.len() { + return Err(BodySendError::LengthDiscrepancy); + } + send_async_slc(data_source, &mut self.handler.inner).await?; + self.should_send_body = OutgoingBody::None; + Ok(()) + } + + /// Attempt to receive an HTTP request from the other endpoint (ideally a client). + /// + /// If a body is advertised by the other endpoint, you (the caller) will then have to poll the + /// returned body object until the expected content length is consumed before receiving another + /// request from the endpoint. + pub async fn receive_request(&mut self) -> Result, ServerReceiveError> { + let received = self.handler.receive_request_async().await?; + Ok(match received.1 { + None => Request::HeadersOnly(received.0), + Some(ExpectedBody::Sized(b)) => Request::WithSizedBody((received.0, b)), + Some(ExpectedBody::Chunked(b)) => Request::WithChunkedBody((received.0, b)), + }) + } +} diff --git a/lib/inferium/src/proto/h1/head.rs b/lib/inferium/src/proto/h1/head.rs new file mode 100644 index 0000000..d5e3290 --- /dev/null +++ b/lib/inferium/src/proto/h1/head.rs @@ -0,0 +1,618 @@ +use std::collections::HashMap; +use crate::{ + headers::{HeaderKey, HeaderValue}, + method::Method, + path::{HttpPath, HttpPathParseError}, + status::Status +}; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ProtocolVariant { + HTTP1_0, + HTTP1_1, +} + +impl std::fmt::Display for ProtocolVariant { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::HTTP1_0 => write!(f, "HTTP/1.0"), + Self::HTTP1_1 => write!(f, "HTTP/1.1"), + } + } +} + +impl ProtocolVariant { + pub(crate) fn text(&self) -> &'static [u8] { + match self { + Self::HTTP1_1 => b"HTTP/1.1", + Self::HTTP1_0 => b"HTTP/1.0", + } + } +} + +impl TryFrom<&[u8]> for ProtocolVariant { + type Error = (); + + fn try_from(value: &[u8]) -> Result { + match value { + b"HTTP/1.0" => Ok(Self::HTTP1_0), + b"HTTP/1.1" => Ok(Self::HTTP1_1), + _ => Err(()), + } + } +} + +fn serialize_header(header: (&HeaderKey, &HeaderValue)) -> Vec { + let mut res = Vec::new(); + for value in header.1.all() { + res.extend_from_slice(header.0.text()); + res.extend_from_slice(b": "); + res.extend_from_slice(value.as_bytes()); + res.extend_from_slice(b"\r\n"); + } + res +} + +fn format_header(header: (&HeaderKey, &HeaderValue)) -> String { + let mut res = String::new(); + for value in header.1.all() { + res.push_str(&format!("{}: {}", header.0, value)); + } + res +} + +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug)] +pub struct RequestHead { + pub(crate) method: Method, + pub(crate) path: HttpPath, + pub(crate) protocol: ProtocolVariant, + pub(crate) headers: HashMap, +} + +impl RequestHead { + #[inline] + pub fn method(&self) -> &Method { + &self.method + } + + #[inline] + pub fn uri(&self) -> &HttpPath { + &self.path + } + + #[inline] + pub fn proto(&self) -> &ProtocolVariant { + &self.protocol + } + + #[inline] + pub fn headers(&self) -> &HashMap { + &self.headers + } +} + +impl std::fmt::Display for RequestHead { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "{} {} {}", self.method, self.path, self.protocol)?; + for entry in self.headers.iter() { + writeln!(f, "{}", format_header(entry))?; + } + Ok(()) + } +} + +impl RequestHead { + pub fn new( + method: Method, + path: HttpPath, + protocol: ProtocolVariant, + headers: HashMap, + ) -> Self { + Self { method, path, protocol, headers } + } +} + +impl RequestHead { + pub(crate) fn serialize(&self) -> Vec { + let mut res = Vec::new(); + res.extend_from_slice(self.method.text()); + res.push(b' '); + res.extend_from_slice(&self.path.serialize()); + res.push(b' '); + res.extend_from_slice(self.protocol.text()); + res.extend_from_slice(b"\r\n"); + for entry in self.headers.iter() { + res.extend_from_slice(&serialize_header(entry)); + } + res.extend_from_slice(b"\r\n"); + res + } +} + +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug)] +pub enum BadRequest { + HeadLine, + InvalidPath(HttpPathParseError), + InvalidMethod, + InvalidProtocol, + InvalidHeaders, +} + +impl From for BadRequest { + fn from(value: HttpPathParseError) -> Self { + Self::InvalidPath(value) + } +} + +pub(super) fn parse_request_head(raw: &[u8]) -> Result { + const PAT_SPACE: &[u8] = b" "; + const JUT_SPACE: &[usize] = &[0]; + const PAT_CRLF: &[u8] = b"\r\n"; + const JUT_CRLF: &[usize] = &[0, 0]; + + // Method + let Some(meth_end) = try_find(raw, PAT_SPACE, JUT_SPACE, &0) else { + return Err(BadRequest::HeadLine); + }; + let meth: Method = std::str::from_utf8(&raw[..meth_end]) + .map_err(|_| BadRequest::InvalidMethod)?.parse().map_err(|_| BadRequest::InvalidMethod)?; + + // Path + let path_start = meth_end + PAT_SPACE.len(); + let Some(path_end) = try_find(raw, PAT_SPACE, JUT_SPACE, &path_start) else { + return Err(BadRequest::HeadLine); + }; + let path = HttpPath::try_from(&raw[path_start..path_end])?; + + // Protocol + let proto_start = path_end + PAT_SPACE.len(); + let Some(proto_end) = try_find(raw, PAT_CRLF, JUT_CRLF, &proto_start) else { + return Err(BadRequest::HeadLine); + }; + let proto: ProtocolVariant = (&raw[proto_start..proto_end]).try_into() + .map_err(|_| BadRequest::InvalidProtocol)?; + + // Headers + let mut res = HashMap::new(); + let mut header_start = proto_end + PAT_CRLF.len(); + loop { + let Some(header_end) = try_find(raw, PAT_CRLF, JUT_CRLF, &header_start) else { + return Err(BadRequest::InvalidHeaders); + }; + if header_start == header_end { + break; + } + parse_header(&raw[header_start..header_end], &mut res) + .map_err(|_| BadRequest::InvalidHeaders)?; + header_start = header_end + PAT_CRLF.len(); + } + + Ok(RequestHead { + method: meth, + path, + protocol: proto, + headers: res + }) +} + + + +#[derive(Debug, PartialEq)] +pub struct ResponseHead { + pub(crate) status: Status, + pub(crate) protocol: ProtocolVariant, + pub(crate) headers: HashMap, +} + +impl ResponseHead { + #[inline] + pub fn status(&self) -> &Status { + &self.status + } + + #[inline] + pub fn proto(&self) -> &ProtocolVariant { + &self.protocol + } + + #[inline] + pub fn headers(&self) -> &HashMap { + &self.headers + } +} + +impl ResponseHead { + pub fn new( + status: Status, protocol: ProtocolVariant, headers: HashMap + ) -> Self { + Self { status, protocol, headers } + } +} + +impl std::fmt::Display for ResponseHead { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "{} {}", self.protocol, self.status)?; + for entry in self.headers.iter() { + writeln!(f, "{}", format_header(entry))?; + } + Ok(()) + } +} + +impl ResponseHead { + pub(crate) fn serialize(&self) -> Vec { + let mut res = Vec::new(); + res.extend_from_slice(self.protocol.text()); + res.push(b' '); + res.extend_from_slice(self.status.num()); + res.push(b' '); + res.extend_from_slice(self.status.text()); + res.extend_from_slice(b"\r\n"); + for header in self.headers.iter() { + res.extend_from_slice(&serialize_header(header)); + } + res.extend_from_slice(b"\r\n"); + res + } +} + +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug)] +pub enum BadResponse { + /// The headline (parseable example: `HTTP/1.1 200 OK`) could not be parsed. + HeadLine, + /// The response status code is unknown. + InvalidStatusCode, + /// The advertised protocol is not supported. + InvalidProtocol, + /// Some headers are invalid. Either they have unknown keys or invalid syntax. + InvalidHeaders, +} + +pub(super) fn parse_response_head(raw: &[u8]) -> Result { + const PAT_SPACE: &[u8] = b" "; + const JUT_SPACE: &[usize] = &[0]; + const PAT_CRLF: &[u8] = b"\r\n"; + const JUT_CRLF: &[usize] = &[0, 0]; + + // Protocol + let Some(proto_end) = try_find(raw, PAT_SPACE, JUT_SPACE, &0) else { + return Err(BadResponse::HeadLine); + }; + let proto: ProtocolVariant = (&raw[..proto_end]).try_into() + .map_err(|_| BadResponse::InvalidProtocol)?; + + // Status code + let status_start = proto_end + PAT_SPACE.len(); + let Some(status_end) = try_find(raw, PAT_SPACE, JUT_SPACE, &status_start) else { + return Err(BadResponse::HeadLine); + }; + let status: Status = Status::try_from(&raw[status_start..status_end]) + .map_err(|_| BadResponse::InvalidStatusCode)?; + + let Some(headline_end) = try_find(raw, PAT_CRLF, JUT_CRLF, &(status_end+PAT_SPACE.len())) else { + return Err(BadResponse::HeadLine); + }; + + // Headers + let mut res = HashMap::new(); + let mut header_start = headline_end + PAT_CRLF.len(); + loop { + let Some(header_end) = try_find(raw, PAT_CRLF, JUT_CRLF, &header_start) else { + return Err(BadResponse::InvalidHeaders); + }; + if header_start == header_end { + break; + } + parse_header(&raw[header_start..header_end], &mut res) + .map_err(|_| BadResponse::InvalidHeaders)?; + header_start = header_end + PAT_CRLF.len(); + } + + Ok(ResponseHead { + status, + protocol: proto, + headers: res + }) +} + +fn parse_header( + raw: &[u8], + res: &mut HashMap, +) -> Result<(), ()> { + let delim = try_find(raw, b": ", &[0, 0], &0).ok_or(())?; + let hkey: HeaderKey = std::str::from_utf8(&raw[..delim]).map_err(|_| ())?.into(); + let hval: String = String::from_utf8(raw[delim+2..].to_vec()).map_err(|_| ())?; + if hval.is_empty() { + return Err(()); + } + let entry = match res.get_mut(&hkey) { + Some(v) => v, + None => { + res.insert(hkey.clone(), HeaderValue::default()); + res.get_mut(&hkey).unwrap() + } + }; + entry.add(hval); + Ok(()) +} + +fn try_find( + haystack: &[u8], + pat: &[u8], + jumptable: &[usize], + start_from: &usize, +) -> Option { + let mut pat_ptr = 0_usize; + for (idx, cur) in haystack[*start_from..].iter().enumerate() { + if pat_ptr >= pat.len() { + return Some(idx + start_from - pat.len()); + } + if *cur == pat[pat_ptr] { + pat_ptr += 1; + continue; + } + try_find_update_jumptable(cur, pat, &mut pat_ptr, jumptable); + } + if pat_ptr >= pat.len() { + return Some(haystack.len() - pat.len()); + } + None +} + +fn try_find_update_jumptable( + hay: &u8, + pat: &[u8], + pat_ptr: &mut usize, + jumptable: &[usize], +) { + while *pat_ptr != 0 { + *pat_ptr = jumptable[*pat_ptr]; + if *hay == pat[*pat_ptr] { + *pat_ptr += 1; + break; + } + } +} + +#[cfg(test)] +mod patfind { + use super::try_find; + + macro_rules! test { + ($name: ident, $src: literal, $pat: literal, $jt:tt, $start_from:literal, $res:expr) => { + #[test] + fn $name() { + let src = $src.as_bytes().to_vec(); + let src = src.as_slice(); + let pat = $pat.as_bytes().to_vec(); + let pat = pat.as_slice(); + let jt = vec!$jt; + let jt = jt.as_slice(); + assert_eq!(try_find(src, pat, jt, &$start_from), $res); + } + } + } + + test!(valid_singlebyte, "Hello, world!", ",", [0], 0, Some(5)); + test!(valid_multibyte, "Hello, world!", ", ", [0, 0], 0, Some(5)); + test!(valid_begin_nostart, "Hello, world!", ", ", [0, 0], 3, Some(5)); + test!(valid_begin_startpat, "Hello, world!", ", ", [0, 0], 5, Some(5)); + test!(invalid_begin_midpat, "Hello, world!", ", ", [0, 0], 6, None); + test!(recurse_jumptable_01, "AAAAAB", "AAAB", [0, 0, 1, 2], 0, Some(2)); + test!(recurse_jumptable_02, "ABABABC", "ABABC", [0, 0, 0, 1, 2], 0, Some(2)); +} + +#[cfg(test)] +mod parse_request_head { + use std::collections::HashMap; + + use crate::{ + headers::{HeaderKey, HeaderValue}, + method::Method, + path::HttpPath, + proto::h1::head::{BadRequest, ProtocolVariant} + }; + use super::{parse_request_head, RequestHead}; + + macro_rules! test_inner { + (@get_params [$($pk: ident : $pv: literal),+]$(,)?) => { + Some(HashMap::from([$((stringify!($pk).to_string(), $pv.to_string())),*])) + }; + + (@get_params []) => { + None + }; + } + + macro_rules! test { + ( + $name: ident, + $src: literal, + ok $method: ident $path: literal ? [$($pk: ident : $pv: literal),*$(,)?] $proto: ident, + [$($hk: ident : $hv: literal),*$(,)?] + ) => { + test!($name, $src, Ok(RequestHead { + method: Method::$method, + path: HttpPath{ path:$path.to_string(), params:test_inner!(@get_params[$($pk:$pv),*]) }, + protocol: ProtocolVariant::$proto, + headers: HashMap::from([$((HeaderKey::$hk,HeaderValue::new(vec![$hv.to_string()]))),*]), + })); + }; + + ($name: ident, $src: literal, $res: expr) => { + #[test] + fn $name() { + let src = $src.as_bytes().to_vec(); + let src = src.as_slice(); + assert_eq!(parse_request_head(src), $res); + } + } + } + + test!(valid_request, "GET /hello HTTP/1.0\r\n\r\n", + ok GET "/hello"?[] HTTP1_0, + [] + ); + + test!(valid_request_single_header, "GET / HTTP/1.1\r\nUser-Agent: Mozilla/5.0\r\n\r\n", + ok GET "/"?[] HTTP1_1, + [ + USER_AGENT: "Mozilla/5.0", + ] + ); + + test!( + valid_request_multi_header, + "GET /api?action=test HTTP/1.1\r\nUser-Agent: Mozilla/5.0\r\nConnection: close\r\n\r\n", + ok GET "/api"?[action: "test"] HTTP1_1, + [ + USER_AGENT: "Mozilla/5.0", + CONNECTION: "close", + ] + ); + + test!(invalid_method, "GHOST / HTTP/1.1\r\n\r\n", Err(BadRequest::InvalidMethod)); +} + +#[cfg(test)] +mod parse_response_head { + use std::collections::HashMap; + + use crate::{ + headers::{HeaderKey, HeaderValue}, + proto::h1::head::{BadResponse, ProtocolVariant} + }; + use super::{parse_response_head, ResponseHead, Status}; + + macro_rules! test { + ( + $name: ident, + $src: literal, + ok $proto: ident $status: ident, + [$($hk: ident : $hv: literal),*$(,)?] + ) => { + test!($name, $src, Ok(ResponseHead { + status: Status::$status, + protocol: ProtocolVariant::$proto, + headers: HashMap::from([$((HeaderKey::$hk,HeaderValue::new(vec![$hv.to_string()]))),*]), + })); + }; + + ($name: ident, $src: literal, $res: expr) => { + #[test] + fn $name() { + let src = $src.as_bytes().to_vec(); + let src = src.as_slice(); + assert_eq!(parse_response_head(src), $res); + } + }; + } + + test!(valid_request, "HTTP/1.1 200 OK\r\n\r\n", + ok HTTP1_1 Ok, + [] + ); + + test!(valid_request_single_header, "HTTP/1.0 200 OK\r\nContent-Type: text/html\r\n\r\n", + ok HTTP1_0 Ok, + [ + CONTENT_TYPE: "text/html", + ] + ); + + test!( + valid_request_multi_header, + "HTTP/1.1 200 OK\r\nServer: inferium\r\nConnection: close\r\n\r\n", + ok HTTP1_1 Ok, + [ + SERVER: "inferium", + CONNECTION: "close", + ] + ); + + test!(invalid_protocol, "PROTO 200 OK\r\n\r\n", Err(BadResponse::InvalidProtocol)); + test!(invalid_status_code, "HTTP/1.1 42069 OK\r\n\r\n", Err(BadResponse::InvalidStatusCode)); +} + +#[cfg(test)] +mod construct_outgoing { + use std::collections::HashMap; + use crate::{headers::{HeaderKey, HeaderValue}, status::Status, path::HttpPath, method::Method}; + use super::{ResponseHead, RequestHead, ProtocolVariant}; + + macro_rules! test { + (@headers [$($hk: ident : $hv: literal),*$(,)?]) => { + HashMap::from([$((HeaderKey::$hk, HeaderValue::new(vec![$hv.to_string()]))),*]) + }; + + (@uri $path: literal [$($hk: ident : $hv: literal),+$(,)?]) => { + HttpPath { + path: $path.to_string(), + params: Some(HashMap::from([$((stringify!($hk).to_string(), $hv.to_string())),+])), + } + }; + + (@uri $path: literal [$(,)?]) => { + HttpPath { + path: $path.to_string(), + params: None, + } + }; + + ($name: ident, res $proto: ident $status: ident, [$($hk: ident : $hv: literal),*$(,)?], + $target: literal + ) => { + #[test] + fn $name() { + let src = ResponseHead { + protocol: ProtocolVariant::$proto, + status: Status::$status, + headers: test!(@headers [$($hk:$hv),*]), + }; + assert_eq!(src.serialize(), $target); + } + }; + + ( + $name: ident, + req $method: ident $path: literal [$($qk: ident : $qv: literal),*$(,)?] $proto: ident, + [$($hk: ident : $hv: literal),*$(,)?], + $target: literal + ) => { + #[test] + fn $name() { + let src = RequestHead { + method: Method::$method, + path: test!(@uri $path [$($qk : $qv),*]), + protocol: ProtocolVariant::$proto, + headers: test!(@headers [$($hk : $hv),*]), + }; + assert_eq!(src.serialize(), $target); + } + }; + } + + test!(response_simple_noheaders, res HTTP1_1 Ok, [], b"HTTP/1.1 200 OK\r\n\r\n"); + test!(response_with_single_header_01, res HTTP1_0 NotFound, [ + SERVER: "inferium", + ], b"HTTP/1.0 404 Not Found\r\nserver: inferium\r\n\r\n"); + test!(response_with_single_header_02, res HTTP1_1 Forbidden, [ + SERVER: "inferium", + ], b"HTTP/1.1 403 Forbidden\r\nserver: inferium\r\n\r\n"); + + test!(request_simple_noheaders, req GET "/"[] HTTP1_1, [], b"GET / HTTP/1.1\r\n\r\n"); + test!(request_simple_single_header, req GET "/"[] HTTP1_1, [ + USER_AGENT: "inferium", + ], b"GET / HTTP/1.1\r\nuser-agent: inferium\r\n\r\n"); + test!(request_path_single_header, req GET "/well/hello/there"[] HTTP1_0, [ + USER_AGENT: "inferium", + ], b"GET /well/hello/there HTTP/1.0\r\nuser-agent: inferium\r\n\r\n"); + test!(request_path_with_query_single_header, req GET "/well/hello/there"[ + hello: "world" + ] HTTP1_0, [ + USER_AGENT: "inferium", + ], b"GET /well/hello/there?hello=world HTTP/1.0\r\nuser-agent: inferium\r\n\r\n"); +} diff --git a/lib/inferium/src/proto/h1/mod.rs b/lib/inferium/src/proto/h1/mod.rs new file mode 100644 index 0000000..288a300 --- /dev/null +++ b/lib/inferium/src/proto/h1/mod.rs @@ -0,0 +1,26 @@ +mod head; +mod stream_handler; +mod exports; + +pub use exports::{ + Request, + Response, + SyncClient, + SyncServer, + ClientSendError, + ClientReceiveError, + ServerSendError, + ServerReceiveError, + BodySendError, +}; +#[cfg(feature = "async")] +pub use exports::{ + AsyncClient, + AsyncServer, +}; + +pub use head::{ + RequestHead, + ResponseHead, +}; +pub use head::ProtocolVariant; diff --git a/lib/inferium/src/proto/h1/stream_handler.rs b/lib/inferium/src/proto/h1/stream_handler.rs new file mode 100644 index 0000000..6928366 --- /dev/null +++ b/lib/inferium/src/proto/h1/stream_handler.rs @@ -0,0 +1,600 @@ +use crate::{ + body::{ChunkedIn, Incoming, SizedIn}, + headers::HeaderKey, + io::{PrependableStream, ReaderError, ReaderValue, Receive, Send, SyncReader}, + proto::h1::head::{parse_request_head, parse_response_head}, + settings::BUF_SIZE_HEAD, HeaderValue +}; +#[cfg(feature = "async")] +use crate::io::{AsyncReceive, AsyncSend, AsyncReader}; +use super::head::{BadRequest, BadResponse, RequestHead, ResponseHead}; + +#[derive(Debug)] +pub(super) struct StreamHandler { + pub(super) inner: PrependableStream, + /// Whether the caller has received the entire HTTP body from the other endpoint (if needed). + has_exhausted_body: bool, +} + +#[cfg_attr(test, derive(Debug))] +pub(super) enum StreamHandlerReceiveError { + HeaderTooLarge, + RequiresBodyPolling, + ParsingError(T), + InvalidExpectedBody, + NoData, + IO(std::io::Error), +} + +impl From for StreamHandlerReceiveError { + fn from(value: BadRequest) -> Self { + Self::ParsingError(value) + } +} + +impl From for StreamHandlerReceiveError { + fn from(value: BadResponse) -> Self { + Self::ParsingError(value) + } +} + +#[cfg_attr(test, derive(Debug))] +pub(super) enum StreamHandlerSendError { + RequiresBodyPolling, + IO(std::io::Error), +} + +impl From for StreamHandlerSendError { + fn from(value: std::io::Error) -> Self { + Self::IO(value) + } +} + +impl StreamHandler { + pub(super) fn new(inner: T) -> Self { + Self { inner: PrependableStream::new(inner), has_exhausted_body: true } + } +} + +macro_rules! leaky { + ($stream: expr, $head: ident, $body: ident) => {{ + $stream.prepend_to_read($body); + $head + }}; +} + +#[cfg_attr(test, derive(Debug))] +pub(super) enum ExpectedBody<'a, T> { + Sized(Incoming<'a, SizedIn<'a, T>>), + Chunked(Incoming<'a, ChunkedIn<'a, T>>), +} + +type FullResponse<'a, T> = (ResponseHead, Option>>); +type FullRequest<'a, T> = (RequestHead, Option>>); + +fn construct_response_body( + sh: &mut StreamHandler, + res: ResponseHead, +) -> Result, StreamHandlerReceiveError> { + if let Some(transfer_encoding) = res.headers.get(&HeaderKey::TRANSFER_ENCODING) { + if transfer_encoding != &HeaderValue::new(vec!["chunked".to_string()]) { + return Err(StreamHandlerReceiveError::InvalidExpectedBody); + } + sh.has_exhausted_body = false; + return Ok((res, Some(ExpectedBody::Chunked(Incoming::>::new( + ChunkedIn::new(&mut sh.inner), &mut sh.has_exhausted_body + ))))); + } + if let Some(content_length) = res.headers.get(&HeaderKey::CONTENT_LENGTH) { + let Ok(content_length): Result = content_length.get().parse() else { + return Err(StreamHandlerReceiveError::InvalidExpectedBody); + }; + sh.has_exhausted_body = false; + return Ok((res, Some(ExpectedBody::Sized(Incoming::>::new( + SizedIn::new(&mut sh.inner, content_length), &mut sh.has_exhausted_body + ))))); + } + Ok((res, None)) +} + +fn construct_request_body( + sh: &mut StreamHandler, + req: RequestHead, +) -> Result, StreamHandlerReceiveError> { + if let Some(transfer_encoding) = req.headers.get(&HeaderKey::TRANSFER_ENCODING) { + if transfer_encoding != &HeaderValue::new(vec!["chunked".to_string()]) { + return Err(StreamHandlerReceiveError::InvalidExpectedBody); + } + sh.has_exhausted_body = false; + return Ok((req, Some(ExpectedBody::Chunked(Incoming::>::new( + ChunkedIn::new(&mut sh.inner), &mut sh.has_exhausted_body + ))))); + } + if let Some(content_length) = req.headers.get(&HeaderKey::CONTENT_LENGTH) { + let Ok(content_length): Result = content_length.get().parse() else { + return Err(StreamHandlerReceiveError::InvalidExpectedBody); + }; + sh.has_exhausted_body = false; + return Ok((req, Some(ExpectedBody::Sized(Incoming::>::new( + SizedIn::new(&mut sh.inner, content_length), &mut sh.has_exhausted_body + ))))); + } + Ok((req, None)) +} + +#[cfg(feature = "async")] +fn construct_response_body_async( + sh: &mut StreamHandler, + res: ResponseHead, +) -> Result, StreamHandlerReceiveError> { + if let Some(transfer_encoding) = res.headers.get(&HeaderKey::TRANSFER_ENCODING) { + if transfer_encoding != &HeaderValue::new(vec!["chunked".to_string()]) { + return Err(StreamHandlerReceiveError::InvalidExpectedBody); + } + sh.has_exhausted_body = false; + return Ok((res, Some(ExpectedBody::Chunked(Incoming::>::new_async( + ChunkedIn::new(&mut sh.inner), &mut sh.has_exhausted_body + ))))); + } + if let Some(content_length) = res.headers.get(&HeaderKey::CONTENT_LENGTH) { + let Ok(content_length): Result = content_length.get().parse() else { + return Err(StreamHandlerReceiveError::InvalidExpectedBody); + }; + sh.has_exhausted_body = false; + return Ok((res, Some(ExpectedBody::Sized(Incoming::>::new_async( + SizedIn::new(&mut sh.inner, content_length), &mut sh.has_exhausted_body + ))))); + } + Ok((res, None)) +} + +#[cfg(feature = "async")] +fn construct_request_body_async( + sh: &mut StreamHandler, + req: RequestHead, +) -> Result, StreamHandlerReceiveError> { + if let Some(transfer_encoding) = req.headers.get(&HeaderKey::TRANSFER_ENCODING) { + if transfer_encoding != &HeaderValue::new(vec!["chunked".to_string()]) { + return Err(StreamHandlerReceiveError::InvalidExpectedBody); + } + sh.has_exhausted_body = false; + return Ok((req, Some(ExpectedBody::Chunked(Incoming::>::new_async( + ChunkedIn::new(&mut sh.inner), &mut sh.has_exhausted_body + ))))); + } + if let Some(content_length) = req.headers.get(&HeaderKey::CONTENT_LENGTH) { + let Ok(content_length): Result = content_length.get().parse() else { + return Err(StreamHandlerReceiveError::InvalidExpectedBody); + }; + sh.has_exhausted_body = false; + return Ok((req, Some(ExpectedBody::Sized(Incoming::>::new_async( + SizedIn::new(&mut sh.inner, content_length), &mut sh.has_exhausted_body + ))))); + } + Ok((req, None)) +} + +impl StreamHandler { + pub(super) fn receive_request(&mut self) + -> Result, StreamHandlerReceiveError> { + if !self.has_exhausted_body { + return Err(StreamHandlerReceiveError::RequiresBodyPolling); + } + let mut buf = [0_u8; BUF_SIZE_HEAD]; + let received = { + let mut reader = SyncReader::new(&mut self.inner); + reader.recv_until(b"\r\n\r\n", &[0, 0, 0, 1], &mut buf) + }; + let header = match received { + Ok(ReaderValue::ExactRead { up_to_delimiter: h }) => h, + Ok(ReaderValue::LeakyRead { up_to_delimiter: h, rest: b }) => leaky!(self.inner, h, b), + Err(ReaderError::IO(e))=>return Err(StreamHandlerReceiveError::IO(e)), + Err(ReaderError::NoData)=>return Err(StreamHandlerReceiveError::NoData), + Err(ReaderError::BufferOverflow)=>return Err(StreamHandlerReceiveError::HeaderTooLarge), + }.len(); + let header = &buf[..header+4]; + let header = parse_request_head(header)?; + construct_request_body(self, header) + } + + pub(super) fn receive_response(&mut self) + -> Result, StreamHandlerReceiveError> { + if !self.has_exhausted_body { + return Err(StreamHandlerReceiveError::RequiresBodyPolling); + } + let mut buf = [0_u8; BUF_SIZE_HEAD]; + let received = { + let mut reader = SyncReader::new(&mut self.inner); + reader.recv_until(b"\r\n\r\n", &[0, 0, 0, 1], &mut buf) + }; + let header = match received { + Ok(ReaderValue::ExactRead { up_to_delimiter: h }) => h, + Ok(ReaderValue::LeakyRead { up_to_delimiter: h, rest: b }) => leaky!(self.inner, h, b), + Err(ReaderError::IO(e))=>return Err(StreamHandlerReceiveError::IO(e)), + Err(ReaderError::NoData)=>return Err(StreamHandlerReceiveError::NoData), + Err(ReaderError::BufferOverflow)=>return Err(StreamHandlerReceiveError::HeaderTooLarge), + }.len(); + let header = &buf[..header+4]; + let header = parse_response_head(header)?; + construct_response_body(self, header) + } +} + +impl StreamHandler { + pub(super) fn send_request(&mut self, req: &RequestHead) -> Result<(), StreamHandlerSendError> { + if !self.has_exhausted_body { + return Err(StreamHandlerSendError::RequiresBodyPolling); + } + let serialized = req.serialize(); + let serialized = serialized.as_slice(); + let mut ptr = 0; + while ptr < serialized.len() { + ptr += self.inner.send(&serialized[ptr..])?; + } + Ok(()) + } + + pub(super) fn send_response( + &mut self, res: &ResponseHead + ) -> Result<(), StreamHandlerSendError> { + if !self.has_exhausted_body { + return Err(StreamHandlerSendError::RequiresBodyPolling); + } + let serialized = res.serialize(); + let serialized = serialized.as_slice(); + let mut ptr = 0; + while ptr < serialized.len() { + ptr += self.inner.send(&serialized[ptr..])?; + } + Ok(()) + } +} + +#[cfg(feature = "async")] +impl StreamHandler { + pub(super) async fn receive_request_async(&mut self) + -> Result, StreamHandlerReceiveError> { + if !self.has_exhausted_body { + return Err(StreamHandlerReceiveError::RequiresBodyPolling); + } + let mut buf = [0_u8; BUF_SIZE_HEAD]; + let received = { + let mut reader = AsyncReader::new(&mut self.inner); + reader.recv_until(b"\r\n\r\n", &[0, 0, 0, 1], &mut buf).await + }; + let header = match received { + Ok(ReaderValue::ExactRead { up_to_delimiter: h }) => h, + Ok(ReaderValue::LeakyRead { up_to_delimiter: h, rest: b }) => leaky!(self.inner, h, b), + Err(ReaderError::IO(e))=>return Err(StreamHandlerReceiveError::IO(e)), + Err(ReaderError::NoData)=>return Err(StreamHandlerReceiveError::NoData), + Err(ReaderError::BufferOverflow)=>return Err(StreamHandlerReceiveError::HeaderTooLarge), + }.len(); + let header = &buf[..header+4]; + let header = parse_request_head(header)?; + construct_request_body_async(self, header) + } + + pub(super) async fn receive_response_async(&mut self) + -> Result, StreamHandlerReceiveError> { + if !self.has_exhausted_body { + return Err(StreamHandlerReceiveError::RequiresBodyPolling); + } + let mut buf = [0_u8; BUF_SIZE_HEAD]; + let received = { + let mut reader = AsyncReader::new(&mut self.inner); + reader.recv_until(b"\r\n\r\n", &[0, 0, 0, 1], &mut buf).await + }; + let header = match received { + Ok(ReaderValue::ExactRead { up_to_delimiter: h }) => h, + Ok(ReaderValue::LeakyRead { up_to_delimiter: h, rest: b }) => leaky!(self.inner, h, b), + Err(ReaderError::IO(e))=>return Err(StreamHandlerReceiveError::IO(e)), + Err(ReaderError::NoData)=>return Err(StreamHandlerReceiveError::NoData), + Err(ReaderError::BufferOverflow)=>return Err(StreamHandlerReceiveError::HeaderTooLarge), + }.len(); + let header = &buf[..header+4]; + let header = parse_response_head(header)?; + construct_response_body_async(self, header) + } +} + +#[cfg(feature = "async")] +impl StreamHandler { + pub(super) async fn send_request_async( + &mut self, req: &RequestHead + ) -> Result<(), StreamHandlerSendError> { + if !self.has_exhausted_body { + return Err(StreamHandlerSendError::RequiresBodyPolling); + } + let serialized = req.serialize(); + let serialized = serialized.as_slice(); + let mut ptr = 0; + while ptr < serialized.len() { + ptr += self.inner.send(&serialized[ptr..]).await?; + } + Ok(()) + } + + pub(super) async fn send_response_async( + &mut self, res: &ResponseHead + ) -> Result<(), StreamHandlerSendError> { + if !self.has_exhausted_body { + return Err(StreamHandlerSendError::RequiresBodyPolling); + } + let serialized = res.serialize(); + let serialized = serialized.as_slice(); + let mut ptr = 0; + while ptr < serialized.len() { + ptr += self.inner.send(&serialized[ptr..]).await?; + } + Ok(()) + } +} + +#[cfg(test)] +mod receive { + use crate::{ + io::TestSyncStream, + method::Method, + path::HttpPath, + headers::{HeaderKey, HeaderValue}, + proto::h1::head::{RequestHead, ResponseHead, ProtocolVariant}, + status::Status, + }; + #[cfg(feature = "async")] + use crate::io::TestAsyncStream; + use std::collections::HashMap; + use super::{StreamHandler, ExpectedBody}; + + macro_rules! test_inner { + ( + @request_head $method: ident, + $path: literal, + $proto: ident, [$($qk: ident : $qv: literal),*], [$($hk: ident : $hv: literal),*] + ) => { + RequestHead { + method: Method::$method, + path: test_inner!(@http_path $path, [$($qk : $qv),*]), + protocol: ProtocolVariant::$proto, + headers: test_inner!(@headers [$($hk : $hv),*]), + } + }; + ( + @response_head + $proto: ident, $status: ident, [$($hk: ident : $hv: literal),*] + ) => { + ResponseHead { + status: Status::$status, + protocol: ProtocolVariant::$proto, + headers: test_inner!(@headers [$($hk : $hv),*]), + } + }; + + (@http_path $path: literal, [$($hk: ident : $hv: literal),+]) => { + HttpPath { path: $path.to_string(), params: Some(HashMap::from([$(($hk, $hv)),+])) } + }; + (@http_path $path: literal, []) => { + HttpPath { path: $path.to_string(), params: None } + }; + + (@headers [$($hk: ident : $hv :literal),*]) => { + HashMap::from([$((HeaderKey::$hk, HeaderValue::new(Vec::from([$hv.to_string()])))),*]) + }; + } + + macro_rules! test { + ( + $name: ident $name_async: ident, + req $src: literal, + $method: ident $path: literal [$($qk: ident : $qv: literal),*] $proto: ident, + [$($hk: ident : $hv: literal),*$(,)?] + ) => { + test!(@inner $name, $name_async, req $src, test_inner!( + @request_head $method, $path, $proto, [$($qk:$qv),*], [$($hk:$hv),*] + )); + }; + + ( + $name: ident $name_async: ident, + req $src: literal, + $method: ident $path: literal [$($qk: ident : $qv: literal),*] $proto: ident, + [$($hk: ident : $hv: literal),*$(,)?], + $body: literal + ) => { + test!(@inner $name, $name_async, req $src, test_inner!( + @request_head $method, $path, $proto, [$($qk:$qv),*], [$($hk:$hv),*] + ), $body); + }; + + ( + $name: ident $name_async: ident, + res $src: literal, + $proto: ident $status: ident, + [$($hk: ident : $hv: literal),*$(,)?] + ) => { + test!(@inner $name, $name_async, res $src, test_inner!( + @response_head $proto, $status, [$($hk:$hv),*] + )); + }; + + ( + $name: ident $name_async: ident, + res $src: literal, + $proto: ident $status: ident, + [$($hk: ident : $hv: literal),*$(,)?], + $body: literal + ) => { + test!(@inner $name, $name_async, res $src, test_inner!( + @response_head $proto, $status, [$($hk:$hv),*] + ), $body); + }; + + ( + @inner $name: ident, $name_async: ident, req $src: literal, $res_head: expr + ) => { + #[test] + fn $name() { + let mut src = $src.into(); + let stream = TestSyncStream::<4>::new(&mut src); + let mut handler = StreamHandler::new(stream); + let (head, body) = handler.receive_request().unwrap(); + assert_eq!(head, $res_head); + assert!(body.is_none()); + } + + #[cfg(all(feature = "async", any(feature = "tokio-net", feature = "tokio-unixsocks")))] + #[tokio::test] + async fn $name_async() { + let mut src = $src.into(); + let stream = TestAsyncStream::<4>::new(&mut src); + let mut handler = StreamHandler::new(stream); + let (head, body) = handler.receive_request_async().await.unwrap(); + assert_eq!(head, $res_head); + assert!(body.is_none()); + } + }; + ( + @inner $name: ident, $name_async: ident, req $src: literal, $res_head: expr, $res_body: expr + ) => { + #[test] + fn $name() { + let mut src = $src.into(); + let stream = TestSyncStream::<4>::new(&mut src); + let mut handler = StreamHandler::new(stream); + let (head, body) = handler.receive_request().unwrap(); + assert_eq!(head, $res_head); + let Some(ExpectedBody::Sized(mut body)) = body else { panic!("body mismatch"); }; + assert_eq!(body.recv_all().unwrap(), $res_body.to_vec()); + } + + #[cfg(all(feature = "async", any(feature = "tokio-net", feature = "tokio-unixsocks")))] + #[tokio::test] + async fn $name_async() { + let mut src = $src.into(); + let stream = TestAsyncStream::<4>::new(&mut src); + let mut handler = StreamHandler::new(stream); + let (head, body) = handler.receive_request_async().await.unwrap(); + assert_eq!(head, $res_head); + let Some(ExpectedBody::Sized(mut body)) = body else { panic!("body mismatch"); }; + assert_eq!(body.recv_all_async().await.unwrap(), $res_body.to_vec()); + } + }; + + ( + @inner $name: ident, $name_async: ident, res $src: literal, $res_head: expr + ) => { + #[test] + fn $name() { + let mut src = $src.into(); + let stream = TestSyncStream::<4>::new(&mut src); + let mut handler = StreamHandler::new(stream); + let (head, body) = handler.receive_response().unwrap(); + assert_eq!(head, $res_head); + assert!(body.is_none()); + } + + #[cfg(all(feature = "async", any(feature = "tokio-net", feature = "tokio-unixsocks")))] + #[tokio::test] + async fn $name_async() { + let mut src = $src.into(); + let stream = TestAsyncStream::<4>::new(&mut src); + let mut handler = StreamHandler::new(stream); + let (head, body) = handler.receive_response_async().await.unwrap(); + assert_eq!(head, $res_head); + assert!(body.is_none()); + } + }; + ( + @inner $name: ident, $name_async: ident, res $src: literal, $res_head: expr, $res_body: expr + ) => { + #[test] + fn $name() { + let mut src = $src.into(); + let stream = TestSyncStream::<4>::new(&mut src); + let mut handler = StreamHandler::new(stream); + let (head, body) = handler.receive_response().unwrap(); + assert_eq!(head, $res_head); + let Some(ExpectedBody::Sized(mut body)) = body else { panic!("body mismatch"); }; + assert_eq!(body.recv_all().unwrap(), $res_body.to_vec()); + } + + #[cfg(all(feature = "async", any(feature = "tokio-net", feature = "tokio-unixsocks")))] + #[tokio::test] + async fn $name_async() { + let mut src = $src.into(); + let stream = TestAsyncStream::<4>::new(&mut src); + let mut handler = StreamHandler::new(stream); + let (head, body) = handler.receive_response_async().await.unwrap(); + assert_eq!(head, $res_head); + let Some(ExpectedBody::Sized(mut body)) = body else { panic!("body mismatch"); }; + assert_eq!(body.recv_all_async().await.unwrap(), $res_body.to_vec()); + } + }; + } + + test!( + request_valid_no_body async_request_valid_no_body, + req "GET / HTTP/1.1\r\nserver: inferium\r\n\r\n", + GET "/"[] HTTP1_1, + [ + SERVER: "inferium", + ] + ); + + test!( + request_valid_body async_request_valid_body, + req "GET / HTTP/1.1\r\ncontent-length: 4\r\n\r\ntest", + GET "/"[] HTTP1_1, + [ + CONTENT_LENGTH: "4", + ], + b"test" + ); + + test!( + response_valid_no_body async_response_valid_no_body, + res "HTTP/1.0 200 OK\r\nserver: inferium\r\n\r\n", + HTTP1_0 Ok, + [ + SERVER: "inferium", + ] + ); + + test!( + response_valid_body async_response_valid_body, + res "HTTP/1.1 200 OK\r\nserver: inferium\r\nconnection: close\r\ncontent-length: 4\r\n\r\n\ + test", + HTTP1_1 Ok, + [ + SERVER: "inferium", + CONNECTION: "close", + CONTENT_LENGTH: "4", + ], + b"test" + ); + + test!( + response_short_body async_response_short_body, + res "HTTP/1.1 200 OK\r\nserver: inferium\r\nconnection: close\r\ncontent-length: 4\r\n\r\n\ + tes", + HTTP1_1 Ok, + [ + SERVER: "inferium", + CONNECTION: "close", + CONTENT_LENGTH: "4", + ], + b"tes" + ); + + test!( + response_long_body async_response_long_body, + res "HTTP/1.1 200 OK\r\nserver: inferium\r\nconnection: close\r\ncontent-length: 4\r\n\r\n\ + testing", + HTTP1_1 Ok, + [ + SERVER: "inferium", + CONNECTION: "close", + CONTENT_LENGTH: "4", + ], + b"test" + ); +} diff --git a/lib/inferium/src/proto/mod.rs b/lib/inferium/src/proto/mod.rs new file mode 100644 index 0000000..f9722e0 --- /dev/null +++ b/lib/inferium/src/proto/mod.rs @@ -0,0 +1 @@ +pub mod h1; diff --git a/lib/inferium/src/settings.rs b/lib/inferium/src/settings.rs new file mode 100644 index 0000000..9fb2e5c --- /dev/null +++ b/lib/inferium/src/settings.rs @@ -0,0 +1,2 @@ +pub const BUF_SIZE_HEAD: usize = 8192; +pub const BUF_SIZE_BODY: usize = 4096; diff --git a/lib/inferium/src/status.rs b/lib/inferium/src/status.rs new file mode 100644 index 0000000..c7aad41 --- /dev/null +++ b/lib/inferium/src/status.rs @@ -0,0 +1,121 @@ +macro_rules! http_status { +( + $(#$objdoc: tt)+ + $($ident: ident = ($number: literal, $text: literal)),*$(,)? +) => { + $(#$objdoc)+ + #[derive(Debug, PartialEq, Eq)] + pub enum Status { + $($ident),* + } + + impl TryFrom<&[u8]> for Status { + type Error = (); + + fn try_from(value: &[u8]) -> Result { + match value { $($number => Ok(Self::$ident),)* _ => Err(()) } + } + } + + impl Status { + pub fn num(&self) -> &[u8] { + match self { $(Self::$ident => $number),* } + } + + pub fn text(&self) -> &[u8] { + match self { $(Self::$ident => $text),* } + } + } + + impl std::fmt::Display for Status { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + match self { $(Self::$ident => write_status(self, f)?),* } + Ok(()) + } + } +}; +} + +fn write_status(this: &Status, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + let num = unsafe { std::str::from_utf8_unchecked(this.num()) }; + let text = unsafe { std::str::from_utf8_unchecked(this.text()) }; + write!(f, "{num} {text}") +} + +http_status! { + /// HTTP response status codes and their names + /// + /// Numbers and names are represented in byte arrays for faster response parsing and + /// construction. + + // Informational + Continue = (b"100", b"Continue"), + SwitchingProtocols = (b"101", b"Switching Protocols"), + Processing = (b"102", b"Processing"), + EarlyHints = (b"103", b"Early Hints"), + + // Successful + Ok = (b"200", b"OK"), + Created = (b"201", b"Created"), + Accepted = (b"202", b"Accepted"), + NonAuthoritativeInformation = (b"203", b"Non-Authoritative Information"), + NoContent = (b"204", b"No Content"), + ResetContent = (b"205", b"Reset Content"), + PartialContent = (b"206", b"Partial Content"), + MultiStatus = (b"207", b"Multi-Status"), + AlreadyReported = (b"208", b"AlreadyReported"), + ImUsed = (b"226", b"IM Used"), + + // Redirection + MultipleChoices = (b"300", b"Multiple Choices"), + MovedPermanently = (b"301", b"Moved Permanently"), + Found = (b"302", b"Found"), + SeeOther = (b"303", b"See Other"), + NotModified = (b"304", b"Not Modified"), + TemporaryRedirect = (b"307", b"Temporary Redirect"), + PermanentRedirect = (b"308", b"Permanent Redirect"), + + // Client errors + BadRequest = (b"400", b"Bad Request"), + Unauthorized = (b"401", b"Unauthorized"), + PaymentRequired = (b"402", b"Payment Required"), + Forbidden = (b"403", b"Forbidden"), + NotFound = (b"404", b"Not Found"), + MethodNotAllowed = (b"405", b"MethodNotAllowed"), + NotAcceptable = (b"406", b"Not Acceptable"), + ProxyAuthenticationRequired = (b"407", b"Proxy Authentication Required"), + RequestTimeout = (b"408", b"Request Timeout"), + Conflict = (b"409", b"Conflict"), + Gone = (b"410", b"Gone"), + LengthRequired = (b"411", b"Length Required"), + PreconditionFailed = (b"412", b"Precondition Failed"), + ContentTooLarge = (b"413", b"Content Too Large"), + UriTooLong = (b"414", b"URI Too Long"), + UnsupportedMediaType = (b"415", b"Unsupported Media Type"), + RangeNotSatisfiable = (b"416", b"Range Not Satisfiable"), + ExpectationFailed = (b"417", b"Expectation Failed"), + ImATeapot = (b"418", b"I'm a teapot"), + MisdirectedRequest = (b"421", b"Misdirected Request"), + UnprocessableContent = (b"422", b"Unprocessable Content"), + Locked = (b"423", b"Locked"), + FailedDependency = (b"424", b"Failed Dependency"), + TooEarly = (b"425", b"Too Early"), + UpgradeRequired = (b"426", b"Upgrade Required"), + PreconditionRequired = (b"428", b"Precondition Required"), + TooManyRequests = (b"429", b"Too Many Requests"), + RequestHeaderFieldsTooLarge = (b"431", b"Request Header Fields Too Large"), + UnavailableForLegalReasons = (b"451", b"Unavailable For Legal Reasons"), + + // Server errors + InternalServerError = (b"500", b"Internal Server Error"), + NotImplemented = (b"501", b"Not Implemented"), + BadGateway = (b"502", b"Bad Gateway"), + ServiceUnavailable = (b"503", b"Service Unavailable"), + GatewayTimeout = (b"504", b"Gateway Timeout"), + HttpVersionNotSupported = (b"505", b"HTTP Version Not Supported"), + VariantAlsoNegotiates = (b"506", b"Variant Also Negotiates"), + InsufficientStorage = (b"507", b"Insufficient Storage"), + LoopDetected = (b"508", b"Loop Detected"), + NotExtended = (b"510", b"Not Extended"), + NetworkAuthenticationRequired = (b"511", b"Network Authentication Required"), +}