Source code for axe_usd.usd.material_builders.base

"""Shared helpers for USD material builders."""

from dataclasses import dataclass
import logging
from typing import Dict, Iterable, Optional, Tuple

from pxr import Sdf, Usd, UsdShade

from ..material_model import TextureFormatOverrides, apply_texture_format_override
from ..types import MaterialTextureDict

RENDERER_ARNOLD = "arnold"
RENDERER_MTLX = "mtlx"
RENDERER_OPENPBR = "openpbr"

ARNOLD_DISPLACEMENT_BUMP = "bump"
ARNOLD_DISPLACEMENT_DISPLACEMENT = "displacement"

MTLX_LIKE_IMAGE_SIGNATURE = {
    "basecolor": "color3",
    "emission": "color3",
    "normal": "vector3",
    "metalness": "float",
    "opacity": "color3",
    "roughness": "float",
    "displacement": "float",
}


[docs] @dataclass(frozen=True) class MaterialBuildContext: stage: Usd.Stage material_dict: MaterialTextureDict is_transmissive: bool texture_format_overrides: TextureFormatOverrides logger: logging.Logger arnold_displacement_mode: str = ARNOLD_DISPLACEMENT_BUMP
def _iter_textures( context: MaterialBuildContext, input_map: Dict[str, str], renderer_name: str, ) -> Iterable[Tuple[str, str, str]]: for slot, info in context.material_dict.items(): input_name = input_map.get(slot) if not input_name: context.logger.warning( "Texture slot '%s' not supported for %s.", slot, renderer_name ) continue path = info.get("path") if not path: context.logger.warning("Texture slot '%s' missing path; skipping.", slot) continue yield slot, input_name, path def _connect_nodegraph_output( nodegraph: UsdShade.NodeGraph, output_name: str, output_type: Sdf.ValueTypeNames, source: UsdShade.Shader, source_output: str, ) -> UsdShade.Output: output = nodegraph.CreateOutput(output_name, output_type) output.ConnectToSource(source.ConnectableAPI(), source_output) return output class _MtlxLikeBuilder: input_map: Dict[str, str] = {} renderer_name = "" texture_prefix = "" image_signatures = MTLX_LIKE_IMAGE_SIGNATURE emission_intensity_input = "emission" def __init__(self, context: MaterialBuildContext) -> None: self._context = context def _initialize_image_shader( self, image_path: str, signature: str = "color3" ) -> UsdShade.Shader: shader = UsdShade.Shader.Define(self._context.stage, image_path) shader.CreateIdAttr(f"ND_image_{signature}") shader.CreateInput("file", Sdf.ValueTypeNames.Asset) return shader def _initialize_color_correct_shader( self, color_correct_path: str, signature: str = "color3", ) -> UsdShade.Shader: shader = UsdShade.Shader.Define(self._context.stage, color_correct_path) shader.CreateIdAttr(f"ND_colorcorrect_{signature}") return shader def _initialize_range_shader( self, range_path: str, signature: str = "color3" ) -> UsdShade.Shader: shader = UsdShade.Shader.Define(self._context.stage, range_path) shader.CreateIdAttr(f"ND_range_{signature}") return shader def _initialize_normal_map_shader(self, normal_map_path: str) -> UsdShade.Shader: shader = UsdShade.Shader.Define(self._context.stage, normal_map_path) shader.CreateIdAttr("ND_normalmap") return shader def _enable_transmission(self, shader: UsdShade.Shader) -> None: shader.GetInput("transmission").Set(0.9) shader.GetInput("thin_walled").Set(1) def _connect_color_correct( self, collect_path: str, slot: str, texture_shader: UsdShade.Shader, std_surf_shader: UsdShade.Shader, input_name: str, ) -> None: color_correct_path = f"{collect_path}/{self.texture_prefix}_{slot}ColorCorrect" color_correct_shader = self._initialize_color_correct_shader(color_correct_path) color_correct_shader.CreateInput( "in", Sdf.ValueTypeNames.Color3f ).ConnectToSource( texture_shader.ConnectableAPI(), "out", ) std_surf_shader.CreateInput( input_name, Sdf.ValueTypeNames.Color3f ).ConnectToSource( color_correct_shader.ConnectableAPI(), "out", ) def _connect_range( self, collect_path: str, slot: str, texture_shader: UsdShade.Shader, std_surf_shader: UsdShade.Shader, input_name: str, signature: str = "float", ) -> None: range_path = f"{collect_path}/{self.texture_prefix}_{slot}Range" range_shader = self._initialize_range_shader(range_path, signature=signature) range_value_type = ( Sdf.ValueTypeNames.Float if signature == "float" else Sdf.ValueTypeNames.Color3f ) range_shader.CreateInput("in", range_value_type).ConnectToSource( texture_shader.ConnectableAPI(), "out", ) std_surf_shader.CreateInput( input_name, range_value_type ).ConnectToSource( range_shader.ConnectableAPI(), "out", ) def _connect_normal( self, collect_path: str, texture_shader: UsdShade.Shader, std_surf_shader: UsdShade.Shader, input_name: str, ) -> None: normal_map_path = f"{collect_path}/{self.texture_prefix}_NormalMap" normal_map_shader = self._initialize_normal_map_shader(normal_map_path) normal_map_shader.CreateInput("in", Sdf.ValueTypeNames.Float3).ConnectToSource( texture_shader.ConnectableAPI(), "out", ) std_surf_shader.CreateInput( input_name, Sdf.ValueTypeNames.Float3 ).ConnectToSource( normal_map_shader.ConnectableAPI(), "out", ) def _connect_displacement( self, _nodegraph: UsdShade.NodeGraph, std_surf_shader: UsdShade.Shader, input_name: str, texture_shader: UsdShade.Shader, ) -> None: std_surf_shader.CreateInput( input_name, Sdf.ValueTypeNames.Float ).ConnectToSource( texture_shader.ConnectableAPI(), "out", ) def _wire_textures( self, nodegraph: UsdShade.NodeGraph, collect_path: str, std_surf_shader: UsdShade.Shader, override: Optional[str], ) -> None: for slot, input_name, path in _iter_textures( self._context, self.input_map, self.renderer_name ): tex_filepath = apply_texture_format_override(path, override) texture_prim_path = f"{collect_path}/{self.texture_prefix}_{slot}Texture" texture_shader = self._initialize_image_shader( texture_prim_path, signature=self.image_signatures[slot], ) texture_shader.GetInput("file").Set(tex_filepath) if slot == "basecolor": self._connect_color_correct( collect_path, slot, texture_shader, std_surf_shader, input_name ) elif slot == "emission": if self.image_signatures[slot] == "float": self._connect_range( collect_path, slot, texture_shader, std_surf_shader, input_name, signature="float", ) else: self._connect_color_correct( collect_path, slot, texture_shader, std_surf_shader, input_name, ) emission_input = std_surf_shader.GetInput(self.emission_intensity_input) if emission_input: emission_input.Set(1) elif slot in {"metalness", "roughness"}: if slot == "metalness" and self._context.is_transmissive: continue self._connect_range( collect_path, slot, texture_shader, std_surf_shader, input_name ) elif slot == "opacity": opacity_signature = ( "float" if self.image_signatures[slot] == "float" else "color3" ) self._connect_range( collect_path, slot, texture_shader, std_surf_shader, input_name, signature=opacity_signature, ) elif slot == "normal": self._connect_normal( collect_path, texture_shader, std_surf_shader, input_name ) elif slot == "displacement": self._connect_displacement( nodegraph, std_surf_shader, input_name, texture_shader )