diff --git a/.gitignore b/.gitignore index af71d88..cce1800 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,8 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -main.ipynb \ No newline at end of file +main.ipynb + +*.xml + +dump_* \ No newline at end of file diff --git a/hivemind_etl/activities.py b/hivemind_etl/activities.py index c4f75b3..25c231d 100644 --- a/hivemind_etl/activities.py +++ b/hivemind_etl/activities.py @@ -1,91 +1,26 @@ import logging from typing import Any -from temporalio import activity, workflow +from hivemind_etl.website.activities import ( + get_hivemind_website_comminities, + extract_website, + transform_website_data, + load_website_data, +) +from hivemind_etl.mediawiki.activities import ( + get_hivemind_mediawiki_platforms, + extract_mediawiki, + transform_mediawiki_data, + load_mediawiki_data, +) + +from temporalio import activity -with workflow.unsafe.imports_passed_through(): - from hivemind_etl.website.module import ModulesWebsite - from hivemind_etl.website.website_etl import WebsiteETL - from llama_index.core import Document logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -@activity.defn -async def get_communities(platform_id: str | None = None) -> list[dict[str, Any]]: - """ - Fetch all communities that need to be processed in case of no platform id given - Else, just process for one platform - - Parameters - ----------- - platform_id : str | None - A platform's community to be fetched - for default it is as `None` meaning to get all communities information - - Returns - --------- - communities : list[dict[str, Any]] - a list of communities holding website informations - """ - try: - if platform_id: - logger.info("Website ingestion is filtered for a single community!") - communities = ModulesWebsite().get_learning_platforms( - filter_platform_id=platform_id - ) - logger.info(f"Found {len(communities)} communities to process") - logging.info(f"communities: {communities}") - return communities - except Exception as e: - logger.error(f"Error fetching communities: {str(e)}") - raise - - -@activity.defn -async def extract_website(urls: list[str], community_id: str) -> list[dict]: - """Extract data from website URLs.""" - try: - logger.info( - f"Starting extraction for community {community_id} with {len(urls)} URLs" - ) - website_etl = WebsiteETL(community_id=community_id) - result = await website_etl.extract(urls=urls) - logger.info(f"Completed extraction for community {community_id}") - return result - except Exception as e: - logger.error(f"Error in extraction for community {community_id}: {str(e)}") - raise - - -@activity.defn -async def transform_data(raw_data: list[dict], community_id: str) -> list[Document]: - """Transform the extracted raw data.""" - try: - logger.info(f"Starting transformation for community {community_id}") - website_etl = WebsiteETL(community_id=community_id) - result = website_etl.transform(raw_data=raw_data) - logger.info(f"Completed transformation for community {community_id}") - return result - except Exception as e: - logger.error(f"Error in transformation for community {community_id}: {str(e)}") - raise - - -@activity.defn -async def load_data(documents: list[Document], community_id: str) -> None: - """Load the transformed data into the database.""" - try: - logger.info(f"Starting data load for community {community_id}") - website_etl = WebsiteETL(community_id=community_id) - website_etl.load(documents=documents) - logger.info(f"Completed data load for community {community_id}") - except Exception as e: - logger.error(f"Error in data load for community {community_id}: {str(e)}") - raise - - @activity.defn async def say_hello(): return 7 diff --git a/hivemind_etl/mediawiki/__init__.py b/hivemind_etl/mediawiki/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hivemind_etl/mediawiki/activities.py b/hivemind_etl/mediawiki/activities.py new file mode 100644 index 0000000..3037411 --- /dev/null +++ b/hivemind_etl/mediawiki/activities.py @@ -0,0 +1,98 @@ +import logging +from typing import Any + +from temporalio import activity, workflow + +with workflow.unsafe.imports_passed_through(): + from hivemind_etl.mediawiki.module import ModulesMediaWiki + from hivemind_etl.mediawiki.etl import MediawikiETL + from llama_index.core import Document + + +@activity.defn +async def get_hivemind_mediawiki_platforms( + platform_id: str | None = None, +) -> list[dict[str, Any]]: + """ + Fetch all MediaWiki communities that need to be processed in case of no platform id given + Else, just process for one platform + + Parameters + ----------- + platform_id : str | None + A platform's community to be fetched + for default it is as `None` meaning to get all platforms information + + example data output: + ``` + [{ + "community_id": "6579c364f1120850414e0dc5", + "base_url": "some_api_url", + "namespaces": [1, 2, 3], + }] + ``` + + Returns + --------- + platforms : list[dict[str, Any]] + a list of platforms holding MediaWiki informations + """ + try: + if platform_id: + logging.info("MediaWiki ingestion is filtered for a single platform!") + platforms = ModulesMediaWiki().get_learning_platforms( + platform_id_filter=platform_id + ) + logging.info(f"Found {len(platforms)} platforms to process") + logging.info(f"platforms: {platforms}") + return platforms + except Exception as e: + logging.error(f"Error fetching MediaWiki platforms: {str(e)}") + raise + + +@activity.defn +async def extract_mediawiki(mediawiki_platform: dict[str, Any]) -> None: + """Extract data from MediaWiki API URL.""" + try: + community_id = mediawiki_platform["community_id"] + api_url = mediawiki_platform["base_url"] + namespaces = mediawiki_platform["namespaces"] + + logging.info( + f"Starting extraction for community {community_id} with API URL: {api_url}" + ) + mediawiki_etl = MediawikiETL(community_id=community_id, namespaces=namespaces) + mediawiki_etl.extract(api_url=api_url) + logging.info(f"Completed extraction for community {community_id}") + except Exception as e: + community_id = mediawiki_platform["community_id"] + logging.error(f"Error in extraction for community {community_id}: {str(e)}") + raise + + +@activity.defn +async def transform_mediawiki_data(community_id: str) -> list[Document]: + """Transform the extracted MediaWiki data.""" + try: + logging.info(f"Starting transformation for community {community_id}") + mediawiki_etl = MediawikiETL(community_id=community_id) + result = mediawiki_etl.transform() + logging.info(f"Completed transformation for community {community_id}") + return result + except Exception as e: + logging.error(f"Error in transformation for community {community_id}: {str(e)}") + raise + + +@activity.defn +async def load_mediawiki_data(documents: list[Document], community_id: str) -> None: + """Load the transformed MediaWiki data into the database.""" + try: + logging.info(f"Starting data load for community {community_id}") + mediawiki_etl = MediawikiETL(community_id=community_id) + mediawiki_etl.load(documents=documents) + logging.info(f"Completed data load for community {community_id}") + except Exception as e: + logging.error(f"Error in data load for community {community_id}: {str(e)}") + raise diff --git a/hivemind_etl/mediawiki/etl.py b/hivemind_etl/mediawiki/etl.py new file mode 100644 index 0000000..3b45dde --- /dev/null +++ b/hivemind_etl/mediawiki/etl.py @@ -0,0 +1,85 @@ +import logging +import shutil + +from llama_index.core import Document +from tc_hivemind_backend.ingest_qdrant import CustomIngestionPipeline +from hivemind_etl.mediawiki.transform_xml import parse_mediawiki_xml +from hivemind_etl.mediawiki.wikiteam_crawler import WikiteamCrawler + + +class MediawikiETL: + def __init__( + self, + community_id: str, + namespaces: list[int], + delete_dump_after_load: bool = True, + ) -> None: + self.community_id = community_id + self.wikiteam_crawler = WikiteamCrawler(community_id, namespaces=namespaces) + + self.dump_dir = f"dump_{self.community_id}" + self.delete_dump_after_load = delete_dump_after_load + + def extract(self, api_url: str, dump_dir: str | None = None) -> None: + if dump_dir is None: + dump_dir = self.dump_dir + else: + self.dump_dir = dump_dir + + self.wikiteam_crawler.crawl(api_url, dump_dir) + + def transform(self) -> list[Document]: + pages = parse_mediawiki_xml(file_dir=self.dump_dir) + + documents: list[Document] = [] + for page in pages: + try: + documents.append( + Document( + doc_id=page.page_id, + text=page.revision.text, + metadata={ + "title": page.title, + "namespace": page.namespace, + "revision_id": page.revision.revision_id, + "parent_revision_id": page.revision.parent_revision_id, + "timestamp": page.revision.timestamp, + "comment": page.revision.comment, + "contributor_username": page.revision.contributor.username, + "contributor_user_id": page.revision.contributor.user_id, + "sha1": page.revision.sha1, + "model": page.revision.model, + }, + excluded_embed_metadata_keys=[ + "namespace", + "revision_id", + "parent_revision_id", + "sha1", + "model", + "contributor_user_id", + "comment", + "timestamp", + ], + excluded_llm_metadata_keys=[ + "namespace", + "revision_id", + "parent_revision_id", + "sha1", + "model", + "contributor_user_id", + ], + ) + ) + except Exception as e: + logging.error(f"Error transforming page {page.page_id}: {e}") + + return documents + + def load(self, documents: list[Document]) -> None: + ingestion_pipeline = CustomIngestionPipeline( + self.community_id, collection_name="mediawiki" + ) + ingestion_pipeline.run_pipeline(documents) + + if self.delete_dump_after_load: + shutil.rmtree(self.dump_dir) diff --git a/hivemind_etl/mediawiki/llama_xml_reader.py b/hivemind_etl/mediawiki/llama_xml_reader.py new file mode 100644 index 0000000..7e61d20 --- /dev/null +++ b/hivemind_etl/mediawiki/llama_xml_reader.py @@ -0,0 +1,95 @@ +"""XML Reader.""" + +"""Copied from https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/readers/llama-index-readers-file/llama_index/readers/file/xml/base.py""" + +import re +import xml.etree.ElementTree as ET +from pathlib import Path +from typing import Dict, List, Optional + +from llama_index.core.readers.base import BaseReader +from llama_index.core.schema import Document + + +def _get_leaf_nodes_up_to_level(root: ET.Element, level: int) -> List[ET.Element]: + """Get collection of nodes up to certain level including leaf nodes. + + Args: + root (ET.Element): XML Root Element + level (int): Levels to traverse in the tree + + Returns: + List[ET.Element]: List of target nodes + """ + + def traverse(current_node, current_level): + if len(current_node) == 0 or level == current_level: + # Keep leaf nodes and target level nodes + nodes.append(current_node) + elif current_level < level: + # Move to the next level + for child in current_node: + traverse(child, current_level + 1) + + nodes = [] + traverse(root, 0) + return nodes + + +class XMLReader(BaseReader): + """XML reader. + + Reads XML documents with options to help suss out relationships between nodes. + + Args: + tree_level_split (int): From which level in the xml tree we split documents, + the default level is the root which is level 0 + + """ + + def __init__(self, tree_level_split: Optional[int] = 0) -> None: + """Initialize with arguments.""" + super().__init__() + self.tree_level_split = tree_level_split + + def _parse_xmlelt_to_document( + self, root: ET.Element, extra_info: Optional[Dict] = None + ) -> List[Document]: + """Parse the xml object into a list of Documents. + + Args: + root: The XML Element to be converted. + extra_info (Optional[Dict]): Additional information. Default is None. + + Returns: + Document: The documents. + """ + nodes = _get_leaf_nodes_up_to_level(root, self.tree_level_split) + documents = [] + for node in nodes: + content = ET.tostring(node, encoding="utf8").decode("utf-8") + content = re.sub(r"^<\?xml.*", "", content) + content = content.strip() + documents.append(Document(text=content, extra_info=extra_info or {})) + + return documents + + def load_data( + self, + file: Path, + extra_info: Optional[Dict] = None, + ) -> List[Document]: + """Load data from the input file. + + Args: + file (Path): Path to the input file. + extra_info (Optional[Dict]): Additional information. Default is None. + + Returns: + List[Document]: List of documents. + """ + if not isinstance(file, Path): + file = Path(file) + + tree = ET.parse(file) + return self._parse_xmlelt_to_document(tree.getroot(), extra_info) diff --git a/hivemind_etl/mediawiki/module.py b/hivemind_etl/mediawiki/module.py new file mode 100644 index 0000000..7dc4e70 --- /dev/null +++ b/hivemind_etl/mediawiki/module.py @@ -0,0 +1,91 @@ +import logging + +from tc_hivemind_backend.db.modules_base import ModulesBase + + +class ModulesMediaWiki(ModulesBase): + def __init__(self) -> None: + self.platform_name = "mediaWiki" + super().__init__() + + def get_learning_platforms( + self, + platform_id_filter: str | None = None, + ) -> list[dict[str, str | list[str]]]: + """ + Get all the MediaWiki communities with their page titles. + + Parameters + ----------- + platform_id_filter : str | None + the platform id to filter the results for + + Returns + --------- + community_orgs : list[dict[str, str | list[str]]] = [] + a list of MediaWiki data information + + example data output: + ``` + [{ + "community_id": "6579c364f1120850414e0dc5", + "base_url": "some_api_url", + "namespaces": [1, 2, 3], + }] + ``` + """ + modules = self.query(platform=self.platform_name, projection={"name": 0}) + communities_data: list[dict[str, str | list[str]]] = [] + + for module in modules: + community = module["community"] + + # each platform of the community + for platform in module["options"]["platforms"]: + if platform["name"] != self.platform_name: + continue + + platform_id = platform["platform"] + + if platform_id_filter is not None and platform_id_filter != str( + platform_id + ): + continue + + try: + # TODO: retrieve baseURL and path in 1 db call + base_url = self.get_platform_metadata( + platform_id=platform_id, + metadata_name="baseURL", + ) + path = self.get_platform_metadata( + platform_id=platform_id, + metadata_name="path", + ) + + if not isinstance(path, str) and not isinstance(base_url, str): + raise ValueError("Wrong format for `path` and `base_url`!") + + modules_options = platform["metadata"] + namespaces = modules_options.get("namespaces", []) + + if not namespaces: + logging.warning( + f"No namespaces found for platform: {platform_id}" + ) + continue + + communities_data.append( + { + "community_id": str(community), + "namespaces": namespaces, + "base_url": base_url + path, # type: ignore + } + ) + except Exception as exp: + logging.error( + "Exception while fetching mediaWiki modules " + f"for platform: {platform_id} | exception: {exp}" + ) + + return communities_data diff --git a/hivemind_etl/mediawiki/schema.py b/hivemind_etl/mediawiki/schema.py new file mode 100644 index 0000000..8f5c326 --- /dev/null +++ b/hivemind_etl/mediawiki/schema.py @@ -0,0 +1,34 @@ +from typing import Optional + +from pydantic import BaseModel + + +class Contributor(BaseModel): + username: Optional[str] = None + user_id: Optional[str] = None + + +class Revision(BaseModel): + revision_id: Optional[str] = None + parent_revision_id: Optional[str] = None + timestamp: Optional[str] = None + comment: str = "" + contributor: Contributor = Contributor() + model: Optional[str] = None + format: Optional[str] = None + text: str = "" + sha1: Optional[str] = None + + +class Page(BaseModel): + title: Optional[str] = None + namespace: Optional[str] = None + page_id: Optional[str] = None + revision: Optional[Revision] = None + + +class SiteInfo(BaseModel): + sitename: Optional[str] = None + dbname: Optional[str] = None + base: Optional[str] = None + generator: Optional[str] = None diff --git a/hivemind_etl/mediawiki/transform_xml.py b/hivemind_etl/mediawiki/transform_xml.py new file mode 100644 index 0000000..bf32a1c --- /dev/null +++ b/hivemind_etl/mediawiki/transform_xml.py @@ -0,0 +1,134 @@ +import logging +import xml.etree.ElementTree as ET +import os +import glob + +from hivemind_etl.mediawiki.schema import Contributor, Page, Revision, SiteInfo + + +def parse_mediawiki_xml(file_dir: str) -> list[Page]: + """Parse a MediaWiki XML dump file and extract page information. + + This function processes a MediaWiki XML dump file, extracting detailed information + about pages, their revisions, and contributors. The data is structured into + Pydantic models for type safety and validation. + + Parameters + ---------- + file_dir : str + Path to the directory containing the MediaWiki XML dump file to be parsed. + + Returns + ------- + pages : list[Page] + A list of Page objects containing the parsed data. Each Page object includes: + - Basic page information (title, namespace, page_id) + - Revision details (revision_id, timestamp, text, etc.) + - Contributor information (username, user_id) + + Examples + -------- + >>> pages = parse_mediawiki_xml("wiki_dump_directory") + >>> for page in pages: + ... print(f"Page: {page.title}") + ... if page.revision: + ... print(f"Last edited by: {page.revision.contributor.username}") + + Notes + ----- + - The function handles optional fields gracefully, setting them to None when not present + - XML namespaces are automatically handled for MediaWiki export format + - The text content retains XML escapes (e.g., < for <) + - The function logs the total number of pages processed + """ + # Find XML file in the directory + xml_files = glob.glob(os.path.join(file_dir, "*.xml")) + if not xml_files: + raise FileNotFoundError(f"No XML files found in directory: {file_dir}") + + # Use the first XML file found + # there should be only one xml file in the directory (wikiteam3 crawler settings) + xml_file = xml_files[0] + logging.info(f"Found XML file: {xml_file}") + + namespaces = {"mw": "http://www.mediawiki.org/xml/export-0.11/"} + # Parse the XML file + tree = ET.parse(xml_file) + root = tree.getroot() + + # --- Extract Site Information --- + siteinfo_el = root.find("mw:siteinfo", namespaces) + siteinfo = SiteInfo() + if siteinfo_el is not None: + for tag in ["sitename", "dbname", "base", "generator"]: + el = siteinfo_el.find(f"mw:{tag}", namespaces) + setattr(siteinfo, tag, el.text if el is not None else None) + + # --- Process Each Page --- + pages = [] + for page in root.findall("mw:page", namespaces): + page_data = Page() + # Extract basic page details: title, namespace, page id + title_el = page.find("mw:title", namespaces) + page_data.title = title_el.text if title_el is not None else None + + ns_el = page.find("mw:ns", namespaces) + page_data.namespace = ns_el.text if ns_el is not None else None + + id_el = page.find("mw:id", namespaces) + page_data.page_id = id_el.text if id_el is not None else None + + # Extract revision details + revision = page.find("mw:revision", namespaces) + if revision is not None: + rev_data = Revision() + rev_id_el = revision.find("mw:id", namespaces) + rev_data.revision_id = rev_id_el.text if rev_id_el is not None else None + + parentid_el = revision.find("mw:parentid", namespaces) + rev_data.parent_revision_id = ( + parentid_el.text if parentid_el is not None else None + ) + + timestamp_el = revision.find("mw:timestamp", namespaces) + rev_data.timestamp = timestamp_el.text if timestamp_el is not None else None + + # Revision comment (present only on some pages) + comment_el = revision.find("mw:comment", namespaces) + rev_data.comment = comment_el.text if comment_el is not None else "" + + # Contributor information + contributor = revision.find("mw:contributor", namespaces) + if contributor is not None: + cont_data = Contributor() + username_el = contributor.find("mw:username", namespaces) + cont_data.username = ( + username_el.text if username_el is not None else None + ) + + user_id_el = contributor.find("mw:id", namespaces) + cont_data.user_id = user_id_el.text if user_id_el is not None else None + + rev_data.contributor = cont_data + + # Other revision details like model and format + model_el = revision.find("mw:model", namespaces) + rev_data.model = model_el.text if model_el is not None else None + + format_el = revision.find("mw:format", namespaces) + rev_data.format = format_el.text if format_el is not None else None + + # Extract the full text content; note that XML escapes are retained (e.g., <) + text_el = revision.find("mw:text", namespaces) + rev_data.text = text_el.text if text_el is not None else "" + + # Capture sha1 if needed + sha1_el = revision.find("mw:sha1", namespaces) + rev_data.sha1 = sha1_el.text if sha1_el is not None else None + + page_data.revision = rev_data + + pages.append(page_data) + + logging.info(f"Total pages processed: {len(pages)}\n") + return pages diff --git a/hivemind_etl/mediawiki/wikiteam_crawler.py b/hivemind_etl/mediawiki/wikiteam_crawler.py new file mode 100644 index 0000000..1ffa802 --- /dev/null +++ b/hivemind_etl/mediawiki/wikiteam_crawler.py @@ -0,0 +1,83 @@ +import logging +import os + +from wikiteam3.dumpgenerator import DumpGenerator + + +class WikiteamCrawler: + def __init__( + self, + community_id: str, + xml: bool = True, + force: bool = True, + curonly: bool = True, + namespaces: list[int] = [], + **kwargs, + ) -> None: + self.community_id = community_id + self.xml = xml + self.force = force + self.curonly = curonly + self.extra_params = kwargs + self.namespaces = namespaces + + def crawl(self, api_url: str, dump_path: str) -> None: + """ + Crawl the mediawiki dump from the given api url and save it to the given path + + Parameters + ---------- + api_url : str + The url of the mediawiki api + dump_path : str + The path to save the dump file + """ + # Create a list of parameters analogous to the terminal command: + params = [ + "--api", + api_url, + "--path", + dump_path, + ] + + # Add optional parameters based on configuration + if self.xml: + params.append("--xml") + if self.force: + params.append("--force") + if self.curonly: + params.append("--curonly") + if self.namespaces: + params.append(f"--namespaces") + params.append(f"{','.join(map(str, self.namespaces))}") + + # Add any extra parameters passed during initialization + for key, value in self.extra_params.items(): + if isinstance(value, bool): + if value: + params.append(f"--{key}") + else: + params.extend([f"--{key}", str(value)]) + + # Directly call the DumpGenerator static __init__ method which will parse these parameters, + # execute the dump generation process, and run through the rest of the workflow. + DumpGenerator(params) + + def delete_dump(self, dump_path: str) -> None: + """ + Delete the dumped file at the specified path. + + Parameters + ---------- + dump_path : str + The path to the dump file to be deleted + """ + try: + if os.path.exists(dump_path): + os.remove(dump_path) + logging.info(f"Successfully deleted dump file at {dump_path}") + else: + logging.warning(f"Dump file not found at {dump_path}") + except Exception as e: + logging.error(f"Error deleting dump file at {dump_path}: {str(e)}") + raise diff --git a/hivemind_etl/mediawiki/workflows.py b/hivemind_etl/mediawiki/workflows.py new file mode 100644 index 0000000..3cd946a --- /dev/null +++ b/hivemind_etl/mediawiki/workflows.py @@ -0,0 +1,76 @@ +import logging +from datetime import timedelta + +from temporalio import workflow + +with workflow.unsafe.imports_passed_through(): + from hivemind_etl.mediawiki.activities import ( + get_hivemind_mediawiki_platforms, + extract_mediawiki, + transform_mediawiki_data, + load_mediawiki_data, + ) + + +@workflow.defn +class MediaWikiETLWorkflow: + @workflow.run + async def run(self, platform_id: str | None = None) -> None: + """ + Run the MediaWiki ETL workflow for all communities or a specific one. + + Parameters + ----------- + platform_id : str | None + A platform's community to be processed + for default it is as `None` meaning to process all communities + """ + try: + # Get all communities that need to be processed + platforms = await workflow.execute_activity( + get_hivemind_mediawiki_platforms, + platform_id, + start_to_close_timeout=timedelta(minutes=1), + ) + + for platform in platforms: + try: + mediawiki_platform = { + "base_url": platform["base_url"], + "community_id": platform["community_id"], + "namespaces": platform["namespaces"], + } + # Extract data from MediaWiki + await workflow.execute_activity( + extract_mediawiki, + mediawiki_platform, + start_to_close_timeout=timedelta(days=5), + ) + + # Transform the extracted data + documents = await workflow.execute_activity( + transform_mediawiki_data, + platform["community_id"], + start_to_close_timeout=timedelta(minutes=30), + ) + + # Load the transformed data + await workflow.execute_activity( + load_mediawiki_data, + documents, + platform["community_id"], + start_to_close_timeout=timedelta(minutes=30), + ) + + logging.info( + f"Successfully completed ETL for community id: {platform['community_id']}" + ) + except Exception as e: + logging.error( + f"Error processing community id: {platform['community_id']}: {str(e)}" + ) + continue + + except Exception as e: + logging.error(f"Error in MediaWiki ETL workflow: {str(e)}") + raise diff --git a/hivemind_etl/website/activities.py b/hivemind_etl/website/activities.py new file mode 100644 index 0000000..6c2d933 --- /dev/null +++ b/hivemind_etl/website/activities.py @@ -0,0 +1,87 @@ +import logging +from typing import Any + +from temporalio import activity, workflow + +with workflow.unsafe.imports_passed_through(): + from hivemind_etl.website.module import ModulesWebsite + from hivemind_etl.website.website_etl import WebsiteETL + from llama_index.core import Document + + +@activity.defn +async def get_hivemind_website_comminities( + platform_id: str | None = None, +) -> list[dict[str, Any]]: + """ + Fetch all communities that need to be processed in case of no platform id given + Else, just process for one platform + + Parameters + ----------- + platform_id : str | None + A platform's community to be fetched + for default it is as `None` meaning to get all communities information + + Returns + --------- + communities : list[dict[str, Any]] + a list of communities holding website informations + """ + try: + if platform_id: + logging.info("Website ingestion is filtered for a single community!") + communities = ModulesWebsite().get_learning_platforms( + filter_platform_id=platform_id + ) + logging.info(f"Found {len(communities)} communities to process") + logging.info(f"communities: {communities}") + return communities + except Exception as e: + logging.error(f"Error fetching communities: {str(e)}") + raise + + +@activity.defn +async def extract_website(urls: list[str], community_id: str) -> list[dict]: + """Extract data from website URLs.""" + try: + logging.info( + f"Starting extraction for community {community_id} with {len(urls)} URLs" + ) + website_etl = WebsiteETL(community_id=community_id) + result = await website_etl.extract(urls=urls) + logging.info(f"Completed extraction for community {community_id}") + return result + except Exception as e: + logging.error(f"Error in extraction for community {community_id}: {str(e)}") + raise + + +@activity.defn +async def transform_website_data( + raw_data: list[dict], community_id: str +) -> list[Document]: + """Transform the extracted raw data.""" + try: + logging.info(f"Starting transformation for community {community_id}") + website_etl = WebsiteETL(community_id=community_id) + result = website_etl.transform(raw_data=raw_data) + logging.info(f"Completed transformation for community {community_id}") + return result + except Exception as e: + logging.error(f"Error in transformation for community {community_id}: {str(e)}") + raise + + +@activity.defn +async def load_website_data(documents: list[Document], community_id: str) -> None: + """Load the transformed data into the database.""" + try: + logging.info(f"Starting data load for community {community_id}") + website_etl = WebsiteETL(community_id=community_id) + website_etl.load(documents=documents) + logging.info(f"Completed data load for community {community_id}") + except Exception as e: + logging.error(f"Error in data load for community {community_id}: {str(e)}") + raise diff --git a/hivemind_etl/website/workflows.py b/hivemind_etl/website/workflows.py new file mode 100644 index 0000000..126e2dd --- /dev/null +++ b/hivemind_etl/website/workflows.py @@ -0,0 +1,93 @@ +import logging +import asyncio +from datetime import timedelta + +from temporalio import workflow +from temporalio.common import RetryPolicy + +from hivemind_etl.website.activities import ( + extract_website, + get_hivemind_website_comminities, + load_website_data, + transform_website_data, +) + + +# Individual community workflow +@workflow.defn +class CommunityWebsiteWorkflow: + @workflow.run + async def run(self, community_info: dict) -> None: + community_id = community_info["community_id"] + platform_id = community_info["platform_id"] + urls = community_info["urls"] + + logging.info( + f"Starting workflow for community {community_id} | platform {platform_id}" + ) + + # Execute activities in sequence with retries + raw_data = await workflow.execute_activity( + extract_website, + args=[urls, community_id], + start_to_close_timeout=timedelta(minutes=30), + retry_policy=RetryPolicy( + initial_interval=timedelta(seconds=10), + maximum_interval=timedelta(minutes=5), + maximum_attempts=3, + ), + ) + + documents = await workflow.execute_activity( + transform_website_data, + args=[raw_data, community_id], + start_to_close_timeout=timedelta(minutes=10), + retry_policy=RetryPolicy( + initial_interval=timedelta(seconds=5), + maximum_interval=timedelta(minutes=2), + maximum_attempts=1, + ), + ) + + await workflow.execute_activity( + load_website_data, + args=[documents, community_id], + start_to_close_timeout=timedelta(minutes=60), + retry_policy=RetryPolicy( + initial_interval=timedelta(seconds=5), + maximum_interval=timedelta(minutes=2), + maximum_attempts=3, + ), + ) + + +# Main scheduler workflow +@workflow.defn +class WebsiteIngestionSchedulerWorkflow: + @workflow.run + async def run(self, platform_id: str | None = None) -> None: + # Get all communities + communities = await workflow.execute_activity( + get_hivemind_website_comminities, + platform_id, + start_to_close_timeout=timedelta(minutes=5), + retry_policy=RetryPolicy( + maximum_attempts=3, + ), + ) + + # Start a child workflow for each community + child_workflows = [] + for community in communities: + child_handle = await workflow.start_child_workflow( + CommunityWebsiteWorkflow.run, + args=[community], + id=f"website:ingestor:{community['community_id']}", + retry_policy=RetryPolicy( + maximum_attempts=1, + ), + ) + child_workflows.append(child_handle) + + # Wait for all child workflows to complete + await asyncio.gather(*[handle for handle in child_workflows]) diff --git a/registry.py b/registry.py index a35c2f4..b3f03ef 100644 --- a/registry.py +++ b/registry.py @@ -1,20 +1,36 @@ from hivemind_etl.activities import ( extract_website, - get_communities, - load_data, + get_hivemind_website_comminities, + load_website_data, say_hello, - transform_data, + transform_website_data, + extract_mediawiki, + get_hivemind_mediawiki_platforms, + transform_mediawiki_data, + load_mediawiki_data, ) from workflows import ( CommunityWebsiteWorkflow, SayHello, WebsiteIngestionSchedulerWorkflow, + MediaWikiETLWorkflow, ) WORKFLOWS = [ CommunityWebsiteWorkflow, SayHello, WebsiteIngestionSchedulerWorkflow, + MediaWikiETLWorkflow, ] -ACTIVITIES = [get_communities, extract_website, transform_data, load_data, say_hello] +ACTIVITIES = [ + get_hivemind_website_comminities, + extract_website, + transform_website_data, + load_website_data, + get_hivemind_mediawiki_platforms, + extract_mediawiki, + transform_mediawiki_data, + load_mediawiki_data, + say_hello, +] diff --git a/requirements.txt b/requirements.txt index 1bb66c0..fdfbd5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ defusedxml==0.7.1 pydantic==2.9.2 motor>=3.6, <4.0.0 tc-temporal-backend==1.0.0 +wikiteam3==4.4.1 diff --git a/tests/integration/test_mediawiki_modules.py b/tests/integration/test_mediawiki_modules.py new file mode 100644 index 0000000..1332fd2 --- /dev/null +++ b/tests/integration/test_mediawiki_modules.py @@ -0,0 +1,234 @@ +from datetime import datetime +from unittest import TestCase + +from bson import ObjectId +from hivemind_etl.mediawiki.module import ModulesMediaWiki +from tc_hivemind_backend.db.mongo import MongoSingleton + + +class TestGetMediaWikiModules(TestCase): + def setUp(self): + client = MongoSingleton.get_instance().client + client["Core"].drop_collection("modules") + client["Core"].drop_collection("platforms") + self.client = client + self.modules_mediawiki = ModulesMediaWiki() + + def test_get_empty_data(self): + result = self.modules_mediawiki.get_learning_platforms() + self.assertEqual(result, []) + + def test_get_single_data(self): + platform_id = ObjectId("6579c364f1120850414e0dc6") + community_id = ObjectId("6579c364f1120850414e0dc5") + + self.client["Core"]["platforms"].insert_one( + { + "_id": platform_id, + "name": "mediaWiki", + "metadata": { + "baseURL": "http://example.com", + "path": "/api", + "namespaces": [0, 1, 2], + }, + "community": community_id, + "disconnectedAt": None, + "connectedAt": datetime.now(), + "createdAt": datetime.now(), + "updatedAt": datetime.now(), + } + ) + self.client["Core"]["modules"].insert_one( + { + "name": "hivemind", + "community": community_id, + "options": { + "platforms": [ + { + "platform": platform_id, + "name": "mediaWiki", + "metadata": { + "namespaces": [0, 1, 2], + }, + } + ] + }, + } + ) + + result = self.modules_mediawiki.get_learning_platforms() + + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["community_id"], "6579c364f1120850414e0dc5") + self.assertEqual( + result[0]["namespaces"], + [0, 1, 2], + ) + self.assertEqual(result[0]["base_url"], "http://example.com/api") + + def test_get_mediawiki_communities_data_multiple_platforms(self): + """ + Two mediawiki platforms for one community + """ + platform_id1 = ObjectId("6579c364f1120850414e0dc6") + platform_id2 = ObjectId("6579c364f1120850414e0dc7") + community_id = ObjectId("1009c364f1120850414e0dc5") + + self.client["Core"]["modules"].insert_one( + { + "name": "hivemind", + "community": community_id, + "options": { + "platforms": [ + { + "platform": platform_id1, + "name": "mediaWiki", + "metadata": { + "namespaces": [0, 1, 2], + }, + }, + { + "platform": platform_id2, + "name": "mediaWiki", + "metadata": { + "namespaces": [3, 4, 5], + }, + }, + ] + }, + } + ) + + self.client["Core"]["platforms"].insert_one( + { + "_id": platform_id1, + "name": "mediaWiki", + "metadata": { + "baseURL": "http://example1.com", + "path": "/api", + }, + "community": community_id, + "disconnectedAt": None, + "connectedAt": datetime.now(), + "createdAt": datetime.now(), + "updatedAt": datetime.now(), + } + ) + + self.client["Core"]["platforms"].insert_one( + { + "_id": platform_id2, + "name": "mediaWiki", + "metadata": { + "baseURL": "http://example2.com", + "path": "/api", + }, + "community": community_id, + "disconnectedAt": None, + "connectedAt": datetime.now(), + "createdAt": datetime.now(), + "updatedAt": datetime.now(), + } + ) + + result = self.modules_mediawiki.get_learning_platforms() + + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + self.assertEqual( + result[0], + { + "community_id": str(community_id), + "namespaces": [0, 1, 2], + "base_url": "http://example1.com/api", + }, + ) + self.assertEqual( + result[1], + { + "community_id": str(community_id), + "namespaces": [3, 4, 5], + "base_url": "http://example2.com/api", + }, + ) + + def test_get_mediawiki_communities_data_filtered_platforms(self): + """ + Two mediawiki platforms for one community + """ + platform_id1 = ObjectId("6579c364f1120850414e0dc6") + platform_id2 = ObjectId("6579c364f1120850414e0dc7") + community_id = ObjectId("1009c364f1120850414e0dc5") + + self.client["Core"]["modules"].insert_one( + { + "name": "hivemind", + "community": community_id, + "options": { + "platforms": [ + { + "platform": platform_id1, + "name": "mediaWiki", + "metadata": { + "namespaces": [0, 1, 2], + }, + }, + { + "platform": platform_id2, + "name": "mediaWiki", + "metadata": { + "namespaces": [3, 4, 5], + }, + }, + ] + }, + } + ) + + self.client["Core"]["platforms"].insert_one( + { + "_id": platform_id1, + "name": "mediaWiki", + "metadata": { + "baseURL": "http://example1.com", + "path": "/api", + }, + "community": community_id, + "disconnectedAt": None, + "connectedAt": datetime.now(), + "createdAt": datetime.now(), + "updatedAt": datetime.now(), + } + ) + + self.client["Core"]["platforms"].insert_one( + { + "_id": platform_id2, + "name": "mediaWiki", + "metadata": { + "baseURL": "http://example2.com", + "path": "/api", + }, + "community": community_id, + "disconnectedAt": None, + "connectedAt": datetime.now(), + "createdAt": datetime.now(), + "updatedAt": datetime.now(), + } + ) + + result = self.modules_mediawiki.get_learning_platforms( + platform_id_filter=str(platform_id1) + ) + + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual( + result[0], + { + "community_id": str(community_id), + "namespaces": [0, 1, 2], + "base_url": "http://example1.com/api", + }, + ) diff --git a/tests/unit/test_mediawiki_etl.py b/tests/unit/test_mediawiki_etl.py new file mode 100644 index 0000000..1441495 --- /dev/null +++ b/tests/unit/test_mediawiki_etl.py @@ -0,0 +1,170 @@ +import os +import unittest +from unittest.mock import Mock, patch +import shutil + +from llama_index.core import Document +from hivemind_etl.mediawiki.etl import MediawikiETL + + +class TestMediawikiETL(unittest.TestCase): + def setUp(self): + self.community_id = "test_community" + self.api_url = "https://example.com/api.php" + self.custom_path = "custom/path" + self.namespaces = [0, 1] # Main and Talk namespaces + + # Create a temporary dumps directory + os.makedirs(f"dump_{self.community_id}", exist_ok=True) + + def tearDown(self): + # Clean up any created files + if os.path.exists(f"dump_{self.community_id}"): + shutil.rmtree(f"dump_{self.community_id}") + if os.path.exists(self.custom_path): + shutil.rmtree(self.custom_path) + + def test_mediawiki_etl_initialization(self): + etl = MediawikiETL(community_id=self.community_id, namespaces=self.namespaces) + self.assertEqual(etl.community_id, self.community_id) + self.assertTrue(etl.delete_dump_after_load) + self.assertEqual(etl.dump_dir, f"dump_{self.community_id}") + + etl = MediawikiETL( + community_id=self.community_id, + namespaces=self.namespaces, + delete_dump_after_load=False, + ) + self.assertFalse(etl.delete_dump_after_load) + + def test_extract_with_default_path(self): + # Create a ETL instance with mocked wikiteam_crawler + etl = MediawikiETL(community_id=self.community_id, namespaces=self.namespaces) + etl.wikiteam_crawler = Mock() + + etl.extract(self.api_url) + + etl.wikiteam_crawler.crawl.assert_called_once_with( + self.api_url, f"dump_{self.community_id}" + ) + + def test_extract_with_custom_path(self): + # Create a ETL instance with mocked wikiteam_crawler + etl = MediawikiETL(community_id=self.community_id, namespaces=self.namespaces) + etl.wikiteam_crawler = Mock() + + etl.extract(self.api_url, self.custom_path) + + self.assertEqual(etl.dump_dir, self.custom_path) + etl.wikiteam_crawler.crawl.assert_called_once_with( + self.api_url, self.custom_path + ) + + @patch("hivemind_etl.mediawiki.etl.parse_mediawiki_xml") + def test_transform_success(self, mock_parse_mediawiki_xml): + etl = MediawikiETL(community_id=self.community_id, namespaces=self.namespaces) + + # Mock page data + mock_page = Mock() + mock_page.page_id = "123" + mock_page.title = "Test Page" + mock_page.namespace = 0 + mock_page.revision = Mock( + text="Test content", + revision_id="456", + parent_revision_id="455", + timestamp="2024-01-01T00:00:00Z", + comment="Test edit", + contributor=Mock(username="testuser", user_id="789"), + sha1="abc123", + model="wikitext", + ) + + mock_parse_mediawiki_xml.return_value = [mock_page] + + documents = etl.transform() + + self.assertEqual(len(documents), 1) + doc = documents[0] + self.assertIsInstance(doc, Document) + self.assertEqual(doc.doc_id, "123") + self.assertEqual(doc.text, "Test content") + self.assertEqual(doc.metadata["title"], "Test Page") + self.assertEqual(doc.metadata["namespace"], 0) + self.assertEqual(doc.metadata["revision_id"], "456") + self.assertEqual(doc.metadata["contributor_username"], "testuser") + + @patch("hivemind_etl.mediawiki.etl.logging") + @patch("hivemind_etl.mediawiki.etl.parse_mediawiki_xml") + def test_transform_error_handling(self, mock_parse_mediawiki_xml, mock_logging): + etl = MediawikiETL(community_id=self.community_id, namespaces=self.namespaces) + + # Mock page that will raise an exception + mock_page = Mock() + mock_page.page_id = "123" + + # Set up a side effect that raises an exception when accessing certain attributes + def get_attribute_error(*args, **kwargs): + raise Exception("Test error") + + # Configure the mock page to raise an exception + type(mock_page).revision = property(get_attribute_error) + + mock_parse_mediawiki_xml.return_value = [mock_page] + + documents = etl.transform() + + self.assertEqual(len(documents), 0) + mock_logging.error.assert_called_once_with( + "Error transforming page 123: Test error" + ) + + @patch("hivemind_etl.mediawiki.etl.CustomIngestionPipeline") + def test_load_with_dump_deletion(self, mock_ingestion_pipeline_class): + etl = MediawikiETL(community_id=self.community_id, namespaces=self.namespaces) + documents = [Document(text="Test content")] + + # Setup the mock + mock_pipeline = Mock() + mock_ingestion_pipeline_class.return_value = mock_pipeline + + # Create a temporary dump directory + os.makedirs(etl.dump_dir, exist_ok=True) + with open(os.path.join(etl.dump_dir, "test.xml"), "w") as f: + f.write("test content") + + etl.load(documents) + + # Verify that methods were called correctly + mock_ingestion_pipeline_class.assert_called_once_with( + self.community_id, collection_name="mediawiki" + ) + mock_pipeline.run_pipeline.assert_called_once_with(documents) + self.assertFalse(os.path.exists(etl.dump_dir)) + + @patch("hivemind_etl.mediawiki.etl.CustomIngestionPipeline") + def test_load_without_dump_deletion(self, mock_ingestion_pipeline_class): + etl = MediawikiETL( + community_id=self.community_id, + namespaces=self.namespaces, + delete_dump_after_load=False, + ) + documents = [Document(text="Test content")] + + # Setup the mock + mock_pipeline = Mock() + mock_ingestion_pipeline_class.return_value = mock_pipeline + + # Create a temporary dump directory + os.makedirs(etl.dump_dir, exist_ok=True) + with open(os.path.join(etl.dump_dir, "test.xml"), "w") as f: + f.write("test content") + + etl.load(documents) + + # Verify that methods were called correctly + mock_ingestion_pipeline_class.assert_called_once_with( + self.community_id, collection_name="mediawiki" + ) + mock_pipeline.run_pipeline.assert_called_once_with(documents) + self.assertTrue(os.path.exists(etl.dump_dir)) diff --git a/workflows.py b/workflows.py index df14af7..f55748f 100644 --- a/workflows.py +++ b/workflows.py @@ -2,101 +2,22 @@ import logging from datetime import timedelta -from hivemind_etl.activities import ( - extract_website, - get_communities, - load_data, - say_hello, - transform_data, +from hivemind_etl.activities import say_hello +from hivemind_etl.website.workflows import ( + CommunityWebsiteWorkflow, + WebsiteIngestionSchedulerWorkflow, ) +from hivemind_etl.mediawiki.workflows import ( + MediaWikiETLWorkflow, +) + from temporalio import workflow -from temporalio.common import RetryPolicy # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Individual community workflow -@workflow.defn -class CommunityWebsiteWorkflow: - @workflow.run - async def run(self, community_info: dict) -> None: - community_id = community_info["community_id"] - platform_id = community_info["platform_id"] - urls = community_info["urls"] - - logger.info( - f"Starting workflow for community {community_id} | platform {platform_id}" - ) - - # Execute activities in sequence with retries - raw_data = await workflow.execute_activity( - extract_website, - args=[urls, community_id], - start_to_close_timeout=timedelta(minutes=30), - retry_policy=RetryPolicy( - initial_interval=timedelta(seconds=10), - maximum_interval=timedelta(minutes=5), - maximum_attempts=3, - ), - ) - - documents = await workflow.execute_activity( - transform_data, - args=[raw_data, community_id], - start_to_close_timeout=timedelta(minutes=10), - retry_policy=RetryPolicy( - initial_interval=timedelta(seconds=5), - maximum_interval=timedelta(minutes=2), - maximum_attempts=1, - ), - ) - - await workflow.execute_activity( - load_data, - args=[documents, community_id], - start_to_close_timeout=timedelta(minutes=60), - retry_policy=RetryPolicy( - initial_interval=timedelta(seconds=5), - maximum_interval=timedelta(minutes=2), - maximum_attempts=3, - ), - ) - - -# Main scheduler workflow -@workflow.defn -class WebsiteIngestionSchedulerWorkflow: - @workflow.run - async def run(self, platform_id: str | None = None) -> None: - # Get all communities - communities = await workflow.execute_activity( - get_communities, - platform_id, - start_to_close_timeout=timedelta(minutes=5), - retry_policy=RetryPolicy( - maximum_attempts=3, - ), - ) - - # Start a child workflow for each community - child_workflows = [] - for community in communities: - child_handle = await workflow.start_child_workflow( - CommunityWebsiteWorkflow.run, - args=[community], - id=f"website:ingestor:{community['community_id']}", - retry_policy=RetryPolicy( - maximum_attempts=1, - ), - ) - child_workflows.append(child_handle) - - # Wait for all child workflows to complete - await asyncio.gather(*[handle for handle in child_workflows]) - - # For test purposes # To be deleted in future @workflow.defn