Skip to content

Commit 6bbe2b1

Browse files
authored
Add ability to modify headers (#170)
1 parent 0d38b16 commit 6bbe2b1

File tree

5 files changed

+141
-96
lines changed

5 files changed

+141
-96
lines changed

integration_tests/base_routes.py

+22
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,28 @@ def shutdown_handler():
163163
logger.log(logging.INFO, "Shutting down")
164164

165165

166+
@app.get("/redirect")
167+
async def redirect(request):
168+
return {"status_code": "307", "body": "", "type": "text"}
169+
170+
171+
@app.get("/redirect_route")
172+
async def redirect_route(request):
173+
return "This is the redirected route"
174+
175+
176+
@app.before_request("/redirect")
177+
async def redirect_before_request(request):
178+
request["headers"]["Location"] = "redirect_route"
179+
return ""
180+
181+
182+
@app.after_request("/redirect")
183+
async def redirect_after_request(request):
184+
request["headers"]["Location"] = "redirect_route"
185+
return ""
186+
187+
166188
if __name__ == "__main__":
167189
ROBYN_URL = os.getenv("ROBYN_URL", "0.0.0.0")
168190
app.add_header("server", "robyn")

integration_tests/test_status_code.py

+5
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,8 @@ def test_404_status_code(session):
1111
def test_404_post_request_status_code(session):
1212
r = requests.post(f"{BASE_URL}/404")
1313
assert r.status_code == 404
14+
15+
def test_307_get_request(session):
16+
r = requests.get(f"{BASE_URL}/redirect")
17+
assert r.text == "This is the redirected route"
18+

robyn/router.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def _format_response(self, res):
2828
"body": res["body"],
2929
**res
3030
}
31-
print("Setting the response", response)
3231
else:
3332
response = {"status_code": "200", "body": res, "type": "text"}
3433

@@ -37,7 +36,6 @@ def _format_response(self, res):
3736
def add_route(self, route_type, endpoint, handler):
3837
async def async_inner_handler(*args):
3938
response = self._format_response(await handler(*args))
40-
print(f"This is the response in python: {response}")
4139
return response
4240

4341
def inner_handler(*args):
@@ -95,7 +93,7 @@ def add_route(self, route_type, endpoint, handler):
9593
def add_after_request(self, endpoint):
9694
def inner(handler):
9795
async def async_inner_handler(*args):
98-
await handler(args)
96+
await handler(*args)
9997
return args
10098

10199
def inner_handler(*args):
@@ -112,7 +110,7 @@ def inner_handler(*args):
112110
def add_before_request(self, endpoint):
113111
def inner(handler):
114112
async def async_inner_handler(*args):
115-
await handler(args)
113+
await handler(*args)
116114
return args
117115

118116
def inner_handler(*args):

src/processor.rs

+70-62
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use anyhow::{bail, Result};
88
use crate::types::{Headers, PyFunction};
99
use futures_util::stream::StreamExt;
1010
use pyo3::prelude::*;
11-
use pyo3::types::{PyDict, PyTuple};
11+
use pyo3::types::PyDict;
1212

1313
use std::fs::File;
1414
use std::io::Read;
@@ -17,9 +17,9 @@ use std::io::Read;
1717
const MAX_SIZE: usize = 10_000;
1818

