Skip to content

agents

acquisition_agents

ArxivAgent

Bases: BaseAcquisitionAgent

Drop-in replacement for your existing ArxivAgent that reuses the generic flow. Keeps the same behaviors (download PDFs, image processing, summarization/RAG).

Source code in src/ursa/agents/acquisition_agents.py
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
class ArxivAgent(BaseAcquisitionAgent):
    """
    Drop-in replacement for your existing ArxivAgent that reuses the generic flow.
    Keeps the same behaviors (download PDFs, image processing, summarization/RAG).
    """

    def __init__(
        self,
        llm: BaseChatModel,
        *,
        process_images: bool = True,
        max_results: int = 3,
        download: bool = True,
        rag_embedding=None,
        database_path="arxiv_papers",
        summaries_path="arxiv_generated_summaries",
        vectorstore_path="arxiv_vectorstores",
        **kwargs,
    ):
        super().__init__(
            llm,
            rag_embedding=rag_embedding,
            process_images=process_images,
            max_results=max_results,
            database_path=database_path,
            summaries_path=summaries_path,
            vectorstore_path=vectorstore_path,
            download=download,
            **kwargs,
        )

    def _id(self, hit_or_item: dict[str, Any]) -> str:
        # hits from arXiv feed have 'id' like ".../abs/XXXX.YYYY"
        arxiv_id = hit_or_item.get("arxiv_id")
        if arxiv_id:
            return arxiv_id
        feed_id = hit_or_item.get("id", "")
        if "/abs/" in feed_id:
            return feed_id.split("/abs/")[-1]
        return _hash(json.dumps(hit_or_item))

    def _citation(self, item: ItemMetadata) -> str:
        return f"ArXiv ID: {item.get('id', '?')}"

    def _search(self, query: str) -> list[dict[str, Any]]:
        enc = quote(query)
        url = f"http://export.arxiv.org/api/query?search_query=all:{enc}&start=0&max_results={self.max_results}"
        try:
            resp = requests.get(url, timeout=15)
            resp.raise_for_status()
            feed = feedparser.parse(resp.content)
            entries = feed.entries if hasattr(feed, "entries") else []
            hits = []
            for e in entries:
                full_id = e.id.split("/abs/")[-1]
                hits.append({
                    "id": e.id,
                    "title": e.title.strip(),
                    "arxiv_id": full_id.split("/")[-1],
                })
            return hits
        except Exception as e:
            return [
                {
                    "id": _hash(query + str(time.time())),
                    "title": "Search error",
                    "error": str(e),
                }
            ]

    def _materialize(self, hit: dict[str, Any]) -> ItemMetadata:
        arxiv_id = self._id(hit)
        title = hit.get("title", "")
        pdf_url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
        local_path = os.path.join(self.database_path, f"{arxiv_id}.pdf")
        full_text = ""
        try:
            _download(pdf_url, local_path)
            full_text = read_pdf_text(local_path)
        except Exception as e:
            full_text = f"[Error loading ArXiv {arxiv_id}: {e}]"
        full_text = self._postprocess_text(full_text, local_path)
        return {
            "id": arxiv_id,
            "title": title,
            "url": pdf_url,
            "local_path": local_path,
            "full_text": full_text,
        }

BaseAcquisitionAgent

Bases: BaseAgent

A generic "acquire-then-summarize-or-RAG" agent.

Subclasses must implement
  • _search(self, query) -> List[dict-like]: lightweight hits
  • _materialize(self, hit) -> ItemMetadata: download or scrape and return populated item
  • _id(self, hit_or_item) -> str: stable id for caching/file naming
  • _citation(self, item) -> str: human-readable citation string
Optional hooks
  • _postprocess_text(self, text, local_path) -> str (e.g., image interpretation)
  • _filter_hit(self, hit) -> bool
Source code in src/ursa/agents/acquisition_agents.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
class BaseAcquisitionAgent(BaseAgent):
    """
    A generic "acquire-then-summarize-or-RAG" agent.

    Subclasses must implement:
      - _search(self, query) -> List[dict-like]: lightweight hits
      - _materialize(self, hit) -> ItemMetadata: download or scrape and return populated item
      - _id(self, hit_or_item) -> str: stable id for caching/file naming
      - _citation(self, item) -> str: human-readable citation string

    Optional hooks:
      - _postprocess_text(self, text, local_path) -> str (e.g., image interpretation)
      - _filter_hit(self, hit) -> bool
    """

    def __init__(
        self,
        llm: BaseChatModel,
        *,
        summarize: bool = True,
        rag_embedding=None,
        process_images: bool = True,
        max_results: int = 5,
        database_path: str = "acq_db",
        summaries_path: str = "acq_summaries",
        vectorstore_path: str = "acq_vectorstores",
        num_threads: int = 4,
        download: bool = True,
        **kwargs,
    ):
        super().__init__(llm, **kwargs)
        self.summarize = summarize
        self.rag_embedding = rag_embedding
        self.process_images = process_images
        self.max_results = max_results
        self.database_path = database_path
        self.summaries_path = summaries_path
        self.vectorstore_path = vectorstore_path
        self.download = download
        self.num_threads = num_threads

        os.makedirs(self.database_path, exist_ok=True)
        os.makedirs(self.summaries_path, exist_ok=True)

        self._action = self._build_graph()

    # ---- abstract-ish methods ----
    def _search(self, query: str) -> list[dict[str, Any]]:
        raise NotImplementedError

    def _materialize(self, hit: dict[str, Any]) -> ItemMetadata:
        raise NotImplementedError

    def _id(self, hit_or_item: dict[str, Any]) -> str:
        raise NotImplementedError

    def _citation(self, item: ItemMetadata) -> str:
        # Subclass should format its ideal citation; fallback is ID or URL.
        return item.get("id") or item.get("url", "Unknown Source")

    # ---- optional hooks ----
    def _filter_hit(self, hit: dict[str, Any]) -> bool:
        return True

    def _postprocess_text(self, text: str, local_path: Optional[str]) -> str:
        # Default: optionally add image descriptions for PDFs
        if (
            self.process_images
            and local_path
            and local_path.lower().endswith(".pdf")
        ):
            try:
                descs = extract_and_describe_images(local_path)
                if any(descs):
                    text += "\n\n[Image Interpretations]\n" + "\n".join(descs)
            except Exception:
                pass
        return text

    # ---- shared nodes ----
    def _fetch_items(self, query: str) -> list[ItemMetadata]:
        hits = self._search(query)[: self.max_results] if self.download else []
        items: list[ItemMetadata] = []

        # If not downloading/scraping, try to load whatever is cached in database_path.
        if not self.download:
            for fname in os.listdir(self.database_path):
                if fname.lower().endswith((".pdf", ".txt", ".html")):
                    item_id = os.path.splitext(fname)[0]
                    local_path = os.path.join(self.database_path, fname)
                    full_text = ""
                    try:
                        if fname.lower().endswith(".pdf"):
                            full_text = read_pdf_text(local_path)
                        else:
                            with open(
                                local_path,
                                "r",
                                encoding="utf-8",
                                errors="ignore",
                            ) as f:
                                full_text = f.read()
                    except Exception as e:
                        full_text = f"[Error reading cached file: {e}]"
                    full_text = self._postprocess_text(full_text, local_path)
                    items.append({
                        "id": item_id,
                        "local_path": local_path,
                        "full_text": full_text,
                    })
            return items

        # Normal path: search → materialize each
        with ThreadPoolExecutor(
            max_workers=min(self.num_threads, max(1, len(hits)))
        ) as ex:
            futures = [
                ex.submit(self._materialize, h)
                for h in hits
                if self._filter_hit(h)
            ]
            for fut in as_completed(futures):
                try:
                    item = fut.result()
                    items.append(item)
                except Exception as e:
                    items.append({
                        "id": _hash(str(time.time())),
                        "full_text": f"[Error: {e}]",
                    })
        return items

    def _fetch_node(self, state: AcquisitionState) -> AcquisitionState:
        items = self._fetch_items(state["query"])
        return {**state, "items": items}

    def _summarize_node(self, state: AcquisitionState) -> AcquisitionState:
        prompt = ChatPromptTemplate.from_template("""
        You are an assistant responsible for summarizing retrieved content in the context of this task: {context}

        Summarize the content below:

        {retrieved_content}
        """)
        chain = prompt | self.llm | StrOutputParser()

        if "items" not in state or not state["items"]:
            return {**state, "summaries": None}

        summaries: list[Optional[str]] = [None] * len(state["items"])

        def process(i: int, item: ItemMetadata):
            item_id = item.get("id", f"item_{i}")
            out_path = os.path.join(
                self.summaries_path, f"{_safe_filename(item_id)}_summary.txt"
            )
            try:
                cleaned = remove_surrogates(item.get("full_text", ""))
                summary = chain.invoke(
                    {"retrieved_content": cleaned, "context": state["context"]},
                    config=self.build_config(tags=["acq", "summarize_each"]),
                )
            except Exception as e:
                summary = f"[Error summarizing item {item_id}: {e}]"
            with open(out_path, "w", encoding="utf-8") as f:
                f.write(summary)
            return i, summary

        with ThreadPoolExecutor(
            max_workers=min(self.num_threads, len(state["items"]))
        ) as ex:
            futures = [
                ex.submit(process, i, it) for i, it in enumerate(state["items"])
            ]
            for fut in as_completed(futures):
                i, s = fut.result()
                summaries[i] = s

        return {**state, "summaries": summaries}  # type: ignore

    def _rag_node(self, state: AcquisitionState) -> AcquisitionState:
        new_state = state.copy()
        rag_agent = RAGAgent(
            llm=self.llm,
            embedding=self.rag_embedding,
            database_path=self.database_path,
        )
        new_state["final_summary"] = rag_agent.invoke(context=state["context"])[
            "summary"
        ]
        return new_state

    def _aggregate_node(self, state: AcquisitionState) -> AcquisitionState:
        if not state.get("summaries") or not state.get("items"):
            return {**state, "final_summary": None}

        blocks: list[str] = []
        for idx, (item, summ) in enumerate(
            zip(state["items"], state["summaries"])
        ):  # type: ignore
            cite = self._citation(item)
            blocks.append(f"[{idx + 1}] {cite}\n\nSummary:\n{summ}")

        combined = "\n\n" + ("\n\n" + "-" * 40 + "\n\n").join(blocks)
        with open(
            os.path.join(self.summaries_path, "summaries_combined.txt"),
            "w",
            encoding="utf-8",
        ) as f:
            f.write(combined)

        prompt = ChatPromptTemplate.from_template("""
        You are a scientific assistant extracting insights from multiple summaries.

        Here are the summaries:

        {Summaries}

        Your task is to read all the summaries and provide a response to this task: {context}
        """)
        chain = prompt | self.llm | StrOutputParser()

        final_summary = chain.invoke(
            {"Summaries": combined, "context": state["context"]},
            config=self.build_config(tags=["acq", "aggregate"]),
        )
        with open(
            os.path.join(self.summaries_path, "final_summary.txt"),
            "w",
            encoding="utf-8",
        ) as f:
            f.write(final_summary)

        return {**state, "final_summary": final_summary}

    def _build_graph(self):
        graph = StateGraph(AcquisitionState)
        self.add_node(graph, self._fetch_node)

        if self.summarize:
            if self.rag_embedding:
                self.add_node(graph, self._rag_node)
                graph.set_entry_point("_fetch_node")
                graph.add_edge("_fetch_node", "_rag_node")
                graph.set_finish_point("_rag_node")
            else:
                self.add_node(graph, self._summarize_node)
                self.add_node(graph, self._aggregate_node)

                graph.set_entry_point("_fetch_node")
                graph.add_edge("_fetch_node", "_summarize_node")
                graph.add_edge("_summarize_node", "_aggregate_node")
                graph.set_finish_point("_aggregate_node")
        else:
            graph.set_entry_point("_fetch_node")
            graph.set_finish_point("_fetch_node")

        return graph.compile(checkpointer=self.checkpointer)

    def _invoke(
        self,
        inputs: Mapping[str, Any],
        *,
        summarize: bool | None = None,
        recursion_limit: int = 1000,
        **_,
    ) -> str:
        config = self.build_config(
            recursion_limit=recursion_limit, tags=["graph"]
        )

        # alias support like your ArxivAgent
        if "query" not in inputs:
            if "arxiv_search_query" in inputs:
                inputs = dict(inputs)
                inputs["query"] = inputs.pop("arxiv_search_query")
            else:
                raise KeyError(
                    "Missing 'query' in inputs (alias 'arxiv_search_query' also accepted)."
                )

        result = self._action.invoke(inputs, config)
        use_summary = self.summarize if summarize is None else summarize
        return (
            result.get("final_summary", "No summary generated.")
            if use_summary
            else "\n\nFinished fetching items!"
        )

OSTIAgent

Bases: BaseAcquisitionAgent

Minimal OSTI.gov acquisition agent.

NOTE
  • OSTI provides search endpoints that can return metadata including full-text links.
  • Depending on your environment, you may prefer the public API or site scraping.
  • Here we assume a JSON API that yields results with keys like: {'osti_id': '12345', 'title': '...', 'pdf_url': 'https://...pdf', 'landing_page': 'https://...'} Adapt field names if your OSTI integration differs.

Customize _search and _materialize to match your OSTI access path.

