""" Thread-based TCP proxy that wraps Bicep.LangServer.dll. Architecture: Editor (TCP:2088) ──► BicepProxy ──► Bicep.LangServer subprocess (stdio) Uses subprocess.Popen + threads instead of asyncio subprocess — far more reliable for stdin/stdout bridging of long-lived processes. Intercepts: - textDocument/didOpen + didChange → tracks document content per URI - textDocument/completion requests → detects cursor context (module path / version / param) - textDocument/completion responses → injects context-appropriate completions Context-aware injection: - Cursor in module path string → inject LRU module names - Cursor after 'br/modules:NAME: → inject version suggestions for that module - Cursor inside params {} block → inject param suggestions for that module+version """ import json import logging import os import re import socket import subprocess import threading from typing import Any from .modules import BicepModuleCatalog logger = logging.getLogger(__name__) BICEP_LS_PATH = os.getenv( "BICEP_LS_PATH", "/opt/bicep-langserver/Bicep.LangServer.dll", ) # ── LSP framing helpers ──────────────────────────────────────────────────────── def _read_message(fileobj) -> bytes: """Read one LSP Content-Length framed message from a file-like object.""" header = b"" while not header.endswith(b"\r\n\r\n"): ch = fileobj.read(1) if not ch: raise EOFError("Stream closed") header += ch content_length = 0 for line in header.split(b"\r\n"): if line.lower().startswith(b"content-length:"): content_length = int(line.split(b":")[1].strip()) return fileobj.read(content_length) def _frame(body: bytes) -> bytes: return f"Content-Length: {len(body)}\r\n\r\n".encode() + body # ── Per-connection session state ─────────────────────────────────────────────── class _ProxySession: """ Tracks document content and pending completion requests for one editor session. Thread safety: CPython GIL makes individual dict reads/writes atomic. One thread writes to docs/pending (client_to_ls); one thread reads (ls_to_client). """ def __init__(self) -> None: self.docs: dict[str, list[str]] = {} # uri → lines self.pending: dict = {} # request_id → context dict def update_doc(self, uri: str, text: str) -> None: self.docs[uri] = text.splitlines() def record_completion_request(self, msg: dict) -> None: req_id = msg.get("id") if req_id is None: return params = msg.get("params", {}) uri = params.get("textDocument", {}).get("uri", "") position = params.get("position", {}) self.pending[req_id] = self._detect_context(uri, position) def _detect_context(self, uri: str, position: dict) -> dict: """Determine what kind of completion is being requested based on cursor position.""" lines = self.docs.get(uri, []) line_idx = position.get("line", 0) char_idx = position.get("character", 0) if not lines or line_idx >= len(lines): return {"type": "unknown"} # Text on the current line up to the cursor current = lines[line_idx][:char_idx] # 1. Version context: cursor is after 'br/modules:NAME: m = re.search(r"'br/modules:([^:'\s]+):([^'\s]*)$", current) if m: return {"type": "version", "module": m.group(1), "prefix": m.group(2)} # 2. Module path context: cursor is inside 'br/modules: (no version colon yet) m = re.search(r"'br/modules:([^:'\s]*)$", current) if m: return {"type": "module_path", "prefix": m.group(1)} # 3. Params context: walk up to find enclosing module declaration lookback_start = max(0, line_idx - 60) context_lines = lines[lookback_start : line_idx + 1] context_lines = list(context_lines) context_lines[-1] = context_lines[-1][:char_idx] context_text = "\n".join(context_lines) # Find the last module declaration in the lookback window mod_matches = list( re.finditer(r"module\s+\w+\s+'br/modules:([^:]+):([^']*)'", context_text) ) if mod_matches: last_mod = mod_matches[-1] text_after_mod = context_text[last_mod.start():] params_m = re.search(r"\bparams\s*:\s*\{", text_after_mod) if params_m: text_in_params = text_after_mod[params_m.start():] if text_in_params.count("{") > text_in_params.count("}"): mod_name = last_mod.group(1) mod_ver = last_mod.group(2) # Check if cursor is after 'paramname: ' on the current line # (value context — inject enum/allowed values) value_m = re.search(r"^\s*(\w+):\s*('?)([^'{}]*)$", current) if value_m and value_m.group(1) not in {"params", "name", "module", "resource"}: return { "type": "param_value", "module": mod_name, "version": mod_ver, "param": value_m.group(1), "has_open_quote": bool(value_m.group(2)), } # Check if cursor is inside an array value for a param # e.g. "roles: ['KEY_VAULT_" or "roles: [ '" array_m = re.search(r"^\s*(\w+):\s*\[[^\]]*?('?)([^',\]]*)$", current) if array_m and array_m.group(1) not in {"params", "name", "module", "resource"}: return { "type": "param_value", "module": mod_name, "version": mod_ver, "param": array_m.group(1), "has_open_quote": bool(array_m.group(2)), } return { "type": "param", "module": mod_name, "version": mod_ver, } return {"type": "unknown"} def pop_context(self, msg_id) -> dict: return self.pending.pop(msg_id, {"type": "unknown"}) # ── Completion injection ─────────────────────────────────────────────────────── def _inject_completions(msg: dict[str, Any], context: dict | None = None) -> bytes: """ Inject LRU-aware completions into completion responses. Behaviour depends on context type: - 'version' → inject version suggestions for the named module - 'param' → inject param suggestions for module+version - 'module_path' / 'unknown' / None → inject module name suggestions (legacy) """ if context is None: context = {} result = msg.get("result") if result is None: return json.dumps(msg).encode() items: list | None = None if isinstance(result, list): items = result elif isinstance(result, dict) and "items" in result: items = result["items"] if items is None: return json.dumps(msg).encode() ctx_type = context.get("type", "unknown") if ctx_type == "version": lru_items = BicepModuleCatalog.version_completion_items(context["module"]) elif ctx_type == "param": lru_items = BicepModuleCatalog.param_completion_items( context["module"], context["version"] ) elif ctx_type == "param_value": lru_items = BicepModuleCatalog.param_value_completion_items( context["module"], context["version"], context["param"], context.get("has_open_quote", False), ) else: # Default: module name completions lru_items = BicepModuleCatalog.as_completion_items() if ctx_type in ("version", "param", "param_value"): # Always replace LS completions for private-registry contexts — the # Bicep LS doesn't know about our ACR, so anything it returns is noise. # Even if lru_items is empty (no enum values for a param), suppress LS. if isinstance(result, list): msg["result"] = lru_items else: result["items"] = lru_items result["isIncomplete"] = False elif lru_items: # module_path / unknown: keep LS completions below ours for item in items: st = item.get("sortText", item.get("label", "")) item["sortText"] = f"1_az_{st}" if isinstance(result, list): msg["result"] = lru_items + items else: result["items"] = lru_items + items result["isIncomplete"] = True return json.dumps(msg).encode() def _client_to_ls( conn_file, proc_stdin, session: _ProxySession, ) -> None: try: while True: body = _read_message(conn_file) logger.debug("Client→LS: %d bytes", len(body)) # Track document state and completion context (never block forwarding) try: msg = json.loads(body) method = msg.get("method", "") if method == "textDocument/didOpen": text_doc = msg.get("params", {}).get("textDocument", {}) uri, text = text_doc.get("uri", ""), text_doc.get("text", "") if uri: session.update_doc(uri, text or "") elif method == "textDocument/didChange": uri = msg.get("params", {}).get("textDocument", {}).get("uri", "") changes = msg.get("params", {}).get("contentChanges", []) if uri and changes: session.update_doc(uri, changes[-1].get("text", "")) elif method == "textDocument/completion": session.record_completion_request(msg) except Exception: pass # parsing errors must never block forwarding framed = _frame(body) proc_stdin.write(framed) proc_stdin.flush() logger.debug("Client→LS: flushed") except EOFError: logger.debug("Client write side closed — signalling EOF to LS") except Exception as exc: logger.debug("Client→LS error: %s", exc) finally: try: proc_stdin.close() except Exception: pass def _ls_to_client( proc_stdout, conn: socket.socket, session: _ProxySession, ) -> None: try: while True: body = _read_message(proc_stdout) logger.debug("LS→Client: %d bytes", len(body)) try: msg = json.loads(body) context: dict = {} if "id" in msg and "result" in msg: context = session.pop_context(msg["id"]) out = _inject_completions(msg, context) except json.JSONDecodeError: out = body conn.sendall(_frame(out)) except EOFError: logger.debug("Bicep LS stdout closed") except Exception as exc: logger.debug("LS→Client error: %s", exc) def _handle_client(conn: socket.socket, addr: tuple) -> None: logger.info("New Bicep client: %s", addr) session = _ProxySession() proc = subprocess.Popen( ["dotnet", BICEP_LS_PATH, "--stdio"], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, ) logger.info("Bicep LS subprocess started (pid=%d)", proc.pid) # Unbuffered read from the socket — critical for correct LSP framing conn_file = conn.makefile("rb", buffering=0) # t1: client → LS (finishes when client closes write side) t1 = threading.Thread( target=_client_to_ls, args=(conn_file, proc.stdin, session), daemon=True, ) # t2: LS → client (finishes when LS closes stdout) t2 = threading.Thread( target=_ls_to_client, args=(proc.stdout, conn, session), daemon=True, ) t1.start() t2.start() # Session ends when LS is done (not when client closes write side) t2.join() try: proc.wait(timeout=3) except Exception: proc.terminate() try: conn.close() except Exception: pass logger.info("Bicep client %s disconnected", addr) def serve_bicep(port: int) -> None: """Blocking TCP server — run in a daemon thread alongside asyncio.""" server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server.bind(("0.0.0.0", port)) server.listen(10) logger.info("Bicep LSP proxy listening on TCP :%d", port) while True: try: conn, addr = server.accept() threading.Thread( target=_handle_client, args=(conn, addr), daemon=True, ).start() except Exception as exc: logger.error("Accept error: %s", exc)