1919
#[inline]
20-
pub fn apply_headers(response: &mut HttpResponseBuilder, headers: &Arc<Headers>) {
21-
for a in headers.iter() {
22-
response.insert_header((a.key().clone(), a.value().clone()));
20+
pub fn apply_headers(response: &mut HttpResponseBuilder, headers: HashMap<String, String>) {
21+
for (key, val) in (headers).iter() {
22+
response.insert_header((key.clone(), val.clone()));
2323
}
2424
}
2525

@@ -37,7 +37,7 @@ pub fn apply_headers(response: &mut HttpResponseBuilder, headers: &Arc<Headers>)
3737
pub async fn handle_request(
3838
function: PyFunction,
3939
number_of_params: u8,
40-
headers: &Arc<Headers>,
40+
headers: HashMap<String, String>,
4141
payload: &mut web::Payload,
4242
req: &HttpRequest,
4343
route_params: HashMap<String, String>,
@@ -46,7 +46,7 @@ pub async fn handle_request(
4646
let contents = match execute_http_function(
4747
function,
4848
payload,
49-
headers,
49+
headers.clone(),
5050
req,
5151
route_params,
5252
queries,
@@ -58,17 +58,30 @@ pub async fn handle_request(
5858
Err(err) => {
5959
println!("Error: {:?}", err);
6060
let mut response = HttpResponse::InternalServerError();
61-
apply_headers(&mut response, headers);
61+
apply_headers(&mut response, headers.clone());
6262
return response.finish();
6363
}
6464
};
6565

66-
let mut response = HttpResponse::Ok();
66+
let body = contents.get("body").unwrap().to_owned();
6767
let status_code =
6868
actix_http::StatusCode::from_str(contents.get("status_code").unwrap()).unwrap();
69-
apply_headers(&mut response, headers);
70-
response.status(status_code);
71-
response.body(contents.get("body").unwrap().to_owned())
69+
70+
let mut response = HttpResponse::build(status_code);
71+
apply_headers(&mut response, headers.clone());
72+
let final_response = if body != "" {
73+
response.body(body)
74+
} else {
75+
response.finish()
76+
};
77+
78+
println!(
79+
"The status code is {} and the headers are {:?}",
80+
final_response.status(),
81+
final_response.headers()
82+
);
83+
// response.body(contents.get("body").unwrap().to_owned())
84+
final_response
7285
}
7386

7487
pub async fn handle_middleware_request(
@@ -79,7 +92,7 @@ pub async fn handle_middleware_request(
7992
req: &HttpRequest,
8093
route_params: HashMap<String, String>,
8194
queries: HashMap<String, String>,
82-
) -> Py<PyTuple> {
95+
) -> HashMap<String, HashMap<String, String>> {
8396
let contents = match execute_middleware_function(
8497
function,
8598
payload,
@@ -92,12 +105,10 @@ pub async fn handle_middleware_request(
92105
.await
93106
{
94107
Ok(res) => res,
95-
Err(err) => Python::with_gil(|py| {
96-
println!("{:?}", err);
97-
PyTuple::empty(py).into_py(py)
98-
}),
108+
Err(_err) => HashMap::new(),
99109
};
100110

111+
println!("These are the middleware response {:?}", contents);
101112
contents
102113
}
103114

@@ -123,12 +134,12 @@ async fn execute_middleware_function<'a>(
123134
route_params: HashMap<String, String>,
124135
queries: HashMap<String, String>,
125136
number_of_params: u8,
126-
) -> Result<Py<PyTuple>> {
137+
) -> Result<HashMap<String, HashMap<String, String>>> {
127138
// TODO:
128139
// try executing the first version of middleware(s) here
129140
// with just headers as params
130141

131-
let mut data: Option<Vec<u8>> = None;
142+
let mut data: Vec<u8> = Vec::new();
132143

133144
if req.method() == Method::POST
134145
|| req.method() == Method::PUT
@@ -145,13 +156,13 @@ async fn execute_middleware_function<'a>(
145156
body.extend_from_slice(&chunk);
146157
}
147158

148-
data = Some(body.to_vec())
159+
data = body.to_vec()
149160
}
150161

151162
// request object accessible while creating routes
152163
let mut request = HashMap::new();
153164
let mut headers_python = HashMap::new();
154-
for elem in headers.into_iter() {
165+
for elem in (*headers).iter() {
155166
headers_python.insert(elem.key().clone(), elem.value().clone());
156167
}
157168

@@ -162,7 +173,7 @@ async fn execute_middleware_function<'a>(
162173
request.insert("params", route_params.into_py(py));
163174
request.insert("queries", queries.into_py(py));
164175
request.insert("headers", headers_python.into_py(py));
165-
request.insert("body", data.into_py(py));
176+
// request.insert("body", data.into_py(py));
166177

167178
// this makes the request object to be accessible across every route
168179
let coro: PyResult<&PyAny> = match number_of_params {
@@ -176,10 +187,13 @@ async fn execute_middleware_function<'a>(
176187

177188
let output = output.await?;
178189

179-
let res = Python::with_gil(|py| -> PyResult<Py<PyTuple>> {
180-
let output: Py<PyTuple> = output.extract(py).unwrap();
181-
Ok(output)
182-
})?;
190+
let res =
191+
Python::with_gil(|py| -> PyResult<HashMap<String, HashMap<String, String>>> {
192+
let output: Vec<HashMap<String, HashMap<String, String>>> =
193+
output.extract(py).unwrap();
194+
let responses = output[0].clone();
195+
Ok(responses)
196+
})?;
183197

184198
Ok(res)
185199
}
@@ -200,9 +214,10 @@ async fn execute_middleware_function<'a>(
200214
2_u8..=u8::MAX => handler.call1((request,)),
201215
};
202216

203-
let output: Py<PyTuple> = output?.extract().unwrap();
217+
let output: Vec<HashMap<String, HashMap<String, String>>> =
218+
output?.extract().unwrap();
204219

205-
Ok(output)
220+
Ok(output[0].clone())
206221
})
207222
})
208223
.await?
@@ -215,15 +230,15 @@ async fn execute_middleware_function<'a>(
215230
async fn execute_http_function(
216231
function: PyFunction,
217232
payload: &mut web::Payload,
218-
headers: &Headers,
233+
headers: HashMap<String, String>,
219234
req: &HttpRequest,
220235
route_params: HashMap<String, String>,
221236
queries: HashMap<String, String>,
222237
number_of_params: u8,
223238
// need to change this to return a response struct
224239
// create a custom struct for this
225240
) -> Result<HashMap<String, String>> {
226-
let mut data: Option<Vec<u8>> = None;
241+
let mut data: Vec<u8> = Vec::new();
227242

228243
if req.method() == Method::POST
229244
|| req.method() == Method::PUT
@@ -240,28 +255,21 @@ async fn execute_http_function(
240255
body.extend_from_slice(&chunk);
241256
}
242257

243-
data = Some(body.to_vec())
258+
data = body.to_vec()
244259
}
245260

246261
// request object accessible while creating routes
247262
let mut request = HashMap::new();
248-
let mut headers_python = HashMap::new();
249-
for elem in headers.into_iter() {
250-
headers_python.insert(elem.key().clone(), elem.value().clone());
251-
}
252263

253264
match function {
254265
PyFunction::CoRoutine(handler) => {
255266
let output = Python::with_gil(|py| {
256267
let handler = handler.as_ref(py);
257268
request.insert("params", route_params.into_py(py));
258269
request.insert("queries", queries.into_py(py));
259-
request.insert("headers", headers_python.into_py(py));
260-
261-
if let Some(res) = data {
262-
let data = res.into_py(py);
263-
request.insert("body", data);
264-
};
270+
request.insert("headers", headers.into_py(py));
271+
let data = data.into_py(py);
272+
request.insert("body", data);
265273

266274
// this makes the request object to be accessible across every route
267275
let coro: PyResult<&PyAny> = match number_of_params {
@@ -298,11 +306,9 @@ async fn execute_http_function(
298306
Python::with_gil(|py| {
299307
let handler = handler.as_ref(py);
300308
request.insert("params", route_params.into_py(py));
301-
request.insert("headers", headers_python.into_py(py));
302-
if let Some(res) = data {
303-
let data = res.into_py(py);
304-
request.insert("body", data);
305-
};
309+
request.insert("headers", headers.into_py(py));
310+
let data = data.into_py(py);
311+
request.insert("body", data);
306312

307313
let output: PyResult<&PyAny> = match number_of_params {
308314
0 => handler.call0(),
@@ -325,22 +331,24 @@ pub async fn execute_event_handler(
325331
event_handler: Option<Arc<PyFunction>>,
326332
event_loop: Arc<Py<PyAny>>,
327333
) {
328-
if let Some(handler) = event_handler { match &(*handler) {
329-
PyFunction::SyncFunction(function) => {
330-
println!("Startup event handler");
331-
Python::with_gil(|py| {
332-
function.call0(py).unwrap();
333-
});
334-
}
335-
PyFunction::CoRoutine(function) => {
336-
let future = Python::with_gil(|py| {
337-
println!("Startup event handler async");
338-
339-
let coroutine = function.as_ref(py).call0().unwrap();
340-
pyo3_asyncio::into_future_with_loop((*event_loop).as_ref(py), coroutine)
341-
.unwrap()
342-
});
343-
future.await.unwrap();
334+
if let Some(handler) = event_handler {
335+
match &(*handler) {
336+
PyFunction::SyncFunction(function) => {
337+
println!("Startup event handler");
338+
Python::with_gil(|py| {
339+
function.call0(py).unwrap();
340+
});
341+
}
342+
PyFunction::CoRoutine(function) => {
343+
let future = Python::with_gil(|py| {
344+
println!("Startup event handler async");
345+
346+
let coroutine = function.as_ref(py).call0().unwrap();
347+
pyo3_asyncio::into_future_with_loop((*event_loop).as_ref(py), coroutine)
348+
.unwrap()
349+
});
350+
future.await.unwrap();
351+
}
344352
}
345-
} }
353+
}
346354
}

0 commit comments

Comments
 (0)