Source code in src/ursa/agents/acquisition_agents.py
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
class OSTIAgent(BaseAcquisitionAgent):
    """
    Minimal OSTI.gov acquisition agent.

    NOTE:
      - OSTI provides search endpoints that can return metadata including full-text links.
      - Depending on your environment, you may prefer the public API or site scraping.
      - Here we assume a JSON API that yields results with keys like:
            {'osti_id': '12345', 'title': '...', 'pdf_url': 'https://...pdf', 'landing_page': 'https://...'}
        Adapt field names if your OSTI integration differs.

    Customize `_search` and `_materialize` to match your OSTI access path.
    """

    def __init__(
        self,
        *args,
        api_base: str = "https://www.osti.gov/api/v1/records",
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.api_base = api_base

    def _id(self, hit_or_item: dict[str, Any]) -> str:
        if "osti_id" in hit_or_item:
            return str(hit_or_item["osti_id"])
        if "id" in hit_or_item:
            return str(hit_or_item["id"])
        if "landing_page" in hit_or_item:
            return _hash(hit_or_item["landing_page"])
        return _hash(json.dumps(hit_or_item))

    def _citation(self, item: ItemMetadata) -> str:
        t = item.get("title", "") or ""
        oid = item.get("id", "")
        return f"OSTI {oid}: {t}" if t else f"OSTI {oid}"

    def _search(self, query: str) -> list[dict[str, Any]]:
        """
        Adjust params to your OSTI setup. This call is intentionally simple;
        add paging/auth as needed.
        """
        params = {
            "q": query,
            "size": self.max_results,
        }
        try:
            r = requests.get(self.api_base, params=params, timeout=25)
            r.raise_for_status()
            data = r.json()
            # Normalize to a list of hits; adapt key if your API differs.
            if isinstance(data, dict) and "records" in data:
                hits = data["records"]
            elif isinstance(data, list):
                hits = data
            else:
                hits = []
            return hits[: self.max_results]
        except Exception as e:
            return [
                {
                    "id": _hash(query + str(time.time())),
                    "title": "Search error",
                    "error": str(e),
                }
            ]

    def _materialize(self, hit: dict[str, Any]) -> ItemMetadata:
        item_id = self._id(hit)
        title = hit.get("title") or hit.get("title_public", "") or ""
        landing = None
        local_path = ""
        full_text = ""

        try:
            pdf_url, landing_used, _ = resolve_pdf_from_osti_record(
                hit,
                headers={"User-Agent": "Mozilla/5.0"},
                unpaywall_email=os.environ.get("UNPAYWALL_EMAIL"),  # optional
            )

            if pdf_url:
                # Try to download as PDF (validate headers)
                with requests.get(
                    pdf_url,
                    headers={"User-Agent": "Mozilla/5.0"},
                    timeout=25,
                    allow_redirects=True,
                    stream=True,
                ) as r:
                    r.raise_for_status()
                    if _is_pdf_response(r):
                        fname = _derive_filename_from_cd_or_url(
                            r, f"osti_{item_id}.pdf"
                        )
                        local_path = os.path.join(self.database_path, fname)
                        _download_stream_to(local_path, r)
                        # Extract PDF text
                        try:
                            full_text = read_pdf_text(local_path)
                        except Exception as e:
                            full_text = (
                                f"[Downloaded but text extraction failed: {e}]"
                            )
                    else:
                        # Not a PDF; treat as HTML landing and parse text
                        landing = r.url
                        r.close()
            # If we still have no text, try scraping the DOE PAGES landing or citation page
            if not full_text:
                # Prefer DOE PAGES landing if present, else OSTI biblio
                landing = (
                    landing
                    or landing_used
                    or next(
                        (
                            link.get("href")
                            for link in hit.get("links", [])
                            if link.get("rel")
                            in ("citation_doe_pages", "citation")
                        ),
                        None,
                    )
                )
                if landing:
                    soup = _get_soup(
                        landing,
                        timeout=25,
                        headers={"User-Agent": "Mozilla/5.0"},
                    )
                    html_text = soup.get_text(" ", strip=True)
                    full_text = html_text[:1_000_000]  # keep it bounded
                    # Save raw HTML for cache/inspection
                    local_path = os.path.join(
                        self.database_path, f"{item_id}.html"
                    )
                    with open(local_path, "w", encoding="utf-8") as f:
                        f.write(str(soup))
                else:
                    full_text = "[No PDF or landing page text available.]"

        except Exception as e:
            full_text = f"[Error materializing OSTI {item_id}: {e}]"

        full_text = self._postprocess_text(full_text, local_path)
        return {
            "id": item_id,
            "title": title,
            "url": landing,
            "local_path": local_path,
            "full_text": full_text,
            "extra": {"raw_hit": hit},
        }

WebSearchAgent

Bases: BaseAcquisitionAgent

Uses DuckDuckGo Search (ddgs) to find pages, downloads HTML or PDFs, extracts text, and then follows the same summarize/RAG path.

Source code in src/ursa/agents/acquisition_agents.py
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
class WebSearchAgent(BaseAcquisitionAgent):
    """
    Uses DuckDuckGo Search (ddgs) to find pages, downloads HTML or PDFs,
    extracts text, and then follows the same summarize/RAG path.
    """

    def __init__(self, *args, user_agent: str = "Mozilla/5.0", **kwargs):
        super().__init__(*args, **kwargs)
        self.user_agent = user_agent
        if DDGS is None:
            raise ImportError(
                "duckduckgo-search (DDGS) is required for WebSearchAgentGeneric."
            )

    def _id(self, hit_or_item: dict[str, Any]) -> str:
        url = hit_or_item.get("href") or hit_or_item.get("url") or ""
        return (
            _hash(url)
            if url
            else hit_or_item.get("id", _hash(json.dumps(hit_or_item)))
        )

    def _citation(self, item: ItemMetadata) -> str:
        t = item.get("title", "") or ""
        u = item.get("url", "") or ""
        return f"{t} ({u})" if t else (u or item.get("id", "Web result"))

    def _search(self, query: str) -> list[dict[str, Any]]:
        results: list[dict[str, Any]] = []
        with DDGS() as ddgs:
            for r in ddgs.text(
                query, max_results=self.max_results, backend="auto"
            ):
                # r keys typically: title, href, body
                results.append(r)
        return results

    def _materialize(self, hit: dict[str, Any]) -> ItemMetadata:
        url = hit.get("href") or hit.get("url")
        title = hit.get("title", "")
        if not url:
            return {"id": self._id(hit), "title": title, "full_text": ""}

        headers = {"User-Agent": self.user_agent}
        local_path = ""
        full_text = ""
        item_id = self._id(hit)

        try:
            if _looks_like_pdf_url(url):
                local_path = os.path.join(
                    self.database_path, _safe_filename(item_id) + ".pdf"
                )
                _download(url, local_path)
                full_text = read_pdf_text(local_path)
            else:
                r = requests.get(url, headers=headers, timeout=20)
                r.raise_for_status()
                html = r.text
                local_path = os.path.join(
                    self.database_path, _safe_filename(item_id) + ".html"
                )
                with open(local_path, "w", encoding="utf-8") as f:
                    f.write(html)
                full_text = extract_main_text_only(html)
                # full_text = _basic_readable_text_from_html(html)
        except Exception as e:
            full_text = f"[Error retrieving {url}: {e}]"

        full_text = self._postprocess_text(full_text, local_path)
        return {
            "id": item_id,
            "title": title,
            "url": url,
            "local_path": local_path,
            "full_text": full_text,
            "extra": {"snippet": hit.get("body", "")},
        }

base

Base agent class providing telemetry, configuration, and execution abstractions.

This module defines the BaseAgent abstract class, which serves as the foundation for all agent implementations in the Ursa framework. It provides:

  • Standardized initialization with LLM configuration
  • Telemetry and metrics collection
  • Thread and checkpoint management
  • Input normalization and validation
  • Execution flow control with invoke/stream methods
  • Graph integration utilities for LangGraph compatibility
  • Runtime enforcement of the agent interface contract

Agents built on this base class benefit from consistent behavior, observability, and integration capabilities while only needing to implement the core _invoke method.

BaseAgent

Bases: ABC

Abstract base class for all agent implementations in the Ursa framework.

BaseAgent provides a standardized foundation for building LLM-powered agents with built-in telemetry, configuration management, and execution flow control. It handles common tasks like input normalization, thread management, metrics collection, and LangGraph integration.

Subclasses only need to implement the _invoke method to define their core functionality, while inheriting standardized invocation patterns, telemetry, and graph integration capabilities. The class enforces a consistent interface through runtime checks that prevent subclasses from overriding critical methods like invoke().

The agent supports both direct invocation with inputs and streaming responses, with automatic tracking of token usage, execution time, and other metrics. It also provides utilities for integrating with LangGraph through node wrapping and configuration.

Subclass Inheritance Guidelines
  • Must Override: _invoke() - Define your agent's core functionality
  • Can Override: _stream() - Enable streaming support _normalize_inputs() - Customize input handling Various helper methods (_default_node_tags, _as_runnable, etc.)
  • Never Override: invoke() - Final method with runtime enforcement stream() - Handles telemetry and delegates to _stream call() - Delegates to invoke Other public methods (build_config, write_state, add_node)

To create a custom agent, inherit from this class and implement the _invoke method:

class MyAgent(BaseAgent):
    def _invoke(self, inputs: Mapping[str, Any], **config: Any) -> Any:
        # Process inputs and return results
        ...
Source code in src/ursa/agents/base.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
class BaseAgent(ABC):
    """Abstract base class for all agent implementations in the Ursa framework.

    BaseAgent provides a standardized foundation for building LLM-powered agents with
    built-in telemetry, configuration management, and execution flow control. It handles
    common tasks like input normalization, thread management, metrics collection, and
    LangGraph integration.

    Subclasses only need to implement the _invoke method to define their core
    functionality, while inheriting standardized invocation patterns, telemetry, and
    graph integration capabilities. The class enforces a consistent interface through
    runtime checks that prevent subclasses from overriding critical methods like
    invoke().

    The agent supports both direct invocation with inputs and streaming responses, with
    automatic tracking of token usage, execution time, and other metrics. It also
    provides utilities for integrating with LangGraph through node wrapping and
    configuration.

    Subclass Inheritance Guidelines:
        - Must Override: _invoke() - Define your agent's core functionality
        - Can Override: _stream() - Enable streaming support
                        _normalize_inputs() - Customize input handling
                        Various helper methods (_default_node_tags, _as_runnable, etc.)
        - Never Override: invoke() - Final method with runtime enforcement
                          stream() - Handles telemetry and delegates to _stream
                          __call__() - Delegates to invoke
                          Other public methods (build_config, write_state, add_node)

    To create a custom agent, inherit from this class and implement the _invoke method:

    ```python
    class MyAgent(BaseAgent):
        def _invoke(self, inputs: Mapping[str, Any], **config: Any) -> Any:
            # Process inputs and return results
            ...
    ```
    """

    # This will be shared across all BaseAgent instances.
    _invoke_depth: int = 0

    _TELEMETRY_KW = {
        "raw_debug",
        "save_json",
        "metrics_path",
        "save_raw_snapshot",
        "save_raw_records",
    }

    _CONTROL_KW = {"config", "recursion_limit", "tags", "metadata", "callbacks"}

    def __init__(
        self,
        llm: BaseChatModel,
        checkpointer: Optional[BaseCheckpointSaver] = None,
        enable_metrics: bool = True,
        metrics_dir: str = "ursa_metrics",  # dir to save metrics, with a default
        autosave_metrics: bool = True,
        thread_id: Optional[str] = None,
    ):
        self.llm = llm
        """Initializes the base agent with a language model and optional configurations.

        Args:
            llm: a BaseChatModel instance.
            checkpointer: Optional checkpoint saver for persisting agent state.
            enable_metrics: Whether to collect performance and usage metrics.
            metrics_dir: Directory path where metrics will be saved.
            autosave_metrics: Whether to automatically save metrics to disk.
            thread_id: Unique identifier for this agent instance. Generated if not
                       provided.
        """
        self.thread_id = thread_id or uuid4().hex
        self.checkpointer = checkpointer
        self.telemetry = Telemetry(
            enable=enable_metrics,
            output_dir=metrics_dir,
            save_json_default=autosave_metrics,
        )

    @property
    def name(self) -> str:
        """Agent name."""
        return self.__class__.__name__

    def add_node(
        self,
        graph: StateGraph,
        f: Callable[..., Mapping[str, Any]],
        node_name: Optional[str] = None,
        agent_name: Optional[str] = None,
    ) -> StateGraph:
        """Add a node to the state graph with token usage tracking.

        This method adds a function as a node to the state graph, wrapping it to track
        token usage during execution. The node is identified by either the provided
        node_name or the function's name.

        Args:
            graph: The StateGraph to add the node to.
            f: The function to add as a node. Should return a mapping of string keys to
                any values.
            node_name: Optional name for the node. If not provided, the function's name
                will be used.
            agent_name: Optional agent name for tracking. If not provided, the agent's
                name in snake_case will be used.

        Returns:
            The updated StateGraph with the new node added.
        """
        _node_name = node_name or f.__name__
        _agent_name = agent_name or _to_snake(self.name)
        wrapped_node = self._wrap_node(f, _node_name, _agent_name)

        return graph.add_node(_node_name, wrapped_node)

    def write_state(self, filename: str, state: dict) -> None:
        """Writes agent state to a JSON file.

        Serializes the provided state dictionary to JSON format and writes it to the
        specified file. The JSON is written with non-ASCII characters preserved.

        Args:
            filename: Path to the file where state will be written.
            state: Dictionary containing the agent state to be serialized.
        """
        json_state = dumps(state, ensure_ascii=False)
        with open(filename, "w") as f:
            f.write(json_state)

    def build_config(self, **overrides) -> dict:
        """Constructs a config dictionary for agent operations with telemetry support.

        This method creates a standardized configuration dictionary that includes thread
        identification, telemetry callbacks, and other metadata needed for agent
        operations. The configuration can be customized through override parameters.

        Args:
            **overrides: Optional configuration overrides that can include keys like
                'recursion_limit', 'configurable', 'metadata', 'tags', etc.

        Returns:
            dict: A complete configuration dictionary with all necessary parameters.
        """
        # Create the base configuration with essential fields.
        base = {
            "configurable": {"thread_id": self.thread_id},
            "metadata": {
                "thread_id": self.thread_id,
                "telemetry_run_id": self.telemetry.context.get("run_id"),
            },
            "tags": [self.name],
            "callbacks": self.telemetry.callbacks,
        }

        # Try to determine the model name from either direct or nested attributes
        model_name = getattr(self, "llm_model", None) or getattr(
            getattr(self, "llm", None), "model", None
        )

        # Add model name to metadata if available
        if model_name:
            base["metadata"]["model"] = model_name

        # Handle configurable dictionary overrides by merging with base configurable
        if "configurable" in overrides and isinstance(
            overrides["configurable"], dict
        ):
            base["configurable"].update(overrides.pop("configurable"))

        # Handle metadata dictionary overrides by merging with base metadata
        if "metadata" in overrides and isinstance(overrides["metadata"], dict):
            base["metadata"].update(overrides.pop("metadata"))

        # Merge tags from caller-provided overrides, avoid duplicates
        if "tags" in overrides and isinstance(overrides["tags"], list):
            base["tags"] = base["tags"] + [
                t for t in overrides.pop("tags") if t not in base["tags"]
            ]

        # Apply any remaining overrides directly to the base configuration
        base.update(overrides)

        return base

    def _invoke_engine(
        self,
        invoke_method,
        inputs: Optional[InputLike] = None,
        raw_debug: bool = False,
        save_json: Optional[bool] = None,
        metrics_path: Optional[str] = None,
        save_raw_snapshot: Optional[bool] = None,
        save_raw_records: Optional[bool] = None,
        config: Optional[dict] = None,
        **kwargs: Any,
    ):
        BaseAgent._invoke_depth += 1

        try:
            # Start telemetry tracking for the top-level invocation
            if BaseAgent._invoke_depth == 1:
                self.telemetry.begin_run(
                    agent=self.name, thread_id=self.thread_id
                )

            # Handle the case where inputs are provided as keyword arguments
            if inputs is None:
                # Separate kwargs into input parameters and control parameters
                kw_inputs: dict[str, Any] = {}
                control_kwargs: dict[str, Any] = {}
                for k, v in kwargs.items():
                    if k in self._TELEMETRY_KW or k in self._CONTROL_KW:
                        control_kwargs[k] = v
                    else:
                        kw_inputs[k] = v
                inputs = kw_inputs

                # Only control kwargs remain for further processing
                kwargs = control_kwargs

            # Handle the case where inputs are provided as a positional argument
            else:
                # Ensure no ambiguous keyword arguments are present
                for k in kwargs.keys():
                    if not (k in self._TELEMETRY_KW or k in self._CONTROL_KW):
                        raise TypeError(
                            f"Unexpected keyword argument '{k}'. "
                            "Pass inputs as a single mapping or omit the positional "
                            "inputs and pass them as keyword arguments."
                        )

            # Allow subclasses to normalize or transform the input format
            normalized = self._normalize_inputs(inputs)

            # Delegate to the subclass implementation with the normalized inputs
            # and any control parameters
            return invoke_method(normalized, config=config, **kwargs)

        finally:
            # Clean up the invocation depth tracking
            BaseAgent._invoke_depth -= 1

            # For the top-level invocation, finalize telemetry and generate outputs
            if BaseAgent._invoke_depth == 0:
                self.telemetry.render(
                    raw=raw_debug,
                    save_json=save_json,
                    filepath=metrics_path,
                    save_raw_snapshot=save_raw_snapshot,
                    save_raw_records=save_raw_records,
                )

    # NOTE: The `invoke` method uses the PEP 570 `/,*` notation to explicitly state which
    # arguments can and cannot be passed as positional or keyword arguments.
    @final
    def invoke(
        self,
        inputs: Optional[InputLike] = None,
        /,
        *,
        raw_debug: bool = False,
        save_json: Optional[bool] = None,
        metrics_path: Optional[str] = None,
        save_raw_snapshot: Optional[bool] = None,
        save_raw_records: Optional[bool] = None,
        config: Optional[dict] = None,
        **kwargs: Any,
    ) -> Any:
        """Executes the agent with the provided inputs and configuration.

        This is the main entry point for agent execution. It handles input normalization,
        telemetry tracking, and proper execution context management. The method supports
        flexible input formats - either as a positional argument or as keyword arguments.

        Args:
            inputs: Optional positional input to the agent. If provided, all non-control
                keyword arguments will be rejected to avoid ambiguity.
            raw_debug: If True, displays raw telemetry data for debugging purposes.
            save_json: If True, saves telemetry data as JSON.
            metrics_path: Optional file path where telemetry metrics should be saved.
            save_raw_snapshot: If True, saves a raw snapshot of the telemetry data.
            save_raw_records: If True, saves raw telemetry records.
            config: Optional configuration dictionary to override default settings.
            **kwargs: Additional keyword arguments that can be either:
                - Input parameters (when no positional input is provided)
                - Control parameters recognized by the agent

        Returns:
            The result of the agent's execution.

        Raises:
            TypeError: If both positional inputs and non-control keyword arguments are
                provided simultaneously.
        """
        return self._invoke_engine(
            invoke_method=self._invoke,
            inputs=inputs,
            raw_debug=raw_debug,
            save_json=save_json,
            metrics_path=metrics_path,
            save_raw_snapshot=save_raw_snapshot,
            save_raw_records=save_raw_records,
            config=config,
            **kwargs,
        )

    # NOTE: The `ainvoke` method uses the PEP 570 `/,*` notation to explicitly state which
    # arguments can and cannot be passed as positional or keyword arguments.
    @final
    def ainvoke(
        self,
        inputs: Optional[InputLike] = None,
        /,
        *,
        raw_debug: bool = False,
        save_json: Optional[bool] = None,
        metrics_path: Optional[str] = None,
        save_raw_snapshot: Optional[bool] = None,
        save_raw_records: Optional[bool] = None,
        config: Optional[dict] = None,
        **kwargs: Any,
    ) -> Any:
        """Asynchrnously executes the agent with the provided inputs and configuration.

        (Async version of `invoke`.)

        This is the main entry point for agent execution. It handles input normalization,
        telemetry tracking, and proper execution context management. The method supports
        flexible input formats - either as a positional argument or as keyword arguments.

        Args:
            inputs: Optional positional input to the agent. If provided, all non-control
                keyword arguments will be rejected to avoid ambiguity.
            raw_debug: If True, displays raw telemetry data for debugging purposes.
            save_json: If True, saves telemetry data as JSON.
            metrics_path: Optional file path where telemetry metrics should be saved.
            save_raw_snapshot: If True, saves a raw snapshot of the telemetry data.
            save_raw_records: If True, saves raw telemetry records.
            config: Optional configuration dictionary to override default settings.
            **kwargs: Additional keyword arguments that can be either:
                - Input parameters (when no positional input is provided)
                - Control parameters recognized by the agent

        Returns:
            The result of the agent's execution.

        Raises:
            TypeError: If both positional inputs and non-control keyword arguments are
                provided simultaneously.
        """
        return self._invoke_engine(
            invoke_method=self._ainvoke,
            inputs=inputs,
            raw_debug=raw_debug,
            save_json=save_json,
            metrics_path=metrics_path,
            save_raw_snapshot=save_raw_snapshot,
            save_raw_records=save_raw_records,
            config=config,
            **kwargs,
        )

    def _normalize_inputs(self, inputs: InputLike) -> Mapping[str, Any]:
        """Normalizes various input formats into a standardized mapping.

        This method converts different input types into a consistent dictionary format
        that can be processed by the agent. String inputs are wrapped as messages, while
        mappings are passed through unchanged.

        Args:
            inputs: The input to normalize. Can be a string (which will be converted to a
                message) or a mapping (which will be returned as-is).

        Returns:
            A mapping containing the normalized inputs, with keys appropriate for agent
            processing.

        Raises:
            TypeError: If the input type is not supported (neither string nor mapping).
        """
        if isinstance(inputs, str):
            # Adjust to your message type
            return {"messages": [HumanMessage(content=inputs)]}
        if isinstance(inputs, Mapping):
            return inputs
        raise TypeError(f"Unsupported input type: {type(inputs)}")

    @abstractmethod
    def _invoke(self, inputs: Mapping[str, Any], **config: Any) -> Any:
        """Subclasses implement the actual work against normalized inputs."""
        ...

    def _ainvoke(self, inputs: Mapping[str, Any], **config: Any) -> Any:
        """Subclasses implement the actual work against normalized inputs."""
        ...

    def __call__(self, inputs: InputLike, /, **kwargs: Any) -> Any:
        """Specify calling behavior for class instance."""
        return self.invoke(inputs, **kwargs)

    # Runtime enforcement: forbid subclasses from overriding invoke
    def __init_subclass__(cls, **kwargs):
        """Ensure subclass does not override key method."""
        super().__init_subclass__(**kwargs)
        if "invoke" in cls.__dict__:
            err_msg = (
                f"{cls.__name__} must not override BaseAgent.invoke(); "
                "implement _invoke() only."
            )
            raise TypeError(err_msg)

    def stream(
        self,
        inputs: InputLike,
        config: Any | None = None,  # allow positional/keyword like LangGraph
        /,
        *,
        raw_debug: bool = False,
        save_json: bool | None = None,
        metrics_path: str | None = None,
        save_raw_snapshot: bool | None = None,
        save_raw_records: bool | None = None,
        **kwargs: Any,
    ) -> Iterator[Any]:
        """Streams agent responses with telemetry tracking.

        This method serves as the public streaming entry point for agent interactions.
        It wraps the actual streaming implementation with telemetry tracking to capture
        metrics and debugging information.

        Args:
            inputs: The input to process, which will be normalized internally.
            config: Optional configuration for the agent, compatible with LangGraph
                positional/keyword argument style.
            raw_debug: If True, renders raw debug information in telemetry output.
            save_json: If True, saves telemetry data as JSON.
            metrics_path: Optional file path where metrics should be saved.
            save_raw_snapshot: If True, saves raw snapshot data in telemetry.
            save_raw_records: If True, saves raw record data in telemetry.
            **kwargs: Additional keyword arguments passed to the streaming
                implementation.

        Returns:
            An iterator yielding the agent's responses.

        Note:
            This method tracks invocation depth to properly handle nested agent calls
            and ensure telemetry is only rendered once at the top level.
        """
        # Track invocation depth to handle nested agent calls
        BaseAgent._invoke_depth += 1

        try:
            # Start telemetry tracking for top-level invocations only
            if BaseAgent._invoke_depth == 1:
                self.telemetry.begin_run(
                    agent=self.name, thread_id=self.thread_id
                )

            # Normalize inputs and delegate to the actual streaming implementation
            normalized = self._normalize_inputs(inputs)
            yield from self._stream(normalized, config=config, **kwargs)

        finally:
            # Decrement invocation depth when exiting
            BaseAgent._invoke_depth -= 1

            # Render telemetry data only for top-level invocations
            if BaseAgent._invoke_depth == 0:
                self.telemetry.render(
                    raw=raw_debug,
                    save_json=save_json,
                    filepath=metrics_path,
                    save_raw_snapshot=save_raw_snapshot,
                    save_raw_records=save_raw_records,
                )

    def _stream(
        self,
        inputs: Mapping[str, Any],
        *,
        config: Any | None = None,
        **kwargs: Any,
    ) -> Iterator[Any]:
        """Subclass method to be overwritten for streaming implementation."""
        raise NotImplementedError(
            f"{self.name} does not support streaming. "
            "Override _stream(...) in your agent to enable it."
        )

    def _default_node_tags(
        self, name: str, extra: Sequence[str] | None = None
    ) -> list[str]:
        """Generate default tags for a graph node.

        Args:
            name: The name of the node.
            extra: Optional sequence of additional tags to include.

        Returns:
            list[str]: A list of tags for the node, including the agent name, 'graph',
                the node name, and any extra tags provided.
        """
        # Start with standard tags: agent name, graph indicator, and node name
        tags = [self.name, "graph", name]

        # Add any extra tags if provided
        if extra:
            tags.extend(extra)

        return tags

    def _as_runnable(self, fn: Any):
        """Convert a function to a runnable if it isn't already.

        Args:
            fn: The function or object to convert to a runnable.

        Returns:
            A runnable object that can be used in the graph. If the input is already
            runnable (has .with_config and .invoke methods), it's returned as is.
            Otherwise, it's wrapped in a RunnableLambda.
        """
        # Check if the function already has the required runnable interface
        # If so, return it as is; otherwise wrap it in a RunnableLambda
        return (
            fn
            if hasattr(fn, "with_config") and hasattr(fn, "invoke")
            else RunnableLambda(fn)
        )

    def _node_cfg(self, name: str, *extra_tags: str) -> dict:
        """Build a consistent configuration for a node/runnable.

        Creates a configuration dict that can be reapplied after operations like
        .map(), subgraph compile, etc.

        Args:
            name: The name of the node.
            *extra_tags: Additional tags to include in the node configuration.

        Returns:
            dict: A configuration dictionary with run_name, tags, and metadata.
        """
        # Determine the namespace - use first extra tag if available, otherwise
        # convert agent name to snake_case
        ns = extra_tags[0] if extra_tags else _to_snake(self.name)

        # Combine all tags: agent name, graph indicator, node name, and any extra tags
        tags = [self.name, "graph", name, *extra_tags]

        # Return the complete configuration dictionary
        return dict(
            run_name="node",  # keep "node:" prefixing in the timer
            tags=tags,
            metadata={
                "langgraph_node": name,
                "ursa_ns": ns,
                "ursa_agent": self.name,
            },
        )

    def ns(self, runnable_or_fn, name: str, *extra_tags: str):
        """Return a runnable with node configuration applied.

        Applies the agent's node configuration to a runnable or callable. This method
        should be called again after operations like .map() or subgraph .compile() as
        these operations may drop configuration.

        Args:
            runnable_or_fn: A runnable or callable to configure.
            name: The name to assign to this node.
            *extra_tags: Additional tags to apply to the node.

        Returns:
            A configured runnable with the agent's node configuration applied.
        """
        # Convert input to a runnable if it's not already one
        r = self._as_runnable(runnable_or_fn)
        # Apply node configuration and return the configured runnable
        return r.with_config(**self._node_cfg(name, *extra_tags))

    def _wrap_node(self, fn_or_runnable, name: str, *extra_tags: str):
        """Wrap a function or runnable as a node in the graph.

        This is a convenience wrapper around the ns() method.

        Args:
            fn_or_runnable: A function or runnable to wrap as a node.
            name: The name to assign to this node.
            *extra_tags: Additional tags to apply to the node.

        Returns:
            A configured runnable with the agent's node configuration applied.
        """
        return self.ns(fn_or_runnable, name, *extra_tags)

    def _wrap_cond(self, fn: Any, name: str, *extra_tags: str):
        """Wrap a conditional function as a routing node in the graph.

        Creates a runnable lambda with routing-specific configuration.

        Args:
            fn: The conditional function to wrap.
            name: The name of the routing node.
            *extra_tags: Additional tags to apply to the node.

        Returns:
            A configured RunnableLambda with routing-specific metadata.
        """
        # Use the first extra tag as namespace, or fall back to agent name in snake_case
        ns = extra_tags[0] if extra_tags else _to_snake(self.name)

        # Create and return a configured RunnableLambda for routing
        return RunnableLambda(fn).with_config(
            run_name="node",
            tags=[
                self.name,
                "graph",
                f"route:{name}",
                *extra_tags,
            ],
            metadata={
                "langgraph_node": f"route:{name}",
                "ursa_ns": ns,
                "ursa_agent": self.name,
            },
        )

    def _named(self, runnable: Any, name: str, *extra_tags: str):
        """Apply a specific name and configuration to a runnable.

        Configures a runnable with a specific name and the agent's metadata.

        Args:
            runnable: The runnable to configure.
            name: The name to assign to this runnable.
            *extra_tags: Additional tags to apply to the runnable.

        Returns:
            A configured runnable with the specified name and agent metadata.
        """
        # Use the first extra tag as namespace, or fall back to agent name in snake_case
        ns = extra_tags[0] if extra_tags else _to_snake(self.name)

        # Apply configuration and return the configured runnable
        return runnable.with_config(
            run_name=name,
            tags=[self.name, "graph", name, *extra_tags],
            metadata={
                "langgraph_node": name,
                "ursa_ns": ns,
                "ursa_agent": self.name,
            },
        )

llm = llm instance-attribute

Initializes the base agent with a language model and optional configurations.

Parameters:

Name Type Description Default
llm

a BaseChatModel instance.

required
checkpointer

Optional checkpoint saver for persisting agent state.

required
enable_metrics

Whether to collect performance and usage metrics.

required
metrics_dir

Directory path where metrics will be saved.

required
autosave_metrics

Whether to automatically save metrics to disk.

required
thread_id

Unique identifier for this agent instance. Generated if not provided.

required

name property

Agent name.

__call__(inputs, /, **kwargs)

Specify calling behavior for class instance.

Source code in src/ursa/agents/base.py
469
470
471
def __call__(self, inputs: InputLike, /, **kwargs: Any) -> Any:
    """Specify calling behavior for class instance."""
    return self.invoke(inputs, **kwargs)

__init_subclass__(**kwargs)

Ensure subclass does not override key method.

Source code in src/ursa/agents/base.py
474
475
476
477
478
479
480
481
482
def __init_subclass__(cls, **kwargs):
    """Ensure subclass does not override key method."""
    super().__init_subclass__(**kwargs)
    if "invoke" in cls.__dict__:
        err_msg = (
            f"{cls.__name__} must not override BaseAgent.invoke(); "
            "implement _invoke() only."
        )
        raise TypeError(err_msg)

add_node(graph, f, node_name=None, agent_name=None)

Add a node to the state graph with token usage tracking.

This method adds a function as a node to the state graph, wrapping it to track token usage during execution. The node is identified by either the provided node_name or the function's name.

Parameters:

Name Type Description Default
graph StateGraph

The StateGraph to add the node to.

required
f Callable[..., Mapping[str, Any]]

The function to add as a node. Should return a mapping of string keys to any values.

required
node_name Optional[str]

Optional name for the node. If not provided, the function's name will be used.

None
agent_name Optional[str]

Optional agent name for tracking. If not provided, the agent's name in snake_case will be used.

None

Returns:

Type Description
StateGraph

The updated StateGraph with the new node added.

Source code in src/ursa/agents/base.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def add_node(
    self,
    graph: StateGraph,
    f: Callable[..., Mapping[str, Any]],
    node_name: Optional[str] = None,
    agent_name: Optional[str] = None,
) -> StateGraph:
    """Add a node to the state graph with token usage tracking.

    This method adds a function as a node to the state graph, wrapping it to track
    token usage during execution. The node is identified by either the provided
    node_name or the function's name.

    Args:
        graph: The StateGraph to add the node to.
        f: The function to add as a node. Should return a mapping of string keys to
            any values.
        node_name: Optional name for the node. If not provided, the function's name
            will be used.
        agent_name: Optional agent name for tracking. If not provided, the agent's
            name in snake_case will be used.

    Returns:
        The updated StateGraph with the new node added.
    """
    _node_name = node_name or f.__name__
    _agent_name = agent_name or _to_snake(self.name)
    wrapped_node = self._wrap_node(f, _node_name, _agent_name)

    return graph.add_node(_node_name, wrapped_node)

ainvoke(inputs=None, /, *, raw_debug=False, save_json=None, metrics_path=None, save_raw_snapshot=None, save_raw_records=None, config=None, **kwargs)

Asynchrnously executes the agent with the provided inputs and configuration.

(Async version of invoke.)

This is the main entry point for agent execution. It handles input normalization, telemetry tracking, and proper execution context management. The method supports flexible input formats - either as a positional argument or as keyword arguments.

Parameters:

Name Type Description Default
inputs Optional[InputLike]

Optional positional input to the agent. If provided, all non-control keyword arguments will be rejected to avoid ambiguity.

None
raw_debug bool

If True, displays raw telemetry data for debugging purposes.

False
save_json Optional[bool]

If True, saves telemetry data as JSON.

None
metrics_path Optional[str]

Optional file path where telemetry metrics should be saved.

None
save_raw_snapshot Optional[bool]

If True, saves a raw snapshot of the telemetry data.

None
save_raw_records Optional[bool]

If True, saves raw telemetry records.

None
config Optional[dict]

Optional configuration dictionary to override default settings.

None
**kwargs Any

Additional keyword arguments that can be either: - Input parameters (when no positional input is provided) - Control parameters recognized by the agent

{}

Returns:

Type Description
Any

The result of the agent's execution.

Raises:

Type Description
TypeError

If both positional inputs and non-control keyword arguments are provided simultaneously.

Source code in src/ursa/agents/base.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
@final
def ainvoke(
    self,
    inputs: Optional[InputLike] = None,
    /,
    *,
    raw_debug: bool = False,
    save_json: Optional[bool] = None,
    metrics_path: Optional[str] = None,
    save_raw_snapshot: Optional[bool] = None,
    save_raw_records: Optional[bool] = None,
    config: Optional[dict] = None,
    **kwargs: Any,
) -> Any:
    """Asynchrnously executes the agent with the provided inputs and configuration.

    (Async version of `invoke`.)

    This is the main entry point for agent execution. It handles input normalization,
    telemetry tracking, and proper execution context management. The method supports
    flexible input formats - either as a positional argument or as keyword arguments.

    Args:
        inputs: Optional positional input to the agent. If provided, all non-control
            keyword arguments will be rejected to avoid ambiguity.
        raw_debug: If True, displays raw telemetry data for debugging purposes.
        save_json: If True, saves telemetry data as JSON.
        metrics_path: Optional file path where telemetry metrics should be saved.
        save_raw_snapshot: If True, saves a raw snapshot of the telemetry data.
        save_raw_records: If True, saves raw telemetry records.
        config: Optional configuration dictionary to override default settings.
        **kwargs: Additional keyword arguments that can be either:
            - Input parameters (when no positional input is provided)
            - Control parameters recognized by the agent

    Returns:
        The result of the agent's execution.

    Raises:
        TypeError: If both positional inputs and non-control keyword arguments are
            provided simultaneously.
    """
    return self._invoke_engine(
        invoke_method=self._ainvoke,
        inputs=inputs,
        raw_debug=raw_debug,
        save_json=save_json,
        metrics_path=metrics_path,
        save_raw_snapshot=save_raw_snapshot,
        save_raw_records=save_raw_records,
        config=config,
        **kwargs,
    )

build_config(**overrides)

Constructs a config dictionary for agent operations with telemetry support.

This method creates a standardized configuration dictionary that includes thread identification, telemetry callbacks, and other metadata needed for agent operations. The configuration can be customized through override parameters.

Parameters:

Name Type Description Default
**overrides

Optional configuration overrides that can include keys like 'recursion_limit', 'configurable', 'metadata', 'tags', etc.

{}

Returns:

Name Type Description
dict dict

A complete configuration dictionary with all necessary parameters.

Source code in src/ursa/agents/base.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def build_config(self, **overrides) -> dict:
    """Constructs a config dictionary for agent operations with telemetry support.

    This method creates a standardized configuration dictionary that includes thread
    identification, telemetry callbacks, and other metadata needed for agent
    operations. The configuration can be customized through override parameters.

    Args:
        **overrides: Optional configuration overrides that can include keys like
            'recursion_limit', 'configurable', 'metadata', 'tags', etc.

    Returns:
        dict: A complete configuration dictionary with all necessary parameters.
    """
    # Create the base configuration with essential fields.
    base = {
        "configurable": {"thread_id": self.thread_id},
        "metadata": {
            "thread_id": self.thread_id,
            "telemetry_run_id": self.telemetry.context.get("run_id"),
        },
        "tags": [self.name],
        "callbacks": self.telemetry.callbacks,
    }

    # Try to determine the model name from either direct or nested attributes
    model_name = getattr(self, "llm_model", None) or getattr(
        getattr(self, "llm", None), "model", None
    )

    # Add model name to metadata if available
    if model_name:
        base["metadata"]["model"] = model_name

    # Handle configurable dictionary overrides by merging with base configurable
    if "configurable" in overrides and isinstance(
        overrides["configurable"], dict
    ):
        base["configurable"].update(overrides.pop("configurable"))

    # Handle metadata dictionary overrides by merging with base metadata
    if "metadata" in overrides and isinstance(overrides["metadata"], dict):
        base["metadata"].update(overrides.pop("metadata"))

    # Merge tags from caller-provided overrides, avoid duplicates
    if "tags" in overrides and isinstance(overrides["tags"], list):
        base["tags"] = base["tags"] + [
            t for t in overrides.pop("tags") if t not in base["tags"]
        ]

    # Apply any remaining overrides directly to the base configuration
    base.update(overrides)

    return base

invoke(inputs=None, /, *, raw_debug=False, save_json=None, metrics_path=None, save_raw_snapshot=None, save_raw_records=None, config=None, **kwargs)

Executes the agent with the provided inputs and configuration.

This is the main entry point for agent execution. It handles input normalization, telemetry tracking, and proper execution context management. The method supports flexible input formats - either as a positional argument or as keyword arguments.

Parameters:

Name Type Description Default
inputs Optional[InputLike]

Optional positional input to the agent. If provided, all non-control keyword arguments will be rejected to avoid ambiguity.

None
raw_debug bool

If True, displays raw telemetry data for debugging purposes.

False
save_json Optional[bool]

If True, saves telemetry data as JSON.

None
metrics_path Optional[str]

Optional file path where telemetry metrics should be saved.

None
save_raw_snapshot Optional[bool]

If True, saves a raw snapshot of the telemetry data.

None
save_raw_records Optional[bool]

If True, saves raw telemetry records.

None
config Optional[dict]

Optional configuration dictionary to override default settings.

None
**kwargs Any

Additional keyword arguments that can be either: - Input parameters (when no positional input is provided) - Control parameters recognized by the agent

{}

Returns:

Type Description
Any

The result of the agent's execution.

Raises:

Type Description
TypeError

If both positional inputs and non-control keyword arguments are provided simultaneously.

Source code in src/ursa/agents/base.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
@final
def invoke(
    self,
    inputs: Optional[InputLike] = None,
    /,
    *,
    raw_debug: bool = False,
    save_json: Optional[bool] = None,
    metrics_path: Optional[str] = None,
    save_raw_snapshot: Optional[bool] = None,
    save_raw_records: Optional[bool] = None,
    config: Optional[dict] = None,
    **kwargs: Any,
) -> Any:
    """Executes the agent with the provided inputs and configuration.

    This is the main entry point for agent execution. It handles input normalization,
    telemetry tracking, and proper execution context management. The method supports
    flexible input formats - either as a positional argument or as keyword arguments.

    Args:
        inputs: Optional positional input to the agent. If provided, all non-control
            keyword arguments will be rejected to avoid ambiguity.
        raw_debug: If True, displays raw telemetry data for debugging purposes.
        save_json: If True, saves telemetry data as JSON.
        metrics_path: Optional file path where telemetry metrics should be saved.
        save_raw_snapshot: If True, saves a raw snapshot of the telemetry data.
        save_raw_records: If True, saves raw telemetry records.
        config: Optional configuration dictionary to override default settings.
        **kwargs: Additional keyword arguments that can be either:
            - Input parameters (when no positional input is provided)
            - Control parameters recognized by the agent

    Returns:
        The result of the agent's execution.

    Raises:
        TypeError: If both positional inputs and non-control keyword arguments are
            provided simultaneously.
    """
    return self._invoke_engine(
        invoke_method=self._invoke,
        inputs=inputs,
        raw_debug=raw_debug,
        save_json=save_json,
        metrics_path=metrics_path,
        save_raw_snapshot=save_raw_snapshot,
        save_raw_records=save_raw_records,
        config=config,
        **kwargs,
    )

ns(runnable_or_fn, name, *extra_tags)

Return a runnable with node configuration applied.

Applies the agent's node configuration to a runnable or callable. This method should be called again after operations like .map() or subgraph .compile() as these operations may drop configuration.

Parameters:

Name Type Description Default
runnable_or_fn

A runnable or callable to configure.

required
name str

The name to assign to this node.

required
*extra_tags str

Additional tags to apply to the node.

()

Returns:

Type Description

A configured runnable with the agent's node configuration applied.

Source code in src/ursa/agents/base.py
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
def ns(self, runnable_or_fn, name: str, *extra_tags: str):
    """Return a runnable with node configuration applied.

    Applies the agent's node configuration to a runnable or callable. This method
    should be called again after operations like .map() or subgraph .compile() as
    these operations may drop configuration.

    Args:
        runnable_or_fn: A runnable or callable to configure.
        name: The name to assign to this node.
        *extra_tags: Additional tags to apply to the node.

    Returns:
        A configured runnable with the agent's node configuration applied.
    """
    # Convert input to a runnable if it's not already one
    r = self._as_runnable(runnable_or_fn)
    # Apply node configuration and return the configured runnable
    return r.with_config(**self._node_cfg(name, *extra_tags))

stream(inputs, config=None, /, *, raw_debug=False, save_json=None, metrics_path=None, save_raw_snapshot=None, save_raw_records=None, **kwargs)

Streams agent responses with telemetry tracking.

This method serves as the public streaming entry point for agent interactions. It wraps the actual streaming implementation with telemetry tracking to capture metrics and debugging information.

Parameters:

Name Type Description Default
inputs InputLike

The input to process, which will be normalized internally.

required
config Any | None

Optional configuration for the agent, compatible with LangGraph positional/keyword argument style.

None
raw_debug bool

If True, renders raw debug information in telemetry output.

False
save_json bool | None

If True, saves telemetry data as JSON.

None
metrics_path str | None

Optional file path where metrics should be saved.

None
save_raw_snapshot bool | None

If True, saves raw snapshot data in telemetry.

None
save_raw_records bool | None

If True, saves raw record data in telemetry.

None
**kwargs Any

Additional keyword arguments passed to the streaming implementation.

{}

Returns:

Type Description
Iterator[Any]

An iterator yielding the agent's responses.

Note

This method tracks invocation depth to properly handle nested agent calls and ensure telemetry is only rendered once at the top level.

Source code in src/ursa/agents/base.py
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
def stream(
    self,
    inputs: InputLike,
    config: Any | None = None,  # allow positional/keyword like LangGraph
    /,
    *,
    raw_debug: bool = False,
    save_json: bool | None = None,
    metrics_path: str | None = None,
    save_raw_snapshot: bool | None = None,
    save_raw_records: bool | None = None,
    **kwargs: Any,
) -> Iterator[Any]:
    """Streams agent responses with telemetry tracking.

    This method serves as the public streaming entry point for agent interactions.
    It wraps the actual streaming implementation with telemetry tracking to capture
    metrics and debugging information.

    Args:
        inputs: The input to process, which will be normalized internally.
        config: Optional configuration for the agent, compatible with LangGraph
            positional/keyword argument style.
        raw_debug: If True, renders raw debug information in telemetry output.
        save_json: If True, saves telemetry data as JSON.
        metrics_path: Optional file path where metrics should be saved.
        save_raw_snapshot: If True, saves raw snapshot data in telemetry.
        save_raw_records: If True, saves raw record data in telemetry.
        **kwargs: Additional keyword arguments passed to the streaming
            implementation.

    Returns:
        An iterator yielding the agent's responses.

    Note:
        This method tracks invocation depth to properly handle nested agent calls
        and ensure telemetry is only rendered once at the top level.
    """
    # Track invocation depth to handle nested agent calls
    BaseAgent._invoke_depth += 1

    try:
        # Start telemetry tracking for top-level invocations only
        if BaseAgent._invoke_depth == 1:
            self.telemetry.begin_run(
                agent=self.name, thread_id=self.thread_id
            )

        # Normalize inputs and delegate to the actual streaming implementation
        normalized = self._normalize_inputs(inputs)
        yield from self._stream(normalized, config=config, **kwargs)

    finally:
        # Decrement invocation depth when exiting
        BaseAgent._invoke_depth -= 1

        # Render telemetry data only for top-level invocations
        if BaseAgent._invoke_depth == 0:
            self.telemetry.render(
                raw=raw_debug,
                save_json=save_json,
                filepath=metrics_path,
                save_raw_snapshot=save_raw_snapshot,
                save_raw_records=save_raw_records,
            )

write_state(filename, state)

Writes agent state to a JSON file.

Serializes the provided state dictionary to JSON format and writes it to the specified file. The JSON is written with non-ASCII characters preserved.

Parameters:

Name Type Description Default
filename str

Path to the file where state will be written.

required
state dict

Dictionary containing the agent state to be serialized.

required
Source code in src/ursa/agents/base.py
188
189
190
191
192
193
194
195
196
197
198
199
200
def write_state(self, filename: str, state: dict) -> None:
    """Writes agent state to a JSON file.

    Serializes the provided state dictionary to JSON format and writes it to the
    specified file. The JSON is written with non-ASCII characters preserved.

    Args:
        filename: Path to the file where state will be written.
        state: Dictionary containing the agent state to be serialized.
    """
    json_state = dumps(state, ensure_ascii=False)
    with open(filename, "w") as f:
        f.write(json_state)

code_review_agent

read_file(filename, state)

Reads in a file with a given filename into a string

Parameters:

Name Type Description Default
filename str

string filename to read in

required
Source code in src/ursa/agents/code_review_agent.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
@tool
def read_file(filename: str, state: Annotated[dict, InjectedState]):
    """
    Reads in a file with a given filename into a string

    Args:
        filename: string filename to read in
    """
    workspace_dir = state["workspace"]
    full_filename = os.path.join(workspace_dir, filename)

    print("[READING]: ", full_filename)
    with open(full_filename, "r") as file:
        file_contents = file.read()
    return file_contents

run_cmd(query, state)

Run command from commandline

Source code in src/ursa/agents/code_review_agent.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
@tool
def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
    """Run command from commandline"""
    workspace_dir = state["workspace"]

    print("RUNNING: ", query)
    process = subprocess.Popen(
        query.split(" "),
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        cwd=workspace_dir,
    )

    stdout, stderr = process.communicate(timeout=600)

    print("STDOUT: ", stdout)
    print("STDERR: ", stderr)

    return f"STDOUT: {stdout} and STDERR: {stderr}"

write_file(code, filename, state)

Writes text to a file in the given workspace as requested.

Parameters:

Name Type Description Default
code str

Text to write to a file

required
filename str

the filename to write to

required

Returns:

Type Description
str

Execution results

Source code in src/ursa/agents/code_review_agent.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
@tool
def write_file(
    code: str, filename: str, state: Annotated[dict, InjectedState]
) -> str:
    """
    Writes text to a file in the given workspace as requested.

    Args:
        code: Text to write to a file
        filename: the filename to write to

    Returns:
        Execution results
    """
    workspace_dir = state["workspace"]

    print("[WRITING]: ", filename)
    try:
        # Extract code if wrapped in markdown code blocks
        if "```" in code:
            code_parts = code.split("```")
            if len(code_parts) >= 3:
                # Extract the actual code
                if "\n" in code_parts[1]:
                    code = "\n".join(code_parts[1].strip().split("\n")[1:])
                else:
                    code = code_parts[2].strip()

        # Write code to a file
        code_file = os.path.join(workspace_dir, filename)

        with open(code_file, "w") as f:
            f.write(code)
        print(f"Written code to file: {code_file}")

        return f"File {filename} written successfully."

    except Exception as e:
        print(f"Error generating code: {str(e)}")
        # Return minimal code that prints the error
        return f"Failed to write {filename} successfully."

execution_agent

Execution agent that builds a tool-enabled state graph to autonomously run tasks.

This module implements ExecutionAgent, a LangGraph-based agent that executes user instructions by invoking LLM tool calls and coordinating a controlled workflow.

Key features: - Workspace management with optional symlinking for external sources. - Safety-checked shell execution via run_command with output size budgeting. - Code authoring and edits through write_code and edit_code with rich previews. - Web search capability through DuckDuckGoSearchResults. - Summarization of the session and optional memory logging. - Configurable graph with nodes for agent, safety_check, action, and summarize.

Implementation notes: - LLM prompts are sourced from prompt_library.execution_prompts. - Outputs from subprocess are trimmed under MAX_TOOL_MSG_CHARS to fit tool messages. - The agent uses ToolNode and LangGraph StateGraph to loop until no tool calls remain. - Safety gates block unsafe shell commands and surface the rationale to the user.

Environment: - MAX_TOOL_MSG_CHARS caps combined stdout/stderr in tool responses.

Entry points: - ExecutionAgent._invoke(...) runs the compiled graph. - main() shows a minimal demo that writes and runs a script.

ExecutionAgent

Bases: BaseAgent

Orchestrates model-driven code execution, tool calls, and state management.

Orchestrates model-driven code execution, tool calls, and state management for iterative program synthesis and shell interaction.

This agent wraps an LLM with a small execution graph that alternates between issuing model queries, invoking tools (read, run, write, edit, search), performing safety checks, and summarizing progress. It manages a workspace on disk, optional symlinks, and an optional memory backend to persist summaries.

Parameters:

Name Type Description Default
llm BaseChatModel

Model identifier or bound chat model instance. If a string is provided, the BaseAgent initializer will resolve it.

required
agent_memory Any | AgentMemory

Memory backend used to store summarized agent interactions. If provided, summaries are saved here.

None
log_state bool

When True, the agent writes intermediate json state to disk for debugging and auditability.

False
**kwargs

Passed through to the BaseAgent constructor (e.g., model configuration, checkpointer).

{}

Attributes:

Name Type Description
safe_codes list[str]

List of trusted programming languages for the agent. Defaults to python and julia

executor_prompt str

Prompt used when invoking the executor LLM loop.

summarize_prompt str

Prompt used to request concise summaries for memory or final output.

tools list[Tool]

Tools available to the agent (run_command, write_code, edit_code, read_file, run_web_search, run_osti_search, run_arxiv_search).

tool_node ToolNode

Graph node that dispatches tool calls.

llm BaseChatModel

LLM instance bound to the available tools.

_action StateGraph

Compiled execution graph that implements the main loop and branching logic.

Methods:

Name Description
query_executor

Send messages to the executor LLM, ensure workspace exists, and handle symlink setup before returning the model response.

summarize

Produce and optionally persist a summary of recent interactions to the memory backend.

safety_check

Validate pending run_command calls via the safety prompt and append ToolMessages for unsafe commands.

get_safety_prompt

Get the LLM prompt for safety_check that includes an editable list of available programming languages and gets the context of files that the agent has generated and can trust.

_build_graph

Construct and compile the StateGraph for the agent loop.

_invoke

Internal entry that invokes the compiled graph with a given recursion limit.

action

Disabled; direct access is not supported. Use invoke or stream entry points instead.

Raises:

Type Description
AttributeError

Accessing the .action attribute raises to encourage using .stream(...) or .invoke(...).

Source code in src/ursa/agents/execution_agent.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
class ExecutionAgent(BaseAgent):
    """Orchestrates model-driven code execution, tool calls, and state management.

    Orchestrates model-driven code execution, tool calls, and state management for
    iterative program synthesis and shell interaction.

    This agent wraps an LLM with a small execution graph that alternates
    between issuing model queries, invoking tools (read, run, write, edit, search),
    performing safety checks, and summarizing progress. It manages a
    workspace on disk, optional symlinks, and an optional memory backend to
    persist summaries.

    Args:
        llm (BaseChatModel): Model identifier or bound chat model
            instance. If a string is provided, the BaseAgent initializer will
            resolve it.
        agent_memory (Any | AgentMemory, optional): Memory backend used to
            store summarized agent interactions. If provided, summaries are
            saved here.
        log_state (bool): When True, the agent writes intermediate json state
            to disk for debugging and auditability.
        **kwargs: Passed through to the BaseAgent constructor (e.g., model
            configuration, checkpointer).

    Attributes:
        safe_codes (list[str]): List of trusted programming languages for the
            agent. Defaults to python and julia
        executor_prompt (str): Prompt used when invoking the executor LLM
            loop.
        summarize_prompt (str): Prompt used to request concise summaries for
            memory or final output.
        tools (list[Tool]): Tools available to the agent (run_command, write_code,
            edit_code, read_file, run_web_search, run_osti_search, run_arxiv_search).
        tool_node (ToolNode): Graph node that dispatches tool calls.
        llm (BaseChatModel): LLM instance bound to the available tools.
        _action (StateGraph): Compiled execution graph that implements the
            main loop and branching logic.

    Methods:
        query_executor(state): Send messages to the executor LLM, ensure
            workspace exists, and handle symlink setup before returning the
            model response.
        summarize(state): Produce and optionally persist a summary of recent
            interactions to the memory backend.
        safety_check(state): Validate pending run_command calls via the safety
            prompt and append ToolMessages for unsafe commands.
        get_safety_prompt(query, safe_codes, created_files): Get the LLM prompt for safety_check
            that includes an editable list of available programming languages and gets the context
            of files that the agent has generated and can trust.
        _build_graph(): Construct and compile the StateGraph for the agent
            loop.
        _invoke(inputs, recursion_limit=...): Internal entry that invokes the
            compiled graph with a given recursion limit.
        action (property): Disabled; direct access is not supported. Use
            invoke or stream entry points instead.

    Raises:
        AttributeError: Accessing the .action attribute raises to encourage
            using .stream(...) or .invoke(...).
    """

    def __init__(
        self,
        llm: BaseChatModel,
        agent_memory: Optional[Any | AgentMemory] = None,
        log_state: bool = False,
        extra_tools: Optional[list[Callable[..., Any]]] = None,
        tokens_before_summarize: int = 50000,
        messages_to_keep: int = 20,
        safe_codes: Optional[list[str]] = None,
        **kwargs,
    ):
        """ExecutionAgent class initialization."""
        super().__init__(llm, **kwargs)
        self.agent_memory = agent_memory
        self.safe_codes = safe_codes or ["python", "julia"]
        self.get_safety_prompt = get_safety_prompt
        self.executor_prompt = executor_prompt
        self.summarize_prompt = summarize_prompt
        self.tools = [
            run_command,
            write_code,
            edit_code,
            read_file,
            run_web_search,
            run_osti_search,
            run_arxiv_search,
        ]
        self.extra_tools = extra_tools
        if self.extra_tools is not None:
            self.tools.extend(self.extra_tools)
        self.tool_node = ToolNode(self.tools)
        self.llm = self.llm.bind_tools(self.tools)
        self.log_state = log_state
        self._action = self._build_graph()
        self.context_summarizer = SummarizationMiddleware(
            model=self.llm,
            max_tokens_before_summary=tokens_before_summarize,
            messages_to_keep=messages_to_keep,
        )

    # Check message history length and summarize to shorten the token usage:
    def _summarize_context(self, state: ExecutionState) -> ExecutionState:
        summarized_messages = self.context_summarizer.before_model(state, None)
        if summarized_messages:
            tokens_before_summarize = self.context_summarizer.token_counter(
                state["messages"]
            )
            state["messages"] = summarized_messages["messages"]
            tokens_after_summarize = self.context_summarizer.token_counter(
                state["messages"][1:]
            )
            console.print(
                Panel(
                    (
                        f"Summarized Conversation History:\n"
                        f"Approximate tokens before: {tokens_before_summarize}\n"
                        f"Approximate tokens after: {tokens_after_summarize}\n"
                    ),
                    title="[bold yellow1 on black]:clipboard: Plan",
                    border_style="yellow1",
                    style="bold yellow1 on black",
                )
            )
        else:
            tokens_after_summarize = self.context_summarizer.token_counter(
                state["messages"]
            )
        return state

    # Define the function that calls the model
    def query_executor(self, state: ExecutionState) -> ExecutionState:
        """Prepare workspace, handle optional symlinks, and invoke the executor LLM.

        This method copies the incoming state, ensures a workspace directory exists
        (creating one with a random name when absent), optionally creates a symlink
        described by state["symlinkdir"], sets or injects the executor system prompt
        as the first message, and invokes the bound LLM. When logging is enabled,
        it persists the pre-invocation state to disk.

        Args:
            state: The current execution state. Expected keys include:
                - "messages": Ordered list of System/Human/AI/Tool messages.
                - "workspace": Optional path to the working directory.
                - "symlinkdir": Optional dict with "source" and "dest" keys.

        Returns:
            ExecutionState: Partial state update containing:
                - "messages": A list with the model's response as the latest entry.
                - "workspace": The resolved workspace path.
        """
        # Add model to the state so it can be passed to tools like the URSA Arxiv or OSTI tools
        state.setdefault("model", self.llm)
        new_state = state.copy()

        # 1) Ensure a workspace directory exists, creating a named one if absent.
        if "workspace" not in new_state.keys():
            new_state["workspace"] = randomname.get_name()
            print(
                f"{RED}Creating the folder "
                f"{BLUE}{BOLD}{new_state['workspace']}{RESET}{RED} "
                f"for this project.{RESET}"
            )
        os.makedirs(new_state["workspace"], exist_ok=True)

        # 1.5) Check message history length and summarize to shorten the token usage:
        new_state = self._summarize_context(new_state)

        # 2) Optionally create a symlink if symlinkdir is provided and not yet linked.
        sd = new_state.get("symlinkdir")
        if isinstance(sd, dict) and "is_linked" not in sd:
            # symlinkdir structure: {"source": "/path/to/src", "dest": "link/name"}
            symlinkdir = sd

            src = Path(symlinkdir["source"]).expanduser().resolve()
            workspace_root = Path(new_state["workspace"]).expanduser().resolve()
            dst = (
                workspace_root / symlinkdir["dest"]
            )  # Link lives inside workspace.

            # If a file/link already exists at the destination, replace it.
            if dst.exists() or dst.is_symlink():
                dst.unlink()

            # Ensure parent directories for the link exist.
            dst.parent.mkdir(parents=True, exist_ok=True)

            # Create the symlink (tell pathlib if the target is a directory).
            dst.symlink_to(src, target_is_directory=src.is_dir())
            print(f"{RED}Symlinked {src} (source) --> {dst} (dest)")
            new_state["symlinkdir"]["is_linked"] = True

        # 3) Ensure the executor prompt is the first SystemMessage.
        if isinstance(new_state["messages"][0], SystemMessage):
            new_state["messages"][0] = SystemMessage(
                content=self.executor_prompt
            )
        else:
            new_state["messages"] = [
                SystemMessage(content=self.executor_prompt)
            ] + state["messages"]

        # 4) Invoke the LLM with the prepared message sequence.
        try:
            response = self.llm.invoke(
                new_state["messages"], self.build_config(tags=["agent"])
            )
            new_state["messages"].append(response)
        except Exception as e:
            print("Error: ", e, " ", new_state["messages"][-1].content)
            new_state["messages"].append(
                AIMessage(content=f"Response error {e}")
            )

        # 5) Optionally persist the pre-invocation state for audit/debugging.
        if self.log_state:
            self.write_state("execution_agent.json", new_state)

        # Return the model's response and the workspace path as a partial state update.
        return new_state

    def summarize(self, state: ExecutionState) -> ExecutionState:
        """Produce a concise summary of the conversation and optionally persist memory.

        This method builds a summarization prompt, invokes the LLM to obtain a compact
        summary of recent interactions, optionally logs salient details to the agent
        memory backend, and writes debug state when logging is enabled.

        Args:
            state (ExecutionState): The execution state containing message history.

        Returns:
            ExecutionState: A partial update with a single string message containing
                the summary.
        """
        new_state = state.copy()

        # 0) Check message history length and summarize to shorten the token usage:
        new_state = self._summarize_context(new_state)

        # 1) Construct the summarization message list (system prompt + prior messages).
        messages = (
            new_state["messages"]
            if isinstance(new_state["messages"][0], SystemMessage)
            else [SystemMessage(content=summarize_prompt)]
            + new_state["messages"]
        )

        # 2) Invoke the LLM to generate a summary; capture content even on failure.
        response_content = ""
        try:
            response = self.llm.invoke(
                messages, self.build_config(tags=["summarize"])
            )
            response_content = response.content
            new_state["messages"].append(response)
        except Exception as e:
            print("Error: ", e, " ", messages[-1].content)
            new_state["messages"].append(
                AIMessage(content=f"Response error {e}")
            )

        # 3) Optionally persist salient details to the memory backend.
        if self.agent_memory:
            memories: list[str] = []
            # Collect human/system/tool message content; for AI tool calls, store args.
            for msg in new_state["messages"]:
                if not isinstance(msg, AIMessage):
                    memories.append(msg.content)
                elif not msg.tool_calls:
                    memories.append(msg.content)
                else:
                    tool_strings = []
                    for tool in msg.tool_calls:
                        tool_strings.append("Tool Name: " + tool["name"])
                        for arg_name in tool["args"]:
                            tool_strings.append(
                                f"Arg: {str(arg_name)}\nValue: "
                                f"{str(tool['args'][arg_name])}"
                            )
                    memories.append("\n".join(tool_strings))
            memories.append(response_content)
            self.agent_memory.add_memories(memories)

        # 4) Optionally write state to disk for debugging/auditing.
        if self.log_state:
            self.write_state("execution_agent.json", new_state)

        # 5) Return a partial state update with only the summary content.
        return new_state

    def safety_check(self, state: ExecutionState) -> ExecutionState:
        """Assess pending shell commands for safety and inject ToolMessages with results.

        This method inspects the most recent AI tool calls, evaluates any run_command
        queries against the safety prompt, and constructs ToolMessages that either
        flag unsafe commands with reasons or confirm safe execution. If any command
        is unsafe, the generated ToolMessages are appended to the state so the agent
        can react without executing the command.

        Args:
            state (ExecutionState): Current execution state.

        Returns:
            ExecutionState: Either the unchanged state (all safe) or a copy with one
                or more ToolMessages appended when unsafe commands are detected.
        """
        # 1) Work on a shallow copy; inspect the most recent model message.
        new_state = state.copy()
        last_msg = new_state["messages"][-1]

        # 1.5) Check message history length and summarize to shorten the token usage:
        new_state = self._summarize_context(new_state)

        # 2) Evaluate any pending run_command tool calls for safety.
        tool_responses: list[ToolMessage] = []
        any_unsafe = False
        for tool_call in last_msg.tool_calls:
            if tool_call["name"] != "run_command":
                continue

            query = tool_call["args"]["query"]
            safety_result = self.llm.invoke(
                self.get_safety_prompt(
                    query, self.safe_codes, new_state.get("code_files", [])
                ),
                self.build_config(tags=["safety_check"]),
            )

            if "[NO]" in safety_result.content:
                any_unsafe = True
                tool_response = (
                    "[UNSAFE] That command `{q}` was deemed unsafe and cannot be run.\n"
                    "For reason: {r}"
                ).format(q=query, r=safety_result.content)
                console.print(
                    "[bold red][WARNING][/bold red] Command deemed unsafe:",
                    query,
                )
                # Also surface the model's rationale for transparency.
                console.print(
                    "[bold red][WARNING][/bold red] REASON:", tool_response
                )
            else:
                tool_response = f"Command `{query}` passed safety check."
                console.print(
                    f"[green]Command passed safety check:[/green] {query}"
                )

            tool_responses.append(
                ToolMessage(
                    content=tool_response,
                    tool_call_id=tool_call["id"],
                )
            )

        # 3) If any command is unsafe, append all tool responses; otherwise keep state.
        if any_unsafe:
            new_state["messages"].extend(tool_responses)

        return new_state

    def _build_graph(self):
        """Construct and compile the agent's LangGraph state machine."""
        # Create a graph over the agent's execution state.
        graph = StateGraph(ExecutionState)

        # Register nodes:
        # - "agent": LLM planning/execution step
        # - "action": tool dispatch (run_command, write_code, etc.)
        # - "summarize": summary/finalization step
        # - "safety_check": gate for shell command safety
        self.add_node(graph, self.query_executor, "agent")
        self.add_node(graph, self.tool_node, "action")
        self.add_node(graph, self.summarize, "summarize")
        self.add_node(graph, self.safety_check, "safety_check")

        # Set entrypoint: execution starts with the "agent" node.
        graph.set_entry_point("agent")

        # From "agent", either continue (tools) or finish (summarize),
        # based on presence of tool calls in the last message.
        graph.add_conditional_edges(
            "agent",
            self._wrap_cond(should_continue, "should_continue", "execution"),
            {"continue": "safety_check", "summarize": "summarize"},
        )

        # From "safety_check", route to tools if safe, otherwise back to agent
        # to revise the plan without executing unsafe commands.
        graph.add_conditional_edges(
            "safety_check",
            self._wrap_cond(command_safe, "command_safe", "execution"),
            {"safe": "action", "unsafe": "agent"},
        )

        # After tools run, return control to the agent for the next step.
        graph.add_edge("action", "agent")

        # The graph completes at the "summarize" node.
        graph.set_finish_point("summarize")

        # Compile and return the executable graph (optionally with a checkpointer).
        return graph.compile(checkpointer=self.checkpointer)

    async def add_mcp_tool(
        self, mcp_tools: Callable[..., Any] | list[Callable[..., Any]]
    ) -> None:
        client = MultiServerMCPClient(mcp_tools)
        tools = await client.get_tools()
        self.add_tool(tools)

    def add_tool(
        self, new_tools: Callable[..., Any] | list[Callable[..., Any]]
    ) -> None:
        if isinstance(new_tools, list):
            self.tools.extend([convert_to_tool(x) for x in new_tools])
        elif isinstance(new_tools, StructuredTool) or isinstance(
            new_tools, Callable
        ):
            self.tools.append(convert_to_tool(new_tools))
        else:
            raise TypeError("Expected a callable or a list of callables.")
        self.tool_node = ToolNode(self.tools)
        self.llm = self.llm.bind_tools(self.tools)
        self._action = self._build_graph()

    def list_tools(self) -> None:
        print(
            f"Available tool names are: {', '.join([x.name for x in self.tools])}."
        )

    def remove_tool(self, cut_tools: str | list[str]) -> None:
        if isinstance(cut_tools, str):
            self.remove_tool([cut_tools])
        elif isinstance(cut_tools, list):
            self.tools = [x for x in self.tools if x.name not in cut_tools]
            self.tool_node = ToolNode(self.tools)
            self.llm = self.llm.bind_tools(self.tools)
            self._action = self._build_graph()
        else:
            raise TypeError(
                "Expected a string or a list of strings describing the tools to remove."
            )

    def _invoke(
        self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
    ):
        """Invoke the compiled graph with inputs under a specified recursion limit.

        This method builds a LangGraph config with the provided recursion limit
        and a "graph" tag, then delegates to the compiled graph's invoke method.
        """
        # Build invocation config with a generous recursion limit for long runs.
        config = self.build_config(
            recursion_limit=recursion_limit, tags=["graph"]
        )

        # Delegate execution to the compiled graph.
        return self._action.invoke(inputs, config)

    def _ainvoke(
        self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
    ):
        """Invoke the compiled graph with inputs under a specified recursion limit.

        This method builds a LangGraph config with the provided recursion limit
        and a "graph" tag, then delegates to the compiled graph's invoke method.
        """
        # Build invocation config with a generous recursion limit for long runs.
        config = self.build_config(
            recursion_limit=recursion_limit, tags=["graph"]
        )

        # Delegate execution to the compiled graph.
        return self._action.ainvoke(inputs, config)

    # This property is trying to stop people bypassing invoke
    @property
    def action(self):
        """Property used to affirm `action` attribute is unsupported."""
        raise AttributeError(
            "Use .stream(...) or .invoke(...); direct .action access is unsupported."
        )

action property

Property used to affirm action attribute is unsupported.

__init__(llm, agent_memory=None, log_state=False, extra_tools=None, tokens_before_summarize=50000, messages_to_keep=20, safe_codes=None, **kwargs)

ExecutionAgent class initialization.

Source code in src/ursa/agents/execution_agent.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def __init__(
    self,
    llm: BaseChatModel,
    agent_memory: Optional[Any | AgentMemory] = None,
    log_state: bool = False,
    extra_tools: Optional[list[Callable[..., Any]]] = None,
    tokens_before_summarize: int = 50000,
    messages_to_keep: int = 20,
    safe_codes: Optional[list[str]] = None,
    **kwargs,
):
    """ExecutionAgent class initialization."""
    super().__init__(llm, **kwargs)
    self.agent_memory = agent_memory
    self.safe_codes = safe_codes or ["python", "julia"]
    self.get_safety_prompt = get_safety_prompt
    self.executor_prompt = executor_prompt
    self.summarize_prompt = summarize_prompt
    self.tools = [
        run_command,
        write_code,
        edit_code,
        read_file,
        run_web_search,
        run_osti_search,
        run_arxiv_search,
    ]
    self.extra_tools = extra_tools
    if self.extra_tools is not None:
        self.tools.extend(self.extra_tools)
    self.tool_node = ToolNode(self.tools)
    self.llm = self.llm.bind_tools(self.tools)
    self.log_state = log_state
    self._action = self._build_graph()
    self.context_summarizer = SummarizationMiddleware(
        model=self.llm,
        max_tokens_before_summary=tokens_before_summarize,
        messages_to_keep=messages_to_keep,
    )

query_executor(state)

Prepare workspace, handle optional symlinks, and invoke the executor LLM.

This method copies the incoming state, ensures a workspace directory exists (creating one with a random name when absent), optionally creates a symlink described by state["symlinkdir"], sets or injects the executor system prompt as the first message, and invokes the bound LLM. When logging is enabled, it persists the pre-invocation state to disk.

Parameters:

Name Type Description Default
state ExecutionState

The current execution state. Expected keys include: - "messages": Ordered list of System/Human/AI/Tool messages. - "workspace": Optional path to the working directory. - "symlinkdir": Optional dict with "source" and "dest" keys.

required

Returns:

Name Type Description
ExecutionState ExecutionState

Partial state update containing: - "messages": A list with the model's response as the latest entry. - "workspace": The resolved workspace path.

Source code in src/ursa/agents/execution_agent.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def query_executor(self, state: ExecutionState) -> ExecutionState:
    """Prepare workspace, handle optional symlinks, and invoke the executor LLM.

    This method copies the incoming state, ensures a workspace directory exists
    (creating one with a random name when absent), optionally creates a symlink
    described by state["symlinkdir"], sets or injects the executor system prompt
    as the first message, and invokes the bound LLM. When logging is enabled,
    it persists the pre-invocation state to disk.

    Args:
        state: The current execution state. Expected keys include:
            - "messages": Ordered list of System/Human/AI/Tool messages.
            - "workspace": Optional path to the working directory.
            - "symlinkdir": Optional dict with "source" and "dest" keys.

    Returns:
        ExecutionState: Partial state update containing:
            - "messages": A list with the model's response as the latest entry.
            - "workspace": The resolved workspace path.
    """
    # Add model to the state so it can be passed to tools like the URSA Arxiv or OSTI tools
    state.setdefault("model", self.llm)
    new_state = state.copy()

    # 1) Ensure a workspace directory exists, creating a named one if absent.
    if "workspace" not in new_state.keys():
        new_state["workspace"] = randomname.get_name()
        print(
            f"{RED}Creating the folder "
            f"{BLUE}{BOLD}{new_state['workspace']}{RESET}{RED} "
            f"for this project.{RESET}"
        )
    os.makedirs(new_state["workspace"], exist_ok=True)

    # 1.5) Check message history length and summarize to shorten the token usage:
    new_state = self._summarize_context(new_state)

    # 2) Optionally create a symlink if symlinkdir is provided and not yet linked.
    sd = new_state.get("symlinkdir")
    if isinstance(sd, dict) and "is_linked" not in sd:
        # symlinkdir structure: {"source": "/path/to/src", "dest": "link/name"}
        symlinkdir = sd

        src = Path(symlinkdir["source"]).expanduser().resolve()
        workspace_root = Path(new_state["workspace"]).expanduser().resolve()
        dst = (
            workspace_root / symlinkdir["dest"]
        )  # Link lives inside workspace.

        # If a file/link already exists at the destination, replace it.
        if dst.exists() or dst.is_symlink():
            dst.unlink()

        # Ensure parent directories for the link exist.
        dst.parent.mkdir(parents=True, exist_ok=True)

        # Create the symlink (tell pathlib if the target is a directory).
        dst.symlink_to(src, target_is_directory=src.is_dir())
        print(f"{RED}Symlinked {src} (source) --> {dst} (dest)")
        new_state["symlinkdir"]["is_linked"] = True

    # 3) Ensure the executor prompt is the first SystemMessage.
    if isinstance(new_state["messages"][0], SystemMessage):
        new_state["messages"][0] = SystemMessage(
            content=self.executor_prompt
        )
    else:
        new_state["messages"] = [
            SystemMessage(content=self.executor_prompt)
        ] + state["messages"]

    # 4) Invoke the LLM with the prepared message sequence.
    try:
        response = self.llm.invoke(
            new_state["messages"], self.build_config(tags=["agent"])
        )
        new_state["messages"].append(response)
    except Exception as e:
        print("Error: ", e, " ", new_state["messages"][-1].content)
        new_state["messages"].append(
            AIMessage(content=f"Response error {e}")
        )

    # 5) Optionally persist the pre-invocation state for audit/debugging.
    if self.log_state:
        self.write_state("execution_agent.json", new_state)

    # Return the model's response and the workspace path as a partial state update.
    return new_state

safety_check(state)

Assess pending shell commands for safety and inject ToolMessages with results.

This method inspects the most recent AI tool calls, evaluates any run_command queries against the safety prompt, and constructs ToolMessages that either flag unsafe commands with reasons or confirm safe execution. If any command is unsafe, the generated ToolMessages are appended to the state so the agent can react without executing the command.

Parameters:

Name Type Description Default
state ExecutionState

Current execution state.

required

Returns:

Name Type Description
ExecutionState ExecutionState

Either the unchanged state (all safe) or a copy with one or more ToolMessages appended when unsafe commands are detected.

Source code in src/ursa/agents/execution_agent.py
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
def safety_check(self, state: ExecutionState) -> ExecutionState:
    """Assess pending shell commands for safety and inject ToolMessages with results.

    This method inspects the most recent AI tool calls, evaluates any run_command
    queries against the safety prompt, and constructs ToolMessages that either
    flag unsafe commands with reasons or confirm safe execution. If any command
    is unsafe, the generated ToolMessages are appended to the state so the agent
    can react without executing the command.

    Args:
        state (ExecutionState): Current execution state.

    Returns:
        ExecutionState: Either the unchanged state (all safe) or a copy with one
            or more ToolMessages appended when unsafe commands are detected.
    """
    # 1) Work on a shallow copy; inspect the most recent model message.
    new_state = state.copy()
    last_msg = new_state["messages"][-1]

    # 1.5) Check message history length and summarize to shorten the token usage:
    new_state = self._summarize_context(new_state)

    # 2) Evaluate any pending run_command tool calls for safety.
    tool_responses: list[ToolMessage] = []
    any_unsafe = False
    for tool_call in last_msg.tool_calls:
        if tool_call["name"] != "run_command":
            continue

        query = tool_call["args"]["query"]
        safety_result = self.llm.invoke(
            self.get_safety_prompt(
                query, self.safe_codes, new_state.get("code_files", [])
            ),
            self.build_config(tags=["safety_check"]),
        )

        if "[NO]" in safety_result.content:
            any_unsafe = True
            tool_response = (
                "[UNSAFE] That command `{q}` was deemed unsafe and cannot be run.\n"
                "For reason: {r}"
            ).format(q=query, r=safety_result.content)
            console.print(
                "[bold red][WARNING][/bold red] Command deemed unsafe:",
                query,
            )
            # Also surface the model's rationale for transparency.
            console.print(
                "[bold red][WARNING][/bold red] REASON:", tool_response
            )
        else:
            tool_response = f"Command `{query}` passed safety check."
            console.print(
                f"[green]Command passed safety check:[/green] {query}"
            )

        tool_responses.append(
            ToolMessage(
                content=tool_response,
                tool_call_id=tool_call["id"],
            )
        )

    # 3) If any command is unsafe, append all tool responses; otherwise keep state.
    if any_unsafe:
        new_state["messages"].extend(tool_responses)

    return new_state

summarize(state)

Produce a concise summary of the conversation and optionally persist memory.

This method builds a summarization prompt, invokes the LLM to obtain a compact summary of recent interactions, optionally logs salient details to the agent memory backend, and writes debug state when logging is enabled.

Parameters:

Name Type Description Default
state ExecutionState

The execution state containing message history.

required

Returns:

Name Type Description
ExecutionState ExecutionState

A partial update with a single string message containing the summary.

Source code in src/ursa/agents/execution_agent.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def summarize(self, state: ExecutionState) -> ExecutionState:
    """Produce a concise summary of the conversation and optionally persist memory.

    This method builds a summarization prompt, invokes the LLM to obtain a compact
    summary of recent interactions, optionally logs salient details to the agent
    memory backend, and writes debug state when logging is enabled.

    Args:
        state (ExecutionState): The execution state containing message history.

    Returns:
        ExecutionState: A partial update with a single string message containing
            the summary.
    """
    new_state = state.copy()

    # 0) Check message history length and summarize to shorten the token usage:
    new_state = self._summarize_context(new_state)

    # 1) Construct the summarization message list (system prompt + prior messages).
    messages = (
        new_state["messages"]
        if isinstance(new_state["messages"][0], SystemMessage)
        else [SystemMessage(content=summarize_prompt)]
        + new_state["messages"]
    )

    # 2) Invoke the LLM to generate a summary; capture content even on failure.
    response_content = ""
    try:
        response = self.llm.invoke(
            messages, self.build_config(tags=["summarize"])
        )
        response_content = response.content
        new_state["messages"].append(response)
    except Exception as e:
        print("Error: ", e, " ", messages[-1].content)
        new_state["messages"].append(
            AIMessage(content=f"Response error {e}")
        )

    # 3) Optionally persist salient details to the memory backend.
    if self.agent_memory:
        memories: list[str] = []
        # Collect human/system/tool message content; for AI tool calls, store args.
        for msg in new_state["messages"]:
            if not isinstance(msg, AIMessage):
                memories.append(msg.content)
            elif not msg.tool_calls:
                memories.append(msg.content)
            else:
                tool_strings = []
                for tool in msg.tool_calls:
                    tool_strings.append("Tool Name: " + tool["name"])
                    for arg_name in tool["args"]:
                        tool_strings.append(
                            f"Arg: {str(arg_name)}\nValue: "
                            f"{str(tool['args'][arg_name])}"
                        )
                memories.append("\n".join(tool_strings))
        memories.append(response_content)
        self.agent_memory.add_memories(memories)

    # 4) Optionally write state to disk for debugging/auditing.
    if self.log_state:
        self.write_state("execution_agent.json", new_state)

    # 5) Return a partial state update with only the summary content.
    return new_state

ExecutionState

Bases: TypedDict

TypedDict representing the execution agent's mutable run state used by nodes.

Fields: - messages: list of messages (System/Human/AI/Tool) with add_messages metadata. - current_progress: short status string describing agent progress. - code_files: list of filenames created or edited in the workspace. - workspace: path to the working directory where files and commands run. - symlinkdir: optional dict describing a symlink operation (source, dest, is_linked).

Source code in src/ursa/agents/execution_agent.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class ExecutionState(TypedDict):
    """TypedDict representing the execution agent's mutable run state used by nodes.

    Fields:
    - messages: list of messages (System/Human/AI/Tool) with add_messages metadata.
    - current_progress: short status string describing agent progress.
    - code_files: list of filenames created or edited in the workspace.
    - workspace: path to the working directory where files and commands run.
    - symlinkdir: optional dict describing a symlink operation (source, dest,
      is_linked).
    """

    messages: Annotated[list[AnyMessage], add_messages]
    current_progress: str
    code_files: list[str]
    workspace: str
    symlinkdir: dict
    model: BaseChatModel

command_safe(state)

Return 'safe' if the last command was safe, otherwise 'unsafe'.

Parameters:

Name Type Description Default
state ExecutionState

The current execution state containing messages and tool calls.

required

Returns: A literal "safe" if no '[UNSAFE]' tags are in the last command, otherwise "unsafe".

Source code in src/ursa/agents/execution_agent.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def command_safe(state: ExecutionState) -> Literal["safe", "unsafe"]:
    """Return 'safe' if the last command was safe, otherwise 'unsafe'.

    Args:
        state: The current execution state containing messages and tool calls.
    Returns:
        A literal "safe" if no '[UNSAFE]' tags are in the last command,
        otherwise "unsafe".
    """
    index = -1
    message = state["messages"][index]
    # Loop through all the consecutive tool messages in reverse order
    while isinstance(message, ToolMessage):
        if "[UNSAFE]" in message.content:
            return "unsafe"

        index -= 1
        message = state["messages"][index]

    return "safe"

should_continue(state)

Return 'summarize' if no tool calls in the last message, else 'continue'.

Parameters:

Name Type Description Default
state ExecutionState

The current execution state containing messages.

required

Returns:

Type Description
Literal['summarize', 'continue']

A literal "summarize" if the last message has no tool calls,

Literal['summarize', 'continue']

otherwise "continue".

Source code in src/ursa/agents/execution_agent.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def should_continue(state: ExecutionState) -> Literal["summarize", "continue"]:
    """Return 'summarize' if no tool calls in the last message, else 'continue'.

    Args:
        state: The current execution state containing messages.

    Returns:
        A literal "summarize" if the last message has no tool calls,
        otherwise "continue".
    """
    messages = state["messages"]
    last_message = messages[-1]
    # If there is no tool call, then we finish
    if not last_message.tool_calls:
        return "summarize"
    # Otherwise if there is, we continue
    else:
        return "continue"

hypothesizer_agent

HypothesizerAgent

Bases: BaseAgent

Source code in src/ursa/agents/hypothesizer_agent.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
class HypothesizerAgent(BaseAgent):
    def __init__(
        self,
        llm: BaseChatModel,
        **kwargs,
    ):
        super().__init__(llm, **kwargs)
        self.hypothesizer_prompt = hypothesizer_prompt
        self.critic_prompt = critic_prompt
        self.competitor_prompt = competitor_prompt
        self.search_tool = DDGS()
        # self.search_tool = TavilySearchResults(
        #     max_results=10, search_depth="advanced", include_answer=False
        # )

        self._action = self._build_graph()

    def agent1_generate_solution(
        self, state: HypothesizerState
    ) -> HypothesizerState:
        """Agent 1: Hypothesizer."""
        print(
            f"[iteration {state['current_iteration']}] Entering agent1_generate_solution. Iteration: {state['current_iteration']}"
        )

        current_iter = state["current_iteration"]
        user_content = f"Question: {state['question']}\n"

        if current_iter > 0:
            user_content += (
                f"\nPrevious solution: {state['agent1_solution'][-1]}"
            )
            user_content += f"\nCritique: {state['agent2_critiques'][-1]}"
            user_content += (
                f"\nCompetitor perspective: {state['agent3_perspectives'][-1]}"
            )
            user_content += (
                "\n\n**You must explicitly list how this new solution differs from the previous solution,** "
                "point by point, explaining what changes were made in response to the critique and competitor perspective."
                "\nAfterward, provide your updated solution."
            )
        else:
            user_content += "Research this problem and generate a solution."

        search_query = self.llm.invoke(
            f"Here is a problem description: {state['question']}. Turn it into a short query to be fed into a search engine."
        ).content
        if '"' in search_query:
            search_query = search_query.split('"')[1]
        raw_search_results = self.search_tool.text(
            search_query, backend="duckduckgo"
        )

        # Parse the results if possible, so we can collect URLs
        new_state = state.copy()
        new_state["question_search_query"] = search_query
        if "visited_sites" not in new_state:
            new_state["visited_sites"] = []

        try:
            if isinstance(raw_search_results, str):
                results_list = ast.literal_eval(raw_search_results)
            else:
                results_list = raw_search_results
            # Each item typically might have "link", "title", "snippet"
            for item in results_list:
                link = item.get("link")
                if link:
                    # print(f"[DEBUG] Appending visited link: {link}")
                    new_state["visited_sites"].append(link)
        except (ValueError, SyntaxError, TypeError):
            # If it's not valid Python syntax or something else goes wrong
            print("[DEBUG] Could not parse search results as Python list.")
            print("[DEBUG] raw_search_results:", raw_search_results)

        user_content += f"\nSearch results: {raw_search_results}"

        # Provide a system message to define this agent's role
        messages = [
            SystemMessage(content=self.hypothesizer_prompt),
            HumanMessage(content=user_content),
        ]
        solution = self.llm.invoke(messages)

        new_state["agent1_solution"].append(solution.content)

        # Print the entire solution in green
        print(
            f"{GREEN}[Agent1 - Hypothesizer solution]\n{solution.content}{RESET}"
        )
        print(
            f"[iteration {state['current_iteration']}] Exiting agent1_generate_solution."
        )
        return new_state

    def agent2_critique(self, state: HypothesizerState) -> HypothesizerState:
        """Agent 2: Critic."""
        print(
            f"[iteration {state['current_iteration']}] Entering agent2_critique."
        )

        solution = state["agent1_solution"][-1]
        user_content = (
            f"Question: {state['question']}\n"
            f"Proposed solution: {solution}\n"
            "Provide a detailed critique of this solution. Identify potential flaws, assumptions, and areas for improvement."
        )

        fact_check_query = f"fact check {state['question_search_query']} solution effectiveness"

        raw_search_results = self.search_tool.text(
            fact_check_query, backend="duckduckgo"
        )

        # Parse the results if possible, so we can collect URLs
        new_state = state.copy()
        if "visited_sites" not in new_state:
            new_state["visited_sites"] = []

        try:
            if isinstance(raw_search_results, str):
                results_list = ast.literal_eval(raw_search_results)
            else:
                results_list = raw_search_results
            # Each item typically might have "link", "title", "snippet"
            for item in results_list:
                link = item.get("link")
                if link:
                    # print(f"[DEBUG] Appending visited link: {link}")
                    new_state["visited_sites"].append(link)
        except (ValueError, SyntaxError, TypeError):
            # If it's not valid Python syntax or something else goes wrong
            print("[DEBUG] Could not parse search results as Python list.")
            print("[DEBUG] raw_search_results:", raw_search_results)

        fact_check_results = raw_search_results
        user_content += f"\nFact check results: {fact_check_results}"

        messages = [
            SystemMessage(content=self.critic_prompt),
            HumanMessage(content=user_content),
        ]
        critique = self.llm.invoke(messages)

        new_state["agent2_critiques"].append(critique.content)

        # Print the entire critique in blue
        print(f"{BLUE}[Agent2 - Critic]\n{critique.content}{RESET}")
        print(
            f"[iteration {state['current_iteration']}] Exiting agent2_critique."
        )
        return new_state

    def agent3_competitor_perspective(
        self, state: HypothesizerState
    ) -> HypothesizerState:
        """Agent 3: Competitor/Stakeholder Simulator."""
        print(
            f"[iteration {state['current_iteration']}] Entering agent3_competitor_perspective."
        )

        solution = state["agent1_solution"][-1]
        critique = state["agent2_critiques"][-1]

        user_content = (
            f"Question: {state['question']}\n"
            f"Proposed solution: {solution}\n"
            f"Critique: {critique}\n"
            "Simulate how a competitor, government agency, or other stakeholder might respond to this solution."
        )

        competitor_search_query = (
            f"competitor responses to {state['question_search_query']}"
        )

        raw_search_results = self.search_tool.text(
            competitor_search_query, backend="duckduckgo"
        )

        # Parse the results if possible, so we can collect URLs
        new_state = state.copy()
        if "visited_sites" not in new_state:
            new_state["visited_sites"] = []

        try:
            if isinstance(raw_search_results, str):
                results_list = ast.literal_eval(raw_search_results)
            else:
                results_list = raw_search_results
            # Each item typically might have "link", "title", "snippet"
            for item in results_list:
                link = item.get("link")
                if link:
                    # print(f"[DEBUG] Appending visited link: {link}")
                    new_state["visited_sites"].append(link)
        except (ValueError, SyntaxError, TypeError):
            # If it's not valid Python syntax or something else goes wrong
            print("[DEBUG] Could not parse search results as Python list.")
            print("[DEBUG] raw_search_results:", raw_search_results)

        competitor_info = raw_search_results
        user_content += f"\nCompetitor information: {competitor_info}"

        messages = [
            SystemMessage(content=self.competitor_prompt),
            HumanMessage(content=user_content),
        ]
        perspective = self.llm.invoke(messages)

        new_state["agent3_perspectives"].append(perspective.content)

        # Print the entire perspective in red
        print(
            f"{RED}[Agent3 - Competitor/Stakeholder Perspective]\n{perspective.content}{RESET}"
        )
        print(
            f"[iteration {state['current_iteration']}] Exiting agent3_competitor_perspective."
        )
        return new_state

    def increment_iteration(
        self, state: HypothesizerState
    ) -> HypothesizerState:
        new_state = state.copy()
        new_state["current_iteration"] += 1
        print(
            f"[iteration {state['current_iteration']}] Iteration incremented to {new_state['current_iteration']}"
        )
        return new_state

    def generate_solution(self, state: HypothesizerState) -> HypothesizerState:
        """Generate the overall, refined solution based on all iterations."""
        print(
            f"[iteration {state['current_iteration']}] Entering generate_solution."
        )
        prompt = f"Original question: {state['question']}\n\n"
        prompt += "Evolution of solutions:\n"

        for i in range(state["max_iterations"]):
            prompt += f"\nIteration {i + 1}:\n"
            prompt += f"Solution: {state['agent1_solution'][i]}\n"
            prompt += f"Critique: {state['agent2_critiques'][i]}\n"
            prompt += (
                f"Competitor perspective: {state['agent3_perspectives'][i]}\n"
            )

        prompt += "\nBased on this iterative process, provide the overall, refined solution."

        print(
            f"[iteration {state['current_iteration']}] Generating overall solution with LLM..."
        )
        solution = self.llm.invoke(prompt)
        print(
            f"[iteration {state['current_iteration']}] Overall solution obtained. Preview:",
            solution.content[:200],
            "...",
        )

        new_state = state.copy()
        new_state["solution"] = solution.content

        print(
            f"[iteration {state['current_iteration']}] Exiting generate_solution."
        )
        return new_state

    def print_visited_sites(
        self, state: HypothesizerState
    ) -> HypothesizerState:
        new_state = state.copy()
        # all_sites = new_state.get("visited_sites", [])
        # print("[DEBUG] Visited Sites:")
        # for s in all_sites:
        #     print("  ", s)
        return new_state

    def summarize_process_as_latex(
        self, state: HypothesizerState
    ) -> HypothesizerState:
        """
        Summarize how the solution changed over time, referencing
        each iteration's critique and competitor perspective,
        then produce a final LaTeX document.
        """
        print("Entering summarize_process_as_latex.")
        llm_model = state.get("llm_model", "unknown_model")

        # Build a single string describing the entire iterative process
        iteration_details = ""
        for i, (sol, crit, comp) in enumerate(
            zip(
                state["agent1_solution"],
                state["agent2_critiques"],
                state["agent3_perspectives"],
            ),
            start=1,
        ):
            iteration_details += (
                f"\\subsection*{{Iteration {i}}}\n\n"
                f"\\textbf{{Solution:}}\\\\\n{sol}\n\n"
                f"\\textbf{{Critique:}}\\\\\n{crit}\n\n"
                f"\\textbf{{Competitor Perspective:}}\\\\\n{comp}\n\n"
            )

        # -----------------------------
        # Write iteration_details to disk as .txt
        # -----------------------------
        timestamp_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        txt_filename = (
            f"iteration_details_{llm_model}_{timestamp_str}_chat_history.txt"
        )
        with open(txt_filename, "w", encoding="utf-8") as f:
            f.write(iteration_details)

        print(f"Wrote iteration details to {txt_filename}.")

        # Prompt the LLM to produce a LaTeX doc
        # We'll just pass it as a single string to the LLM;
        # you could also do system+human messages if you prefer.
        prompt = f"""\
            You are a system that produces a FULL LaTeX document.
            Here is information about a multi-iteration process:

            Original question: {state["question"]}

            Below are the solutions, critiques, and competitor perspectives from each iteration:

            {iteration_details}

            The solution we arrived at was:

            {state["solution"]}

            Now produce a valid LaTeX document.  Be sure to use a table of contents.
            It must start with an Executive Summary (that may be multiple pages) which summarizes
            the entire iterative process.  Following that, we should include the solution in full,
            not summarized, but reformatted for appropriate LaTeX.  And then, finally (and this will be
            quite long), we must take all the steps - solutions, critiques, and competitor perspectives
            and *NOT SUMMARIZE THEM* but merely reformat them for the reader.  This will be in an Appendix
            of the full content of the steps.  Finally, include a listing of all of the websites we
            used in our research.

            You must ONLY RETURN LaTeX, nothing else.  It must be valid LaTeX syntax!

            Your output should start with:
            \\documentclass{{article}}
            \\usepackage[margin=1in]{{geometry}}
            etc.

            It must compile without errors under pdflatex. 
        """

        # Now produce a valid LaTeX document that nicely summarizes this entire iterative process.
        # It must include the overall solution in full, not summarized, but reformatted for appropriate
        # LaTeX. The summarization is for the other steps.

        # all_visited_sites = state.get("visited_sites", [])
        # (Optional) remove duplicates by converting to a set, then back to a list
        # visited_sites_unique = list(set(all_visited_sites))
        # if visited_sites_unique:
        #     websites_latex = "\\section*{Websites Visited}\\begin{itemize}\n"
        #     for url in visited_sites_unique:
        #         print(f"We visited: {url}")
        #         # Use \url{} to handle special characters in URLs
        #         websites_latex += f"\\item \\url{{{url}}}\n"
        #     websites_latex += "\\end{itemize}\n\n"
        # else:
        #     # If no sites visited, or the list is empty
        #     websites_latex = (
        #         "\\section*{Websites Visited}\nNo sites were visited.\n\n"
        #     )
        # print(websites_latex)
        websites_latex = ""

        # Ask the LLM to produce *only* LaTeX content
        latex_response = self.llm.invoke(prompt)

        latex_doc = latex_response.content

        def inject_into_latex(original_tex: str, injection: str) -> str:
            """
            Find the last occurrence of '\\end{document}' in 'original_tex'
            and insert 'injection' right before it.
            If '\\end{document}' is not found, just append the injection at the end.
            """
            injection_index = original_tex.rfind(r"\end{document}")
            if injection_index == -1:
                # If the LLM didn't include \end{document}, just append
                return original_tex + "\n" + injection
            else:
                # Insert right before \end{document}
                return (
                    original_tex[:injection_index]
                    + "\n"
                    + injection
                    + "\n"
                    + original_tex[injection_index:]
                )

        final_latex = inject_into_latex(latex_doc, websites_latex)

        new_state = state.copy()
        new_state["summary_report"] = final_latex

        print(
            f"[iteration {state['current_iteration']}] Received LaTeX from LLM. Preview:"
        )
        print(latex_response.content[:300], "...")
        print(
            f"[iteration {state['current_iteration']}] Exiting summarize_process_as_latex."
        )
        return new_state

    def _build_graph(self):
        # Initialize the graph
        graph = StateGraph(HypothesizerState)

        # Add nodes
        self.add_node(graph, self.agent1_generate_solution, "agent1")
        self.add_node(graph, self.agent2_critique, "agent2")
        self.add_node(graph, self.agent3_competitor_perspective, "agent3")
        self.add_node(graph, self.increment_iteration, "increment_iteration")
        self.add_node(graph, self.generate_solution, "finalize")
        self.add_node(graph, self.print_visited_sites, "print_sites")
        self.add_node(
            graph, self.summarize_process_as_latex, "summarize_as_latex"
        )
        # self.graph.add_node("compile_pdf",                compile_summary_to_pdf)

        # Add simple edges for the known flow
        graph.add_edge("agent1", "agent2")
        graph.add_edge("agent2", "agent3")
        graph.add_edge("agent3", "increment_iteration")

        # Then from increment_iteration, we have a conditional:
        # If we 'continue', we go back to agent1
        # If we 'finish', we jump to the finalize node
        graph.add_conditional_edges(
            "increment_iteration",
            should_continue,
            {"continue": "agent1", "finish": "finalize"},
        )

        graph.add_edge("finalize", "summarize_as_latex")
        graph.add_edge("summarize_as_latex", "print_sites")
        # self.graph.add_edge("summarize_as_latex", "compile_pdf")
        # self.graph.add_edge("compile_pdf", "print_sites")

        # Set the entry point
        graph.set_entry_point("agent1")
        graph.set_finish_point("print_sites")

        return graph.compile(checkpointer=self.checkpointer)
        # self.action.get_graph().draw_mermaid_png(output_file_path="hypothesizer_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)

    def _invoke(
        self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
    ):
        config = self.build_config(
            recursion_limit=recursion_limit, tags=["graph"]
        )
        if "prompt" not in inputs:
            raise KeyError("'prompt' is a required arguments")

        inputs["question"] = inputs["prompt"]
        inputs["max_iterations"] = inputs.get("max_iterations", 3)
        inputs["current_iteration"] = 0
        inputs["agent1_solution"] = []
        inputs["agent2_critiques"] = []
        inputs["agent3_perspectives"] = []
        inputs["solution"] = ""

        return self._action.invoke(inputs, config)

agent1_generate_solution(state)

Agent 1: Hypothesizer.

Source code in src/ursa/agents/hypothesizer_agent.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def agent1_generate_solution(
    self, state: HypothesizerState
) -> HypothesizerState:
    """Agent 1: Hypothesizer."""
    print(
        f"[iteration {state['current_iteration']}] Entering agent1_generate_solution. Iteration: {state['current_iteration']}"
    )

    current_iter = state["current_iteration"]
    user_content = f"Question: {state['question']}\n"

    if current_iter > 0:
        user_content += (
            f"\nPrevious solution: {state['agent1_solution'][-1]}"
        )
        user_content += f"\nCritique: {state['agent2_critiques'][-1]}"
        user_content += (
            f"\nCompetitor perspective: {state['agent3_perspectives'][-1]}"
        )
        user_content += (
            "\n\n**You must explicitly list how this new solution differs from the previous solution,** "
            "point by point, explaining what changes were made in response to the critique and competitor perspective."
            "\nAfterward, provide your updated solution."
        )
    else:
        user_content += "Research this problem and generate a solution."

    search_query = self.llm.invoke(
        f"Here is a problem description: {state['question']}. Turn it into a short query to be fed into a search engine."
    ).content
    if '"' in search_query:
        search_query = search_query.split('"')[1]
    raw_search_results = self.search_tool.text(
        search_query, backend="duckduckgo"
    )

    # Parse the results if possible, so we can collect URLs
    new_state = state.copy()
    new_state["question_search_query"] = search_query
    if "visited_sites" not in new_state:
        new_state["visited_sites"] = []

    try:
        if isinstance(raw_search_results, str):
            results_list = ast.literal_eval(raw_search_results)
        else:
            results_list = raw_search_results
        # Each item typically might have "link", "title", "snippet"
        for item in results_list:
            link = item.get("link")
            if link:
                # print(f"[DEBUG] Appending visited link: {link}")
                new_state["visited_sites"].append(link)
    except (ValueError, SyntaxError, TypeError):
        # If it's not valid Python syntax or something else goes wrong
        print("[DEBUG] Could not parse search results as Python list.")
        print("[DEBUG] raw_search_results:", raw_search_results)

    user_content += f"\nSearch results: {raw_search_results}"

    # Provide a system message to define this agent's role
    messages = [
        SystemMessage(content=self.hypothesizer_prompt),
        HumanMessage(content=user_content),
    ]
    solution = self.llm.invoke(messages)

    new_state["agent1_solution"].append(solution.content)

    # Print the entire solution in green
    print(
        f"{GREEN}[Agent1 - Hypothesizer solution]\n{solution.content}{RESET}"
    )
    print(
        f"[iteration {state['current_iteration']}] Exiting agent1_generate_solution."
    )
    return new_state

agent2_critique(state)

Agent 2: Critic.

Source code in src/ursa/agents/hypothesizer_agent.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def agent2_critique(self, state: HypothesizerState) -> HypothesizerState:
    """Agent 2: Critic."""
    print(
        f"[iteration {state['current_iteration']}] Entering agent2_critique."
    )

    solution = state["agent1_solution"][-1]
    user_content = (
        f"Question: {state['question']}\n"
        f"Proposed solution: {solution}\n"
        "Provide a detailed critique of this solution. Identify potential flaws, assumptions, and areas for improvement."
    )

    fact_check_query = f"fact check {state['question_search_query']} solution effectiveness"

    raw_search_results = self.search_tool.text(
        fact_check_query, backend="duckduckgo"
    )

    # Parse the results if possible, so we can collect URLs
    new_state = state.copy()
    if "visited_sites" not in new_state:
        new_state["visited_sites"] = []

    try:
        if isinstance(raw_search_results, str):
            results_list = ast.literal_eval(raw_search_results)
        else:
            results_list = raw_search_results
        # Each item typically might have "link", "title", "snippet"
        for item in results_list:
            link = item.get("link")
            if link:
                # print(f"[DEBUG] Appending visited link: {link}")
                new_state["visited_sites"].append(link)
    except (ValueError, SyntaxError, TypeError):
        # If it's not valid Python syntax or something else goes wrong
        print("[DEBUG] Could not parse search results as Python list.")
        print("[DEBUG] raw_search_results:", raw_search_results)

    fact_check_results = raw_search_results
    user_content += f"\nFact check results: {fact_check_results}"

    messages = [
        SystemMessage(content=self.critic_prompt),
        HumanMessage(content=user_content),
    ]
    critique = self.llm.invoke(messages)

    new_state["agent2_critiques"].append(critique.content)

    # Print the entire critique in blue
    print(f"{BLUE}[Agent2 - Critic]\n{critique.content}{RESET}")
    print(
        f"[iteration {state['current_iteration']}] Exiting agent2_critique."
    )
    return new_state

agent3_competitor_perspective(state)

Agent 3: Competitor/Stakeholder Simulator.

Source code in src/ursa/agents/hypothesizer_agent.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def agent3_competitor_perspective(
    self, state: HypothesizerState
) -> HypothesizerState:
    """Agent 3: Competitor/Stakeholder Simulator."""
    print(
        f"[iteration {state['current_iteration']}] Entering agent3_competitor_perspective."
    )

    solution = state["agent1_solution"][-1]
    critique = state["agent2_critiques"][-1]

    user_content = (
        f"Question: {state['question']}\n"
        f"Proposed solution: {solution}\n"
        f"Critique: {critique}\n"
        "Simulate how a competitor, government agency, or other stakeholder might respond to this solution."
    )

    competitor_search_query = (
        f"competitor responses to {state['question_search_query']}"
    )

    raw_search_results = self.search_tool.text(
        competitor_search_query, backend="duckduckgo"
    )

    # Parse the results if possible, so we can collect URLs
    new_state = state.copy()
    if "visited_sites" not in new_state:
        new_state["visited_sites"] = []

    try:
        if isinstance(raw_search_results, str):
            results_list = ast.literal_eval(raw_search_results)
        else:
            results_list = raw_search_results
        # Each item typically might have "link", "title", "snippet"
        for item in results_list:
            link = item.get("link")
            if link:
                # print(f"[DEBUG] Appending visited link: {link}")
                new_state["visited_sites"].append(link)
    except (ValueError, SyntaxError, TypeError):
        # If it's not valid Python syntax or something else goes wrong
        print("[DEBUG] Could not parse search results as Python list.")
        print("[DEBUG] raw_search_results:", raw_search_results)

    competitor_info = raw_search_results
    user_content += f"\nCompetitor information: {competitor_info}"

    messages = [
        SystemMessage(content=self.competitor_prompt),
        HumanMessage(content=user_content),
    ]
    perspective = self.llm.invoke(messages)

    new_state["agent3_perspectives"].append(perspective.content)

    # Print the entire perspective in red
    print(
        f"{RED}[Agent3 - Competitor/Stakeholder Perspective]\n{perspective.content}{RESET}"
    )
    print(
        f"[iteration {state['current_iteration']}] Exiting agent3_competitor_perspective."
    )
    return new_state

generate_solution(state)

Generate the overall, refined solution based on all iterations.

Source code in src/ursa/agents/hypothesizer_agent.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def generate_solution(self, state: HypothesizerState) -> HypothesizerState:
    """Generate the overall, refined solution based on all iterations."""
    print(
        f"[iteration {state['current_iteration']}] Entering generate_solution."
    )
    prompt = f"Original question: {state['question']}\n\n"
    prompt += "Evolution of solutions:\n"

    for i in range(state["max_iterations"]):
        prompt += f"\nIteration {i + 1}:\n"
        prompt += f"Solution: {state['agent1_solution'][i]}\n"
        prompt += f"Critique: {state['agent2_critiques'][i]}\n"
        prompt += (
            f"Competitor perspective: {state['agent3_perspectives'][i]}\n"
        )

    prompt += "\nBased on this iterative process, provide the overall, refined solution."

    print(
        f"[iteration {state['current_iteration']}] Generating overall solution with LLM..."
    )
    solution = self.llm.invoke(prompt)
    print(
        f"[iteration {state['current_iteration']}] Overall solution obtained. Preview:",
        solution.content[:200],
        "...",
    )

    new_state = state.copy()
    new_state["solution"] = solution.content

    print(
        f"[iteration {state['current_iteration']}] Exiting generate_solution."
    )
    return new_state

summarize_process_as_latex(state)

Summarize how the solution changed over time, referencing each iteration's critique and competitor perspective, then produce a final LaTeX document.

Source code in src/ursa/agents/hypothesizer_agent.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
def summarize_process_as_latex(
    self, state: HypothesizerState
) -> HypothesizerState:
    """
    Summarize how the solution changed over time, referencing
    each iteration's critique and competitor perspective,
    then produce a final LaTeX document.
    """
    print("Entering summarize_process_as_latex.")
    llm_model = state.get("llm_model", "unknown_model")

    # Build a single string describing the entire iterative process
    iteration_details = ""
    for i, (sol, crit, comp) in enumerate(
        zip(
            state["agent1_solution"],
            state["agent2_critiques"],
            state["agent3_perspectives"],
        ),
        start=1,
    ):
        iteration_details += (
            f"\\subsection*{{Iteration {i}}}\n\n"
            f"\\textbf{{Solution:}}\\\\\n{sol}\n\n"
            f"\\textbf{{Critique:}}\\\\\n{crit}\n\n"
            f"\\textbf{{Competitor Perspective:}}\\\\\n{comp}\n\n"
        )

    # -----------------------------
    # Write iteration_details to disk as .txt
    # -----------------------------
    timestamp_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    txt_filename = (
        f"iteration_details_{llm_model}_{timestamp_str}_chat_history.txt"
    )
    with open(txt_filename, "w", encoding="utf-8") as f:
        f.write(iteration_details)

    print(f"Wrote iteration details to {txt_filename}.")

    # Prompt the LLM to produce a LaTeX doc
    # We'll just pass it as a single string to the LLM;
    # you could also do system+human messages if you prefer.
    prompt = f"""\
        You are a system that produces a FULL LaTeX document.
        Here is information about a multi-iteration process:

        Original question: {state["question"]}

        Below are the solutions, critiques, and competitor perspectives from each iteration:

        {iteration_details}

        The solution we arrived at was:

        {state["solution"]}

        Now produce a valid LaTeX document.  Be sure to use a table of contents.
        It must start with an Executive Summary (that may be multiple pages) which summarizes
        the entire iterative process.  Following that, we should include the solution in full,
        not summarized, but reformatted for appropriate LaTeX.  And then, finally (and this will be
        quite long), we must take all the steps - solutions, critiques, and competitor perspectives
        and *NOT SUMMARIZE THEM* but merely reformat them for the reader.  This will be in an Appendix
        of the full content of the steps.  Finally, include a listing of all of the websites we
        used in our research.

        You must ONLY RETURN LaTeX, nothing else.  It must be valid LaTeX syntax!

        Your output should start with:
        \\documentclass{{article}}
        \\usepackage[margin=1in]{{geometry}}
        etc.

        It must compile without errors under pdflatex. 
    """

    # Now produce a valid LaTeX document that nicely summarizes this entire iterative process.
    # It must include the overall solution in full, not summarized, but reformatted for appropriate
    # LaTeX. The summarization is for the other steps.

    # all_visited_sites = state.get("visited_sites", [])
    # (Optional) remove duplicates by converting to a set, then back to a list
    # visited_sites_unique = list(set(all_visited_sites))
    # if visited_sites_unique:
    #     websites_latex = "\\section*{Websites Visited}\\begin{itemize}\n"
    #     for url in visited_sites_unique:
    #         print(f"We visited: {url}")
    #         # Use \url{} to handle special characters in URLs
    #         websites_latex += f"\\item \\url{{{url}}}\n"
    #     websites_latex += "\\end{itemize}\n\n"
    # else:
    #     # If no sites visited, or the list is empty
    #     websites_latex = (
    #         "\\section*{Websites Visited}\nNo sites were visited.\n\n"
    #     )
    # print(websites_latex)
    websites_latex = ""

    # Ask the LLM to produce *only* LaTeX content
    latex_response = self.llm.invoke(prompt)

    latex_doc = latex_response.content

    def inject_into_latex(original_tex: str, injection: str) -> str:
        """
        Find the last occurrence of '\\end{document}' in 'original_tex'
        and insert 'injection' right before it.
        If '\\end{document}' is not found, just append the injection at the end.
        """
        injection_index = original_tex.rfind(r"\end{document}")
        if injection_index == -1:
            # If the LLM didn't include \end{document}, just append
            return original_tex + "\n" + injection
        else:
            # Insert right before \end{document}
            return (
                original_tex[:injection_index]
                + "\n"
                + injection
                + "\n"
                + original_tex[injection_index:]
            )

    final_latex = inject_into_latex(latex_doc, websites_latex)

    new_state = state.copy()
    new_state["summary_report"] = final_latex

    print(
        f"[iteration {state['current_iteration']}] Received LaTeX from LLM. Preview:"
    )
    print(latex_response.content[:300], "...")
    print(
        f"[iteration {state['current_iteration']}] Exiting summarize_process_as_latex."
    )
    return new_state

mp_agent

MaterialsProjectAgent

Bases: BaseAgent

Source code in src/ursa/agents/mp_agent.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class MaterialsProjectAgent(BaseAgent):
    def __init__(
        self,
        llm: BaseChatModel,
        summarize: bool = True,
        max_results: int = 3,
        database_path: str = "mp_database",
        summaries_path: str = "mp_summaries",
        **kwargs,
    ):
        super().__init__(llm, **kwargs)
        self.summarize = summarize
        self.max_results = max_results
        self.database_path = database_path
        self.summaries_path = summaries_path

        os.makedirs(self.database_path, exist_ok=True)
        os.makedirs(self.summaries_path, exist_ok=True)

        self._action = self._build_graph()

    def _fetch_node(self, state: dict) -> dict:
        f = state["query"]
        els = f["elements"]  # e.g. ["Ga","In"]
        bg = (f["band_gap_min"], f["band_gap_max"])
        e_above_hull = (0, 0)  # only on-hull (stable)
        mats = []
        with MPRester() as mpr:
            # get ALL matching materials…
            all_results = mpr.materials.summary.search(
                elements=els,
                band_gap=bg,
                energy_above_hull=e_above_hull,
                is_stable=True,  # equivalent filter
            )
            # …then take only the first `max_results`
            for doc in all_results[: self.max_results]:
                mid = doc.material_id
                data = doc.dict()
                # cache to disk
                path = os.path.join(self.database_path, f"{mid}.json")
                if not os.path.exists(path):
                    with open(path, "w") as f:
                        json.dump(data, f, indent=2)
                mats.append({"material_id": mid, "metadata": data})

        return {**state, "materials": mats}

    def _summarize_node(self, state: dict) -> dict:
        """Summarize each material via LLM over its metadata."""
        # prompt template
        prompt = ChatPromptTemplate.from_template("""
You are a materials-science assistant. Given the following metadata about a material, produce a concise summary focusing on its key properties:

{metadata}
        """)
        chain = prompt | self.llm | StrOutputParser()

        summaries = [None] * len(state["materials"])

        def process(i, mat):
            mid = mat["material_id"]
            meta = mat["metadata"]
            # flatten metadata to text
            text = "\n".join(f"{k}: {v}" for k, v in meta.items())
            # build or load summary
            summary_file = os.path.join(
                self.summaries_path, f"{mid}_summary.txt"
            )
            if os.path.exists(summary_file):
                with open(summary_file) as f:
                    return i, f.read()
            # optional: vectorize & retrieve, but here we just summarize full text
            result = chain.invoke({"metadata": text})
            with open(summary_file, "w") as f:
                f.write(result)
            return i, result

        with ThreadPoolExecutor(
            max_workers=min(8, len(state["materials"]))
        ) as exe:
            futures = [
                exe.submit(process, i, m)
                for i, m in enumerate(state["materials"])
            ]
            for future in tqdm(futures, desc="Summarizing materials"):
                i, summ = future.result()
                summaries[i] = summ

        return {**state, "summaries": summaries}

    def _aggregate_node(self, state: dict) -> dict:
        """Combine all summaries into a single, coherent answer."""
        combined = "\n\n----\n\n".join(
            f"[{i + 1}] {m['material_id']}\n\n{summary}"
            for i, (m, summary) in enumerate(
                zip(state["materials"], state["summaries"])
            )
        )

        prompt = ChatPromptTemplate.from_template("""
        You are a materials informatics assistant. Below are brief summaries of several materials:

        {summaries}

        Answer the user’s question in context:

        {context}
                """)
        chain = prompt | self.llm | StrOutputParser()
        final = chain.invoke({
            "summaries": combined,
            "context": state["context"],
        })
        return {**state, "final_summary": final}

    def _build_graph(self):
        graph = StateGraph(dict)  # using plain dict for state
        self.add_node(graph, self._fetch_node)
        if self.summarize:
            self.add_node(graph, self._summarize_node)
            self.add_node(graph, self._aggregate_node)

            graph.set_entry_point("_fetch_node")
            graph.add_edge("_fetch_node", "_summarize_node")
            graph.add_edge("_summarize_node", "_aggregate_node")
            graph.set_finish_point("_aggregate_node")
        else:
            graph.set_entry_point("_fetch_node")
            graph.set_finish_point("_fetch_node")
        return graph.compile(checkpointer=self.checkpointer)

    def _invoke(
        self,
        inputs: Mapping[str, Any],
        *,
        summarize: bool | None = None,
        recursion_limit: int = 1000,
        **_,
    ) -> str:
        config = self.build_config(
            recursion_limit=recursion_limit, tags=["graph"]
        )

        if "query" not in inputs:
            if "mp_query" in inputs:
                # make a shallow copy and rename the key
                inputs = dict(inputs)
                inputs["query"] = inputs.pop("mp_query")
            else:
                raise KeyError(
                    "Missing 'query' in inputs (alias 'mp_query' also accepted)."
                )

        result = self._action.invoke(inputs, config)

        use_summary = self.summarize if summarize is None else summarize
        return (
            result.get("final_summary", "No summary generated.")
            if use_summary
            else "\n\nFinished Fetching Materials Database Information!"
        )

optimization_agent

run_cmd(query, state)

Run a commandline command from using the subprocess package in python

Parameters:

Name Type Description Default
query str

commandline command to be run as a string given to the subprocess.run command.

required
Source code in src/ursa/agents/optimization_agent.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
@tool
def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
    """
    Run a commandline command from using the subprocess package in python

    Args:
        query: commandline command to be run as a string given to the subprocess.run command.
    """
    workspace_dir = state["workspace"]
    print("RUNNING: ", query)
    try:
        process = subprocess.Popen(
            query.split(" "),
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            cwd=workspace_dir,
        )

        stdout, stderr = process.communicate(timeout=60000)
    except KeyboardInterrupt:
        print("Keyboard Interrupt of command: ", query)
        stdout, stderr = "", "KeyboardInterrupt:"

    print("STDOUT: ", stdout)
    print("STDERR: ", stderr)

    return f"STDOUT: {stdout} and STDERR: {stderr}"

write_code(code, filename, state)

Writes python or Julia code to a file in the given workspace as requested.

Parameters:

Name Type Description Default
code str

The code to write

required
filename str

the filename with an appropriate extension for programming language (.py for python, .jl for Julia, etc.)

required

Returns:

Type Description
str

Execution results

Source code in src/ursa/agents/optimization_agent.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
@tool
def write_code(
    code: str, filename: str, state: Annotated[dict, InjectedState]
) -> str:
    """
    Writes python or Julia code to a file in the given workspace as requested.

    Args:
        code: The code to write
        filename: the filename with an appropriate extension for programming language (.py for python, .jl for Julia, etc.)

    Returns:
        Execution results
    """
    workspace_dir = state["workspace"]
    print("Writing filename ", filename)
    try:
        # Extract code if wrapped in markdown code blocks
        if "```" in code:
            code_parts = code.split("```")
            if len(code_parts) >= 3:
                # Extract the actual code
                if "\n" in code_parts[1]:
                    code = "\n".join(code_parts[1].strip().split("\n")[1:])
                else:
                    code = code_parts[2].strip()

        # Write code to a file
        code_file = os.path.join(workspace_dir, filename)

        with open(code_file, "w") as f:
            f.write(code)
        print(f"Written code to file: {code_file}")

        return f"File {filename} written successfully."

    except Exception as e:
        print(f"Error generating code: {str(e)}")
        # Return minimal code that prints the error
        return f"Failed to write {filename} successfully."

planning_agent

PlanningAgent

Bases: BaseAgent

Source code in src/ursa/agents/planning_agent.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class PlanningAgent(BaseAgent):
    def __init__(
        self,
        llm: BaseChatModel,
        max_reflection_steps: int = 1,
        **kwargs,
    ):
        super().__init__(llm, **kwargs)
        self.planner_prompt = planner_prompt
        self.reflection_prompt = reflection_prompt
        self._action = self._build_graph()
        self.max_reflection_steps = max_reflection_steps

    def generation_node(self, state: PlanningState) -> PlanningState:
        """
        Plan generation with structured output. Produces a JSON string in messages
        and a parsed list of steps in state["plan_steps"].
        """

        print("PlanningAgent: generating . . .")

        messages = cast(list, state.get("messages"))
        if isinstance(messages[0], SystemMessage):
            messages[0] = SystemMessage(content=self.planner_prompt)
        else:
            messages = [SystemMessage(content=self.planner_prompt)] + messages

        structured_llm = self.llm.with_structured_output(Plan)
        plan_obj = cast(
            Plan,
            structured_llm.invoke(
                messages, self.build_config(tags=["planner"])
            ),
        )

        try:
            json_text = plan_obj.model_dump_json(indent=2)

        except Exception as e:
            raise RuntimeError(
                f"Failed to serialize Plan object with Pydantic v2: {e}"
            )

        return {
            "messages": [AIMessage(content=json_text)],
            "plan_steps": [
                cast(PlanStep, step.model_dump()) for step in plan_obj.steps
            ],
            "reflection_steps": state["reflection_steps"],
        }

    def reflection_node(self, state: PlanningState) -> PlanningState:
        print("PlanningAgent: reflecting . . .")

        cls_map = {"ai": HumanMessage, "human": AIMessage}
        translated = [state["messages"][0]] + [
            cls_map[msg.type](content=msg.content)
            for msg in state["messages"][1:]
        ]
        translated = [SystemMessage(content=reflection_prompt)] + translated
        res = self.llm.invoke(
            translated,
            self.build_config(tags=["planner", "reflect"]),
        )
        return {
            "messages": [HumanMessage(content=res.content)],
            "reflection_steps": state["reflection_steps"] - 1,
        }

    def _build_graph(self):
        graph = StateGraph(PlanningState)
        self.add_node(graph, self.generation_node, "generate")
        self.add_node(graph, self.reflection_node, "reflect")
        graph.set_entry_point("generate")
        graph.add_conditional_edges(
            "generate",
            self._wrap_cond(
                _should_reflect, "should_reflect", "planning_agent"
            ),
            {"reflect": "reflect", "END": END},
        )
        graph.add_conditional_edges(
            "reflect",
            self._wrap_cond(
                _should_regenerate, "should_regenerate", "planning_agent"
            ),
            {"generate": "generate", "END": END},
        )
        return graph.compile(checkpointer=self.checkpointer)

    def _invoke(
        self, inputs: dict[str, Any], recursion_limit: int = 999999, **_
    ):
        config = self.build_config(
            recursion_limit=recursion_limit, tags=["planner"]
        )
        inputs.setdefault("reflection_steps", self.max_reflection_steps)
        return self._action.invoke(inputs, config)

    def _stream(
        self,
        inputs: dict[str, Any],
        *,
        config: dict | None = None,
        recursion_limit: int = 999999,
        **_,
    ) -> Iterator[dict]:
        # If you have defaults, merge them here:
        default = self.build_config(
            recursion_limit=recursion_limit, tags=["planner"]
        )
        if config:
            merged = {**default, **config}
            if "configurable" in config:
                merged["configurable"] = {
                    **default.get("configurable", {}),
                    **config["configurable"],
                }
        else:
            merged = default

        inputs.setdefault("reflection_steps", self.max_reflection_steps)
        # Delegate to the compiled graph's stream
        yield from self._action.stream(inputs, merged)

generation_node(state)

Plan generation with structured output. Produces a JSON string in messages and a parsed list of steps in state["plan_steps"].

Source code in src/ursa/agents/planning_agent.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def generation_node(self, state: PlanningState) -> PlanningState:
    """
    Plan generation with structured output. Produces a JSON string in messages
    and a parsed list of steps in state["plan_steps"].
    """

    print("PlanningAgent: generating . . .")

    messages = cast(list, state.get("messages"))
    if isinstance(messages[0], SystemMessage):
        messages[0] = SystemMessage(content=self.planner_prompt)
    else:
        messages = [SystemMessage(content=self.planner_prompt)] + messages

    structured_llm = self.llm.with_structured_output(Plan)
    plan_obj = cast(
        Plan,
        structured_llm.invoke(
            messages, self.build_config(tags=["planner"])
        ),
    )

    try:
        json_text = plan_obj.model_dump_json(indent=2)

    except Exception as e:
        raise RuntimeError(
            f"Failed to serialize Plan object with Pydantic v2: {e}"
        )

    return {
        "messages": [AIMessage(content=json_text)],
        "plan_steps": [
            cast(PlanStep, step.model_dump()) for step in plan_obj.steps
        ],
        "reflection_steps": state["reflection_steps"],
    }

PlanningState

Bases: TypedDict

State dictionary for planning agent

Source code in src/ursa/agents/planning_agent.py
40
41
42
43
44
45
46
47
48
49
class PlanningState(TypedDict, total=False):
    """State dictionary for planning agent"""

    messages: Annotated[list, add_messages]

    # Ordered steps in the solution plan
    plan_steps: list[PlanStep]

    # Number of reflection steps
    reflection_steps: Required[int]

websearch_agent

WebSearchAgentLegacy

Bases: BaseAgent

Source code in src/ursa/agents/websearch_agent.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class WebSearchAgentLegacy(BaseAgent):
    def __init__(
        self,
        llm: BaseChatModel,
        **kwargs,
    ):
        super().__init__(llm, **kwargs)
        self.websearch_prompt = websearch_prompt
        self.reflection_prompt = reflection_prompt
        self.tools = [search_tool, process_content]  # + cb_tools
        self.has_internet = self._check_for_internet(
            kwargs.get("url", "http://www.lanl.gov")
        )
        self._build_graph()

    def _review_node(self, state: WebSearchState) -> WebSearchState:
        if not self.has_internet:
            return {
                "messages": [
                    HumanMessage(
                        content="No internet for WebSearch Agent so no research to review."
                    )
                ],
                "urls_visited": [],
            }

        translated = [SystemMessage(content=reflection_prompt)] + state[
            "messages"
        ]
        res = self.llm.invoke(
            translated, {"configurable": {"thread_id": self.thread_id}}
        )
        return {"messages": [HumanMessage(content=res.content)]}

    def _response_node(self, state: WebSearchState) -> WebSearchState:
        if not self.has_internet:
            return {
                "messages": [
                    HumanMessage(
                        content="No internet for WebSearch Agent. No research carried out."
                    )
                ],
                "urls_visited": [],
            }

        messages = state["messages"] + [SystemMessage(content=summarize_prompt)]
        response = self.llm.invoke(
            messages, {"configurable": {"thread_id": self.thread_id}}
        )

        urls_visited = []
        for message in messages:
            if message.model_dump().get("tool_calls", []):
                if "url" in message.tool_calls[0]["args"]:
                    urls_visited.append(message.tool_calls[0]["args"]["url"])
        return {"messages": [response.content], "urls_visited": urls_visited}

    def _check_for_internet(self, url, timeout=2):
        """
        Checks for internet connectivity by attempting an HTTP GET request.
        """
        try:
            requests.get(url, timeout=timeout)
            return True
        except (requests.ConnectionError, requests.Timeout):
            return False

    def _state_store_node(self, state: WebSearchState) -> WebSearchState:
        state["thread_id"] = self.thread_id
        return state
        # return dict(**state, thread_id=self.thread_id)

    def _create_react(self, state: WebSearchState) -> WebSearchState:
        react_agent = create_agent(
            self.llm,
            self.tools,
            state_schema=WebSearchState,
            system_prompt=self.websearch_prompt,
        )
        return react_agent.invoke(state)

    def _build_graph(self):
        graph = StateGraph(WebSearchState)
        self.add_node(graph, self._state_store_node)
        self.add_node(graph, self._create_react)
        self.add_node(graph, self._review_node)
        self.add_node(graph, self._response_node)

        graph.set_entry_point("_state_store_node")
        graph.add_edge("_state_store_node", "_create_react")
        graph.add_edge("_create_react", "_review_node")
        graph.set_finish_point("_response_node")

        graph.add_conditional_edges(
            "_review_node",
            should_continue,
            {
                "_create_react": "_create_react",
                "_response_node": "_response_node",
            },
        )
        self._action = graph.compile(checkpointer=self.checkpointer)
        # self._action.get_graph().draw_mermaid_png(output_file_path="./websearch_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)

    def _invoke(
        self, inputs: Mapping[str, Any], recursion_limit: int = 1000, **_
    ):
        config = self.build_config(
            recursion_limit=recursion_limit, tags=["graph"]
        )
        return self._action.invoke(inputs, config)

process_content(url, context, state)

Processes content from a given webpage.

Parameters:

Name Type Description Default
url str

string with the url to obtain text content from.

required
context str

string summary of the information the agent wants from the url for summarizing salient information.

required
Source code in src/ursa/agents/websearch_agent.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def process_content(
    url: str, context: str, state: Annotated[dict, InjectedState]
) -> str:
    """
    Processes content from a given webpage.

    Args:
        url: string with the url to obtain text content from.
        context: string summary of the information the agent wants from the url for summarizing salient information.
    """
    print("Parsing information from ", url)
    response = requests.get(url)
    soup = BeautifulSoup(response.content, "html.parser")

    content_prompt = f"""
    Here is the full content:
    {soup.get_text()}

    Carefully summarize the content in full detail, given the following context:
    {context}
    """
    summarized_information = (
        state["model"]
        .invoke(
            content_prompt, {"configurable": {"thread_id": state["thread_id"]}}
        )
        .content
    )
    return summarized_information