Merge pull request #12 from bboehmke/join-room

join room before sending message
This commit is contained in:
Guilhem Saurel 2021-09-28 10:50:27 +02:00 committed by GitHub
commit 85dd602f5c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 6 deletions

View file

@ -76,6 +76,11 @@ async def matrix_webhook(request):
else: else:
formatted_body = markdown(str(data["body"]), extensions=["extra"]) formatted_body = markdown(str(data["body"]), extensions=["extra"])
# try to join room first -> non none response means error
resp = await utils.join_room(data["room_id"])
if resp is not None:
return resp
content = { content = {
"msgtype": "m.text", "msgtype": "m.text",
"body": data["body"], "body": data["body"],

View file

@ -6,7 +6,7 @@ from http import HTTPStatus
from aiohttp import web from aiohttp import web
from nio import AsyncClient from nio import AsyncClient
from nio.exceptions import LocalProtocolError from nio.exceptions import LocalProtocolError
from nio.responses import RoomSendError from nio.responses import RoomSendError, JoinError
from . import conf from . import conf
@ -15,6 +15,13 @@ LOGGER = logging.getLogger("matrix_webhook.utils")
CLIENT = AsyncClient(conf.MATRIX_URL, conf.MATRIX_ID) CLIENT = AsyncClient(conf.MATRIX_URL, conf.MATRIX_ID)
def error_map(resp):
"""Map response errors to HTTP status."""
if resp.status_code == "M_UNKNOWN":
return resp.transport_response.status
return ERROR_MAP[resp.status_code]
def create_json_response(status, ret): def create_json_response(status, ret):
"""Create a JSON response.""" """Create a JSON response."""
LOGGER.debug(f"Creating json response: {status=}, {ret=}") LOGGER.debug(f"Creating json response: {status=}, {ret=}")
@ -22,6 +29,27 @@ def create_json_response(status, ret):
return web.json_response(response_data, status=status) return web.json_response(response_data, status=status)
async def join_room(room_id):
"""Try to join the room."""
LOGGER.debug(f"Join room {room_id=}")
for _ in range(10):
try:
resp = await CLIENT.join(room_id)
if isinstance(resp, JoinError):
if resp.status_code == "M_UNKNOWN_TOKEN":
LOGGER.warning("Reconnecting")
await CLIENT.login(conf.MATRIX_PW)
else:
return create_json_response(error_map(resp), resp.message)
else:
return None
except LocalProtocolError as e:
LOGGER.error(f"Send error: {e}")
LOGGER.warning("Trying again")
return create_json_response(HTTPStatus.GATEWAY_TIMEOUT, "Homeserver not responding")
async def send_room_message(room_id, content): async def send_room_message(room_id, content):
"""Send a message to a room.""" """Send a message to a room."""
LOGGER.debug(f"Sending room message in {room_id=}: {content=}") LOGGER.debug(f"Sending room message in {room_id=}: {content=}")
@ -36,9 +64,7 @@ async def send_room_message(room_id, content):
LOGGER.warning("Reconnecting") LOGGER.warning("Reconnecting")
await CLIENT.login(conf.MATRIX_PW) await CLIENT.login(conf.MATRIX_PW)
else: else:
return create_json_response( return create_json_response(error_map(resp), resp.message)
ERROR_MAP[resp.status_code], resp.message
)
else: else:
return create_json_response(HTTPStatus.OK, "OK") return create_json_response(HTTPStatus.OK, "OK")
except LocalProtocolError as e: except LocalProtocolError as e:

View file

@ -33,11 +33,11 @@ class BotTest(unittest.IsolatedAsyncioTestCase):
# this won't be a 403 from synapse, but a LocalProtocolError from matrix_webhook # this won't be a 403 from synapse, but a LocalProtocolError from matrix_webhook
self.assertEqual( self.assertEqual(
bot_req({"body": 3}, KEY, "wrong_room"), bot_req({"body": 3}, KEY, "wrong_room"),
{"status": 403, "ret": "Unknown room"}, {"status": 400, "ret": "wrong_room was not legal room ID or room alias"},
) )
self.assertEqual( self.assertEqual(
bot_req({"body": 3}, KEY, "wrong_room", key_as_param=True), bot_req({"body": 3}, KEY, "wrong_room", key_as_param=True),
{"status": 403, "ret": "Unknown room"}, {"status": 400, "ret": "wrong_room was not legal room ID or room alias"},
) )
async def test_message(self): async def test_message(self):